Mohammed-Altaf commited on
Commit
0c731dd
·
1 Parent(s): 82f3f96

added inference.py script

Browse files
Files changed (3) hide show
  1. inference.py +208 -0
  2. pyproject.toml +1 -0
  3. uv.lock +2 -0
inference.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference script for the Data Analysis Agent environment.
2
+
3
+ Runs a language model agent against all 3 tasks and reports scores.
4
+ Uses the OpenAI-compatible client pointed at API_BASE_URL.
5
+
6
+ Required environment variables (set in .env or shell):
7
+ API_BASE_URL OpenAI-compatible LLM API endpoint
8
+ MODEL_NAME Model identifier to use for inference
9
+ HF_TOKEN API key (Hugging Face token or other provider key)
10
+
11
+ Optional:
12
+ ENV_SERVER_URL Environment server URL (default: http://localhost:7860)
13
+
14
+ Usage:
15
+ uv run python inference.py
16
+ uv run python inference.py --env-url http://localhost:8000
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import os
22
+ import sys
23
+
24
+ from dotenv import load_dotenv
25
+ from openai import OpenAI
26
+
27
+ from client import DataAnalysisClient
28
+ from models import DataAction
29
+
30
+ # Load .env file if present (safe — does not override already-set shell vars)
31
+ load_dotenv()
32
+
33
+ TEMPERATURE = 0.0
34
+ MAX_TOKENS = 1024
35
+ MAX_STEPS = 15 # Per task — keeps total runtime well under 20 min
36
+
37
+ SYSTEM_PROMPT = """You are a data analyst. You are given a dataset loaded as a pandas DataFrame called `df`.
38
+ You can execute Python/pandas code to explore the dataset and answer the question.
39
+
40
+ Rules:
41
+ - Use `print()` to see results of your code
42
+ - The DataFrame `df` is pre-loaded with pandas as `pd` and numpy as `np`
43
+ - When you have the answer, submit it in the exact format requested
44
+ - Be precise with numbers and formatting
45
+
46
+ Respond with JSON in one of these formats:
47
+ 1. To execute code: {"action": "execute_code", "code": "your python code here"}
48
+ 2. To submit answer: {"action": "submit_answer", "answer": "your answer here"}
49
+
50
+ Respond with ONLY the JSON, no other text."""
51
+
52
+ FALLBACK_ACTION = json.dumps({"action": "submit_answer", "answer": "unknown"})
53
+
54
+
55
+ def parse_model_action(response_text: str) -> dict:
56
+ """Parse the model's raw text response into an action dict.
57
+
58
+ Handles plain JSON and markdown code block wrapping.
59
+
60
+ Args:
61
+ response_text: Raw string returned by the model.
62
+
63
+ Returns:
64
+ Parsed action dict, or a fallback submit_answer on failure.
65
+ """
66
+ text = response_text.strip()
67
+ if text.startswith("```"):
68
+ parts = text.split("```")
69
+ if len(parts) >= 2:
70
+ text = parts[1]
71
+ if text.startswith("json"):
72
+ text = text[4:]
73
+ text = text.strip()
74
+ try:
75
+ return json.loads(text)
76
+ except json.JSONDecodeError:
77
+ return json.loads(FALLBACK_ACTION)
78
+
79
+
80
+ def run_task(openai_client: OpenAI, env_client: DataAnalysisClient, task_id: int) -> float:
81
+ """Run a single task episode using the language model as the agent.
82
+
83
+ Args:
84
+ openai_client: Configured OpenAI-compatible client.
85
+ env_client: Connected DataAnalysisClient (sync wrapper).
86
+ task_id: Task to evaluate (1 = easy, 2 = medium, 3 = hard).
87
+
88
+ Returns:
89
+ Final score for this task between 0.0 and 1.0.
90
+ """
91
+ result = env_client.reset(task_id=task_id)
92
+ obs = result.observation
93
+
94
+ messages = [
95
+ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
96
+ {
97
+ "role": "user",
98
+ "content": [
99
+ {
100
+ "type": "text",
101
+ "text": f"Task: {obs.task_description}\n\nDataset Info:\n{obs.dataset_info}",
102
+ }
103
+ ],
104
+ },
105
+ ]
106
+
107
+ print(f"\n--- Task {task_id} ---")
108
+ print(f"Question: {obs.task_description}")
109
+
110
+ for step in range(MAX_STEPS):
111
+ try:
112
+ completion = openai_client.chat.completions.create(
113
+ model=os.environ["MODEL_NAME"],
114
+ messages=messages,
115
+ temperature=TEMPERATURE,
116
+ max_tokens=MAX_TOKENS,
117
+ stream=False,
118
+ )
119
+ response_text = completion.choices[0].message.content or ""
120
+ except Exception as exc:
121
+ print(f" Model request failed ({exc}). Using fallback action.")
122
+ response_text = FALLBACK_ACTION
123
+
124
+ action = parse_model_action(response_text)
125
+ action_type = action.get("action", "")
126
+ print(f" Step {step + 1}: model suggested -> {action_type}")
127
+
128
+ if action_type == "execute_code":
129
+ step_result = env_client.step(
130
+ DataAction(action_type="execute_code", code=action.get("code", ""))
131
+ )
132
+ step_obs = step_result.observation
133
+ result_text = f"Output: {step_obs.output}" if not step_obs.error else f"Error: {step_obs.error}"
134
+ print(f" -> {result_text[:120]}")
135
+
136
+ messages.append({"role": "assistant", "content": response_text})
137
+ messages.append({"role": "user", "content": [{"type": "text", "text": result_text}]})
138
+
139
+ elif action_type == "submit_answer":
140
+ step_result = env_client.step(
141
+ DataAction(action_type="submit_answer", answer=action.get("answer", ""))
142
+ )
143
+ step_obs = step_result.observation
144
+ score = step_obs.metadata.get("score", 0.0) if step_obs.metadata else step_result.reward
145
+ print(f" -> submitted: '{action.get('answer', '')}' | score: {score:.2f}")
146
+ return float(score)
147
+
148
+ else:
149
+ messages.append({"role": "assistant", "content": response_text})
150
+ messages.append({
151
+ "role": "user",
152
+ "content": [{"type": "text", "text": f"Unknown action '{action_type}'. Use 'execute_code' or 'submit_answer'."}],
153
+ })
154
+
155
+ print(f" Reached max steps ({MAX_STEPS}). No answer submitted.")
156
+ return 0.0
157
+
158
+
159
+ def main():
160
+ """Run inference across all 3 tasks and print final scores."""
161
+ parser = argparse.ArgumentParser(description="Data Analysis Agent inference script")
162
+ parser.add_argument(
163
+ "--env-url",
164
+ default=os.environ.get("ENV_SERVER_URL", "http://localhost:7860"),
165
+ help="Environment server URL (default: http://localhost:7860)",
166
+ )
167
+ args = parser.parse_args()
168
+
169
+ # Validate required environment variables
170
+ missing = [v for v in ("API_BASE_URL", "MODEL_NAME", "HF_TOKEN") if not os.environ.get(v)]
171
+ if missing:
172
+ print(f"Error: Missing required environment variables: {', '.join(missing)}")
173
+ print("Set them in your shell or create a .env file (see .env.example).")
174
+ sys.exit(1)
175
+
176
+ openai_client = OpenAI(
177
+ base_url=os.environ["API_BASE_URL"],
178
+ api_key=os.environ["HF_TOKEN"],
179
+ )
180
+
181
+ print("=" * 55)
182
+ print("Data Analysis Agent — Inference")
183
+ print(f"Server : {args.env_url}")
184
+ print(f"Model : {os.environ['MODEL_NAME']}")
185
+ print(f"API : {os.environ['API_BASE_URL']}")
186
+ print("=" * 55)
187
+
188
+ scores = {}
189
+ difficulties = {1: "Easy", 2: "Medium", 3: "Hard"}
190
+
191
+ # Each task gets its own isolated WebSocket session
192
+ for task_id in [1, 2, 3]:
193
+ with DataAnalysisClient(base_url=args.env_url).sync() as env_client:
194
+ score = run_task(openai_client, env_client, task_id)
195
+ scores[task_id] = score
196
+
197
+ print("\n" + "=" * 55)
198
+ print("RESULTS")
199
+ print("=" * 55)
200
+ for task_id, score in scores.items():
201
+ print(f" Task {task_id} ({difficulties[task_id]:6s}): {score:.2f}")
202
+ avg = sum(scores.values()) / len(scores)
203
+ print(f"\n Average Score : {avg:.2f}")
204
+ print("=" * 55)
205
+
206
+
207
+ if __name__ == "__main__":
208
+ main()
pyproject.toml CHANGED
@@ -14,6 +14,7 @@ dependencies = [
14
  "openai>=1.0.0",
15
  "black>=26.3.1",
16
  "isort>=8.0.1",
 
17
  ]
18
 
19
  [project.scripts]
 
14
  "openai>=1.0.0",
15
  "black>=26.3.1",
16
  "isort>=8.0.1",
17
+ "python-dotenv>=1.2.2",
18
  ]
