File size: 16,381 Bytes
1875b13
 
 
 
 
 
 
 
 
6d74982
 
1875b13
 
6d74982
 
1875b13
 
 
 
 
 
 
 
 
 
 
 
 
 
2e0c292
1875b13
6d74982
 
1875b13
6d74982
 
 
1875b13
 
 
 
 
 
 
6d74982
1875b13
6d74982
1875b13
 
 
 
 
 
 
 
 
 
 
6d74982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1875b13
 
6d74982
1875b13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d74982
1875b13
 
 
 
 
 
 
 
 
2e0c292
1875b13
 
 
2e0c292
 
1875b13
 
 
 
 
 
 
 
 
 
 
 
 
 
6d74982
1875b13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d74982
1875b13
 
6d74982
 
 
 
 
1875b13
2e0c292
 
 
 
 
 
1875b13
 
 
 
 
 
 
 
 
 
 
 
 
 
6d74982
1875b13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d74982
1875b13
 
 
 
 
 
 
 
 
 
 
6d74982
 
 
 
 
 
 
1875b13
 
 
 
6d74982
 
 
 
 
 
 
 
 
 
 
 
 
1875b13
 
 
6d74982
 
1875b13
 
 
6d74982
 
 
 
 
1875b13
4c1963b
6d74982
 
1875b13
6d74982
 
 
 
 
 
 
 
 
2787b1e
6d74982
2787b1e
 
6d74982
 
 
1875b13
6d74982
1875b13
6d74982
 
 
 
 
 
 
 
 
 
 
1875b13
 
 
6d74982
 
1875b13
 
 
 
 
 
 
 
 
 
 
 
 
 
6d74982
 
1875b13
 
 
 
 
6d74982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1875b13
 
 
 
 
 
 
 
6d74982
1875b13
 
 
 
 
 
 
6d74982
1875b13
 
 
6d74982
1875b13
 
 
6d74982
1875b13
 
6d74982
 
1875b13
 
 
6d74982
 
 
 
 
 
 
 
 
 
 
1875b13
6d74982
 
 
 
1875b13
 
 
 
 
 
 
 
 
 
6d74982
1875b13
6d74982
1875b13
 
 
 
 
 
 
 
 
 
 
 
 
6d74982
 
 
1875b13
 
 
 
6d74982
1875b13
6d74982
1875b13
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
"""
GridMind-RL Baseline Inference Script
--------------------------------------
Runs an LLM agent against all 3 tasks for N episodes each.
Uses OpenAI-compatible API via API_BASE_URL / MODEL_NAME / HF_TOKEN environment variables.

Usage:
    export MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct
    export HF_TOKEN=hf_xxxx
    python inference.py
    # or: python python/inference.py [--episodes 1] [--llm-every 4] [--fast-mode]
"""

from __future__ import annotations

import argparse
import json
import os
import sys
import time
from typing import Any

import requests
from openai import OpenAI

# ── Constants ──────────────────────────────────────────────────────────────

ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
HF_TOKEN = os.getenv("HF_TOKEN", "")
DEFAULT_EPISODES = 1
DEFAULT_SEED_BASE = 1000
MAX_RETRIES = 3
# 96 steps × 15 min = 24 h (must match env.EpisodeSteps)
EPISODE_STEPS = 96
LAST_STEP_INDEX = EPISODE_STEPS - 1

SYSPROMPT = """You are GridMind, an expert industrial energy management controller.
You control a building's HVAC, thermal storage, batch job scheduling, and load shedding.
Your goal is to minimize electricity costs while maintaining comfort and meeting grid demand-response signals.
Always respond with a single valid JSON object matching the action schema. No explanation needed."""

TASK_DESCRIPTIONS = {
    1: "Task 1 (Easy - Cost Minimization): Minimize total energy cost over 24 hours. No temperature or batch constraints. Use cheap off-peak periods and thermal storage.",
    2: "Task 2 (Medium - Temperature Management): Minimize cost AND keep indoor temperature within 19-23°C at all times. Balance comfort vs cost.",
    3: "Task 3 (Hard - Full Demand Response): Minimize cost, maintain temperature, respond to grid stress (shed when grid_stress_signal > 0.7), schedule batch jobs, minimize carbon.",
}

ACTION_SCHEMA_STR = """{
  "hvac_power_level": <float 0.0-1.0>,
  "thermal_charge_rate": <float -1.0 to 1.0>,
  "batch_job_slot": <int 0-4>,
  "load_shed_fraction": <float 0.0-0.5>,
  "building_id": 0
}"""


