"""Catalog builders. Two backends, same output schema: AnchorScenarioBuilder: deterministic, samples N variations per anchor with seeded jitter on knobs. No external dependencies. LLMScenarioBuilder: calls Claude via the anthropic SDK. Produces richer descriptions and more diverse themes. Falls back to the anchor builder if the SDK / API key is missing. Every produced ScenarioSpec is round-trip validated against the simulator before being added to the catalog: we actually call `env.reset()` with the spec's Config + seed and ensure no exception. Specs that fail validation are discarded and replaced. """ from __future__ import annotations import json import logging import os import random from typing import Iterable, List, Optional, Sequence from pydantic import ValidationError from dispatch_arena.catalog.anchors import ( ALL_ANCHORS, Anchor, anchors_by_difficulty, ) from dispatch_arena.catalog.prompts import ( DESIGNER_SYSTEM_PROMPT, render_user_prompt, ) from dispatch_arena.catalog.spec import Difficulty, ScenarioSpec, SkillTag from dispatch_arena.models import OrderStatus from dispatch_arena.server.env import DispatchArenaEnvironment logger = logging.getLogger(__name__) # --- Shared validation ------------------------------------------------------- def round_trip_validate(spec: ScenarioSpec) -> None: """Make sure the simulator can actually run this scenario. Calls reset() with the spec's Config + seed and verifies the resulting state has at least one courier, at least one order (visible or pending), and the episode is not already done. Raises on failure. """ env = DispatchArenaEnvironment(config=spec.to_config()) obs = env.reset(seed=spec.seed) if obs.done: raise ValueError(f"scenario {spec.name!r} terminates immediately on reset") if not obs.state.couriers: raise ValueError(f"scenario {spec.name!r} has no couriers after reset") total_visible = len(obs.state.orders) total_pending = len(env._pending_arrivals) if total_visible + total_pending != spec.num_orders: raise ValueError( f"scenario {spec.name!r}: expected {spec.num_orders} orders, " f"got {total_visible} visible + {total_pending} pending" ) def heuristic_solvable(spec: ScenarioSpec, min_deliveries: int = 1) -> bool: """Lightweight smoke check: does a greedy heuristic deliver ≥ N orders? Filters out unsolvable scenarios where even the easy heuristic times out with zero deliveries — those would just inject noise into training. """ env = DispatchArenaEnvironment(config=spec.to_config()) obs = env.reset(seed=spec.seed) safety_cap = spec.max_ticks * 3 steps = 0 while not obs.done and steps < safety_cap: steps += 1 # Greedy: assign the first idle courier to the first unassigned order. courier = next( (c for c in obs.state.couriers if c.status.value == "idle" and c.load is None), None, ) order = next( ( o for o in obs.state.orders if o.status.value in {"queued", "ready"} and o.assigned_courier_id is None ), None, ) if courier and order: action = {"action_type": "assign", "courier_id": courier.id, "order_id": order.id} else: action = {"action_type": "hold"} obs = env.step(action) delivered = sum(1 for o in obs.state.orders if o.status == OrderStatus.DELIVERED) return delivered >= min_deliveries # --- Anchor (deterministic) builder ------------------------------------------ class AnchorScenarioBuilder: """Sample N variations per anchor with seeded jitter on knobs. Determinism: same `master_seed` => same catalog, every time. """ def __init__(self, master_seed: int = 0) -> None: self._rng = random.Random(master_seed) def build_batch(self, count_per_difficulty: dict) -> List[ScenarioSpec]: accepted: List[ScenarioSpec] = [] used_names: set = set() used_seeds: set = set() for difficulty, count in count_per_difficulty.items(): anchors = anchors_by_difficulty(difficulty) if not anchors: raise ValueError(f"no anchors for difficulty {difficulty!r}") attempts = 0 target = count while sum(1 for s in accepted if s.difficulty == difficulty) < target: attempts += 1 if attempts > target * 10: raise RuntimeError( f"too many failed attempts producing {difficulty} scenarios" ) anchor = self._rng.choice(anchors) spec = self._instantiate(anchor, used_names, used_seeds) if spec is None: continue try: round_trip_validate(spec) except (ValueError, ValidationError) as exc: logger.warning("dropping %s: %s", spec.name, exc) continue if not heuristic_solvable(spec): logger.warning("dropping %s: heuristic delivered 0 orders", spec.name) continue accepted.append(spec) used_names.add(spec.name) used_seeds.add(spec.seed) return accepted def _instantiate( self, anchor: Anchor, used_names: set, used_seeds: set, ) -> Optional[ScenarioSpec]: # Pick a unique seed in [1, 99_999] not used yet. for _ in range(64): seed = self._rng.randint(1, 99_999) if seed not in used_seeds: break else: return None name = f"{anchor.slug}_seed{seed}" if name in used_names: return None max_ticks = self._rng.randint(*anchor.max_ticks_range) num_couriers = self._rng.randint(*anchor.num_couriers_range) num_orders = self._rng.randint(*anchor.num_orders_range) scenario_bucket = self._rng.choice(anchor.scenario_buckets) traffic_lo, traffic_hi = anchor.traffic_noise_range traffic_noise = round( traffic_lo + self._rng.random() * (traffic_hi - traffic_lo), 2 ) try: spec = ScenarioSpec( name=name, difficulty=anchor.difficulty, theme=anchor.theme, description=anchor.description, skill_focus=list(anchor.skill_focus), seed=seed, mode="normal", max_ticks=max_ticks, num_couriers=num_couriers, num_orders=num_orders, scenario_bucket=scenario_bucket, rolling_arrivals=anchor.rolling_arrivals, traffic_noise=traffic_noise, visible_prep=False, expected_failure_modes=list(anchor.expected_failure_modes), success_criteria=anchor.success_criteria, ) except ValidationError as exc: logger.warning("schema rejection for %s: %s", name, exc) return None return spec # --- LLM builder (Claude) ---------------------------------------------------- class LLMScenarioBuilder: """Use Claude to design scenarios. Falls back to anchors on any failure. Uses the anthropic SDK if available + ANTHROPIC_API_KEY is set. The LLM output is validated against the same Pydantic schema and simulator round-trip as the anchor builder, so a misbehaving model can never inject invalid data into the catalog. """ def __init__( self, model: str = "claude-sonnet-4-5", master_seed: int = 0, max_retries_per_scenario: int = 3, ) -> None: self._rng = random.Random(master_seed) self._model = model self._max_retries = max_retries_per_scenario self._client = self._maybe_load_client() self._anchor_fallback = AnchorScenarioBuilder(master_seed=master_seed + 1) @staticmethod def _maybe_load_client(): if not os.environ.get("ANTHROPIC_API_KEY"): logger.info("ANTHROPIC_API_KEY not set; LLM builder will fall back to anchors") return None try: import anthropic # type: ignore except ImportError: logger.info("anthropic SDK not installed; LLM builder will fall back to anchors") return None return anthropic.Anthropic() def build_batch(self, count_per_difficulty: dict) -> List[ScenarioSpec]: if self._client is None: logger.warning( "LLM builder unavailable, delegating entire build to anchor fallback" ) return self._anchor_fallback.build_batch(count_per_difficulty) accepted: List[ScenarioSpec] = [] for difficulty, count in count_per_difficulty.items(): target = count while sum(1 for s in accepted if s.difficulty == difficulty) < target: spec = self._design_one(difficulty, accepted) if spec is None: # Bail to anchor for this slot rather than infinite-loop. fill = self._anchor_fallback.build_batch({difficulty: 1}) accepted.extend(fill) continue accepted.append(spec) return accepted def _design_one( self, difficulty: Difficulty, prior: Sequence[ScenarioSpec], ) -> Optional[ScenarioSpec]: skill_hint = self._pick_skill_hint(difficulty) used_seeds = {s.seed for s in prior} for attempt in range(self._max_retries): seed_lo = attempt * 10_000 + 1 seed_hi = seed_lo + 9_998 user_prompt = render_user_prompt( difficulty=difficulty, seed_lo=seed_lo, seed_hi=seed_hi, prior_specs=list(prior), skill_hint=skill_hint, ) try: raw = self._chat_json(DESIGNER_SYSTEM_PROMPT, user_prompt) except Exception as exc: # API errors, JSON errors, anything logger.warning("LLM call failed (attempt %d): %s", attempt + 1, exc) continue try: spec = ScenarioSpec.model_validate(raw) except ValidationError as exc: logger.warning("schema rejection: %s", exc) continue if spec.seed in used_seeds or any(s.name == spec.name for s in prior): logger.warning("duplicate seed/name from LLM, retrying") continue try: round_trip_validate(spec) except ValueError as exc: logger.warning("simulator rejection: %s", exc) continue if not heuristic_solvable(spec): logger.warning("dropping %s: heuristic delivers 0 orders", spec.name) continue return spec return None def _chat_json(self, system: str, user: str) -> dict: resp = self._client.messages.create( model=self._model, max_tokens=2048, temperature=0.9, system=system, messages=[{"role": "user", "content": user}], ) text = "".join(block.text for block in resp.content if hasattr(block, "text")) # Tolerate accidental markdown fences just in case. text = text.strip() if text.startswith("```"): text = text.split("```", 2)[1] if text.startswith("json"): text = text[4:] return json.loads(text.strip()) def _pick_skill_hint(self, difficulty: Difficulty) -> str: # Bias toward skill tags appropriate for the difficulty. easy = ["prep_uncertainty", "courier_load_balance"] medium = ["rolling_arrivals", "prep_uncertainty", "deadline_pressure", "courier_load_balance"] hard = [ "traffic_noise", "rolling_arrivals", "long_tail_routing", "shifted_distribution", "deadline_pressure", ] pool = {"easy": easy, "medium": medium, "hard": hard}[difficulty] return self._rng.choice(pool) # --- Catalog I/O ------------------------------------------------------------- def save_catalog(specs: Iterable[ScenarioSpec], path) -> None: payload = [s.model_dump(mode="json") for s in specs] with open(path, "w") as f: json.dump(payload, f, indent=2) def load_catalog(path) -> List[ScenarioSpec]: with open(path) as f: payload = json.load(f) return [ScenarioSpec.model_validate(item) for item in payload]