ShreeshantXD commited on
Commit
9fd03cb
·
1 Parent(s): 427e52b

Fix inference.py: handle missing API key gracefully, wrap all exceptions

Browse files
Files changed (1) hide show
  1. python/inference.py +65 -21
python/inference.py CHANGED
@@ -56,7 +56,7 @@ API_BASE_URL = os.getenv("API_BASE_URL", "<your-active-endpoint>")
56
  HF_TOKEN = os.getenv("HF_TOKEN")
57
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or HF_TOKEN
58
  if not OPENAI_API_KEY:
59
- raise ValueError("HF_TOKEN or OPENAI_API_KEY environment variable is required")
60
  DEFAULT_EPISODES = 1
61
  DEFAULT_SEED_BASE = 1000
62
  MAX_RETRIES = 3
@@ -121,26 +121,42 @@ class GridMindEnvClient:
121
  except Exception:
122
  return False
123
 
124
- def reset(self, task_id: int = 1, seed: int = 42, num_buildings: int = 1) -> dict:
125
- payload = {"task_id": task_id, "seed": seed, "num_buildings": num_buildings}
126
- r = requests.post(f"{self.base}/reset", json=payload, timeout=self.timeout)
127
- r.raise_for_status()
128
- return r.json()
 
 
 
 
129
 
130
- def step(self, action: dict) -> dict:
131
- r = requests.post(f"{self.base}/step", json=action, timeout=self.timeout)
132
- r.raise_for_status()
133
- return r.json()
 
 
 
 
134
 
135
  def grade(self) -> dict:
136
- r = requests.get(f"{self.base}/grade", timeout=self.timeout)
137
- r.raise_for_status()
138
- return r.json()
 
 
 
 
139
 
140
- def state(self) -> dict:
141
- r = requests.get(f"{self.base}/state", timeout=self.timeout)
142
- r.raise_for_status()
143
- return r.json()
 
 
 
 
144
 
145
 
146
  # ── LLM agent ───────────────────────────────────────────────────────────────
@@ -296,7 +312,19 @@ def run_episode(
296
  [END] success=<true|false> steps=<n> rewards=<r1,r2,...>
297
  """
298
  reset_resp = env_client.reset(task_id=task_id, seed=seed)
299
- obs = reset_resp["observations"][0]
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  task_name = f"gridmind-task-{task_id}"
302
 
@@ -329,8 +357,13 @@ def run_episode(
329
  action = cached_action
330
 
331
  step_resp = env_client.step(action)
332
- if step_resp is None or "observation" not in step_resp:
333
- last_error = "invalid step response"
 
 
 
 
 
334
  break
335
 
336
  if not fast_mode:
@@ -424,6 +457,11 @@ def main() -> None:
424
  help="Stop after N steps (default: full episode). Grade uses partial episode.",
425
  )
426
  args = parser.parse_args()
 
 
 
 
 
427
 
428
  print("=" * 60)
429
  print("GridMind-RL Baseline Inference")
@@ -510,4 +548,10 @@ def main() -> None:
510
 
511
 
512
  if __name__ == "__main__":
513
- main()
 
 
 
 
 
 
 
56
  HF_TOKEN = os.getenv("HF_TOKEN")
57
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or HF_TOKEN
58
  if not OPENAI_API_KEY:
59
+ print("[WARN] No HF_TOKEN or OPENAI_API_KEY set - will use heuristic mode if --fast-mode is set")
60
  DEFAULT_EPISODES = 1
61
  DEFAULT_SEED_BASE = 1000
62
  MAX_RETRIES = 3
 
121
  except Exception:
122
  return False
123
 
124
+ def reset(self, task_id: int = 1, seed: int = 42, num_buildings: int = 1) -> dict | None:
125
+ try:
126
+ payload = {"task_id": task_id, "seed": seed, "num_buildings": num_buildings}
127
+ r = requests.post(f"{self.base}/reset", json=payload, timeout=self.timeout)
128
+ r.raise_for_status()
129
+ return r.json()
130
+ except Exception as e:
131
+ print(f"[ERROR] Failed to reset environment: {e}", file=sys.stderr)
132
+ return None
133
 
134
+ def step(self, action: dict) -> dict | None:
135
+ try:
136
+ r = requests.post(f"{self.base}/step", json=action, timeout=self.timeout)
137
+ r.raise_for_status()
138
+ return r.json()
139
+ except Exception as e:
140
+ print(f"[ERROR] Failed to step environment: {e}", file=sys.stderr)
141
+ return None
142
 
143
  def grade(self) -> dict:
144
+ try:
145
+ r = requests.get(f"{self.base}/grade", timeout=self.timeout)
146
+ r.raise_for_status()
147
+ return r.json()
148
+ except Exception as e:
149
+ print(f"[ERROR] Failed to grade: {e}", file=sys.stderr)
150
+ return {"score": 0.0, "sub_scores": {}, "exploit_detected": False}
151
 
152
+ def state(self) -> dict | None:
153
+ try:
154
+ r = requests.get(f"{self.base}/state", timeout=self.timeout)
155
+ r.raise_for_status()
156
+ return r.json()
157
+ except Exception as e:
158
+ print(f"[ERROR] Failed to get state: {e}", file=sys.stderr)
159
+ return None
160
 
161
 
162
  # ── LLM agent ───────────────────────────────────────────────────────────────
 
312
  [END] success=<true|false> steps=<n> rewards=<r1,r2,...>
313
  """
314
  reset_resp = env_client.reset(task_id=task_id, seed=seed)
315
+ if reset_resp is None:
316
+ print(f"[END] success=false steps=0 rewards=", flush=True)
317
+ return {
318
+ "task_id": task_id,
319
+ "seed": seed,
320
+ "total_reward": 0.0,
321
+ "total_steps": 0,
322
+ "elapsed_sec": 0.0,
323
+ "score": 0.0,
324
+ "sub_scores": {},
325
+ "exploit_detected": False,
326
+ }
327
+ obs = reset_resp.get("observations", [{}])[0]
328
 
329
  task_name = f"gridmind-task-{task_id}"
330
 
 
357
  action = cached_action
358
 
359
  step_resp = env_client.step(action)
360
+ if step_resp is None or not isinstance(step_resp, dict) or "observation" not in step_resp:
361
+ last_error = "invalid step response from environment"
362
+ print(
363
+ f"[STEP] step={total_steps + 1} action=null "
364
+ f"reward=0.00 done=true error=\"{last_error}\"",
365
+ flush=True
366
+ )
367
  break
368
 
369
  if not fast_mode:
 
457
  help="Stop after N steps (default: full episode). Grade uses partial episode.",
458
  )
459
  args = parser.parse_args()
460
+
461
+ # Validate API key AFTER argparse (allows --fast-mode to bypass)
462
+ if not OPENAI_API_KEY and not args.fast_mode:
463
+ print("[WARN] No API key set, switching to fast-mode (heuristic)", file=sys.stderr)
464
+ args.fast_mode = True
465
 
466
  print("=" * 60)
467
  print("GridMind-RL Baseline Inference")
 
548
 
549
 
550
  if __name__ == "__main__":
551
+ try:
552
+ main()
553
+ except Exception as e:
554
+ print(f"[FATAL] Unhandled exception: {e}", file=sys.stderr)
555
+ import traceback
556
+ traceback.print_exc(file=sys.stderr)
557
+ sys.exit(1)