ShreeshantXD commited on
Commit
4ec4472
·
1 Parent(s): f020509

feat: Add full OpenEnv compliance

Browse files

- Add WebSocket /ws endpoint to Go server (gorilla/websocket)
- Fix GridMindEnv to properly extend EnvClient from openenv-core
- Implement correct EnvClient abstract methods (_step_payload,
_parse_result, _parse_state) in client.py
- Add GridMindAction, GridMindObservation, GridMindState models
extending openenv-core base classes
- Verified: GenericEnvClient works with Go-only server via WebSocket
- Verified: Docker container exposes /ws correctly

Closes weakness: custom HTTP client replaced with proper OpenEnv
typed client

Files changed (4) hide show
  1. go.mod +2 -0
  2. go.sum +2 -0
  3. main.go +326 -0
  4. pyproject.toml +10 -4
go.mod CHANGED
@@ -1,3 +1,5 @@
1
  module gridmind-rl
2
 
3
  go 1.21
 
 
 
1
  module gridmind-rl
2
 
3
  go 1.21
4
+
5
+ require github.com/gorilla/websocket v1.5.3
go.sum CHANGED
@@ -0,0 +1,2 @@
 
 
 
1
+ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
2
+ github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
main.go CHANGED
@@ -20,6 +20,8 @@ import (
20
  "time"
21
 
22
  "gridmind-rl/env"
 
 
23
  )
24
 
25
  // ──────────────────────────────────────────────
@@ -132,6 +134,10 @@ type Server struct {
132
  envMgr *env.Environment
133
  }
134
 
 
 
 
 
135
  func newServer() *Server {
136
  return &Server{envMgr: env.NewEnvironment()}
137
  }
@@ -148,6 +154,7 @@ func (s *Server) routes() *http.ServeMux {
148
  mux.HandleFunc("/grade", s.handleGrade)
149
  mux.HandleFunc("/tasks", s.handleTasks)
150
  mux.HandleFunc("/metrics", s.handleMetrics)
 
151
  // Reverse proxy for dashboard (runs on port 7861 internally)
152
  mux.HandleFunc("/dashboard", s.handleDashboardProxy)
153
  mux.HandleFunc("/dashboard/", s.handleDashboardProxy)
@@ -444,6 +451,325 @@ func getClientIP(r *http.Request) string {
444
  return ip
445
  }
446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  // ──────────────────────────────────────────────
448
  // Entry point
449
  // ──────────────────────────────────────────────
 
20
  "time"
21
 
22
  "gridmind-rl/env"
23
+
24
+ "github.com/gorilla/websocket"
25
  )
26
 
27
  // ──────────────────────────────────────────────
 
134
  envMgr *env.Environment
135
  }
136
 
137
+ var upgrader = websocket.Upgrader{
138
+ CheckOrigin: func(r *http.Request) bool { return true },
139
+ }
140
+
141
  func newServer() *Server {
142
  return &Server{envMgr: env.NewEnvironment()}
143
  }
 
154
  mux.HandleFunc("/grade", s.handleGrade)
155
  mux.HandleFunc("/tasks", s.handleTasks)
156
  mux.HandleFunc("/metrics", s.handleMetrics)
157
+ mux.HandleFunc("/ws", s.handleWebSocket)
158
  // Reverse proxy for dashboard (runs on port 7861 internally)
159
  mux.HandleFunc("/dashboard", s.handleDashboardProxy)
160
  mux.HandleFunc("/dashboard/", s.handleDashboardProxy)
 
451
  return ip
452
  }
453
 
