grid2op-openenv / models.py
Sidharth1743's picture
tasks refined
54f256c
from __future__ import annotations
from typing import Any, Dict, List, Literal
from openenv.core.env_server.types import Action, Observation, State
from pydantic import BaseModel, Field
TaskId = Literal["single_fault", "n_minus_1", "cascade_prevent", "multi_stage_cascade"]
ScenarioMode = Literal["curriculum", "benchmark"]
class GridAction(Action):
"""JSON-serializable subset of Grid2Op actions."""
line_set: Dict[int, int] = Field(
default_factory=dict,
description="Map line id to status. Use -1 to disconnect and 1 to reconnect.",
)
redispatch: Dict[int, float] = Field(
default_factory=dict,
description="Map generator id to redispatch delta in MW.",
)
do_nothing: bool = Field(
default=False,
description="When true, ignore other fields and apply the native no-op action.",
)
class GridObservation(Observation):
"""Typed subset of the Grid2Op observation surface."""
rho: List[float] = Field(default_factory=list)
gen_p: List[float] = Field(default_factory=list)
load_p: List[float] = Field(default_factory=list)
line_status: List[bool] = Field(default_factory=list)
timestep_overflow: List[int] = Field(default_factory=list)
sensitivity_guidance: List[Dict[str, Any]] = Field(default_factory=list)
class EpisodeStepLog(BaseModel):
"""Structured per-step trace used by graders and debugging."""
step: int
task_id: TaskId
reward: float
raw_reward: float
done: bool
max_rho: float
redispatch_mw: float = 0.0
action_penalty: float = 0.0
n1_security_score: float = 0.0
reconnect_successful: bool = False
stage_index: int = 1
steps_to_stage_boundary: int = 0
available_load_ratio: float = 1.0
available_island_ratio: float = 1.0
stage_boundary_assessed: bool = False
majority_islands_available: bool = False
overloaded_line_ids: List[int] = Field(default_factory=list)
single_fault_target_threshold: float = 0.8
all_lines_below_target: bool = False
all_lines_below_80: bool = False
all_lines_below_90: bool = False
all_lines_below_100: bool = False
disconnected_lines: List[int] = Field(default_factory=list)
timestep_overflow: List[int] = Field(default_factory=list)
safe_line_ratio: float = 0.0
topology_change_count: int = 0
auto_trip_detected: bool = False
invalid_action: bool = False
invalid_action_reason: str | None = None
convergence_failed: bool = False
action: Dict[str, Any] = Field(default_factory=dict)
class GridState(State):
"""Environment state for the current Grid2Op episode."""
env_name: str = "l2rpn_case14_sandbox"
task_id: TaskId = "single_fault"
max_steps: int = 0
n_line: int = 0
n_gen: int = 0
last_reward: float = 0.0
done: bool = False
episode_log: List[EpisodeStepLog] = Field(default_factory=list)
scenario_metadata: Dict[str, Any] = Field(default_factory=dict)
class TaskInfo(BaseModel):
task_id: TaskId
difficulty: Literal["easy", "medium", "hard"]
description: str
max_steps: int
class TaskListResponse(BaseModel):
tasks: List[TaskInfo]
action_schema: Dict[str, Any]
class GraderRequest(BaseModel):
task_id: TaskId
episode_log: List[EpisodeStepLog]
class GraderResponse(BaseModel):
task_id: TaskId
score: float
class BaselineRequest(BaseModel):
model: str = Field(default="Qwen/Qwen3.5-9B")
max_tokens: int = Field(default=300, ge=1)
temperature: float = 0.7
top_p: float = 0.8
presence_penalty: float = 1.5
top_k: int = 20
min_p: float = 0.0
repetition_penalty: float = 1.0
enable_thinking: bool = False
num_seeds: int = Field(default=5, ge=1)
seed_start: int = Field(default=0, ge=0)
scenario_mode: ScenarioMode = Field(default="benchmark")
class BaselineScores(BaseModel):
model: str
scores: Dict[TaskId, float]
episode_lengths: Dict[TaskId, int]
class SimulationRequest(BaseModel):
episode_id: str
actions: List[GridAction]
class SimulationResult(BaseModel):
action: GridAction
max_rho: float
done: bool
simulated_reward: float
overloaded_line_ids: List[int] = Field(default_factory=list)
disconnected_lines: List[int] = Field(default_factory=list)
convergence_failed: bool = False
exceptions: List[str] = Field(default_factory=list)
raw_result: Dict[str, Any] = Field(default_factory=dict)
class SimulationResponse(BaseModel):
episode_id: str
results: List[SimulationResult]
class PlanningContextRequest(BaseModel):
episode_id: str
class RedispatchGeneratorContext(BaseModel):
gen_id: int
p_mw: float
max_ramp_up: float
max_ramp_down: float
allowed_delta_min: float
allowed_delta_max: float
allowed_deltas: List[float] = Field(default_factory=list)
class PlanningContextResponse(BaseModel):
episode_id: str
graph_intelligence: Dict[str, Any] = Field(default_factory=dict)
redispatchable_generators: List[int] = Field(default_factory=list)
redispatch_generators: List[RedispatchGeneratorContext] = Field(default_factory=list)