Spaces:
Running
Running
Commit ·
32565e1
1
Parent(s): 2ede269
rewrite reward system
Browse files- inference.py +218 -232
inference.py
CHANGED
|
@@ -1,26 +1,18 @@
|
|
| 1 |
"""
|
| 2 |
-
GridMind-RL
|
| 3 |
-
----------------------------
|
| 4 |
Runs an LLM agent against all 3 tasks for N episodes each.
|
| 5 |
Uses the OpenAI Python client pointed at any OpenAI-compatible endpoint.
|
| 6 |
|
| 7 |
-
Required environment variables
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# Option 2: Set env vars manually
|
| 17 |
-
export API_BASE_URL=https://openrouter.ai/api/v1
|
| 18 |
-
export MODEL_NAME=meta-llama/llama-3.1-8b-instruct:free
|
| 19 |
-
export OPENAI_API_KEY=sk-or-v1-xxxx
|
| 20 |
-
python inference.py
|
| 21 |
-
|
| 22 |
-
# Option 3: Fast mode (no LLM, heuristic only)
|
| 23 |
-
python inference.py --fast-mode --episodes 1
|
| 24 |
"""
|
| 25 |
|
| 26 |
from __future__ import annotations
|
|
@@ -36,27 +28,36 @@ from typing import Any, Optional
|
|
| 36 |
import requests
|
| 37 |
from openai import OpenAI
|
| 38 |
|
| 39 |
-
# ── Load .env file
|
| 40 |
try:
|
| 41 |
from dotenv import load_dotenv
|
| 42 |
-
load_dotenv()
|
| 43 |
except ImportError:
|
| 44 |
-
pass
|
| 45 |
-
|
| 46 |
-
# ── Constants ──────────────────────────────────────────────────────────────
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
|
| 52 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
DEFAULT_EPISODES = 1
|
| 54 |
DEFAULT_SEED_BASE = 1000
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
SCORE_EPSILON = 0.01
|
| 59 |
|
|
|
|
| 60 |
SYSPROMPT = """You are GridMind, an expert industrial energy management controller.
|
| 61 |
You control a building's HVAC, thermal storage, batch job scheduling, and load shedding.
|
| 62 |
Your goal is to minimize electricity costs while maintaining comfort and meeting grid demand-response signals.
|
|
@@ -68,7 +69,7 @@ TASK_DESCRIPTIONS = {
|
|
| 68 |
3: "Task 3 (Hard - Full Demand Response): Minimize cost, maintain temperature, respond to grid stress (shed when grid_stress_signal > 0.7), schedule batch jobs, minimize carbon.",
|
| 69 |
}
|
| 70 |
|
| 71 |
-
|
| 72 |
"hvac_power_level": <float 0.0-1.0>,
|
| 73 |
"thermal_charge_rate": <float -1.0 to 1.0>,
|
| 74 |
"batch_job_slot": <int 0-4>,
|
|
@@ -77,8 +78,37 @@ ACTION_SCHEMA_STR = """{
|
|
| 77 |
}"""
|
| 78 |
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
start = text.find("{")
|
| 83 |
if start < 0:
|
| 84 |
return None
|
|
@@ -91,7 +121,7 @@ def extract_json_object(text: str) -> dict[str, Any] | None:
|
|
| 91 |
depth -= 1
|
| 92 |
if depth == 0:
|
| 93 |
try:
|
| 94 |
-
return json.loads(text[start
|
| 95 |
except json.JSONDecodeError:
|
| 96 |
return None
|
| 97 |
return None
|
|
@@ -106,85 +136,34 @@ def clamp_open_score(score: float) -> float:
|
|
| 106 |
return score
|
| 107 |
|
| 108 |
|
| 109 |
-
def
|
| 110 |
-
"""Normalize raw
|
| 111 |
-
if
|
| 112 |
-
return
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
if raw_range > 0:
|
| 117 |
-
return [clamp_open_score((r - raw_min) / raw_range) for r in raw_rewards]
|
| 118 |
-
return [0.5] * len(raw_rewards)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
# ── Environment client ───────────────────────────────────────────────────────
|
| 122 |
|
| 123 |
|
| 124 |
-
|
| 125 |
-
"""
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def health(self) -> bool:
|
| 132 |
-
try:
|
| 133 |
-
r = requests.get(f"{self.base}/health", timeout=5)
|
| 134 |
-
return r.status_code == 200
|
| 135 |
-
except Exception:
|
| 136 |
-
return False
|
| 137 |
-
|
| 138 |
-
def reset(self, task_id: int = 1, seed: int = 42, num_buildings: int = 1) -> dict | None:
|
| 139 |
-
try:
|
| 140 |
-
payload = {"task_id": task_id, "seed": seed, "num_buildings": num_buildings}
|
| 141 |
-
r = requests.post(f"{self.base}/reset", json=payload, timeout=self.timeout)
|
| 142 |
-
r.raise_for_status()
|
| 143 |
-
return r.json()
|
| 144 |
-
except Exception as e:
|
| 145 |
-
print(f"[ERROR] Failed to reset environment: {e}", file=sys.stderr)
|
| 146 |
-
return None
|
| 147 |
-
|
| 148 |
-
def step(self, action: dict) -> dict | None:
|
| 149 |
-
try:
|
| 150 |
-
r = requests.post(f"{self.base}/step", json=action, timeout=self.timeout)
|
| 151 |
-
r.raise_for_status()
|
| 152 |
-
return r.json()
|
| 153 |
-
except Exception as e:
|
| 154 |
-
print(f"[ERROR] Failed to step environment: {e}", file=sys.stderr)
|
| 155 |
-
return None
|
| 156 |
-
|
| 157 |
-
def grade(self) -> dict:
|
| 158 |
-
try:
|
| 159 |
-
r = requests.get(f"{self.base}/grade", timeout=self.timeout)
|
| 160 |
-
r.raise_for_status()
|
| 161 |
-
return r.json()
|
| 162 |
-
except Exception as e:
|
| 163 |
-
print(f"[ERROR] Failed to grade: {e}", file=sys.stderr)
|
| 164 |
-
return {"score": SCORE_EPSILON, "sub_scores": {}, "exploit_detected": False}
|
| 165 |
-
|
| 166 |
-
def state(self) -> dict | None:
|
| 167 |
-
try:
|
| 168 |
-
r = requests.get(f"{self.base}/state", timeout=self.timeout)
|
| 169 |
-
r.raise_for_status()
|
| 170 |
-
return r.json()
|
| 171 |
-
except Exception as e:
|
| 172 |
-
print(f"[ERROR] Failed to get state: {e}", file=sys.stderr)
|
| 173 |
-
return None
|
| 174 |
-
|
| 175 |
-
def close(self) -> None:
|
| 176 |
-
"""Compatibility close hook for episode-finalization contract."""
|
| 177 |
-
return None
|
| 178 |
|
| 179 |
|
| 180 |
-
# ── LLM
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
|
|
|
|
| 183 |
class LLMAgent:
|
| 184 |
-
"""OpenAI-compatible LLM agent that chooses actions given observations."""
|
| 185 |
-
|
| 186 |
def __init__(self):
|
| 187 |
-
self.client =
|
| 188 |
self.model = MODEL_NAME
|
| 189 |
self.fallback_mode = False
|
| 190 |
|
|
@@ -192,6 +171,7 @@ class LLMAgent:
|
|
| 192 |
"""Prompt the LLM with current observation, return parsed action dict."""
|
| 193 |
if self.fallback_mode:
|
| 194 |
return self._heuristic_action(obs)
|
|
|
|
| 195 |
task_desc = TASK_DESCRIPTIONS.get(task_id, TASK_DESCRIPTIONS[1])
|
| 196 |
|
| 197 |
prompt = f"""{task_desc}
|
|
@@ -206,12 +186,12 @@ Current observation:
|
|
| 206 |
- Hour of day: {obs.get('hour_of_day', 12)} (0=midnight, peak prices 8-12 and 17-21)
|
| 207 |
- Pending batch job deadlines: {obs.get('batch_queue', [])}
|
| 208 |
- Cumulative cost so far: ${obs.get('cumulative_cost', 0):.4f}
|
| 209 |
-
- Episode step: {obs.get('step', 0)}/{
|
| 210 |
|
| 211 |
IMPORTANT RULES:
|
| 212 |
-
- thermal_charge_rate:
|
| 213 |
- load_shed_fraction: MUST be 0.2-0.5 when grid_stress_signal > 0.7, otherwise 0.0
|
| 214 |
-
- shed load during grid stress to earn rewards
|
| 215 |
|
| 216 |
Strategy hints:
|
| 217 |
- Charge thermal storage (positive) when price < $0.08/kWh
|
|
@@ -221,7 +201,7 @@ Strategy hints:
|
|
| 221 |
- Schedule batch jobs early if deadline is close (slot 0 or 1)
|
| 222 |
|
| 223 |
Respond with ONLY a JSON action:
|
| 224 |
-
{
|
| 225 |
|
| 226 |
for attempt in range(MAX_RETRIES):
|
| 227 |
try:
|
|
@@ -244,10 +224,7 @@ Respond with ONLY a JSON action:
|
|
| 244 |
err_str = str(e)
|
| 245 |
print(f" [LLM attempt {attempt+1}/{MAX_RETRIES}] error: {err_str}", file=sys.stderr)
|
| 246 |
if "402" in err_str or "depleted" in err_str:
|
| 247 |
-
print(
|
| 248 |
-
" [WARN] Hugging Face free credits depleted! Switching to local heuristic agent for the rest of the simulation.",
|
| 249 |
-
file=sys.stderr,
|
| 250 |
-
)
|
| 251 |
self.fallback_mode = True
|
| 252 |
return self._heuristic_action(obs)
|
| 253 |
time.sleep(1)
|
|
@@ -264,7 +241,7 @@ Respond with ONLY a JSON action:
|
|
| 264 |
}
|
| 265 |
|
| 266 |
def _heuristic_action(self, obs: dict) -> dict:
|
| 267 |
-
"""Rule-based
|
| 268 |
price = obs.get("current_price", 0.10)
|
| 269 |
stress = obs.get("grid_stress_signal", 0.0)
|
| 270 |
temp = obs.get("indoor_temperature", 21.0)
|
|
@@ -311,9 +288,61 @@ Respond with ONLY a JSON action:
|
|
| 311 |
}
|
| 312 |
|
| 313 |
|
| 314 |
-
# ──
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
def run_episode(
|
| 318 |
env_client: GridMindEnvClient,
|
| 319 |
agent: LLMAgent,
|
|
@@ -322,19 +351,12 @@ def run_episode(
|
|
| 322 |
*,
|
| 323 |
fast_mode: bool,
|
| 324 |
llm_every: int,
|
| 325 |
-
max_steps: int
|
| 326 |
verbose: bool = False,
|
| 327 |
) -> dict[str, Any]:
|
| 328 |
-
"""Run a single episode and emit hackathon-compliant stdout format.
|
| 329 |
-
|
| 330 |
-
Emits:
|
| 331 |
-
[START] task=<name> env=gridmind model=<model>
|
| 332 |
-
[STEP] step=<n> action=<json> reward=<0.00> done=<true|false> error=<msg|null>
|
| 333 |
-
...
|
| 334 |
-
[END] success=<true|false> steps=<n> rewards=<r1,r2,...>
|
| 335 |
-
"""
|
| 336 |
task_name = f"gridmind-task-{task_id}"
|
| 337 |
-
|
| 338 |
|
| 339 |
total_reward = 0.0
|
| 340 |
total_steps = 0
|
|
@@ -345,7 +367,7 @@ def run_episode(
|
|
| 345 |
llm_reuse_remaining = 0
|
| 346 |
cached_action = agent._default_action()
|
| 347 |
|
| 348 |
-
|
| 349 |
reward_min = float('inf')
|
| 350 |
reward_max = float('-inf')
|
| 351 |
success = False
|
|
@@ -369,13 +391,15 @@ def run_episode(
|
|
| 369 |
cached_action = agent.choose_action(obs, task_id)
|
| 370 |
llm_reuse_remaining = max(1, llm_every)
|
| 371 |
action = cached_action
|
| 372 |
-
|
| 373 |
step_resp = env_client.step(action)
|
| 374 |
if step_resp is None or not isinstance(step_resp, dict) or "observation" not in step_resp:
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
|
|
|
|
|
|
| 379 |
)
|
| 380 |
break
|
| 381 |
|
|
@@ -385,29 +409,26 @@ def run_episode(
|
|
| 385 |
obs = step_resp["observation"]
|
| 386 |
raw_reward = float(step_resp["reward"])
|
| 387 |
total_reward += raw_reward
|
| 388 |
-
|
| 389 |
-
|
| 390 |
if raw_reward < reward_min:
|
| 391 |
reward_min = raw_reward
|
| 392 |
if raw_reward > reward_max:
|
| 393 |
reward_max = raw_reward
|
| 394 |
-
|
| 395 |
total_steps += 1
|
| 396 |
done = bool(step_resp.get("done", False))
|
| 397 |
|
| 398 |
-
|
| 399 |
-
if reward_range > 0:
|
| 400 |
-
normalized_reward = clamp_open_score((raw_reward - reward_min) / reward_range)
|
| 401 |
-
else:
|
| 402 |
-
normalized_reward = 0.5
|
| 403 |
|
| 404 |
action_json = json.dumps(action, separators=(',', ':'))
|
| 405 |
last_action_error = step_resp.get("last_action_error")
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
|
|
|
| 411 |
)
|
| 412 |
|
| 413 |
if verbose and total_steps % 16 == 0:
|
|
@@ -421,37 +442,33 @@ def run_episode(
|
|
| 421 |
)
|
| 422 |
|
| 423 |
success = bool(step_resp.get("done", False))
|
|
|
|
| 424 |
except Exception as e:
|
| 425 |
-
err = str(e)
|
| 426 |
-
|
| 427 |
-
err = "unknown error"
|
| 428 |
-
if "\n" in err:
|
| 429 |
-
err = err.replace("\n", " ")
|
| 430 |
-
if "\r" in err:
|
| 431 |
-
err = err.replace("\r", " ")
|
| 432 |
if total_steps < step_limit:
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
|
|
|
|
|
|
| 437 |
)
|
|
|
|
| 438 |
finally:
|
| 439 |
env_client.close()
|
| 440 |
|
| 441 |
elapsed = time.time() - start_time
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
print(
|
| 451 |
-
f"[END] success={'true' if success else 'false'} steps={total_steps} rewards={rewards_str}",
|
| 452 |
-
flush=True
|
| 453 |
)
|
| 454 |
-
|
| 455 |
return {
|
| 456 |
"task_id": task_id,
|
| 457 |
"seed": seed,
|
|
@@ -459,51 +476,37 @@ def run_episode(
|
|
| 459 |
"total_steps": total_steps,
|
| 460 |
"elapsed_sec": elapsed,
|
| 461 |
"score": episode_score,
|
| 462 |
-
"sub_scores":
|
| 463 |
-
"exploit_detected":
|
| 464 |
}
|
| 465 |
|
| 466 |
|
| 467 |
-
# ──
|
| 468 |
-
|
| 469 |
def start_environment_server(port: int = 7860) -> Optional[subprocess.Popen]:
|
| 470 |
-
"""Start the GridMind-RL environment server as a background process.
|
| 471 |
-
|
| 472 |
-
Returns:
|
| 473 |
-
A Popen object if the server was started, or None if it's already running.
|
| 474 |
-
"""
|
| 475 |
-
# First check if server is already running
|
| 476 |
try:
|
| 477 |
r = requests.get(f"http://localhost:{port}/health", timeout=2)
|
| 478 |
if r.status_code == 200:
|
| 479 |
print(f"[INFO] Environment server already running on port {port}", file=sys.stderr)
|
| 480 |
return None
|
| 481 |
except Exception:
|
| 482 |
-
pass
|
| 483 |
-
|
| 484 |
print(f"[INFO] Starting environment server on port {port}...", file=sys.stderr)
|
| 485 |
-
|
| 486 |
-
# Try to find and run the server
|
| 487 |
try:
|
| 488 |
-
# Prepare environment
|
| 489 |
env = os.environ.copy()
|
| 490 |
env["PORT"] = str(port)
|
| 491 |
-
|
| 492 |
-
env["PYTHONPATH"] = "." + os.pathsep + env["PYTHONPATH"]
|
| 493 |
-
else:
|
| 494 |
-
env["PYTHONPATH"] = "."
|
| 495 |
-
|
| 496 |
-
# Look for compiled Go binary first
|
| 497 |
binary_paths = [
|
| 498 |
-
"/usr/local/bin/gridmind-server",
|
| 499 |
-
"./gridmind-server",
|
| 500 |
-
"./gridmind-server.exe",
|
| 501 |
]
|
| 502 |
-
|
| 503 |
for binary_path in binary_paths:
|
| 504 |
if os.path.exists(binary_path):
|
| 505 |
try:
|
| 506 |
-
print(f"[INFO] Running Go binary: {binary_path}", file=sys.stderr)
|
| 507 |
proc = subprocess.Popen(
|
| 508 |
[binary_path],
|
| 509 |
env=env,
|
|
@@ -515,51 +518,40 @@ def start_environment_server(port: int = 7860) -> Optional[subprocess.Popen]:
|
|
| 515 |
return proc
|
| 516 |
except Exception as e:
|
| 517 |
print(f"[DEBUG] Failed with {binary_path}: {e}", file=sys.stderr)
|
| 518 |
-
|
| 519 |
-
# Try to compile Go binary if 'go' is available
|
| 520 |
try:
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
result = subprocess.run(
|
| 524 |
-
compile_cmd,
|
| 525 |
capture_output=True,
|
| 526 |
timeout=60,
|
| 527 |
cwd=".",
|
| 528 |
)
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
)
|
| 537 |
-
time.sleep(2)
|
| 538 |
-
if proc.poll() is None:
|
| 539 |
-
return proc
|
| 540 |
-
except Exception as e:
|
| 541 |
-
print(f"[DEBUG] Could not compile: {e}", file=sys.stderr)
|
| 542 |
-
|
| 543 |
-
# Fallback: try to run via Python server module
|
| 544 |
-
print(f"[INFO] Attempting Python server module...", file=sys.stderr)
|
| 545 |
proc = subprocess.Popen(
|
| 546 |
[sys.executable, "-m", "server.app"],
|
| 547 |
env=env,
|
| 548 |
stdout=subprocess.PIPE,
|
| 549 |
stderr=subprocess.PIPE,
|
| 550 |
-
cwd=".",
|
| 551 |
)
|
| 552 |
time.sleep(3)
|
| 553 |
if proc.poll() is None:
|
| 554 |
return proc
|
| 555 |
-
|
| 556 |
except Exception as e:
|
| 557 |
print(f"[WARNING] Could not start environment server: {e}", file=sys.stderr)
|
| 558 |
-
|
|
|
|
| 559 |
|
| 560 |
|
|
|
|
| 561 |
def main() -> None:
|
| 562 |
-
parser = argparse.ArgumentParser(description="GridMind-RL
|
| 563 |
parser.add_argument("--episodes", type=int, default=DEFAULT_EPISODES)
|
| 564 |
parser.add_argument("--env-url", type=str, default=ENV_URL)
|
| 565 |
parser.add_argument("--verbose", action="store_true")
|
|
@@ -567,31 +559,26 @@ def main() -> None:
|
|
| 567 |
parser.add_argument(
|
| 568 |
"--fast-mode",
|
| 569 |
action="store_true",
|
| 570 |
-
help="Heuristic policy only (no LLM calls
|
| 571 |
)
|
| 572 |
parser.add_argument(
|
| 573 |
"--llm-every",
|
| 574 |
type=int,
|
| 575 |
default=8,
|
| 576 |
metavar="N",
|
| 577 |
-
help="Reuse the same LLM action for N
|
| 578 |
)
|
| 579 |
parser.add_argument(
|
| 580 |
"--max-steps",
|
| 581 |
type=int,
|
| 582 |
default=None,
|
| 583 |
metavar="N",
|
| 584 |
-
help="Stop after N steps
|
| 585 |
)
|
| 586 |
args = parser.parse_args()
|
| 587 |
-
|
| 588 |
-
if not HF_TOKEN:
|
| 589 |
-
print("HF_TOKEN is required.", file=sys.stderr)
|
| 590 |
-
sys.exit(1)
|
| 591 |
|
| 592 |
-
# Start the environment server if not already running
|
| 593 |
server_proc = start_environment_server(port=7860)
|
| 594 |
-
|
| 595 |
try:
|
| 596 |
env_client = GridMindEnvClient(base_url=args.env_url)
|
| 597 |
|
|
@@ -622,13 +609,13 @@ def main() -> None:
|
|
| 622 |
)
|
| 623 |
task_scores.append(float(result["score"]))
|
| 624 |
all_results.append(result)
|
| 625 |
-
_ = sum(task_scores) / len(task_scores)
|
| 626 |
|
| 627 |
task_avgs: dict[int, float] = {}
|
| 628 |
for task_id in [1, 2, 3]:
|
| 629 |
scores = [float(r["score"]) for r in all_results if r["task_id"] == task_id]
|
| 630 |
avg = clamp_open_score(sum(scores) / len(scores)) if scores else SCORE_EPSILON
|
| 631 |
task_avgs[task_id] = avg
|
|
|
|
| 632 |
overall = clamp_open_score(sum(task_avgs.values()) / len(task_avgs))
|
| 633 |
|
| 634 |
output = {
|
|
@@ -645,14 +632,13 @@ def main() -> None:
|
|
| 645 |
}
|
| 646 |
with open(args.output, "w", encoding="utf-8") as f:
|
| 647 |
json.dump(output, f, indent=2)
|
|
|
|
| 648 |
finally:
|
| 649 |
-
# Clean up the server process if we started it
|
| 650 |
if server_proc:
|
| 651 |
try:
|
| 652 |
server_proc.terminate()
|
| 653 |
server_proc.wait(timeout=5)
|
| 654 |
-
except Exception
|
| 655 |
-
print(f"[WARNING] Failed to terminate server: {e}", file=sys.stderr)
|
| 656 |
try:
|
| 657 |
server_proc.kill()
|
| 658 |
except Exception:
|
|
|
|
| 1 |
"""
|
| 2 |
+
GridMind-RL Inference Script
|
| 3 |
+
----------------------------
|
| 4 |
Runs an LLM agent against all 3 tasks for N episodes each.
|
| 5 |
Uses the OpenAI Python client pointed at any OpenAI-compatible endpoint.
|
| 6 |
|
| 7 |
+
Required environment variables:
|
| 8 |
+
HF_TOKEN — Hugging Face / API token (mandatory, no default)
|
| 9 |
+
API_BASE_URL — API endpoint for the LLM (has default)
|
| 10 |
+
MODEL_NAME — Model identifier (has default)
|
| 11 |
|
| 12 |
+
STDOUT FORMAT (machine-parsed by judge):
|
| 13 |
+
[START] task=<task_name> env=<benchmark> model=<model_name>
|
| 14 |
+
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 15 |
+
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"""
|
| 17 |
|
| 18 |
from __future__ import annotations
|
|
|
|
| 28 |
import requests
|
| 29 |
from openai import OpenAI
|
| 30 |
|
| 31 |
+
# ── Load .env file ─────────────────────────────────────────────────────────
|
| 32 |
try:
|
| 33 |
from dotenv import load_dotenv
|
| 34 |
+
load_dotenv()
|
| 35 |
except ImportError:
|
| 36 |
+
pass
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
# ── Environment Variables ────────────────────────────────────────────────────
|
| 39 |
+
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
|
| 40 |
+
HF_TOKEN = os.getenv("HF_TOKEN") # Mandatory — no default
|
|
|
|
| 41 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
|
| 42 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
|
| 43 |
+
|
| 44 |
+
# ── Constants ────────────────────────────────────────────────────────────────
|
| 45 |
+
BENCHMARK = "gridmind"
|
| 46 |
+
EPISODE_STEPS = 96
|
| 47 |
+
LAST_STEP = EPISODE_STEPS - 1
|
| 48 |
+
MAX_RETRIES = 3
|
| 49 |
DEFAULT_EPISODES = 1
|
| 50 |
DEFAULT_SEED_BASE = 1000
|
| 51 |
+
|
| 52 |
+
# Reward range per step in this environment: (0.10, 0.90)
|
| 53 |
+
# Worst action -> 0.10, best action -> 0.90
|
| 54 |
+
REWARD_MIN = 0.10
|
| 55 |
+
REWARD_MAX = 0.90
|
| 56 |
+
|
| 57 |
+
# Score clamp buffer (never output exactly 0.0 or 1.0)
|
| 58 |
SCORE_EPSILON = 0.01
|
| 59 |
|
| 60 |
+
# ── System Prompt ────────────────────────────────────────────────────────────
|
| 61 |
SYSPROMPT = """You are GridMind, an expert industrial energy management controller.
|
| 62 |
You control a building's HVAC, thermal storage, batch job scheduling, and load shedding.
|
| 63 |
Your goal is to minimize electricity costs while maintaining comfort and meeting grid demand-response signals.
|
|
|
|
| 69 |
3: "Task 3 (Hard - Full Demand Response): Minimize cost, maintain temperature, respond to grid stress (shed when grid_stress_signal > 0.7), schedule batch jobs, minimize carbon.",
|
| 70 |
}
|
| 71 |
|
| 72 |
+
ACTION_SCHEMA = """{
|
| 73 |
"hvac_power_level": <float 0.0-1.0>,
|
| 74 |
"thermal_charge_rate": <float -1.0 to 1.0>,
|
| 75 |
"batch_job_slot": <int 0-4>,
|
|
|
|
| 78 |
}"""
|
| 79 |
|
| 80 |
|
| 81 |
+
# ── Logging Helpers (judge-parsed format) ────────────────────────────────────
|
| 82 |
+
def log_start(task: str, env_name: str, model: str) -> None:
|
| 83 |
+
"""[START] line — emitted once at episode begin."""
|
| 84 |
+
print(f"[START] task={task} env={env_name} model={model}", flush=True)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def log_step(step: int, action: str, reward: float, done: bool,
|
| 88 |
+
error: Optional[str] = None) -> None:
|
| 89 |
+
"""[STEP] line — emitted after each env.step() returns."""
|
| 90 |
+
error_val = error if error else "null"
|
| 91 |
+
done_val = str(done).lower()
|
| 92 |
+
print(
|
| 93 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} "
|
| 94 |
+
f"done={done_val} error={error_val}",
|
| 95 |
+
flush=True,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
|
| 100 |
+
"""[END] line — always emitted (even on exception)."""
|
| 101 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 102 |
+
print(
|
| 103 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 104 |
+
f"score={score:.3f} rewards={rewards_str}",
|
| 105 |
+
flush=True,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ── Utility Functions ─────────────────────────────────────────────────────────
|
| 110 |
+
def extract_json_object(text: str) -> Optional[dict[str, Any]]:
|
| 111 |
+
"""Parse first balanced {...} JSON object from text."""
|
| 112 |
start = text.find("{")
|
| 113 |
if start < 0:
|
| 114 |
return None
|
|
|
|
| 121 |
depth -= 1
|
| 122 |
if depth == 0:
|
| 123 |
try:
|
| 124 |
+
return json.loads(text[start:i + 1])
|
| 125 |
except json.JSONDecodeError:
|
| 126 |
return None
|
| 127 |
return None
|
|
|
|
| 136 |
return score
|
| 137 |
|
| 138 |
|
| 139 |
+
def normalize_reward(raw_reward: float, raw_min: float, raw_max: float) -> float:
|
| 140 |
+
"""Normalize raw reward to (REWARD_MIN, REWARD_MAX) range."""
|
| 141 |
+
if raw_max == raw_min:
|
| 142 |
+
return (REWARD_MIN + REWARD_MAX) / 2
|
| 143 |
+
normalized = (raw_reward - raw_min) / (raw_max - raw_min)
|
| 144 |
+
normalized = normalized * (REWARD_MAX - REWARD_MIN) + REWARD_MIN
|
| 145 |
+
return clamp_open_score(normalized)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
+
def compute_score(rewards: list[float]) -> float:
|
| 149 |
+
"""Return mean reward clamped strictly to (0.01, 0.99)."""
|
| 150 |
+
if not rewards:
|
| 151 |
+
return SCORE_EPSILON
|
| 152 |
+
mean_reward = sum(rewards) / len(rewards)
|
| 153 |
+
return clamp_open_score(round(mean_reward, 4))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
+
# ── LLM Client ───────────────────────────────────────────────────────────────
|
| 157 |
+
def get_llm_client() -> OpenAI:
|
| 158 |
+
if not HF_TOKEN:
|
| 159 |
+
raise EnvironmentError("HF_TOKEN environment variable is not set.")
|
| 160 |
+
return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 161 |
|
| 162 |
|
| 163 |
+
# ── LLM Agent ────────────────────────────────────────────────────────────────
|
| 164 |
class LLMAgent:
|
|
|
|
|
|
|
| 165 |
def __init__(self):
|
| 166 |
+
self.client = get_llm_client()
|
| 167 |
self.model = MODEL_NAME
|
| 168 |
self.fallback_mode = False
|
| 169 |
|
|
|
|
| 171 |
"""Prompt the LLM with current observation, return parsed action dict."""
|
| 172 |
if self.fallback_mode:
|
| 173 |
return self._heuristic_action(obs)
|
| 174 |
+
|
| 175 |
task_desc = TASK_DESCRIPTIONS.get(task_id, TASK_DESCRIPTIONS[1])
|
| 176 |
|
| 177 |
prompt = f"""{task_desc}
|
|
|
|
| 186 |
- Hour of day: {obs.get('hour_of_day', 12)} (0=midnight, peak prices 8-12 and 17-21)
|
| 187 |
- Pending batch job deadlines: {obs.get('batch_queue', [])}
|
| 188 |
- Cumulative cost so far: ${obs.get('cumulative_cost', 0):.4f}
|
| 189 |
+
- Episode step: {obs.get('step', 0)}/{LAST_STEP}
|
| 190 |
|
| 191 |
IMPORTANT RULES:
|
| 192 |
+
- thermal_charge_rate: NEGATIVE = DISCHARGE storage, POSITIVE = CHARGE
|
| 193 |
- load_shed_fraction: MUST be 0.2-0.5 when grid_stress_signal > 0.7, otherwise 0.0
|
| 194 |
+
- shed load during grid stress to earn rewards
|
| 195 |
|
| 196 |
Strategy hints:
|
| 197 |
- Charge thermal storage (positive) when price < $0.08/kWh
|
|
|
|
| 201 |
- Schedule batch jobs early if deadline is close (slot 0 or 1)
|
| 202 |
|
| 203 |
Respond with ONLY a JSON action:
|
| 204 |
+
{ACTION_SCHEMA}"""
|
| 205 |
|
| 206 |
for attempt in range(MAX_RETRIES):
|
| 207 |
try:
|
|
|
|
| 224 |
err_str = str(e)
|
| 225 |
print(f" [LLM attempt {attempt+1}/{MAX_RETRIES}] error: {err_str}", file=sys.stderr)
|
| 226 |
if "402" in err_str or "depleted" in err_str:
|
| 227 |
+
print(" [WARN] API credits depleted! Switching to heuristic agent.", file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
| 228 |
self.fallback_mode = True
|
| 229 |
return self._heuristic_action(obs)
|
| 230 |
time.sleep(1)
|
|
|
|
| 241 |
}
|
| 242 |
|
| 243 |
def _heuristic_action(self, obs: dict) -> dict:
|
| 244 |
+
"""Rule-based fallback policy."""
|
| 245 |
price = obs.get("current_price", 0.10)
|
| 246 |
stress = obs.get("grid_stress_signal", 0.0)
|
| 247 |
temp = obs.get("indoor_temperature", 21.0)
|
|
|
|
| 288 |
}
|
| 289 |
|
| 290 |
|
| 291 |
+
# ── Environment Client ────────────────────────────────────────────────────────
|
| 292 |
+
class GridMindEnvClient:
|
| 293 |
+
def __init__(self, base_url: str = ENV_URL, timeout: int = 30):
|
| 294 |
+
self.base = base_url.rstrip("/")
|
| 295 |
+
self.timeout = timeout
|
| 296 |
|
| 297 |
+
def health(self) -> bool:
|
| 298 |
+
try:
|
| 299 |
+
r = requests.get(f"{self.base}/health", timeout=5)
|
| 300 |
+
return r.status_code == 200
|
| 301 |
+
except Exception:
|
| 302 |
+
return False
|
| 303 |
+
|
| 304 |
+
def reset(self, task_id: int = 1, seed: int = 42, num_buildings: int = 1) -> Optional[dict]:
|
| 305 |
+
try:
|
| 306 |
+
payload = {"task_id": task_id, "seed": seed, "num_buildings": num_buildings}
|
| 307 |
+
r = requests.post(f"{self.base}/reset", json=payload, timeout=self.timeout)
|
| 308 |
+
r.raise_for_status()
|
| 309 |
+
return r.json()
|
| 310 |
+
except Exception as e:
|
| 311 |
+
print(f"[ERROR] Failed to reset environment: {e}", file=sys.stderr)
|
| 312 |
+
return None
|
| 313 |
+
|
| 314 |
+
def step(self, action: dict) -> Optional[dict]:
|
| 315 |
+
try:
|
| 316 |
+
r = requests.post(f"{self.base}/step", json=action, timeout=self.timeout)
|
| 317 |
+
r.raise_for_status()
|
| 318 |
+
return r.json()
|
| 319 |
+
except Exception as e:
|
| 320 |
+
print(f"[ERROR] Failed to step environment: {e}", file=sys.stderr)
|
| 321 |
+
return None
|
| 322 |
+
|
| 323 |
+
def grade(self) -> dict:
|
| 324 |
+
try:
|
| 325 |
+
r = requests.get(f"{self.base}/grade", timeout=self.timeout)
|
| 326 |
+
r.raise_for_status()
|
| 327 |
+
return r.json()
|
| 328 |
+
except Exception as e:
|
| 329 |
+
print(f"[ERROR] Failed to grade: {e}", file=sys.stderr)
|
| 330 |
+
return {"score": SCORE_EPSILON, "sub_scores": {}, "exploit_detected": False}
|
| 331 |
+
|
| 332 |
+
def state(self) -> Optional[dict]:
|
| 333 |
+
try:
|
| 334 |
+
r = requests.get(f"{self.base}/state", timeout=self.timeout)
|
| 335 |
+
r.raise_for_status()
|
| 336 |
+
return r.json()
|
| 337 |
+
except Exception as e:
|
| 338 |
+
print(f"[ERROR] Failed to get state: {e}", file=sys.stderr)
|
| 339 |
+
return None
|
| 340 |
|
| 341 |
+
def close(self) -> None:
|
| 342 |
+
return None
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# ── Episode Runner ────────────────────────────────────────────────────────────
|
| 346 |
def run_episode(
|
| 347 |
env_client: GridMindEnvClient,
|
| 348 |
agent: LLMAgent,
|
|
|
|
| 351 |
*,
|
| 352 |
fast_mode: bool,
|
| 353 |
llm_every: int,
|
| 354 |
+
max_steps: Optional[int],
|
| 355 |
verbose: bool = False,
|
| 356 |
) -> dict[str, Any]:
|
| 357 |
+
"""Run a single episode and emit hackathon-compliant stdout format."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
task_name = f"gridmind-task-{task_id}"
|
| 359 |
+
log_start(task=task_name, env_name=BENCHMARK, model=MODEL_NAME)
|
| 360 |
|
| 361 |
total_reward = 0.0
|
| 362 |
total_steps = 0
|
|
|
|
| 367 |
llm_reuse_remaining = 0
|
| 368 |
cached_action = agent._default_action()
|
| 369 |
|
| 370 |
+
raw_rewards: list[float] = []
|
| 371 |
reward_min = float('inf')
|
| 372 |
reward_max = float('-inf')
|
| 373 |
success = False
|
|
|
|
| 391 |
cached_action = agent.choose_action(obs, task_id)
|
| 392 |
llm_reuse_remaining = max(1, llm_every)
|
| 393 |
action = cached_action
|
| 394 |
+
|
| 395 |
step_resp = env_client.step(action)
|
| 396 |
if step_resp is None or not isinstance(step_resp, dict) or "observation" not in step_resp:
|
| 397 |
+
log_step(
|
| 398 |
+
step=total_steps + 1,
|
| 399 |
+
action="null",
|
| 400 |
+
reward=0.0,
|
| 401 |
+
done=True,
|
| 402 |
+
error="invalid step response from environment",
|
| 403 |
)
|
| 404 |
break
|
| 405 |
|
|
|
|
| 409 |
obs = step_resp["observation"]
|
| 410 |
raw_reward = float(step_resp["reward"])
|
| 411 |
total_reward += raw_reward
|
| 412 |
+
raw_rewards.append(raw_reward)
|
| 413 |
+
|
| 414 |
if raw_reward < reward_min:
|
| 415 |
reward_min = raw_reward
|
| 416 |
if raw_reward > reward_max:
|
| 417 |
reward_max = raw_reward
|
| 418 |
+
|
| 419 |
total_steps += 1
|
| 420 |
done = bool(step_resp.get("done", False))
|
| 421 |
|
| 422 |
+
normalized_reward = normalize_reward(raw_reward, reward_min, reward_max)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
action_json = json.dumps(action, separators=(',', ':'))
|
| 425 |
last_action_error = step_resp.get("last_action_error")
|
| 426 |
+
log_step(
|
| 427 |
+
step=total_steps,
|
| 428 |
+
action=action_json,
|
| 429 |
+
reward=normalized_reward,
|
| 430 |
+
done=done,
|
| 431 |
+
error=last_action_error,
|
| 432 |
)
|
| 433 |
|
| 434 |
if verbose and total_steps % 16 == 0:
|
|
|
|
| 442 |
)
|
| 443 |
|
| 444 |
success = bool(step_resp.get("done", False))
|
| 445 |
+
|
| 446 |
except Exception as e:
|
| 447 |
+
err = str(e) or "unknown error"
|
| 448 |
+
err = err.replace("\n", " ").replace("\r", " ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
if total_steps < step_limit:
|
| 450 |
+
log_step(
|
| 451 |
+
step=total_steps + 1,
|
| 452 |
+
action="null",
|
| 453 |
+
reward=0.0,
|
| 454 |
+
done=True,
|
| 455 |
+
error=err,
|
| 456 |
)
|
| 457 |
+
|
| 458 |
finally:
|
| 459 |
env_client.close()
|
| 460 |
|
| 461 |
elapsed = time.time() - start_time
|
| 462 |
+
normalized_rewards = [normalize_reward(r, reward_min, reward_max) for r in raw_rewards]
|
| 463 |
+
episode_score = compute_score(normalized_rewards)
|
| 464 |
+
|
| 465 |
+
log_end(
|
| 466 |
+
success=success,
|
| 467 |
+
steps=total_steps,
|
| 468 |
+
score=episode_score,
|
| 469 |
+
rewards=normalized_rewards,
|
|
|
|
|
|
|
|
|
|
| 470 |
)
|
| 471 |
+
|
| 472 |
return {
|
| 473 |
"task_id": task_id,
|
| 474 |
"seed": seed,
|
|
|
|
| 476 |
"total_steps": total_steps,
|
| 477 |
"elapsed_sec": elapsed,
|
| 478 |
"score": episode_score,
|
| 479 |
+
"sub_scores": {},
|
| 480 |
+
"exploit_detected": False,
|
| 481 |
}
|
| 482 |
|
| 483 |
|
| 484 |
+
# ── Environment Server Starter ────────────────────────────────────────────────
|
|
|
|
| 485 |
def start_environment_server(port: int = 7860) -> Optional[subprocess.Popen]:
|
| 486 |
+
"""Start the GridMind-RL environment server as a background process."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
try:
|
| 488 |
r = requests.get(f"http://localhost:{port}/health", timeout=2)
|
| 489 |
if r.status_code == 200:
|
| 490 |
print(f"[INFO] Environment server already running on port {port}", file=sys.stderr)
|
| 491 |
return None
|
| 492 |
except Exception:
|
| 493 |
+
pass
|
| 494 |
+
|
| 495 |
print(f"[INFO] Starting environment server on port {port}...", file=sys.stderr)
|
| 496 |
+
|
|
|
|
| 497 |
try:
|
|
|
|
| 498 |
env = os.environ.copy()
|
| 499 |
env["PORT"] = str(port)
|
| 500 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
binary_paths = [
|
| 502 |
+
"/usr/local/bin/gridmind-server",
|
| 503 |
+
"./gridmind-server",
|
| 504 |
+
"./gridmind-server.exe",
|
| 505 |
]
|
| 506 |
+
|
| 507 |
for binary_path in binary_paths:
|
| 508 |
if os.path.exists(binary_path):
|
| 509 |
try:
|
|
|
|
| 510 |
proc = subprocess.Popen(
|
| 511 |
[binary_path],
|
| 512 |
env=env,
|
|
|
|
| 518 |
return proc
|
| 519 |
except Exception as e:
|
| 520 |
print(f"[DEBUG] Failed with {binary_path}: {e}", file=sys.stderr)
|
| 521 |
+
|
|
|
|
| 522 |
try:
|
| 523 |
+
subprocess.run(
|
| 524 |
+
["go", "build", "-o", "gridmind-server", "main.go"],
|
|
|
|
|
|
|
| 525 |
capture_output=True,
|
| 526 |
timeout=60,
|
| 527 |
cwd=".",
|
| 528 |
)
|
| 529 |
+
proc = subprocess.Popen(["./gridmind-server"], env=env)
|
| 530 |
+
time.sleep(2)
|
| 531 |
+
if proc.poll() is None:
|
| 532 |
+
return proc
|
| 533 |
+
except Exception:
|
| 534 |
+
pass
|
| 535 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
proc = subprocess.Popen(
|
| 537 |
[sys.executable, "-m", "server.app"],
|
| 538 |
env=env,
|
| 539 |
stdout=subprocess.PIPE,
|
| 540 |
stderr=subprocess.PIPE,
|
|
|
|
| 541 |
)
|
| 542 |
time.sleep(3)
|
| 543 |
if proc.poll() is None:
|
| 544 |
return proc
|
| 545 |
+
|
| 546 |
except Exception as e:
|
| 547 |
print(f"[WARNING] Could not start environment server: {e}", file=sys.stderr)
|
| 548 |
+
|
| 549 |
+
return None
|
| 550 |
|
| 551 |
|
| 552 |
+
# ── Main ─────────────────────────────────────────────────────────────────────
|
| 553 |
def main() -> None:
|
| 554 |
+
parser = argparse.ArgumentParser(description="GridMind-RL inference script")
|
| 555 |
parser.add_argument("--episodes", type=int, default=DEFAULT_EPISODES)
|
| 556 |
parser.add_argument("--env-url", type=str, default=ENV_URL)
|
| 557 |
parser.add_argument("--verbose", action="store_true")
|
|
|
|
| 559 |
parser.add_argument(
|
| 560 |
"--fast-mode",
|
| 561 |
action="store_true",
|
| 562 |
+
help="Heuristic policy only (no LLM calls).",
|
| 563 |
)
|
| 564 |
parser.add_argument(
|
| 565 |
"--llm-every",
|
| 566 |
type=int,
|
| 567 |
default=8,
|
| 568 |
metavar="N",
|
| 569 |
+
help="Reuse the same LLM action for N steps (default: 8).",
|
| 570 |
)
|
| 571 |
parser.add_argument(
|
| 572 |
"--max-steps",
|
| 573 |
type=int,
|
| 574 |
default=None,
|
| 575 |
metavar="N",
|
| 576 |
+
help="Stop after N steps.",
|
| 577 |
)
|
| 578 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
|
|
|
|
| 580 |
server_proc = start_environment_server(port=7860)
|
| 581 |
+
|
| 582 |
try:
|
| 583 |
env_client = GridMindEnvClient(base_url=args.env_url)
|
| 584 |
|
|
|
|
| 609 |
)
|
| 610 |
task_scores.append(float(result["score"]))
|
| 611 |
all_results.append(result)
|
|
|
|
| 612 |
|
| 613 |
task_avgs: dict[int, float] = {}
|
| 614 |
for task_id in [1, 2, 3]:
|
| 615 |
scores = [float(r["score"]) for r in all_results if r["task_id"] == task_id]
|
| 616 |
avg = clamp_open_score(sum(scores) / len(scores)) if scores else SCORE_EPSILON
|
| 617 |
task_avgs[task_id] = avg
|
| 618 |
+
|
| 619 |
overall = clamp_open_score(sum(task_avgs.values()) / len(task_avgs))
|
| 620 |
|
| 621 |
output = {
|
|
|
|
| 632 |
}
|
| 633 |
with open(args.output, "w", encoding="utf-8") as f:
|
| 634 |
json.dump(output, f, indent=2)
|
| 635 |
+
|
| 636 |
finally:
|
|
|
|
| 637 |
if server_proc:
|
| 638 |
try:
|
| 639 |
server_proc.terminate()
|
| 640 |
server_proc.wait(timeout=5)
|
| 641 |
+
except Exception:
|
|
|
|
| 642 |
try:
|
| 643 |
server_proc.kill()
|
| 644 |
except Exception:
|