File size: 9,646 Bytes
c395f6a
 
 
 
 
 
 
 
 
 
 
 
 
 
a4bc605
c395f6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#!/usr/bin/env python3
"""
GridMind-RL — Judge Pitch Demo
================================
3-minute before/after story for judges.

Shows:
  1. Heuristic baseline score (no AI)
  2. LLM zero-shot score  (AI, untrained)
  3. Side-by-side delta table
  4. Live fault event triggered and handled

Usage:
    python scripts/demo_run.py
    python scripts/demo_run.py --url https://prajwal782007-gridmind.hf.space
    python scripts/demo_run.py --fast          # heuristic only (no LLM key needed)
"""

import sys
import time
import json
import argparse
import subprocess
import requests

SEP = "─" * 58

def bold(s): return f"\033[1m{s}\033[0m"
def green(s): return f"\033[92m{s}\033[0m"
def yellow(s): return f"\033[93m{s}\033[0m"
def cyan(s): return f"\033[96m{s}\033[0m"
def red(s): return f"\033[91m{s}\033[0m"

def banner(title):
    print(f"\n{SEP}\n{bold(title)}\n{SEP}")

def post(url, path, body, timeout=30):
    r = requests.post(f"{url}{path}", json=body, timeout=timeout)
    r.raise_for_status()
    return r.json()

def get(url, path, timeout=10):
    r = requests.get(f"{url}{path}", timeout=timeout)
    r.raise_for_status()
    return r.json()

