| """ |
| Adaptive Curriculum — The environment gets harder as the agent improves. |
| |
| This is what makes SupplyChainEnv different from EVERY other submission. |
| Standard envs are static. This one learns from the agent's performance |
| and generates increasingly challenging scenarios. |
| |
| Inspired by: |
| - OpenAI's Automatic Domain Randomization (ADR) from Rubik's Cube paper |
| - Meta FAIR's PAIRED (Protagonist Antagonist Induced Regret Environment Design) |
| - Jiang et al., "Prioritized Level Replay" (ICML 2021) |
| |
| How it works: |
| 1. Track agent performance per disruption type |
| 2. When agent masters a disruption type (>80% delivery), increase its frequency |
| 3. When agent struggles (<40% delivery), keep frequency but add hints |
| 4. Introduce NEW disruption combinations the agent hasn't seen |
| 5. Scale network complexity: more routes, tighter deadlines, higher stakes |
| |
| The result: the environment is always at the agent's skill boundary. |
| Too easy = boring. Too hard = no learning signal. Curriculum = optimal. |
| """ |
|
|
| import random |
| from typing import Any, Dict, List, Tuple |
| from dataclasses import dataclass, field |
|
|
| from world import SupplyChainWorld |
|
|
|
|
| @dataclass |
| class AgentProfile: |
| """Tracks what the agent is good/bad at.""" |
| episodes_played: int = 0 |
| total_reward: float = 0.0 |
|
|
| |
| disruption_scores: Dict[str, List[float]] = field(default_factory=lambda: { |
| "typhoon": [], "port_strike": [], "factory_fire": [], "canal_blockage": [], |
| "pandemic_wave": [], "earthquake": [], "sanctions": [], "fuel_shortage": [], |
| "cyber_attack": [], "monsoon": [], |
| }) |
|
|
| |
| difficulty_scores: Dict[str, List[float]] = field(default_factory=lambda: { |
| "easy": [], "medium": [], "hard": [], |
| }) |
|
|
| |
| mastered: List[str] = field(default_factory=list) |
| |
| struggling: List[str] = field(default_factory=list) |
|
|
| def record_episode(self, reward: float, difficulty: str, disruption_types: List[str]) -> None: |
| self.episodes_played += 1 |
| self.total_reward += reward |
| self.difficulty_scores[difficulty].append(reward) |
|
|
| for dt in disruption_types: |
| if dt in self.disruption_scores: |
| self.disruption_scores[dt].append(reward) |
|
|
| |
| self.mastered = [] |
| self.struggling = [] |
| for dt, scores in self.disruption_scores.items(): |
| if len(scores) >= 3: |
| avg = sum(scores[-5:]) / len(scores[-5:]) |
| if avg >= 0.4: |
| self.mastered.append(dt) |
| elif avg < 0.2: |
| self.struggling.append(dt) |
|
|
| @property |
| def current_level(self) -> int: |
| """Agent's current curriculum level (1-10).""" |
| if self.episodes_played < 3: |
| return 1 |
| avg = self.total_reward / self.episodes_played |
| if avg < 0.1: |
| return 1 |
| elif avg < 0.2: |
| return 2 |
| elif avg < 0.3: |
| return 3 |
| elif avg < 0.35: |
| return 4 |
| elif avg < 0.4: |
| return 5 |
| elif avg < 0.45: |
| return 6 |
| elif avg < 0.5: |
| return 7 |
| elif avg < 0.55: |
| return 8 |
| elif avg < 0.6: |
| return 9 |
| else: |
| return 10 |
|
|
| def get_summary(self) -> Dict[str, Any]: |
| return { |
| "episodes": self.episodes_played, |
| "avg_reward": self.total_reward / max(1, self.episodes_played), |
| "level": self.current_level, |
| "mastered": self.mastered, |
| "struggling": self.struggling, |
| } |
|
|
|
|
| class AdaptiveCurriculum: |
| """Generates scenarios at the agent's skill boundary. |
| |
| Uses the agent's performance history to: |
| 1. Pick disruption types the agent hasn't mastered |
| 2. Adjust deadline tightness based on delivery rate |
| 3. Scale network complexity |
| 4. Introduce unseen disruption combinations |
| """ |
|
|
| def __init__(self): |
| self.profile = AgentProfile() |
| self.rng = random.Random(42) |
| self._episode_counter = 0 |
|
|
| def generate_episode(self, seed: int = None) -> Tuple[SupplyChainWorld, Dict[str, Any]]: |
| """Generate a scenario tuned to the agent's current level.""" |
| self._episode_counter += 1 |
| actual_seed = seed if seed is not None else self._episode_counter |
|
|
| level = self.profile.current_level |
| curriculum_info = self._design_scenario(level) |
|
|
| world = SupplyChainWorld( |
| seed=actual_seed, |
| difficulty=curriculum_info["difficulty"], |
| total_days=curriculum_info["total_days"], |
| ) |
|
|
| |
| self._inject_curriculum_disruptions(world, curriculum_info) |
|
|
| |
| self._adjust_deadlines(world, curriculum_info) |
|
|
| return world, curriculum_info |
|
|
| def record_result(self, reward: float, world: SupplyChainWorld, curriculum_info: Dict) -> Dict[str, Any]: |
| """Record episode result and return updated profile.""" |
| disruption_types = [d.type for d in world.disruptions] |
| self.profile.record_episode(reward, curriculum_info["difficulty"], disruption_types) |
|
|
| return { |
| "level_before": curriculum_info.get("level", 1), |
| "level_after": self.profile.current_level, |
| "level_changed": curriculum_info.get("level", 1) != self.profile.current_level, |
| "profile": self.profile.get_summary(), |
| } |
|
|
| def _design_scenario(self, level: int) -> Dict[str, Any]: |
| """Design scenario parameters based on curriculum level.""" |
|
|
| |
| if level <= 3: |
| return { |
| "level": level, |
| "difficulty": "easy", |
| "total_days": 30, |
| "n_disruptions": level, |
| "disruption_focus": self._pick_weakest_types(1), |
| "deadline_tightness": 1.0, |
| "cascade_probability": 0.0, |
| "description": f"Level {level}: Learning basics with {level} disruption(s)", |
| } |
|
|
| |
| elif level <= 6: |
| weak_types = self._pick_weakest_types(2) |
| return { |
| "level": level, |
| "difficulty": "medium", |
| "total_days": 30, |
| "n_disruptions": level - 1, |
| "disruption_focus": weak_types, |
| "deadline_tightness": 0.85, |
| "cascade_probability": 0.2, |
| "description": f"Level {level}: Targeting weaknesses ({', '.join(weak_types)})", |
| } |
|
|
| |
| elif level <= 9: |
| weak_types = self._pick_weakest_types(3) |
| return { |
| "level": level, |
| "difficulty": "hard", |
| "total_days": 25, |
| "n_disruptions": level, |
| "disruption_focus": weak_types, |
| "deadline_tightness": 0.7, |
| "cascade_probability": 0.5, |
| "description": f"Level {level}: Cascading crises with tight deadlines", |
| } |
|
|
| |
| else: |
| return { |
| "level": 10, |
| "difficulty": "hard", |
| "total_days": 20, |
| "n_disruptions": 10, |
| "disruption_focus": [], |
| "deadline_tightness": 0.6, |
| "cascade_probability": 0.8, |
| "description": "Level 10: MAXIMUM — all disruptions, cascading, tight deadlines", |
| } |
|
|
| def _pick_weakest_types(self, n: int) -> List[str]: |
| """Pick the disruption types the agent is worst at.""" |
| type_scores = {} |
| for dt, scores in self.profile.disruption_scores.items(): |
| if scores: |
| type_scores[dt] = sum(scores[-5:]) / len(scores[-5:]) |
| else: |
| type_scores[dt] = 0.5 |
|
|
| |
| sorted_types = sorted(type_scores, key=lambda t: type_scores[t]) |
| return sorted_types[:n] |
|
|
| def _inject_curriculum_disruptions(self, world: SupplyChainWorld, info: Dict) -> None: |
| """Override world disruptions based on curriculum design.""" |
| focus_types = info.get("disruption_focus", []) |
| if not focus_types: |
| return |
|
|
| |
| existing_types = {d.type for d in world.disruptions} |
| for ft in focus_types: |
| if ft not in existing_types and len(world.disruptions) < info["n_disruptions"]: |
| |
| from world import Disruption |
| start = self.rng.randint(0, world.total_days // 3) |
| world.disruptions.append(Disruption( |
| id=f"curriculum_{ft}", |
| type=ft, |
| severity="high", |
| description=f"Curriculum-generated {ft} event", |
| affected_nodes=self._pick_affected_nodes(world, ft), |
| affected_routes=[], |
| start_day=start, |
| end_day=start + self.rng.randint(3, 10), |
| capacity_reduction=0.8, |
| )) |
|
|
| def _adjust_deadlines(self, world: SupplyChainWorld, info: Dict) -> None: |
| """Tighten deadlines based on curriculum level.""" |
| tightness = info.get("deadline_tightness", 1.0) |
| if tightness >= 1.0: |
| return |
| for ship in world.shipments.values(): |
| ship.deadline_day = max(5, int(ship.deadline_day * tightness)) |
|
|
| def _pick_affected_nodes(self, world: SupplyChainWorld, disruption_type: str) -> List[str]: |
| """Pick nodes to affect based on disruption type.""" |
| if disruption_type in ("typhoon", "monsoon", "earthquake"): |
| asia_ports = [p for p in world.ports if world.ports[p].get("region") == "asia"] |
| return self.rng.sample(asia_ports, min(2, len(asia_ports))) |
| elif disruption_type in ("port_strike",): |
| europe_ports = [p for p in world.ports if world.ports[p].get("region") == "europe"] |
| return self.rng.sample(europe_ports, min(1, len(europe_ports))) |
| elif disruption_type == "factory_fire": |
| factories = list(world.factories.keys()) |
| return [self.rng.choice(factories)] |
| elif disruption_type == "cyber_attack": |
| return [self.rng.choice(list(world.ports.keys()))] |
| else: |
| return [self.rng.choice(list(world.ports.keys()))] |
|
|
|
|
| def demo_curriculum(): |
| """Demo: watch the curriculum adapt as the agent plays.""" |
| from server.supply_chain_environment import SupplyChainEnvironment |
| from models import SupplyChainAction |
|
|
| def tool(n, a=None): |
| return SupplyChainAction(action_type="ToolCallAction", tool_name=n, arguments=a or {}) |
|
|
| curriculum = AdaptiveCurriculum() |
|
|
| print("=" * 70) |
| print(" ADAPTIVE CURRICULUM DEMO") |
| print(" Environment gets harder as agent improves") |
| print("=" * 70) |
|
|
| for episode in range(10): |
| world, info = curriculum.generate_episode() |
|
|
| |
| env = SupplyChainEnvironment() |
| env._world = world |
| env._done = False |
| env._step_count = 0 |
|
|
| obs = env.step(tool("view_shipments")) |
| if obs.tool_result: |
| for s in obs.tool_result.get("shipments", []): |
| if s["status"] == "pending": |
| p = env.step(tool("find_path", {"from_port": s["current_location"], "to_warehouse": s["destination"]})) |
| if p.tool_result and p.tool_result.get("path"): |
| env.step(tool("route_shipment", {"shipment_id": s["id"], "route": p.tool_result["path"]})) |
|
|
| for _ in range(world.total_days): |
| obs = env.step(tool("advance_day")) |
| if obs.done: |
| break |
| if not obs.done: |
| obs = env.step(tool("end_simulation")) |
|
|
| reward = obs.reward or 0.0 |
| result = curriculum.record_result(reward, world, info) |
|
|
| level_indicator = "=" * result["profile"]["level"] + "." * (10 - result["profile"]["level"]) |
| print(f" Episode {episode+1:2d} | Level [{level_indicator}] {result['profile']['level']:2d}/10 | " |
| f"Reward: {reward:.3f} | {info['description'][:40]}") |
| if result["level_changed"]: |
| print(f" LEVEL UP! {result['level_before']} -> {result['level_after']}") |
|
|
| print(f"\n Final profile: {curriculum.profile.get_summary()}") |
| print("=" * 70) |
|
|
|
|
| if __name__ == "__main__": |
| demo_curriculum() |
|
|