grid / smartgrid_mas /env.py
Not-OmKar's picture
Big update
427a79e
import random
import uuid
from dataclasses import dataclass, field
from typing import Dict, Optional
from smartgrid_mas.engine.policies import (
adaptive_stackelberg_action,
heuristic_joint_action,
random_joint_action,
)
from smartgrid_mas.engine.control import ReliabilityDispatchControlAgent
from smartgrid_mas.engine.dynamics import evolve_grid
from smartgrid_mas.engine.ldu import enforce_dispatch
from smartgrid_mas.engine.market import clear_market
from smartgrid_mas.engine.reward import compute_reward
from smartgrid_mas.models import (
DispatchAction,
JointAction,
MarketObservation,
MarketReward,
ResetResponse,
StateResponse,
StepResponse,
)
from smartgrid_mas.tasks import TaskConfig, get_task, list_tasks
SCHEMA_INFO = (
"Provide a JointAction with supply and demand bids from multiple agents plus EV charge/discharge "
"commands. Market clears bids first, then the Reliability Dispatch Control Agent proposes corrective dispatch, "
"and the Physics-Constrained Safety Shield enforces physical feasibility and logs corrections."
)
@dataclass
class Session:
task: TaskConfig
rng: random.Random
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
step: int = 0
done: bool = False
demand_mwh: float = 0.0
renewable_mwh: float = 0.0
peaker_capacity_mwh: float = 0.0
ev_storage_mwh: float = 0.0
ev_storage_capacity_mwh: float = 0.0
base_price: float = 0.0
last_clearing_price: float = 0.0
prior_gap: float = 0.0
correction_count: int = 0
infeasible_actions: int = 0
total_demand_met: float = 0.0
total_cost: float = 0.0
reward_history: list = field(default_factory=list)
event_log: list = field(default_factory=list)
shock_seen: bool = False
contingency_seen: bool = False
contingency_type: str = "none"
operator_override_enabled: bool = False
forecast_demand_mwh: float = 0.0
forecast_renewable_mwh: float = 0.0
load_forecast_error_mwh: float = 0.0
renewable_forecast_error_mwh: float = 0.0
previous_peaker_dispatch_mwh: float = 0.0
previous_ev_discharge_mwh: float = 0.0
peaker_online: bool = False
contingency_peaker_multiplier: float = 1.0
contingency_loss_multiplier: float = 1.0
total_emissions_tco2: float = 0.0
blackout_steps: int = 0
reserve_commitment_events: int = 0
emergency_dispatch_events: int = 0
stability_events: int = 0
peaker_activation_timer: int = 0
personalities: Dict[str, str] = field(default_factory=dict)
def to_observation(self, hint: Optional[str] = None, error_message: Optional[str] = None) -> MarketObservation:
public_signal = (
"Shock regime active; renewable volatility is elevated"
if self.shock_seen
else "Normal regime; optimize demand satisfaction with low infeasibility"
)
return MarketObservation(
step=self.step,
steps_taken=self.step,
max_steps=self.task.max_steps,
demand_mwh=round(self.demand_mwh, 3),
renewable_availability_mwh=round(self.renewable_mwh, 3),
peaker_capacity_mwh=round(self.peaker_capacity_mwh, 3),
ev_storage_mwh=round(self.ev_storage_mwh, 3),
ev_storage_capacity_mwh=round(self.ev_storage_capacity_mwh, 3),
last_clearing_price=round(self.last_clearing_price, 3),
leader_price_signal=round(self.base_price, 3),
scarcity_index=round(max(0.0, (self.demand_mwh - self.renewable_mwh) / max(self.demand_mwh, 1e-6)), 4),
shock_active=self.shock_seen,
forecast_demand_mwh=round(self.forecast_demand_mwh, 3),
forecast_renewable_mwh=round(self.forecast_renewable_mwh, 3),
load_forecast_error_mwh=round(self.load_forecast_error_mwh, 3),
renewable_forecast_error_mwh=round(self.renewable_forecast_error_mwh, 3),
contingency_active=self.contingency_seen,
contingency_type=self.contingency_type,
operator_override_enabled=self.operator_override_enabled,
public_signal=public_signal,
schema_info=SCHEMA_INFO,
hint=hint,
error_message=error_message,
)
class SmartGridMarketEnv:
def __init__(self):
self._sessions: Dict[str, Session] = {}
self._latest_session_id: Optional[str] = None
def reset(self, task_id: str = "default", seed: Optional[int] = None) -> ResetResponse:
task = get_task(task_id)
rng = random.Random(seed)
session = Session(
task=task,
rng=rng,
demand_mwh=task.initial_demand_mwh,
renewable_mwh=task.initial_renewable_mwh,
peaker_capacity_mwh=task.peaker_capacity_mwh,
ev_storage_mwh=task.ev_storage_mwh,
ev_storage_capacity_mwh=task.ev_storage_capacity_mwh,
base_price=task.base_price_usd_per_mwh,
last_clearing_price=task.base_price_usd_per_mwh,
personalities={
"renewable_1": rng.choice(["opportunistic", "balanced"]),
"peaker_1": rng.choice(["greedy", "balanced", "risk_averse"]),
"industrial_1": rng.choice(["risk_averse", "balanced"]),
"ev_1": rng.choice(["balanced", "risk_averse"]),
},
forecast_demand_mwh=task.initial_demand_mwh,
forecast_renewable_mwh=task.initial_renewable_mwh,
)
self._sessions[session.session_id] = session
self._latest_session_id = session.session_id
return ResetResponse(
session_id=session.session_id,
task_id=task.task_id,
task_description=task.description,
schema_info=SCHEMA_INFO,
steps_taken=0,
observation=session.to_observation(hint=task.hint),
)
def step(
self,
action: JointAction,
session_id: Optional[str] = None,
dispatch_action: Optional[DispatchAction] = None,
) -> StepResponse:
session = self._get_session(session_id)
if session.done:
return StepResponse(
observation=session.to_observation(error_message="Episode finished. Call reset."),
reward=compute_reward(
dispatch={
"delivered_supply_mwh": 0.0,
"unmet_demand_mwh": 0.0,
"oversupply_mwh": 0.0,
"correction_count": 0,
"storage_loss_mwh": 0.0,
"renewable_dispatch_mwh": 0.0,
},
clearing_price=session.last_clearing_price,
demand_mwh=max(1.0, session.demand_mwh),
prior_gap=0.0,
),
done=True,
truncated=False,
info={"error": "episode_done"},
)
applied_action = action
if session.operator_override_enabled:
applied_action = heuristic_joint_action(session.to_observation(), personality="risk_averse")
market = clear_market(applied_action.bids, leader_price_signal=session.base_price)
dispatch_action = dispatch_action or self._resolve_dispatch_action(session, market, session.to_observation())
effective_peaker_capacity = session.peaker_capacity_mwh * session.contingency_peaker_multiplier
effective_peaker_capacity += dispatch_action.reserve_activation_mwh + dispatch_action.peaker_adjustment_mwh
applied_ev_charge = applied_action.ev_charge_mwh
applied_ev_discharge = applied_action.ev_discharge_mwh
if dispatch_action.storage_dispatch_mwh >= 0.0:
applied_ev_discharge += dispatch_action.storage_dispatch_mwh
else:
applied_ev_charge += abs(dispatch_action.storage_dispatch_mwh)
expected_residual = max(0.0, market.get("cleared_mwh", 0.0) - session.renewable_mwh)
if (
session.task.peaker_activation_delay_steps > 0
and not session.peaker_online
and session.peaker_activation_timer == 0
and expected_residual > 0.0
):
session.peaker_activation_timer = session.task.peaker_activation_delay_steps
session.event_log.append(
{
"step": session.step,
"type": "peaker_startup_delay",
"delay_steps": session.task.peaker_activation_delay_steps,
}
)
dispatch_override_active = (
dispatch_action.reserve_activation_mwh > 0.0
or dispatch_action.peaker_adjustment_mwh > 0.0
or dispatch_action.storage_dispatch_mwh > 0.0
or dispatch_action.corrective_redispatch_mwh > 0.0
)
if dispatch_override_active and session.peaker_activation_timer > 0:
session.peaker_activation_timer = 0
if session.peaker_activation_timer > 0:
effective_peaker_capacity = 0.0
session.peaker_activation_timer -= 1
adjusted_market = dict(market)
dispatch_target_shift = dispatch_action.corrective_redispatch_mwh
adjusted_market["cleared_mwh"] = max(0.0, adjusted_market.get("cleared_mwh", 0.0) + dispatch_target_shift)
adjusted_market["dispatcher_action"] = dispatch_action.model_dump()
dispatch, next_storage = enforce_dispatch(
market_result=adjusted_market,
demand_mwh=session.demand_mwh,
renewable_available_mwh=session.renewable_mwh,
peaker_capacity_mwh=effective_peaker_capacity,
ev_storage_mwh=session.ev_storage_mwh,
ev_storage_capacity_mwh=session.ev_storage_capacity_mwh,
ev_charge_mwh=applied_ev_charge,
ev_discharge_mwh=applied_ev_discharge,
reserve_margin_ratio=session.task.reserve_margin_ratio,
reserve_commitment_threshold_ratio=session.task.reserve_commitment_threshold_ratio,
peaker_ramp_limit_mwh=session.task.peaker_ramp_limit_mwh,
ev_ramp_limit_mwh=session.task.ev_ramp_limit_mwh,
previous_peaker_dispatch_mwh=session.previous_peaker_dispatch_mwh,
previous_ev_discharge_mwh=session.previous_ev_discharge_mwh,
previous_peaker_online=session.peaker_online,
peaker_startup_cost_usd=session.task.peaker_startup_cost_usd,
peaker_emission_factor_tco2_per_mwh=session.task.peaker_emission_factor_tco2_per_mwh,
transmission_loss_multiplier=session.contingency_loss_multiplier,
carbon_price_usd_per_tco2=session.task.carbon_price_usd_per_tco2,
enable_reserve_logic=session.task.enable_reserve_logic,
enable_ramp_limits=session.task.enable_ramp_limits,
enable_startup_emissions=session.task.enable_startup_emissions,
)
reward = compute_reward(
dispatch=dispatch,
clearing_price=market["clearing_price"] or session.base_price,
demand_mwh=session.demand_mwh,
prior_gap=session.prior_gap,
carbon_price_usd_per_tco2=session.task.carbon_price_usd_per_tco2,
)
session.step += 1
session.ev_storage_mwh = next_storage
session.last_clearing_price = market["clearing_price"] or session.base_price
session.prior_gap = dispatch["delivered_supply_mwh"] - session.demand_mwh
session.previous_peaker_dispatch_mwh = dispatch.get("peaker_dispatch_mwh", 0.0)
session.previous_ev_discharge_mwh = dispatch.get("ev_discharge_mwh", 0.0)
session.peaker_online = bool(dispatch.get("peaker_online", False))
session.correction_count += dispatch["correction_count"]
if dispatch["correction_count"] > 0:
session.infeasible_actions += 1
session.total_demand_met += min(session.demand_mwh, dispatch["delivered_supply_mwh"])
energy_cost = dispatch["delivered_supply_mwh"] * session.last_clearing_price
session.total_cost += energy_cost + dispatch.get("startup_cost_usd", 0.0) + dispatch.get("emissions_cost_usd", 0.0)
session.total_emissions_tco2 += dispatch.get("emissions_tco2", 0.0)
if dispatch["unmet_demand_mwh"] > 0.0:
session.blackout_steps += 1
if dispatch.get("reserve_commitment_active", False):
session.reserve_commitment_events += 1
if dispatch.get("emergency_dispatch_triggered", False):
session.emergency_dispatch_events += 1
if dispatch.get("stability_risk_index", 0.0) >= 0.45:
session.stability_events += 1
session.reward_history.append(reward.score)
private_views = self._build_private_agent_views(session, market, dispatch)
next_demand, next_renewable, next_price, dyn_info = evolve_grid(
demand_mwh=session.demand_mwh,
renewable_mwh=session.renewable_mwh,
base_price_usd_per_mwh=session.base_price,
step=session.step,
task=session.task,
rng=session.rng,
)
session.demand_mwh = next_demand
session.renewable_mwh = next_renewable
session.base_price = next_price
session.shock_seen = session.shock_seen or dyn_info["shock_active"]
session.contingency_seen = session.contingency_seen or dyn_info.get("contingency_active", False)
session.contingency_type = dyn_info.get("contingency_type", "none")
session.forecast_demand_mwh = dyn_info.get("forecast_demand_mwh", session.demand_mwh)
session.forecast_renewable_mwh = dyn_info.get("forecast_renewable_mwh", session.renewable_mwh)
session.load_forecast_error_mwh = dyn_info.get("load_forecast_error_mwh", 0.0)
session.renewable_forecast_error_mwh = dyn_info.get("renewable_forecast_error_mwh", 0.0)
session.contingency_peaker_multiplier = dyn_info.get("peaker_capacity_multiplier", 1.0)
session.contingency_loss_multiplier = dyn_info.get("transmission_loss_multiplier", 1.0)
event = {
"step": session.step,
"market": market,
"dispatch_action": dispatch_action.model_dump(),
"dispatch": dispatch,
"reward": reward.model_dump(),
"dynamics": dyn_info,
"agent_private_views": private_views,
}
session.event_log.append(event)
done = session.step >= session.task.max_steps
session.done = done
info = {
"market": market,
"dispatch_action": dispatch_action.model_dump(),
"dispatch": dispatch,
"dynamics": dyn_info,
"agent_private_views": private_views,
"summary": {
"avg_reward": round(sum(session.reward_history) / len(session.reward_history), 4),
"total_demand_met_mwh": round(session.total_demand_met, 3),
"total_cost_usd": round(session.total_cost, 3),
"total_emissions_tco2": round(session.total_emissions_tco2, 4),
"blackout_steps": session.blackout_steps,
"infeasible_actions": session.infeasible_actions,
"ldu_corrections": session.correction_count,
"reserve_commitment_events": session.reserve_commitment_events,
"emergency_dispatch_events": session.emergency_dispatch_events,
"stability_events": session.stability_events,
"leader_adjusted_bids": market["leader_adjusted_bids"],
"personality_map": session.personalities,
"operator_override_enabled": session.operator_override_enabled,
},
}
return StepResponse(
observation=session.to_observation(),
reward=reward,
done=done,
truncated=False,
info=info,
)
def policy_action(
self,
policy: str = "adaptive",
personality: str = "balanced",
session_id: Optional[str] = None,
) -> JointAction:
session = self._get_session(session_id)
obs = session.to_observation()
if policy == "random":
return random_joint_action(obs, session.rng)
if policy == "heuristic":
return heuristic_joint_action(obs, personality=personality)
return adaptive_stackelberg_action(obs, personality=personality)
def dispatch_action(
self,
personality: str = "balanced",
session_id: Optional[str] = None,
cleared_mwh: Optional[float] = None,
) -> DispatchAction:
session = self._get_session(session_id)
obs = session.to_observation()
controller = ReliabilityDispatchControlAgent(personality=personality)
return controller.act(obs, cleared_mwh=float(cleared_mwh if cleared_mwh is not None else obs.demand_mwh))
def state(self, session_id: Optional[str] = None) -> StateResponse:
session = self._get_session(session_id)
return StateResponse(
current_task_id=session.task.task_id,
steps_taken=session.step,
episode_done=session.done,
observation=session.to_observation(),
)
def events(self, session_id: Optional[str] = None) -> Dict:
session = self._get_session(session_id)
return {"session_id": session.session_id, "events": session.event_log[-50:]}
def inject_shock(self, session_id: Optional[str] = None, renewable_drop_mwh: float = 20.0) -> Dict:
session = self._get_session(session_id)
before = session.renewable_mwh
session.renewable_mwh = max(0.0, session.renewable_mwh - max(0.0, renewable_drop_mwh))
session.shock_seen = True
event = {
"step": session.step,
"type": "manual_shock",
"renewable_before_mwh": round(before, 3),
"renewable_after_mwh": round(session.renewable_mwh, 3),
"drop_mwh": round(max(0.0, renewable_drop_mwh), 3),
}
session.event_log.append(event)
return {
"session_id": session.session_id,
"shock_event": event,
"observation": session.to_observation(),
}
def get_schema(self) -> Dict:
return {
"action_schema": JointAction.model_json_schema(),
"dispatch_action_schema": DispatchAction.model_json_schema(),
"observation_schema": MarketObservation.model_json_schema(),
"reward_schema": MarketReward.model_json_schema(),
"tasks": list_tasks(),
"notes": "Hybrid Theme 1+2+3.1 baseline implementation with the Physics-Constrained Safety Shield as core physical layer",
}
def set_operator_override(self, enabled: bool, session_id: Optional[str] = None) -> Dict:
session = self._get_session(session_id)
session.operator_override_enabled = bool(enabled)
event = {
"step": session.step,
"type": "operator_override",
"enabled": session.operator_override_enabled,
}
session.event_log.append(event)
return {
"session_id": session.session_id,
"operator_override_enabled": session.operator_override_enabled,
"event": event,
}
def _get_session(self, session_id: Optional[str]) -> Session:
sid = session_id or self._latest_session_id
if sid is None or sid not in self._sessions:
raise KeyError("No active session. Call /reset first.")
return self._sessions[sid]
def _build_private_agent_views(self, session: Session, market: Dict, dispatch: Dict) -> Dict[str, Dict]:
scarcity = max(0.0, (session.demand_mwh - session.renewable_mwh) / max(session.demand_mwh, 1e-6))
spread = max(0.0, session.base_price - session.last_clearing_price)
return {
"renewable_1": {
"personality": session.personalities.get("renewable_1", "balanced"),
"curtailment_risk": round(max(0.0, session.renewable_mwh - market.get("cleared_mwh", 0.0)), 3),
"forecast_bias": round(session.rng.uniform(-3.0, 3.0), 3),
},
"peaker_1": {
"personality": session.personalities.get("peaker_1", "balanced"),
"scarcity_index": round(scarcity, 4),
"margin_signal": round(market.get("clearing_price", session.base_price) - 42.0, 3),
},
"industrial_1": {
"personality": session.personalities.get("industrial_1", "balanced"),
"budget_pressure": round(
market.get("clearing_price", session.base_price) / max(session.base_price, 1e-6),
4,
),
"unmet_demand_mwh": dispatch["unmet_demand_mwh"],
},
"ev_1": {
"personality": session.personalities.get("ev_1", "balanced"),
"soc_ratio": round(session.ev_storage_mwh / max(session.ev_storage_capacity_mwh, 1e-6), 4),
"arbitrage_spread": round(spread, 3),
},
}
def _resolve_dispatch_action(
self,
session: Session,
market: Dict,
observation: MarketObservation,
personality: str = "balanced",
) -> DispatchAction:
controller = ReliabilityDispatchControlAgent(personality=personality)
return controller.act(observation, cleared_mwh=float(market.get("cleared_mwh", observation.demand_mwh)))