File size: 15,911 Bytes
c71bf62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
"""GRPO training against Dispatch Arena (normal mode, catalog-driven).

End-to-end pipeline:
  1. Load `catalog.json` -> stratified 70/30 train/eval split.
  2. Spin up the FastAPI server in-process; one DispatchToolEnv per rollout.
  3. Each env is configured per-row from the catalog spec (mode=normal,
     plus rolling_arrivals / traffic_noise / scenario_bucket / etc.).
  4. Tool methods exposed to the LLM:
        view_dashboard, assign, reposition, hold, prioritize, finish_shift
  5. Reward function: reward_total (sum of env's per-step
     RewardBreakdown.total_reward over the rollout). Single function β€” the
     env already decomposes the reward; we report the scalar to GRPO.
  6. Training: TRL GRPOTrainer + LoRA (r=16). Smoke-friendly defaults for
     a Tesla T4 (16 GB) with grad checkpointing on.

Hyperparameters confirmed by user before run:
  max_steps=50, num_generations=2, max_completion_length=512,
  max_tool_calling_iterations=20, LoRA on, beta=0.0, fp16,
  per_device_train_batch_size=2, lr=1e-5.
"""

from __future__ import annotations

import os
import time
from pathlib import Path
from typing import Any, Dict, List, Optional

# Silence the experimental-feature warning before importing TRL.
os.environ.setdefault("TRL_EXPERIMENTAL_SILENCE", "1")


# NOTE: Qwen2.5-Instruct required a chat-template patch (TRL 1.2.0 ships the
# template but never wired it into add_response_schema). We dropped that path
# in favor of Qwen3-0.6B, which TRL recognizes natively β€” simpler stack, no
# monkey-patching, and the original smoke script used this model successfully.
# The git history of train_grpo_smoke.py shows the patch logic if it's needed
# again later.

import torch
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

from dispatch_arena.catalog.dataset import load_catalog_datasets
from dispatch_arena.client import DispatchArenaClient
from dispatch_arena.server.app import run_local_server_in_thread

# ---------------------------------------------------------------------------
# Model / paths
# ---------------------------------------------------------------------------

MODEL_NAME = "Qwen/Qwen3-1.7B"  # Natively recognized by TRL 1.2.0 (no patch needed). Pre-flight verified: tools render into prompt, envelope identical to 0.6B, ~6.7 GB total VRAM on T4.
CATALOG_PATH = Path(__file__).resolve().parents[1] / "catalog" / "catalog.json"
OUTPUT_DIR = Path(__file__).resolve().parents[1] / "scripts" / "_grpo_normal_out"

# ---------------------------------------------------------------------------
# System prompt for normal-mode dispatcher
# ---------------------------------------------------------------------------

SYSTEM_PROMPT = """You are a real-time delivery dispatcher running one shift over a small fleet of couriers. Your job is to dispatch each order to the right courier and keep the shift moving so orders are delivered before their deadlines.

# Tool calling

Always reply with EXACTLY ONE tool call per turn, in this format (no other text):
<tool_call>
{"name": "<tool_name>", "arguments": {<args-json>}}
</tool_call>

# Available tools

```json
[
  {
    "name": "view_dashboard",
    "description": "Refresh the dashboard. Returns courier statuses, order list, deadlines, and travel times.",
    "parameters": {"type": "object", "properties": {}, "required": []}
  },
  {
    "name": "assign",
    "description": "Dispatch an idle courier to an unassigned order whose status is queued or ready.",
    "parameters": {
      "type": "object",
      "properties": {
        "courier_id": {"type": "string", "description": "e.g. courier_0, courier_1, ..."},
        "order_id":   {"type": "string", "description": "e.g. order_0, order_1, ..."}
      },
      "required": ["courier_id", "order_id"]
    }
  },
  {
    "name": "reposition",
    "description": "Pre-stage an idle courier near a busy store or upcoming dropoff.",
    "parameters": {
      "type": "object",
      "properties": {
        "courier_id": {"type": "string"},
        "node_id":    {"type": "string", "description": "hub, store_0..3, or customer_0..N"}
      },
      "required": ["courier_id", "node_id"]
    }
  },
  {
    "name": "hold",
    "description": "Wait one tick. Use when prep is not done and no good action exists.",
    "parameters": {"type": "object", "properties": {}, "required": []}
  },
  {
    "name": "prioritize",
    "description": "Mark an order as priority. Safe even if not yet assigned.",
    "parameters": {
      "type": "object",
      "properties": {"order_id": {"type": "string"}},
      "required": ["order_id"]
    }
  },
  {
    "name": "finish_shift",
    "description": "End the shift early once all visible orders are delivered.",
    "parameters": {"type": "object", "properties": {}, "required": []}
  }
]
```

# Examples

Refresh the dashboard:
<tool_call>
{"name": "view_dashboard", "arguments": {}}
</tool_call>

Dispatch courier_0 to order_1:
<tool_call>
{"name": "assign", "arguments": {"courier_id": "courier_0", "order_id": "order_1"}}
</tool_call>

# Rules

- Prep time is hidden. Queued orders flip to "ready" when prep completes; the courier you dispatch may have to wait briefly at the store.
- Travel times shown are BASE estimates. With traffic noise, real ETAs can run longer.
- The shift ends automatically at max_ticks. Maximize on-time deliveries.
- One tool per turn. Output the tool call in the format above and nothing else."""

