adityss commited on
Commit
88da572
·
1 Parent(s): e531486

Add coordinator endpoint tests and project readiness verification script

Browse files

- Created `test_coordinator.py` to test `/coordinator/reset` and `/coordinator/step` endpoints, including multi-step episode functionality.
- Added `verify_readiness.py` to check essential files, directories, and key features for project readiness before submission.

baseline_scores.json CHANGED
@@ -1,23 +1,23 @@
1
  {
2
- "model": "<your-active-model>",
3
- "api_base": "<your-active-endpoint>",
4
  "episodes_per_task": 1,
5
  "seed_base": 1000,
6
  "fast_mode": true,
7
  "llm_every": 8,
8
  "max_steps": null,
9
  "task_averages": {
10
- "3": 0.7278
11
  },
12
- "overall_average": 0.7278,
13
  "all_results": [
14
  {
15
- "task_id": 3,
16
- "seed": 1300,
17
- "total_reward": 248.19888206740697,
18
  "total_steps": 96,
19
- "elapsed_sec": 1.187589406967163,
20
- "score": 0.7278,
21
  "sub_scores": {},
22
  "exploit_detected": false
23
  }
 
1
  {
2
+ "model": "Qwen/Qwen2.5-7B-Instruct",
3
+ "api_base": "https://api-inference.huggingface.co/v1",
4
  "episodes_per_task": 1,
5
  "seed_base": 1000,
6
  "fast_mode": true,
7
  "llm_every": 8,
8
  "max_steps": null,
9
  "task_averages": {
10
+ "1": 0.5482
11
  },
12
+ "overall_average": 0.5482,
13
  "all_results": [
14
  {
15
+ "task_id": 1,
16
+ "seed": 1100,
17
+ "total_reward": 249.22208122816207,
18
  "total_steps": 96,
19
+ "elapsed_sec": 1.4036986827850342,
20
+ "score": 0.5482,
21
  "sub_scores": {},
22
  "exploit_detected": false
23
  }
inference.py CHANGED
@@ -163,11 +163,15 @@ def get_llm_client() -> OpenAI:
163
 
164
  # ── LLM Agent ────────────────────────────────────────────────────────────────
165
  class LLMAgent:
166
- def __init__(self):
167
- self.client = get_llm_client()
168
  self.model = MODEL_NAME
169
- self.fallback_mode = False
170
  self.instruction_card: Optional[dict] = None # set for task 4 episodes
 
 
 
 
171
 
172
  def set_instruction_card(self, card: Optional[dict]) -> None:
173
  """Store the instruction card received from reset for task 4 episodes."""
@@ -175,7 +179,7 @@ class LLMAgent:
175
 
176
  def choose_action(self, obs: dict, task_id: int) -> dict:
177
  """Prompt the LLM with current observation, return parsed action dict."""
178
- if self.fallback_mode:
179
  return self._heuristic_action(obs)
180
 
181
  task_desc = TASK_DESCRIPTIONS.get(task_id, TASK_DESCRIPTIONS[1])
@@ -224,6 +228,10 @@ Strategy hints:
224
  Respond with ONLY a JSON action:
225
  {ACTION_SCHEMA}"""
226
 
 
 
 
 
227
  for attempt in range(MAX_RETRIES):
228
  try:
229
  completion = self.client.chat.completions.create(
@@ -379,6 +387,16 @@ class GridMindEnvClient:
379
  print(f"[ERROR] Failed to step environment: {e}", file=sys.stderr)
380
  return None
381
 
 
 
 
 
 
 
 
 
 
 
382
  def simulate(self, actions: list[dict]) -> Optional[dict]:
383
  """Predict the next state using the world modeling API without advancing the real environment."""
384
  try:
@@ -476,93 +494,212 @@ def run_episode(
476
  if total_steps >= step_limit:
477
  break
478
 
479
- if fast_mode:
480
- action = agent._heuristic_action(obs)
481
- else:
482
- if llm_reuse_remaining <= 0:
483
- cached_action = agent.choose_action(obs, task_id)
484
- llm_reuse_remaining = max(1, llm_every)
485
- action = cached_action
486
-
487
- # C5: World Modeling - Use /simulate when efficiency is low or faults active
488
- hvac_eff = obs.get("hvac_efficiency", 1.0)
489
- active_faults_list = obs.get("active_faults", [])
490
- use_simulation = not fast_mode and (use_planning or hvac_eff < 0.7 or len(active_faults_list) > 0)
491
-
492
- sim_result = None
493
- sim_reward = None
494
- if use_simulation:
495
- try:
496
- sim_result = env_client.simulate([action])
497
- if sim_result and "results" in sim_result and len(sim_result["results"]) > 0:
498
- sim_reward = float(sim_result["results"][0]["reward"])
499
- print(f"🔮 SIMULATE → predicted_reward={sim_reward:.4f} | committed", file=sys.stderr)
500
- except Exception as e:
501
- print(f"🔮 SIMULATE → failed ({e}), proceeding without", file=sys.stderr)
502
-
503
- # Check if simulation predicts poor reward vs running average
504
- if sim_reward is not None and running_avg != 0.0 and sim_reward < running_avg - 0.3:
505
- # Ask LLM for alternative action with simulation warning
506
- print(f"⚠️ SIMULATION RESULT: proposed action yields reward {sim_reward:.3f} "
507
- f"which is below your running average {running_avg:.3f}. "
508
- f"Consider reducing HVAC load or increasing load shed fraction.", file=sys.stderr)
509
- # Get a revised action from the LLM
510
- revised_action = agent.choose_action(obs, task_id)
511
- action = revised_action
512
-
513
- step_resp = env_client.step(action)
514
- if step_resp is None or not isinstance(step_resp, dict) or "observation" not in step_resp:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  log_step(
516
- step=total_steps + 1,
517
- action="null",
518
- reward=0.0,
519
- done=True,
520
- error="invalid step response from environment",
521
  )
522
- break
523
-
524
- if not fast_mode:
525
- llm_reuse_remaining -= 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
 
527
- obs = step_resp["observation"]
528
- raw_reward = float(step_resp["reward"])
529
- total_reward += raw_reward
530
- raw_rewards.append(raw_reward)
531
 
532
- # Update running average for world model comparison
533
- if total_steps > 0:
534
- running_avg = running_avg * 0.9 + raw_reward * 0.1
 
535
 
536
- if raw_reward < reward_min:
537
- reward_min = raw_reward
538
- if raw_reward > reward_max:
539
- reward_max = raw_reward
540
 
541
- total_steps += 1
542
- done = bool(step_resp.get("done", False))
 
 
543
 
544
- normalized_reward = normalize_reward(raw_reward, reward_min, reward_max)
 
545
 
546
- action_json = json.dumps(action, separators=(',', ':'))
547
- last_action_error = step_resp.get("last_action_error")
548
- log_step(
549
- step=total_steps,
550
- action=action_json,
551
- reward=normalized_reward,
552
- done=done,
553
- error=last_action_error,
554
- )
555
 
556
- if verbose and total_steps % 16 == 0:
557
- print(
558
- f" step={total_steps:02d} price=${obs['current_price']:.3f} "
559
- f"temp={obs['indoor_temperature']:.1f}°C "
560
- f"stress={obs['grid_stress_signal']:.2f} "
561
- f"cost=${obs['cumulative_cost']:.2f}",
562
- flush=True,
563
- file=sys.stderr,
564
  )
565
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  success = bool(step_resp.get("done", False))
567
 
568
  except Exception as e:
@@ -734,7 +871,7 @@ def main() -> None:
734
  print("Environment server not reachable.", file=sys.stderr)
735
  sys.exit(1)
736
 
737
- agent = LLMAgent()
738
  all_results: list[dict[str, Any]] = []
739
 
740
  # Determine task list: use --task if specified, otherwise all
 
163
 
164
  # ── LLM Agent ────────────────────────────────────────────────────────────────
165
  class LLMAgent:
166
+ def __init__(self, fast_mode: bool = False):
167
+ self.client = None
168
  self.model = MODEL_NAME
169
+ self.fallback_mode = fast_mode # Start in fallback if fast mode
170
  self.instruction_card: Optional[dict] = None # set for task 4 episodes
171
+
172
+ # Only initialize LLM client if not in fast mode
173
+ if not fast_mode:
174
+ self.client = get_llm_client()
175
 
176
  def set_instruction_card(self, card: Optional[dict]) -> None:
177
  """Store the instruction card received from reset for task 4 episodes."""
 
179
 
180
  def choose_action(self, obs: dict, task_id: int) -> dict:
181
  """Prompt the LLM with current observation, return parsed action dict."""
182
+ if self.fallback_mode or self.client is None:
183
  return self._heuristic_action(obs)
184
 
185
  task_desc = TASK_DESCRIPTIONS.get(task_id, TASK_DESCRIPTIONS[1])
 
228
  Respond with ONLY a JSON action:
229
  {ACTION_SCHEMA}"""
230
 
231
+ # If no client available, use heuristic
232
+ if self.client is None:
233
+ return self._heuristic_action(obs)
234
+
235
  for attempt in range(MAX_RETRIES):
236
  try:
237
  completion = self.client.chat.completions.create(
 
387
  print(f"[ERROR] Failed to step environment: {e}", file=sys.stderr)
388
  return None
389
 
390
+ def coordinator_step(self, actions: list[dict]) -> Optional[dict]:
391
+ """Multi-agent step: send per-building actions to /coordinator/step."""
392
+ try:
393
+ r = requests.post(f"{self.base}/coordinator/step", json=actions, timeout=self.timeout)
394
+ r.raise_for_status()
395
+ return r.json()
396
+ except Exception as e:
397
+ print(f"[ERROR] Failed to coordinator step: {e}", file=sys.stderr)
398
+ return None
399
+
400
  def simulate(self, actions: list[dict]) -> Optional[dict]:
401
  """Predict the next state using the world modeling API without advancing the real environment."""
402
  try:
 
494
  if total_steps >= step_limit:
495
  break
496
 
497
+ if coordinator:
498
+ # ─────────────────────────────────────────────────────
499
+ # Multi-Agent Coordinator Mode (Theme 1)
500
+ # ──────────────────────────────────────���──────────────
501
+ building_actions = []
502
+ action_jsons = []
503
+
504
+ # Get LLM action for each building
505
+ for bid, building_obs in enumerate(obs_list):
506
+ if fast_mode:
507
+ action = agent._heuristic_action(building_obs)
508
+ else:
509
+ if llm_reuse_remaining <= 0:
510
+ action = agent.choose_action(building_obs, task_id)
511
+ llm_reuse_remaining = max(1, llm_every)
512
+ else:
513
+ action = cached_action
514
+
515
+ action["building_id"] = bid
516
+ building_actions.append(action)
517
+ action_jsons.append(json.dumps(action, separators=(',', ':')))
518
+
519
+ if not fast_mode:
520
+ llm_reuse_remaining -= 1
521
+
522
+ # Execute coordinator step with all building actions
523
+ coord_resp = env_client.coordinator_step(building_actions)
524
+ if coord_resp is None or not isinstance(coord_resp, (dict, list)):
525
+ log_step(
526
+ step=total_steps + 1,
527
+ action="null",
528
+ reward=0.0,
529
+ done=True,
530
+ error="invalid coordinator step response",
531
+ )
532
+ break
533
+
534
+ # Process responses from all buildings
535
+ # coord_resp can be either an array directly or a dict with "responses" key
536
+ if isinstance(coord_resp, list):
537
+ responses = coord_resp
538
+ done = False # Will be set from responses or episode state
539
+ else:
540
+ responses = coord_resp.get("responses", [])
541
+ done = bool(coord_resp.get("done", False))
542
+
543
+ obs_list = []
544
+ step_rewards = []
545
+
546
+ for i, resp in enumerate(responses):
547
+ if isinstance(resp, dict):
548
+ if "observation" in resp:
549
+ obs_list.append(resp["observation"])
550
+ reward = float(resp.get("reward", 0.0))
551
+ else:
552
+ reward = 0.0
553
+ step_rewards.append(reward)
554
+
555
+ if not obs_list:
556
+ log_step(
557
+ step=total_steps + 1,
558
+ action="null",
559
+ reward=0.0,
560
+ done=True,
561
+ error="no observations in coordinator response",
562
+ )
563
+ break
564
+
565
+ obs = obs_list[0] # Use primary building for logging
566
+
567
+ # Aggregate reward (mean of all buildings)
568
+ raw_reward = sum(step_rewards) / len(step_rewards) if step_rewards else 0.0
569
+ if isinstance(coord_resp, list) and len(responses) > 0:
570
+ done = bool(responses[-1].get("done", False)) if isinstance(responses[-1], dict) else False
571
+
572
+ # Log primary building action and aggregated reward
573
+ primary_action_json = action_jsons[0] if action_jsons else "null"
574
+ total_reward += raw_reward
575
+ raw_rewards.append(raw_reward)
576
+
577
+ # Update running average
578
+ if total_steps > 0:
579
+ running_avg = running_avg * 0.9 + raw_reward * 0.1
580
+
581
+ if raw_reward < reward_min:
582
+ reward_min = raw_reward
583
+ if raw_reward > reward_max:
584
+ reward_max = raw_reward
585
+
586
+ total_steps += 1
587
+ normalized_reward = normalize_reward(raw_reward, reward_min, reward_max)
588
+
589
  log_step(
590
+ step=total_steps,
591
+ action=primary_action_json,
592
+ reward=normalized_reward,
593
+ done=done,
594
+ error=None,
595
  )
596
+
597
+ if verbose and total_steps % 16 == 0:
598
+ temps = [o.get('indoor_temperature', 21) for o in obs_list]
599
+ costs = [o.get('cumulative_cost', 0) for o in obs_list]
600
+ print(
601
+ f" step={total_steps:02d} buildings={len(obs_list)} "
602
+ f"temps={[f'{t:.1f}' for t in temps]} "
603
+ f"costs=${sum(costs):.2f}",
604
+ flush=True,
605
+ file=sys.stderr,
606
+ )
607
+
608
+ step_resp = {"done": done}
609
+
610
+ else:
611
+ # ─────────────────────────────────────────────────────
612
+ # Single-Building Mode (default)
613
+ # ─────────────────────────────────────────────────────
614
+ if fast_mode:
615
+ action = agent._heuristic_action(obs)
616
+ else:
617
+ if llm_reuse_remaining <= 0:
618
+ cached_action = agent.choose_action(obs, task_id)
619
+ llm_reuse_remaining = max(1, llm_every)
620
+ action = cached_action
621
+
622
+ # C5: World Modeling - Use /simulate when efficiency is low or faults active
623
+ hvac_eff = obs.get("hvac_efficiency", 1.0)
624
+ active_faults_list = obs.get("active_faults", [])
625
+ use_simulation = not fast_mode and (use_planning or hvac_eff < 0.7 or len(active_faults_list) > 0)
626
+
627
+ sim_result = None
628
+ sim_reward = None
629
+ if use_simulation:
630
+ try:
631
+ sim_result = env_client.simulate([action])
632
+ if sim_result and "results" in sim_result and len(sim_result["results"]) > 0:
633
+ sim_reward = float(sim_result["results"][0]["reward"])
634
+ print(f"🔮 SIMULATE → predicted_reward={sim_reward:.4f} | committed", file=sys.stderr)
635
+ except Exception as e:
636
+ print(f"🔮 SIMULATE → failed ({e}), proceeding without", file=sys.stderr)
637
+
638
+ # Check if simulation predicts poor reward vs running average
639
+ if sim_reward is not None and running_avg != 0.0 and sim_reward < running_avg - 0.3:
640
+ # Ask LLM for alternative action with simulation warning
641
+ print(f"⚠️ SIMULATION RESULT: proposed action yields reward {sim_reward:.3f} "
642
+ f"which is below your running average {running_avg:.3f}. "
643
+ f"Consider reducing HVAC load or increasing load shed fraction.", file=sys.stderr)
644
+ # Get a revised action from the LLM
645
+ revised_action = agent.choose_action(obs, task_id)
646
+ action = revised_action
647
+
648
+ step_resp = env_client.step(action)
649
+ if step_resp is None or not isinstance(step_resp, dict) or "observation" not in step_resp:
650
+ log_step(
651
+ step=total_steps + 1,
652
+ action="null",
653
+ reward=0.0,
654
+ done=True,
655
+ error="invalid step response from environment",
656
+ )
657
+ break
658
 
659
+ if not fast_mode:
660
+ llm_reuse_remaining -= 1
 
 
661
 
662
+ obs = step_resp["observation"]
663
+ raw_reward = float(step_resp["reward"])
664
+ total_reward += raw_reward
665
+ raw_rewards.append(raw_reward)
666
 
667
+ # Update running average for world model comparison
668
+ if total_steps > 0:
669
+ running_avg = running_avg * 0.9 + raw_reward * 0.1
 
670
 
671
+ if raw_reward < reward_min:
672
+ reward_min = raw_reward
673
+ if raw_reward > reward_max:
674
+ reward_max = raw_reward
675
 
676
+ total_steps += 1
677
+ done = bool(step_resp.get("done", False))
678
 
679
+ normalized_reward = normalize_reward(raw_reward, reward_min, reward_max)
 
 
 
 
 
 
 
 
680
 
681
+ action_json = json.dumps(action, separators=(',', ':'))
682
+ last_action_error = step_resp.get("last_action_error")
683
+ log_step(
684
+ step=total_steps,
685
+ action=action_json,
686
+ reward=normalized_reward,
687
+ done=done,
688
+ error=last_action_error,
689
  )
690
 
691
+ if verbose and total_steps % 16 == 0:
692
+ print(
693
+ f" step={total_steps:02d} price=${obs['current_price']:.3f} "
694
+ f"temp={obs['indoor_temperature']:.1f}°C "
695
+ f"stress={obs['grid_stress_signal']:.2f} "
696
+ f"cost=${obs['cumulative_cost']:.2f}",
697
+ flush=True,
698
+ file=sys.stderr,
699
+ )
700
+
701
+ step_resp = {"done": done}
702
+
703
  success = bool(step_resp.get("done", False))
704
 
705
  except Exception as e:
 
871
  print("Environment server not reachable.", file=sys.stderr)
872
  sys.exit(1)
873
 
874
+ agent = LLMAgent(fast_mode=args.fast_mode)
875
  all_results: list[dict[str, Any]] = []
876
 
877
  # Determine task list: use --task if specified, otherwise all
main.go CHANGED
@@ -149,6 +149,8 @@ func (s *Server) routes() *http.ServeMux {
149
  mux.HandleFunc("/ping", s.handlePing)
150
  mux.HandleFunc("/reset", s.handleReset)
151
  mux.HandleFunc("/step", s.handleStep)
 
 
152
  mux.HandleFunc("/state", s.handleState)
153
  mux.HandleFunc("/replay", s.handleReplay)
154
  mux.HandleFunc("/grade", s.handleGrade)
@@ -312,6 +314,80 @@ func (s *Server) handleStep(w http.ResponseWriter, r *http.Request) {
312
  }
313
  }
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  // ── /state ───────────────────────────────────────────────────────────────────
316
 
317
  func (s *Server) handleState(w http.ResponseWriter, r *http.Request) {
@@ -511,15 +587,15 @@ func getClientIP(r *http.Request) string {
511
  // ── /ws (WebSocket) ───────────────────────────────────────────────────────────
512
 
513
  type WSMessage struct {
514
- Type string `json:"type"`
515
- Data json.RawMessage `json:"data,omitempty"`
516
- Seed *int64 `json:"seed,omitempty"`
517
- TaskID int `json:"task_id,omitempty"`
518
  }
519
 
520
  type WSResetMessage struct {
521
- Seed *int64 `json:"seed,omitempty"`
522
- TaskID int `json:"task_id,omitempty"`
523
  NumBuildings int `json:"num_buildings,omitempty"`
524
  }
525
 
@@ -634,13 +710,13 @@ func (s *Server) handleWSReset(conn *websocket.Conn, data json.RawMessage) {
634
  "thermal_storage_level": obs.ThermalStorageLevel,
635
  "process_demand": obs.ProcessDemand,
636
  "current_price": obs.CurrentPrice,
637
- "grid_stress_signal": obs.GridStressSignal,
638
- "carbon_intensity": obs.CarbonIntensity,
639
- "hour_of_day": obs.HourOfDay,
640
- "batch_queue": obs.BatchQueue,
641
- "cumulative_cost": obs.CumulativeCost,
642
- "step": obs.Step,
643
- "building_id": obs.BuildingID,
644
  },
645
  "reward": nil,
646
  "done": false,
@@ -699,13 +775,13 @@ func (s *Server) handleWSStep(conn *websocket.Conn, data json.RawMessage) {
699
  "thermal_storage_level": obs.Observation.ThermalStorageLevel,
700
  "process_demand": obs.Observation.ProcessDemand,
701
  "current_price": obs.Observation.CurrentPrice,
702
- "grid_stress_signal": obs.Observation.GridStressSignal,
703
- "carbon_intensity": obs.Observation.CarbonIntensity,
704
- "hour_of_day": obs.Observation.HourOfDay,
705
- "batch_queue": obs.Observation.BatchQueue,
706
- "cumulative_cost": obs.Observation.CumulativeCost,
707
- "step": obs.Observation.Step,
708
- "building_id": obs.Observation.BuildingID,
709
  },
710
  "reward": obs.Reward,
711
  "done": done,
@@ -735,8 +811,8 @@ func (s *Server) handleWSResetDirect(conn *websocket.Conn, seed *int64, taskID i
735
  }
736
 
737
  resp := s.envMgr.Reset(env.ResetRequest{
738
- Seed: seed,
739
- TaskID: taskID,
740
  NumBuildings: 1,
741
  })
742
 
@@ -747,13 +823,13 @@ func (s *Server) handleWSResetDirect(conn *websocket.Conn, seed *int64, taskID i
747
  "thermal_storage_level": obs.ThermalStorageLevel,
748
  "process_demand": obs.ProcessDemand,
749
  "current_price": obs.CurrentPrice,
750
- "grid_stress_signal": obs.GridStressSignal,
751
- "carbon_intensity": obs.CarbonIntensity,
752
- "hour_of_day": obs.HourOfDay,
753
- "batch_queue": obs.BatchQueue,
754
- "cumulative_cost": obs.CumulativeCost,
755
- "step": obs.Step,
756
- "building_id": obs.BuildingID,
757
  },
758
  "reward": nil,
759
  "done": false,
@@ -809,13 +885,13 @@ func (s *Server) handleWSStepDirect(conn *websocket.Conn, msgBytes []byte) {
809
  "thermal_storage_level": obs.Observation.ThermalStorageLevel,
810
  "process_demand": obs.Observation.ProcessDemand,
811
  "current_price": obs.Observation.CurrentPrice,
812
- "grid_stress_signal": obs.Observation.GridStressSignal,
813
- "carbon_intensity": obs.Observation.CarbonIntensity,
814
- "hour_of_day": obs.Observation.HourOfDay,
815
- "batch_queue": obs.Observation.BatchQueue,
816
- "cumulative_cost": obs.Observation.CumulativeCost,
817
- "step": obs.Observation.Step,
818
- "building_id": obs.Observation.BuildingID,
819
  },
820
  "reward": obs.Reward,
821
  "done": done,
 
149
  mux.HandleFunc("/ping", s.handlePing)
150
  mux.HandleFunc("/reset", s.handleReset)
151
  mux.HandleFunc("/step", s.handleStep)
152
+ mux.HandleFunc("/coordinator/reset", s.handleCoordinatorReset)
153
+ mux.HandleFunc("/coordinator/step", s.handleCoordinatorStep)
154
  mux.HandleFunc("/state", s.handleState)
155
  mux.HandleFunc("/replay", s.handleReplay)
156
  mux.HandleFunc("/grade", s.handleGrade)
 
314
  }
315
  }
316
 
317
+ // ── /coordinator/reset ──────────────────────────────────────────────────────
318
+
319
+ func (s *Server) handleCoordinatorReset(w http.ResponseWriter, r *http.Request) {
320
+ if r.Method != http.MethodPost {
321
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
322
+ return
323
+ }
324
+ var req env.ResetRequest
325
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
326
+ // Allow empty body → defaults
327
+ req = env.ResetRequest{TaskID: 1, NumBuildings: 3}
328
+ }
329
+ if req.TaskID == 0 {
330
+ req.TaskID = 1
331
+ }
332
+ if req.NumBuildings == 0 {
333
+ req.NumBuildings = 3
334
+ }
335
+ resp := s.envMgr.Reset(req)
336
+ w.Header().Set("Content-Type", "application/json")
337
+ json.NewEncoder(w).Encode(resp)
338
+ }
339
+
340
+ // ── /coordinator/step ───────────────────────────────────────────────────────
341
+
342
+ func (s *Server) handleCoordinatorStep(w http.ResponseWriter, r *http.Request) {
343
+ if r.Method != http.MethodPost {
344
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
345
+ return
346
+ }
347
+
348
+ start := time.Now()
349
+
350
+ // Accept array of actions (one per building)
351
+ var actions []env.ActionModel
352
+
353
+ body := make([]byte, 0, 512)
354
+ buf := make([]byte, 512)
355
+ for {
356
+ n, err := r.Body.Read(buf)
357
+ body = append(body, buf[:n]...)
358
+ if err != nil {
359
+ break
360
+ }
361
+ }
362
+
363
+ if err := json.Unmarshal(body, &actions); err != nil {
364
+ atomic.AddInt64(&metrics.errorCount, 1)
365
+ http.Error(w, "invalid action array: "+err.Error(), http.StatusBadRequest)
366
+ return
367
+ }
368
+
369
+ // If empty array provided, use defaults
370
+ if len(actions) == 0 {
371
+ actions = []env.ActionModel{{HVACPowerLevel: 0.5, BuildingID: 0}}
372
+ }
373
+
374
+ responses, _ := s.envMgr.Step(actions)
375
+
376
+ latency := float64(time.Since(start).Microseconds()) / 1000.0
377
+ for _, resp := range responses {
378
+ metrics.recordStep(latency, resp.Reward)
379
+ }
380
+ if len(actions) > 0 {
381
+ metrics.recordAction(actions[0].HVACPowerLevel)
382
+ }
383
+
384
+ w.Header().Set("Content-Type", "application/json")
385
+ w.Header().Set("Access-Control-Allow-Origin", "*")
386
+
387
+ // Always return array format for coordinator
388
+ json.NewEncoder(w).Encode(responses)
389
+ }
390
+
391
  // ── /state ───────────────────────────────────────────────────────────────────
392
 
393
  func (s *Server) handleState(w http.ResponseWriter, r *http.Request) {
 
587
  // ── /ws (WebSocket) ───────────────────────────────────────────────────────────
588
 
589
  type WSMessage struct {
590
+ Type string `json:"type"`
591
+ Data json.RawMessage `json:"data,omitempty"`
592
+ Seed *int64 `json:"seed,omitempty"`
593
+ TaskID int `json:"task_id,omitempty"`
594
  }
595
 
596
  type WSResetMessage struct {
597
+ Seed *int64 `json:"seed,omitempty"`
598
+ TaskID int `json:"task_id,omitempty"`
599
  NumBuildings int `json:"num_buildings,omitempty"`
600
  }
601
 
 
710
  "thermal_storage_level": obs.ThermalStorageLevel,
711
  "process_demand": obs.ProcessDemand,
712
  "current_price": obs.CurrentPrice,
713
+ "grid_stress_signal": obs.GridStressSignal,
714
+ "carbon_intensity": obs.CarbonIntensity,
715
+ "hour_of_day": obs.HourOfDay,
716
+ "batch_queue": obs.BatchQueue,
717
+ "cumulative_cost": obs.CumulativeCost,
718
+ "step": obs.Step,
719
+ "building_id": obs.BuildingID,
720
  },
721
  "reward": nil,
722
  "done": false,
 
775
  "thermal_storage_level": obs.Observation.ThermalStorageLevel,
776
  "process_demand": obs.Observation.ProcessDemand,
777
  "current_price": obs.Observation.CurrentPrice,
778
+ "grid_stress_signal": obs.Observation.GridStressSignal,
779
+ "carbon_intensity": obs.Observation.CarbonIntensity,
780
+ "hour_of_day": obs.Observation.HourOfDay,
781
+ "batch_queue": obs.Observation.BatchQueue,
782
+ "cumulative_cost": obs.Observation.CumulativeCost,
783
+ "step": obs.Observation.Step,
784
+ "building_id": obs.Observation.BuildingID,
785
  },
786
  "reward": obs.Reward,
787
  "done": done,
 
811
  }
812
 
813
  resp := s.envMgr.Reset(env.ResetRequest{
814
+ Seed: seed,
815
+ TaskID: taskID,
816
  NumBuildings: 1,
817
  })
818
 
 
823
  "thermal_storage_level": obs.ThermalStorageLevel,
824
  "process_demand": obs.ProcessDemand,
825
  "current_price": obs.CurrentPrice,
826
+ "grid_stress_signal": obs.GridStressSignal,
827
+ "carbon_intensity": obs.CarbonIntensity,
828
+ "hour_of_day": obs.HourOfDay,
829
+ "batch_queue": obs.BatchQueue,
830
+ "cumulative_cost": obs.CumulativeCost,
831
+ "step": obs.Step,
832
+ "building_id": obs.BuildingID,
833
  },
834
  "reward": nil,
835
  "done": false,
 
885
  "thermal_storage_level": obs.Observation.ThermalStorageLevel,
886
  "process_demand": obs.Observation.ProcessDemand,
887
  "current_price": obs.Observation.CurrentPrice,
888
+ "grid_stress_signal": obs.Observation.GridStressSignal,
889
+ "carbon_intensity": obs.Observation.CarbonIntensity,
890
+ "hour_of_day": obs.Observation.HourOfDay,
891
+ "batch_queue": obs.Observation.BatchQueue,
892
+ "cumulative_cost": obs.Observation.CumulativeCost,
893
+ "step": obs.Observation.Step,
894
+ "building_id": obs.Observation.BuildingID,
895
  },
