JaamCTRL-OpenEnv / env /models.py
Akshara
Fresh start: Clean repo without large file history - ready for HF Spaces
f4f43f0
"""
env/models.py
───────────────────────────────────────────────────────────────────────────────
OpenEnv-compliant Pydantic models for JaamCTRL.
Defines:
- Observation: typed, structured observation from the environment
- Action: agent action (multi-discrete phase selection)
- Reward: reward breakdown with detailed metrics
───────────────────────────────────────────────────────────────────────────────
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class ObservationData(BaseModel):
"""
Structured environment observation.
All vectors are zero-padded to size 3 (max intersections).
Inactive intersections have zeros in their slots.
"""
queue_lengths: List[List[float]] = Field(
..., description="Queue length per lane (3 intersections × 4 lanes)"
)
current_phase: List[int] = Field(
..., description="Current traffic light phase (0-3) per intersection"
)
phase_elapsed: List[float] = Field(
..., description="Seconds elapsed in current phase per intersection"
)
probe_density: List[List[float]] = Field(
..., description="GPS probe vehicle density in 8 spatial bins per intersection"
)
incident_active: List[bool] = Field(
..., description="Whether an incident is active per intersection"
)
time_of_day_norm: float = Field(
..., description="Normalized time of day (0.0 to 1.0, 24-hour cycle)"
)
class Config:
json_schema_extra = {
"example": {
"queue_lengths": [[10.0, 8.0, 5.0, 3.0], [0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0]],
"current_phase": [0, 2, 0],
"phase_elapsed": [7.5, 3.2, 0.0],
"probe_density": [[0.5, 0.6, 0.7, 0.4, 0.3, 0.2, 0.1, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
"incident_active": [False, False, False],
"time_of_day_norm": 0.5,
}
}
class ActionData(BaseModel):
"""
Agent action: one phase (0-3) per active traffic light.
- Task 1: 1 phase (Barakhamba only)
- Task 2: 2 phases (Barakhamba + CP Core)
- Task 3: 3 phases (full corridor)
"""
phases: List[int] = Field(
..., description="Phase selection (0-3) per traffic light, zero-padded to size 3"
)
def validate_phases(self):
"""All phase values must be 0-3."""
for phase in self.phases:
assert 0 <= phase <= 3, f"Invalid phase {phase}: must be 0-3"
class Config:
json_schema_extra = {
"example": {"phases": [0, 2, 1]}
}
class RewardBreakdown(BaseModel):
"""
Detailed breakdown of reward components for transparency.
"""
r_queue: float = Field(..., description="Penalty from total queue length")
r_throughput: float = Field(..., description="Reward from vehicles exiting")
r_stops: float = Field(..., description="Penalty from new stop events")
r_long_wait: float = Field(..., description="Hard penalty for excessive waits")
r_green_wave: float = Field(..., description="Bonus for green-wave coordination")
r_overflow: float = Field(..., description="Penalty for lane overflow")
r_incident: float = Field(..., description="Bonus when incident queue clears")
r_thrash: float = Field(..., description="Penalty for rapid phase switching")
class Config:
json_schema_extra = {
"example": {
"r_queue": -2.5,
"r_throughput": 1.2,
"r_stops": -0.3,
"r_long_wait": 0.0,
"r_green_wave": 0.5,
"r_overflow": 0.0,
"r_incident": 0.0,
"r_thrash": 0.0,
}
}
class RewardData(BaseModel):
"""
Final reward and breakdown.
"""
total: float = Field(..., description="Total reward for this step (clipped to [-10, +5])")
breakdown: RewardBreakdown = Field(..., description="Reward component breakdown")
class Config:
json_schema_extra = {
"example": {
"total": -0.1,
"breakdown": {
"r_queue": -2.5,
"r_throughput": 1.2,
"r_stops": -0.3,
"r_long_wait": 0.0,
"r_green_wave": 0.5,
"r_overflow": 0.0,
"r_incident": 0.0,
"r_thrash": 0.0,
}
}
}
class StepOutput(BaseModel):
"""
Complete output of a single environment step.
"""
observation: ObservationData
reward: float
done: bool
info: Dict[str, Any] = Field(
default_factory=dict,
description="Additional info: episode_metrics, metrics summary, etc."
)
class TaskMetrics(BaseModel):
"""
Final task performance metrics for grading.
"""
avg_delay_reduction: float = Field(
..., description="% reduction in avg vehicle delay vs. fixed-time baseline"
)
avg_throughput_improvement: float = Field(
..., description="% improvement in vehicles exiting intersections"
)
overflow_events: int = Field(
..., description="Number of lane overflow events (lanes exceeded capacity)"
)
total_reward: float = Field(
..., description="Accumulated reward over full episode"
)
episode_length: int = Field(
..., description="Number of steps completed"
)
class Config:
json_schema_extra = {
"example": {
"avg_delay_reduction": 28.5,
"avg_throughput_improvement": 15.2,
"overflow_events": 0,
"total_reward": 125.3,
"episode_length": 300,
}
}