def extract_json_object(text: str) -> dict[str, Any] | None:
    """Parse first balanced {...} JSON object from text (handles nested braces)."""
    start = text.find("{")
    if start < 0:
        return None
    depth = 0
    for i in range(start, len(text)):
        c = text[i]
        if c == "{":
            depth += 1
        elif c == "}":
            depth -= 1
            if depth == 0:
                try:
                    return json.loads(text[start : i + 1])
                except json.JSONDecodeError:
                    return None
    return None


# ── Environment client ───────────────────────────────────────────────────────


class GridMindEnvClient:
    """Simple HTTP client for the GridMind-RL Go environment server."""

    def __init__(self, base_url: str = ENV_URL, timeout: int = 30):
        self.base = base_url.rstrip("/")
        self.timeout = timeout

    def health(self) -> bool:
        try:
            r = requests.get(f"{self.base}/health", timeout=5)
            return r.status_code == 200
        except Exception:
            return False

    def reset(self, task_id: int = 1, seed: int = 42, num_buildings: int = 1) -> dict:
        payload = {"task_id": task_id, "seed": seed, "num_buildings": num_buildings}
        r = requests.post(f"{self.base}/reset", json=payload, timeout=self.timeout)
        r.raise_for_status()
        return r.json()

    def step(self, action: dict) -> dict:
        r = requests.post(f"{self.base}/step", json=action, timeout=self.timeout)
        r.raise_for_status()
        return r.json()

    def grade(self) -> dict:
        r = requests.get(f"{self.base}/grade", timeout=self.timeout)
        r.raise_for_status()
        return r.json()

    def state(self) -> dict:
        r = requests.get(f"{self.base}/state", timeout=self.timeout)
        r.raise_for_status()
        return r.json()


# ── LLM agent ───────────────────────────────────────────────────────────────


class LLMAgent:
    """OpenAI-compatible LLM agent that chooses actions given observations."""

    def __init__(self):
        self.client = OpenAI(
            base_url=API_BASE_URL,
            api_key=HF_TOKEN if HF_TOKEN else "none",
        )
        self.model = MODEL_NAME
        self.fallback_mode = False

    def choose_action(self, obs: dict, task_id: int) -> dict:
        """Prompt the LLM with current observation, return parsed action dict."""
        if self.fallback_mode:
            return self._heuristic_action(obs)
        task_desc = TASK_DESCRIPTIONS.get(task_id, TASK_DESCRIPTIONS[1])

        prompt = f"""{task_desc}

Current observation:
- Indoor temperature: {obs.get('indoor_temperature', 21):.1f}°C (target: 21°C, bounds: 19-23°C)
- Thermal storage level: {obs.get('thermal_storage_level', 0.5):.2f} (0=empty, 1=full)
- Process demand: {obs.get('process_demand', 15):.1f} kW
- Current electricity price: ${obs.get('current_price', 0.10):.4f}/kWh
- Grid stress signal: {obs.get('grid_stress_signal', 0):.3f} (>0.7 = critical, shed load!)
- Carbon intensity: {obs.get('carbon_intensity', 300):.0f} gCO2/kWh
- Hour of day: {obs.get('hour_of_day', 12)} (0=midnight, peak prices 8-12 and 17-21)
- Pending batch job deadlines: {obs.get('batch_queue', [])}
- Cumulative cost so far: ${obs.get('cumulative_cost', 0):.4f}
- Episode step: {obs.get('step', 0)}/{LAST_STEP_INDEX}

Strategy hints:
- Charge thermal storage when price < $0.08/kWh, discharge when price > $0.15/kWh
- Set HVAC low during peak prices (0.3-0.4) and use storage for temperature control
- Shed 30-50% load if grid_stress_signal > 0.7
- Schedule batch jobs early if deadline is close (slot 0 or 1)

Respond with ONLY a JSON action:
{ACTION_SCHEMA_STR}"""

        for attempt in range(MAX_RETRIES):
            try:
                completion = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": SYSPROMPT},
                        {"role": "user", "content": prompt},
                    ],
                    max_tokens=128,
                    temperature=0.0,
                )
                content = completion.choices[0].message.content.strip()
                parsed = extract_json_object(content)
                if parsed is not None:
                    return self._clamp_action(parsed)
                action = json.loads(content)
                return self._clamp_action(action)
            except Exception as e:
                err_str = str(e)
                print(f"  [LLM attempt {attempt+1}/{MAX_RETRIES}] error: {err_str}")
                if "402" in err_str or "depleted" in err_str:
                    print("  [WARN] Hugging Face free credits depleted! Switching to local heuristic agent for the rest of the simulation.")
                    self.fallback_mode = True
                    return self._heuristic_action(obs)
                time.sleep(1)

        return self._heuristic_action(obs)

    def _clamp_action(self, action: dict) -> dict:
        return {
            "hvac_power_level": max(0.0, min(1.0, float(action.get("hvac_power_level", 0.5)))),
            "thermal_charge_rate": max(-1.0, min(1.0, float(action.get("thermal_charge_rate", 0.0)))),
            "batch_job_slot": max(0, min(4, int(action.get("batch_job_slot", 0)))),
            "load_shed_fraction": max(0.0, min(0.5, float(action.get("load_shed_fraction", 0.0)))),
            "building_id": int(action.get("building_id", 0)),
        }

    def _heuristic_action(self, obs: dict) -> dict:
        """Rule-based policy (deterministic given obs)."""
        price = obs.get("current_price", 0.10)
        stress = obs.get("grid_stress_signal", 0.0)
        temp = obs.get("indoor_temperature", 21.0)
        storage = obs.get("thermal_storage_level", 0.5)
        queue = obs.get("batch_queue", [])

        hvac = 0.7 if price < 0.08 else (0.3 if price > 0.15 else 0.5)
        if temp > 23.0:
            hvac = max(hvac, 0.8)
        elif temp < 19.0:
            hvac = min(hvac, 0.2)

        charge = 0.0
        if price < 0.07 and storage < 0.8:
            charge = 0.5
        elif price > 0.15 and storage > 0.3:
            charge = -0.5

        shed = 0.0
        if stress > 0.7:
            shed = 0.4
        elif stress > 0.5:
            shed = 0.2

        slot = 2
        if queue and min(queue) < 8:
            slot = 0

        return {
            "hvac_power_level": hvac,
            "thermal_charge_rate": charge,
            "batch_job_slot": slot,
            "load_shed_fraction": shed,
            "building_id": 0,
        }

    def _default_action(self) -> dict:
        return {
            "hvac_power_level": 0.5,
            "thermal_charge_rate": 0.0,
            "batch_job_slot": 0,
            "load_shed_fraction": 0.0,
            "building_id": 0,
        }


