Spaces:
Sleeping
Sleeping
| """Baseline inference script for the Data Analysis Agent environment. | |
| Uses the OpenAI API to run a model (gpt-4o-mini) against all 6 tasks | |
| and produces reproducible baseline scores. | |
| The script uses DataAnalysisClient (WebSocket) because the HTTP endpoints | |
| are stateless — each request gets a fresh env instance. State (namespace, | |
| task, dataset) only persists within a WebSocket session. | |
| Tasks 1-3 use only the pandas DataFrame (df). Tasks 4-6 are cross-source: | |
| they also require querying a SQLite database via sqlite3.connect(db_path). | |
| Usage: | |
| OPENAI_API_KEY=sk-... uv run python baseline.py | |
| OPENAI_API_KEY=sk-... uv run python baseline.py --base-url http://localhost:8000 | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from openai import OpenAI | |
| from client import DataAnalysisClient | |
| from helpers.prompts import SYSTEM_PROMPT | |
| from models import DataAction | |
| def run_task(openai_client: OpenAI, env_client: DataAnalysisClient, task_id: int, max_steps: int = 15) -> float: | |
| """Run a single task using the OpenAI API as the agent. | |
| Args: | |
| openai_client: The OpenAI client instance. | |
| env_client: The connected DataAnalysisClient (sync wrapper). | |
| task_id: Which task to run (1–6). | |
| max_steps: Maximum agent steps before giving up. | |
| Returns: | |
| The final score for this task (0.0 to 1.0). | |
| """ | |
| result = env_client.reset(task_id=task_id) | |
| obs = result.observation | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| { | |
| "role": "user", | |
| "content": f"Task: {obs.task_description}\n\nDataset Info:\n{obs.dataset_info}", | |
| }, | |
| ] | |
| print(f"\n--- Task {task_id} ---") | |
| print(f"Question: {obs.task_description}") | |
| for step in range(max_steps): | |
| response = openai_client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=messages, | |
| temperature=0.0, | |
| ) | |
| assistant_msg = response.choices[0].message.content.strip() | |
| # Parse the agent's JSON response | |
| try: | |
| # Handle markdown code blocks if present | |
| if assistant_msg.startswith("```"): | |
| assistant_msg = assistant_msg.split("```")[1] | |
| if assistant_msg.startswith("json"): | |
| assistant_msg = assistant_msg[4:] | |
| assistant_msg = assistant_msg.strip() | |
| action = json.loads(assistant_msg) | |
| except json.JSONDecodeError: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append( | |
| { | |
| "role": "user", | |
| "content": "Invalid JSON. Please respond with valid JSON only.", | |
| } | |
| ) | |
| continue | |
| action_type = action.get("action", "") | |
| if action_type == "execute_code": | |
| result = env_client.step(DataAction(action_type="execute_code", code=action.get("code", ""))) | |
| obs = result.observation | |
| result_text = f"Output: {obs.output}" if not obs.error else f"Error: {obs.error}" | |
| print(f" Step {step + 1}: execute_code -> {result_text[:120]}") | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": result_text}) | |
| elif action_type == "submit_answer": | |
| result = env_client.step(DataAction(action_type="submit_answer", answer=action.get("answer", ""))) | |
| obs = result.observation | |
| score = obs.metadata.get("score", 0.0) if obs.metadata else result.reward | |
| print(f" Step {step + 1}: submit_answer -> '{action.get('answer', '')}'") | |
| print(f" Score: {score:.2f}") | |
| return score | |
| else: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append( | |
| { | |
| "role": "user", | |
| "content": f"Unknown action '{action_type}'. Use 'execute_code' or 'submit_answer'.", | |
| } | |
| ) | |
| print(" Max steps reached without submitting an answer.") | |
| return 0.0 | |
| def main(): | |
| """Run baseline inference across all 6 tasks and report scores.""" | |
| parser = argparse.ArgumentParser(description="Baseline inference for Data Analysis Env") | |
| parser.add_argument( | |
| "--base-url", | |
| default="http://localhost:8000", | |
| help="Environment server URL (default: http://localhost:8000)", | |
| ) | |
| args = parser.parse_args() | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| print("Error: OPENAI_API_KEY environment variable is required.") | |
| sys.exit(1) | |
| openai_client = OpenAI(api_key=api_key) | |
| print("=" * 55) | |
| print("Data Analysis Agent - Baseline Inference") | |
| print(f"Server: {args.base_url}") | |
| print("Model: gpt-4o-mini") | |
| print("=" * 55) | |
| scores = {} | |
| difficulties = { | |
| 1: "Easy", | |
| 2: "Medium", | |
| 3: "Medium", | |
| 4: "Hard", | |
| 5: "Hard", | |
| 6: "Hard", | |
| } | |
| with DataAnalysisClient(base_url=args.base_url).sync() as env_client: | |
| for task_id in [1, 2, 3, 4, 5, 6]: | |
| score = run_task(openai_client, env_client, task_id) | |
| scores[task_id] = score | |
| print("\n" + "=" * 55) | |
| print("RESULTS") | |
| print("=" * 55) | |
| for task_id, score in scores.items(): | |
| print(f" Task {task_id} ({difficulties[task_id]:6s}): {score:.2f}") | |
| avg = sum(scores.values()) / len(scores) | |
| print(f"\n Average Score: {avg:.2f}") | |
| print("=" * 55) | |
| if __name__ == "__main__": | |
| main() | |