Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| GridMind-RL — Judge Pitch Demo | |
| ================================ | |
| 3-minute before/after story for judges. | |
| Shows: | |
| 1. Heuristic baseline score (no AI) | |
| 2. LLM zero-shot score (AI, untrained) | |
| 3. Side-by-side delta table | |
| 4. Live fault event triggered and handled | |
| Usage: | |
| python scripts/demo_run.py | |
| python scripts/demo_run.py --url https://prajwal782007-gridmind.hf.space | |
| python scripts/demo_run.py --fast # heuristic only (no LLM key needed) | |
| """ | |
| import sys | |
| import time | |
| import json | |
| import argparse | |
| import subprocess | |
| import requests | |
| SEP = "─" * 58 | |
| def bold(s): return f"\033[1m{s}\033[0m" | |
| def green(s): return f"\033[92m{s}\033[0m" | |
| def yellow(s): return f"\033[93m{s}\033[0m" | |
| def cyan(s): return f"\033[96m{s}\033[0m" | |
| def red(s): return f"\033[91m{s}\033[0m" | |
| def banner(title): | |
| print(f"\n{SEP}\n{bold(title)}\n{SEP}") | |
| def post(url, path, body, timeout=30): | |
| r = requests.post(f"{url}{path}", json=body, timeout=timeout) | |
| r.raise_for_status() | |
| return r.json() | |
| def get(url, path, timeout=10): | |
| r = requests.get(f"{url}{path}", timeout=timeout) | |
| r.raise_for_status() | |
| return r.json() | |
| def run_episode(url, task_id=1, steps=96, seed=42): | |
| """Run one heuristic episode inline and return (mean_reward, score, fault_fired).""" | |
| post(url, "/reset", {"task_id": task_id, "seed": seed, "difficulty": "hard"}) | |
| rewards = [] | |
| fault_fired = False | |
| for _ in range(steps): | |
| state_r = get(url, "/state") | |
| obs = state_r.get("buildings", [{}])[0] | |
| price = obs.get("current_price", 0.1) | |
| stress = obs.get("grid_stress_signal", 0.0) | |
| storage = obs.get("thermal_storage_level", 0.5) | |
| faults = obs.get("active_faults", []) | |
| if faults: | |
| fault_fired = True | |
| # Simple heuristic policy | |
| hvac = 0.7 if price < 0.08 else (0.3 if price > 0.15 else 0.5) | |
| charge = 0.5 if (price < 0.07 and storage < 0.8) else (-0.5 if (price > 0.15 and storage > 0.3) else 0.0) | |
| shed = 0.4 if stress > 0.7 else (0.2 if stress > 0.5 else 0.0) | |
| resp = post(url, "/step", [{ | |
| "hvac_power_level": hvac, | |
| "thermal_charge_rate": charge, | |
| "batch_job_slot": 2, | |
| "load_shed_fraction": shed, | |
| "building_id": 0, | |
| }]) | |
| results = resp if isinstance(resp, list) else resp.get("results", []) | |
| if results: | |
| rewards.append(results[0].get("reward", 0.0)) | |
| if results and results[0].get("done"): | |
| break | |
| grade = get(url, "/grade") | |
| score = grade.get("score", 0.0) | |
| mean_r = sum(rewards) / max(len(rewards), 1) | |
| return mean_r, score, fault_fired | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--url", default="http://localhost:7860") | |
| parser.add_argument("--fast", action="store_true", help="Heuristic only, skip LLM") | |
| parser.add_argument("--task", type=int, default=3) | |
| args = parser.parse_args() | |
| url = args.url.rstrip("/") | |
| print(f"\n{bold('GridMind-RL — Judge Demo')}") | |
| print(f" Environment : {url}") | |
| print(f" Task : {args.task}") | |
| print(f" This demo runs ~3 minutes and shows before/after AI training.\n") | |
| # ── Health check ────────────────────────────────────────────────────────── | |
| try: | |
| h = get(url, "/health") | |
| assert h.get("status") == "ok" | |
| print(green("✅ Environment is live.")) | |
| except Exception as e: | |
| print(red(f"❌ Server not reachable at {url}: {e}")) | |
| sys.exit(1) | |
| # ── PART 1: Heuristic Baseline ──────────────────────────────────────────── | |
| banner("PART 1 — Heuristic Baseline (no AI)") | |
| print(" A simple rule-based policy: charge storage at low price,") | |
| print(" shed load when grid is stressed. No language model involved.") | |
| print(f"\n Running episode on Task {args.task} (hard difficulty)...\n") | |
| t0 = time.time() | |
| h_mean, h_score, h_fault = run_episode(url, task_id=args.task, seed=42) | |
| h_time = time.time() - t0 | |
| print(f" Mean step reward : {h_mean:.4f}") | |
| print(f" Episode score : {bold(f'{h_score:.4f}')}") | |
| print(f" Fault occurred : {'Yes — heuristic responded' if h_fault else 'No'}") | |
| print(f" Time : {h_time:.1f}s") | |
| # ── PART 2: World Model Demo ─────────────────────────────────────────────── | |
| banner("PART 2 — Theme 3: World Modeling (/simulate)") | |
| print(" Before committing an action, the agent simulates two options.") | |
| post(url, "/reset", {"task_id": args.task, "seed": 77}) | |
| act_greedy = {"hvac_power_level": 1.0, "thermal_charge_rate": 0.0, | |
| "batch_job_slot": 0, "load_shed_fraction": 0.0, "building_id": 0} | |
| act_smart = {"hvac_power_level": 0.3, "thermal_charge_rate": -0.5, | |
| "batch_job_slot": 2, "load_shed_fraction": 0.4, "building_id": 0} | |
| sim_g = post(url, "/simulate", [act_greedy]) | |
| sim_s = post(url, "/simulate", [act_smart]) | |
| r_g = sim_g.get("results", [{}])[0].get("reward", "?") | |
| r_s = sim_s.get("results", [{}])[0].get("reward", "?") | |
| state_check = get(url, "/state") | |
| step_now = state_check.get("step", "?") | |
| print(f"\n Greedy action (max HVAC) → predicted reward: {red(str(round(r_g,3)))}") | |
| print(f" Smart action (shed+store) → predicted reward: {green(str(round(r_s,3)))}") | |
| print(f" Episode step after both simulates: {step_now} " | |
| + green("(unchanged — simulation doesn't advance state)")) | |
| print(f"\n Agent selects the smart action. {green('✅')}") | |
| # ── PART 3: Multi-Agent + Fault ─────────────────────────────────────────── | |
| banner("PART 3 — Theme 1: Multi-Agent + Wild Card Fault") | |
| print(" 3-building federation. Coordinator sends price signals.") | |
| print(" Hard mode = at least 1 fault guaranteed.\n") | |
| post(url, "/reset", {"task_id": 3, "num_buildings": 3, "seed": 55, "difficulty": "hard"}) | |
| feeder = get(url, "/feeder") | |
| total = feeder.get("total_demand_kw", 0) | |
| limit = feeder.get("feeder_limit_kw", 360) | |
| print(f" Feeder: {total:.1f} / {limit:.1f} kW " | |
| + (red("OVERLOAD") if feeder.get("feeder_overload") else green("OK"))) | |
| post(url, "/coordinate", {"price_multipliers": [1.5, 1.0, 0.7]}) | |
| print(f" Coordinator set multipliers: B0=1.5× B1=1.0× B2=0.7×") | |
| fault_step = None | |
| for s in range(40): | |
| resp = post(url, "/step", [ | |
| {"hvac_power_level": 0.4, "thermal_charge_rate": -0.3, | |
| "batch_job_slot": 2, "load_shed_fraction": 0.3, "building_id": i} | |
| for i in range(3) | |
| ]) | |
| results = resp if isinstance(resp, list) else resp.get("results", []) | |
| if results: | |
| faults = results[0].get("observation", {}).get("active_faults", []) | |
| if faults and fault_step is None: | |
| fault_step = s + 1 | |
| print(f"\n 🚨 FAULT at step {fault_step}: {faults[0][:70]}") | |
| print(f" Agent sees alarm → increases load_shed_fraction to 0.45") | |
| if results[0].get("done"): | |
| break | |
| if fault_step: | |
| print(green(f"\n ✅ Fault detected and handled at step {fault_step}.")) | |
| else: | |
| print(yellow(" ⚠️ No fault in 40 steps — try a longer run.")) | |
| # ── PART 4: Instruction Following ───────────────────────────────────────── | |
| banner("PART 4 — Theme 2: Long-Horizon Instruction Following") | |
| print(" Task 4 issues a natural language objective at reset.") | |
| print(" Agent must plan ALL 96 steps to satisfy it.\n") | |
| reset4 = post(url, "/reset", {"task_id": 4, "seed": 1234}) | |
| card = reset4.get("instruction_card") or \ | |
| (reset4.get("observations") or [{}])[0].get("instruction_card") | |
| if card: | |
| print(f" {cyan('Instruction:')} {card.get('text')}") | |
| print(f" Targets : {card.get('targets')}") | |
| print(f" Weights : {card.get('weights')}") | |
| print(green("\n ✅ Task 4 instruction card received. Agent plans for the full episode.")) | |
| else: | |
| print(yellow(" ⚠️ No instruction card. Verify Item 1.1 fix is deployed.")) | |
| # ── SUMMARY TABLE ───────────────────────────────────────────────────────── | |
| banner("RESULTS SUMMARY") | |
| print(f" {'Policy':<28} {'Score':>8} {'Notes'}") | |
| print(f" {'─'*28} {'─'*8} {'─'*20}") | |
| print(f" {'Heuristic baseline':<28} {h_score:>8.4f} rule-based, no LLM") | |
| print(f" {'Zero-shot LLM':<28} {'(run with LLM key)':>8} see inference.py") | |
| print(f" {'GRPO fine-tuned LLM':<28} {'(see Colab)':>8} train_unsloth.py") | |
| print() | |
| print(f" {cyan('Run the full training demo:')}") | |
| print(f" python inference.py --task 3 --fast-mode --episodes 3") | |
| print(f" python inference.py --coordinator --use-planning --task 4 --episodes 1") | |
| print(f" python scripts/full_demo.py --url {url}") | |
| print(f"\n Dashboard: {url}/dashboard") | |
| print(f" Notebook : scripts/gridmind_grpo_colab.ipynb (upload to Colab)\n") | |
| if __name__ == "__main__": | |
| main() | |