USER_KICKOFF = "Begin the shift. Call view_dashboard first to see the state, then dispatch."


# ---------------------------------------------------------------------------
# Server boot + helpers
# ---------------------------------------------------------------------------


def _start_shared_server() -> str:
    server, _thread = run_local_server_in_thread(port=0, max_concurrent_envs=64)
    host, port = server.server_address
    time.sleep(0.2)  # let uvicorn bind
    return f"http://{host}:{port}"


SERVER_URL = _start_shared_server()


def _render_dashboard(obs) -> str:
    """Compact textual dashboard rendered for the LLM.

    Lists couriers, orders (with deadline + status), an excerpt of the travel
    matrix, last events, and the legal action shape. Designed to fit inside
    ~300 tokens so the agent has room for tool-call output too.
    """
    state = obs.state
    parts = [
        f"tick={state.tick}/{state.max_ticks}  verdict={obs.verifier_status.value}  "
        f"backlog={state.backlog}  sla_pressure={state.sla_pressure:.2f}",
    ]

    # Couriers
    parts.append("couriers:")
    for c in state.couriers:
        load = c.load or "none"
        target = f" -> {c.target_node_id}(eta {c.eta_remaining})" if c.target_node_id else ""
        parts.append(f"  {c.id} @ {c.node_id} {c.status.value}{target} carrying={load}")

    # Orders
    parts.append("orders:")
    for o in state.orders:
        assigned = o.assigned_courier_id or "-"
        parts.append(
            f"  {o.id} {o.kind} {o.pickup_node_id}->{o.dropoff_node_id} "
            f"status={o.status.value} deadline=t{o.deadline_tick} assigned={assigned}"
        )

    # Travel times β€” compact: one line per node showing top-K nearest
    parts.append("travel_times (base, may run longer with traffic):")
    for src in [n.id for n in state.nodes]:
        row = state.travel_time_matrix.get(src, {})
        # Show all destinations in a compact format
        edges = ", ".join(f"{dst}={t}" for dst, t in row.items() if dst != src)
        parts.append(f"  {src}: {edges}")

    # Last events
    if obs.info.get("events"):
        parts.append("last_events: " + " | ".join(obs.info["events"][-4:]))

    if obs.done:
        parts.append("DONE")
    return "\n".join(parts)


# ---------------------------------------------------------------------------
# Tool-calling environment (one per rollout via environment_factory)
# ---------------------------------------------------------------------------


class DispatchToolEnv:
    """Normal-mode dispatcher wrapper exposing 6 tools to the LLM.

    The TRL trainer instantiates one DispatchToolEnv per generation. Public
    methods become the LLM's callable tools (per TRL OpenEnv integration).

    `metrics` is read by the reward functions after the rollout finishes.
    """

    def __init__(self) -> None:
        self.client = DispatchArenaClient(base_url=SERVER_URL, timeout_seconds=30)
        self.metrics: Dict[str, Any] = {
            "step_total": 0.0,
            "invalid_count": 0,
            "delivered": 0,
            "ticks": 0,
            "verdict": "in_progress",
            "rollout_done": False,
        }

    # The trainer passes dataset row fields here (seed + _config + ...).
    # We accept **kwargs to ignore _difficulty / _skill_focus / _name without
    # leaking them into env state.
    def reset(
        self,
        seed: Optional[int] = None,
        _config: Optional[Dict[str, Any]] = None,
        **_: Any,
    ) -> str:
        seed_int = int(seed) if seed is not None else 0
        config = _config or {"mode": "normal", "max_ticks": 16, "num_couriers": 3, "num_orders": 5}
        obs = self.client.reset(seed=seed_int, config=config)
        self.metrics = {
            "step_total": float(obs.reward),
            "invalid_count": 0,
            "delivered": 0,
            "ticks": int(obs.state.tick),
            "verdict": obs.verifier_status.value,
            "rollout_done": False,
        }
        return "Initial dashboard:\n" + _render_dashboard(obs)

    def _step(self, action: Dict[str, Any]) -> str:
        if self.metrics.get("rollout_done"):
            return "Shift already finished β€” call finish_shift to stop or stop calling tools."
        obs = self.client.step(action)
        self.metrics["step_total"] += float(obs.reward)
        self.metrics["ticks"] = int(obs.state.tick)
        self.metrics["verdict"] = obs.verifier_status.value
        self.metrics["delivered"] = sum(
            1 for o in obs.state.orders if o.status.value == "delivered"
        )
        if obs.info.get("invalid_action"):
            self.metrics["invalid_count"] += 1
        if obs.done:
            self.metrics["rollout_done"] = True
        return _render_dashboard(obs)

    # ---- Tools (each is exposed to the LLM as a callable) -----------------

    def view_dashboard(self) -> str:
        """Refresh the dashboard with the latest courier/order state."""
        # No-op step semantically β€” but our env doesn't separate "look" from
        # "act", so we issue a hold instead to advance one tick.
        return self._step({"action_type": "hold"})

    def assign(self, courier_id: str, order_id: str) -> str:
        """Dispatch a courier to an order. Both must be valid + free.

        Args:
            courier_id: e.g. "courier_0".
            order_id:   e.g. "order_3".
        """
        return self._step(
            {"action_type": "assign", "courier_id": courier_id, "order_id": order_id}
        )

    def reposition(self, courier_id: str, node_id: str) -> str:
        """Move an idle courier to a node to pre-stage near a busy store.

        Args:
            courier_id: e.g. "courier_1".
            node_id:    e.g. "store_0", "hub", "customer_2".
        """
        return self._step(
            {"action_type": "reposition", "courier_id": courier_id, "node_id": node_id}
        )

    def hold(self) -> str:
        """Wait one tick. Use when prep is unfinished and no good move exists."""
        return self._step({"action_type": "hold"})

    def prioritize(self, order_id: str) -> str:
        """Signal that an order is priority. Safe even if not assigned.

        Args:
            order_id: e.g. "order_2".
        """
        return self._step({"action_type": "prioritize", "order_id": order_id})

    def finish_shift(self) -> str:
        """End the shift early. Returns the final summary."""
        # Mark rollout done; TRL will stop tool-calling once the next
        # iteration sees the rollout flag. We also issue a hold to advance
        # the tick so the env can finalize.
        if not self.metrics.get("rollout_done"):
            self._step({"action_type": "hold"})
        self.metrics["rollout_done"] = True
        return (
            f"Shift finished. tick={self.metrics['ticks']} delivered={self.metrics['delivered']} "
            f"verdict={self.metrics['verdict']} reward={self.metrics['step_total']:.2f}"
        )


