adityss commited on
Commit
ebe8fa5
·
1 Parent(s): c009bc5

feat: add multi-agent and planning CLI flags to inference and expose environment metadata via /info endpoint

Browse files
Files changed (2) hide show
  1. inference.py +17 -2
  2. main.go +50 -0
inference.py CHANGED
@@ -429,6 +429,8 @@ def run_episode(
429
  llm_every: int,
430
  max_steps: Optional[int],
431
  verbose: bool = False,
 
 
432
  ) -> dict[str, Any]:
433
  """Run a single episode and emit hackathon-compliant stdout format."""
434
  task_name = f"gridmind-task-{task_id}"
@@ -450,7 +452,8 @@ def run_episode(
450
  obs: dict[str, Any] = {}
451
 
452
  try:
453
- reset_resp = env_client.reset(task_id=task_id, seed=seed)
 
454
  if reset_resp is None:
455
  raise RuntimeError("reset failed")
456
  obs_list = reset_resp.get("observations", [{}])
@@ -483,7 +486,7 @@ def run_episode(
483
  # C5: World Modeling - Use /simulate when efficiency is low or faults active
484
  hvac_eff = obs.get("hvac_efficiency", 1.0)
485
  active_faults_list = obs.get("active_faults", [])
486
- use_simulation = not fast_mode and (hvac_eff < 0.7 or len(active_faults_list) > 0)
487
 
488
  sim_result = None
489
  sim_reward = None
@@ -705,6 +708,16 @@ def main() -> None:
705
  action="store_true",
706
  help="Enable automatic task curriculum (Theme 4: Self-Improvement)",
707
  )
 
 
 
 
 
 
 
 
 
 
708
  args = parser.parse_args()
709
 
710
  server_proc = start_environment_server(port=7860)
@@ -751,6 +764,8 @@ def main() -> None:
751
  llm_every=args.llm_every,
752
  max_steps=args.max_steps,
753
  verbose=args.verbose,
 
 
754
  )
755
  task_scores.append(float(result["score"]))
756
  all_results.append(result)
 
429
  llm_every: int,
430
  max_steps: Optional[int],
431
  verbose: bool = False,
432
+ coordinator: bool = False,
433
+ use_planning: bool = False,
434
  ) -> dict[str, Any]:
435
  """Run a single episode and emit hackathon-compliant stdout format."""
436
  task_name = f"gridmind-task-{task_id}"
 
452
  obs: dict[str, Any] = {}
453
 
454
  try:
455
+ num_buildings = 3 if coordinator else 1
456
+ reset_resp = env_client.reset(task_id=task_id, seed=seed, num_buildings=num_buildings)
457
  if reset_resp is None:
458
  raise RuntimeError("reset failed")
459
  obs_list = reset_resp.get("observations", [{}])
 
486
  # C5: World Modeling - Use /simulate when efficiency is low or faults active
487
  hvac_eff = obs.get("hvac_efficiency", 1.0)
488
  active_faults_list = obs.get("active_faults", [])
489
+ use_simulation = not fast_mode and (use_planning or hvac_eff < 0.7 or len(active_faults_list) > 0)
490
 
491
  sim_result = None
492
  sim_reward = None
 
708
  action="store_true",
709
  help="Enable automatic task curriculum (Theme 4: Self-Improvement)",
710
  )
711
+ parser.add_argument(
712
+ "--coordinator",
713
+ action="store_true",
714
+ help="Multi-building coordinator mode: reset with 3 buildings (Theme 1: Multi-Agent)",
715
+ )
716
+ parser.add_argument(
717
+ "--use-planning",
718
+ action="store_true",
719
+ help="Force /simulate world-model call on every step (Theme 3: World Modeling)",
720
+ )
721
  args = parser.parse_args()
722
 
723
  server_proc = start_environment_server(port=7860)
 
764
  llm_every=args.llm_every,
765
  max_steps=args.max_steps,
766
  verbose=args.verbose,
767
+ coordinator=args.coordinator,
768
+ use_planning=args.use_planning,
769
  )
770
  task_scores.append(float(result["score"]))
771
  all_results.append(result)
main.go CHANGED
@@ -158,6 +158,7 @@ func (s *Server) routes() *http.ServeMux {
158
  mux.HandleFunc("/tasks", s.handleTasks)
159
  mux.HandleFunc("/metrics", s.handleMetrics)
160
  mux.HandleFunc("/ws", s.handleWebSocket)
 
161
  // Reverse proxy for dashboard (runs on port 7861 internally)
162
  mux.HandleFunc("/dashboard", s.handleDashboardProxy)
163
  mux.HandleFunc("/dashboard/", s.handleDashboardProxy)
@@ -879,3 +880,52 @@ func withCORS(next http.Handler) http.Handler {
879
  next.ServeHTTP(w, r)
880
  })
881
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  mux.HandleFunc("/tasks", s.handleTasks)
159
  mux.HandleFunc("/metrics", s.handleMetrics)
160
  mux.HandleFunc("/ws", s.handleWebSocket)
161
+ mux.HandleFunc("/info", s.handleInfo)
162
  // Reverse proxy for dashboard (runs on port 7861 internally)
163
  mux.HandleFunc("/dashboard", s.handleDashboardProxy)
164
  mux.HandleFunc("/dashboard/", s.handleDashboardProxy)
 
880
  next.ServeHTTP(w, r)
881
  })
882
  }
883
+
884
+ // handleInfo returns OpenEnv-standard metadata for automated validators and judges.
885
+ func (s *Server) handleInfo(w http.ResponseWriter, r *http.Request) {
886
+ if r.Method != http.MethodGet {
887
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
888
+ return
889
+ }
890
+ info := map[string]interface{}{
891
+ "name": "gridmind-rl",
892
+ "version": "2.0.0",
893
+ "description": "Multi-building industrial energy management RL environment with instruction-following, world modeling, fault injection, and curriculum learning.",
894
+ "multi_agent": true,
895
+ "themes": []string{
896
+ "multi-agent",
897
+ "long-horizon-planning",
898
+ "world-modeling",
899
+ "self-improvement",
900
+ },
901
+ "observation_space": map[string]interface{}{
902
+ "type": "dict",
903
+ "fields": []string{
904
+ "indoor_temperature", "thermal_storage_level", "current_price",
905
+ "grid_stress_signal", "carbon_intensity", "hour_of_day", "step",
906
+ "hvac_efficiency", "process_demand", "cumulative_cost",
907
+ "batch_queue", "active_faults", "instruction_card",
908
+ },
909
+ },
910
+ "action_space": map[string]interface{}{
911
+ "type": "dict",
912
+ "fields": map[string]string{
913
+ "hvac_power_level": "float [0.0, 1.0]",
914
+ "thermal_charge_rate": "float [-1.0, 1.0]",
915
+ "batch_job_slot": "int [0, 4]",
916
+ "load_shed_fraction": "float [0.0, 0.5]",
917
+ "building_id": "int [0, N_buildings-1]",
918
+ },
919
+ },
920
+ "endpoints": []string{
921
+ "POST /reset", "POST /step", "GET /grade", "GET /tasks",
922
+ "GET /state", "POST /simulate", "GET /feeder", "POST /coordinate",
923
+ "GET /health", "GET /info",
924
+ },
925
+ "hf_space": "https://lo-kyu-gridmind.hf.space",
926
+ "github": "https://github.com/LO-Kyu/gridmind",
927
+ }
928
+ w.Header().Set("Content-Type", "application/json")
929
+ w.Header().Set("Access-Control-Allow-Origin", "*")
930
+ json.NewEncoder(w).Encode(info)
931
+ }