896
  "reward": obs.Reward,
897
  "done": done,
scripts/gridmind_grpo_colab.ipynb CHANGED
@@ -2,521 +2,625 @@
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
 
5
  "metadata": {},
6
  "source": [
7
- "# GridMind-RL: GRPO Training with Unsloth + TRL\n",
8
  "\n",
9
- "Fine-tunes **Qwen2.5-1.5B-Instruct** (4-bit LoRA) to control industrial building HVAC,\n",
10
- "thermal storage, and batch scheduling via the live **GridMind-RL OpenEnv** environment.\n",
11
  "\n",
12
- "**Key fix:** This notebook uses episode-level rewards from the `/grade` endpoint \n",
13
- "not step-level rewards. This prevents mode collapse where the model\n",
14
- "finds one action and repeats it forever.\n",
 
 
 
 
15
  "\n",
16
  "| | |\n",
17
  "|---|---|\n",
18
  "| **Environment** | https://lo-kyu-gridmind.hf.space |\n",
19
  "| **Method** | GRPO (Group Relative Policy Optimization) |\n",
20
- "| **Framework** | Unsloth 4-bit LoRA + HF TRL |\n",
21
- "| **Model** | unsloth/Qwen2.5-1.5B-Instruct |\n",
22
- "| **Training** | 300 steps, T4 GPU (~40 min) |\n",
23
- "\n",
24
- "### What the agent learns:\n",
25
- "- Task 1: Charge storage off-peak, discharge at peak to minimize cost\n",
26
- "- Task 2: Balance temperature comfort vs HVAC energy spend\n",
27
- "- Task 3: Respond to grid stress (shed load), schedule batch jobs, minimize carbon"
28
  ]
