| import asyncio |
| import os |
| import sys |
| from typing import List, Optional |
|
|
| from openai import OpenAI |
| from openenv.core.env_client import StepResult |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from data_analysis_env import ( |
| DataAnalysisAction, |
| DataAnalysisEnv, |
| TASKS, |
| ) |
|
|
|
|
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") |
|
|
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") |
| BENCHMARK = "data_analysis_env" |
| MAX_STEPS = 20 |
|
|
|
|
| TASK_INSTRUCTIONS = { |
| "task_1": """You are a data analysis assistant. Your task is to: |
| 1. Load the CSV file 'simple.csv' |
| 2. Calculate the mean of the 'price' column |
| |
| Available tools: |
| - load_csv(filename='filename.csv') |
| - show_data() |
| - show_columns() |
| - calculate(column='column_name', operation='mean|median|sum|count|std|min|max') |
| |
| Start by loading the data, then calculate the mean of the price column.""", |
| "task_2": """You are a data analysis assistant. Your task is to: |
| 1. Load the CSV file 'dirty.csv' |
| 2. Fill missing values (use mean) |
| 3. Remove duplicate rows |
| 4. Calculate the median of the 'age' column |
| |
| Available tools: |
| - load_csv(filename='filename.csv') |
| - fill_missing(value='mean|median|zero|value') |
| - remove_duplicates() |
| - show_data() |
| - show_columns() |
| - calculate(column='column_name', operation='mean|median|sum|count|std|min|max') |
| |
| Start by loading the data, then clean it, then calculate the median.""", |
| "task_3": """You are a data analysis assistant. Your task is to: |
| 1. Load 'sales.csv' and 'products.csv' |
| 2. Merge them on 'product_id' |
| 3. Group by 'category' and sum the 'sales' column |
| 4. Get the final result |
| |
| Available tools: |
| - load_csv(filename='filename.csv') |
| - merge_datasets(filename='filename.csv', on='column_name') |
| - show_data() |
| - show_columns() |
| - group_by(group_column='column_name', agg_column='column_name', operation='sum|mean|count') |
| - calculate(column='column_name', operation='sum|mean|count') |
| - get_result() |
| |
| Start by loading both files, then merge, then group and aggregate.""", |
| } |
|
|
|
|
| def get_action_from_response(response: str) -> Optional[DataAnalysisAction]: |
| response = response.strip() |
|
|
| if response.lower() in ["done", "get_result()"]: |
| return DataAnalysisAction(tool="get_result", parameters={}) |
|
|
| if "(" not in response or ")" not in response: |
| return None |
|
|
| try: |
| tool_name = response.split("(")[0].strip() |
| params_str = response.split("(")[1].split(")")[0].strip() |
|
|
| parameters = {} |
| if params_str: |
| for param in params_str.split(","): |
| param = param.strip() |
| if "=" in param: |
| key, value = param.split("=", 1) |
| key = key.strip() |
| value = value.strip().strip("'\"") |
|
|
| if value.lower() == "none": |
| value = None |
| elif value.lower() == "true": |
| value = True |
| elif value.lower() == "false": |
| value = False |
| else: |
| try: |
| if "." in value: |
| value = float(value) |
| else: |
| value = int(value) |
| except ValueError: |
| pass |
|
|
| parameters[key] = value |
|
|
| return DataAnalysisAction(tool=tool_name, parameters=parameters) |
|
|
| except Exception as e: |
| print(f"Error parsing action: {e}", file=sys.stderr) |
| return None |
|
|
|
|
| async def run_task(client: OpenAI, env: DataAnalysisEnv, task_name: str) -> dict: |
| print(f"[START] task={task_name} env={BENCHMARK} model={MODEL_NAME}") |
|
|
| instruction = TASK_INSTRUCTIONS.get(task_name, "") |
| messages = [ |
| {"role": "system", "content": instruction}, |
| {"role": "user", "content": "Begin the analysis task."}, |
| ] |
|
|
| step = 0 |
| rewards = [] |
| last_error = None |
|
|
| result = await env.reset(task=task_name) |
| obs = result.observation |
| reward_val = obs.reward if obs.reward is not None else 0.0 |
|
|
| print( |
| f"[STEP] step={step} action=reset task={task_name} reward={reward_val:.2f} done={result.done} error=null" |
| ) |
|
|
| while not result.done and step < MAX_STEPS: |
| step += 1 |
|
|
| response = ( |
| client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=messages |
| + [{"role": "assistant", "content": f"Previous output: {obs.output}"}], |
| temperature=0.1, |
| max_tokens=500, |
| ) |
| .choices[0] |
| .message.content |
| ) |
|
|
| action = get_action_from_response(response) |
|
|
| if action is None: |
| last_error = "Could not parse action" |
| print( |
| f"[STEP] step={step} action='{response}' reward={obs.reward:.2f} done=false error={last_error}" |
| ) |
| messages.append( |
| { |
| "role": "user", |
| "content": f"Invalid action format. Please use tool_name(param1=value1, param2=value2). Error: {last_error}", |
| } |
| ) |
| continue |
|
|
| result = await env.step(action) |
| obs = result.observation |
| reward_val = obs.reward if obs.reward is not None else 0.0 |
| rewards.append(reward_val) |
|
|
| error_str = obs.error if obs.error else "null" |
| print( |
| f"[STEP] step={step} action={action.tool}({action.parameters}) reward={reward_val:.2f} done={result.done} error={error_str}" |
| ) |
|
|
| if obs.error: |
| last_error = obs.error |
| messages.append( |
| { |
| "role": "user", |
| "content": f"Error: {obs.error}. Please try a different tool or correct parameters.", |
| } |
| ) |
| else: |
| messages.append( |
| { |
| "role": "user", |
| "content": f"Tool executed successfully. Output: {obs.output}", |
| } |
| ) |
|
|
| if result.done: |
| break |
|
|
| score = obs.reward |
| success = score >= 0.7 |
|
|
| rewards_str = ",".join([f"{r:.2f}" for r in rewards]) |
| print( |
| f"[END] success={str(success).lower()} steps={step} score={score:.2f} rewards={rewards_str}" |
| ) |
|
|
| return { |
| "task": task_name, |
| "success": success, |
| "steps": step, |
| "score": score, |
| "rewards": rewards, |
| } |
|
|
|
|
| async def main(): |
| api_key = API_KEY |
| if not api_key: |
| print( |
| "Error: HF_TOKEN or API_KEY environment variable not set", file=sys.stderr |
| ) |
| sys.exit(1) |
|
|
| client = OpenAI( |
| api_key=api_key, |
| base_url=API_BASE_URL, |
| ) |
|
|
| base_url = os.getenv("ENV_URL", "http://localhost:8000") |
| env = DataAnalysisEnv(base_url=base_url) |
|
|
| results = [] |
|
|
| for task_name in ["task_1", "task_2", "task_3"]: |
| try: |
| result = await run_task(client, env, task_name) |
| results.append(result) |
| except Exception as e: |
| print(f"Error running {task_name}: {e}", file=sys.stderr) |
| results.append( |
| { |
| "task": task_name, |
| "success": False, |
| "steps": 0, |
| "score": 0.0, |
| "rewards": [], |
| } |
| ) |
|
|
| avg_score = sum(r["score"] for r in results) / len(results) |
| print(f"\n=== Summary ===") |
| print(f"Average Score: {avg_score:.2f}") |
| for r in results: |
| print(f" {r['task']}: {r['score']:.2f} ({'PASS' if r['success'] else 'FAIL'})") |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|