| import random |
| import uuid |
| from dataclasses import dataclass, field |
| from typing import Dict, Optional |
|
|
| from smartgrid_mas.engine.policies import ( |
| adaptive_stackelberg_action, |
| heuristic_joint_action, |
| random_joint_action, |
| ) |
| from smartgrid_mas.engine.control import ReliabilityDispatchControlAgent |
| from smartgrid_mas.engine.dynamics import evolve_grid |
| from smartgrid_mas.engine.ldu import enforce_dispatch |
| from smartgrid_mas.engine.market import clear_market |
| from smartgrid_mas.engine.reward import compute_reward |
| from smartgrid_mas.models import ( |
| DispatchAction, |
| JointAction, |
| MarketObservation, |
| MarketReward, |
| ResetResponse, |
| StateResponse, |
| StepResponse, |
| ) |
| from smartgrid_mas.tasks import TaskConfig, get_task, list_tasks |
|
|
|
|
| SCHEMA_INFO = ( |
| "Provide a JointAction with supply and demand bids from multiple agents plus EV charge/discharge " |
| "commands. Market clears bids first, then the Reliability Dispatch Control Agent proposes corrective dispatch, " |
| "and the Physics-Constrained Safety Shield enforces physical feasibility and logs corrections." |
| ) |
|
|
|
|
| @dataclass |
| class Session: |
| task: TaskConfig |
| rng: random.Random |
| session_id: str = field(default_factory=lambda: str(uuid.uuid4())) |
| step: int = 0 |
| done: bool = False |
| demand_mwh: float = 0.0 |
| renewable_mwh: float = 0.0 |
| peaker_capacity_mwh: float = 0.0 |
| ev_storage_mwh: float = 0.0 |
| ev_storage_capacity_mwh: float = 0.0 |
| base_price: float = 0.0 |
| last_clearing_price: float = 0.0 |
| prior_gap: float = 0.0 |
| correction_count: int = 0 |
| infeasible_actions: int = 0 |
| total_demand_met: float = 0.0 |
| total_cost: float = 0.0 |
| reward_history: list = field(default_factory=list) |
| event_log: list = field(default_factory=list) |
| shock_seen: bool = False |
| contingency_seen: bool = False |
| contingency_type: str = "none" |
| operator_override_enabled: bool = False |
| forecast_demand_mwh: float = 0.0 |
| forecast_renewable_mwh: float = 0.0 |
| load_forecast_error_mwh: float = 0.0 |
| renewable_forecast_error_mwh: float = 0.0 |
| previous_peaker_dispatch_mwh: float = 0.0 |
| previous_ev_discharge_mwh: float = 0.0 |
| peaker_online: bool = False |
| contingency_peaker_multiplier: float = 1.0 |
| contingency_loss_multiplier: float = 1.0 |
| total_emissions_tco2: float = 0.0 |
| blackout_steps: int = 0 |
| reserve_commitment_events: int = 0 |
| emergency_dispatch_events: int = 0 |
| stability_events: int = 0 |
| peaker_activation_timer: int = 0 |
| personalities: Dict[str, str] = field(default_factory=dict) |
|
|
| def to_observation(self, hint: Optional[str] = None, error_message: Optional[str] = None) -> MarketObservation: |
| public_signal = ( |
| "Shock regime active; renewable volatility is elevated" |
| if self.shock_seen |
| else "Normal regime; optimize demand satisfaction with low infeasibility" |
| ) |
| return MarketObservation( |
| step=self.step, |
| steps_taken=self.step, |
| max_steps=self.task.max_steps, |
| demand_mwh=round(self.demand_mwh, 3), |
| renewable_availability_mwh=round(self.renewable_mwh, 3), |
| peaker_capacity_mwh=round(self.peaker_capacity_mwh, 3), |
| ev_storage_mwh=round(self.ev_storage_mwh, 3), |
| ev_storage_capacity_mwh=round(self.ev_storage_capacity_mwh, 3), |
| last_clearing_price=round(self.last_clearing_price, 3), |
| leader_price_signal=round(self.base_price, 3), |
| scarcity_index=round(max(0.0, (self.demand_mwh - self.renewable_mwh) / max(self.demand_mwh, 1e-6)), 4), |
| shock_active=self.shock_seen, |
| forecast_demand_mwh=round(self.forecast_demand_mwh, 3), |
| forecast_renewable_mwh=round(self.forecast_renewable_mwh, 3), |
| load_forecast_error_mwh=round(self.load_forecast_error_mwh, 3), |
| renewable_forecast_error_mwh=round(self.renewable_forecast_error_mwh, 3), |
| contingency_active=self.contingency_seen, |
| contingency_type=self.contingency_type, |
| operator_override_enabled=self.operator_override_enabled, |
| public_signal=public_signal, |
| schema_info=SCHEMA_INFO, |
| hint=hint, |
| error_message=error_message, |
| ) |
|
|
|
|
| class SmartGridMarketEnv: |
| def __init__(self): |
| self._sessions: Dict[str, Session] = {} |
| self._latest_session_id: Optional[str] = None |
|
|
| def reset(self, task_id: str = "default", seed: Optional[int] = None) -> ResetResponse: |
| task = get_task(task_id) |
| rng = random.Random(seed) |
| session = Session( |
| task=task, |
| rng=rng, |
| demand_mwh=task.initial_demand_mwh, |
| renewable_mwh=task.initial_renewable_mwh, |
| peaker_capacity_mwh=task.peaker_capacity_mwh, |
| ev_storage_mwh=task.ev_storage_mwh, |
| ev_storage_capacity_mwh=task.ev_storage_capacity_mwh, |
| base_price=task.base_price_usd_per_mwh, |
| last_clearing_price=task.base_price_usd_per_mwh, |
| personalities={ |
| "renewable_1": rng.choice(["opportunistic", "balanced"]), |
| "peaker_1": rng.choice(["greedy", "balanced", "risk_averse"]), |
| "industrial_1": rng.choice(["risk_averse", "balanced"]), |
| "ev_1": rng.choice(["balanced", "risk_averse"]), |
| }, |
| forecast_demand_mwh=task.initial_demand_mwh, |
| forecast_renewable_mwh=task.initial_renewable_mwh, |
| ) |
| self._sessions[session.session_id] = session |
| self._latest_session_id = session.session_id |
|
|
| return ResetResponse( |
| session_id=session.session_id, |
| task_id=task.task_id, |
| task_description=task.description, |
| schema_info=SCHEMA_INFO, |
| steps_taken=0, |
| observation=session.to_observation(hint=task.hint), |
| ) |
|
|
| def step( |
| self, |
| action: JointAction, |
| session_id: Optional[str] = None, |
| dispatch_action: Optional[DispatchAction] = None, |
| ) -> StepResponse: |
| session = self._get_session(session_id) |
| if session.done: |
| return StepResponse( |
| observation=session.to_observation(error_message="Episode finished. Call reset."), |
| reward=compute_reward( |
| dispatch={ |
| "delivered_supply_mwh": 0.0, |
| "unmet_demand_mwh": 0.0, |
| "oversupply_mwh": 0.0, |
| "correction_count": 0, |
| "storage_loss_mwh": 0.0, |
| "renewable_dispatch_mwh": 0.0, |
| }, |
| clearing_price=session.last_clearing_price, |
| demand_mwh=max(1.0, session.demand_mwh), |
| prior_gap=0.0, |
| ), |
| done=True, |
| truncated=False, |
| info={"error": "episode_done"}, |
| ) |
|
|
| applied_action = action |
| if session.operator_override_enabled: |
| applied_action = heuristic_joint_action(session.to_observation(), personality="risk_averse") |
| market = clear_market(applied_action.bids, leader_price_signal=session.base_price) |
| dispatch_action = dispatch_action or self._resolve_dispatch_action(session, market, session.to_observation()) |
| effective_peaker_capacity = session.peaker_capacity_mwh * session.contingency_peaker_multiplier |
| effective_peaker_capacity += dispatch_action.reserve_activation_mwh + dispatch_action.peaker_adjustment_mwh |
| applied_ev_charge = applied_action.ev_charge_mwh |
| applied_ev_discharge = applied_action.ev_discharge_mwh |
| if dispatch_action.storage_dispatch_mwh >= 0.0: |
| applied_ev_discharge += dispatch_action.storage_dispatch_mwh |
| else: |
| applied_ev_charge += abs(dispatch_action.storage_dispatch_mwh) |
| expected_residual = max(0.0, market.get("cleared_mwh", 0.0) - session.renewable_mwh) |
| if ( |
| session.task.peaker_activation_delay_steps > 0 |
| and not session.peaker_online |
| and session.peaker_activation_timer == 0 |
| and expected_residual > 0.0 |
| ): |
| session.peaker_activation_timer = session.task.peaker_activation_delay_steps |
| session.event_log.append( |
| { |
| "step": session.step, |
| "type": "peaker_startup_delay", |
| "delay_steps": session.task.peaker_activation_delay_steps, |
| } |
| ) |
| dispatch_override_active = ( |
| dispatch_action.reserve_activation_mwh > 0.0 |
| or dispatch_action.peaker_adjustment_mwh > 0.0 |
| or dispatch_action.storage_dispatch_mwh > 0.0 |
| or dispatch_action.corrective_redispatch_mwh > 0.0 |
| ) |
| if dispatch_override_active and session.peaker_activation_timer > 0: |
| session.peaker_activation_timer = 0 |
| if session.peaker_activation_timer > 0: |
| effective_peaker_capacity = 0.0 |
| session.peaker_activation_timer -= 1 |
|
|
| adjusted_market = dict(market) |
| dispatch_target_shift = dispatch_action.corrective_redispatch_mwh |
| adjusted_market["cleared_mwh"] = max(0.0, adjusted_market.get("cleared_mwh", 0.0) + dispatch_target_shift) |
| adjusted_market["dispatcher_action"] = dispatch_action.model_dump() |
|
|
| dispatch, next_storage = enforce_dispatch( |
| market_result=adjusted_market, |
| demand_mwh=session.demand_mwh, |
| renewable_available_mwh=session.renewable_mwh, |
| peaker_capacity_mwh=effective_peaker_capacity, |
| ev_storage_mwh=session.ev_storage_mwh, |
| ev_storage_capacity_mwh=session.ev_storage_capacity_mwh, |
| ev_charge_mwh=applied_ev_charge, |
| ev_discharge_mwh=applied_ev_discharge, |
| reserve_margin_ratio=session.task.reserve_margin_ratio, |
| reserve_commitment_threshold_ratio=session.task.reserve_commitment_threshold_ratio, |
| peaker_ramp_limit_mwh=session.task.peaker_ramp_limit_mwh, |
| ev_ramp_limit_mwh=session.task.ev_ramp_limit_mwh, |
| previous_peaker_dispatch_mwh=session.previous_peaker_dispatch_mwh, |
| previous_ev_discharge_mwh=session.previous_ev_discharge_mwh, |
| previous_peaker_online=session.peaker_online, |
| peaker_startup_cost_usd=session.task.peaker_startup_cost_usd, |
| peaker_emission_factor_tco2_per_mwh=session.task.peaker_emission_factor_tco2_per_mwh, |
| transmission_loss_multiplier=session.contingency_loss_multiplier, |
| carbon_price_usd_per_tco2=session.task.carbon_price_usd_per_tco2, |
| enable_reserve_logic=session.task.enable_reserve_logic, |
| enable_ramp_limits=session.task.enable_ramp_limits, |
| enable_startup_emissions=session.task.enable_startup_emissions, |
| ) |
|
|
| reward = compute_reward( |
| dispatch=dispatch, |
| clearing_price=market["clearing_price"] or session.base_price, |
| demand_mwh=session.demand_mwh, |
| prior_gap=session.prior_gap, |
| carbon_price_usd_per_tco2=session.task.carbon_price_usd_per_tco2, |
| ) |
|
|
| session.step += 1 |
| session.ev_storage_mwh = next_storage |
| session.last_clearing_price = market["clearing_price"] or session.base_price |
| session.prior_gap = dispatch["delivered_supply_mwh"] - session.demand_mwh |
| session.previous_peaker_dispatch_mwh = dispatch.get("peaker_dispatch_mwh", 0.0) |
| session.previous_ev_discharge_mwh = dispatch.get("ev_discharge_mwh", 0.0) |
| session.peaker_online = bool(dispatch.get("peaker_online", False)) |
| session.correction_count += dispatch["correction_count"] |
| if dispatch["correction_count"] > 0: |
| session.infeasible_actions += 1 |
| session.total_demand_met += min(session.demand_mwh, dispatch["delivered_supply_mwh"]) |
| energy_cost = dispatch["delivered_supply_mwh"] * session.last_clearing_price |
| session.total_cost += energy_cost + dispatch.get("startup_cost_usd", 0.0) + dispatch.get("emissions_cost_usd", 0.0) |
| session.total_emissions_tco2 += dispatch.get("emissions_tco2", 0.0) |
| if dispatch["unmet_demand_mwh"] > 0.0: |
| session.blackout_steps += 1 |
| if dispatch.get("reserve_commitment_active", False): |
| session.reserve_commitment_events += 1 |
| if dispatch.get("emergency_dispatch_triggered", False): |
| session.emergency_dispatch_events += 1 |
| if dispatch.get("stability_risk_index", 0.0) >= 0.45: |
| session.stability_events += 1 |
| session.reward_history.append(reward.score) |
|
|
| private_views = self._build_private_agent_views(session, market, dispatch) |
|
|
| next_demand, next_renewable, next_price, dyn_info = evolve_grid( |
| demand_mwh=session.demand_mwh, |
| renewable_mwh=session.renewable_mwh, |
| base_price_usd_per_mwh=session.base_price, |
| step=session.step, |
| task=session.task, |
| rng=session.rng, |
| ) |
| session.demand_mwh = next_demand |
| session.renewable_mwh = next_renewable |
| session.base_price = next_price |
| session.shock_seen = session.shock_seen or dyn_info["shock_active"] |
| session.contingency_seen = session.contingency_seen or dyn_info.get("contingency_active", False) |
| session.contingency_type = dyn_info.get("contingency_type", "none") |
| session.forecast_demand_mwh = dyn_info.get("forecast_demand_mwh", session.demand_mwh) |
| session.forecast_renewable_mwh = dyn_info.get("forecast_renewable_mwh", session.renewable_mwh) |
| session.load_forecast_error_mwh = dyn_info.get("load_forecast_error_mwh", 0.0) |
| session.renewable_forecast_error_mwh = dyn_info.get("renewable_forecast_error_mwh", 0.0) |
| session.contingency_peaker_multiplier = dyn_info.get("peaker_capacity_multiplier", 1.0) |
| session.contingency_loss_multiplier = dyn_info.get("transmission_loss_multiplier", 1.0) |
|
|
| event = { |
| "step": session.step, |
| "market": market, |
| "dispatch_action": dispatch_action.model_dump(), |
| "dispatch": dispatch, |
| "reward": reward.model_dump(), |
| "dynamics": dyn_info, |
| "agent_private_views": private_views, |
| } |
| session.event_log.append(event) |
|
|
| done = session.step >= session.task.max_steps |
| session.done = done |
|
|
| info = { |
| "market": market, |
| "dispatch_action": dispatch_action.model_dump(), |
| "dispatch": dispatch, |
| "dynamics": dyn_info, |
| "agent_private_views": private_views, |
| "summary": { |
| "avg_reward": round(sum(session.reward_history) / len(session.reward_history), 4), |
| "total_demand_met_mwh": round(session.total_demand_met, 3), |
| "total_cost_usd": round(session.total_cost, 3), |
| "total_emissions_tco2": round(session.total_emissions_tco2, 4), |
| "blackout_steps": session.blackout_steps, |
| "infeasible_actions": session.infeasible_actions, |
| "ldu_corrections": session.correction_count, |
| "reserve_commitment_events": session.reserve_commitment_events, |
| "emergency_dispatch_events": session.emergency_dispatch_events, |
| "stability_events": session.stability_events, |
| "leader_adjusted_bids": market["leader_adjusted_bids"], |
| "personality_map": session.personalities, |
| "operator_override_enabled": session.operator_override_enabled, |
| }, |
| } |
|
|
| return StepResponse( |
| observation=session.to_observation(), |
| reward=reward, |
| done=done, |
| truncated=False, |
| info=info, |
| ) |
|
|
| def policy_action( |
| self, |
| policy: str = "adaptive", |
| personality: str = "balanced", |
| session_id: Optional[str] = None, |
| ) -> JointAction: |
| session = self._get_session(session_id) |
| obs = session.to_observation() |
| if policy == "random": |
| return random_joint_action(obs, session.rng) |
| if policy == "heuristic": |
| return heuristic_joint_action(obs, personality=personality) |
| return adaptive_stackelberg_action(obs, personality=personality) |
|
|
| def dispatch_action( |
| self, |
| personality: str = "balanced", |
| session_id: Optional[str] = None, |
| cleared_mwh: Optional[float] = None, |
| ) -> DispatchAction: |
| session = self._get_session(session_id) |
| obs = session.to_observation() |
| controller = ReliabilityDispatchControlAgent(personality=personality) |
| return controller.act(obs, cleared_mwh=float(cleared_mwh if cleared_mwh is not None else obs.demand_mwh)) |
|
|
| def state(self, session_id: Optional[str] = None) -> StateResponse: |
| session = self._get_session(session_id) |
| return StateResponse( |
| current_task_id=session.task.task_id, |
| steps_taken=session.step, |
| episode_done=session.done, |
| observation=session.to_observation(), |
| ) |
|
|
| def events(self, session_id: Optional[str] = None) -> Dict: |
| session = self._get_session(session_id) |
| return {"session_id": session.session_id, "events": session.event_log[-50:]} |
|
|
| def inject_shock(self, session_id: Optional[str] = None, renewable_drop_mwh: float = 20.0) -> Dict: |
| session = self._get_session(session_id) |
| before = session.renewable_mwh |
| session.renewable_mwh = max(0.0, session.renewable_mwh - max(0.0, renewable_drop_mwh)) |
| session.shock_seen = True |
| event = { |
| "step": session.step, |
| "type": "manual_shock", |
| "renewable_before_mwh": round(before, 3), |
| "renewable_after_mwh": round(session.renewable_mwh, 3), |
| "drop_mwh": round(max(0.0, renewable_drop_mwh), 3), |
| } |
| session.event_log.append(event) |
| return { |
| "session_id": session.session_id, |
| "shock_event": event, |
| "observation": session.to_observation(), |
| } |
|
|
| def get_schema(self) -> Dict: |
| return { |
| "action_schema": JointAction.model_json_schema(), |
| "dispatch_action_schema": DispatchAction.model_json_schema(), |
| "observation_schema": MarketObservation.model_json_schema(), |
| "reward_schema": MarketReward.model_json_schema(), |
| "tasks": list_tasks(), |
| "notes": "Hybrid Theme 1+2+3.1 baseline implementation with the Physics-Constrained Safety Shield as core physical layer", |
| } |
|
|
| def set_operator_override(self, enabled: bool, session_id: Optional[str] = None) -> Dict: |
| session = self._get_session(session_id) |
| session.operator_override_enabled = bool(enabled) |
| event = { |
| "step": session.step, |
| "type": "operator_override", |
| "enabled": session.operator_override_enabled, |
| } |
| session.event_log.append(event) |
| return { |
| "session_id": session.session_id, |
| "operator_override_enabled": session.operator_override_enabled, |
| "event": event, |
| } |
|
|
| def _get_session(self, session_id: Optional[str]) -> Session: |
| sid = session_id or self._latest_session_id |
| if sid is None or sid not in self._sessions: |
| raise KeyError("No active session. Call /reset first.") |
| return self._sessions[sid] |
|
|
| def _build_private_agent_views(self, session: Session, market: Dict, dispatch: Dict) -> Dict[str, Dict]: |
| scarcity = max(0.0, (session.demand_mwh - session.renewable_mwh) / max(session.demand_mwh, 1e-6)) |
| spread = max(0.0, session.base_price - session.last_clearing_price) |
| return { |
| "renewable_1": { |
| "personality": session.personalities.get("renewable_1", "balanced"), |
| "curtailment_risk": round(max(0.0, session.renewable_mwh - market.get("cleared_mwh", 0.0)), 3), |
| "forecast_bias": round(session.rng.uniform(-3.0, 3.0), 3), |
| }, |
| "peaker_1": { |
| "personality": session.personalities.get("peaker_1", "balanced"), |
| "scarcity_index": round(scarcity, 4), |
| "margin_signal": round(market.get("clearing_price", session.base_price) - 42.0, 3), |
| }, |
| "industrial_1": { |
| "personality": session.personalities.get("industrial_1", "balanced"), |
| "budget_pressure": round( |
| market.get("clearing_price", session.base_price) / max(session.base_price, 1e-6), |
| 4, |
| ), |
| "unmet_demand_mwh": dispatch["unmet_demand_mwh"], |
| }, |
| "ev_1": { |
| "personality": session.personalities.get("ev_1", "balanced"), |
| "soc_ratio": round(session.ev_storage_mwh / max(session.ev_storage_capacity_mwh, 1e-6), 4), |
| "arbitrage_spread": round(spread, 3), |
| }, |
| } |
|
|
| def _resolve_dispatch_action( |
| self, |
| session: Session, |
| market: Dict, |
| observation: MarketObservation, |
| personality: str = "balanced", |
| ) -> DispatchAction: |
| controller = ReliabilityDispatchControlAgent(personality=personality) |
| return controller.act(observation, cleared_mwh=float(market.get("cleared_mwh", observation.demand_mwh))) |
|
|