supply-chain-env / curriculum.py
ragavrida's picture
feat: adaptive curriculum — environment learns from agent and gets harder
c63ea5a
"""
Adaptive Curriculum — The environment gets harder as the agent improves.
This is what makes SupplyChainEnv different from EVERY other submission.
Standard envs are static. This one learns from the agent's performance
and generates increasingly challenging scenarios.
Inspired by:
- OpenAI's Automatic Domain Randomization (ADR) from Rubik's Cube paper
- Meta FAIR's PAIRED (Protagonist Antagonist Induced Regret Environment Design)
- Jiang et al., "Prioritized Level Replay" (ICML 2021)
How it works:
1. Track agent performance per disruption type
2. When agent masters a disruption type (>80% delivery), increase its frequency
3. When agent struggles (<40% delivery), keep frequency but add hints
4. Introduce NEW disruption combinations the agent hasn't seen
5. Scale network complexity: more routes, tighter deadlines, higher stakes
The result: the environment is always at the agent's skill boundary.
Too easy = boring. Too hard = no learning signal. Curriculum = optimal.
"""
import random
from typing import Any, Dict, List, Tuple
from dataclasses import dataclass, field
from world import SupplyChainWorld
@dataclass
class AgentProfile:
"""Tracks what the agent is good/bad at."""
episodes_played: int = 0
total_reward: float = 0.0
# Per-disruption-type performance
disruption_scores: Dict[str, List[float]] = field(default_factory=lambda: {
"typhoon": [], "port_strike": [], "factory_fire": [], "canal_blockage": [],
"pandemic_wave": [], "earthquake": [], "sanctions": [], "fuel_shortage": [],
"cyber_attack": [], "monsoon": [],
})
# Per-difficulty performance
difficulty_scores: Dict[str, List[float]] = field(default_factory=lambda: {
"easy": [], "medium": [], "hard": [],
})
# Mastered disruption types (>80% avg delivery)
mastered: List[str] = field(default_factory=list)
# Struggling disruption types (<40% avg delivery)
struggling: List[str] = field(default_factory=list)
def record_episode(self, reward: float, difficulty: str, disruption_types: List[str]) -> None:
self.episodes_played += 1
self.total_reward += reward
self.difficulty_scores[difficulty].append(reward)
for dt in disruption_types:
if dt in self.disruption_scores:
self.disruption_scores[dt].append(reward)
# Update mastered/struggling
self.mastered = []
self.struggling = []
for dt, scores in self.disruption_scores.items():
if len(scores) >= 3:
avg = sum(scores[-5:]) / len(scores[-5:])
if avg >= 0.4:
self.mastered.append(dt)
elif avg < 0.2:
self.struggling.append(dt)
@property
def current_level(self) -> int:
"""Agent's current curriculum level (1-10)."""
if self.episodes_played < 3:
return 1
avg = self.total_reward / self.episodes_played
if avg < 0.1:
return 1
elif avg < 0.2:
return 2
elif avg < 0.3:
return 3
elif avg < 0.35:
return 4
elif avg < 0.4:
return 5
elif avg < 0.45:
return 6
elif avg < 0.5:
return 7
elif avg < 0.55:
return 8
elif avg < 0.6:
return 9
else:
return 10
def get_summary(self) -> Dict[str, Any]:
return {
"episodes": self.episodes_played,
"avg_reward": self.total_reward / max(1, self.episodes_played),
"level": self.current_level,
"mastered": self.mastered,
"struggling": self.struggling,
}
class AdaptiveCurriculum:
"""Generates scenarios at the agent's skill boundary.
Uses the agent's performance history to:
1. Pick disruption types the agent hasn't mastered
2. Adjust deadline tightness based on delivery rate
3. Scale network complexity
4. Introduce unseen disruption combinations
"""
def __init__(self):
self.profile = AgentProfile()
self.rng = random.Random(42)
self._episode_counter = 0
def generate_episode(self, seed: int = None) -> Tuple[SupplyChainWorld, Dict[str, Any]]:
"""Generate a scenario tuned to the agent's current level."""
self._episode_counter += 1
actual_seed = seed if seed is not None else self._episode_counter
level = self.profile.current_level
curriculum_info = self._design_scenario(level)
world = SupplyChainWorld(
seed=actual_seed,
difficulty=curriculum_info["difficulty"],
total_days=curriculum_info["total_days"],
)
# Override disruptions based on curriculum
self._inject_curriculum_disruptions(world, curriculum_info)
# Tighten deadlines based on level
self._adjust_deadlines(world, curriculum_info)
return world, curriculum_info
def record_result(self, reward: float, world: SupplyChainWorld, curriculum_info: Dict) -> Dict[str, Any]:
"""Record episode result and return updated profile."""
disruption_types = [d.type for d in world.disruptions]
self.profile.record_episode(reward, curriculum_info["difficulty"], disruption_types)
return {
"level_before": curriculum_info.get("level", 1),
"level_after": self.profile.current_level,
"level_changed": curriculum_info.get("level", 1) != self.profile.current_level,
"profile": self.profile.get_summary(),
}
def _design_scenario(self, level: int) -> Dict[str, Any]:
"""Design scenario parameters based on curriculum level."""
# Level 1-3: Easy with focused disruptions
if level <= 3:
return {
"level": level,
"difficulty": "easy",
"total_days": 30,
"n_disruptions": level,
"disruption_focus": self._pick_weakest_types(1),
"deadline_tightness": 1.0, # normal
"cascade_probability": 0.0,
"description": f"Level {level}: Learning basics with {level} disruption(s)",
}
# Level 4-6: Medium with agent's weak spots
elif level <= 6:
weak_types = self._pick_weakest_types(2)
return {
"level": level,
"difficulty": "medium",
"total_days": 30,
"n_disruptions": level - 1,
"disruption_focus": weak_types,
"deadline_tightness": 0.85, # 15% tighter
"cascade_probability": 0.2,
"description": f"Level {level}: Targeting weaknesses ({', '.join(weak_types)})",
}
# Level 7-9: Hard with cascading failures
elif level <= 9:
weak_types = self._pick_weakest_types(3)
return {
"level": level,
"difficulty": "hard",
"total_days": 25, # less time
"n_disruptions": level,
"disruption_focus": weak_types,
"deadline_tightness": 0.7, # 30% tighter
"cascade_probability": 0.5,
"description": f"Level {level}: Cascading crises with tight deadlines",
}
# Level 10: Maximum challenge — everything at once
else:
return {
"level": 10,
"difficulty": "hard",
"total_days": 20, # very short
"n_disruptions": 10,
"disruption_focus": [], # all types
"deadline_tightness": 0.6, # 40% tighter
"cascade_probability": 0.8,
"description": "Level 10: MAXIMUM — all disruptions, cascading, tight deadlines",
}
def _pick_weakest_types(self, n: int) -> List[str]:
"""Pick the disruption types the agent is worst at."""
type_scores = {}
for dt, scores in self.profile.disruption_scores.items():
if scores:
type_scores[dt] = sum(scores[-5:]) / len(scores[-5:])
else:
type_scores[dt] = 0.5 # unknown = mid priority
# Sort by score ascending (worst first), pick top N
sorted_types = sorted(type_scores, key=lambda t: type_scores[t])
return sorted_types[:n]
def _inject_curriculum_disruptions(self, world: SupplyChainWorld, info: Dict) -> None:
"""Override world disruptions based on curriculum design."""
focus_types = info.get("disruption_focus", [])
if not focus_types:
return
# Keep existing disruptions but ensure focused types appear
existing_types = {d.type for d in world.disruptions}
for ft in focus_types:
if ft not in existing_types and len(world.disruptions) < info["n_disruptions"]:
# Add a focused disruption
from world import Disruption
start = self.rng.randint(0, world.total_days // 3)
world.disruptions.append(Disruption(
id=f"curriculum_{ft}",
type=ft,
severity="high",
description=f"Curriculum-generated {ft} event",
affected_nodes=self._pick_affected_nodes(world, ft),
affected_routes=[],
start_day=start,
end_day=start + self.rng.randint(3, 10),
capacity_reduction=0.8,
))
def _adjust_deadlines(self, world: SupplyChainWorld, info: Dict) -> None:
"""Tighten deadlines based on curriculum level."""
tightness = info.get("deadline_tightness", 1.0)
if tightness >= 1.0:
return
for ship in world.shipments.values():
ship.deadline_day = max(5, int(ship.deadline_day * tightness))
def _pick_affected_nodes(self, world: SupplyChainWorld, disruption_type: str) -> List[str]:
"""Pick nodes to affect based on disruption type."""
if disruption_type in ("typhoon", "monsoon", "earthquake"):
asia_ports = [p for p in world.ports if world.ports[p].get("region") == "asia"]
return self.rng.sample(asia_ports, min(2, len(asia_ports)))
elif disruption_type in ("port_strike",):
europe_ports = [p for p in world.ports if world.ports[p].get("region") == "europe"]
return self.rng.sample(europe_ports, min(1, len(europe_ports)))
elif disruption_type == "factory_fire":
factories = list(world.factories.keys())
return [self.rng.choice(factories)]
elif disruption_type == "cyber_attack":
return [self.rng.choice(list(world.ports.keys()))]
else:
return [self.rng.choice(list(world.ports.keys()))]
def demo_curriculum():
"""Demo: watch the curriculum adapt as the agent plays."""
from server.supply_chain_environment import SupplyChainEnvironment
from models import SupplyChainAction
def tool(n, a=None):
return SupplyChainAction(action_type="ToolCallAction", tool_name=n, arguments=a or {})
curriculum = AdaptiveCurriculum()
print("=" * 70)
print(" ADAPTIVE CURRICULUM DEMO")
print(" Environment gets harder as agent improves")
print("=" * 70)
for episode in range(10):
world, info = curriculum.generate_episode()
# Simple greedy agent
env = SupplyChainEnvironment()
env._world = world
env._done = False
env._step_count = 0
obs = env.step(tool("view_shipments"))
if obs.tool_result:
for s in obs.tool_result.get("shipments", []):
if s["status"] == "pending":
p = env.step(tool("find_path", {"from_port": s["current_location"], "to_warehouse": s["destination"]}))
if p.tool_result and p.tool_result.get("path"):
env.step(tool("route_shipment", {"shipment_id": s["id"], "route": p.tool_result["path"]}))
for _ in range(world.total_days):
obs = env.step(tool("advance_day"))
if obs.done:
break
if not obs.done:
obs = env.step(tool("end_simulation"))
reward = obs.reward or 0.0
result = curriculum.record_result(reward, world, info)
level_indicator = "=" * result["profile"]["level"] + "." * (10 - result["profile"]["level"])
print(f" Episode {episode+1:2d} | Level [{level_indicator}] {result['profile']['level']:2d}/10 | "
f"Reward: {reward:.3f} | {info['description'][:40]}")
if result["level_changed"]:
print(f" LEVEL UP! {result['level_before']} -> {result['level_after']}")
print(f"\n Final profile: {curriculum.profile.get_summary()}")
print("=" * 70)
if __name__ == "__main__":
demo_curriculum()