grid / main.py
Not-OmKar's picture
Big update
427a79e
import os
import json
import time
from typing import Optional
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, StreamingResponse
from pydantic import BaseModel
from smartgrid_mas.engine.policies import (
adaptive_stackelberg_action,
heuristic_joint_action,
random_joint_action,
)
from smartgrid_mas.env import SmartGridMarketEnv
from smartgrid_mas.demo_page import build_demo_html
from smartgrid_mas.models import DispatchAction, JointAction, ResetRequest, StepRequest
app = FastAPI(
title="OpenEnv Smart Grid MarketSim",
description="Multi-agent market simulator with a Reliability Dispatch Control Agent and a Physics-Constrained Safety Shield.",
version="0.1.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
env = SmartGridMarketEnv()
DEMO_MODE_CONFIG = {
"policy": "adaptive",
"personality": "balanced",
"task_id": "default",
"seed": 42,
}
class InferenceRequest(BaseModel):
policy: str = "heuristic"
personality: str = "balanced"
task_id: str = "default"
seed: Optional[int] = 42
dispatcher_enabled: bool = True
class ShockRequest(BaseModel):
renewable_drop_mwh: float = 20.0
class PolicyActionRequest(BaseModel):
policy: str = "adaptive"
personality: str = "balanced"
class DispatchActionRequest(BaseModel):
personality: str = "balanced"
cleared_mwh: Optional[float] = None
class OverrideRequest(BaseModel):
enabled: bool = True
class ResilienceDemoRequest(BaseModel):
task_id: str = "stress_shock"
seed: int = 314
baseline_policy: str = "random"
candidate_policy: str = "adaptive"
def _rollout_inference(request: InferenceRequest) -> dict:
reset_resp = env.reset(task_id=request.task_id, seed=request.seed)
sid = reset_resp.session_id
obs = reset_resp.observation
rng = __import__("random").Random(request.seed)
trajectory = []
while True:
if request.policy == "random":
action = random_joint_action(obs, rng)
elif request.policy == "adaptive":
action = adaptive_stackelberg_action(obs, personality=request.personality)
else:
action = heuristic_joint_action(obs, personality=request.personality)
dispatch_action = None if request.dispatcher_enabled else DispatchAction()
result = env.step(action=action, session_id=sid, dispatch_action=dispatch_action)
trajectory.append(
{
"step": len(trajectory) + 1,
"action": action.model_dump(),
"dispatch_action": result.info.get("dispatch_action"),
"reward": result.reward.model_dump(),
"info": result.info,
}
)
obs = result.observation
if result.done:
break
avg_reward = sum(t["reward"]["score"] for t in trajectory) / max(1, len(trajectory))
return {
"success": True,
"policy": request.policy,
"personality": request.personality,
"task_id": request.task_id,
"seed": request.seed,
"dispatcher_enabled": request.dispatcher_enabled,
"steps": len(trajectory),
"average_reward": round(avg_reward, 4),
"trajectory": trajectory,
}
def _run_policy_episode(task_id: str, seed: int, policy: str, personality: str = "balanced") -> dict:
reset_resp = env.reset(task_id=task_id, seed=seed)
sid = reset_resp.session_id
obs = reset_resp.observation
rng = __import__("random").Random(seed)
rewards = []
blackout_steps = 0
unmet_energy = 0.0
reserve_events = 0
emergency_events = 0
startup_events = 0
stability_events = 0
min_frequency_hz = 50.0
peak_stability_risk = 0.0
while True:
if policy == "random":
action = random_joint_action(obs, rng)
elif policy == "adaptive":
action = adaptive_stackelberg_action(obs, personality=personality)
else:
action = heuristic_joint_action(obs, personality=personality)
result = env.step(action=action, session_id=sid)
rewards.append(result.reward.score)
dispatch = result.info["dispatch"]
unmet = dispatch.get("unmet_demand_mwh", 0.0)
min_frequency_hz = min(min_frequency_hz, float(dispatch.get("frequency_hz", 50.0)))
peak_stability_risk = max(peak_stability_risk, float(dispatch.get("stability_risk_index", 0.0)))
reserve_events += 1 if dispatch.get("reserve_commitment_active", False) else 0
emergency_events += 1 if dispatch.get("emergency_dispatch_triggered", False) else 0
startup_events += 1 if dispatch.get("startup_cost_usd", 0.0) > 0.0 else 0
stability_events += 1 if dispatch.get("stability_risk_index", 0.0) >= 0.45 else 0
unmet_energy += unmet
if unmet > 0.0:
blackout_steps += 1
obs = result.observation
if result.done:
summary = result.info["summary"]
return {
"avg_reward": round(sum(rewards) / max(1, len(rewards)), 4),
"total_cost_usd": summary["total_cost_usd"],
"total_emissions_tco2": summary.get("total_emissions_tco2", 0.0),
"blackout_steps": blackout_steps,
"unmet_energy_mwh": round(unmet_energy, 3),
"corrections": summary.get("ldu_corrections", 0),
"reserve_commitment_events": reserve_events,
"emergency_dispatch_events": emergency_events,
"startup_events": startup_events,
"stability_events": stability_events,
"min_frequency_hz": round(min_frequency_hz, 4),
"peak_stability_risk": round(peak_stability_risk, 4),
}
@app.get("/")
def root():
return {
"name": "OpenEnv Smart Grid MarketSim",
"status": "ready",
"docs": "/docs",
"health": "/health",
"demo": "/demo",
}
@app.get("/health")
def health():
return {"status": "ok", "service": "openenv-smartgrid-marketsim"}
@app.post("/reset")
def reset(request: ResetRequest):
try:
return env.reset(task_id=request.task_id, seed=request.seed)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/step")
def step(request: StepRequest, session_id: Optional[str] = Query(default=None)):
try:
return env.step(action=request.action, session_id=session_id, dispatch_action=request.dispatch_action)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/state")
def state(session_id: Optional[str] = Query(default=None)):
try:
return env.state(session_id=session_id)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/act")
def act(request: PolicyActionRequest, session_id: Optional[str] = Query(default=None)):
try:
action = env.policy_action(policy=request.policy, personality=request.personality, session_id=session_id)
return {"action": action.model_dump()}
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/dispatch-act")
def dispatch_act(request: DispatchActionRequest, session_id: Optional[str] = Query(default=None)):
try:
action = env.dispatch_action(personality=request.personality, session_id=session_id, cleared_mwh=request.cleared_mwh)
return {"dispatch_action": action.model_dump()}
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/events")
def events(session_id: Optional[str] = Query(default=None)):
try:
return env.events(session_id=session_id)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/events/stream")
def events_stream(session_id: Optional[str] = Query(default=None), poll_ms: int = Query(default=650, ge=150, le=5000)):
try:
env.state(session_id=session_id)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
def event_generator():
last_len = 0
while True:
data = env.events(session_id=session_id)
events_list = data.get("events", [])
if len(events_list) > last_len:
for item in events_list[last_len:]:
yield f"data: {json.dumps(item)}\n\n"
last_len = len(events_list)
else:
yield ": keepalive\n\n"
time.sleep(poll_ms / 1000.0)
return StreamingResponse(event_generator(), media_type="text/event-stream")
@app.post("/inject-shock")
def inject_shock(request: ShockRequest, session_id: Optional[str] = Query(default=None)):
try:
return env.inject_shock(session_id=session_id, renewable_drop_mwh=request.renewable_drop_mwh)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/operator-override")
def operator_override(request: OverrideRequest, session_id: Optional[str] = Query(default=None)):
try:
return env.set_operator_override(enabled=request.enabled, session_id=session_id)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/demo", response_class=HTMLResponse)
def demo_page():
return HTMLResponse(build_demo_html())
@app.get("/info")
def info():
return env.get_schema()
@app.post("/run-inference")
def run_inference(request: InferenceRequest):
return _rollout_inference(request)
@app.post("/run-demo-mode")
def run_demo_mode(dispatcher_enabled: bool = True):
request = InferenceRequest(**DEMO_MODE_CONFIG)
request.dispatcher_enabled = dispatcher_enabled
result = _rollout_inference(request)
result["mode"] = "demo"
result["deterministic"] = True
result["dispatcher_enabled"] = dispatcher_enabled
result["governing_claim"] = (
"Reliable grid balancing emerges when strategic bidding is constrained by a dispatch control agent and a physical safety shield."
)
return result
@app.post("/run-resilience-demo")
def run_resilience_demo(request: ResilienceDemoRequest):
baseline = _run_policy_episode(
task_id=request.task_id,
seed=request.seed,
policy=request.baseline_policy,
)
candidate = _run_policy_episode(
task_id=request.task_id,
seed=request.seed,
policy=request.candidate_policy,
)
prevented = baseline["blackout_steps"] > candidate["blackout_steps"]
return {
"task_id": request.task_id,
"seed": request.seed,
"baseline_policy": request.baseline_policy,
"candidate_policy": request.candidate_policy,
"baseline": baseline,
"candidate": candidate,
"catastrophic_failure_prevented": prevented,
"trajectory_comparison": {
"blackout_step_delta": baseline["blackout_steps"] - candidate["blackout_steps"],
"reserve_activation_delta": baseline["reserve_commitment_events"] - candidate["reserve_commitment_events"],
"emergency_dispatch_delta": baseline["emergency_dispatch_events"] - candidate["emergency_dispatch_events"],
"stability_event_delta": baseline["stability_events"] - candidate["stability_events"],
},
"narrative": (
"Candidate policy preserved service continuity under contingency and forecast uncertainty, while improving reserve and stability outcomes."
if prevented
else "Candidate policy did not outperform baseline on blackout prevention for this seed."
),
}
def main() -> None:
import uvicorn
port = int(os.getenv("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
main()