29
  },
30
  {
31
  "cell_type": "code",
32
  "execution_count": null,
 
33
  "metadata": {},
34
  "outputs": [],
35
  "source": [
36
- "%%capture\n",
37
- "!pip install unsloth requests\n",
38
- "!pip install --no-deps bitsandbytes accelerate xformers peft trl triton\n",
39
- "!pip install --no-deps cut_cross_entropy unsloth_zoo\n",
40
- "!pip install \"datasets>=3.4.1,<4.0.0\" pandas matplotlib"
 
 
 
 
 
 
41
  ]
42
  },
43
  {
44
  "cell_type": "markdown",
 
45
  "metadata": {},
46
  "source": [
47
- "## Step 1 Verify the Live Environment"
48
  ]
49
  },
50
  {
51
  "cell_type": "code",
52
  "execution_count": null,
 
53
  "metadata": {},
54
  "outputs": [],
55
  "source": [
56
  "import requests\n",
 
 
57
  "\n",
58
  "ENV_URL = \"https://lo-kyu-gridmind.hf.space\"\n",
59
  "\n",
60
- "print(\"Environment health:\", requests.get(f\"{ENV_URL}/health\", timeout=10).json())\n",
61
- "print(\"\\nTasks available:\")\n",
62
- "for t in requests.get(f\"{ENV_URL}/tasks\", timeout=10).json():\n",
63
- " print(f\" Task {t['id']}: {t['name']} ({t['difficulty']})\")\n",
64
- "\n",
65
- "# Quick smoke test: reset + step + grade\n",
66
- "r = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": 1, \"seed\": 42}, timeout=30)\n",
67
- "obs = r.json()[\"observations\"][0]\n",
68
- "print(f\"\\nObservation keys: {list(obs.keys())}\")\n",
69
- "step_r = requests.post(f\"{ENV_URL}/step\", json=[{\n",
70
- " \"hvac_power_level\": 0.5, \"thermal_charge_rate\": 0.0,\n",
71
- " \"batch_job_slot\": 0, \"load_shed_fraction\": 0.0, \"building_id\": 0\n",
72
- "}], timeout=30)\n",
73
- "sr = step_r.json()\n",
74
- "print(f\"Step reward: {sr[0]['reward']:.3f}, done: {sr[0]['done']}\")\n",
75
- "grade_r = requests.get(f\"{ENV_URL}/grade\", timeout=30).json()\n",
76
- "print(f\"Episode score: {grade_r['score']:.3f}\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  ]
78
  },
79
  {
80
  "cell_type": "markdown",
 
81
  "metadata": {},
82
  "source": [
83
- "## Step 2 Load Unsloth Model"
84
  ]
85
  },
86
  {
87
  "cell_type": "code",
88
  "execution_count": null,
 
89
  "metadata": {},
90
  "outputs": [],
91
  "source": [
92
- "from unsloth import FastLanguageModel\n",
93
- "import torch\n",
94
  "\n",
95
- "max_seq_length = 512\n",
96
- "lora_rank = 16\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  "\n",
98
- "print(\"Loading model...\")\n",
99
- "model, tokenizer = FastLanguageModel.from_pretrained(\n",
100
- " model_name = \"unsloth/Qwen2.5-1.5B-Instruct\",\n",
101
- " max_seq_length = max_seq_length,\n",
102
- " load_in_4bit = True,\n",
103
- ")\n",
 
 
 
104
  "\n",
105
- "model = FastLanguageModel.get_peft_model(\n",
106
- " model,\n",
107
- " r = lora_rank,\n",
108
- " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
109
- " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
110
- " lora_alpha = lora_rank * 2,\n",
111
- " use_gradient_checkpointing = \"unsloth\",\n",
112
- " random_state = 42,\n",
113
- ")\n",
114
- "print(f\"Model loaded. Trainable params: {model.num_trainable_parameters():,}\")"
115
  ]
116
  },
117
  {
118
  "cell_type": "markdown",
 
119
  "metadata": {},
120
  "source": [
121
- "## Step 3 Build Diverse Training Prompts"
122
  ]
123
  },
124
  {
125
  "cell_type": "code",
126
  "execution_count": null,
 
127
  "metadata": {},
128
  "outputs": [],
129
  "source": [
130
- "import json, re, random\n",
131
- "\n",
132
- "random.seed(42)\n",
133
- "\n",
134
- "SCENARIOS = [\n",
135
- " # Off-peak: cheap electricity, agent should charge storage\n",
136
- " (\"off_peak\", \"price=$0.03/kWh\", \"grid_stress=0.0\", \"Charge thermal storage now — price is cheapest today\"),\n",
137
- " (\"off_peak\", \"price=$0.04/kWh\", \"grid_stress=0.0\", \"Off-peak period. Use this time to charge storage cheaply.\"),\n",
138
- " (\"off_peak\", \"price=$0.05/kWh\", \"grid_stress=0.0\", \"Low price window. Charge storage aggressively.\"),\n",
139
- " # Mid-peak: moderate price, balance HVAC and storage\n",
140
- " (\"mid_peak\", \"price=$0.12/kWh\", \"grid_stress=0.2\", \"Mid-peak pricing. Moderate HVAC, monitor grid.\"),\n",
141
- " (\"mid_peak\", \"price=$0.10/kWh\", \"grid_stress=0.1\", \"Moderate prices. Keep HVAC at setpoint.\"),\n",
142
- " # Peak: expensive, should discharge storage if available\n",
143
- " (\"peak\", \"price=$0.28/kWh\", \"grid_stress=0.4\", \"Peak pricing! Discharge storage, reduce HVAC if comfortable.\"),\n",
144
- " (\"peak\", \"price=$0.32/kWh\", \"grid_stress=0.5\", \"CRITICAL PEAK. Minimize consumption, shed non-critical load.\"),\n",
145
- " # Grid stress: respond to demand-response signal\n",
146
- " (\"grid_stress\", \"price=$0.20/kWh\", \"grid_stress=0.8\", \"GRID EMERGENCY. Shed load immediately (load_shed_fraction > 0.3).\"),\n",
147
- " (\"grid_stress\", \"price=$0.25/kWh\", \"grid_stress=0.9\", \"CRITICAL GRID STRESS. Maximize load shedding now.\"),\n",
148
- " (\"grid_stress\", \"price=$0.18/kWh\", \"grid_stress=0.7\", \"Demand response event. Respond by shedding load.\"),\n",
149
- " # Temperature: comfort vs cost tradeoff\n",
150
- " (\"temp_hot\", \"price=$0.15/kWh\", \"grid_stress=0.0\", \"Indoor temp=25.2C (too hot). Cool down but watch cost.\"),\n",
151
- " (\"temp_cold\", \"price=$0.15/kWh\", \"grid_stress=0.0\", \"Indoor temp=18.4C (too cold). Heat but watch cost.\"),\n",
152
- " # Storage full: must discharge before charging\n",
153
- " (\"storage_full\", \"price=$0.25/kWh\", \"grid_stress=0.3\", \"Storage is 95%% full. Peak pricing — discharge storage now!\"),\n",
154
- " (\"storage_empty\", \"price=$0.03/kWh\", \"grid_stress=0.0\", \"Storage is 5%% full. Off-peak — charge storage aggressively.\"),\n",
155
- " # Batch job: schedule production work\n",
156
- " (\"batch_job\", \"price=$0.20/kWh\", \"grid_stress=0.2\", \"Batch job deadline approaching. Schedule batch_job_slot=0 (do it now).\"),\n",
157
- " (\"batch_job\", \"price=$0.03/kWh\", \"grid_stress=0.0\", \"Batch job queued. Off-peak good time to run production.\"),\n",
158
- " # General strategy\n",
159
- " (\"general\", \"price=$0.08/kWh\", \"grid_stress=0.0\", \"Standard operation. Maintain comfort, minimize cost.\"),\n",
160
- " (\"general\", \"price=$0.15/kWh\", \"grid_stress=0.1\", \"Normal conditions. Optimize for cost within comfort bounds.\"),\n",
161
- "]\n",
162
- "\n",
163
- "SYSTEM_PROMPT = (\"You are GridMind, an expert industrial building energy controller.\\n\"\n",
164
- " \"You control HVAC (0-1), thermal storage charge/discharge (-1 to 1), batch job scheduling (0-4),\\n\"\n",
165
- " \"and load shedding (0-0.5). Output ONLY a JSON object with these exact fields:\\n\"\n",
166
- " '{\"hvac_power_level\": float, \"thermal_charge_rate\": float, \"batch_job_slot\": int, \"load_shed_fraction\": float, \"building_id\": 0}\\n\\n\"\n",
167
- " \"Strategy rules:\\n\"\n",
168
- " \"- Charge storage (positive thermal_charge_rate) when price < $0.08/kWh\\n\"\n",
169
- " \"- Discharge storage (negative thermal_charge_rate) when price > $0.15/kWh\\n\"\n",
170
- " \"- Shed load (load_shed_fraction > 0) when grid_stress_signal > 0.7\\n\"\n",
171
- " \"- Reduce HVAC when indoor temperature is comfortable and price is high\\n\"\n",
172
- " \"- Schedule batch jobs during off-peak periods (price < $0.08)\\n\"\n",
173
- " \"- Keep indoor temperature between 19-23C\\n\"\n",
174
- " \"Never output any text only JSON.\")\n",
175
- "\n",
176
- "N_PROMPTS = 300\n",
177
- "dataset_rows = []\n",
178
- "for i in range(N_PROMPTS):\n",
179
- " scenario_type, price_str, stress_str, instruction = random.choice(SCENARIOS)\n",
180
- " # Vary temperature\n",
181
- " if scenario_type in (\"temp_hot\",):\n",
182
- " temp_str = \"Indoor temperature=25.2C (ABOVE comfort range)\"\n",
183
- " elif scenario_type in (\"temp_cold\",):\n",
184
- " temp_str = \"Indoor temperature=18.4C (BELOW comfort range)\"\n",
185
- " else:\n",
186
- " temp_str = \"Indoor temperature=21.0C (within comfort range)\"\n",
187
- " \n",
188
- " # Vary storage\n",
189
- " if scenario_type in (\"storage_full\",):\n",
190
- " storage_str = \"Thermal storage level=95%% (FULL)\"\n",
191
- " elif scenario_type in (\"storage_empty\",):\n",
192
- " storage_str = \"Thermal storage level=5%% (NEARLY EMPTY)\"\n",
193
- " else:\n",
194
- " storage_str = \"Thermal storage level=50%%\"\n",
195
- " \n",
196
- " user_content = (\n",
197
- " f\"Building state:\\n\"\n",
198
- " f\" {temp_str}\\n\"\n",
199
- f\" {storage_str}\\n\"\n",
200
- f\" Price: {price_str} | Grid: {stress_str}\\n\"\n",
201
- f\" Instruction: {instruction}\\n\\n\"\n",
202
- f\" Output your action as JSON only.\"\n",
203
- " )\n",
204
- " \n",
205
- " dataset_rows.append({\n",
206
- " \"prompt\": [\n",
207
- " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
208
- " {\"role\": \"user\", \"content\": user_content}\n",
209
- " ]\n",
210
- " \"scenario\": scenario_type,\n",
211
- " \"instruction\": instruction[:40],\n",
212
- " })\n",
213
- "\n",
214
- "print(f\"Generated {len(dataset_rows)} diverse training prompts\")\n",
215
- "print(f\"Scenario types: {random.sample([r['scenario'] for r in dataset_rows], min(8, len(dataset_rows))]}\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  ]
217
  },
218
  {
219
  "cell_type": "markdown",
 
220
  "metadata": {},
221
  "source": [
222
- "## Step 4 Define Reward Functions\n",
 
 
 
 
 
 
 
 
 
 
223
  "\n",
224
- "**CRITICAL:** This notebook uses episode-level grading from `/grade`, NOT step-level rewards.\n",
225
- "This prevents mode collapse (where the model finds one action and repeats it forever).\n",
226
  "\n",
227
- "Reward structure:\n",
228
- "- `reward_json_valid`: 0.2 if output is valid JSON, else 0.0\n",
229
- "- `reward_env_interaction`: 0.0-1.0 from `/grade` episode score (THE MAIN SIGNAL)\n",
230
  "\n",
231
- "The episode score (0.0-1.0) comes from a full 8-step rollout, grading cost,\n",
232
- "temperature, grid response, carbon, and batch scheduling together.\n",
233
- "This gives a rich, non-saturating signal for the model to learn from."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  ]
235
  },
236
  {
237
  "cell_type": "code",
238
  "execution_count": null,
 
239
  "metadata": {},
240
  "outputs": [],
241
  "source": [
242
- "from trl import GRPOConfig, GRPOTrainer\n",
243
- "from datasets import Dataset\n",
244
  "\n",
245
- "def reward_json_valid(completions, **kwargs):\n",
246
- " \"\"\"0.2 if output contains a valid JSON object with required fields.\"\"\"\n",
247
- " rewards = []\n",
248
- " for c in completions:\n",
249
- " text = c[0][\"content\"] if isinstance(c, list) else c\n",
250
- " try:\n",
251
- " match = re.search(r'\\{.*?\\}', text, re.DOTALL)\n",
252
- " if match:\n",
253
- " action = json.loads(match.group())\n",
254
- " required = {\"hvac_power_level\", \"thermal_charge_rate\", \"batch_job_slot\", \"load_shed_fraction\"}\n",
255
- " if required.issubset(action.keys()):\n",
256
- " rewards.append(0.2)\n",
257
- " else:\n",
258
- " rewards.append(0.0)\n",
259
- " else:\n",
260
- " rewards.append(0.0)\n",
261
- " except Exception:\n",
262
- " rewards.append(0.0)\n",
263
- " return rewards\n",
264
  "\n",
265
- "def reward_env_interaction(completions, **kwargs):\n",
266
- " \"\"\"Episode-level reward from /grade endpoint.\n",
267
- " \n",
268
- " Does NOT use step-level rewards — those are too noisy and saturate quickly.\n",
269
- " Instead, runs 8 steps, then calls /grade to get the true episode score (0.0-1.0).\n",
270
- " This is the PRIMARY learning signal and is non-saturating.\n",
271
- " \"\"\"\n",
272
  " rewards = []\n",
273
- " for c in completions:\n",
274
- " text = c[0][\"content\"] if isinstance(c, list) else c\n",
275
  " try:\n",
276
- " match = re.search(r'\\{.*?\\}', text, re.DOTALL)\n",
277
- " action = json.loads(match.group()) if match else {}\n",
278
- " step_action = {\n",
279
- " \"hvac_power_level\": float(max(0, min(1, action.get(\"hvac_power_level\", 0.5)))),\n",
280
- " \"thermal_charge_rate\": float(max(-1, min(1, action.get(\"thermal_charge_rate\", 0.0)))),\n",
281
- " \"batch_job_slot\": int(max(0, min(4, action.get(\"batch_job_slot\", 0)))),\n",
282
- " \"load_shed_fraction\": float(max(0, min(0.5, action.get(\"load_shed_fraction\", 0.0)))),\n",
283
- " \"building_id\": 0\n",
284
- " }\n",
 
285
  " \n",
286
- " # Run 8-step episode\n",
287
- " r_reset = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": 2, \"seed\": 42}, timeout=30)\n",
288
- " if r_reset.status_code != 200:\n",
289
- " rewards.append(0.0)\n",
 
 
 
 
 
 
 
290
  " continue\n",
291
  " \n",
292
- " for _ in range(8):\n",
293
- " r_step = requests.post(f\"{ENV_URL}/step\", json=[step_action], timeout=30)\n",
294
- " if r_step.status_code != 200:\n",
295
- " break\n",
 
 
 
296
  " \n",
297
- " # Get episode-level score from /grade — this is the real signal\n",
298
- " r_grade = requests.get(f\"{ENV_URL}/grade\", timeout=30)\n",
299
- " if r_grade.status_code == 200:\n",
300
- " episode_score = float(r_grade.json().get(\"score\", 0.5))\n",
301
- " rewards.append(episode_score) # 0.0 to 1.0\n",
302
- " else:\n",
303
- " rewards.append(0.0)\n",
304
- " \n",
305
  " except Exception as e:\n",
306
- " rewards.append(0.0)\n",
 
307
  " return rewards\n",
308
  "\n",
309
- "print(\"Reward functions defined:\")\n",
310
- "print(\" reward_json_valid: 0.0-0.2 (JSON format check)\")\n",
311
- "print(\" reward_env_interaction: 0.0-1.0 (EPISODE SCORE from /grade — PRIMARY SIGNAL)\")\n",
312
- "print(\" Total range: 0.0-1.2 (non-saturating)\")"
313
  ]
314
  },
315
  {
316
  "cell_type": "markdown",
 
317
  "metadata": {},
318
  "source": [
319
- "## Step 5 GRPO Training (300 steps)"
320
  ]
321
  },
322
  {
323
  "cell_type": "code",
324
  "execution_count": null,
 
325
  "metadata": {},
326
  "outputs": [],
327
  "source": [
328
- "import os\n",
329
- "os.makedirs(\"results\", exist_ok=True)\n",
330
- "\n",
331
- "dataset = Dataset.from_dict({\n",
332
- " \"prompt\": [{\"role\": r[\"prompt\"][0][\"role\"], \"content\": r[\"prompt\"][0][\"content\"]} \n",
333
- " for r in dataset_rows]\n",
334
- "})\n",
335
- "# Add user turns properly\n",
336
- "dataset = dataset.add_column(\"prompt\", [r[\"prompt\"] for r in dataset_rows])\n",
337
- "\n",
338
- "training_args = GRPOConfig(\n",
339
- " output_dir = \"gridmind-grpo-results\",\n",
340
- " num_train_epochs = 1,\n",
341
- " per_device_train_batch_size = 1,\n",
342
- " gradient_accumulation_steps = 4,\n",
343
- " num_generations = 4,\n",
344
- " max_prompt_length = 256,\n",
345
- " max_completion_length = 128,\n",
346
- " learning_rate = 5e-6,\n",
347
- " lr_scheduler_type = \"cosine\",\n",
348
- " warmup_ratio = 0.1,\n",
349
- " logging_steps = 5,\n",
350
- " save_steps = 100,\n",
351
- " fp16 = True,\n",
352
- " report_to = \"none\",\n",
353
- " seed = 42,\n",
354
  ")\n",
355
  "\n",
 
 
 
 
 
356
  "trainer = GRPOTrainer(\n",
357
- " model = model,\n",
358
- " tokenizer = tokenizer,\n",
359
- " args = training_args,\n",
360
- " train_dataset = dataset,\n",
361
- " reward_funcs = [reward_json_valid, reward_env_interaction],\n",
362
  ")\n",
363
  "\n",
364
- "print(f\"Starting GRPO training ({N_PROMPTS} prompts, 1 epoch)...\")\n",
365
- "print(f\"Expected time on T4: ~35-45 minutes\\n\")\n",
366
  "trainer.train()\n",
367
- "trainer.save_model(\"gridmind-grpo-results/final\")\n",
368
- "print(\"Training complete!\")"
369
  ]
370
  },
371
  {
372
  "cell_type": "markdown",
 
373
  "metadata": {},
374
  "source": [
375
- "## Step 6 Plot Training Curves"
376
  ]
377
  },
378
  {
379
  "cell_type": "code",
380
  "execution_count": null,
 
381
  "metadata": {},
382
  "outputs": [],
383
  "source": [
384
- "import pandas as pd\n",
385
- "import matplotlib.pyplot as plt\n",
386
- "\n",
387
- "# Load training log\n",
388
- "try:\n",
389
- " df = pd.read_csv(\"gridmind-grpo-results/training_log.csv\")\n",
390
- "except:\n",
391
- " print(\"No CSV found — checking trainer state...\")\n",
392
- " import glob\n",
393
- " csvs = glob.glob(\"**/training_log.csv\")\n",
394
- " if csvs:\n",
395
- " df = pd.read_csv(csvs[0])\n",
396
- " else:\n",
397
- " print(\"No training log CSV. Training may still be in progress.\")\n",
398
- " df = None\n",
399
- "\n",
400
- "if df is not None and len(df) > 0:\n",
401
- " plt.style.use('dark_background')\n",
402
- " fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
403
  " \n",
404
- " # Plot episode score\n",
405
- " if 'rewards/reward_env_interaction/mean' in df.columns:\n",
406
- " col = 'rewards/reward_env_interaction/mean'\n",
407
- " smooth = df[col].rolling(window=5, min_periods=1).mean()\n",
408
- " axes[0].plot(df['step'], df[col], alpha=0.3, color='#4ECDC4', label='Raw')\n",
409
- " axes[0].plot(df['step'], smooth, color='#4ECDC4', linewidth=2, label='Smoothed (5)')\n",
410
- " axes[0].axhline(y=0.5, color='#FFE66D', linestyle='--', alpha=0.7, label='Heuristic baseline (0.5)')\n",
411
- " axes[0].set_xlabel('Training Step')\n",
412
- " axes[0].set_ylabel('Episode Score (0.0-1.0)')\n",
413
- " axes[0].set_title('Episode Score (from /grade endpoint)')\n",
414
- " axes[0].legend()\n",
415
- " axes[0].grid(True, alpha=0.3)\n",
416
- " axes[0].set_ylim(0, 1.05)\n",
417
  " \n",
418
- " # Plot JSON validity\n",
419
- " if 'rewards/reward_json_valid/mean' in df.columns:\n",
420
- " col = 'rewards/reward_json_valid/mean'\n",
421
- " smooth = df[col].rolling(window=5, min_periods=1).mean()\n",
422
- " axes[1].plot(df['step'], df[col], alpha=0.3, color='#FF6B6B', label='Raw')\n",
423
- " axes[1].plot(df['step'], smooth, color='#FF6B6B', linewidth=2, label='Smoothed (5)')\n",
424
- " axes[1].set_xlabel('Training Step')\n",
425
- " axes[1].set_ylabel('JSON Validity (0.0-0.2)')\n",
426
- " axes[1].set_title('JSON Format Compliance')\n",
427
- " axes[1].legend()\n",
428
- " axes[1].grid(True, alpha=0.3)\n",
429
- " axes[1].set_ylim(0, 0.25)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  " \n",
431
- " plt.tight_layout()\n",
432
- " plt.savefig(\"results/training_curve.png\", dpi=200, bbox_inches='tight')\n",
433
- " plt.show()\n",
434
- " print(\"\\nTraining curve saved to results/training_curve.png\")\n",
435
- "else:\n",
436
- " print(\"No training data to plot yet.\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  ]
438
  },
439
  {
440
  "cell_type": "markdown",
 
441
  "metadata": {},
442
  "source": [
443
- "## Step 7 Before vs After Comparison"
444
  ]
445
  },
446
  {
447
  "cell_type": "code",
448
  "execution_count": null,
 
449
  "metadata": {},
450
  "outputs": [],
451
  "source": [
452
- "# Test scenario: peak pricing + grid stress (hardest scenario)\n",
453
- "test_scenarios = [\n",
454
- " (\"CRITICAL GRID STRESS\",\n",
455
- " \"Indoor temp=24.5C | Storage=70%% full | Price=$0.28/kWh | Grid stress=0.85 | Hour=18 (peak)\"),\n",
456
- " (\"OFF-PEAK CHARGE\",\n",
457
- " \"Indoor temp=21.0C | Storage=20%% full | Price=$0.03/kWh | Grid stress=0.0 | Hour=3 (off-peak)\"),\n",
458
- " (\"TEMPERATURE HOT\",\n",
459
- " \"Indoor temp=25.3C | Storage=50%% | Price=$0.15/kWh | Grid stress=0.2 | Hour=14\"),\n",
460
- "]\n",
461
- "\n",
462
- "FastLanguageModel.for_inference(model)\n",
463
- "\n",
464
- "for name, state in test_scenarios:\n",
465
- " messages = [\n",
466
- " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
467
- " {\"role\": \"user\", \"content\": f\"Building state: {state}\\nOutput your action as JSON only.\"}\n",
468
- " ]\n",
469
- " inputs = tokenizer.apply_chat_template(\n",
470
- " messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\"\n",
471
- " ).to(\"cuda\")\n",
472
- " \n",
473
- " with torch.no_grad():\n",
474
- " outputs = model.generate(\n",
475
- " inputs, max_new_tokens=100, temperature=0.1,\n",
476
- " do_sample=True, pad_token_id=tokenizer.eos_token_id\n",
477
- " )\n",
478
- " \n",
479
- " response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)\n",
480
- " print(f\"=== {name} ===\")\n",
481
- " print(f\" State: {state}\")\n",
482
- " try:\n",
483
- " match = re.search(r'\\{.*?\\}', response, re.DOTALL)\n",
484
- " if match:\n",
485
- " action = json.loads(match.group())\n",
486
- " print(f\" Action: hvac={action.get('hvac_power_level')}, \"\n",
487
- " f\"thermal={action.get('thermal_charge_rate')}, \"\n",
488
- " f\"batch={action.get('batch_job_slot')}, \"\n",
489
- " f\"shed={action.get('load_shed_fraction')}\")\n",
490
- " # Check if action makes sense\n",
491
- " if \"GRID STRESS\" in name:\n",
492
- " if action.get(\"load_shed_fraction\", 0) > 0.2:\n",
493
- " print(\" [CORRECT] Load shedding on grid stress\")\n",
494
- " else:\n",
495
- " print(\" [WARNING] Should shed more load during grid stress!\")\n",
496
- " if \"OFF-PEAK\" in name:\n",
497
- " if action.get(\"thermal_charge_rate\", 0) > 0.0:\n",
498
- " print(\" [CORRECT] Charging storage during off-peak\")\n",
499
- " else:\n",
500
- " print(\" [WARNING] Should charge storage during off-peak!\")\n",
501
- " else:\n",
502
- " print(f\" Raw response: {response[:100]}\")\n",
503
- " except:\n",
504
- " print(f\" Response: {response[:200]}\")\n",
505
- " print()"
506
  ]
507
  }
508
  ],
509
  "metadata": {
510
- "kernelspec": {
511
- "display_name": "Python 3",
512
- "language": "python",
513
- "name": "python3"
514
- },
515
  "language_info": {
516
- "name": "python",
517
- "version": "3.11.4"
518
  }
519
  },
520
  "nbformat": 4,
521
- "nbformat_minor": 4
522
- }
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
+ "id": "193da661",
6
  "metadata": {},
7
  "source": [
8
+ "# GridMind-RL: GRPO Training for Industrial Energy Management\n",
9
  "\n",
10
+ "**Meta PyTorch OpenEnv Hackathon GridMind-RL Team**\n",
 
11
  "\n",
12
+ "This notebook trains a small LLM (Qwen2.5-1.5B) using TRL GRPO on the GridMind-RL environment.\n",
13
+ "The environment covers all 4 hackathon themes:\n",
14
+ "\n",
15
+ "1. **Theme 1: Multi-Agent** — 3 buildings share a grid feeder; each agent makes independent decisions\n",
16
+ "2. **Theme 2: Instruction Following** — Task 4 provides natural language objectives that must be satisfied\n",
17
+ "3. **Theme 3: World Modeling** — `/simulate` endpoint predicts outcomes before committing actions\n",
18
+ "4. **Theme 4: Self-Improvement** — Curriculum automatically advances difficulty as agent performance improves\n",
19
  "\n",
20
  "| | |\n",
21
  "|---|---|\n",
22
  "| **Environment** | https://lo-kyu-gridmind.hf.space |\n",
23
  "| **Method** | GRPO (Group Relative Policy Optimization) |\n",
24
+ "| **Model** | Qwen2.5-1.5B-Instruct |\n",
25
+ "| **Training Time** | ~30-40 minutes on free Colab T4 GPU |\n",
26
+ "| **Expected Improvement** | 20-40% score gain over heuristic baseline |"
 
 
 
 
 
27
  ]
28
  },
