import os import json import time from typing import Optional from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, StreamingResponse from pydantic import BaseModel from smartgrid_mas.engine.policies import ( adaptive_stackelberg_action, heuristic_joint_action, random_joint_action, ) from smartgrid_mas.env import SmartGridMarketEnv from smartgrid_mas.demo_page import build_demo_html from smartgrid_mas.models import DispatchAction, JointAction, ResetRequest, StepRequest app = FastAPI( title="OpenEnv Smart Grid MarketSim", description="Multi-agent market simulator with a Reliability Dispatch Control Agent and a Physics-Constrained Safety Shield.", version="0.1.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) env = SmartGridMarketEnv() DEMO_MODE_CONFIG = { "policy": "adaptive", "personality": "balanced", "task_id": "default", "seed": 42, } class InferenceRequest(BaseModel): policy: str = "heuristic" personality: str = "balanced" task_id: str = "default" seed: Optional[int] = 42 dispatcher_enabled: bool = True class ShockRequest(BaseModel): renewable_drop_mwh: float = 20.0 class PolicyActionRequest(BaseModel): policy: str = "adaptive" personality: str = "balanced" class DispatchActionRequest(BaseModel): personality: str = "balanced" cleared_mwh: Optional[float] = None class OverrideRequest(BaseModel): enabled: bool = True class ResilienceDemoRequest(BaseModel): task_id: str = "stress_shock" seed: int = 314 baseline_policy: str = "random" candidate_policy: str = "adaptive" def _rollout_inference(request: InferenceRequest) -> dict: reset_resp = env.reset(task_id=request.task_id, seed=request.seed) sid = reset_resp.session_id obs = reset_resp.observation rng = __import__("random").Random(request.seed) trajectory = [] while True: if request.policy == "random": action = random_joint_action(obs, rng) elif request.policy == "adaptive": action = adaptive_stackelberg_action(obs, personality=request.personality) else: action = heuristic_joint_action(obs, personality=request.personality) dispatch_action = None if request.dispatcher_enabled else DispatchAction() result = env.step(action=action, session_id=sid, dispatch_action=dispatch_action) trajectory.append( { "step": len(trajectory) + 1, "action": action.model_dump(), "dispatch_action": result.info.get("dispatch_action"), "reward": result.reward.model_dump(), "info": result.info, } ) obs = result.observation if result.done: break avg_reward = sum(t["reward"]["score"] for t in trajectory) / max(1, len(trajectory)) return { "success": True, "policy": request.policy, "personality": request.personality, "task_id": request.task_id, "seed": request.seed, "dispatcher_enabled": request.dispatcher_enabled, "steps": len(trajectory), "average_reward": round(avg_reward, 4), "trajectory": trajectory, } def _run_policy_episode(task_id: str, seed: int, policy: str, personality: str = "balanced") -> dict: reset_resp = env.reset(task_id=task_id, seed=seed) sid = reset_resp.session_id obs = reset_resp.observation rng = __import__("random").Random(seed) rewards = [] blackout_steps = 0 unmet_energy = 0.0 reserve_events = 0 emergency_events = 0 startup_events = 0 stability_events = 0 min_frequency_hz = 50.0 peak_stability_risk = 0.0 while True: if policy == "random": action = random_joint_action(obs, rng) elif policy == "adaptive": action = adaptive_stackelberg_action(obs, personality=personality) else: action = heuristic_joint_action(obs, personality=personality) result = env.step(action=action, session_id=sid) rewards.append(result.reward.score) dispatch = result.info["dispatch"] unmet = dispatch.get("unmet_demand_mwh", 0.0) min_frequency_hz = min(min_frequency_hz, float(dispatch.get("frequency_hz", 50.0))) peak_stability_risk = max(peak_stability_risk, float(dispatch.get("stability_risk_index", 0.0))) reserve_events += 1 if dispatch.get("reserve_commitment_active", False) else 0 emergency_events += 1 if dispatch.get("emergency_dispatch_triggered", False) else 0 startup_events += 1 if dispatch.get("startup_cost_usd", 0.0) > 0.0 else 0 stability_events += 1 if dispatch.get("stability_risk_index", 0.0) >= 0.45 else 0 unmet_energy += unmet if unmet > 0.0: blackout_steps += 1 obs = result.observation if result.done: summary = result.info["summary"] return { "avg_reward": round(sum(rewards) / max(1, len(rewards)), 4), "total_cost_usd": summary["total_cost_usd"], "total_emissions_tco2": summary.get("total_emissions_tco2", 0.0), "blackout_steps": blackout_steps, "unmet_energy_mwh": round(unmet_energy, 3), "corrections": summary.get("ldu_corrections", 0), "reserve_commitment_events": reserve_events, "emergency_dispatch_events": emergency_events, "startup_events": startup_events, "stability_events": stability_events, "min_frequency_hz": round(min_frequency_hz, 4), "peak_stability_risk": round(peak_stability_risk, 4), } @app.get("/") def root(): return { "name": "OpenEnv Smart Grid MarketSim", "status": "ready", "docs": "/docs", "health": "/health", "demo": "/demo", } @app.get("/health") def health(): return {"status": "ok", "service": "openenv-smartgrid-marketsim"} @app.post("/reset") def reset(request: ResetRequest): try: return env.reset(task_id=request.task_id, seed=request.seed) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @app.post("/step") def step(request: StepRequest, session_id: Optional[str] = Query(default=None)): try: return env.step(action=request.action, session_id=session_id, dispatch_action=request.dispatch_action) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @app.get("/state") def state(session_id: Optional[str] = Query(default=None)): try: return env.state(session_id=session_id) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @app.post("/act") def act(request: PolicyActionRequest, session_id: Optional[str] = Query(default=None)): try: action = env.policy_action(policy=request.policy, personality=request.personality, session_id=session_id) return {"action": action.model_dump()} except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @app.post("/dispatch-act") def dispatch_act(request: DispatchActionRequest, session_id: Optional[str] = Query(default=None)): try: action = env.dispatch_action(personality=request.personality, session_id=session_id, cleared_mwh=request.cleared_mwh) return {"dispatch_action": action.model_dump()} except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @app.get("/events") def events(session_id: Optional[str] = Query(default=None)): try: return env.events(session_id=session_id) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @app.get("/events/stream") def events_stream(session_id: Optional[str] = Query(default=None), poll_ms: int = Query(default=650, ge=150, le=5000)): try: env.state(session_id=session_id) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc def event_generator(): last_len = 0 while True: data = env.events(session_id=session_id) events_list = data.get("events", []) if len(events_list) > last_len: for item in events_list[last_len:]: yield f"data: {json.dumps(item)}\n\n" last_len = len(events_list) else: yield ": keepalive\n\n" time.sleep(poll_ms / 1000.0) return StreamingResponse(event_generator(), media_type="text/event-stream") @app.post("/inject-shock") def inject_shock(request: ShockRequest, session_id: Optional[str] = Query(default=None)): try: return env.inject_shock(session_id=session_id, renewable_drop_mwh=request.renewable_drop_mwh) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @app.post("/operator-override") def operator_override(request: OverrideRequest, session_id: Optional[str] = Query(default=None)): try: return env.set_operator_override(enabled=request.enabled, session_id=session_id) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @app.get("/demo", response_class=HTMLResponse) def demo_page(): return HTMLResponse(build_demo_html()) @app.get("/info") def info(): return env.get_schema() @app.post("/run-inference") def run_inference(request: InferenceRequest): return _rollout_inference(request) @app.post("/run-demo-mode") def run_demo_mode(dispatcher_enabled: bool = True): request = InferenceRequest(**DEMO_MODE_CONFIG) request.dispatcher_enabled = dispatcher_enabled result = _rollout_inference(request) result["mode"] = "demo" result["deterministic"] = True result["dispatcher_enabled"] = dispatcher_enabled result["governing_claim"] = ( "Reliable grid balancing emerges when strategic bidding is constrained by a dispatch control agent and a physical safety shield." ) return result @app.post("/run-resilience-demo") def run_resilience_demo(request: ResilienceDemoRequest): baseline = _run_policy_episode( task_id=request.task_id, seed=request.seed, policy=request.baseline_policy, ) candidate = _run_policy_episode( task_id=request.task_id, seed=request.seed, policy=request.candidate_policy, ) prevented = baseline["blackout_steps"] > candidate["blackout_steps"] return { "task_id": request.task_id, "seed": request.seed, "baseline_policy": request.baseline_policy, "candidate_policy": request.candidate_policy, "baseline": baseline, "candidate": candidate, "catastrophic_failure_prevented": prevented, "trajectory_comparison": { "blackout_step_delta": baseline["blackout_steps"] - candidate["blackout_steps"], "reserve_activation_delta": baseline["reserve_commitment_events"] - candidate["reserve_commitment_events"], "emergency_dispatch_delta": baseline["emergency_dispatch_events"] - candidate["emergency_dispatch_events"], "stability_event_delta": baseline["stability_events"] - candidate["stability_events"], }, "narrative": ( "Candidate policy preserved service continuity under contingency and forecast uncertainty, while improving reserve and stability outcomes." if prevented else "Candidate policy did not outperform baseline on blackout prevention for this seed." ), } def main() -> None: import uvicorn port = int(os.getenv("PORT", "7860")) uvicorn.run(app, host="0.0.0.0", port=port) if __name__ == "__main__": main()