Spaces:
Restarting
Restarting
| """ | |
| Multi-task configuration for the OpenEnv bus routing environment. | |
| Three difficulty tiers — Easy, Medium, Hard — share the same | |
| ``BusRoutingEnv`` class but differ in the number of stops, passenger | |
| demand, fuel constraints, and penalty intensity. | |
| """ | |
| from __future__ import annotations | |
| import copy | |
| from dataclasses import dataclass | |
| from typing import Any, Dict | |
| from environment import BusRoutingEnv | |
| # Explicitly export task configurations for OpenEnv detection | |
| __all__ = [ | |
| "TaskConfig", | |
| "task_1", | |
| "task_2", | |
| "task_3", | |
| "task_4", | |
| "task_5", | |
| "TASKS", | |
| "TASK_EASY", | |
| "TASK_MEDIUM", | |
| "TASK_HARD", | |
| "get_task", | |
| ] | |
| class TaskConfig: | |
| """All parameters needed to instantiate a BusRoutingEnv for a task.""" | |
| name: str = "" | |
| description: str = "" | |
| difficulty: str = "medium" # easy | medium | hard | |
| num_stops: int = 10 | |
| num_buses: int = 1 | |
| max_steps: int = 150 | |
| seed: int = 42 | |
| bus_capacity: int = 30 | |
| fuel_start: float = 100.0 | |
| passenger_arrival_rate: float = 1.2 | |
| large_queue_threshold: int = 10 | |
| wait_time_threshold: int = 3 | |
| fuel_cost_move: float = 1.0 | |
| fuel_cost_wait: float = 0.2 | |
| background_bus_pickup_fraction: float = 0.6 | |
| new_stop_bonus: float = 1.0 | |
| idle_camping_penalty: float = 0.6 | |
| camping_grace_steps: int = 1 | |
| nearby_queue_ignore_penalty: float = 1.5 | |
| recent_window: int = 10 | |
| recent_unvisited_bonus: float = 1.0 | |
| repeat_stop_penalty: float = 0.5 | |
| high_queue_reward_threshold: int = 6 | |
| high_queue_visit_bonus: float = 2.0 | |
| reward_clip: float = 10.0 | |
| demand_profile: str = "synthetic" | |
| def build_env(self) -> BusRoutingEnv: | |
| import os | |
| m_steps = int(os.getenv("EVAL_MAX_STEPS", self.max_steps)) | |
| return BusRoutingEnv( | |
| num_stops=self.num_stops, | |
| num_buses=self.num_buses, | |
| max_steps=m_steps, | |
| seed=self.seed, | |
| bus_capacity=self.bus_capacity, | |
| fuel_start=self.fuel_start, | |
| passenger_arrival_rate=self.passenger_arrival_rate, | |
| large_queue_threshold=self.large_queue_threshold, | |
| wait_time_threshold=self.wait_time_threshold, | |
| fuel_cost_move=self.fuel_cost_move, | |
| fuel_cost_wait=self.fuel_cost_wait, | |
| background_bus_pickup_fraction=self.background_bus_pickup_fraction, | |
| new_stop_bonus=self.new_stop_bonus, | |
| idle_camping_penalty=self.idle_camping_penalty, | |
| camping_grace_steps=self.camping_grace_steps, | |
| nearby_queue_ignore_penalty=self.nearby_queue_ignore_penalty, | |
| recent_window=self.recent_window, | |
| recent_unvisited_bonus=self.recent_unvisited_bonus, | |
| repeat_stop_penalty=self.repeat_stop_penalty, | |
| high_queue_reward_threshold=self.high_queue_reward_threshold, | |
| high_queue_visit_bonus=self.high_queue_visit_bonus, | |
| reward_clip=self.reward_clip, | |
| demand_profile=self.demand_profile, | |
| ) | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "name": self.name, | |
| "difficulty": self.difficulty, | |
| "description": self.description, | |
| "num_stops": self.num_stops, | |
| "num_buses": self.num_buses, | |
| "max_steps": self.max_steps, | |
| "fuel_start": self.fuel_start, | |
| "passenger_arrival_rate": self.passenger_arrival_rate, | |
| "fuel_cost_move": self.fuel_cost_move, | |
| "fuel_cost_wait": self.fuel_cost_wait, | |
| "large_queue_threshold": self.large_queue_threshold, | |
| "bus_capacity": self.bus_capacity, | |
| } | |
| _TASK_EASY_TEMPLATE = TaskConfig( | |
| name="task_easy", | |
| description="Easy template", | |
| difficulty="easy", | |
| num_stops=5, | |
| num_buses=1, | |
| max_steps=100, | |
| seed=42, | |
| bus_capacity=30, | |
| fuel_start=100.0, | |
| passenger_arrival_rate=0.6, | |
| large_queue_threshold=12, | |
| wait_time_threshold=5, | |
| fuel_cost_move=0.5, | |
| fuel_cost_wait=0.1, | |
| new_stop_bonus=0.5, | |
| idle_camping_penalty=0.3, | |
| nearby_queue_ignore_penalty=0.5, | |
| repeat_stop_penalty=0.2, | |
| high_queue_reward_threshold=8, | |
| reward_clip=10.0, | |
| demand_profile="off_peak", | |
| ) | |
| _TASK_MEDIUM_TEMPLATE = TaskConfig( | |
| name="task_medium", | |
| description="Medium template", | |
| difficulty="medium", | |
| num_stops=10, | |
| num_buses=1, | |
| max_steps=150, | |
| seed=42, | |
| bus_capacity=30, | |
| fuel_start=100.0, | |
| passenger_arrival_rate=1.2, | |
| large_queue_threshold=10, | |
| wait_time_threshold=3, | |
| fuel_cost_move=1.0, | |
| fuel_cost_wait=0.2, | |
| new_stop_bonus=1.0, | |
| idle_camping_penalty=0.6, | |
| nearby_queue_ignore_penalty=1.5, | |
| repeat_stop_penalty=0.5, | |
| high_queue_reward_threshold=6, | |
| reward_clip=10.0, | |
| demand_profile="weekday", | |
| ) | |
| _TASK_HARD_TEMPLATE = TaskConfig( | |
| name="task_hard", | |
| description="Hard template", | |
| difficulty="hard", | |
| num_stops=12, | |
| num_buses=2, | |
| max_steps=200, | |
| seed=42, | |
| bus_capacity=25, | |
| fuel_start=80.0, | |
| passenger_arrival_rate=2.0, | |
| large_queue_threshold=8, | |
| wait_time_threshold=2, | |
| fuel_cost_move=1.5, | |
| fuel_cost_wait=0.4, | |
| new_stop_bonus=1.5, | |
| idle_camping_penalty=1.0, | |
| camping_grace_steps=0, | |
| nearby_queue_ignore_penalty=2.5, | |
| repeat_stop_penalty=0.8, | |
| high_queue_reward_threshold=5, | |
| high_queue_visit_bonus=3.0, | |
| reward_clip=15.0, | |
| demand_profile="peak_hour", | |
| ) | |
| task_1 = copy.deepcopy(_TASK_EASY_TEMPLATE) | |
| task_1.name = "task_1" | |
| task_1.description = "Easy task 1" | |
| task_2 = copy.deepcopy(_TASK_MEDIUM_TEMPLATE) | |
| task_2.name = "task_2" | |
| task_2.description = "Medium task 2" | |
| task_3 = copy.deepcopy(_TASK_HARD_TEMPLATE) | |
| task_3.name = "task_3" | |
| task_3.description = "Hard task 3" | |
| task_4 = copy.deepcopy(_TASK_MEDIUM_TEMPLATE) | |
| task_4.name = "task_4" | |
| task_4.description = "Medium task 4 (Alternative Seed)" | |
| task_4.seed = 99 | |
| task_5 = copy.deepcopy(_TASK_HARD_TEMPLATE) | |
| task_5.name = "task_5" | |
| task_5.description = "Hard task 5 (Extreme Peak)" | |
| task_5.passenger_arrival_rate = 2.5 | |
| task_5.seed = 123 | |
| TASKS: Dict[str, TaskConfig] = { | |
| "task_1": task_1, | |
| "task_2": task_2, | |
| "task_3": task_3, | |
| "task_4": task_4, | |
| "task_5": task_5, | |
| } | |
| TASK_EASY = task_1 | |
| TASK_MEDIUM = task_2 | |
| TASK_HARD = task_3 | |
| def get_task(name: str) -> TaskConfig: | |
| key = name.lower().strip() | |
| legacy_map = { | |
| "easy": "task_1", | |
| "medium": "task_2", | |
| "hard": "task_3", | |
| "task_11": "task_2", | |
| "task_21": "task_3", | |
| } | |
| key = legacy_map.get(key, key) | |
| if key not in TASKS: | |
| raise ValueError(f"Unknown task '{name}'. Choose from: {list(TASKS.keys())}") | |
| return TASKS[key] | |