19
 
20
  [project.scripts]
uv.lock CHANGED
@@ -1176,6 +1176,7 @@ dependencies = [
1176
  { name = "openenv-core" },
1177
  { name = "pandas" },
1178
  { name = "pydantic" },
 
1179
  { name = "uvicorn" },
1180
  ]
1181
 
@@ -1189,6 +1190,7 @@ requires-dist = [
1189
  { name = "openenv-core", specifier = ">=0.2.3" },
1190
  { name = "pandas", specifier = ">=2.0.0" },
1191
  { name = "pydantic", specifier = ">=2.0.0" },
 
1192
  { name = "uvicorn", specifier = ">=0.24.0" },
1193
  ]
1194
 
 
1176
  { name = "openenv-core" },
1177
  { name = "pandas" },
1178
  { name = "pydantic" },
1179
+ { name = "python-dotenv" },
1180
  { name = "uvicorn" },
1181
  ]
1182
 
 
1190
  { name = "openenv-core", specifier = ">=0.2.3" },
1191
  { name = "pandas", specifier = ">=2.0.0" },
1192
  { name = "pydantic", specifier = ">=2.0.0" },
1193
+ { name = "python-dotenv", specifier = ">=1.2.2" },
1194
  { name = "uvicorn", specifier = ">=0.24.0" },
1195
  ]
1196