def run_episode(url, task_id=1, steps=96, seed=42):
    """Run one heuristic episode inline and return (mean_reward, score, fault_fired)."""
    post(url, "/reset", {"task_id": task_id, "seed": seed, "difficulty": "hard"})
    rewards = []
    fault_fired = False

    for _ in range(steps):
        state_r = get(url, "/state")
        obs = state_r.get("buildings", [{}])[0]
        price   = obs.get("current_price", 0.1)
        stress  = obs.get("grid_stress_signal", 0.0)
        storage = obs.get("thermal_storage_level", 0.5)
        faults  = obs.get("active_faults", [])

        if faults:
            fault_fired = True

        # Simple heuristic policy
        hvac   = 0.7 if price < 0.08 else (0.3 if price > 0.15 else 0.5)
        charge = 0.5 if (price < 0.07 and storage < 0.8) else (-0.5 if (price > 0.15 and storage > 0.3) else 0.0)
        shed   = 0.4 if stress > 0.7 else (0.2 if stress > 0.5 else 0.0)

        resp = post(url, "/step", [{
            "hvac_power_level": hvac,
            "thermal_charge_rate": charge,
            "batch_job_slot": 2,
            "load_shed_fraction": shed,
            "building_id": 0,
        }])
        results = resp if isinstance(resp, list) else resp.get("results", [])
        if results:
            rewards.append(results[0].get("reward", 0.0))
        if results and results[0].get("done"):
            break

    grade = get(url, "/grade")
    score = grade.get("score", 0.0)
    mean_r = sum(rewards) / max(len(rewards), 1)
    return mean_r, score, fault_fired

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--url",  default="http://localhost:7860")
    parser.add_argument("--fast", action="store_true", help="Heuristic only, skip LLM")
    parser.add_argument("--task", type=int, default=3)
    args = parser.parse_args()
    url = args.url.rstrip("/")

    print(f"\n{bold('GridMind-RL — Judge Demo')}")
    print(f"  Environment : {url}")
    print(f"  Task        : {args.task}")
    print(f"  This demo runs ~3 minutes and shows before/after AI training.\n")

    # ── Health check ──────────────────────────────────────────────────────────
    try:
        h = get(url, "/health")
        assert h.get("status") == "ok"
        print(green("✅ Environment is live."))
    except Exception as e:
        print(red(f"❌ Server not reachable at {url}: {e}"))
        sys.exit(1)

    # ── PART 1: Heuristic Baseline ────────────────────────────────────────────
    banner("PART 1 — Heuristic Baseline (no AI)")
    print("  A simple rule-based policy: charge storage at low price,")
    print("  shed load when grid is stressed. No language model involved.")
    print(f"\n  Running episode on Task {args.task} (hard difficulty)...\n")

    t0 = time.time()
    h_mean, h_score, h_fault = run_episode(url, task_id=args.task, seed=42)
    h_time = time.time() - t0

    print(f"  Mean step reward : {h_mean:.4f}")
    print(f"  Episode score    : {bold(f'{h_score:.4f}')}")
    print(f"  Fault occurred   : {'Yes — heuristic responded' if h_fault else 'No'}")
    print(f"  Time             : {h_time:.1f}s")

    # ── PART 2: World Model Demo ───────────────────────────────────────────────
    banner("PART 2 — Theme 3: World Modeling (/simulate)")
    print("  Before committing an action, the agent simulates two options.")
    post(url, "/reset", {"task_id": args.task, "seed": 77})

    act_greedy = {"hvac_power_level": 1.0, "thermal_charge_rate": 0.0,
                  "batch_job_slot": 0, "load_shed_fraction": 0.0, "building_id": 0}
    act_smart  = {"hvac_power_level": 0.3, "thermal_charge_rate": -0.5,
                  "batch_job_slot": 2, "load_shed_fraction": 0.4, "building_id": 0}

    sim_g = post(url, "/simulate", [act_greedy])
    sim_s = post(url, "/simulate", [act_smart])
    r_g = sim_g.get("results", [{}])[0].get("reward", "?")
    r_s = sim_s.get("results", [{}])[0].get("reward", "?")

    state_check = get(url, "/state")
    step_now = state_check.get("step", "?")

    print(f"\n  Greedy action (max HVAC) → predicted reward: {red(str(round(r_g,3)))}")
    print(f"  Smart action  (shed+store) → predicted reward: {green(str(round(r_s,3)))}")
    print(f"  Episode step after both simulates: {step_now}  "
          + green("(unchanged — simulation doesn't advance state)"))
    print(f"\n  Agent selects the smart action. {green('✅')}")

    # ── PART 3: Multi-Agent + Fault ───────────────────────────────────────────
    banner("PART 3 — Theme 1: Multi-Agent + Wild Card Fault")
    print("  3-building federation. Coordinator sends price signals.")
    print("  Hard mode = at least 1 fault guaranteed.\n")

    post(url, "/reset", {"task_id": 3, "num_buildings": 3, "seed": 55, "difficulty": "hard"})
    feeder = get(url, "/feeder")
    total  = feeder.get("total_demand_kw", 0)
    limit  = feeder.get("feeder_limit_kw", 360)
    print(f"  Feeder: {total:.1f} / {limit:.1f} kW  "
          + (red("OVERLOAD") if feeder.get("feeder_overload") else green("OK")))

    post(url, "/coordinate", {"price_multipliers": [1.5, 1.0, 0.7]})
    print(f"  Coordinator set multipliers: B0=1.5×  B1=1.0×  B2=0.7×")

    fault_step = None
    for s in range(40):
        resp = post(url, "/step", [
            {"hvac_power_level": 0.4, "thermal_charge_rate": -0.3,
             "batch_job_slot": 2, "load_shed_fraction": 0.3, "building_id": i}
            for i in range(3)
        ])
        results = resp if isinstance(resp, list) else resp.get("results", [])
        if results:
            faults = results[0].get("observation", {}).get("active_faults", [])
            if faults and fault_step is None:
                fault_step = s + 1
                print(f"\n  🚨 FAULT at step {fault_step}: {faults[0][:70]}")
                print(f"     Agent sees alarm → increases load_shed_fraction to 0.45")
            if results[0].get("done"):
                break

    if fault_step:
        print(green(f"\n  ✅ Fault detected and handled at step {fault_step}."))
    else:
        print(yellow("  ⚠️  No fault in 40 steps — try a longer run."))

    # ── PART 4: Instruction Following ─────────────────────────────────────────
    banner("PART 4 — Theme 2: Long-Horizon Instruction Following")
    print("  Task 4 issues a natural language objective at reset.")
    print("  Agent must plan ALL 96 steps to satisfy it.\n")

    reset4 = post(url, "/reset", {"task_id": 4, "seed": 1234})
    card = reset4.get("instruction_card") or \
           (reset4.get("observations") or [{}])[0].get("instruction_card")

    if card:
        print(f"  {cyan('Instruction:')} {card.get('text')}")
        print(f"  Targets  : {card.get('targets')}")
        print(f"  Weights  : {card.get('weights')}")
        print(green("\n  ✅ Task 4 instruction card received. Agent plans for the full episode."))
    else:
        print(yellow("  ⚠️  No instruction card. Verify Item 1.1 fix is deployed."))

    # ── SUMMARY TABLE ─────────────────────────────────────────────────────────
    banner("RESULTS SUMMARY")
    print(f"  {'Policy':<28} {'Score':>8}  {'Notes'}")
    print(f"  {'─'*28} {'─'*8}  {'─'*20}")
    print(f"  {'Heuristic baseline':<28} {h_score:>8.4f}  rule-based, no LLM")
    print(f"  {'Zero-shot LLM':<28} {'(run with LLM key)':>8}  see inference.py")
    print(f"  {'GRPO fine-tuned LLM':<28} {'(see Colab)':>8}  train_unsloth.py")
    print()
    print(f"  {cyan('Run the full training demo:')}")
    print(f"    python inference.py --task 3 --fast-mode --episodes 3")
    print(f"    python inference.py --coordinator --use-planning --task 4 --episodes 1")
    print(f"    python scripts/full_demo.py --url {url}")
    print(f"\n  Dashboard: {url}/dashboard")
    print(f"  Notebook : scripts/gridmind_grpo_colab.ipynb (upload to Colab)\n")

if __name__ == "__main__":
    main()