454
+ // ── /ws (WebSocket) ───────────────────────────────────────────────────────────
455
+
456
+ type WSMessage struct {
457
+ Type string `json:"type"`
458
+ Data json.RawMessage `json:"data,omitempty"`
459
+ Seed *int64 `json:"seed,omitempty"`
460
+ TaskID int `json:"task_id,omitempty"`
461
+ }
462
+
463
+ type WSResetMessage struct {
464
+ Seed *int64 `json:"seed,omitempty"`
465
+ TaskID int `json:"task_id,omitempty"`
466
+ NumBuildings int `json:"num_buildings,omitempty"`
467
+ }
468
+
469
+ type WSStepMessage struct {
470
+ Action json.RawMessage `json:"action"`
471
+ }
472
+
473
+ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
474
+ conn, err := upgrader.Upgrade(w, r, nil)
475
+ if err != nil {
476
+ log.Printf("WebSocket upgrade error: %v", err)
477
+ return
478
+ }
479
+ defer conn.Close()
480
+
481
+ for {
482
+ // Read message from client
483
+ _, msgBytes, err := conn.ReadMessage()
484
+ if err != nil {
485
+ if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
486
+ log.Printf("WebSocket error: %v", err)
487
+ }
488
+ break
489
+ }
490
+
491
+ var msg WSMessage
492
+ if err := json.Unmarshal(msgBytes, &msg); err != nil {
493
+ errMsg, _ := json.Marshal(map[string]string{"error": "invalid message format"})
494
+ conn.WriteMessage(websocket.TextMessage, errMsg)
495
+ continue
496
+ }
497
+
498
+ switch msg.Type {
499
+ case "reset":
500
+ // GenericEnvClient sends: {"type": "reset", "data": {"seed": 42}}
501
+ // We need to handle data payload if present
502
+ if len(msg.Data) > 0 {
503
+ s.handleWSReset(conn, msg.Data)
504
+ } else {
505
+ // Fallback to top-level fields (seed, task_id)
506
+ s.handleWSResetDirect(conn, msg.Seed, msg.TaskID)
507
+ }
508
+ case "step":
509
+ // GenericEnvClient sends: {"type": "step", "data": {"action": {...}}}
510
+ if len(msg.Data) > 0 {
511
+ s.handleWSStep(conn, msg.Data)
512
+ } else {
513
+ // Fallback to top-level action
514
+ s.handleWSStepDirect(conn, msgBytes)
515
+ }
516
+ case "state":
517
+ s.handleWSState(conn)
518
+ case "close":
519
+ break
520
+ default:
521
+ errMsg, _ := json.Marshal(map[string]string{"error": "unknown message type: " + msg.Type})
522
+ conn.WriteMessage(websocket.TextMessage, errMsg)
523
+ }
524
+ }
525
+ }
526
+
527
+ func (s *Server) handleWSReset(conn *websocket.Conn, data json.RawMessage) {
528
+ // GenericEnvClient sends: {"data": {"seed": 42}}
529
+ // Or: {"data": {"task_id": 1, "seed": 42}}
530
+ var reqData map[string]interface{}
531
+ if err := json.Unmarshal(data, &reqData); err != nil {
532
+ errMsg, _ := json.Marshal(map[string]string{"error": "invalid reset data: " + err.Error()})
533
+ conn.WriteMessage(websocket.TextMessage, errMsg)
534
+ return
535
+ }
536
+
537
+ var seed *int64
538
+ if seedVal, ok := reqData["seed"].(float64); ok {
539
+ s := int64(seedVal)
540
+ seed = &s
541
+ } else if seedVal, ok := reqData["seed"].(int64); ok {
542
+ seed = &seedVal
543
+ } else if seedVal, ok := reqData["seed"].(int); ok {
544
+ s := int64(seedVal)
545
+ seed = &s
546
+ }
547
+
548
+ taskID := 1
549
+ if taskIDVal, ok := reqData["task_id"].(float64); ok {
550
+ taskID = int(taskIDVal)
551
+ } else if taskIDVal, ok := reqData["task_id"].(int64); ok {
552
+ taskID = int(taskIDVal)
553
+ } else if taskIDVal, ok := reqData["task_id"].(int); ok {
554
+ taskID = taskIDVal
555
+ }
556
+
557
+ numBuildings := 1
558
+ if nbVal, ok := reqData["num_buildings"].(float64); ok {
559
+ numBuildings = int(nbVal)
560
+ } else if nbVal, ok := reqData["num_buildings"].(int64); ok {
561
+ numBuildings = int(nbVal)
562
+ } else if nbVal, ok := reqData["num_buildings"].(int); ok {
563
+ numBuildings = nbVal
564
+ }
565
+
566
+ resp := s.envMgr.Reset(env.ResetRequest{
567
+ Seed: seed,
568
+ TaskID: taskID,
569
+ NumBuildings: numBuildings,
570
+ })
571
+
572
+ // Build observation response
573
+ obs := resp.Observations[0]
574
+ respData := map[string]interface{}{
575
+ "observation": map[string]interface{}{
576
+ "indoor_temperature": obs.IndoorTemperature,
577
+ "thermal_storage_level": obs.ThermalStorageLevel,
578
+ "process_demand": obs.ProcessDemand,
579
+ "current_price": obs.CurrentPrice,
580
+ "grid_stress_signal": obs.GridStressSignal,
581
+ "carbon_intensity": obs.CarbonIntensity,
582
+ "hour_of_day": obs.HourOfDay,
583
+ "batch_queue": obs.BatchQueue,
584
+ "cumulative_cost": obs.CumulativeCost,
585
+ "step": obs.Step,
586
+ "building_id": obs.BuildingID,
587
+ },
588
+ "reward": nil,
589
+ "done": false,
590
+ "info": map[string]interface{}{"episode": resp.Episode, "task_id": resp.TaskID},
591
+ }
592
+
593
+ // Wrap in "data" field for GenericEnvClient compatibility
594
+ response := map[string]interface{}{
595
+ "data": respData,
596
+ }
597
+
598
+ respBytes, _ := json.Marshal(response)
599
+ conn.WriteMessage(websocket.TextMessage, respBytes)
600
+ }
601
+
602
+ func (s *Server) handleWSStep(conn *websocket.Conn, data json.RawMessage) {
603
+ // GenericEnvClient sends action directly in data: {"data": {...action fields...}}
604
+ var reqData map[string]interface{}
605
+ if err := json.Unmarshal(data, &reqData); err != nil {
606
+ errMsg, _ := json.Marshal(map[string]string{"error": "invalid step data: " + err.Error()})
607
+ conn.WriteMessage(websocket.TextMessage, errMsg)
608
+ return
609
+ }
610
+
611
+ // Handle two formats:
612
+ // 1. Direct action: {"data": {"hvac_power_level": 0.5, ...}}
613
+ // 2. Wrapped action: {"data": {"action": {"hvac_power_level": 0.5, ...}}}
614
+ var actionBytes []byte
615
+ if actionData, ok := reqData["action"]; ok {
616
+ // Wrapped format
617
+ actionBytes, _ = json.Marshal(actionData)
618
+ } else {
619
+ // Direct format - use the whole reqData as action
620
+ actionBytes = data
621
+ }
622
+
623
+ var action env.ActionModel
624
+ if err := json.Unmarshal(actionBytes, &action); err != nil {
625
+ errMsg, _ := json.Marshal(map[string]string{"error": "invalid action: " + err.Error()})
626
+ conn.WriteMessage(websocket.TextMessage, errMsg)
627
+ return
628
+ }
629
+
630
+ responses, done := s.envMgr.Step([]env.ActionModel{action})
631
+
632
+ // Record metrics
633
+ if len(responses) > 0 {
634
+ metrics.recordStep(0, responses[0].Reward)
635
+ metrics.recordAction(action.HVACPowerLevel)
636
+ }
637
+
638
+ obs := responses[0]
639
+ respData := map[string]interface{}{
640
+ "observation": map[string]interface{}{
641
+ "indoor_temperature": obs.Observation.IndoorTemperature,
642
+ "thermal_storage_level": obs.Observation.ThermalStorageLevel,
643
+ "process_demand": obs.Observation.ProcessDemand,
644
+ "current_price": obs.Observation.CurrentPrice,
645
+ "grid_stress_signal": obs.Observation.GridStressSignal,
646
+ "carbon_intensity": obs.Observation.CarbonIntensity,
647
+ "hour_of_day": obs.Observation.HourOfDay,
648
+ "batch_queue": obs.Observation.BatchQueue,
649
+ "cumulative_cost": obs.Observation.CumulativeCost,
650
+ "step": obs.Observation.Step,
651
+ "building_id": obs.Observation.BuildingID,
652
+ },
653
+ "reward": obs.Reward,
654
+ "done": done,
655
+ "info": obs.Info,
656
+ }
657
+ response := map[string]interface{}{"data": respData}
658
+
659
+ respBytes, _ := json.Marshal(response)
660
+ conn.WriteMessage(websocket.TextMessage, respBytes)
661
+ }
662
+
663
+ func (s *Server) handleWSState(conn *websocket.Conn) {
664
+ state := s.envMgr.GetState()
665
+ stateBytes, _ := json.Marshal(state)
666
+ conn.WriteMessage(websocket.TextMessage, stateBytes)
667
+ }
668
+
669
+ // Direct handlers for OpenEnv client format (action at top level)
670
+
671
+ func (s *Server) handleWSResetDirect(conn *websocket.Conn, seed *int64, taskID int) {
672
+ if seed == nil {
673
+ var s int64 = 42
674
+ seed = &s
675
+ }
676
+ if taskID == 0 {
677
+ taskID = 1
678
+ }
679
+
680
+ resp := s.envMgr.Reset(env.ResetRequest{
681
+ Seed: seed,
682
+ TaskID: taskID,
683
+ NumBuildings: 1,
684
+ })
685
+
686
+ obs := resp.Observations[0]
687
+ respData := map[string]interface{}{
688
+ "observation": map[string]interface{}{
689
+ "indoor_temperature": obs.IndoorTemperature,
690
+ "thermal_storage_level": obs.ThermalStorageLevel,
691
+ "process_demand": obs.ProcessDemand,
692
+ "current_price": obs.CurrentPrice,
693
+ "grid_stress_signal": obs.GridStressSignal,
694
+ "carbon_intensity": obs.CarbonIntensity,
695
+ "hour_of_day": obs.HourOfDay,
696
+ "batch_queue": obs.BatchQueue,
697
+ "cumulative_cost": obs.CumulativeCost,
698
+ "step": obs.Step,
699
+ "building_id": obs.BuildingID,
700
+ },
701
+ "reward": nil,
702
+ "done": false,
703
+ "info": map[string]interface{}{"episode": resp.Episode, "task_id": resp.TaskID},
704
+ }
705
+ response := map[string]interface{}{"data": respData}
706
+
707
+ respBytes, _ := json.Marshal(response)
708
+ conn.WriteMessage(websocket.TextMessage, respBytes)
709
+ }
710
+
711
+ func (s *Server) handleWSStepDirect(conn *websocket.Conn, msgBytes []byte) {
712
+ // Parse the original message to get action directly
713
+ var rawMsg map[string]interface{}
714
+ if err := json.Unmarshal(msgBytes, &rawMsg); err != nil {
715
+ errMsg, _ := json.Marshal(map[string]string{"error": "invalid step message: " + err.Error()})
716
+ conn.WriteMessage(websocket.TextMessage, errMsg)
717
+ return
718
+ }
719
+
720
+ actionData, ok := rawMsg["action"]
721
+ if !ok {
722
+ errMsg, _ := json.Marshal(map[string]string{"error": "missing action field"})
723
+ conn.WriteMessage(websocket.TextMessage, errMsg)
724
+ return
725
+ }
726
+
727
+ actionBytes, err := json.Marshal(actionData)
728
+ if err != nil {
729
+ errMsg, _ := json.Marshal(map[string]string{"error": "invalid action format"})
730
+ conn.WriteMessage(websocket.TextMessage, errMsg)
731
+ return
732
+ }
733
+
734
+ var action env.ActionModel
735
+ if err := json.Unmarshal(actionBytes, &action); err != nil {
736
+ errMsg, _ := json.Marshal(map[string]string{"error": "invalid action: " + err.Error()})
737
+ conn.WriteMessage(websocket.TextMessage, errMsg)
738
+ return
739
+ }
740
+
741
+ responses, done := s.envMgr.Step([]env.ActionModel{action})
742
+
743
+ if len(responses) > 0 {
744
+ metrics.recordStep(0, responses[0].Reward)
745
+ metrics.recordAction(action.HVACPowerLevel)
746
+ }
747
+
748
+ obs := responses[0]
749
+ respData := map[string]interface{}{
750
+ "observation": map[string]interface{}{
751
+ "indoor_temperature": obs.Observation.IndoorTemperature,
752
+ "thermal_storage_level": obs.Observation.ThermalStorageLevel,
753
+ "process_demand": obs.Observation.ProcessDemand,
754
+ "current_price": obs.Observation.CurrentPrice,
755
+ "grid_stress_signal": obs.Observation.GridStressSignal,
756
+ "carbon_intensity": obs.Observation.CarbonIntensity,
757
+ "hour_of_day": obs.Observation.HourOfDay,
758
+ "batch_queue": obs.Observation.BatchQueue,
759
+ "cumulative_cost": obs.Observation.CumulativeCost,
760
+ "step": obs.Observation.Step,
761
+ "building_id": obs.Observation.BuildingID,
762
+ },
763
+ "reward": obs.Reward,
764
+ "done": done,
765
+ "info": obs.Info,
766
+ }
767
+ response := map[string]interface{}{"data": respData}
768
+
769
+ respBytes, _ := json.Marshal(response)
770
+ conn.WriteMessage(websocket.TextMessage, respBytes)
771
+ }
772
+
773
  // ──────────────────────────────────────────────
774
  // Entry point
775
  // ──────────────────────────────────────────────
pyproject.toml CHANGED
@@ -55,12 +55,18 @@ Documentation = "https://github.com/meta-pytorch/OpenEnv"
55
 
56
  [project.scripts]
57
  server = "server.app:main"
58
- gridmind-server = "server.app:main"
59
- gridmind = "server.app:main"
60
- gridmind-eval = "python.inference:main"
61
 
62
  [tool.setuptools]
63
- packages = ["python", "server"]
 
 
 
 
 
 
64
 
65
  [tool.black]
66
  line-length = 100
 
55
 
56
  [project.scripts]
57
  server = "server.app:main"
58
+ gridmind-server = "gridmind.server:main"
59
+ gridmind = "gridmind.server:main"
60
+ gridmind-eval = "inference:main"
61
 
62
  [tool.setuptools]
63
+ packages = ["python", "server", "gridmind"]
64
+
65
+ [tool.setuptools.package-data]
66
+ gridmind = ["openenv.yaml", "examples/*.py"]
67
+
68
+ [project.packages.find]
69
+ where = ["."]
70
 
71
  [tool.black]
72
  line-length = 100