Spaces:
Sleeping
Sleeping
| """ | |
| GridMind-RL Baseline Inference Script | |
| -------------------------------------- | |
| Runs an LLM agent against all 3 tasks for N episodes each. | |
| Uses OpenAI-compatible API via API_BASE_URL / MODEL_NAME / HF_TOKEN environment variables. | |
| Usage: | |
| export MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct | |
| export HF_TOKEN=hf_xxxx | |
| python inference.py | |
| # or: python python/inference.py [--episodes 1] [--llm-every 4] [--fast-mode] | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from typing import Any | |
| import requests | |
| from openai import OpenAI | |
| # ── Constants ────────────────────────────────────────────────────────────── | |
| ENV_URL = os.getenv("ENV_URL", "http://localhost:7860") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| DEFAULT_EPISODES = 1 | |
| DEFAULT_SEED_BASE = 1000 | |
| MAX_RETRIES = 3 | |
| # 96 steps × 15 min = 24 h (must match env.EpisodeSteps) | |
| EPISODE_STEPS = 96 | |
| LAST_STEP_INDEX = EPISODE_STEPS - 1 | |
| SYSPROMPT = """You are GridMind, an expert industrial energy management controller. | |
| You control a building's HVAC, thermal storage, batch job scheduling, and load shedding. | |
| Your goal is to minimize electricity costs while maintaining comfort and meeting grid demand-response signals. | |
| Always respond with a single valid JSON object matching the action schema. No explanation needed.""" | |
| TASK_DESCRIPTIONS = { | |
| 1: "Task 1 (Easy - Cost Minimization): Minimize total energy cost over 24 hours. No temperature or batch constraints. Use cheap off-peak periods and thermal storage.", | |
| 2: "Task 2 (Medium - Temperature Management): Minimize cost AND keep indoor temperature within 19-23°C at all times. Balance comfort vs cost.", | |
| 3: "Task 3 (Hard - Full Demand Response): Minimize cost, maintain temperature, respond to grid stress (shed when grid_stress_signal > 0.7), schedule batch jobs, minimize carbon.", | |
| } | |
| ACTION_SCHEMA_STR = """{ | |
| "hvac_power_level": <float 0.0-1.0>, | |
| "thermal_charge_rate": <float -1.0 to 1.0>, | |
| "batch_job_slot": <int 0-4>, | |
| "load_shed_fraction": <float 0.0-0.5>, | |
| "building_id": 0 | |
| }""" | |
| def extract_json_object(text: str) -> dict[str, Any] | None: | |
| """Parse first balanced {...} JSON object from text (handles nested braces).""" | |
| start = text.find("{") | |
| if start < 0: | |
| return None | |
| depth = 0 | |
| for i in range(start, len(text)): | |
| c = text[i] | |
| if c == "{": | |
| depth += 1 | |
| elif c == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| try: | |
| return json.loads(text[start : i + 1]) | |
| except json.JSONDecodeError: | |
| return None | |
| return None | |
| # ── Environment client ─────────────────────────────────────────────────────── | |
| class GridMindEnvClient: | |
| """Simple HTTP client for the GridMind-RL Go environment server.""" | |
| def __init__(self, base_url: str = ENV_URL, timeout: int = 30): | |
| self.base = base_url.rstrip("/") | |
| self.timeout = timeout | |
| def health(self) -> bool: | |
| try: | |
| r = requests.get(f"{self.base}/health", timeout=5) | |
| return r.status_code == 200 | |
| except Exception: | |
| return False | |
| def reset(self, task_id: int = 1, seed: int = 42, num_buildings: int = 1) -> dict: | |
| payload = {"task_id": task_id, "seed": seed, "num_buildings": num_buildings} | |
| r = requests.post(f"{self.base}/reset", json=payload, timeout=self.timeout) | |
| r.raise_for_status() | |
| return r.json() | |
| def step(self, action: dict) -> dict: | |
| r = requests.post(f"{self.base}/step", json=action, timeout=self.timeout) | |
| r.raise_for_status() | |
| return r.json() | |
| def grade(self) -> dict: | |
| r = requests.get(f"{self.base}/grade", timeout=self.timeout) | |
| r.raise_for_status() | |
| return r.json() | |
| def state(self) -> dict: | |
| r = requests.get(f"{self.base}/state", timeout=self.timeout) | |
| r.raise_for_status() | |
| return r.json() | |
| # ── LLM agent ─────────────────────────────────────────────────────────────── | |
| class LLMAgent: | |
| """OpenAI-compatible LLM agent that chooses actions given observations.""" | |
| def __init__(self): | |
| self.client = OpenAI( | |
| base_url=API_BASE_URL, | |
| api_key=HF_TOKEN if HF_TOKEN else "none", | |
| ) | |
| self.model = MODEL_NAME | |
| self.fallback_mode = False | |
| def choose_action(self, obs: dict, task_id: int) -> dict: | |
| """Prompt the LLM with current observation, return parsed action dict.""" | |
| if self.fallback_mode: | |
| return self._heuristic_action(obs) | |
| task_desc = TASK_DESCRIPTIONS.get(task_id, TASK_DESCRIPTIONS[1]) | |
| prompt = f"""{task_desc} | |
| Current observation: | |
| - Indoor temperature: {obs.get('indoor_temperature', 21):.1f}°C (target: 21°C, bounds: 19-23°C) | |
| - Thermal storage level: {obs.get('thermal_storage_level', 0.5):.2f} (0=empty, 1=full) | |
| - Process demand: {obs.get('process_demand', 15):.1f} kW | |
| - Current electricity price: ${obs.get('current_price', 0.10):.4f}/kWh | |
| - Grid stress signal: {obs.get('grid_stress_signal', 0):.3f} (>0.7 = critical, shed load!) | |
| - Carbon intensity: {obs.get('carbon_intensity', 300):.0f} gCO2/kWh | |
| - Hour of day: {obs.get('hour_of_day', 12)} (0=midnight, peak prices 8-12 and 17-21) | |
| - Pending batch job deadlines: {obs.get('batch_queue', [])} | |
| - Cumulative cost so far: ${obs.get('cumulative_cost', 0):.4f} | |
| - Episode step: {obs.get('step', 0)}/{LAST_STEP_INDEX} | |
| Strategy hints: | |
| - Charge thermal storage when price < $0.08/kWh, discharge when price > $0.15/kWh | |
| - Set HVAC low during peak prices (0.3-0.4) and use storage for temperature control | |
| - Shed 30-50% load if grid_stress_signal > 0.7 | |
| - Schedule batch jobs early if deadline is close (slot 0 or 1) | |
| Respond with ONLY a JSON action: | |
| {ACTION_SCHEMA_STR}""" | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| completion = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": SYSPROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| max_tokens=128, | |
| temperature=0.0, | |
| ) | |
| content = completion.choices[0].message.content.strip() | |
| parsed = extract_json_object(content) | |
| if parsed is not None: | |
| return self._clamp_action(parsed) | |
| action = json.loads(content) | |
| return self._clamp_action(action) | |
| except Exception as e: | |
| err_str = str(e) | |
| print(f" [LLM attempt {attempt+1}/{MAX_RETRIES}] error: {err_str}") | |
| if "402" in err_str or "depleted" in err_str: | |
| print(" [WARN] Hugging Face free credits depleted! Switching to local heuristic agent for the rest of the simulation.") | |
| self.fallback_mode = True | |
| return self._heuristic_action(obs) | |
| time.sleep(1) | |
| return self._heuristic_action(obs) | |
| def _clamp_action(self, action: dict) -> dict: | |
| return { | |
| "hvac_power_level": max(0.0, min(1.0, float(action.get("hvac_power_level", 0.5)))), | |
| "thermal_charge_rate": max(-1.0, min(1.0, float(action.get("thermal_charge_rate", 0.0)))), | |
| "batch_job_slot": max(0, min(4, int(action.get("batch_job_slot", 0)))), | |
| "load_shed_fraction": max(0.0, min(0.5, float(action.get("load_shed_fraction", 0.0)))), | |
| "building_id": int(action.get("building_id", 0)), | |
| } | |
| def _heuristic_action(self, obs: dict) -> dict: | |
| """Rule-based policy (deterministic given obs).""" | |
| price = obs.get("current_price", 0.10) | |
| stress = obs.get("grid_stress_signal", 0.0) | |
| temp = obs.get("indoor_temperature", 21.0) | |
| storage = obs.get("thermal_storage_level", 0.5) | |
| queue = obs.get("batch_queue", []) | |
| hvac = 0.7 if price < 0.08 else (0.3 if price > 0.15 else 0.5) | |
| if temp > 23.0: | |
| hvac = max(hvac, 0.8) | |
| elif temp < 19.0: | |
| hvac = min(hvac, 0.2) | |
| charge = 0.0 | |
| if price < 0.07 and storage < 0.8: | |
| charge = 0.5 | |
| elif price > 0.15 and storage > 0.3: | |
| charge = -0.5 | |
| shed = 0.0 | |
| if stress > 0.7: | |
| shed = 0.4 | |
| elif stress > 0.5: | |
| shed = 0.2 | |
| slot = 2 | |
| if queue and min(queue) < 8: | |
| slot = 0 | |
| return { | |
| "hvac_power_level": hvac, | |
| "thermal_charge_rate": charge, | |
| "batch_job_slot": slot, | |
| "load_shed_fraction": shed, | |
| "building_id": 0, | |
| } | |
| def _default_action(self) -> dict: | |
| return { | |
| "hvac_power_level": 0.5, | |
| "thermal_charge_rate": 0.0, | |
| "batch_job_slot": 0, | |
| "load_shed_fraction": 0.0, | |
| "building_id": 0, | |
| } | |
| # ── Episode runner ─────────────────────────────────────────────────────────── | |
| def run_episode( | |
| env_client: GridMindEnvClient, | |
| agent: LLMAgent, | |
| task_id: int, | |
| seed: int, | |
| *, | |
| fast_mode: bool, | |
| llm_every: int, | |
| max_steps: int | None, | |
| verbose: bool = False, | |
| ) -> dict[str, Any]: | |
| """Run a single episode and return grade + metadata. Prints [START], [STEPn], [END].""" | |
| reset_resp = env_client.reset(task_id=task_id, seed=seed) | |
| obs = reset_resp["observations"][0] | |
| print("[START]", flush=True) | |
| total_reward = 0.0 | |
| total_steps = 0 | |
| start_time = time.time() | |
| step_resp: dict[str, Any] = {} | |
| step_limit = EPISODE_STEPS if max_steps is None else min(max_steps, EPISODE_STEPS) | |
| llm_reuse_remaining = 0 | |
| cached_action = agent._default_action() | |
| while not step_resp.get("done", False): | |
| if total_steps >= step_limit: | |
| break | |
| if fast_mode: | |
| action = agent._heuristic_action(obs) | |
| else: | |
| if llm_reuse_remaining <= 0: | |
| cached_action = agent.choose_action(obs, task_id) | |
| llm_reuse_remaining = max(1, llm_every) | |
| action = cached_action | |
| step_resp = env_client.step(action) | |
| if step_resp is None or "observation" not in step_resp: | |
| print(f" [WARN] step {total_steps}: invalid step response", flush=True) | |
| break | |
| if not fast_mode: | |
| llm_reuse_remaining -= 1 | |
| obs = step_resp["observation"] | |
| total_reward += float(step_resp["reward"]) | |
| total_steps += 1 | |
| print(f"[STEP{total_steps}]", flush=True) | |
| if verbose and total_steps % 16 == 0: | |
| print( | |
| f" step={total_steps:02d} price=${obs['current_price']:.3f} " | |
| f"temp={obs['indoor_temperature']:.1f}°C " | |
| f"stress={obs['grid_stress_signal']:.2f} " | |
| f"cost=${obs['cumulative_cost']:.2f} " | |
| f"reward={step_resp['reward']:.3f}", | |
| flush=True, | |
| ) | |
| elapsed = time.time() - start_time | |
| grade = env_client.grade() | |
| print("[END]", flush=True) | |
| return { | |
| "task_id": task_id, | |
| "seed": seed, | |
| "total_reward": total_reward, | |
| "total_steps": total_steps, | |
| "elapsed_sec": elapsed, | |
| "score": grade.get("score", 0.0), | |
| "sub_scores": grade.get("sub_scores", {}), | |
| "exploit_detected": grade.get("exploit_detected", False), | |
| } | |
| # ── Main ───────────────────────────────────────────────────────────────────── | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="GridMind-RL baseline inference") | |
| parser.add_argument("--episodes", type=int, default=DEFAULT_EPISODES) | |
| parser.add_argument("--env-url", type=str, default=ENV_URL) | |
| parser.add_argument("--verbose", action="store_true") | |
| parser.add_argument("--output", type=str, default="baseline_scores.json") | |
| parser.add_argument( | |
| "--fast-mode", | |
| action="store_true", | |
| help="Heuristic policy only (no LLM calls; fastest, fully reproducible).", | |
| ) | |
| parser.add_argument( | |
| "--llm-every", | |
| type=int, | |
| default=4, | |
| metavar="N", | |
| help="Reuse the same LLM action for N consecutive steps (default: 4).", | |
| ) | |
| parser.add_argument( | |
| "--max-steps", | |
| type=int, | |
| default=None, | |
| metavar="N", | |
| help="Stop after N steps (default: full episode). Grade uses partial episode.", | |
| ) | |
| args = parser.parse_args() | |
| print("=" * 60) | |
| print("GridMind-RL Baseline Inference") | |
| print(f" Model: {MODEL_NAME}") | |
| print(f" API: {API_BASE_URL}") | |
| print(f" Env: {args.env_url}") | |
| print(f" Episodes per task: {args.episodes}") | |
| print(f" Fast mode: {args.fast_mode} | LLM every: {args.llm_every} steps") | |
| print("=" * 60) | |
| env_client = GridMindEnvClient(base_url=args.env_url) | |
| print("\nWaiting for environment server...") | |
| for attempt in range(30): | |
| if env_client.health(): | |
| print(" [OK] Environment server is healthy") | |
| break | |
| time.sleep(2) | |
| if attempt == 29: | |
| print(" [FAIL] Environment server not reachable. Exiting.") | |
| sys.exit(1) | |
| agent = LLMAgent() | |
| all_results: list[dict[str, Any]] = [] | |
| for task_id in [1, 2, 3]: | |
| print(f"\n-- Task {task_id}: {TASK_DESCRIPTIONS[task_id][:60]}...") | |
| task_scores: list[float] = [] | |
| for ep in range(args.episodes): | |
| seed = DEFAULT_SEED_BASE + task_id * 100 + ep | |
| print(f" Episode {ep+1}/{args.episodes} (seed={seed})") | |
| result = run_episode( | |
| env_client, | |
| agent, | |
| task_id=task_id, | |
| seed=seed, | |
| fast_mode=args.fast_mode, | |
| llm_every=args.llm_every, | |
| max_steps=args.max_steps, | |
| verbose=args.verbose, | |
| ) | |
| task_scores.append(float(result["score"])) | |
| all_results.append(result) | |
| print( | |
| f" → score={result['score']:.4f} | reward={result['total_reward']:.3f} | " | |
| f"{result['elapsed_sec']:.1f}s | steps={result['total_steps']}" | |
| ) | |
| avg_score = sum(task_scores) / len(task_scores) | |
| print(f" Task {task_id} average score: {avg_score:.4f}") | |
| print("\n" + "=" * 60) | |
| print("BASELINE SCORES SUMMARY") | |
| print("=" * 60) | |
| print(f"{'Task':<10} {'Model':<30} {'Score':<10} {'Episodes':<10}") | |
| print("-" * 60) | |
| task_avgs: dict[int, float] = {} | |
| for task_id in [1, 2, 3]: | |
| scores = [float(r["score"]) for r in all_results if r["task_id"] == task_id] | |
| avg = sum(scores) / len(scores) if scores else 0.0 | |
| task_avgs[task_id] = avg | |
| print(f"Task {task_id:<6} {MODEL_NAME:<30} {avg:<10.4f} {len(scores)}") | |
| print("-" * 60) | |
| overall = sum(task_avgs.values()) / len(task_avgs) | |
| print(f"{'Overall':<10} {'':<30} {overall:<10.4f}") | |
| output = { | |
| "model": MODEL_NAME, | |
| "api_base": API_BASE_URL, | |
| "episodes_per_task": args.episodes, | |
| "seed_base": DEFAULT_SEED_BASE, | |
| "fast_mode": args.fast_mode, | |
| "llm_every": args.llm_every, | |
| "max_steps": args.max_steps, | |
| "task_averages": {str(k): v for k, v in task_avgs.items()}, | |
| "overall_average": overall, | |
| "all_results": all_results, | |
| } | |
| with open(args.output, "w", encoding="utf-8") as f: | |
| json.dump(output, f, indent=2) | |
| print(f"\n[OK] Results saved to {args.output}") | |
| if __name__ == "__main__": | |
| main() | |