Akshaykumarbm commited on
Commit
f8efd56
·
verified ·
1 Parent(s): 8e97e82

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. inference.py +224 -166
  2. pyproject.toml +6 -0
  3. uv.lock +6 -0
inference.py CHANGED
@@ -1,72 +1,41 @@
1
  """
2
- LLM-based Inference Script for Meeting Scheduling RL Environment.
3
- ===================================
4
- Uses OpenAI-compatible LLM via HF Router to intelligently schedule meetings.
5
-
6
- MANDATORY environment variables:
7
- API_BASE_URL The API endpoint for the LLM.
8
- MODEL_NAME The model identifier to use for inference.
9
- HF_TOKEN Your Hugging Face / API key.
10
-
11
- STDOUT FORMAT:
12
- [START] task=<task_name> env=scheduling_env model=<model_name>
13
- [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
14
- [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
15
- """
16
 
17
- import asyncio
 
 
 
 
 
 
 
 
 
 
18
  import json
19
  import os
 
20
  import textwrap
21
- from typing import Dict, List, Optional
22
 
23
  from openai import OpenAI
24
 
25
- from scheduling_env.client import SchedulingEnv
26
- from scheduling_env.models import SchedulingAction
27
 
28
- # ---------------------------------------------------------------------------
29
- # Configuration
30
- # ---------------------------------------------------------------------------
31
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
32
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
33
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
 
 
34
 
35
- ENV_REPO_ID = "Akshaykumarbm/scheduling_env"
36
- BENCHMARK = "scheduling_env"
37
- TASKS = ["task1_easy", "task2_medium", "task3_hard"]
38
- MAX_STEPS = 20
39
- TEMPERATURE = 0.3
40
- MAX_TOKENS = 512
41
 
42
- # ---------------------------------------------------------------------------
43
- # Logging helpers
44
- # ---------------------------------------------------------------------------
45
-
46
- def log_start(task: str, env: str, model: str) -> None:
47
- print(f"[START] task={task} env={env} model={model}", flush=True)
48
 
49
-
50
- def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
51
- error_val = error if error else "null"
52
- done_val = str(done).lower()
53
- print(
54
- f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
55
- flush=True,
56
- )
57
-
58
-
59
- def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
60
- rewards_str = ",".join(f"{r:.2f}" for r in rewards)
61
- print(
62
- f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
63
- flush=True,
64
- )
65
-
66
-
67
- # ---------------------------------------------------------------------------
68
- # LLM interaction
69
- # ---------------------------------------------------------------------------
70
 
71
  SYSTEM_PROMPT = textwrap.dedent("""\
72
  You are an AI meeting scheduling assistant. You must schedule a meeting by choosing actions.
@@ -95,28 +64,59 @@ Rules:
95
  """)
96
 
97
 