# ── Episode runner ───────────────────────────────────────────────────────────


def run_episode(
    env_client: GridMindEnvClient,
    agent: LLMAgent,
    task_id: int,
    seed: int,
    *,
    fast_mode: bool,
    llm_every: int,
    max_steps: int | None,
    verbose: bool = False,
) -> dict[str, Any]:
    """Run a single episode and return grade + metadata. Prints [START], [STEPn], [END]."""
    reset_resp = env_client.reset(task_id=task_id, seed=seed)
    obs = reset_resp["observations"][0]

    print("[START]", flush=True)

    total_reward = 0.0
    total_steps = 0
    start_time = time.time()
    step_resp: dict[str, Any] = {}
    step_limit = EPISODE_STEPS if max_steps is None else min(max_steps, EPISODE_STEPS)

    llm_reuse_remaining = 0
    cached_action = agent._default_action()

    while not step_resp.get("done", False):
        if total_steps >= step_limit:
            break

        if fast_mode:
            action = agent._heuristic_action(obs)
        else:
            if llm_reuse_remaining <= 0:
                cached_action = agent.choose_action(obs, task_id)
                llm_reuse_remaining = max(1, llm_every)
            action = cached_action

        step_resp = env_client.step(action)
        if step_resp is None or "observation" not in step_resp:
            print(f"  [WARN] step {total_steps}: invalid step response", flush=True)
            break

        if not fast_mode:
            llm_reuse_remaining -= 1

        obs = step_resp["observation"]
        total_reward += float(step_resp["reward"])
        total_steps += 1
        print(f"[STEP{total_steps}]", flush=True)

        if verbose and total_steps % 16 == 0:
            print(
                f"    step={total_steps:02d} price=${obs['current_price']:.3f} "
                f"temp={obs['indoor_temperature']:.1f}°C "
                f"stress={obs['grid_stress_signal']:.2f} "
                f"cost=${obs['cumulative_cost']:.2f} "
                f"reward={step_resp['reward']:.3f}",
                flush=True,
            )

    elapsed = time.time() - start_time
    grade = env_client.grade()
    print("[END]", flush=True)

    return {
        "task_id": task_id,
        "seed": seed,
        "total_reward": total_reward,
        "total_steps": total_steps,
        "elapsed_sec": elapsed,
        "score": grade.get("score", 0.0),
        "sub_scores": grade.get("sub_scores", {}),
        "exploit_detected": grade.get("exploit_detected", False),
    }