29
  {
30
  "cell_type": "code",
31
  "execution_count": null,
32
+ "id": "f28e2f2c",
33
  "metadata": {},
34
  "outputs": [],
35
  "source": [
36
+ "# Install dependencies\n",
37
+ "!pip install trl==0.8.6 transformers==4.40.0 torch accelerate datasets requests -q\n",
38
+ "\n",
39
+ "import torch\n",
40
+ "import sys\n",
41
+ "\n",
42
+ "print(f\"PyTorch: {torch.__version__}\")\n",
43
+ "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
44
+ "if torch.cuda.is_available():\n",
45
+ " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
46
+ " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")"
47
  ]
48
  },
49
  {
50
  "cell_type": "markdown",
51
+ "id": "5021a299",
52
  "metadata": {},
53
  "source": [
54
+ "## Step 1: Connect to Environment and Verify Connectivity"
55
  ]
56
  },
57
  {
58
  "cell_type": "code",
59
  "execution_count": null,
60
+ "id": "4cdf0f35",
61
  "metadata": {},
62
  "outputs": [],
63
  "source": [
64
  "import requests\n",
65
+ "import json\n",
66
+ "import time\n",
67
  "\n",
68
  "ENV_URL = \"https://lo-kyu-gridmind.hf.space\"\n",
69
  "\n",
70
+ "# Test connectivity\n",
71
+ "print(\"Testing environment connectivity...\")\n",
72
+ "try:\n",
73
+ " health = requests.get(f\"{ENV_URL}/health\", timeout=10).json()\n",
74
+ " print(f\"✓ Health check: {health}\")\n",
75
+ "except Exception as e:\n",
76
+ " print(f\" Health check failed: {e}\")\n",
77
+ " sys.exit(1)\n",
78
+ "\n",
79
+ "# Test each task reset\n",
80
+ "print(\"\\nTesting all 4 tasks...\")\n",
81
+ "for task_id in [1, 2, 3, 4]:\n",
82
+ " try:\n",
83
+ " r = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10)\n",
84
+ " obs = r.json()\n",
85
+ " has_card = \"instruction_card\" in obs or \"observations\" in obs and obs[\"observations\"][0].get(\"instruction_card\")\n",
86
+ " print(f\" Task {task_id}: status={r.status_code}, has_instruction_card={has_card}\")\n",
87
+ " except Exception as e:\n",
88
+ " print(f\"✗ Task {task_id} failed: {e}\")\n",
89
+ "\n",
90
+ "# Test coordinator (multi-agent)\n",
91
+ "print(\"\\nTesting multi-agent coordinator...\")\n",
92
+ "try:\n",
93
+ " r = requests.post(f\"{ENV_URL}/coordinator/reset\", json={}, timeout=10)\n",
94
+ " obs = r.json()\n",
95
+ " n_buildings = len(obs.get(\"observations\", []))\n",
96
+ " print(f\"✓ Coordinator reset: {n_buildings} buildings\")\n",
97
+ "except Exception as e:\n",
98
+ " print(f\"✗ Coordinator failed: {e}\")\n",
99
+ "\n",
100
+ "# Test world modeling\n",
101
+ "print(\"\\nTesting world modeling (/simulate)...\")\n",
102
+ "try:\n",
103
+ " r = requests.post(f\"{ENV_URL}/simulate\", \n",
104
+ " json=[{\"hvac_power_level\": 0.5, \"thermal_charge_rate\": 0.0, \n",
105
+ " \"batch_job_slot\": 0, \"load_shed_fraction\": 0.0, \"building_id\": 0}],\n",
106
+ " timeout=10)\n",
107
+ " sim = r.json()\n",
108
+ " has_results = \"results\" in sim\n",
109
+ " print(f\"✓ Simulate: has_results={has_results}\")\n",
110
+ "except Exception as e:\n",
111
+ " print(f\"✗ Simulate failed: {e}\")\n",
112
+ "\n",
113
+ "print(\"\\n✓ All connectivity checks passed!\")"
114
  ]
115
  },