98
- def format_observation(obs, step: int) -> str:
99
- """Convert a SchedulingObservation into a user prompt for the LLM."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  parts = [
101
- f"Step {step}/{obs.max_steps}",
102
- f"Meeting to schedule: {obs.requested_duration} min, priority {obs.requested_priority}",
103
- f"Attendees: {', '.join(obs.attendee_ids)}",
104
- f"Collective working hours: {obs.collective_work_hours.get('min_start_hour', 9)}:00 - {obs.collective_work_hours.get('max_end_hour', 17)}:00",
105
  ]
106
 
107
- if obs.preference_constraints:
108
- parts.append(f"Preferences: max {obs.preference_constraints.get('max_meetings_per_day', 'N/A')} meetings/day, "
109
- f"buffer required: {obs.preference_constraints.get('requires_buffer', False)}, "
110
- f"buffer mins: {obs.preference_constraints.get('buffer_minutes', 0)}")
 
 
 
 
 
 
111
 
112
  # Busy slots grouped by attendee
113
  busy_by_attendee: Dict[str, List] = {}
114
- for slot in obs.busy_slots:
115
- att = slot["attendee"]
116
  busy_by_attendee.setdefault(att, []).append(slot)
117
 
118
  parts.append("\nCalendars:")
119
- for att in obs.attendee_ids:
120
  slots = busy_by_attendee.get(att, [])
121
  if slots:
122
  slot_strs = [
@@ -128,26 +128,29 @@ def format_observation(obs, step: int) -> str:
128
  else:
129
  parts.append(f" {att}: (no meetings)")
130
 
131
- if obs.current_proposal:
132
- parts.append(f"\nCurrent proposal: {obs.current_proposal['start']} to {obs.current_proposal['end']}")
 
133
 
134
- if obs.conflicts:
135
- parts.append(f"\nConflicts ({len(obs.conflicts)}):")
136
- for c in obs.conflicts:
 
137
  parts.append(
138
  f" - {c['attendee']}: {c['start']} to {c['end']} "
139
  f"(priority {c['priority']}, {c['summary']}, id: {c['meeting_id']})"
140
  )
141
 
142
- if obs.error_message:
143
- parts.append(f"\nLast error: {obs.error_message}")
 
144
 
145
- parts.append(f"\nRescheduled so far: {obs.num_rescheduled}")
146
- parts.append(f"Preference penalty: {obs.preference_penalty}")
147
 
148
- if not obs.current_proposal and not obs.conflicts:
149
  parts.append("\nAction needed: propose a time slot for the meeting.")
150
- elif obs.conflicts:
151
  parts.append("\nAction needed: reschedule a conflict (lower-priority only) or propose a different slot.")
152
  else:
153
  parts.append("\nAction needed: no conflicts remain - you should finalize.")
@@ -155,12 +158,35 @@ def format_observation(obs, step: int) -> str:
155
  return "\n".join(parts)
156
 
157
 
158
- def parse_llm_response(text: str, obs) -> SchedulingAction:
159
- """Parse LLM JSON response into a SchedulingAction, with fallback."""
160
- # Extract JSON from response (handle markdown code blocks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  cleaned = text.strip()
 
 
162
  if "```" in cleaned:
163
- # Extract content between code fences
164
  lines = cleaned.split("\n")
165
  json_lines = []
166
  in_block = False
@@ -172,7 +198,7 @@ def parse_llm_response(text: str, obs) -> SchedulingAction:
172
  json_lines.append(line)
173
  cleaned = "\n".join(json_lines).strip()
174
 
175
- # Try to find JSON object in the response
176
  start = cleaned.find("{")
177
  end = cleaned.rfind("}") + 1
178
  if start >= 0 and end > start:
@@ -180,50 +206,36 @@ def parse_llm_response(text: str, obs) -> SchedulingAction:
180
 
181
  try:
182
  data = json.loads(cleaned)
183
- return SchedulingAction(**data)
184
- except (json.JSONDecodeError, Exception) as e:
185
- print(f"[DEBUG] Failed to parse LLM response: {e}. Response: {text[:200]}", flush=True)
186
- # Fallback: if we have no proposal yet, propose at first available hour
187
- if obs.current_proposal is None:
188
- min_h = obs.collective_work_hours.get("min_start_hour", 9)
189
- return SchedulingAction(
190
- action_type="propose_slot",
191
- proposed_start=f"2025-04-07T{min_h:02d}:00:00+00:00",
192
- proposed_duration=obs.requested_duration,
193
- )
194
- elif not obs.conflicts:
195
- return SchedulingAction(action_type="finalize")
196
- else:
197
- return SchedulingAction(action_type="reject")
198
-
 
 
 
 
 
 
199
 
200
- def get_llm_action(client: OpenAI, obs, step: int) -> SchedulingAction:
201
- """Query the LLM and return a SchedulingAction."""
202
- user_prompt = format_observation(obs, step)
203
- try:
204
- completion = client.chat.completions.create(
205
- model=MODEL_NAME,
206
- messages=[
207
- {"role": "system", "content": SYSTEM_PROMPT},
208
- {"role": "user", "content": user_prompt},
209
- ],
210
- temperature=TEMPERATURE,
211
- max_tokens=MAX_TOKENS,
212
- stream=False,
213
- )
214
- text = (completion.choices[0].message.content or "").strip()
215
- return parse_llm_response(text, obs)
216
- except Exception as exc:
217
- print(f"[DEBUG] LLM request failed: {exc}", flush=True)
218
- return parse_llm_response("", obs)
219
 
 
220
 
221
- # ---------------------------------------------------------------------------
222
- # Main loop
223
- # ---------------------------------------------------------------------------
224
 
225
- async def run_task(env, client: OpenAI, task_id: str) -> None:
226
- """Run a single scheduling task."""
227
  rewards: List[float] = []
228
  steps_taken = 0
229
  score = 0.0
@@ -232,62 +244,108 @@ async def run_task(env, client: OpenAI, task_id: str) -> None:
232
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
233
 
234
  try:
235
- result = await env.reset(task_id=task_id)
236
- obs = result.observation
237
-
238
- for step in range(1, MAX_STEPS + 1):
239
- if result.done:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  break
241
 
242
- action = get_llm_action(client, obs, step)
243
-
244
- result = await env.step(action)
245
- obs = result.observation
246
-
247
- reward = result.reward or 0.0
248
- done = result.done
249
- error = obs.error_message
250
 
251
  rewards.append(reward)
252
- steps_taken = step
253
-
254
- action_str = action.action_type
255
- if action.action_type == "propose_slot":
256
- action_str = f"propose_slot({action.proposed_start},{action.proposed_duration}m)"
257
- elif action.action_type == "reschedule_meeting":
258
- action_str = f"reschedule({action.meeting_id_to_move}->{action.new_start_time})"
259
 
260
- log_step(step=step, action=action_str, reward=reward, done=done, error=error)
261
-
262
- if done:
263
- break
264
-
265
- # Score is the final reward (0.0-1.0 from calculate_final_reward)
266
  score = rewards[-1] if rewards else 0.0
267
- score = min(max(score, 0.0), 1.0)
268
- success = obs.success if hasattr(obs, "success") else (score > 0.0)
 
269
 
270
  except Exception as exc:
271
- print(f"[DEBUG] Task {task_id} error: {exc}", flush=True)
272
 
273
  finally:
274
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
275
 
276
 
277
- async def main() -> None:
278
- llm_client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
279
 
280
- env = await SchedulingEnv.from_env(ENV_REPO_ID)
 
281
 
282
- try:
283
- for task_id in TASKS:
284
- await run_task(env, llm_client, task_id)
285
- finally:
286
- try:
287
- await env.close()
288
- except Exception as e:
289
- print(f"[DEBUG] env.close() error: {e}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
 
292
  if __name__ == "__main__":
293
- asyncio.run(main())
 
1
  """
