Spaces:
Sleeping
Sleeping
Commit ·
0c731dd
1
Parent(s): 82f3f96
added inference.py script
Browse files- inference.py +208 -0
- pyproject.toml +1 -0
- uv.lock +2 -0
inference.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference script for the Data Analysis Agent environment.
|
| 2 |
+
|
| 3 |
+
Runs a language model agent against all 3 tasks and reports scores.
|
| 4 |
+
Uses the OpenAI-compatible client pointed at API_BASE_URL.
|
| 5 |
+
|
| 6 |
+
Required environment variables (set in .env or shell):
|
| 7 |
+
API_BASE_URL OpenAI-compatible LLM API endpoint
|
| 8 |
+
MODEL_NAME Model identifier to use for inference
|
| 9 |
+
HF_TOKEN API key (Hugging Face token or other provider key)
|
| 10 |
+
|
| 11 |
+
Optional:
|
| 12 |
+
ENV_SERVER_URL Environment server URL (default: http://localhost:7860)
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
uv run python inference.py
|
| 16 |
+
uv run python inference.py --env-url http://localhost:8000
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
|
| 24 |
+
from dotenv import load_dotenv
|
| 25 |
+
from openai import OpenAI
|
| 26 |
+
|
| 27 |
+
from client import DataAnalysisClient
|
| 28 |
+
from models import DataAction
|
| 29 |
+
|
| 30 |
+
# Load .env file if present (safe — does not override already-set shell vars)
|
| 31 |
+
load_dotenv()
|
| 32 |
+
|
| 33 |
+
TEMPERATURE = 0.0
|
| 34 |
+
MAX_TOKENS = 1024
|
| 35 |
+
MAX_STEPS = 15 # Per task — keeps total runtime well under 20 min
|
| 36 |
+
|
| 37 |
+
SYSTEM_PROMPT = """You are a data analyst. You are given a dataset loaded as a pandas DataFrame called `df`.
|
| 38 |
+
You can execute Python/pandas code to explore the dataset and answer the question.
|
| 39 |
+
|
| 40 |
+
Rules:
|
| 41 |
+
- Use `print()` to see results of your code
|
| 42 |
+
- The DataFrame `df` is pre-loaded with pandas as `pd` and numpy as `np`
|
| 43 |
+
- When you have the answer, submit it in the exact format requested
|
| 44 |
+
- Be precise with numbers and formatting
|
| 45 |
+
|
| 46 |
+
Respond with JSON in one of these formats:
|
| 47 |
+
1. To execute code: {"action": "execute_code", "code": "your python code here"}
|
| 48 |
+
2. To submit answer: {"action": "submit_answer", "answer": "your answer here"}
|
| 49 |
+
|
| 50 |
+
Respond with ONLY the JSON, no other text."""
|
| 51 |
+
|
| 52 |
+
FALLBACK_ACTION = json.dumps({"action": "submit_answer", "answer": "unknown"})
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def parse_model_action(response_text: str) -> dict:
|
| 56 |
+
"""Parse the model's raw text response into an action dict.
|
| 57 |
+
|
| 58 |
+
Handles plain JSON and markdown code block wrapping.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
response_text: Raw string returned by the model.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Parsed action dict, or a fallback submit_answer on failure.
|
| 65 |
+
"""
|
| 66 |
+
text = response_text.strip()
|
| 67 |
+
if text.startswith("```"):
|
| 68 |
+
parts = text.split("```")
|
| 69 |
+
if len(parts) >= 2:
|
| 70 |
+
text = parts[1]
|
| 71 |
+
if text.startswith("json"):
|
| 72 |
+
text = text[4:]
|
| 73 |
+
text = text.strip()
|
| 74 |
+
try:
|
| 75 |
+
return json.loads(text)
|
| 76 |
+
except json.JSONDecodeError:
|
| 77 |
+
return json.loads(FALLBACK_ACTION)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def run_task(openai_client: OpenAI, env_client: DataAnalysisClient, task_id: int) -> float:
|
| 81 |
+
"""Run a single task episode using the language model as the agent.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
openai_client: Configured OpenAI-compatible client.
|
| 85 |
+
env_client: Connected DataAnalysisClient (sync wrapper).
|
| 86 |
+
task_id: Task to evaluate (1 = easy, 2 = medium, 3 = hard).
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Final score for this task between 0.0 and 1.0.
|
| 90 |
+
"""
|
| 91 |
+
result = env_client.reset(task_id=task_id)
|
| 92 |
+
obs = result.observation
|
| 93 |
+
|
| 94 |
+
messages = [
|
| 95 |
+
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
| 96 |
+
{
|
| 97 |
+
"role": "user",
|
| 98 |
+
"content": [
|
| 99 |
+
{
|
| 100 |
+
"type": "text",
|
| 101 |
+
"text": f"Task: {obs.task_description}\n\nDataset Info:\n{obs.dataset_info}",
|
| 102 |
+
}
|
| 103 |
+
],
|
| 104 |
+
},
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
print(f"\n--- Task {task_id} ---")
|
| 108 |
+
print(f"Question: {obs.task_description}")
|
| 109 |
+
|
| 110 |
+
for step in range(MAX_STEPS):
|
| 111 |
+
try:
|
| 112 |
+
completion = openai_client.chat.completions.create(
|
| 113 |
+
model=os.environ["MODEL_NAME"],
|
| 114 |
+
messages=messages,
|
| 115 |
+
temperature=TEMPERATURE,
|
| 116 |
+
max_tokens=MAX_TOKENS,
|
| 117 |
+
stream=False,
|
| 118 |
+
)
|
| 119 |
+
response_text = completion.choices[0].message.content or ""
|
| 120 |
+
except Exception as exc:
|
| 121 |
+
print(f" Model request failed ({exc}). Using fallback action.")
|
| 122 |
+
response_text = FALLBACK_ACTION
|
| 123 |
+
|
| 124 |
+
action = parse_model_action(response_text)
|
| 125 |
+
action_type = action.get("action", "")
|
| 126 |
+
print(f" Step {step + 1}: model suggested -> {action_type}")
|
| 127 |
+
|
| 128 |
+
if action_type == "execute_code":
|
| 129 |
+
step_result = env_client.step(
|
| 130 |
+
DataAction(action_type="execute_code", code=action.get("code", ""))
|
| 131 |
+
)
|
| 132 |
+
step_obs = step_result.observation
|
| 133 |
+
result_text = f"Output: {step_obs.output}" if not step_obs.error else f"Error: {step_obs.error}"
|
| 134 |
+
print(f" -> {result_text[:120]}")
|
| 135 |
+
|
| 136 |
+
messages.append({"role": "assistant", "content": response_text})
|
| 137 |
+
messages.append({"role": "user", "content": [{"type": "text", "text": result_text}]})
|
| 138 |
+
|
| 139 |
+
elif action_type == "submit_answer":
|
| 140 |
+
step_result = env_client.step(
|
| 141 |
+
DataAction(action_type="submit_answer", answer=action.get("answer", ""))
|
| 142 |
+
)
|
| 143 |
+
step_obs = step_result.observation
|
| 144 |
+
score = step_obs.metadata.get("score", 0.0) if step_obs.metadata else step_result.reward
|
| 145 |
+
print(f" -> submitted: '{action.get('answer', '')}' | score: {score:.2f}")
|
| 146 |
+
return float(score)
|
| 147 |
+
|
| 148 |
+
else:
|
| 149 |
+
messages.append({"role": "assistant", "content": response_text})
|
| 150 |
+
messages.append({
|
| 151 |
+
"role": "user",
|
| 152 |
+
"content": [{"type": "text", "text": f"Unknown action '{action_type}'. Use 'execute_code' or 'submit_answer'."}],
|
| 153 |
+
})
|
| 154 |
+
|
| 155 |
+
print(f" Reached max steps ({MAX_STEPS}). No answer submitted.")
|
| 156 |
+
return 0.0
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def main():
|
| 160 |
+
"""Run inference across all 3 tasks and print final scores."""
|
| 161 |
+
parser = argparse.ArgumentParser(description="Data Analysis Agent inference script")
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--env-url",
|
| 164 |
+
default=os.environ.get("ENV_SERVER_URL", "http://localhost:7860"),
|
| 165 |
+
help="Environment server URL (default: http://localhost:7860)",
|
| 166 |
+
)
|
| 167 |
+
args = parser.parse_args()
|
| 168 |
+
|
| 169 |
+
# Validate required environment variables
|
| 170 |
+
missing = [v for v in ("API_BASE_URL", "MODEL_NAME", "HF_TOKEN") if not os.environ.get(v)]
|
| 171 |
+
if missing:
|
| 172 |
+
print(f"Error: Missing required environment variables: {', '.join(missing)}")
|
| 173 |
+
print("Set them in your shell or create a .env file (see .env.example).")
|
| 174 |
+
sys.exit(1)
|
| 175 |
+
|
| 176 |
+
openai_client = OpenAI(
|
| 177 |
+
base_url=os.environ["API_BASE_URL"],
|
| 178 |
+
api_key=os.environ["HF_TOKEN"],
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
print("=" * 55)
|
| 182 |
+
print("Data Analysis Agent — Inference")
|
| 183 |
+
print(f"Server : {args.env_url}")
|
| 184 |
+
print(f"Model : {os.environ['MODEL_NAME']}")
|
| 185 |
+
print(f"API : {os.environ['API_BASE_URL']}")
|
| 186 |
+
print("=" * 55)
|
| 187 |
+
|
| 188 |
+
scores = {}
|
| 189 |
+
difficulties = {1: "Easy", 2: "Medium", 3: "Hard"}
|
| 190 |
+
|
| 191 |
+
# Each task gets its own isolated WebSocket session
|
| 192 |
+
for task_id in [1, 2, 3]:
|
| 193 |
+
with DataAnalysisClient(base_url=args.env_url).sync() as env_client:
|
| 194 |
+
score = run_task(openai_client, env_client, task_id)
|
| 195 |
+
scores[task_id] = score
|
| 196 |
+
|
| 197 |
+
print("\n" + "=" * 55)
|
| 198 |
+
print("RESULTS")
|
| 199 |
+
print("=" * 55)
|
| 200 |
+
for task_id, score in scores.items():
|
| 201 |
+
print(f" Task {task_id} ({difficulties[task_id]:6s}): {score:.2f}")
|
| 202 |
+
avg = sum(scores.values()) / len(scores)
|
| 203 |
+
print(f"\n Average Score : {avg:.2f}")
|
| 204 |
+
print("=" * 55)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if __name__ == "__main__":
|
| 208 |
+
main()
|
pyproject.toml
CHANGED
|
@@ -14,6 +14,7 @@ dependencies = [
|
|
| 14 |
"openai>=1.0.0",
|
| 15 |
"black>=26.3.1",
|
| 16 |
"isort>=8.0.1",
|
|
|
|
| 17 |
]
|
| 18 |
|
| 19 |
[project.scripts]
|
|
|
|
| 14 |
"openai>=1.0.0",
|
| 15 |
"black>=26.3.1",
|
| 16 |
"isort>=8.0.1",
|
| 17 |
+
"python-dotenv>=1.2.2",
|
| 18 |
]
|
| 19 |
|
| 20 |
[project.scripts]
|
uv.lock
CHANGED
|
@@ -1176,6 +1176,7 @@ dependencies = [
|
|
| 1176 |
{ name = "openenv-core" },
|
| 1177 |
{ name = "pandas" },
|
| 1178 |
{ name = "pydantic" },
|
|
|
|
| 1179 |
{ name = "uvicorn" },
|
| 1180 |
]
|
| 1181 |
|
|
@@ -1189,6 +1190,7 @@ requires-dist = [
|
|
| 1189 |
{ name = "openenv-core", specifier = ">=0.2.3" },
|
| 1190 |
{ name = "pandas", specifier = ">=2.0.0" },
|
| 1191 |
{ name = "pydantic", specifier = ">=2.0.0" },
|
|
|
|
| 1192 |
{ name = "uvicorn", specifier = ">=0.24.0" },
|
| 1193 |
]
|
| 1194 |
|
|
|
|
| 1176 |
{ name = "openenv-core" },
|
| 1177 |
{ name = "pandas" },
|
| 1178 |
{ name = "pydantic" },
|
| 1179 |
+
{ name = "python-dotenv" },
|
| 1180 |
{ name = "uvicorn" },
|
| 1181 |
]
|
| 1182 |
|
|
|
|
| 1190 |
{ name = "openenv-core", specifier = ">=0.2.3" },
|
| 1191 |
{ name = "pandas", specifier = ">=2.0.0" },
|
| 1192 |
{ name = "pydantic", specifier = ">=2.0.0" },
|
| 1193 |
+
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1194 |
{ name = "uvicorn", specifier = ">=0.24.0" },
|
| 1195 |
]
|
| 1196 |
|