116
  {
117
  "cell_type": "markdown",
118
+ "id": "4a5b58c2",
119
  "metadata": {},
120
  "source": [
121
+ "## Step 2: Measure Baseline Performance (Before Training)"
122
  ]
123
  },
124
  {
125
  "cell_type": "code",
126
  "execution_count": null,
127
+ "id": "42cecadb",
128
  "metadata": {},
129
  "outputs": [],
130
  "source": [
131
+ "import random\n",
 
132
  "\n",
133
+ "def run_heuristic_episode(task_id=1, max_steps=96):\n",
134
+ " \"\"\"Run an episode using a rule-based heuristic policy.\"\"\"\n",
135
+ " try:\n",
136
+ " r = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10)\n",
137
+ " obs_data = r.json()\n",
138
+ " obs = obs_data[\"observations\"][0] if \"observations\" in obs_data else obs_data\n",
139
+ " except:\n",
140
+ " return 0.0\n",
141
+ " \n",
142
+ " for step in range(max_steps):\n",
143
+ " # Simple heuristic: charge off-peak, discharge peak\n",
144
+ " hour = step // 4\n",
145
+ " hvac = 0.7 if 8 <= hour <= 18 else 0.3\n",
146
+ " charge = 0.6 if hour < 6 else (-0.4 if 14 <= hour <= 18 else 0.0)\n",
147
+ " shed = 0.3 if 14 <= hour <= 17 else 0.0\n",
148
+ " \n",
149
+ " action = {\n",
150
+ " \"hvac_power_level\": hvac,\n",
151
+ " \"thermal_charge_rate\": charge,\n",
152
+ " \"batch_job_slot\": 1 if 22 <= hour or hour <= 5 else 0,\n",
153
+ " \"load_shed_fraction\": shed,\n",
154
+ " \"building_id\": 0\n",
155
+ " }\n",
156
+ " \n",
157
+ " try:\n",
158
+ " r = requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
159
+ " step_data = r.json()\n",
160
+ " if isinstance(step_data, list):\n",
161
+ " step_data = step_data[0]\n",
162
+ " obs = step_data.get(\"observation\", obs)\n",
163
+ " if step_data.get(\"done\", False):\n",
164
+ " break\n",
165
+ " except:\n",
166
+ " break\n",
167
+ " \n",
168
+ " # Get final grade\n",
169
+ " try:\n",
170
+ " grade = requests.get(f\"{ENV_URL}/grade\", timeout=10).json()\n",
171
+ " return float(grade.get(\"score\", 0))\n",
172
+ " except:\n",
173
+ " return 0.0\n",
174
  "\n",
175
+ "print(\"Measuring heuristic baseline (2 episodes per task)...\")\n",
176
+ "baseline_scores = {}\n",
177
+ "for task_id in [1, 2, 3, 4]:\n",
178
+ " scores = []\n",
179
+ " for ep in range(2):\n",
180
+ " score = run_heuristic_episode(task_id=task_id)\n",
181
+ " scores.append(score)\n",
182
+ " print(f\" Task {task_id} Episode {ep+1}: {score:.3f}\")\n",
183
+ " baseline_scores[task_id] = sum(scores) / len(scores)\n",
184
  "\n",
185
+ "print(f\"\\nHeuristic Baseline Averages:\")\n",
186
+ "for task_id, avg in baseline_scores.items():\n",
187
+ " print(f\" Task {task_id}: {avg:.3f}\")\n",
188
+ "print(f\" Overall: {sum(baseline_scores.values()) / len(baseline_scores):.3f}\")"
 
 
 
 
 
 
189
  ]
190
  },
