Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import Any | |
| from openenv.core.rubrics import Rubric, WeightedSum | |
| from .sim.world import World | |
| class ScoreBreakdown: | |
| throughput: float | |
| emergency: float | |
| fairness: float | |
| efficiency: float | |
| planning: float | |
| safety: float | |
| total: float | |
| def as_dict(self) -> dict: | |
| return { | |
| "throughput": round(self.throughput, 4), | |
| "emergency": round(self.emergency, 4), | |
| "fairness": round(self.fairness, 4), | |
| "efficiency": round(self.efficiency, 4), | |
| "planning": round(self.planning, 4), | |
| "safety": round(self.safety, 4), | |
| "total": round(self.total, 4), | |
| } | |
| class WorldRubric(Rubric): | |
| def forward(self, action: Any, observation: Any) -> float: | |
| world = observation | |
| return self.score(world) | |
| def score(self, world: World) -> float: | |
| raise NotImplementedError | |
| class ThroughputRubric(WorldRubric): | |
| def score(self, world: World) -> float: | |
| spawned = world.metrics.spawned_civilian + world.metrics.spawned_emergency | |
| if spawned == 0: | |
| return 0.0 | |
| cleared_vehicles = [v for v in world.vehicles.values() if v.cleared] | |
| frac = len(cleared_vehicles) / spawned | |
| if not cleared_vehicles: | |
| return 0.0 | |
| slowdowns: list[float] = [] | |
| for v in cleared_vehicles: | |
| optimal = sum(world.roads[rid].length for rid in v.route) | |
| actual = (v.clear_tick or world.tick) - v.spawn_tick | |
| slowdowns.append(actual / max(1, optimal)) | |
| mean_slowdown = sum(slowdowns) / len(slowdowns) | |
| speed_score = max(0.0, min(1.0, 1.0 - (mean_slowdown - 1.0) / 0.8)) | |
| return 0.3 * frac + 0.7 * speed_score | |
| class EmergencyRubric(WorldRubric): | |
| def score(self, world: World) -> float: | |
| spawned = world.metrics.spawned_emergency | |
| if spawned == 0: | |
| return 1.0 | |
| cleared = world.metrics.cleared_emergency | |
| clear_frac = cleared / spawned | |
| times = world.metrics.emergency_clear_times | |
| if not times: | |
| return 0.0 | |
| mean_clear_ticks = sum(times) / len(times) | |
| budget_per_em = 40.0 | |
| speed_score = max(0.0, 1.0 - mean_clear_ticks / budget_per_em) | |
| return 0.5 * clear_frac + 0.5 * (clear_frac * speed_score) | |
| class FairnessRubric(WorldRubric): | |
| def score(self, world: World) -> float: | |
| budget = 150 | |
| max_wait = world.metrics.max_wait_ticks_seen | |
| if max_wait <= 0: | |
| return 1.0 | |
| return max(0.0, 1.0 - max_wait / budget) | |
| class EfficiencyRubric(WorldRubric): | |
| def score(self, world: World) -> float: | |
| total_ticks_seen = max(1, world.tick * max(1, len(world.intersections))) | |
| wasted = world.metrics.wasted_green_ticks | |
| ratio = wasted / total_ticks_seen | |
| return max(0.0, 1.0 - 6.0 * ratio) | |
| class PlanningRubric(WorldRubric): | |
| def score(self, world: World) -> float: | |
| budget = world.interventions_budget | |
| used = world.interventions_used | |
| invalid = world.metrics.invalid_actions | |
| if budget == 0: | |
| base = 1.0 | |
| else: | |
| over = max(0, used - budget) | |
| base = 1.0 - (over / max(1, budget)) | |
| penalty = min(1.0, invalid * 0.1) | |
| return max(0.0, base - penalty) | |
| class SafetyRubric(WorldRubric): | |
| def score(self, world: World) -> float: | |
| if world.metrics.gridlock_events == 0: | |
| return 1.0 | |
| return max(0.0, 1.0 - 0.5 * world.metrics.gridlock_events) | |
| RUBRIC_CLASSES = { | |
| "throughput": ThroughputRubric, | |
| "emergency": EmergencyRubric, | |
| "fairness": FairnessRubric, | |
| "efficiency": EfficiencyRubric, | |
| "planning": PlanningRubric, | |
| "safety": SafetyRubric, | |
| } | |
| TASK_WEIGHTS: dict[str, dict[str, float]] = { | |
| "grid_balanced": { | |
| "throughput": 0.40, | |
| "emergency": 0.15, | |
| "fairness": 0.15, | |
| "efficiency": 0.15, | |
| "planning": 0.05, | |
| "safety": 0.10, | |
| }, | |
| "demand_shift": { | |
| "throughput": 0.35, | |
| "emergency": 0.10, | |
| "fairness": 0.20, | |
| "efficiency": 0.15, | |
| "planning": 0.10, | |
| "safety": 0.10, | |
| }, | |
| "incident_corridor": { | |
| "throughput": 0.15, | |
| "emergency": 0.40, | |
| "fairness": 0.10, | |
| "efficiency": 0.10, | |
| "planning": 0.15, | |
| "safety": 0.10, | |
| }, | |
| "rush_hour_wave": { | |
| "throughput": 0.35, | |
| "emergency": 0.10, | |
| "fairness": 0.25, | |
| "efficiency": 0.10, | |
| "planning": 0.10, | |
| "safety": 0.10, | |
| }, | |
| "multi_crisis": { | |
| "throughput": 0.15, | |
| "emergency": 0.30, | |
| "fairness": 0.10, | |
| "efficiency": 0.10, | |
| "planning": 0.20, | |
| "safety": 0.15, | |
| }, | |
| } | |
| DIMENSION_ORDER = ["throughput", "emergency", "fairness", "efficiency", "planning", "safety"] | |
| def build_rubric(task: str) -> WeightedSum: | |
| weights = TASK_WEIGHTS.get(task, TASK_WEIGHTS["grid_balanced"]) | |
| rubrics = [RUBRIC_CLASSES[dim]() for dim in DIMENSION_ORDER] | |
| weight_list = [weights[dim] for dim in DIMENSION_ORDER] | |
| return WeightedSum(rubrics, weight_list) | |
| def grade(world: World) -> ScoreBreakdown: | |
| rubric = build_rubric(world.task) | |
| total = rubric(None, world) | |
| total = max(0.0, min(1.0, total)) | |
| scores = {} | |
| for i, dim in enumerate(DIMENSION_ORDER): | |
| scores[dim] = rubric._rubric_list[i].last_score | |
| return ScoreBreakdown( | |
| throughput=scores["throughput"], | |
| emergency=scores["emergency"], | |
| fairness=scores["fairness"], | |
| efficiency=scores["efficiency"], | |
| planning=scores["planning"], | |
| safety=scores["safety"], | |
| total=total, | |
| ) | |