data_analysis_env / inference.py
HimanshuSardana2's picture
fix: handle async client and None reward values
5e73079
import asyncio
import os
import sys
from typing import List, Optional
from openai import OpenAI
from openenv.core.env_client import StepResult
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from data_analysis_env import (
DataAnalysisAction,
DataAnalysisEnv,
TASKS,
)
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
BENCHMARK = "data_analysis_env"
MAX_STEPS = 20
TASK_INSTRUCTIONS = {
"task_1": """You are a data analysis assistant. Your task is to:
1. Load the CSV file 'simple.csv'
2. Calculate the mean of the 'price' column
Available tools:
- load_csv(filename='filename.csv')
- show_data()
- show_columns()
- calculate(column='column_name', operation='mean|median|sum|count|std|min|max')
Start by loading the data, then calculate the mean of the price column.""",
"task_2": """You are a data analysis assistant. Your task is to:
1. Load the CSV file 'dirty.csv'
2. Fill missing values (use mean)
3. Remove duplicate rows
4. Calculate the median of the 'age' column
Available tools:
- load_csv(filename='filename.csv')
- fill_missing(value='mean|median|zero|value')
- remove_duplicates()
- show_data()
- show_columns()
- calculate(column='column_name', operation='mean|median|sum|count|std|min|max')
Start by loading the data, then clean it, then calculate the median.""",
"task_3": """You are a data analysis assistant. Your task is to:
1. Load 'sales.csv' and 'products.csv'
2. Merge them on 'product_id'
3. Group by 'category' and sum the 'sales' column
4. Get the final result
Available tools:
- load_csv(filename='filename.csv')
- merge_datasets(filename='filename.csv', on='column_name')
- show_data()
- show_columns()
- group_by(group_column='column_name', agg_column='column_name', operation='sum|mean|count')
- calculate(column='column_name', operation='sum|mean|count')
- get_result()
Start by loading both files, then merge, then group and aggregate.""",
}
def get_action_from_response(response: str) -> Optional[DataAnalysisAction]:
response = response.strip()
if response.lower() in ["done", "get_result()"]:
return DataAnalysisAction(tool="get_result", parameters={})
if "(" not in response or ")" not in response:
return None
try:
tool_name = response.split("(")[0].strip()
params_str = response.split("(")[1].split(")")[0].strip()
parameters = {}
if params_str:
for param in params_str.split(","):
param = param.strip()
if "=" in param:
key, value = param.split("=", 1)
key = key.strip()
value = value.strip().strip("'\"")
if value.lower() == "none":
value = None
elif value.lower() == "true":
value = True
elif value.lower() == "false":
value = False
else:
try:
if "." in value:
value = float(value)
else:
value = int(value)
except ValueError:
pass
parameters[key] = value
return DataAnalysisAction(tool=tool_name, parameters=parameters)
except Exception as e:
print(f"Error parsing action: {e}", file=sys.stderr)
return None
async def run_task(client: OpenAI, env: DataAnalysisEnv, task_name: str) -> dict:
print(f"[START] task={task_name} env={BENCHMARK} model={MODEL_NAME}")
instruction = TASK_INSTRUCTIONS.get(task_name, "")
messages = [
{"role": "system", "content": instruction},
{"role": "user", "content": "Begin the analysis task."},
]
step = 0
rewards = []
last_error = None
result = await env.reset(task=task_name)
obs = result.observation
reward_val = obs.reward if obs.reward is not None else 0.0
print(
f"[STEP] step={step} action=reset task={task_name} reward={reward_val:.2f} done={result.done} error=null"
)
while not result.done and step < MAX_STEPS:
step += 1
response = (
client.chat.completions.create(
model=MODEL_NAME,
messages=messages
+ [{"role": "assistant", "content": f"Previous output: {obs.output}"}],
temperature=0.1,
max_tokens=500,
)
.choices[0]
.message.content
)
action = get_action_from_response(response)
if action is None:
last_error = "Could not parse action"
print(
f"[STEP] step={step} action='{response}' reward={obs.reward:.2f} done=false error={last_error}"
)
messages.append(
{
"role": "user",
"content": f"Invalid action format. Please use tool_name(param1=value1, param2=value2). Error: {last_error}",
}
)
continue
result = await env.step(action)
obs = result.observation
reward_val = obs.reward if obs.reward is not None else 0.0
rewards.append(reward_val)
error_str = obs.error if obs.error else "null"
print(
f"[STEP] step={step} action={action.tool}({action.parameters}) reward={reward_val:.2f} done={result.done} error={error_str}"
)
if obs.error:
last_error = obs.error
messages.append(
{
"role": "user",
"content": f"Error: {obs.error}. Please try a different tool or correct parameters.",
}
)
else:
messages.append(
{
"role": "user",
"content": f"Tool executed successfully. Output: {obs.output}",
}
)
if result.done:
break
score = obs.reward
success = score >= 0.7
rewards_str = ",".join([f"{r:.2f}" for r in rewards])
print(
f"[END] success={str(success).lower()} steps={step} score={score:.2f} rewards={rewards_str}"
)
return {
"task": task_name,
"success": success,
"steps": step,
"score": score,
"rewards": rewards,
}
async def main():
api_key = API_KEY
if not api_key:
print(
"Error: HF_TOKEN or API_KEY environment variable not set", file=sys.stderr
)
sys.exit(1)
client = OpenAI(
api_key=api_key,
base_url=API_BASE_URL,
)
base_url = os.getenv("ENV_URL", "http://localhost:8000")
env = DataAnalysisEnv(base_url=base_url)
results = []
for task_name in ["task_1", "task_2", "task_3"]:
try:
result = await run_task(client, env, task_name)
results.append(result)
except Exception as e:
print(f"Error running {task_name}: {e}", file=sys.stderr)
results.append(
{
"task": task_name,
"success": False,
"steps": 0,
"score": 0.0,
"rewards": [],
}
)
avg_score = sum(r["score"] for r in results) / len(results)
print(f"\n=== Summary ===")
print(f"Average Score: {avg_score:.2f}")
for r in results:
print(f" {r['task']}: {r['score']:.2f} ({'PASS' if r['success'] else 'FAIL'})")
if __name__ == "__main__":
asyncio.run(main())