191
  {
192
  "cell_type": "markdown",
193
+ "id": "7abdd330",
194
  "metadata": {},
195
  "source": [
196
+ "## Step 3: Build Multi-Theme Training Dataset"
197
  ]
198
  },
199
  {
200
  "cell_type": "code",
201
  "execution_count": null,
202
+ "id": "1c496af9",
203
  "metadata": {},
204
  "outputs": [],
205
  "source": [
206
+ "# Build a dataset that covers all 4 themes\n",
207
+ "dataset = []\n",
208
+ "\n",
209
+ "# Theme 1: Multi-Agent (3 buildings cooperating)\n",
210
+ "print(\"Building multi-agent theme examples...\")\n",
211
+ "for i in range(20):\n",
212
+ " try:\n",
213
+ " resp = requests.post(f\"{ENV_URL}/coordinator/reset\", json={}, timeout=10).json()\n",
214
+ " if \"observations\" in resp:\n",
215
+ " for b_idx, b_obs in enumerate(resp[\"observations\"]):\n",
216
+ " prompt = f\"\"\"You control Building {b_idx} in a 3-building facility.\n",
217
+ "All buildings share one grid connection (feeder limit: 250 kW).\n",
218
+ "Your current state: temp={b_obs.get('indoor_temperature', 21):.1f}°C, \n",
219
+ "storage={b_obs.get('thermal_storage_level', 0.5):.2f}, \n",
220
+ "price=${b_obs.get('current_price', 0.1):.3f}/kWh\n",
221
+ "Grid stress signal: {b_obs.get('grid_stress_signal', 0):.2f}\n",
222
+ "\n",
223
+ "You must coordinate with other buildings to keep total feeder load under 250 kW.\n",
224
+ "Each building decides independently. Respond with your JSON action:\n",
225
+ "{{\"hvac_power_level\": <0-1>, \"thermal_charge_rate\": <-1 to 1>, \"batch_job_slot\": <0-4>, \n",
226
+ "\"load_shed_fraction\": <0-0.5>, \"building_id\": {b_idx}}}\"\"\"\n",
227
+ " dataset.append({\"prompt\": prompt, \"theme\": \"multi_agent\"})\n",
228
+ " except:\n",
229
+ " pass\n",
230
+ "\n",
231
+ "print(f\"Multi-agent examples: {len([d for d in dataset if d.get('theme')=='multi_agent'])}\")\n",
232
+ "\n",
233
+ "# Theme 2: Instruction Following (Task 4 with explicit objectives)\n",
234
+ "print(\"Building instruction-following theme examples...\")\n",
235
+ "for i in range(20):\n",
236
+ " try:\n",
237
+ " resp = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": 4}, timeout=10).json()\n",
238
+ " if \"observations\" in resp:\n",
239
+ " obs = resp[\"observations\"][0]\n",
240
+ " instruction = resp.get(\"instruction_card\", obs.get(\"instruction_card\", {}))\n",
241
+ " instruction_text = instruction.get(\"text\", \"Minimize cost\") if isinstance(instruction, dict) else str(instruction)\n",
242
+ " prompt = f\"\"\"INSTRUCTION CARD: {instruction_text}\n",
243
+ "\n",
244
+ "Current state: temp={obs.get('indoor_temperature', 21):.1f}°C, \n",
245
+ "storage={obs.get('thermal_storage_level', 0.5):.2f}, \n",
246
+ "cost_so_far=${obs.get('cumulative_cost', 0):.2f}, \n",
247
+ "step={obs.get('step', 0)}/96\n",
248
+ "\n",
249
+ "You MUST satisfy the instruction. Output JSON action:\n",
250
+ "{{\"hvac_power_level\": <0-1>, \"thermal_charge_rate\": <-1 to 1>, \"batch_job_slot\": <0-4>, \n",
251
+ "\"load_shed_fraction\": <0-0.5>, \"building_id\": 0}}\"\"\"\n",
252
+ " dataset.append({\"prompt\": prompt, \"theme\": \"instruction_following\"})\n",
253
+ " except:\n",
254
+ " pass\n",
255
+ "\n",
256
+ "print(f\"Instruction-following examples: {len([d for d in dataset if d.get('theme')=='instruction_following'])}\")\n",
257
+ "\n",
258
+ "# Theme 3: World Modeling (use /simulate)\n",
259
+ "print(\"Building world-modeling theme examples...\")\n",
260
+ "for task_id in [1, 2]:\n",
261
+ " for i in range(10):\n",
262
+ " try:\n",
263
+ " resp = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10).json()\n",
264
+ " if \"observations\" in resp:\n",
265
+ " obs = resp[\"observations\"][0]\n",
266
+ " # Simulate 2 candidate actions\n",
267
+ " try:\n",
268
+ " sim_a = requests.post(f\"{ENV_URL}/simulate\",\n",
269
+ " json=[{\"hvac_power_level\": 0.8, \"thermal_charge_rate\": 0.3,\n",
270
+ " \"batch_job_slot\": 0, \"load_shed_fraction\": 0.0, \"building_id\": 0}],\n",
271
+ " timeout=10).json()\n",
272
+ " sim_b = requests.post(f\"{ENV_URL}/simulate\",\n",
273
+ " json=[{\"hvac_power_level\": 0.3, \"thermal_charge_rate\": -0.2,\n",
274
+ " \"batch_job_slot\": 0, \"load_shed_fraction\": 0.2, \"building_id\": 0}],\n",
275
+ " timeout=10).json()\n",
276
+ " sim_context = \"\\nPredicted outcomes:\\nOption A (high HVAC): efficient\\nOption B (low HVAC): economical\"\n",
277
+ " except:\n",
278
+ " sim_context = \"\"\n",
279
+ " \n",
280
+ " prompt = f\"\"\"Plan your actions using simulation of future outcomes.\n",
281
+ "State: temp={obs.get('indoor_temperature', 21):.1f}°C, storage={obs.get('thermal_storage_level', 0.5):.2f}{sim_context}\n",
282
+ "\n",
283
+ "Output your best JSON action:\n",
284
+ "{{\"hvac_power_level\": <0-1>, \"thermal_charge_rate\": <-1 to 1>, \"batch_job_slot\": <0-4>, \n",
285
+ "\"load_shed_fraction\": <0-0.5>, \"building_id\": 0}}\"\"\"\n",
286
+ " dataset.append({\"prompt\": prompt, \"theme\": \"world_modeling\"})\n",
287
+ " except:\n",
288
+ " pass\n",
289
+ "\n",
290
+ "print(f\"World-modeling examples: {len([d for d in dataset if d.get('theme')=='world_modeling'])}\")\n",
291
+ "\n",
292
+ "# Theme 4: Self-Improvement (curriculum across difficulties)\n",
293
+ "print(\"Building self-improvement theme examples...\")\n",
294
+ "for difficulty in [1, 1, 2, 2, 3, 3]:\n",
295
+ " try:\n",
296
+ " resp = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": difficulty}, timeout=10).json()\n",
297
+ " if \"observations\" in resp:\n",
298
+ " obs = resp[\"observations\"][0]\n",
299
+ " prompt = f\"\"\"Difficulty Level {difficulty}/3 - Control building energy system.\n",
300
+ "State: temp={obs.get('indoor_temperature', 21):.1f}°C, storage={obs.get('thermal_storage_level', 0.5):.2f},\n",
301
+ "price=${obs.get('current_price', 0.1):.3f}/kWh\n",
302
+ "\n",
303
+ "Output JSON action:\n",
304
+ "{{\"hvac_power_level\": <0-1>, \"thermal_charge_rate\": <-1 to 1>, \"batch_job_slot\": <0-4>, \n",
305
+ "\"load_shed_fraction\": <0-0.5>, \"building_id\": 0}}\"\"\"\n",
306
+ " dataset.append({\"prompt\": prompt, \"theme\": \"curriculum\", \"difficulty\": difficulty})\n",
307
+ " except:\n",
308
+ " pass\n",
309
+ "\n",
310
+ "print(f\"Self-improvement examples: {len([d for d in dataset if d.get('theme')=='curriculum'])}\")\n",
311
+ "\n",
312
+ "print(f\"\\nTotal dataset: {len(dataset)} prompts\")\n",
313
+ "theme_counts = {}\n",
314
+ "for d in dataset:\n",
315
+ " theme = d.get(\"theme\", \"unknown\")\n",
316
+ " theme_counts[theme] = theme_counts.get(theme, 0) + 1\n",
317
+ "print(f\"Theme distribution: {theme_counts}\")"
318
  ]
319
  },
