Spaces:
Running
Running
Commit ·
4ec4472
1
Parent(s): f020509
feat: Add full OpenEnv compliance
Browse files- Add WebSocket /ws endpoint to Go server (gorilla/websocket)
- Fix GridMindEnv to properly extend EnvClient from openenv-core
- Implement correct EnvClient abstract methods (_step_payload,
_parse_result, _parse_state) in client.py
- Add GridMindAction, GridMindObservation, GridMindState models
extending openenv-core base classes
- Verified: GenericEnvClient works with Go-only server via WebSocket
- Verified: Docker container exposes /ws correctly
Closes weakness: custom HTTP client replaced with proper OpenEnv
typed client
go.mod
CHANGED
|
@@ -1,3 +1,5 @@
|
|
| 1 |
module gridmind-rl
|
| 2 |
|
| 3 |
go 1.21
|
|
|
|
|
|
|
|
|
| 1 |
module gridmind-rl
|
| 2 |
|
| 3 |
go 1.21
|
| 4 |
+
|
| 5 |
+
require github.com/gorilla/websocket v1.5.3
|
go.sum
CHANGED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
| 2 |
+
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
main.go
CHANGED
|
@@ -20,6 +20,8 @@ import (
|
|
| 20 |
"time"
|
| 21 |
|
| 22 |
"gridmind-rl/env"
|
|
|
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
// ──────────────────────────────────────────────
|
|
@@ -132,6 +134,10 @@ type Server struct {
|
|
| 132 |
envMgr *env.Environment
|
| 133 |
}
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
func newServer() *Server {
|
| 136 |
return &Server{envMgr: env.NewEnvironment()}
|
| 137 |
}
|
|
@@ -148,6 +154,7 @@ func (s *Server) routes() *http.ServeMux {
|
|
| 148 |
mux.HandleFunc("/grade", s.handleGrade)
|
| 149 |
mux.HandleFunc("/tasks", s.handleTasks)
|
| 150 |
mux.HandleFunc("/metrics", s.handleMetrics)
|
|
|
|
| 151 |
// Reverse proxy for dashboard (runs on port 7861 internally)
|
| 152 |
mux.HandleFunc("/dashboard", s.handleDashboardProxy)
|
| 153 |
mux.HandleFunc("/dashboard/", s.handleDashboardProxy)
|
|
@@ -444,6 +451,325 @@ func getClientIP(r *http.Request) string {
|
|
| 444 |
return ip
|
| 445 |
}
|
| 446 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
// ──────────────────────────────────────────────
|
| 448 |
// Entry point
|
| 449 |
// ──────────────────────────────────────────────
|
|
|
|
| 20 |
"time"
|
| 21 |
|
| 22 |
"gridmind-rl/env"
|
| 23 |
+
|
| 24 |
+
"github.com/gorilla/websocket"
|
| 25 |
)
|
| 26 |
|
| 27 |
// ──────────────────────────────────────────────
|
|
|
|
| 134 |
envMgr *env.Environment
|
| 135 |
}
|
| 136 |
|
| 137 |
+
var upgrader = websocket.Upgrader{
|
| 138 |
+
CheckOrigin: func(r *http.Request) bool { return true },
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
func newServer() *Server {
|
| 142 |
return &Server{envMgr: env.NewEnvironment()}
|
| 143 |
}
|
|
|
|
| 154 |
mux.HandleFunc("/grade", s.handleGrade)
|
| 155 |
mux.HandleFunc("/tasks", s.handleTasks)
|
| 156 |
mux.HandleFunc("/metrics", s.handleMetrics)
|
| 157 |
+
mux.HandleFunc("/ws", s.handleWebSocket)
|
| 158 |
// Reverse proxy for dashboard (runs on port 7861 internally)
|
| 159 |
mux.HandleFunc("/dashboard", s.handleDashboardProxy)
|
| 160 |
mux.HandleFunc("/dashboard/", s.handleDashboardProxy)
|
|
|
|
| 451 |
return ip
|
| 452 |
}
|
| 453 |
|
| 454 |
+
// ── /ws (WebSocket) ───────────────────────────────────────────────────────────
|
| 455 |
+
|
| 456 |
+
type WSMessage struct {
|
| 457 |
+
Type string `json:"type"`
|
| 458 |
+
Data json.RawMessage `json:"data,omitempty"`
|
| 459 |
+
Seed *int64 `json:"seed,omitempty"`
|
| 460 |
+
TaskID int `json:"task_id,omitempty"`
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
type WSResetMessage struct {
|
| 464 |
+
Seed *int64 `json:"seed,omitempty"`
|
| 465 |
+
TaskID int `json:"task_id,omitempty"`
|
| 466 |
+
NumBuildings int `json:"num_buildings,omitempty"`
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
type WSStepMessage struct {
|
| 470 |
+
Action json.RawMessage `json:"action"`
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
| 474 |
+
conn, err := upgrader.Upgrade(w, r, nil)
|
| 475 |
+
if err != nil {
|
| 476 |
+
log.Printf("WebSocket upgrade error: %v", err)
|
| 477 |
+
return
|
| 478 |
+
}
|
| 479 |
+
defer conn.Close()
|
| 480 |
+
|
| 481 |
+
for {
|
| 482 |
+
// Read message from client
|
| 483 |
+
_, msgBytes, err := conn.ReadMessage()
|
| 484 |
+
if err != nil {
|
| 485 |
+
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
| 486 |
+
log.Printf("WebSocket error: %v", err)
|
| 487 |
+
}
|
| 488 |
+
break
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
var msg WSMessage
|
| 492 |
+
if err := json.Unmarshal(msgBytes, &msg); err != nil {
|
| 493 |
+
errMsg, _ := json.Marshal(map[string]string{"error": "invalid message format"})
|
| 494 |
+
conn.WriteMessage(websocket.TextMessage, errMsg)
|
| 495 |
+
continue
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
switch msg.Type {
|
| 499 |
+
case "reset":
|
| 500 |
+
// GenericEnvClient sends: {"type": "reset", "data": {"seed": 42}}
|
| 501 |
+
// We need to handle data payload if present
|
| 502 |
+
if len(msg.Data) > 0 {
|
| 503 |
+
s.handleWSReset(conn, msg.Data)
|
| 504 |
+
} else {
|
| 505 |
+
// Fallback to top-level fields (seed, task_id)
|
| 506 |
+
s.handleWSResetDirect(conn, msg.Seed, msg.TaskID)
|
| 507 |
+
}
|
| 508 |
+
case "step":
|
| 509 |
+
// GenericEnvClient sends: {"type": "step", "data": {"action": {...}}}
|
| 510 |
+
if len(msg.Data) > 0 {
|
| 511 |
+
s.handleWSStep(conn, msg.Data)
|
| 512 |
+
} else {
|
| 513 |
+
// Fallback to top-level action
|
| 514 |
+
s.handleWSStepDirect(conn, msgBytes)
|
| 515 |
+
}
|
| 516 |
+
case "state":
|
| 517 |
+
s.handleWSState(conn)
|
| 518 |
+
case "close":
|
| 519 |
+
break
|
| 520 |
+
default:
|
| 521 |
+
errMsg, _ := json.Marshal(map[string]string{"error": "unknown message type: " + msg.Type})
|
| 522 |
+
conn.WriteMessage(websocket.TextMessage, errMsg)
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
func (s *Server) handleWSReset(conn *websocket.Conn, data json.RawMessage) {
|
| 528 |
+
// GenericEnvClient sends: {"data": {"seed": 42}}
|
| 529 |
+
// Or: {"data": {"task_id": 1, "seed": 42}}
|
| 530 |
+
var reqData map[string]interface{}
|
| 531 |
+
if err := json.Unmarshal(data, &reqData); err != nil {
|
| 532 |
+
errMsg, _ := json.Marshal(map[string]string{"error": "invalid reset data: " + err.Error()})
|
| 533 |
+
conn.WriteMessage(websocket.TextMessage, errMsg)
|
| 534 |
+
return
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
var seed *int64
|
| 538 |
+
if seedVal, ok := reqData["seed"].(float64); ok {
|
| 539 |
+
s := int64(seedVal)
|
| 540 |
+
seed = &s
|
| 541 |
+
} else if seedVal, ok := reqData["seed"].(int64); ok {
|
| 542 |
+
seed = &seedVal
|
| 543 |
+
} else if seedVal, ok := reqData["seed"].(int); ok {
|
| 544 |
+
s := int64(seedVal)
|
| 545 |
+
seed = &s
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
taskID := 1
|
| 549 |
+
if taskIDVal, ok := reqData["task_id"].(float64); ok {
|
| 550 |
+
taskID = int(taskIDVal)
|
| 551 |
+
} else if taskIDVal, ok := reqData["task_id"].(int64); ok {
|
| 552 |
+
taskID = int(taskIDVal)
|
| 553 |
+
} else if taskIDVal, ok := reqData["task_id"].(int); ok {
|
| 554 |
+
taskID = taskIDVal
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
numBuildings := 1
|
| 558 |
+
if nbVal, ok := reqData["num_buildings"].(float64); ok {
|
| 559 |
+
numBuildings = int(nbVal)
|
| 560 |
+
} else if nbVal, ok := reqData["num_buildings"].(int64); ok {
|
| 561 |
+
numBuildings = int(nbVal)
|
| 562 |
+
} else if nbVal, ok := reqData["num_buildings"].(int); ok {
|
| 563 |
+
numBuildings = nbVal
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
resp := s.envMgr.Reset(env.ResetRequest{
|
| 567 |
+
Seed: seed,
|
| 568 |
+
TaskID: taskID,
|
| 569 |
+
NumBuildings: numBuildings,
|
| 570 |
+
})
|
| 571 |
+
|
| 572 |
+
// Build observation response
|
| 573 |
+
obs := resp.Observations[0]
|
| 574 |
+
respData := map[string]interface{}{
|
| 575 |
+
"observation": map[string]interface{}{
|
| 576 |
+
"indoor_temperature": obs.IndoorTemperature,
|
| 577 |
+
"thermal_storage_level": obs.ThermalStorageLevel,
|
| 578 |
+
"process_demand": obs.ProcessDemand,
|
| 579 |
+
"current_price": obs.CurrentPrice,
|
| 580 |
+
"grid_stress_signal": obs.GridStressSignal,
|
| 581 |
+
"carbon_intensity": obs.CarbonIntensity,
|
| 582 |
+
"hour_of_day": obs.HourOfDay,
|
| 583 |
+
"batch_queue": obs.BatchQueue,
|
| 584 |
+
"cumulative_cost": obs.CumulativeCost,
|
| 585 |
+
"step": obs.Step,
|
| 586 |
+
"building_id": obs.BuildingID,
|
| 587 |
+
},
|
| 588 |
+
"reward": nil,
|
| 589 |
+
"done": false,
|
| 590 |
+
"info": map[string]interface{}{"episode": resp.Episode, "task_id": resp.TaskID},
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
// Wrap in "data" field for GenericEnvClient compatibility
|
| 594 |
+
response := map[string]interface{}{
|
| 595 |
+
"data": respData,
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
respBytes, _ := json.Marshal(response)
|
| 599 |
+
conn.WriteMessage(websocket.TextMessage, respBytes)
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
func (s *Server) handleWSStep(conn *websocket.Conn, data json.RawMessage) {
|
| 603 |
+
// GenericEnvClient sends action directly in data: {"data": {...action fields...}}
|
| 604 |
+
var reqData map[string]interface{}
|
| 605 |
+
if err := json.Unmarshal(data, &reqData); err != nil {
|
| 606 |
+
errMsg, _ := json.Marshal(map[string]string{"error": "invalid step data: " + err.Error()})
|
| 607 |
+
conn.WriteMessage(websocket.TextMessage, errMsg)
|
| 608 |
+
return
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
// Handle two formats:
|
| 612 |
+
// 1. Direct action: {"data": {"hvac_power_level": 0.5, ...}}
|
| 613 |
+
// 2. Wrapped action: {"data": {"action": {"hvac_power_level": 0.5, ...}}}
|
| 614 |
+
var actionBytes []byte
|
| 615 |
+
if actionData, ok := reqData["action"]; ok {
|
| 616 |
+
// Wrapped format
|
| 617 |
+
actionBytes, _ = json.Marshal(actionData)
|
| 618 |
+
} else {
|
| 619 |
+
// Direct format - use the whole reqData as action
|
| 620 |
+
actionBytes = data
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
var action env.ActionModel
|
| 624 |
+
if err := json.Unmarshal(actionBytes, &action); err != nil {
|
| 625 |
+
errMsg, _ := json.Marshal(map[string]string{"error": "invalid action: " + err.Error()})
|
| 626 |
+
conn.WriteMessage(websocket.TextMessage, errMsg)
|
| 627 |
+
return
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
responses, done := s.envMgr.Step([]env.ActionModel{action})
|
| 631 |
+
|
| 632 |
+
// Record metrics
|
| 633 |
+
if len(responses) > 0 {
|
| 634 |
+
metrics.recordStep(0, responses[0].Reward)
|
| 635 |
+
metrics.recordAction(action.HVACPowerLevel)
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
obs := responses[0]
|
| 639 |
+
respData := map[string]interface{}{
|
| 640 |
+
"observation": map[string]interface{}{
|
| 641 |
+
"indoor_temperature": obs.Observation.IndoorTemperature,
|
| 642 |
+
"thermal_storage_level": obs.Observation.ThermalStorageLevel,
|
| 643 |
+
"process_demand": obs.Observation.ProcessDemand,
|
| 644 |
+
"current_price": obs.Observation.CurrentPrice,
|
| 645 |
+
"grid_stress_signal": obs.Observation.GridStressSignal,
|
| 646 |
+
"carbon_intensity": obs.Observation.CarbonIntensity,
|
| 647 |
+
"hour_of_day": obs.Observation.HourOfDay,
|
| 648 |
+
"batch_queue": obs.Observation.BatchQueue,
|
| 649 |
+
"cumulative_cost": obs.Observation.CumulativeCost,
|
| 650 |
+
"step": obs.Observation.Step,
|
| 651 |
+
"building_id": obs.Observation.BuildingID,
|
| 652 |
+
},
|
| 653 |
+
"reward": obs.Reward,
|
| 654 |
+
"done": done,
|
| 655 |
+
"info": obs.Info,
|
| 656 |
+
}
|
| 657 |
+
response := map[string]interface{}{"data": respData}
|
| 658 |
+
|
| 659 |
+
respBytes, _ := json.Marshal(response)
|
| 660 |
+
conn.WriteMessage(websocket.TextMessage, respBytes)
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
func (s *Server) handleWSState(conn *websocket.Conn) {
|
| 664 |
+
state := s.envMgr.GetState()
|
| 665 |
+
stateBytes, _ := json.Marshal(state)
|
| 666 |
+
conn.WriteMessage(websocket.TextMessage, stateBytes)
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
// Direct handlers for OpenEnv client format (action at top level)
|
| 670 |
+
|
| 671 |
+
func (s *Server) handleWSResetDirect(conn *websocket.Conn, seed *int64, taskID int) {
|
| 672 |
+
if seed == nil {
|
| 673 |
+
var s int64 = 42
|
| 674 |
+
seed = &s
|
| 675 |
+
}
|
| 676 |
+
if taskID == 0 {
|
| 677 |
+
taskID = 1
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
resp := s.envMgr.Reset(env.ResetRequest{
|
| 681 |
+
Seed: seed,
|
| 682 |
+
TaskID: taskID,
|
| 683 |
+
NumBuildings: 1,
|
| 684 |
+
})
|
| 685 |
+
|
| 686 |
+
obs := resp.Observations[0]
|
| 687 |
+
respData := map[string]interface{}{
|
| 688 |
+
"observation": map[string]interface{}{
|
| 689 |
+
"indoor_temperature": obs.IndoorTemperature,
|
| 690 |
+
"thermal_storage_level": obs.ThermalStorageLevel,
|
| 691 |
+
"process_demand": obs.ProcessDemand,
|
| 692 |
+
"current_price": obs.CurrentPrice,
|
| 693 |
+
"grid_stress_signal": obs.GridStressSignal,
|
| 694 |
+
"carbon_intensity": obs.CarbonIntensity,
|
| 695 |
+
"hour_of_day": obs.HourOfDay,
|
| 696 |
+
"batch_queue": obs.BatchQueue,
|
| 697 |
+
"cumulative_cost": obs.CumulativeCost,
|
| 698 |
+
"step": obs.Step,
|
| 699 |
+
"building_id": obs.BuildingID,
|
| 700 |
+
},
|
| 701 |
+
"reward": nil,
|
| 702 |
+
"done": false,
|
| 703 |
+
"info": map[string]interface{}{"episode": resp.Episode, "task_id": resp.TaskID},
|
| 704 |
+
}
|
| 705 |
+
response := map[string]interface{}{"data": respData}
|
| 706 |
+
|
| 707 |
+
respBytes, _ := json.Marshal(response)
|
| 708 |
+
conn.WriteMessage(websocket.TextMessage, respBytes)
|
| 709 |
+
}
|
| 710 |
+
|
| 711 |
+
func (s *Server) handleWSStepDirect(conn *websocket.Conn, msgBytes []byte) {
|
| 712 |
+
// Parse the original message to get action directly
|
| 713 |
+
var rawMsg map[string]interface{}
|
| 714 |
+
if err := json.Unmarshal(msgBytes, &rawMsg); err != nil {
|
| 715 |
+
errMsg, _ := json.Marshal(map[string]string{"error": "invalid step message: " + err.Error()})
|
| 716 |
+
conn.WriteMessage(websocket.TextMessage, errMsg)
|
| 717 |
+
return
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
actionData, ok := rawMsg["action"]
|
| 721 |
+
if !ok {
|
| 722 |
+
errMsg, _ := json.Marshal(map[string]string{"error": "missing action field"})
|
| 723 |
+
conn.WriteMessage(websocket.TextMessage, errMsg)
|
| 724 |
+
return
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
actionBytes, err := json.Marshal(actionData)
|
| 728 |
+
if err != nil {
|
| 729 |
+
errMsg, _ := json.Marshal(map[string]string{"error": "invalid action format"})
|
| 730 |
+
conn.WriteMessage(websocket.TextMessage, errMsg)
|
| 731 |
+
return
|
| 732 |
+
}
|
| 733 |
+
|
| 734 |
+
var action env.ActionModel
|
| 735 |
+
if err := json.Unmarshal(actionBytes, &action); err != nil {
|
| 736 |
+
errMsg, _ := json.Marshal(map[string]string{"error": "invalid action: " + err.Error()})
|
| 737 |
+
conn.WriteMessage(websocket.TextMessage, errMsg)
|
| 738 |
+
return
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
responses, done := s.envMgr.Step([]env.ActionModel{action})
|
| 742 |
+
|
| 743 |
+
if len(responses) > 0 {
|
| 744 |
+
metrics.recordStep(0, responses[0].Reward)
|
| 745 |
+
metrics.recordAction(action.HVACPowerLevel)
|
| 746 |
+
}
|
| 747 |
+
|
| 748 |
+
obs := responses[0]
|
| 749 |
+
respData := map[string]interface{}{
|
| 750 |
+
"observation": map[string]interface{}{
|
| 751 |
+
"indoor_temperature": obs.Observation.IndoorTemperature,
|
| 752 |
+
"thermal_storage_level": obs.Observation.ThermalStorageLevel,
|
| 753 |
+
"process_demand": obs.Observation.ProcessDemand,
|
| 754 |
+
"current_price": obs.Observation.CurrentPrice,
|
| 755 |
+
"grid_stress_signal": obs.Observation.GridStressSignal,
|
| 756 |
+
"carbon_intensity": obs.Observation.CarbonIntensity,
|
| 757 |
+
"hour_of_day": obs.Observation.HourOfDay,
|
| 758 |
+
"batch_queue": obs.Observation.BatchQueue,
|
| 759 |
+
"cumulative_cost": obs.Observation.CumulativeCost,
|
| 760 |
+
"step": obs.Observation.Step,
|
| 761 |
+
"building_id": obs.Observation.BuildingID,
|
| 762 |
+
},
|
| 763 |
+
"reward": obs.Reward,
|
| 764 |
+
"done": done,
|
| 765 |
+
"info": obs.Info,
|
| 766 |
+
}
|
| 767 |
+
response := map[string]interface{}{"data": respData}
|
| 768 |
+
|
| 769 |
+
respBytes, _ := json.Marshal(response)
|
| 770 |
+
conn.WriteMessage(websocket.TextMessage, respBytes)
|
| 771 |
+
}
|
| 772 |
+
|
| 773 |
// ──────────────────────────────────────────────
|
| 774 |
// Entry point
|
| 775 |
// ──────────────────────────────────────────────
|
pyproject.toml
CHANGED
|
@@ -55,12 +55,18 @@ Documentation = "https://github.com/meta-pytorch/OpenEnv"
|
|
| 55 |
|
| 56 |
[project.scripts]
|
| 57 |
server = "server.app:main"
|
| 58 |
-
gridmind-server = "
|
| 59 |
-
gridmind = "
|
| 60 |
-
gridmind-eval = "
|
| 61 |
|
| 62 |
[tool.setuptools]
|
| 63 |
-
packages = ["python", "server"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
[tool.black]
|
| 66 |
line-length = 100
|
|
|
|
| 55 |
|
| 56 |
[project.scripts]
|
| 57 |
server = "server.app:main"
|
| 58 |
+
gridmind-server = "gridmind.server:main"
|
| 59 |
+
gridmind = "gridmind.server:main"
|
| 60 |
+
gridmind-eval = "inference:main"
|
| 61 |
|
| 62 |
[tool.setuptools]
|
| 63 |
+
packages = ["python", "server", "gridmind"]
|
| 64 |
+
|
| 65 |
+
[tool.setuptools.package-data]
|
| 66 |
+
gridmind = ["openenv.yaml", "examples/*.py"]
|
| 67 |
+
|
| 68 |
+
[project.packages.find]
|
| 69 |
+
where = ["."]
|
| 70 |
|
| 71 |
[tool.black]
|
| 72 |
line-length = 100
|