# ── Main ─────────────────────────────────────────────────────────────────────


def main() -> None:
    parser = argparse.ArgumentParser(description="GridMind-RL baseline inference")
    parser.add_argument("--episodes", type=int, default=DEFAULT_EPISODES)
    parser.add_argument("--env-url", type=str, default=ENV_URL)
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--output", type=str, default="baseline_scores.json")
    parser.add_argument(
        "--fast-mode",
        action="store_true",
        help="Heuristic policy only (no LLM calls; fastest, fully reproducible).",
    )
    parser.add_argument(
        "--llm-every",
        type=int,
        default=4,
        metavar="N",
        help="Reuse the same LLM action for N consecutive steps (default: 4).",
    )
    parser.add_argument(
        "--max-steps",
        type=int,
        default=None,
        metavar="N",
        help="Stop after N steps (default: full episode). Grade uses partial episode.",
    )
    args = parser.parse_args()

    print("=" * 60)
    print("GridMind-RL Baseline Inference")
    print(f"  Model: {MODEL_NAME}")
    print(f"  API:   {API_BASE_URL}")
    print(f"  Env:   {args.env_url}")
    print(f"  Episodes per task: {args.episodes}")
    print(f"  Fast mode: {args.fast_mode} | LLM every: {args.llm_every} steps")
    print("=" * 60)

    env_client = GridMindEnvClient(base_url=args.env_url)

    print("\nWaiting for environment server...")
    for attempt in range(30):
        if env_client.health():
            print("  [OK] Environment server is healthy")
            break
        time.sleep(2)
        if attempt == 29:
            print("  [FAIL] Environment server not reachable. Exiting.")
            sys.exit(1)

    agent = LLMAgent()
    all_results: list[dict[str, Any]] = []

    for task_id in [1, 2, 3]:
        print(f"\n-- Task {task_id}: {TASK_DESCRIPTIONS[task_id][:60]}...")
        task_scores: list[float] = []
        for ep in range(args.episodes):
            seed = DEFAULT_SEED_BASE + task_id * 100 + ep
            print(f"  Episode {ep+1}/{args.episodes} (seed={seed})")
            result = run_episode(
                env_client,
                agent,
                task_id=task_id,
                seed=seed,
                fast_mode=args.fast_mode,
                llm_every=args.llm_every,
                max_steps=args.max_steps,
                verbose=args.verbose,
            )
            task_scores.append(float(result["score"]))
            all_results.append(result)
            print(
                f"    → score={result['score']:.4f} | reward={result['total_reward']:.3f} | "
                f"{result['elapsed_sec']:.1f}s | steps={result['total_steps']}"
            )

        avg_score = sum(task_scores) / len(task_scores)
        print(f"  Task {task_id} average score: {avg_score:.4f}")

    print("\n" + "=" * 60)
    print("BASELINE SCORES SUMMARY")
    print("=" * 60)
    print(f"{'Task':<10} {'Model':<30} {'Score':<10} {'Episodes':<10}")
    print("-" * 60)

    task_avgs: dict[int, float] = {}
    for task_id in [1, 2, 3]:
        scores = [float(r["score"]) for r in all_results if r["task_id"] == task_id]
        avg = sum(scores) / len(scores) if scores else 0.0
        task_avgs[task_id] = avg
        print(f"Task {task_id:<6} {MODEL_NAME:<30} {avg:<10.4f} {len(scores)}")

    print("-" * 60)
    overall = sum(task_avgs.values()) / len(task_avgs)
    print(f"{'Overall':<10} {'':<30} {overall:<10.4f}")

    output = {
        "model": MODEL_NAME,
        "api_base": API_BASE_URL,
        "episodes_per_task": args.episodes,
        "seed_base": DEFAULT_SEED_BASE,
        "fast_mode": args.fast_mode,
        "llm_every": args.llm_every,
        "max_steps": args.max_steps,
        "task_averages": {str(k): v for k, v in task_avgs.items()},
        "overall_average": overall,
        "all_results": all_results,
    }
    with open(args.output, "w", encoding="utf-8") as f:
        json.dump(output, f, indent=2)
    print(f"\n[OK] Results saved to {args.output}")


if __name__ == "__main__":
    main()