2
+ inference.py - Meeting Scheduling OpenEnv Agent
3
+
4
+ Runs an LLM agent through all 3 scheduling tasks and emits structured stdout logs.
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ Required environment variables:
7
+ API_BASE_URL LLM API endpoint (OpenAI-compatible)
8
+ MODEL_NAME Model identifier
9
+ HF_TOKEN HuggingFace / API key
10
+
11
+ Stdout format (must not deviate):
12
+ [START] task=<task> env=<benchmark> model=<model>
13
+ [STEP] step=<n> action=<action> reward=<0.00> done=<true|false> error=<msg|null>
14
+ [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
15
+ """
16
+ import argparse
17
  import json
18
  import os
19
+ import sys
20
  import textwrap
21
+ from typing import Any, Dict, List, Optional
22
 
23
  from openai import OpenAI
24
 
25
+ # -- Config -------------------------------------------------------------------
 
26
 
 
 
 
 
27
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
28
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
29
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
30
+ ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
31
 
32
+ BENCHMARK = "scheduling_env"
33
+ MAX_STEPS = 20
34
+ TEMPERATURE = 0.3
 
 
 
35
 
36
+ TASK_IDS = ["task1_easy", "task2_medium", "task3_hard"]
 
 
 
 
 
37
 
38
+ # -- System prompt ------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  SYSTEM_PROMPT = textwrap.dedent("""\
41
  You are an AI meeting scheduling assistant. You must schedule a meeting by choosing actions.
 
64
  """)
65
 
66
 
67
+ # -- Logging helpers (judge-parsed format) ------------------------------------
68
+
69
+ def log_start(task: str, env: str, model: str) -> None:
70
+ print(f"[START] task={task} env={env} model={model}", flush=True)
71
+
72
+
73
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str] = None) -> None:
74
+ error_val = error if error else "null"
75
+ done_val = str(done).lower()
76
+ print(
77
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
78
+ flush=True,
79
+ )
80
+
81
+
82
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
83
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
84
+ print(
85
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
86
+ flush=True,
87
+ )
88
+
89
+
90
+ # -- Observation formatting ---------------------------------------------------
91
+
92
+ def format_observation(obs: Dict[str, Any], step: int) -> str:
93
+ """Convert observation dict into a user prompt for the LLM."""
94
+ max_steps = obs.get("max_steps", MAX_STEPS)
95
  parts = [
96
+ f"Step {step}/{max_steps}",
97
+ f"Meeting to schedule: {obs.get('requested_duration', '?')} min, priority {obs.get('requested_priority', '?')}",
98
+ f"Attendees: {', '.join(obs.get('attendee_ids', []))}",
 
99
  ]
100
 
101
+ work_hours = obs.get("collective_work_hours", {})
102
+ parts.append(f"Collective working hours: {work_hours.get('min_start_hour', 9)}:00 - {work_hours.get('max_end_hour', 17)}:00")
103
+
104
+ prefs = obs.get("preference_constraints", {})
105
+ if prefs:
106
+ parts.append(
107
+ f"Preferences: max {prefs.get('max_meetings_per_day', 'N/A')} meetings/day, "
108
+ f"buffer required: {prefs.get('requires_buffer', False)}, "
109
+ f"buffer mins: {prefs.get('buffer_minutes', 0)}"
110
+ )
111
 
112
  # Busy slots grouped by attendee
113
  busy_by_attendee: Dict[str, List] = {}
114
+ for slot in obs.get("busy_slots", []):
115
+ att = slot.get("attendee", "unknown")
116
  busy_by_attendee.setdefault(att, []).append(slot)
117
 
118
  parts.append("\nCalendars:")
119
+ for att in obs.get("attendee_ids", []):
120
  slots = busy_by_attendee.get(att, [])
121
  if slots:
122
  slot_strs = [
 
128
  else:
129
  parts.append(f" {att}: (no meetings)")
130
 
131
+ proposal = obs.get("current_proposal")
132
+ if proposal:
133
+ parts.append(f"\nCurrent proposal: {proposal['start']} to {proposal['end']}")
134
 
135
+ conflicts = obs.get("conflicts", [])
136
+ if conflicts:
137
+ parts.append(f"\nConflicts ({len(conflicts)}):")
138
+ for c in conflicts:
139
  parts.append(
140
  f" - {c['attendee']}: {c['start']} to {c['end']} "
141
  f"(priority {c['priority']}, {c['summary']}, id: {c['meeting_id']})"
142
  )
143
 
144
+ error_msg = obs.get("error_message")
145
+ if error_msg:
146
+ parts.append(f"\nLast error: {error_msg}")
147
 
148
+ parts.append(f"\nRescheduled so far: {obs.get('num_rescheduled', 0)}")
149
+ parts.append(f"Preference penalty: {obs.get('preference_penalty', 0.0)}")
150
 
151
+ if not proposal and not conflicts:
152
  parts.append("\nAction needed: propose a time slot for the meeting.")
153
+ elif conflicts:
154
  parts.append("\nAction needed: reschedule a conflict (lower-priority only) or propose a different slot.")
155
  else:
156
  parts.append("\nAction needed: no conflicts remain - you should finalize.")
 
158
  return "\n".join(parts)
159
 
160
 
161
+ # -- LLM call -----------------------------------------------------------------
162
+
163
+ def call_llm(client: OpenAI, obs: Dict[str, Any], step: int) -> Dict[str, Any]:
164
+ """Ask the LLM for the next action given the current observation."""
165
+ user_prompt = format_observation(obs, step)
166
+
167
+ try:
168
+ completion = client.chat.completions.create(
169
+ model=MODEL_NAME,
170
+ messages=[
171
+ {"role": "system", "content": SYSTEM_PROMPT},
172
+ {"role": "user", "content": user_prompt},
173
+ ],
174
+ temperature=TEMPERATURE,
175
+ max_tokens=512,
176
+ )
177
+ text = (completion.choices[0].message.content or "").strip()
178
+ return parse_llm_response(text, obs)
179
+ except Exception as exc:
180
+ print(f"[DEBUG] LLM error: {exc}", file=sys.stderr, flush=True)
181
+ return fallback_action(obs)
182
+
183
+
184
+ def parse_llm_response(text: str, obs: Dict[str, Any]) -> Dict[str, Any]:
185
+ """Parse LLM JSON response into an action dict, with fallback."""
186
  cleaned = text.strip()
187
+
188
+ # Handle markdown code blocks
189
  if "```" in cleaned:
 
190
  lines = cleaned.split("\n")
191
  json_lines = []
192
  in_block = False
 
198
  json_lines.append(line)
199
  cleaned = "\n".join(json_lines).strip()
200
 
201
+ # Extract JSON object
202
  start = cleaned.find("{")
203
  end = cleaned.rfind("}") + 1
204
  if start >= 0 and end > start:
 
206
 
207
  try:
208
  data = json.loads(cleaned)
209
+ if "action_type" not in data:
210
+ raise ValueError("No action_type in response")
211
+ return data
212
+ except (json.JSONDecodeError, ValueError) as e:
213
+ print(f"[DEBUG] Parse error: {e}. Response: {text[:200]}", file=sys.stderr, flush=True)
214
+ return fallback_action(obs)
215
+
216
+
217
+ def fallback_action(obs: Dict[str, Any]) -> Dict[str, Any]:
218
+ """Produce a safe fallback action based on current observation state."""
219
+ if obs.get("current_proposal") is None:
220
+ min_h = obs.get("collective_work_hours", {}).get("min_start_hour", 9)
221
+ duration = obs.get("requested_duration", 30)
222
+ return {
223
+ "action_type": "propose_slot",
224
+ "proposed_start": f"2025-04-07T{min_h:02d}:00:00+00:00",
225
+ "proposed_duration": duration,
226
+ }
227
+ elif not obs.get("conflicts"):
228
+ return {"action_type": "finalize"}
229
+ else:
230
+ return {"action_type": "reject"}
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
+ # -- Episode runner -----------------------------------------------------------
234
 
235
+ def run_episode(client: OpenAI, task_id: str) -> None:
236
+ """Run one full episode for a task, emitting [START]/[STEP]/[END] logs."""
237
+ import requests
238
 
 
 
239
  rewards: List[float] = []
240
  steps_taken = 0
241
  score = 0.0
 
244
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
245
 
246
  try:
247
+ # Reset environment
248
+ try:
249
+ resp = requests.post(
250
+ f"{ENV_URL}/reset",
251
+ json={"task_id": task_id},
252
+ timeout=30,
253
+ )
254
+ resp.raise_for_status()
255
+ reset_data = resp.json()
256
+ except Exception as e:
257
+ print(f"[DEBUG] Reset failed: {e}", file=sys.stderr, flush=True)
258
+ log_end(success=False, steps=0, score=0.0, rewards=[])
259
+ return
260
+
261
+ observation = reset_data.get("observation", reset_data)
262
+ done = reset_data.get("done", False)
263
+
264
+ # Episode loop
265
+ while not done and steps_taken < MAX_STEPS:
266
+ steps_taken += 1
267
+
268
+ # Get action from LLM
269
+ action = call_llm(client, observation, steps_taken)
270
+ action_type = action.get("action_type", "unknown")
271
+
272
+ # Build compact action string for logging
273
+ if action_type == "propose_slot":
274
+ action_str = f"propose_slot({action.get('proposed_start', '?')[:16]},{action.get('proposed_duration', '?')}m)"
275
+ elif action_type == "reschedule_meeting":
276
+ action_str = f"reschedule({action.get('meeting_id_to_move', '?')[:20]})"
277
+ else:
278
+ action_str = action_type
279
+
280
+ # Execute step
281
+ try:
282
+ step_resp = requests.post(
283
+ f"{ENV_URL}/step",
284
+ json={"action": action},
285
+ timeout=30,
286
+ )
287
+ step_resp.raise_for_status()
288
+ step_data = step_resp.json()
289
+ except Exception as e:
290
+ print(f"[DEBUG] Step failed: {e}", file=sys.stderr, flush=True)
291
+ rewards.append(0.0)
292
+ log_step(step=steps_taken, action=action_str, reward=0.0, done=True, error=str(e))
293
  break
294
 
295
+ observation = step_data.get("observation", {})
296
+ reward = step_data.get("reward", 0.0) or 0.0
297
+ done = step_data.get("done", False)
298
+ error = observation.get("error_message")
 
 
 
 
299
 
300
  rewards.append(reward)
301
+ log_step(step=steps_taken, action=action_str, reward=reward, done=done, error=error)
 
 
 
 
 
 
302
 
303
+ # Final score is the last reward (0.0-1.0 from calculate_final_reward)
 
 
 
 
 
304
  score = rewards[-1] if rewards else 0.0
305
+ # Clamp to (0.01, 0.99) as required by judge
306
+ score = max(0.01, min(score, 0.99))
307
+ success = score > 0.3
308
 
309
  except Exception as exc:
310
+ print(f"[DEBUG] Episode error: {exc}", file=sys.stderr, flush=True)
311
 
312
  finally:
313
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
314
 
315
 
316
+ # -- Main ---------------------------------------------------------------------
 
317
 
318
+ def main():
319
+ global ENV_URL
320
 
321
+ parser = argparse.ArgumentParser(description="Scheduling env baseline inference")
322
+ parser.add_argument("--task", choices=TASK_IDS, help="Run a specific task only")
323
+ parser.add_argument("--all", action="store_true", help="Run all 3 tasks (default)")
324
+ parser.add_argument("--url", default=ENV_URL, help="Environment base URL")
325
+ args = parser.parse_args()
326
+
327
+ ENV_URL = args.url
328
+
329
+ # Check for TASK_NAME environment variable (judge may set this)
330
+ target_task = os.getenv("TASK_NAME")
331
+ if target_task:
332
+ if "task1" in target_task or "easy" in target_task:
333
+ args.task = "task1_easy"
334
+ elif "task2" in target_task or "medium" in target_task:
335
+ args.task = "task2_medium"
336
+ elif "task3" in target_task or "hard" in target_task:
337
+ args.task = "task3_hard"
338
+
339
+ if not HF_TOKEN:
340
+ print("[ERROR] HF_TOKEN environment variable not set", file=sys.stderr)
341
+ sys.exit(1)
342
+
343
+ client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
344
+ tasks = [args.task] if args.task else TASK_IDS
345
+
346
+ for task_id in tasks:
347
+ run_episode(client, task_id)
348
 
349
 
350
  if __name__ == "__main__":
351
+ main()
pyproject.toml CHANGED
@@ -16,8 +16,14 @@ requires-python = ">=3.10"
16
  dependencies = [
17
  # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
  "openenv-core[core]>=0.2.2",
 
 
19
  # OpenAI client for LLM-based inference
20
  "openai>=1.0.0",
 
 
 
 
21
  ]
22
 
23
  [project.optional-dependencies]
 
16
  dependencies = [
17
  # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
  "openenv-core[core]>=0.2.2",
19
+ # Pydantic for data models
20
+ "pydantic>=2.5.0,<3.0.0",
21
  # OpenAI client for LLM-based inference
22
  "openai>=1.0.0",
23
+ # HTTP requests for inference script
24
+ "requests>=2.31.0",
25
+ # ASGI server
26
+ "uvicorn>=0.44.0",
27
  ]
28
 
29
  [project.optional-dependencies]
uv.lock CHANGED
@@ -1605,6 +1605,9 @@ source = { editable = "." }
1605
  dependencies = [
1606
  { name = "openai" },
1607
  { name = "openenv-core", extra = ["core"] },
 
 
 
1608
  ]
1609
 
1610
  [package.optional-dependencies]
@@ -1617,8 +1620,11 @@ dev = [
1617
  requires-dist = [
1618
  { name = "openai", specifier = ">=1.0.0" },
1619
  { name = "openenv-core", extras = ["core"], specifier = ">=0.2.2" },
 
1620
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
1621
  { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0.0" },
 
 
1622
  ]
1623
  provides-extras = ["dev"]
1624
 
 
1605
  dependencies = [
1606
  { name = "openai" },
1607
  { name = "openenv-core", extra = ["core"] },
1608
+ { name = "pydantic" },
1609
+ { name = "requests" },
1610
+ { name = "uvicorn" },
1611
  ]
1612
 
1613
  [package.optional-dependencies]
 
1620
  requires-dist = [
1621
  { name = "openai", specifier = ">=1.0.0" },
1622
  { name = "openenv-core", extras = ["core"], specifier = ">=0.2.2" },
1623
+ { name = "pydantic", specifier = ">=2.5.0,<3.0.0" },
1624
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
1625
  { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0.0" },
1626
+ { name = "requests", specifier = ">=2.31.0" },
1627
+ { name = "uvicorn", specifier = ">=0.44.0" },
1628
  ]
1629
  provides-extras = ["dev"]
1630