refactor: update default model and API endpoint, enhance error handling, and add close method for compatibility
Browse files- inference.py +60 -95
inference.py
CHANGED
|
@@ -45,8 +45,8 @@ except ImportError:
|
|
| 45 |
# ── Constants ──────────────────────────────────────────────────────────────
|
| 46 |
|
| 47 |
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
|
| 48 |
-
MODEL_NAME = os.getenv("MODEL_NAME", "
|
| 49 |
-
API_BASE_URL = os.getenv("API_BASE_URL", "
|
| 50 |
|
| 51 |
# ── Environment Variable Handling ─────────────────────────────────────────
|
| 52 |
# The LLM API credential is read from HF_TOKEN or OPENAI_API_KEY environment variables
|
|
@@ -54,9 +54,7 @@ API_BASE_URL = os.getenv("API_BASE_URL", "<your-active-endpoint>")
|
|
| 54 |
# Primary: HF_TOKEN
|
| 55 |
# Fallback: OPENAI_API_KEY (for local testing/development)
|
| 56 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 57 |
-
OPENAI_API_KEY =
|
| 58 |
-
if not OPENAI_API_KEY:
|
| 59 |
-
print("[WARN] No HF_TOKEN or OPENAI_API_KEY set - will use heuristic mode if --fast-mode is set")
|
| 60 |
DEFAULT_EPISODES = 1
|
| 61 |
DEFAULT_SEED_BASE = 1000
|
| 62 |
MAX_RETRIES = 3
|
|
@@ -158,6 +156,10 @@ class GridMindEnvClient:
|
|
| 158 |
print(f"[ERROR] Failed to get state: {e}", file=sys.stderr)
|
| 159 |
return None
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
# ── LLM agent ───────────────────────────────────────────────────────────────
|
| 163 |
|
|
@@ -168,10 +170,7 @@ class LLMAgent:
|
|
| 168 |
def __init__(self):
|
| 169 |
# Initialize OpenAI client with credentials from HF_TOKEN (per hackathon spec)
|
| 170 |
# The OPENAI_API_KEY variable contains the HF_TOKEN value passed by evaluators
|
| 171 |
-
self.client = OpenAI(
|
| 172 |
-
base_url=API_BASE_URL,
|
| 173 |
-
api_key=OPENAI_API_KEY or "dummy-key-for-fast-mode",
|
| 174 |
-
)
|
| 175 |
self.model = MODEL_NAME
|
| 176 |
self.fallback_mode = False
|
| 177 |
|
|
@@ -223,9 +222,12 @@ Respond with ONLY a JSON action:
|
|
| 223 |
return self._clamp_action(action)
|
| 224 |
except Exception as e:
|
| 225 |
err_str = str(e)
|
| 226 |
-
print(f" [LLM attempt {attempt+1}/{MAX_RETRIES}] error: {err_str}")
|
| 227 |
if "402" in err_str or "depleted" in err_str:
|
| 228 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 229 |
self.fallback_mode = True
|
| 230 |
return self._heuristic_action(obs)
|
| 231 |
time.sleep(1)
|
|
@@ -311,44 +313,33 @@ def run_episode(
|
|
| 311 |
...
|
| 312 |
[END] success=<true|false> steps=<n> rewards=<r1,r2,...>
|
| 313 |
"""
|
| 314 |
-
reset_resp = env_client.reset(task_id=task_id, seed=seed)
|
| 315 |
-
if reset_resp is None:
|
| 316 |
-
print(f"[END] success=false steps=0 rewards=", flush=True)
|
| 317 |
-
return {
|
| 318 |
-
"task_id": task_id,
|
| 319 |
-
"seed": seed,
|
| 320 |
-
"total_reward": 0.0,
|
| 321 |
-
"total_steps": 0,
|
| 322 |
-
"elapsed_sec": 0.0,
|
| 323 |
-
"score": 0.0,
|
| 324 |
-
"sub_scores": {},
|
| 325 |
-
"exploit_detected": False,
|
| 326 |
-
}
|
| 327 |
-
obs_list = reset_resp.get("observations", [{}])
|
| 328 |
-
obs = obs_list[0] if obs_list else {}
|
| 329 |
-
|
| 330 |
task_name = f"gridmind-task-{task_id}"
|
| 331 |
-
|
| 332 |
-
# Emit [START] with required fields
|
| 333 |
print(f"[START] task={task_name} env=gridmind model={MODEL_NAME}", flush=True)
|
| 334 |
-
|
| 335 |
total_reward = 0.0
|
| 336 |
total_steps = 0
|
| 337 |
start_time = time.time()
|
| 338 |
-
step_resp: dict[str, Any] = {}
|
| 339 |
step_limit = EPISODE_STEPS if max_steps is None else min(max_steps, EPISODE_STEPS)
|
| 340 |
-
|
| 341 |
llm_reuse_remaining = 0
|
| 342 |
cached_action = agent._default_action()
|
| 343 |
-
|
| 344 |
step_rewards: list[float] = []
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
if fast_mode:
|
| 353 |
action = agent._heuristic_action(obs)
|
| 354 |
else:
|
|
@@ -359,35 +350,32 @@ def run_episode(
|
|
| 359 |
|
| 360 |
step_resp = env_client.step(action)
|
| 361 |
if step_resp is None or not isinstance(step_resp, dict) or "observation" not in step_resp:
|
| 362 |
-
last_error = "invalid step response from environment"
|
| 363 |
print(
|
| 364 |
f"[STEP] step={total_steps + 1} action=null "
|
| 365 |
-
f"reward=0.00 done=true error=
|
| 366 |
flush=True
|
| 367 |
)
|
| 368 |
break
|
| 369 |
-
|
| 370 |
if not fast_mode:
|
| 371 |
llm_reuse_remaining -= 1
|
| 372 |
-
|
| 373 |
obs = step_resp["observation"]
|
| 374 |
reward = float(step_resp["reward"])
|
| 375 |
total_reward += reward
|
| 376 |
step_rewards.append(reward)
|
| 377 |
total_steps += 1
|
| 378 |
done = bool(step_resp.get("done", False))
|
| 379 |
-
|
| 380 |
-
# Emit [STEP] with required fields (action as compact JSON, reward to 2 decimals)
|
| 381 |
action_json = json.dumps(action, separators=(',', ':'))
|
| 382 |
-
|
|
|
|
| 383 |
print(
|
| 384 |
f"[STEP] step={total_steps} action={action_json} "
|
| 385 |
f"reward={reward:.2f} done={'true' if done else 'false'} error={error_field}",
|
| 386 |
flush=True
|
| 387 |
)
|
| 388 |
-
|
| 389 |
-
last_error = None # Clear error after successful step
|
| 390 |
-
|
| 391 |
if verbose and total_steps % 16 == 0:
|
| 392 |
print(
|
| 393 |
f" step={total_steps:02d} price=${obs['current_price']:.3f} "
|
|
@@ -395,22 +383,30 @@ def run_episode(
|
|
| 395 |
f"stress={obs['grid_stress_signal']:.2f} "
|
| 396 |
f"cost=${obs['cumulative_cost']:.2f}",
|
| 397 |
flush=True,
|
|
|
|
| 398 |
)
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
print(
|
| 403 |
f"[STEP] step={total_steps + 1} action=null "
|
| 404 |
-
f"reward=0.00 done=true error=
|
| 405 |
flush=True
|
| 406 |
)
|
| 407 |
-
|
| 408 |
-
|
|
|
|
| 409 |
elapsed = time.time() - start_time
|
| 410 |
grade = env_client.grade()
|
| 411 |
-
|
| 412 |
-
# Emit [END] with required fields
|
| 413 |
-
success = last_error is None and step_resp.get("done", False)
|
| 414 |
rewards_str = ",".join(f"{r:.2f}" for r in step_rewards)
|
| 415 |
print(
|
| 416 |
f"[END] success={'true' if success else 'false'} steps={total_steps} rewards={rewards_str}",
|
|
@@ -459,41 +455,27 @@ def main() -> None:
|
|
| 459 |
)
|
| 460 |
args = parser.parse_args()
|
| 461 |
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
args.fast_mode = True
|
| 466 |
-
|
| 467 |
-
print("=" * 60)
|
| 468 |
-
print("GridMind-RL Baseline Inference")
|
| 469 |
-
print(f" Model: {MODEL_NAME}")
|
| 470 |
-
print(f" API: {API_BASE_URL}")
|
| 471 |
-
print(f" Env: {args.env_url}")
|
| 472 |
-
print(f" Episodes per task: {args.episodes}")
|
| 473 |
-
print(f" Fast mode: {args.fast_mode} | LLM every: {args.llm_every} steps")
|
| 474 |
-
print("=" * 60)
|
| 475 |
|
| 476 |
env_client = GridMindEnvClient(base_url=args.env_url)
|
| 477 |
|
| 478 |
-
print("\nWaiting for environment server...")
|
| 479 |
for attempt in range(30):
|
| 480 |
if env_client.health():
|
| 481 |
-
print(" [OK] Environment server is healthy")
|
| 482 |
break
|
| 483 |
time.sleep(2)
|
| 484 |
if attempt == 29:
|
| 485 |
-
print("
|
| 486 |
sys.exit(1)
|
| 487 |
|
| 488 |
agent = LLMAgent()
|
| 489 |
all_results: list[dict[str, Any]] = []
|
| 490 |
|
| 491 |
for task_id in [1, 2, 3]:
|
| 492 |
-
print(f"\n-- Task {task_id}: {TASK_DESCRIPTIONS[task_id][:60]}...")
|
| 493 |
task_scores: list[float] = []
|
| 494 |
for ep in range(args.episodes):
|
| 495 |
seed = DEFAULT_SEED_BASE + task_id * 100 + ep
|
| 496 |
-
print(f" Episode {ep+1}/{args.episodes} (seed={seed})")
|
| 497 |
result = run_episode(
|
| 498 |
env_client,
|
| 499 |
agent,
|
|
@@ -506,30 +488,14 @@ def main() -> None:
|
|
| 506 |
)
|
| 507 |
task_scores.append(float(result["score"]))
|
| 508 |
all_results.append(result)
|
| 509 |
-
|
| 510 |
-
f" → score={result['score']:.4f} | reward={result['total_reward']:.3f} | "
|
| 511 |
-
f"{result['elapsed_sec']:.1f}s | steps={result['total_steps']}"
|
| 512 |
-
)
|
| 513 |
-
|
| 514 |
-
avg_score = sum(task_scores) / len(task_scores)
|
| 515 |
-
print(f" Task {task_id} average score: {avg_score:.4f}")
|
| 516 |
-
|
| 517 |
-
print("\n" + "=" * 60)
|
| 518 |
-
print("BASELINE SCORES SUMMARY")
|
| 519 |
-
print("=" * 60)
|
| 520 |
-
print(f"{'Task':<10} {'Model':<30} {'Score':<10} {'Episodes':<10}")
|
| 521 |
-
print("-" * 60)
|
| 522 |
|
| 523 |
task_avgs: dict[int, float] = {}
|
| 524 |
for task_id in [1, 2, 3]:
|
| 525 |
scores = [float(r["score"]) for r in all_results if r["task_id"] == task_id]
|
| 526 |
avg = sum(scores) / len(scores) if scores else 0.0
|
| 527 |
task_avgs[task_id] = avg
|
| 528 |
-
print(f"Task {task_id:<6} {MODEL_NAME:<30} {avg:<10.4f} {len(scores)}")
|
| 529 |
-
|
| 530 |
-
print("-" * 60)
|
| 531 |
overall = sum(task_avgs.values()) / len(task_avgs)
|
| 532 |
-
print(f"{'Overall':<10} {'':<30} {overall:<10.4f}")
|
| 533 |
|
| 534 |
output = {
|
| 535 |
"model": MODEL_NAME,
|
|
@@ -545,7 +511,6 @@ def main() -> None:
|
|
| 545 |
}
|
| 546 |
with open(args.output, "w", encoding="utf-8") as f:
|
| 547 |
json.dump(output, f, indent=2)
|
| 548 |
-
print(f"\n[OK] Results saved to {args.output}")
|
| 549 |
|
| 550 |
|
| 551 |
if __name__ == "__main__":
|
|
|
|
| 45 |
# ── Constants ──────────────────────────────────────────────────────────────
|
| 46 |
|
| 47 |
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
|
| 48 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
|
| 49 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
|
| 50 |
|
| 51 |
# ── Environment Variable Handling ─────────────────────────────────────────
|
| 52 |
# The LLM API credential is read from HF_TOKEN or OPENAI_API_KEY environment variables
|
|
|
|
| 54 |
# Primary: HF_TOKEN
|
| 55 |
# Fallback: OPENAI_API_KEY (for local testing/development)
|
| 56 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 57 |
+
OPENAI_API_KEY = HF_TOKEN
|
|
|
|
|
|
|
| 58 |
DEFAULT_EPISODES = 1
|
| 59 |
DEFAULT_SEED_BASE = 1000
|
| 60 |
MAX_RETRIES = 3
|
|
|
|
| 156 |
print(f"[ERROR] Failed to get state: {e}", file=sys.stderr)
|
| 157 |
return None
|
| 158 |
|
| 159 |
+
def close(self) -> None:
|
| 160 |
+
"""Compatibility close hook for episode-finalization contract."""
|
| 161 |
+
return None
|
| 162 |
+
|
| 163 |
|
| 164 |
# ── LLM agent ───────────────────────────────────────────────────────────────
|
| 165 |
|
|
|
|
| 170 |
def __init__(self):
|
| 171 |
# Initialize OpenAI client with credentials from HF_TOKEN (per hackathon spec)
|
| 172 |
# The OPENAI_API_KEY variable contains the HF_TOKEN value passed by evaluators
|
| 173 |
+
self.client = OpenAI(base_url=API_BASE_URL, api_key=OPENAI_API_KEY)
|
|
|
|
|
|
|
|
|
|
| 174 |
self.model = MODEL_NAME
|
| 175 |
self.fallback_mode = False
|
| 176 |
|
|
|
|
| 222 |
return self._clamp_action(action)
|
| 223 |
except Exception as e:
|
| 224 |
err_str = str(e)
|
| 225 |
+
print(f" [LLM attempt {attempt+1}/{MAX_RETRIES}] error: {err_str}", file=sys.stderr)
|
| 226 |
if "402" in err_str or "depleted" in err_str:
|
| 227 |
+
print(
|
| 228 |
+
" [WARN] Hugging Face free credits depleted! Switching to local heuristic agent for the rest of the simulation.",
|
| 229 |
+
file=sys.stderr,
|
| 230 |
+
)
|
| 231 |
self.fallback_mode = True
|
| 232 |
return self._heuristic_action(obs)
|
| 233 |
time.sleep(1)
|
|
|
|
| 313 |
...
|
| 314 |
[END] success=<true|false> steps=<n> rewards=<r1,r2,...>
|
| 315 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
task_name = f"gridmind-task-{task_id}"
|
|
|
|
|
|
|
| 317 |
print(f"[START] task={task_name} env=gridmind model={MODEL_NAME}", flush=True)
|
| 318 |
+
|
| 319 |
total_reward = 0.0
|
| 320 |
total_steps = 0
|
| 321 |
start_time = time.time()
|
| 322 |
+
step_resp: dict[str, Any] = {"done": False}
|
| 323 |
step_limit = EPISODE_STEPS if max_steps is None else min(max_steps, EPISODE_STEPS)
|
| 324 |
+
|
| 325 |
llm_reuse_remaining = 0
|
| 326 |
cached_action = agent._default_action()
|
| 327 |
+
|
| 328 |
step_rewards: list[float] = []
|
| 329 |
+
success = False
|
| 330 |
+
obs: dict[str, Any] = {}
|
| 331 |
+
|
| 332 |
+
try:
|
| 333 |
+
reset_resp = env_client.reset(task_id=task_id, seed=seed)
|
| 334 |
+
if reset_resp is None:
|
| 335 |
+
raise RuntimeError("reset failed")
|
| 336 |
+
obs_list = reset_resp.get("observations", [{}])
|
| 337 |
+
obs = obs_list[0] if obs_list else {}
|
| 338 |
+
|
| 339 |
+
while not step_resp.get("done", False):
|
| 340 |
+
if total_steps >= step_limit:
|
| 341 |
+
break
|
| 342 |
+
|
| 343 |
if fast_mode:
|
| 344 |
action = agent._heuristic_action(obs)
|
| 345 |
else:
|
|
|
|
| 350 |
|
| 351 |
step_resp = env_client.step(action)
|
| 352 |
if step_resp is None or not isinstance(step_resp, dict) or "observation" not in step_resp:
|
|
|
|
| 353 |
print(
|
| 354 |
f"[STEP] step={total_steps + 1} action=null "
|
| 355 |
+
f"reward=0.00 done=true error=invalid step response from environment",
|
| 356 |
flush=True
|
| 357 |
)
|
| 358 |
break
|
| 359 |
+
|
| 360 |
if not fast_mode:
|
| 361 |
llm_reuse_remaining -= 1
|
| 362 |
+
|
| 363 |
obs = step_resp["observation"]
|
| 364 |
reward = float(step_resp["reward"])
|
| 365 |
total_reward += reward
|
| 366 |
step_rewards.append(reward)
|
| 367 |
total_steps += 1
|
| 368 |
done = bool(step_resp.get("done", False))
|
| 369 |
+
|
|
|
|
| 370 |
action_json = json.dumps(action, separators=(',', ':'))
|
| 371 |
+
last_action_error = step_resp.get("last_action_error")
|
| 372 |
+
error_field = "null" if last_action_error is None else str(last_action_error)
|
| 373 |
print(
|
| 374 |
f"[STEP] step={total_steps} action={action_json} "
|
| 375 |
f"reward={reward:.2f} done={'true' if done else 'false'} error={error_field}",
|
| 376 |
flush=True
|
| 377 |
)
|
| 378 |
+
|
|
|
|
|
|
|
| 379 |
if verbose and total_steps % 16 == 0:
|
| 380 |
print(
|
| 381 |
f" step={total_steps:02d} price=${obs['current_price']:.3f} "
|
|
|
|
| 383 |
f"stress={obs['grid_stress_signal']:.2f} "
|
| 384 |
f"cost=${obs['cumulative_cost']:.2f}",
|
| 385 |
flush=True,
|
| 386 |
+
file=sys.stderr,
|
| 387 |
)
|
| 388 |
+
|
| 389 |
+
success = bool(step_resp.get("done", False))
|
| 390 |
+
except Exception as e:
|
| 391 |
+
err = str(e)
|
| 392 |
+
if not err:
|
| 393 |
+
err = "unknown error"
|
| 394 |
+
if "\n" in err:
|
| 395 |
+
err = err.replace("\n", " ")
|
| 396 |
+
if "\r" in err:
|
| 397 |
+
err = err.replace("\r", " ")
|
| 398 |
+
if total_steps < step_limit:
|
| 399 |
print(
|
| 400 |
f"[STEP] step={total_steps + 1} action=null "
|
| 401 |
+
f"reward=0.00 done=true error={err}",
|
| 402 |
flush=True
|
| 403 |
)
|
| 404 |
+
finally:
|
| 405 |
+
env_client.close()
|
| 406 |
+
|
| 407 |
elapsed = time.time() - start_time
|
| 408 |
grade = env_client.grade()
|
| 409 |
+
|
|
|
|
|
|
|
| 410 |
rewards_str = ",".join(f"{r:.2f}" for r in step_rewards)
|
| 411 |
print(
|
| 412 |
f"[END] success={'true' if success else 'false'} steps={total_steps} rewards={rewards_str}",
|
|
|
|
| 455 |
)
|
| 456 |
args = parser.parse_args()
|
| 457 |
|
| 458 |
+
if not HF_TOKEN:
|
| 459 |
+
print("HF_TOKEN is required.", file=sys.stderr)
|
| 460 |
+
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
env_client = GridMindEnvClient(base_url=args.env_url)
|
| 463 |
|
|
|
|
| 464 |
for attempt in range(30):
|
| 465 |
if env_client.health():
|
|
|
|
| 466 |
break
|
| 467 |
time.sleep(2)
|
| 468 |
if attempt == 29:
|
| 469 |
+
print("Environment server not reachable.", file=sys.stderr)
|
| 470 |
sys.exit(1)
|
| 471 |
|
| 472 |
agent = LLMAgent()
|
| 473 |
all_results: list[dict[str, Any]] = []
|
| 474 |
|
| 475 |
for task_id in [1, 2, 3]:
|
|
|
|
| 476 |
task_scores: list[float] = []
|
| 477 |
for ep in range(args.episodes):
|
| 478 |
seed = DEFAULT_SEED_BASE + task_id * 100 + ep
|
|
|
|
| 479 |
result = run_episode(
|
| 480 |
env_client,
|
| 481 |
agent,
|
|
|
|
| 488 |
)
|
| 489 |
task_scores.append(float(result["score"]))
|
| 490 |
all_results.append(result)
|
| 491 |
+
_ = sum(task_scores) / len(task_scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
task_avgs: dict[int, float] = {}
|
| 494 |
for task_id in [1, 2, 3]:
|
| 495 |
scores = [float(r["score"]) for r in all_results if r["task_id"] == task_id]
|
| 496 |
avg = sum(scores) / len(scores) if scores else 0.0
|
| 497 |
task_avgs[task_id] = avg
|
|
|
|
|
|
|
|
|
|
| 498 |
overall = sum(task_avgs.values()) / len(task_avgs)
|
|
|
|
| 499 |
|
| 500 |
output = {
|
| 501 |
"model": MODEL_NAME,
|
|
|
|
| 511 |
}
|
| 512 |
with open(args.output, "w", encoding="utf-8") as f:
|
| 513 |
json.dump(output, f, indent=2)
|
|
|
|
| 514 |
|
| 515 |
|
| 516 |
if __name__ == "__main__":
|