""" models.py — Typed Pydantic models for the USAR OpenEnv environment. All observations, actions, and rewards are fully typed and validated. These form the contract between the environment server and any agent. """ from __future__ import annotations from enum import Enum from typing import Dict, List, Optional from pydantic import BaseModel, Field, field_validator # --------------------------------------------------------------------------- # Enumerations # --------------------------------------------------------------------------- class ActionType(str, Enum): """Every discrete action the incident commander can take in one step.""" ASSIGN_TEAM = "assign_team" # Assign a rescue team to a site RECALL_TEAM = "recall_team" # Pull a team back to staging REQUEST_AIR = "request_air" # Call helicopter support for one site DEPLOY_MEDKIT = "deploy_medkit" # Airdrop medical supplies to a site TRIAGE_EVACUATE = "triage_evacuate" # Evacuate critical survivors from site WAIT = "wait" # Hold position / gather intel class DebrisType(str, Enum): LIGHT = "light" # Collapsed interior walls — fast clearance MODERATE = "moderate" # Partial floor/roof collapse — moderate time HEAVY = "heavy" # Reinforced concrete — requires heavy equipment class SiteStatus(str, Enum): ACTIVE = "active" # Survivors trapped, rescue ongoing or pending CLEARED = "cleared" # All accessible survivors rescued COLLAPSED = "collapsed" # Secondary collapse — site inaccessible EVACUATED = "evacuated" # Site fully evacuated by air support class TeamStatus(str, Enum): IDLE = "idle" # At staging, ready for assignment DEPLOYED = "deployed" # Currently working a site FATIGUED = "fatigued" # Needs rest before next deployment AIRBORNE = "airborne" # In helicopter — temporarily unavailable # --------------------------------------------------------------------------- # Sub-models # --------------------------------------------------------------------------- class SiteObservation(BaseModel): """State of a single collapse site as observed by the incident commander.""" site_id: int = Field(..., description="Unique site index 0–7") name: str = Field(..., description="Human-readable site name") status: SiteStatus debris_type: DebrisType trapped_survivors: int = Field(..., ge=0, description="Confirmed trapped count") critical_count: int = Field(..., ge=0, description="Life-threatening injuries") survival_probability: float = Field(..., ge=0.0, le=1.0, description="Current avg survival prob") decay_rate: float = Field(..., ge=0.0, le=1.0, description="Survival prob drop per step") rescue_progress: float = Field(..., ge=0.0, le=1.0, description="Fraction of site cleared") assigned_team_id: Optional[int] = Field(None, description="Team currently working this site") requires_air_support: bool = Field(False, description="Heavy debris requires air asset") secondary_collapse_risk: float = Field(..., ge=0.0, le=1.0, description="Prob of collapse next step") distance_from_staging: float = Field(..., ge=0.0, description="Travel time in steps") survivors_rescued: int = Field(0, ge=0, description="Already rescued from this site") class TeamObservation(BaseModel): """State of a single rescue team.""" team_id: int = Field(..., description="Unique team index 0–2") name: str status: TeamStatus current_site_id: Optional[int] = Field(None, description="Site being worked, None if idle") fatigue_level: float = Field(..., ge=0.0, le=1.0, description="0=fresh, 1=exhausted") specialization: str = Field(..., description="heavy_rescue | medical | swift_water") steps_at_site: int = Field(0, ge=0, description="Steps spent at current site") efficiency_multiplier: float = Field(1.0, ge=0.1, le=2.0, description="Performance modifier") class ResourceState(BaseModel): """Shared resource pool available to the incident commander.""" medkits_remaining: int = Field(..., ge=0) heavy_equipment_charges: int = Field(..., ge=0, description="Uses of heavy machinery left") air_support_available: bool = Field(..., description="Helicopter currently available") air_support_cooldown: int = Field(0, ge=0, description="Steps until air support returns") communication_quality: float = Field(1.0, ge=0.0, le=1.0, description="1=clear, 0=blackout") class EventNotification(BaseModel): """A real-time event that occurred this step — aftershock, trapped comm, etc.""" event_type: str site_id: Optional[int] description: str severity: float = Field(..., ge=0.0, le=1.0) # --------------------------------------------------------------------------- # Primary API models # --------------------------------------------------------------------------- class USARObservation(BaseModel): """ Complete observable state returned after reset() or step(). Reflects what a real incident commander would know. """ step: int max_steps: int elapsed_hours: float = Field(..., description="Simulated hours since earthquake") sites: List[SiteObservation] teams: List[TeamObservation] resources: ResourceState total_survivors_rescued: int = Field(0, ge=0) total_survivors_lost: int = Field(0, ge=0) total_survivors_remaining: int = Field(0, ge=0) cumulative_reward: float events_this_step: List[EventNotification] = Field(default_factory=list) task_name: str scenario_description: str class USARAction(BaseModel): """ A single action submitted by the incident commander agent. Not all fields are required — they depend on action_type. """ action_type: ActionType team_id: Optional[int] = Field(None, description="Required for ASSIGN_TEAM / RECALL_TEAM") site_id: Optional[int] = Field(None, description="Required for ASSIGN_TEAM / REQUEST_AIR / DEPLOY_MEDKIT / TRIAGE_EVACUATE") justification: Optional[str] = Field(None, description="Agent's reasoning — used for logging") @field_validator("team_id") @classmethod def validate_team_id(cls, v): if v is not None and v not in range(3): raise ValueError(f"team_id must be 0, 1, or 2; got {v}") return v @field_validator("site_id") @classmethod def validate_site_id(cls, v): if v is not None and v not in range(8): raise ValueError(f"site_id must be 0–7; got {v}") return v class StepResult(BaseModel): """Full result of a step() call.""" observation: USARObservation reward: float = Field(..., description="Immediate reward this step") done: bool info: Dict = Field(default_factory=dict) error: Optional[str] = None class ResetRequest(BaseModel): task: str = Field("single_site_rescue", description="Task name to run") seed: int = Field(42, description="Random seed for reproducibility") class ResetResult(BaseModel): observation: USARObservation task: str seed: int class GradeResult(BaseModel): """Final episode grading breakdown.""" task: str score: float = Field(..., ge=0.0, le=1.0) survivors_rescued: int survivors_lost: int max_possible_rescued: int rescue_rate: float time_efficiency: float resource_efficiency: float critical_save_bonus: float penalty_total: float passed: bool threshold: float breakdown: Dict[str, float] narrative: str class TaskInfo(BaseModel): """Metadata about a single task.""" name: str difficulty: str description: str steps: int threshold: float sites: int teams: int scenario: str