Spaces:
Running
Running
Commit ·
9fd03cb
1
Parent(s): 427e52b
Fix inference.py: handle missing API key gracefully, wrap all exceptions
Browse files- python/inference.py +65 -21
python/inference.py
CHANGED
|
@@ -56,7 +56,7 @@ API_BASE_URL = os.getenv("API_BASE_URL", "<your-active-endpoint>")
|
|
| 56 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 57 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or HF_TOKEN
|
| 58 |
if not OPENAI_API_KEY:
|
| 59 |
-
|
| 60 |
DEFAULT_EPISODES = 1
|
| 61 |
DEFAULT_SEED_BASE = 1000
|
| 62 |
MAX_RETRIES = 3
|
|
@@ -121,26 +121,42 @@ class GridMindEnvClient:
|
|
| 121 |
except Exception:
|
| 122 |
return False
|
| 123 |
|
| 124 |
-
def reset(self, task_id: int = 1, seed: int = 42, num_buildings: int = 1) -> dict:
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
def step(self, action: dict) -> dict:
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
def grade(self) -> dict:
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
-
def state(self) -> dict:
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
# ── LLM agent ───────────────────────────────────────────────────────────────
|
|
@@ -296,7 +312,19 @@ def run_episode(
|
|
| 296 |
[END] success=<true|false> steps=<n> rewards=<r1,r2,...>
|
| 297 |
"""
|
| 298 |
reset_resp = env_client.reset(task_id=task_id, seed=seed)
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
task_name = f"gridmind-task-{task_id}"
|
| 302 |
|
|
@@ -329,8 +357,13 @@ def run_episode(
|
|
| 329 |
action = cached_action
|
| 330 |
|
| 331 |
step_resp = env_client.step(action)
|
| 332 |
-
if step_resp is None or "observation" not in step_resp:
|
| 333 |
-
last_error = "invalid step response"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
break
|
| 335 |
|
| 336 |
if not fast_mode:
|
|
@@ -424,6 +457,11 @@ def main() -> None:
|
|
| 424 |
help="Stop after N steps (default: full episode). Grade uses partial episode.",
|
| 425 |
)
|
| 426 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
print("=" * 60)
|
| 429 |
print("GridMind-RL Baseline Inference")
|
|
@@ -510,4 +548,10 @@ def main() -> None:
|
|
| 510 |
|
| 511 |
|
| 512 |
if __name__ == "__main__":
|
| 513 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 57 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or HF_TOKEN
|
| 58 |
if not OPENAI_API_KEY:
|
| 59 |
+
print("[WARN] No HF_TOKEN or OPENAI_API_KEY set - will use heuristic mode if --fast-mode is set")
|
| 60 |
DEFAULT_EPISODES = 1
|
| 61 |
DEFAULT_SEED_BASE = 1000
|
| 62 |
MAX_RETRIES = 3
|
|
|
|
| 121 |
except Exception:
|
| 122 |
return False
|
| 123 |
|
| 124 |
+
def reset(self, task_id: int = 1, seed: int = 42, num_buildings: int = 1) -> dict | None:
|
| 125 |
+
try:
|
| 126 |
+
payload = {"task_id": task_id, "seed": seed, "num_buildings": num_buildings}
|
| 127 |
+
r = requests.post(f"{self.base}/reset", json=payload, timeout=self.timeout)
|
| 128 |
+
r.raise_for_status()
|
| 129 |
+
return r.json()
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"[ERROR] Failed to reset environment: {e}", file=sys.stderr)
|
| 132 |
+
return None
|
| 133 |
|
| 134 |
+
def step(self, action: dict) -> dict | None:
|
| 135 |
+
try:
|
| 136 |
+
r = requests.post(f"{self.base}/step", json=action, timeout=self.timeout)
|
| 137 |
+
r.raise_for_status()
|
| 138 |
+
return r.json()
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"[ERROR] Failed to step environment: {e}", file=sys.stderr)
|
| 141 |
+
return None
|
| 142 |
|
| 143 |
def grade(self) -> dict:
|
| 144 |
+
try:
|
| 145 |
+
r = requests.get(f"{self.base}/grade", timeout=self.timeout)
|
| 146 |
+
r.raise_for_status()
|
| 147 |
+
return r.json()
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"[ERROR] Failed to grade: {e}", file=sys.stderr)
|
| 150 |
+
return {"score": 0.0, "sub_scores": {}, "exploit_detected": False}
|
| 151 |
|
| 152 |
+
def state(self) -> dict | None:
|
| 153 |
+
try:
|
| 154 |
+
r = requests.get(f"{self.base}/state", timeout=self.timeout)
|
| 155 |
+
r.raise_for_status()
|
| 156 |
+
return r.json()
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"[ERROR] Failed to get state: {e}", file=sys.stderr)
|
| 159 |
+
return None
|
| 160 |
|
| 161 |
|
| 162 |
# ── LLM agent ───────────────────────────────────────────────────────────────
|
|
|
|
| 312 |
[END] success=<true|false> steps=<n> rewards=<r1,r2,...>
|
| 313 |
"""
|
| 314 |
reset_resp = env_client.reset(task_id=task_id, seed=seed)
|
| 315 |
+
if reset_resp is None:
|
| 316 |
+
print(f"[END] success=false steps=0 rewards=", flush=True)
|
| 317 |
+
return {
|
| 318 |
+
"task_id": task_id,
|
| 319 |
+
"seed": seed,
|
| 320 |
+
"total_reward": 0.0,
|
| 321 |
+
"total_steps": 0,
|
| 322 |
+
"elapsed_sec": 0.0,
|
| 323 |
+
"score": 0.0,
|
| 324 |
+
"sub_scores": {},
|
| 325 |
+
"exploit_detected": False,
|
| 326 |
+
}
|
| 327 |
+
obs = reset_resp.get("observations", [{}])[0]
|
| 328 |
|
| 329 |
task_name = f"gridmind-task-{task_id}"
|
| 330 |
|
|
|
|
| 357 |
action = cached_action
|
| 358 |
|
| 359 |
step_resp = env_client.step(action)
|
| 360 |
+
if step_resp is None or not isinstance(step_resp, dict) or "observation" not in step_resp:
|
| 361 |
+
last_error = "invalid step response from environment"
|
| 362 |
+
print(
|
| 363 |
+
f"[STEP] step={total_steps + 1} action=null "
|
| 364 |
+
f"reward=0.00 done=true error=\"{last_error}\"",
|
| 365 |
+
flush=True
|
| 366 |
+
)
|
| 367 |
break
|
| 368 |
|
| 369 |
if not fast_mode:
|
|
|
|
| 457 |
help="Stop after N steps (default: full episode). Grade uses partial episode.",
|
| 458 |
)
|
| 459 |
args = parser.parse_args()
|
| 460 |
+
|
| 461 |
+
# Validate API key AFTER argparse (allows --fast-mode to bypass)
|
| 462 |
+
if not OPENAI_API_KEY and not args.fast_mode:
|
| 463 |
+
print("[WARN] No API key set, switching to fast-mode (heuristic)", file=sys.stderr)
|
| 464 |
+
args.fast_mode = True
|
| 465 |
|
| 466 |
print("=" * 60)
|
| 467 |
print("GridMind-RL Baseline Inference")
|
|
|
|
| 548 |
|
| 549 |
|
| 550 |
if __name__ == "__main__":
|
| 551 |
+
try:
|
| 552 |
+
main()
|
| 553 |
+
except Exception as e:
|
| 554 |
+
print(f"[FATAL] Unhandled exception: {e}", file=sys.stderr)
|
| 555 |
+
import traceback
|
| 556 |
+
traceback.print_exc(file=sys.stderr)
|
| 557 |
+
sys.exit(1)
|