"""GRPO training against Dispatch Arena (normal mode, catalog-driven). End-to-end pipeline: 1. Load `catalog.json` -> stratified 70/30 train/eval split. 2. Spin up the FastAPI server in-process; one DispatchToolEnv per rollout. 3. Each env is configured per-row from the catalog spec (mode=normal, plus rolling_arrivals / traffic_noise / scenario_bucket / etc.). 4. Tool methods exposed to the LLM: view_dashboard, assign, reposition, hold, prioritize, finish_shift 5. Reward function: reward_total (sum of env's per-step RewardBreakdown.total_reward over the rollout). Single function — the env already decomposes the reward; we report the scalar to GRPO. 6. Training: TRL GRPOTrainer + LoRA (r=16). Smoke-friendly defaults for a Tesla T4 (16 GB) with grad checkpointing on. Hyperparameters confirmed by user before run: max_steps=50, num_generations=2, max_completion_length=512, max_tool_calling_iterations=20, LoRA on, beta=0.0, fp16, per_device_train_batch_size=2, lr=1e-5. """ from __future__ import annotations import os import time from pathlib import Path from typing import Any, Dict, List, Optional # Silence the experimental-feature warning before importing TRL. os.environ.setdefault("TRL_EXPERIMENTAL_SILENCE", "1") # NOTE: Qwen2.5-Instruct required a chat-template patch (TRL 1.2.0 ships the # template but never wired it into add_response_schema). We dropped that path # in favor of Qwen3-0.6B, which TRL recognizes natively — simpler stack, no # monkey-patching, and the original smoke script used this model successfully. # The git history of train_grpo_smoke.py shows the patch logic if it's needed # again later. import torch from peft import LoraConfig from trl import GRPOConfig, GRPOTrainer from dispatch_arena.catalog.dataset import load_catalog_datasets from dispatch_arena.client import DispatchArenaClient from dispatch_arena.server.app import run_local_server_in_thread # --------------------------------------------------------------------------- # Model / paths # --------------------------------------------------------------------------- MODEL_NAME = "Qwen/Qwen3-1.7B" # Natively recognized by TRL 1.2.0 (no patch needed). Pre-flight verified: tools render into prompt, envelope identical to 0.6B, ~6.7 GB total VRAM on T4. CATALOG_PATH = Path(__file__).resolve().parents[1] / "catalog" / "catalog.json" OUTPUT_DIR = Path(__file__).resolve().parents[1] / "scripts" / "_grpo_normal_out" # --------------------------------------------------------------------------- # System prompt for normal-mode dispatcher # --------------------------------------------------------------------------- SYSTEM_PROMPT = """You are a real-time delivery dispatcher running one shift over a small fleet of couriers. Your job is to dispatch each order to the right courier and keep the shift moving so orders are delivered before their deadlines. # Tool calling Always reply with EXACTLY ONE tool call per turn, in this format (no other text): {"name": "", "arguments": {}} # Available tools ```json [ { "name": "view_dashboard", "description": "Refresh the dashboard. Returns courier statuses, order list, deadlines, and travel times.", "parameters": {"type": "object", "properties": {}, "required": []} }, { "name": "assign", "description": "Dispatch an idle courier to an unassigned order whose status is queued or ready.", "parameters": { "type": "object", "properties": { "courier_id": {"type": "string", "description": "e.g. courier_0, courier_1, ..."}, "order_id": {"type": "string", "description": "e.g. order_0, order_1, ..."} }, "required": ["courier_id", "order_id"] } }, { "name": "reposition", "description": "Pre-stage an idle courier near a busy store or upcoming dropoff.", "parameters": { "type": "object", "properties": { "courier_id": {"type": "string"}, "node_id": {"type": "string", "description": "hub, store_0..3, or customer_0..N"} }, "required": ["courier_id", "node_id"] } }, { "name": "hold", "description": "Wait one tick. Use when prep is not done and no good action exists.", "parameters": {"type": "object", "properties": {}, "required": []} }, { "name": "prioritize", "description": "Mark an order as priority. Safe even if not yet assigned.", "parameters": { "type": "object", "properties": {"order_id": {"type": "string"}}, "required": ["order_id"] } }, { "name": "finish_shift", "description": "End the shift early once all visible orders are delivered.", "parameters": {"type": "object", "properties": {}, "required": []} } ] ``` # Examples Refresh the dashboard: {"name": "view_dashboard", "arguments": {}} Dispatch courier_0 to order_1: {"name": "assign", "arguments": {"courier_id": "courier_0", "order_id": "order_1"}} # Rules - Prep time is hidden. Queued orders flip to "ready" when prep completes; the courier you dispatch may have to wait briefly at the store. - Travel times shown are BASE estimates. With traffic noise, real ETAs can run longer. - The shift ends automatically at max_ticks. Maximize on-time deliveries. - One tool per turn. Output the tool call in the format above and nothing else.""" USER_KICKOFF = "Begin the shift. Call view_dashboard first to see the state, then dispatch." # --------------------------------------------------------------------------- # Server boot + helpers # --------------------------------------------------------------------------- def _start_shared_server() -> str: server, _thread = run_local_server_in_thread(port=0, max_concurrent_envs=64) host, port = server.server_address time.sleep(0.2) # let uvicorn bind return f"http://{host}:{port}" SERVER_URL = _start_shared_server() def _render_dashboard(obs) -> str: """Compact textual dashboard rendered for the LLM. Lists couriers, orders (with deadline + status), an excerpt of the travel matrix, last events, and the legal action shape. Designed to fit inside ~300 tokens so the agent has room for tool-call output too. """ state = obs.state parts = [ f"tick={state.tick}/{state.max_ticks} verdict={obs.verifier_status.value} " f"backlog={state.backlog} sla_pressure={state.sla_pressure:.2f}", ] # Couriers parts.append("couriers:") for c in state.couriers: load = c.load or "none" target = f" -> {c.target_node_id}(eta {c.eta_remaining})" if c.target_node_id else "" parts.append(f" {c.id} @ {c.node_id} {c.status.value}{target} carrying={load}") # Orders parts.append("orders:") for o in state.orders: assigned = o.assigned_courier_id or "-" parts.append( f" {o.id} {o.kind} {o.pickup_node_id}->{o.dropoff_node_id} " f"status={o.status.value} deadline=t{o.deadline_tick} assigned={assigned}" ) # Travel times — compact: one line per node showing top-K nearest parts.append("travel_times (base, may run longer with traffic):") for src in [n.id for n in state.nodes]: row = state.travel_time_matrix.get(src, {}) # Show all destinations in a compact format edges = ", ".join(f"{dst}={t}" for dst, t in row.items() if dst != src) parts.append(f" {src}: {edges}") # Last events if obs.info.get("events"): parts.append("last_events: " + " | ".join(obs.info["events"][-4:])) if obs.done: parts.append("DONE") return "\n".join(parts) # --------------------------------------------------------------------------- # Tool-calling environment (one per rollout via environment_factory) # --------------------------------------------------------------------------- class DispatchToolEnv: """Normal-mode dispatcher wrapper exposing 6 tools to the LLM. The TRL trainer instantiates one DispatchToolEnv per generation. Public methods become the LLM's callable tools (per TRL OpenEnv integration). `metrics` is read by the reward functions after the rollout finishes. """ def __init__(self) -> None: self.client = DispatchArenaClient(base_url=SERVER_URL, timeout_seconds=30) self.metrics: Dict[str, Any] = { "step_total": 0.0, "invalid_count": 0, "delivered": 0, "ticks": 0, "verdict": "in_progress", "rollout_done": False, } # The trainer passes dataset row fields here (seed + _config + ...). # We accept **kwargs to ignore _difficulty / _skill_focus / _name without # leaking them into env state. def reset( self, seed: Optional[int] = None, _config: Optional[Dict[str, Any]] = None, **_: Any, ) -> str: seed_int = int(seed) if seed is not None else 0 config = _config or {"mode": "normal", "max_ticks": 16, "num_couriers": 3, "num_orders": 5} obs = self.client.reset(seed=seed_int, config=config) self.metrics = { "step_total": float(obs.reward), "invalid_count": 0, "delivered": 0, "ticks": int(obs.state.tick), "verdict": obs.verifier_status.value, "rollout_done": False, } return "Initial dashboard:\n" + _render_dashboard(obs) def _step(self, action: Dict[str, Any]) -> str: if self.metrics.get("rollout_done"): return "Shift already finished — call finish_shift to stop or stop calling tools." obs = self.client.step(action) self.metrics["step_total"] += float(obs.reward) self.metrics["ticks"] = int(obs.state.tick) self.metrics["verdict"] = obs.verifier_status.value self.metrics["delivered"] = sum( 1 for o in obs.state.orders if o.status.value == "delivered" ) if obs.info.get("invalid_action"): self.metrics["invalid_count"] += 1 if obs.done: self.metrics["rollout_done"] = True return _render_dashboard(obs) # ---- Tools (each is exposed to the LLM as a callable) ----------------- def view_dashboard(self) -> str: """Refresh the dashboard with the latest courier/order state.""" # No-op step semantically — but our env doesn't separate "look" from # "act", so we issue a hold instead to advance one tick. return self._step({"action_type": "hold"}) def assign(self, courier_id: str, order_id: str) -> str: """Dispatch a courier to an order. Both must be valid + free. Args: courier_id: e.g. "courier_0". order_id: e.g. "order_3". """ return self._step( {"action_type": "assign", "courier_id": courier_id, "order_id": order_id} ) def reposition(self, courier_id: str, node_id: str) -> str: """Move an idle courier to a node to pre-stage near a busy store. Args: courier_id: e.g. "courier_1". node_id: e.g. "store_0", "hub", "customer_2". """ return self._step( {"action_type": "reposition", "courier_id": courier_id, "node_id": node_id} ) def hold(self) -> str: """Wait one tick. Use when prep is unfinished and no good move exists.""" return self._step({"action_type": "hold"}) def prioritize(self, order_id: str) -> str: """Signal that an order is priority. Safe even if not assigned. Args: order_id: e.g. "order_2". """ return self._step({"action_type": "prioritize", "order_id": order_id}) def finish_shift(self) -> str: """End the shift early. Returns the final summary.""" # Mark rollout done; TRL will stop tool-calling once the next # iteration sees the rollout flag. We also issue a hold to advance # the tick so the env can finalize. if not self.metrics.get("rollout_done"): self._step({"action_type": "hold"}) self.metrics["rollout_done"] = True return ( f"Shift finished. tick={self.metrics['ticks']} delivered={self.metrics['delivered']} " f"verdict={self.metrics['verdict']} reward={self.metrics['step_total']:.2f}" ) # --------------------------------------------------------------------------- # Reward functions # --------------------------------------------------------------------------- def reward_total(environments: List[DispatchToolEnv], **_: Any) -> List[float]: """Sum of env's per-step RewardBreakdown.total_reward across the rollout. Already includes step_cost, progress, success, invalid_penalty, on-time bonus, late penalty, timeout penalty, idle penalty, churn, fairness — so additional reward functions would be double-counts of components inside this scalar. """ return [float(env.metrics.get("step_total", 0.0)) for env in environments] # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: if not torch.cuda.is_available(): raise RuntimeError("CUDA not available; this trainer requires a GPU.") train_ds, eval_ds, train_specs, eval_specs = load_catalog_datasets( catalog_path=CATALOG_PATH, system_prompt=SYSTEM_PROMPT, eval_fraction=0.30, master_seed=0, ) print(f"Catalog loaded: train={len(train_ds)} eval={len(eval_ds)}") lora_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) config = GRPOConfig( output_dir=str(OUTPUT_DIR), per_device_train_batch_size=1, # OOM fix: only 1 prompt per micro-batch gradient_accumulation_steps=4, # generation_batch_size = 1*4*1 = 4, divisible by G=4 num_generations=4, # bump from 2 for better advantage variance max_completion_length=384, # OOM fix: was 512, less KV cache + caps rambling earlier max_tool_calling_iterations=20, learning_rate=1e-5, max_steps=80, # longer horizon for the policy to actually move beta=0.0, # no KL -> skips reference model log_completions=True, report_to=["tensorboard"], logging_dir=str(OUTPUT_DIR / "tb"), save_strategy="no", # smoke-friendly; no checkpoints to disk eval_strategy="no", # post-training eval is a separate script logging_steps=1, bf16=False, fp16=True, gradient_checkpointing=True, model_init_kwargs={"torch_dtype": "float16"}, ) trainer = GRPOTrainer( model=MODEL_NAME, reward_funcs=[reward_total], args=config, train_dataset=train_ds, environment_factory=DispatchToolEnv, peft_config=lora_config, ) print("Starting training...") train_output = trainer.train() print("\n=== TRAIN DONE ===") print("metrics:", train_output.metrics) # Persist the trained LoRA adapter so eval can load it later. # GRPOConfig was set with save_strategy="no" to avoid mid-run checkpoints, # but we explicitly save the final state here. trainer.save_model() writes # the adapter (since peft is in use) — base model weights are not duplicated. final_dir = OUTPUT_DIR / "final_lora" trainer.save_model(str(final_dir)) print(f"LoRA adapter saved -> {final_dir}") if __name__ == "__main__": main()