DataAnalysis_Env / baseline.py
Mohammed-Altaf's picture
black format and isort code
a038a1e
"""Baseline inference script for the Data Analysis Agent environment.
Uses the OpenAI API to run a model (gpt-4o-mini) against all 6 tasks
and produces reproducible baseline scores.
The script uses DataAnalysisClient (WebSocket) because the HTTP endpoints
are stateless — each request gets a fresh env instance. State (namespace,
task, dataset) only persists within a WebSocket session.
Tasks 1-3 use only the pandas DataFrame (df). Tasks 4-6 are cross-source:
they also require querying a SQLite database via sqlite3.connect(db_path).
Usage:
OPENAI_API_KEY=sk-... uv run python baseline.py
OPENAI_API_KEY=sk-... uv run python baseline.py --base-url http://localhost:8000
"""
import argparse
import json
import os
import sys
from openai import OpenAI
from client import DataAnalysisClient
from helpers.prompts import SYSTEM_PROMPT
from models import DataAction
def run_task(openai_client: OpenAI, env_client: DataAnalysisClient, task_id: int, max_steps: int = 15) -> float:
"""Run a single task using the OpenAI API as the agent.
Args:
openai_client: The OpenAI client instance.
env_client: The connected DataAnalysisClient (sync wrapper).
task_id: Which task to run (1–6).
max_steps: Maximum agent steps before giving up.
Returns:
The final score for this task (0.0 to 1.0).
"""
result = env_client.reset(task_id=task_id)
obs = result.observation
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": f"Task: {obs.task_description}\n\nDataset Info:\n{obs.dataset_info}",
},
]
print(f"\n--- Task {task_id} ---")
print(f"Question: {obs.task_description}")
for step in range(max_steps):
response = openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
temperature=0.0,
)
assistant_msg = response.choices[0].message.content.strip()
# Parse the agent's JSON response
try:
# Handle markdown code blocks if present
if assistant_msg.startswith("```"):
assistant_msg = assistant_msg.split("```")[1]
if assistant_msg.startswith("json"):
assistant_msg = assistant_msg[4:]
assistant_msg = assistant_msg.strip()
action = json.loads(assistant_msg)
except json.JSONDecodeError:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append(
{
"role": "user",
"content": "Invalid JSON. Please respond with valid JSON only.",
}
)
continue
action_type = action.get("action", "")
if action_type == "execute_code":
result = env_client.step(DataAction(action_type="execute_code", code=action.get("code", "")))
obs = result.observation
result_text = f"Output: {obs.output}" if not obs.error else f"Error: {obs.error}"
print(f" Step {step + 1}: execute_code -> {result_text[:120]}")
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": result_text})
elif action_type == "submit_answer":
result = env_client.step(DataAction(action_type="submit_answer", answer=action.get("answer", "")))
obs = result.observation
score = obs.metadata.get("score", 0.0) if obs.metadata else result.reward
print(f" Step {step + 1}: submit_answer -> '{action.get('answer', '')}'")
print(f" Score: {score:.2f}")
return score
else:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append(
{
"role": "user",
"content": f"Unknown action '{action_type}'. Use 'execute_code' or 'submit_answer'.",
}
)
print(" Max steps reached without submitting an answer.")
return 0.0
def main():
"""Run baseline inference across all 6 tasks and report scores."""
parser = argparse.ArgumentParser(description="Baseline inference for Data Analysis Env")
parser.add_argument(
"--base-url",
default="http://localhost:8000",
help="Environment server URL (default: http://localhost:8000)",
)
args = parser.parse_args()
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
print("Error: OPENAI_API_KEY environment variable is required.")
sys.exit(1)
openai_client = OpenAI(api_key=api_key)
print("=" * 55)
print("Data Analysis Agent - Baseline Inference")
print(f"Server: {args.base_url}")
print("Model: gpt-4o-mini")
print("=" * 55)
scores = {}
difficulties = {
1: "Easy",
2: "Medium",
3: "Medium",
4: "Hard",
5: "Hard",
6: "Hard",
}
with DataAnalysisClient(base_url=args.base_url).sync() as env_client:
for task_id in [1, 2, 3, 4, 5, 6]:
score = run_task(openai_client, env_client, task_id)
scores[task_id] = score
print("\n" + "=" * 55)
print("RESULTS")
print("=" * 55)
for task_id, score in scores.items():
print(f" Task {task_id} ({difficulties[task_id]:6s}): {score:.2f}")
avg = sum(scores.values()) / len(scores)
print(f"\n Average Score: {avg:.2f}")
print("=" * 55)
if __name__ == "__main__":
main()