# ---------------------------------------------------------------------------
# Reward functions
# ---------------------------------------------------------------------------


def reward_total(environments: List[DispatchToolEnv], **_: Any) -> List[float]:
    """Sum of env's per-step RewardBreakdown.total_reward across the rollout.

    Already includes step_cost, progress, success, invalid_penalty, on-time
    bonus, late penalty, timeout penalty, idle penalty, churn, fairness β€” so
    additional reward functions would be double-counts of components inside
    this scalar.
    """
    return [float(env.metrics.get("step_total", 0.0)) for env in environments]


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


def main() -> None:
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA not available; this trainer requires a GPU.")

    train_ds, eval_ds, train_specs, eval_specs = load_catalog_datasets(
        catalog_path=CATALOG_PATH,
        system_prompt=SYSTEM_PROMPT,
        eval_fraction=0.30,
        master_seed=0,
    )
    print(f"Catalog loaded: train={len(train_ds)} eval={len(eval_ds)}")

    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj"],
    )

    config = GRPOConfig(
        output_dir=str(OUTPUT_DIR),
        per_device_train_batch_size=1,    # OOM fix: only 1 prompt per micro-batch
        gradient_accumulation_steps=4,    # generation_batch_size = 1*4*1 = 4, divisible by G=4
        num_generations=4,                # bump from 2 for better advantage variance
        max_completion_length=384,        # OOM fix: was 512, less KV cache + caps rambling earlier
        max_tool_calling_iterations=20,
        learning_rate=1e-5,
        max_steps=80,                     # longer horizon for the policy to actually move
        beta=0.0,                   # no KL -> skips reference model
        log_completions=True,
        report_to=["tensorboard"],
        logging_dir=str(OUTPUT_DIR / "tb"),
        save_strategy="no",         # smoke-friendly; no checkpoints to disk
        eval_strategy="no",         # post-training eval is a separate script
        logging_steps=1,
        bf16=False,
        fp16=True,
        gradient_checkpointing=True,
        model_init_kwargs={"torch_dtype": "float16"},
    )

    trainer = GRPOTrainer(
        model=MODEL_NAME,
        reward_funcs=[reward_total],
        args=config,
        train_dataset=train_ds,
        environment_factory=DispatchToolEnv,
        peft_config=lora_config,
    )

    print("Starting training...")
    train_output = trainer.train()
    print("\n=== TRAIN DONE ===")
    print("metrics:", train_output.metrics)

    # Persist the trained LoRA adapter so eval can load it later.
    # GRPOConfig was set with save_strategy="no" to avoid mid-run checkpoints,
    # but we explicitly save the final state here. trainer.save_model() writes
    # the adapter (since peft is in use) β€” base model weights are not duplicated.
    final_dir = OUTPUT_DIR / "final_lora"
    trainer.save_model(str(final_dir))
    print(f"LoRA adapter saved -> {final_dir}")


if __name__ == "__main__":
    main()