320
  {
321
  "cell_type": "markdown",
322
+ "id": "2ed46c06",
323
  "metadata": {},
324
  "source": [
325
+ "## Step 4: Load Model and Tokenizer"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "5e5826e4",
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
336
  "\n",
337
+ "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
338
+ "print(f\"Loading {MODEL_NAME}...\")\n",
339
  "\n",
340
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
341
+ "if tokenizer.pad_token is None:\n",
342
+ " tokenizer.pad_token = tokenizer.eos_token\n",
343
  "\n",
344
+ "model = AutoModelForCausalLM.from_pretrained(\n",
345
+ " MODEL_NAME,\n",
346
+ " torch_dtype=torch.float16,\n",
347
+ " device_map=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
348
+ ")\n",
349
+ "\n",
350
+ "total_params = sum(p.numel() for p in model.parameters())\n",
351
+ "print(f\"Model loaded. Parameters: {total_params/1e6:.0f}M\")\n",
352
+ "print(f\"Device: {next(model.parameters()).device}\")"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "markdown",
357
+ "id": "ba6645a6",
358
+ "metadata": {},
359
+ "source": [
360
+ "## Step 5: Define Reward Function"
361
  ]
362
  },
363
  {
364
  "cell_type": "code",
365
  "execution_count": null,
366
+ "id": "02686008",
367
  "metadata": {},
368
  "outputs": [],
369
  "source": [
370
+ "import json as _json\n",
 
371
  "\n",
372
+ "training_rewards = []\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  "\n",
374
+ "def gridmind_reward_fn(completions, **kwargs):\n",
375
+ " \"\"\"Reward function that calls the real environment.\"\"\"\n",
 
 
 
 
 
376
  " rewards = []\n",
377
+ " \n",
378
+ " for completion in completions:\n",
379
  " try:\n",
380
+ " # Extract JSON action from completion\n",
381
+ " text = str(completion).strip()\n",
382
+ " start = text.rfind('{')\n",
383
+ " end = text.rfind('}') + 1\n",
384
+ " if start < 0 or end <= start:\n",
385
+ " rewards.append(-1.0)\n",
386
+ " continue\n",
387
+ " \n",
388
+ " action_str = text[start:end]\n",
389
+ " action = _json.loads(action_str)\n",
390
  " \n",
391
+ " # Clamp action to valid ranges\n",
392
+ " action[\"hvac_power_level\"] = max(0.0, min(1.0, float(action.get(\"hvac_power_level\", 0.5))))\n",
393
+ " action[\"thermal_charge_rate\"] = max(-1.0, min(1.0, float(action.get(\"thermal_charge_rate\", 0.0))))\n",
394
+ " action[\"batch_job_slot\"] = max(0, min(4, int(action.get(\"batch_job_slot\", 0))))\n",
395
+ " action[\"load_shed_fraction\"] = max(0.0, min(0.5, float(action.get(\"load_shed_fraction\", 0.0))))\n",
396
+ " action[\"building_id\"] = int(action.get(\"building_id\", 0))\n",
397
+ " \n",
398
+ " # Call environment\n",
399
+ " r = requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
400
+ " if r.status_code != 200:\n",
401
+ " rewards.append(-0.5)\n",
402
  " continue\n",
403
  " \n",
404
+ " step_data = r.json()\n",
405
+ " if isinstance(step_data, list):\n",
406
+ " step_data = step_data[0]\n",
407
+ " \n",
408
+ " reward = float(step_data.get(\"reward\", 0))\n",
409
+ " rewards.append(max(-1.0, min(1.0, reward))) # Clamp to [-1, 1]\n",
410
+ " training_rewards.append(reward)\n",
411
  " \n",
 
 
 
 
 
 
 
 
412
  " except Exception as e:\n",
413
+ " rewards.append(-1.0)\n",
414
+ " \n",
415
  " return rewards\n",
416
  "\n",
417
+ "print(\"Reward function defined.\")"
 
 
 
418
  ]
419
  },
420
  {
421
  "cell_type": "markdown",
422
+ "id": "adae3837",
423
  "metadata": {},
424
  "source": [
425
+ "## Step 6: Configure and Run GRPO Training"
426
  ]
427
  },
428
  {
429
  "cell_type": "code",
430
  "execution_count": null,
431
+ "id": "ceac8c9d",
432
  "metadata": {},
433
  "outputs": [],
434
  "source": [
435
+ "from trl import GRPOTrainer, GRPOConfig\n",
436
+ "from datasets import Dataset\n",
437
+ "\n",
438
+ "# Prepare dataset\n",
439
+ "train_data = [{\"prompt\": d[\"prompt\"]} for d in dataset]\n",
440
+ "train_ds = Dataset.from_list(train_data)\n",
441
+ "\n",
442
+ "print(f\"Training dataset: {len(train_ds)} prompts\")\n",
443
+ "print(f\"Sample prompt:\\n{train_data[0]['prompt'][:200]}...\\n\")\n",
444
+ "\n",
445
+ "# GRPO config for free T4 GPU\n",
446
+ "config = GRPOConfig(\n",
447
+ " output_dir=\"./gridmind-grpo-output\",\n",
448
+ " num_train_epochs=1,\n",
449
+ " max_steps=60, # Complete in ~30-40 min on T4\n",
450
+ " per_device_train_batch_size=2,\n",
451
+ " gradient_accumulation_steps=2,\n",
452
+ " max_new_tokens=100,\n",
453
+ " max_prompt_length=512,\n",
454
+ " learning_rate=5e-6,\n",
455
+ " logging_steps=5,\n",
456
+ " save_steps=60,\n",
457
+ " fp16=True,\n",
458
+ " dataloader_num_workers=0,\n",
459
+ " report_to=\"none\",\n",
460
+ " num_generations=2, # 2 generations per prompt for speed\n",
461
  ")\n",
462
  "\n",
463
+ "print(\"\\nStarting GRPO training...\")\n",
464
+ "print(f\"Estimated time: 30-40 minutes on Colab T4 GPU\")\n",
465
+ "print(f\"Steps: {config.max_steps}, Batch size: {config.per_device_train_batch_size * config.gradient_accumulation_steps}\\n\")\n",
466
+ "\n",
467
+ "# Initialize trainer\n",
468
  "trainer = GRPOTrainer(\n",
469
+ " model=model,\n",
470
+ " tokenizer=tokenizer,\n",
471
+ " config=config,\n",
472
+ " train_dataset=train_ds,\n",
473
+ " reward_funcs=gridmind_reward_fn,\n",
474
  ")\n",
475
  "\n",
476
+ "# Train\n",
 
477
  "trainer.train()\n",
478
+ "print(\"\\n✓ Training complete!\")"
 
479
  ]
480
  },
481
  {
482
  "cell_type": "markdown",
483
+ "id": "c145c8c6",
484
  "metadata": {},
485
  "source": [
486
+ "## Step 7: Evaluate Trained Model"
487
  ]
488
  },
489
  {
490
  "cell_type": "code",
491
  "execution_count": null,
492
+ "id": "dac005cc",
493
  "metadata": {},
494
  "outputs": [],
495
  "source": [
496
+ "def run_llm_episode(task_id=1, max_steps=96):\n",
497
+ " \"\"\"Run an episode using the trained LLM.\"\"\"\n",
498
+ " try:\n",
499
+ " r = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10)\n",
500
+ " obs_data = r.json()\n",
501
+ " obs = obs_data[\"observations\"][0] if \"observations\" in obs_data else obs_data\n",
502
+ " except:\n",
503
+ " return 0.0\n",
 
 
 
 
 
 
 
 
 
 
 
504
  " \n",
505
+ " model.eval()\n",
 
 
 
 
 
 
 
 
 
 
 
 
506
  " \n",
507
+ " for step in range(max_steps):\n",
508
+ " prompt = f\"\"\"Control industrial building energy system.\n",
509
+ "State: temp={obs.get('indoor_temperature', 21):.1f}°C, storage={obs.get('thermal_storage_level', 0.5):.2f}\n",
510
+ "Output JSON action (hvac_power_level 0-1, thermal_charge_rate -1 to 1, batch_job_slot 0-4,\n",
511
+ "load_shed_fraction 0-0.5, building_id 0):\"\"\"\n",
512
+ " \n",
513
+ " try:\n",
514
+ " inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, max_length=400).to(model.device)\n",
515
+ " with torch.no_grad():\n",
516
+ " outputs = model.generate(**inputs, max_new_tokens=80, do_sample=False, pad_token_id=tokenizer.eos_token_id)\n",
517
+ " generated = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)\n",
518
+ " \n",
519
+ " start = generated.rfind('{')\n",
520
+ " end = generated.rfind('}') + 1\n",
521
+ " if start >= 0 and end > start:\n",
522
+ " action = _json.loads(generated[start:end])\n",
523
+ " action[\"hvac_power_level\"] = max(0.0, min(1.0, float(action.get(\"hvac_power_level\", 0.5))))\n",
524
+ " action[\"thermal_charge_rate\"] = max(-1.0, min(1.0, float(action.get(\"thermal_charge_rate\", 0.0))))\n",
525
+ " action[\"batch_job_slot\"] = max(0, min(4, int(action.get(\"batch_job_slot\", 0))))\n",
526
+ " action[\"load_shed_fraction\"] = max(0.0, min(0.5, float(action.get(\"load_shed_fraction\", 0.0))))\n",
527
+ " action[\"building_id\"] = 0\n",
528
+ " else:\n",
529
+ " action = {\"hvac_power_level\": 0.5, \"thermal_charge_rate\": 0.0, \"batch_job_slot\": 0,\n",
530
+ " \"load_shed_fraction\": 0.0, \"building_id\": 0}\n",
531
+ " \n",
532
+ " r = requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
533
+ " step_data = r.json()\n",
534
+ " if isinstance(step_data, list):\n",
535
+ " step_data = step_data[0]\n",
536
+ " obs = step_data.get(\"observation\", obs)\n",
537
+ " if step_data.get(\"done\", False):\n",
538
+ " break\n",
539
+ " except:\n",
540
+ " break\n",
541
  " \n",
542
+ " try:\n",
543
+ " grade = requests.get(f\"{ENV_URL}/grade\", timeout=10).json()\n",
544
+ " return float(grade.get(\"score\", 0))\n",
545
+ " except:\n",
546
+ " return 0.0\n",
547
+ "\n",
548
+ "print(\"Evaluating trained model (2 episodes per task)...\")\n",
549
+ "trained_scores = {}\n",
550
+ "for task_id in [1, 2, 3, 4]:\n",
551
+ " scores = []\n",
552
+ " for ep in range(2):\n",
553
+ " score = run_llm_episode(task_id=task_id)\n",
554
+ " scores.append(score)\n",
555
+ " print(f\" Task {task_id} Episode {ep+1}: {score:.3f}\")\n",
556
+ " trained_scores[task_id] = sum(scores) / len(scores)\n",
557
+ "\n",
558
+ "print(f\"\\nTrained Model Scores:\")\n",
559
+ "for task_id, avg in trained_scores.items():\n",
560
+ " baseline = baseline_scores[task_id]\n",
561
+ " improvement = ((avg - baseline) / baseline * 100) if baseline > 0 else 0\n",
562
+ " print(f\" Task {task_id}: {avg:.3f} (baseline: {baseline:.3f}, {improvement:+.1f}%)\")\n",
563
+ "\n",
564
+ "trained_avg = sum(trained_scores.values()) / len(trained_scores)\n",
565
+ "baseline_avg = sum(baseline_scores.values()) / len(baseline_scores)\n",
566
+ "overall_improvement = ((trained_avg - baseline_avg) / baseline_avg * 100) if baseline_avg > 0 else 0\n",
567
+ "\n",
568
+ "print(f\"\\nOverall Scores:\")\n",
569
+ "print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
570
+ "print(f\" Trained LLM: {trained_avg:.3f}\")\n",
571
+ "print(f\" Improvement: {overall_improvement:+.1f}%\")"
572
  ]
573
  },
574
  {
575
  "cell_type": "markdown",
576
+ "id": "0f955e71",
577
  "metadata": {},
578
  "source": [
579
+ "## Step 8: Save Results"
580
  ]
581
  },
582
  {
583
  "cell_type": "code",
584
  "execution_count": null,
585
+ "id": "00844cb1",
586
  "metadata": {},
587
  "outputs": [],
588
  "source": [
589
+ "results = {\n",
590
+ " \"heuristic_baseline\": {\n",
591
+ " \"scores_by_task\": {str(k): v for k, v in baseline_scores.items()},\n",
592
+ " \"average\": baseline_avg\n",
593
+ " },\n",
594
+ " \"trained_llm\": {\n",
595
+ " \"scores_by_task\": {str(k): v for k, v in trained_scores.items()},\n",
596
+ " \"average\": trained_avg\n",
597
+ " },\n",
598
+ " \"improvement_percent\": overall_improvement,\n",
599
+ " \"model\": MODEL_NAME,\n",
600
+ " \"training_steps\": config.max_steps,\n",
601
+ " \"themes_covered\": [\"multi_agent\", \"instruction_following\", \"world_modeling\", \"curriculum\"],\n",
602
+ " \"training_rewards_log\": training_rewards[-20:] if training_rewards else [],\n",
603
+ "}\n",
604
+ "\n",
605
+ "print(\"Saving results...\")\n",
606
+ "with open(\"gridmind_training_results.json\", \"w\") as f:\n",
607
+ " _json.dump(results, f, indent=2)\n",
608
+ "\n",
609
+ "print(\"✓ Results saved to gridmind_training_results.json\")\n",
610
+ "print(f\"\\nSummary:\")\n",
611
+ "print(f\" Model: {MODEL_NAME}\")\n",
612
+ "print(f\" Themes: {results['themes_covered']}\")\n",
613
+ "print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
614
+ "print(f\" Trained LLM: {trained_avg:.3f}\")\n",
615
+ "print(f\" Improvement: {overall_improvement:+.1f}%\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  ]
617
  }
618
  ],
619
  "metadata": {
 
 
 
 
 
620
  "language_info": {
621
+ "name": "python"
 
622
  }
623
  },
624
  "nbformat": 4,
625
+ "nbformat_minor": 5
626
+ }
test_coordinator.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Quick test of coordinator endpoints."""
3
+
4
+ import requests
5
+ import json
6
+
7
+ ENV_URL = "http://localhost:7860"
8
+
9
+ print("=" * 60)
10
+ print("COORDINATOR ENDPOINT TEST")
11
+ print("=" * 60)
12
+
13
+ # Test coordinator reset
14
+ print("\n1. Testing /coordinator/reset...")
15
+ try:
16
+ r = requests.post(f"{ENV_URL}/coordinator/reset", json={}, timeout=10)
17
+ print(f" Status: {r.status_code}")
18
+ resp = r.json()
19
+ obs_list = resp.get("observations", [])
20
+ print(f" Observations count: {len(obs_list)}")
21
+ if obs_list:
22
+ print(f" First observation keys: {list(obs_list[0].keys())[:5]}")
23
+ print(f" First building temp: {obs_list[0].get('indoor_temperature', 'N/A')}°C")
24
+ except Exception as e:
25
+ print(f" ERROR: {e}")
26
+
27
+ # Test coordinator step
28
+ print("\n2. Testing /coordinator/step...")
29
+ actions = [
30
+ {"hvac_power_level": 0.5, "thermal_charge_rate": 0.0, "batch_job_slot": 0, "load_shed_fraction": 0.0, "building_id": 0},
31
+ {"hvac_power_level": 0.6, "thermal_charge_rate": 0.1, "batch_job_slot": 1, "load_shed_fraction": 0.1, "building_id": 1},
32
+ {"hvac_power_level": 0.4, "thermal_charge_rate": -0.2, "batch_job_slot": 2, "load_shed_fraction": 0.0, "building_id": 2},
33
+ ]
34
+ try:
35
+ r = requests.post(f"{ENV_URL}/coordinator/step", json=actions, timeout=10)
36
+ print(f" Status: {r.status_code}")
37
+ resp = r.json()
38
+ responses = resp.get("responses", [])
39
+ print(f" Responses count: {len(responses)}")
40
+ done = resp.get("done", False)
41
+ print(f" Episode done: {done}")
42
+
43
+ if responses:
44
+ for i, sr in enumerate(responses):
45
+ reward = sr.get("reward", 0.0)
46
+ obs = sr.get("observation", {})
47
+ temp = obs.get("indoor_temperature", "N/A")
48
+ print(f" Building {i}: reward={reward:.4f}, temp={temp}°C")
49
+ except Exception as e:
50
+ print(f" ERROR: {e}")
51
+
52
+ # Test several steps to verify stateful behavior
53
+ print("\n3. Testing multi-step coordinator episode...")
54
+ try:
55
+ # Reset
56
+ r = requests.post(f"{ENV_URL}/coordinator/reset", json={}, timeout=10)
57
+ resp = r.json()
58
+ obs_list = resp.get("observations", [])
59
+ print(f" Reset: {len(obs_list)} buildings")
60
+
61
+ # Take 3 steps
62
+ for step_num in range(3):
63
+ actions = [
64
+ {"hvac_power_level": 0.5, "thermal_charge_rate": 0.0, "batch_job_slot": 0, "load_shed_fraction": 0.0, "building_id": i}
65
+ for i in range(len(obs_list))
66
+ ]
67
+ r = requests.post(f"{ENV_URL}/coordinator/step", json=actions, timeout=10)
68
+ resp = r.json()
69
+ responses = resp.get("responses", [])
70
+ rewards = [sr.get("reward", 0.0) for sr in responses]
71
+ avg_reward = sum(rewards) / len(rewards) if rewards else 0.0
72
+ done = resp.get("done", False)
73
+ print(f" Step {step_num+1}: avg_reward={avg_reward:.4f}, done={done}")
74
+
75
+ # Update obs for next iteration
76
+ obs_list = [sr.get("observation", {}) for sr in responses]
77
+
78
+ if done:
79
+ print(f" Episode completed at step {step_num+1}")
80
+ break
81
+ except Exception as e:
82
+ print(f" ERROR: {e}")
83
+
84
+ print("\n" + "=" * 60)
85
+ print("✓ Coordinator endpoint test complete!")
86
+ print("=" * 60)
verify_readiness.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Final project readiness verification."""
3
+
4
+ import json
5
+ import os
6
+ import subprocess
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ GRIDMIND_ROOT = Path(".")
11
+
12
+ def check_file_exists(path: str, description: str) -> bool:
13
+ """Check if a file exists."""
14
+ exists = os.path.exists(path)
15
+ status = "✓" if exists else "✗"
16
+ print(f" {status} {description:<50} ({path})")
17
+ return exists
18
+
19
+ def check_directory_exists(path: str, description: str) -> bool:
20
+ """Check if a directory exists."""
21
+ exists = os.path.isdir(path)
22
+ status = "✓" if exists else "✗"
23
+ print(f" {status} {description:<50} ({path})")
24
+ return exists
25
+
26
+ def check_file_size(path: str, min_bytes: int, description: str) -> bool:
27
+ """Check if a file exists and is above minimum size."""
28
+ if not os.path.exists(path):
29
+ print(f" ✗ {description:<50} (not found)")
30
+ return False
31
+ size = os.path.getsize(path)
32
+ ok = size >= min_bytes
33
+ status = "✓" if ok else "✗"
34
+ print(f" {status} {description:<50} ({size} bytes, min {min_bytes})")
35
+ return ok
36
+
37
+ print("=" * 70)
38
+ print("GridMind-RL PROJECT READINESS CHECK")
39
+ print("=" * 70)
40
+
41
+ all_ok = True
42
+
43
+ # 1. Essential Files
44
+ print("\n1. ESSENTIAL FILES")
45
+ all_ok &= check_file_exists("main.go", "Go server main file")
46
+ all_ok &= check_file_exists("inference.py", "Python inference script")
47
+ all_ok &= check_file_exists("go.mod", "Go module file")
48
+ all_ok &= check_file_exists("go.sum", "Go dependencies")
49
+
50
+ # 2. Environment Module
51
+ print("\n2. ENVIRONMENT PACKAGE")
52
+ all_ok &= check_directory_exists("env", "Environment package directory")
53
+ all_ok &= check_file_exists("env/environment.go", "Main environment logic")
54
+ all_ok &= check_file_exists("env/models.go", "Data models")
55
+ all_ok &= check_file_exists("env/rewards.go", "Reward computation")
56
+ all_ok &= check_file_exists("env/faults.go", "Fault system")
57
+ all_ok &= check_file_exists("env/tasks.go", "Task definitions")
58
+
59
+ # 3. Python Module
60
+ print("\n3. PYTHON PACKAGE")
61
+ all_ok &= check_directory_exists("python", "Python package directory")
62
+ all_ok &= check_file_exists("python/__init__.py", "Python package init")
63
+ all_ok &= check_file_exists("python/models.py", "Python models")
64
+ all_ok &= check_file_size("python/requirements.txt", 100, "Python requirements")
65
+
66
+ # 4. Notebooks
67
+ print("\n4. NOTEBOOKS")
68
+ all_ok &= check_file_size("scripts/gridmind_grpo_colab.ipynb", 20000, "Colab notebook (≥20KB)")
69
+
70
+ # 5. Dashboard
71
+ print("\n5. DASHBOARD")
72
+ all_ok &= check_directory_exists("dashboard", "Dashboard directory")
73
+ all_ok &= check_file_exists("dashboard/server.py", "Dashboard server")
74
+ all_ok &= check_file_exists("dashboard/static/index.html", "Dashboard HTML")
75
+ all_ok &= check_file_exists("dashboard/static/dashboard.js", "Dashboard JavaScript")
76
+
77
+ # 6. Test Files
78
+ print("\n6. TEST/DEMO FILES")
79
+ all_ok &= check_file_exists("scripts/demo_run.py", "Demo runner")
80
+ all_ok &= check_file_exists("scripts/full_demo.py", "Full demo")
81
+ all_ok &= check_file_exists("tests/environment_test.go", "Go tests")
82
+
83
+ # 7. README & Docs
84
+ print("\n7. DOCUMENTATION")
85
+ all_ok &= check_file_exists("README.md", "README")
86
+ all_ok &= check_file_exists("HF_BLOG_POST.md", "Blog post")
87
+
88
+ # 8. Key Features Check
89
+ print("\n8. KEY FEATURES (Code Inspection)")
90
+ try:
91
+ with open("inference.py", encoding="utf-8-sig", errors="ignore") as f:
92
+ content = f.read()
93
+ has_coordinator = "--coordinator" in content and "coordinator_step" in content
94
+ has_curriculum = "CurriculumManager" in content
95
+ has_planning = "--use-planning" in content and "simulate" in content
96
+ status = "✓" if has_coordinator else "✗"
97
+ print(f" {status} Multi-Agent Coordinator mode (Theme 1)")
98
+ status = "✓" if has_curriculum else "✗"
99
+ print(f" {status} Curriculum Learning (Theme 4)")
100
+ status = "✓" if has_planning else "✗"
101
+ print(f" {status} World Modeling (/simulate) (Theme 3)")
102
+ all_ok &= has_coordinator and has_curriculum and has_planning
103
+ except Exception as e:
104
+ print(f" ✗ Could not read inference.py: {e}")
105
+ all_ok = False
106
+
107
+ try:
108
+ with open("main.go", encoding="utf-8-sig", errors="ignore") as f:
109
+ content = f.read()
110
+ has_coord_reset = "handleCoordinatorReset" in content
111
+ has_coord_step = "handleCoordinatorStep" in content
112
+ has_simulate = "handleSimulate" in content
113
+ has_reset = "handleReset" in content
114
+ status = "✓" if has_coord_reset else "✗"
115
+ print(f" {status} /coordinator/reset endpoint")
116
+ status = "✓" if has_coord_step else "✗"
117
+ print(f" {status} /coordinator/step endpoint")
118
+ status = "✓" if has_simulate else "✗"
119
+ print(f" {status} /simulate endpoint (world modeling)")
120
+ status = "✓" if has_reset else "✗"
121
+ print(f" {status} /reset endpoint (task 1-4 support)")
122
+ all_ok &= has_coord_reset and has_coord_step and has_simulate and has_reset
123
+ except Exception as e:
124
+ print(f" ✗ Could not read main.go: {e}")
125
+ all_ok = False
126
+
127
+ # 9. Test Quick Functionality
128
+ print("\n9. QUICK FUNCTIONALITY TEST")
129
+ try:
130
+ import requests
131
+ health = requests.get("http://localhost:7860/health", timeout=5)
132
+ if health.status_code == 200:
133
+ print(f" ✓ Server health check passed (port 7860)")
134
+ else:
135
+ print(f" ✗ Server health check failed ({health.status_code})")
136
+ all_ok = False
137
+ except Exception as e:
138
+ print(f" ✗ Could not reach server: {e}")
139
+ all_ok = False
140
+
141
+ # Final Summary
142
+ print("\n" + "=" * 70)
143
+ if all_ok:
144
+ print("✓ PROJECT READY FOR SUBMISSION")
145
+ print("=" * 70)
146
+ sys.exit(0)
147
+ else:
148
+ print("✗ SOME CHECKS FAILED - REVIEW REQUIRED")
149
+ print("=" * 70)
150
+ sys.exit(1)