SwapnilPatil28 commited on
Commit
58af620
·
verified ·
1 Parent(s): a403b80

Add LLM policy, SFT saving & LLM evaluation

Browse files
Files changed (4) hide show
  1. README.md +26 -4
  2. inference.py +38 -9
  3. llm_policy.py +184 -0
  4. train_trl.py +189 -38
README.md CHANGED
@@ -236,9 +236,9 @@ Expected output: **21 passing** (domain rubric, incident catalog, environment in
236
  [`train_trl.py`](./train_trl.py) orchestrates the end-to-end training & evaluation pipeline:
237
 
238
  1. **Rollout** — the `HeuristicCoordinator` drives the live environment to collect `(prompt, completion)` pairs. Prompts include customer tier, revenue impact, visible signals and investigation targets; completions are structured JSON actions.
239
- 2. **SFT** — the dataset is collapsed into a single `text` column (robust across TRL ≥ 0.20) and fed to `SFTTrainer`.
240
- 3. **Evaluation** — the trained model is not yet wired as the acting policy (to stay CPU-friendly), but heuristic vs random are evaluated under identical seeds so the judges can see an observable gap.
241
- 4. **Artifacts** — `artifacts/reward_curve.png` and `artifacts/summary_metrics.json` are written.
242
 
243
  ### Local run (small model)
244
 
@@ -274,7 +274,29 @@ Environment variables you can tune before running `train_trl.py`:
274
  | `TRAIN_EPOCHS` | `1` | SFT epochs |
275
  | `TRAIN_MAX_LENGTH` | `768` | Max sequence length |
276
  | `TRAIN_BATCH_SIZE` / `TRAIN_GRAD_ACCUM` | `1` / `2` | Effective batch size |
277
- | `MAX_ROLLOUT_STEPS` | `120` | Safety cap per episode |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  ---
280
 
 
236
  [`train_trl.py`](./train_trl.py) orchestrates the end-to-end training & evaluation pipeline:
237
 
238
  1. **Rollout** — the `HeuristicCoordinator` drives the live environment to collect `(prompt, completion)` pairs. Prompts include customer tier, revenue impact, visible signals and investigation targets; completions are structured JSON actions.
239
+ 2. **SFT** — the dataset is collapsed into a single `text` column (robust across TRL ≥ 0.20) and fed to `SFTTrainer`. The fine-tuned weights + tokenizer are saved to `artifacts/sft_model/`.
240
+ 3. **Evaluation** — four policies are rolled out under identical seeds: `random`, `heuristic`, `base_model` (raw `BASE_MODEL` HF checkpoint), and `sft_model` (the fine-tuned checkpoint just saved). LLM evaluation auto-enables on a CUDA GPU; force it with `EVAL_LLM_MODELS=true` or disable with `EVAL_LLM_MODELS=false`.
241
+ 4. **Artifacts** — `artifacts/reward_curve.png` (4 lines) and `artifacts/summary_metrics.json` (random / heuristic / base / SFT rewards + per-task SFT-over-base improvements) are written.
242
 
243
  ### Local run (small model)
244
 
 
274
  | `TRAIN_EPOCHS` | `1` | SFT epochs |
275
  | `TRAIN_MAX_LENGTH` | `768` | Max sequence length |
276
  | `TRAIN_BATCH_SIZE` / `TRAIN_GRAD_ACCUM` | `1` / `2` | Effective batch size |
277
+ | `MAX_ROLLOUT_STEPS` | `120` | Safety cap per episode (data collection + baselines) |
278
+ | `MAX_LLM_EVAL_STEPS` | `60` | Safety cap per episode when an LLM policy is acting |
279
+ | `EVAL_LLM_MODELS` | `auto` | `auto` ⇒ eval LLMs only if CUDA is available; `true`/`false` to force |
280
+
281
+ ### Running a base vs fine-tuned comparison
282
+
283
+ After `train_trl.py` finishes, the fine-tuned checkpoint lives at
284
+ `artifacts/sft_model/`. You can re-run just the LLM rollouts against the
285
+ running environment without retraining:
286
+
287
+ ```python
288
+ # Colab / local
289
+ import os
290
+ os.environ["POLICY_MODEL"] = "Qwen/Qwen2.5-0.5B-Instruct" # base model
291
+ !python inference.py
292
+
293
+ os.environ["POLICY_MODEL"] = "artifacts/sft_model" # fine-tuned
294
+ !python inference.py
295
+ ```
296
+
297
+ `inference.py` picks up `POLICY_MODEL` and routes every step through the
298
+ LLM via `llm_policy.LLMPolicy`, falling back to a safe action only when
299
+ the model emits invalid JSON.
300
 
301
  ---
302
 
inference.py CHANGED
@@ -26,6 +26,9 @@ from models import IncidentAction, IncidentObservation
26
  ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:8000")
27
  BENCHMARK = "incident_command_center_env"
28
  RANDOM_BASELINE = os.getenv("RANDOM_BASELINE", "false").lower() == "true"
 
 
 
29
 
30
 
31
  # ---------------------------------------------------------------------------
@@ -297,9 +300,14 @@ def random_action(observation: IncidentObservation) -> IncidentAction:
297
  # ---------------------------------------------------------------------------
298
 
299
 
300
- async def run_task(task_name: str) -> None:
301
  env = IncidentCommandEnvClient(base_url=ENV_URL).sync()
302
- policy_name = "random_baseline" if RANDOM_BASELINE else "heuristic_coordinator"
 
 
 
 
 
303
  coordinator = HeuristicCoordinator()
304
 
305
  log_start(task=task_name, env=BENCHMARK, policy=policy_name)
@@ -313,11 +321,12 @@ async def run_task(task_name: str) -> None:
313
  res = env.reset(task_name=task_name)
314
  while not res.done:
315
  steps_taken += 1
316
- action = (
317
- random_action(res.observation)
318
- if RANDOM_BASELINE
319
- else coordinator.select_action(res.observation)
320
- )
 
321
  res = env.step(action)
322
  reward = float(res.reward or 0.0)
323
  rewards.append(reward)
@@ -340,19 +349,39 @@ async def run_task(task_name: str) -> None:
340
 
341
 
342
  def main() -> None:
 
 
 
 
 
 
343
  for task in ["easy", "medium", "hard"]:
344
- asyncio.run(run_task(task))
 
 
 
 
 
 
 
 
345
  print(
346
  json.dumps(
347
  {
348
  "benchmark": BENCHMARK,
349
- "policy": "random_baseline" if RANDOM_BASELINE else "heuristic_coordinator",
350
  "env_url": ENV_URL,
351
  },
352
  indent=2,
353
  )
354
  )
355
 
 
 
 
 
 
 
356
 
357
  if __name__ == "__main__":
358
  main()
 
26
  ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:8000")
27
  BENCHMARK = "incident_command_center_env"
28
  RANDOM_BASELINE = os.getenv("RANDOM_BASELINE", "false").lower() == "true"
29
+ # When set, run an LLM-backed policy (base or fine-tuned checkpoint) instead
30
+ # of the heuristic / random ones. Point this at a HF hub id or a local dir.
31
+ POLICY_MODEL = os.getenv("POLICY_MODEL", "").strip()
32
 
33
 
34
  # ---------------------------------------------------------------------------
 
300
  # ---------------------------------------------------------------------------
301
 
302
 
303
+ async def run_task(task_name: str, llm_policy=None) -> None:
304
  env = IncidentCommandEnvClient(base_url=ENV_URL).sync()
305
+ if llm_policy is not None:
306
+ policy_name = f"llm:{getattr(llm_policy, 'label', POLICY_MODEL)}"
307
+ elif RANDOM_BASELINE:
308
+ policy_name = "random_baseline"
309
+ else:
310
+ policy_name = "heuristic_coordinator"
311
  coordinator = HeuristicCoordinator()
312
 
313
  log_start(task=task_name, env=BENCHMARK, policy=policy_name)
 
321
  res = env.reset(task_name=task_name)
322
  while not res.done:
323
  steps_taken += 1
324
+ if llm_policy is not None:
325
+ action = llm_policy.select_action(res.observation)
326
+ elif RANDOM_BASELINE:
327
+ action = random_action(res.observation)
328
+ else:
329
+ action = coordinator.select_action(res.observation)
330
  res = env.step(action)
331
  reward = float(res.reward or 0.0)
332
  rewards.append(reward)
 
349
 
350
 
351
  def main() -> None:
352
+ llm_policy = None
353
+ if POLICY_MODEL:
354
+ from llm_policy import LLMPolicy
355
+
356
+ llm_policy = LLMPolicy(POLICY_MODEL, label=POLICY_MODEL)
357
+
358
  for task in ["easy", "medium", "hard"]:
359
+ asyncio.run(run_task(task, llm_policy=llm_policy))
360
+
361
+ if llm_policy is not None:
362
+ policy_label = f"llm:{POLICY_MODEL}"
363
+ elif RANDOM_BASELINE:
364
+ policy_label = "random_baseline"
365
+ else:
366
+ policy_label = "heuristic_coordinator"
367
+
368
  print(
369
  json.dumps(
370
  {
371
  "benchmark": BENCHMARK,
372
+ "policy": policy_label,
373
  "env_url": ENV_URL,
374
  },
375
  indent=2,
376
  )
377
  )
378
 
379
+ if llm_policy is not None:
380
+ try:
381
+ llm_policy.release()
382
+ except Exception:
383
+ pass
384
+
385
 
386
  if __name__ == "__main__":
387
  main()
llm_policy.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM-backed policy for the Incident Command Center environment.
2
+
3
+ Wraps any Hugging Face causal-LM (a base model OR a fine-tuned checkpoint)
4
+ into a callable that takes an ``IncidentObservation`` and returns a typed
5
+ ``IncidentAction``. This is what turns a raw language model into an agent
6
+ that can act inside the environment.
7
+
8
+ Usage::
9
+
10
+ from llm_policy import LLMPolicy
11
+ policy = LLMPolicy("Qwen/Qwen2.5-0.5B-Instruct")
12
+ action = policy.select_action(observation)
13
+
14
+ If the model emits invalid JSON, the policy degrades gracefully to a safe
15
+ default action (inspect the first log target) so one bad generation never
16
+ crashes a whole rollout.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import json
22
+ import logging
23
+ import re
24
+ from typing import Any, Dict, Optional
25
+
26
+ from models import IncidentAction, IncidentObservation
27
+
28
+ _LOG = logging.getLogger("icc.llm_policy")
29
+
30
+ # Regex for the first balanced-ish JSON object in the model output.
31
+ # (Greedy `.*` inside `{...}` keeps nested braces intact for our tiny JSON.)
32
+ _JSON_RE = re.compile(r"\{[\s\S]*\}")
33
+
34
+
35
+ class LLMPolicy:
36
+ """Policy that calls a HF causal-LM and parses its JSON action."""
37
+
38
+ def __init__(
39
+ self,
40
+ model_name_or_path: str,
41
+ *,
42
+ device: Optional[str] = None,
43
+ max_new_tokens: int = 160,
44
+ temperature: float = 0.0,
45
+ dtype: Optional[str] = None,
46
+ label: Optional[str] = None,
47
+ ) -> None:
48
+ try:
49
+ import torch
50
+ from transformers import AutoModelForCausalLM, AutoTokenizer
51
+ except ImportError as exc: # pragma: no cover - runtime dep
52
+ raise RuntimeError(
53
+ "LLMPolicy requires `transformers` and `torch` installed. "
54
+ "Run: pip install transformers torch"
55
+ ) from exc
56
+
57
+ self._torch = torch
58
+ self.label = label or model_name_or_path
59
+ self.max_new_tokens = max_new_tokens
60
+ self.temperature = temperature
61
+
62
+ resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
63
+ if dtype is None:
64
+ torch_dtype = torch.float16 if resolved_device == "cuda" else torch.float32
65
+ else:
66
+ torch_dtype = getattr(torch, dtype)
67
+
68
+ _LOG.info(
69
+ "Loading LLM policy %s on %s (dtype=%s)",
70
+ model_name_or_path,
71
+ resolved_device,
72
+ torch_dtype,
73
+ )
74
+
75
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
76
+ if self.tokenizer.pad_token is None:
77
+ self.tokenizer.pad_token = self.tokenizer.eos_token
78
+
79
+ self.model = AutoModelForCausalLM.from_pretrained(
80
+ model_name_or_path,
81
+ torch_dtype=torch_dtype,
82
+ ).to(resolved_device)
83
+ self.model.eval()
84
+ self.device = resolved_device
85
+
86
+ # ------------------------------------------------------------------
87
+ # Public API
88
+ # ------------------------------------------------------------------
89
+
90
+ def select_action(self, observation: IncidentObservation) -> IncidentAction:
91
+ prompt_text = self._build_prompt_text(observation)
92
+ response_text = self._generate(prompt_text)
93
+ return self._parse_action(response_text, observation)
94
+
95
+ # ------------------------------------------------------------------
96
+ # Internals
97
+ # ------------------------------------------------------------------
98
+
99
+ def _build_prompt_text(self, observation: IncidentObservation) -> str:
100
+ # Keep this import here to avoid importing the trainer stack when the
101
+ # module is used for inference only.
102
+ from train_trl import obs_to_prompt
103
+
104
+ user_prompt = obs_to_prompt(observation)
105
+ if getattr(self.tokenizer, "chat_template", None):
106
+ messages = [{"role": "user", "content": user_prompt}]
107
+ return self.tokenizer.apply_chat_template(
108
+ messages,
109
+ tokenize=False,
110
+ add_generation_prompt=True,
111
+ )
112
+ return f"User: {user_prompt}\n\nAssistant:"
113
+
114
+ def _generate(self, prompt_text: str) -> str:
115
+ torch = self._torch
116
+ inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device)
117
+ gen_kwargs: Dict[str, Any] = {
118
+ "max_new_tokens": self.max_new_tokens,
119
+ "pad_token_id": self.tokenizer.pad_token_id,
120
+ }
121
+ if self.temperature > 0:
122
+ gen_kwargs.update(
123
+ do_sample=True,
124
+ temperature=self.temperature,
125
+ top_p=0.9,
126
+ )
127
+ else:
128
+ gen_kwargs["do_sample"] = False
129
+
130
+ with torch.no_grad():
131
+ output = self.model.generate(**inputs, **gen_kwargs)
132
+ generated_ids = output[0][inputs["input_ids"].shape[1]:]
133
+ return self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
134
+
135
+ def _parse_action(
136
+ self,
137
+ response_text: str,
138
+ observation: IncidentObservation,
139
+ ) -> IncidentAction:
140
+ json_match = _JSON_RE.search(response_text)
141
+ if json_match:
142
+ raw = json_match.group(0)
143
+ # Qwen / Llama sometimes add trailing commentary; strip past the
144
+ # last closing brace to give JSON parser a clean slice.
145
+ last_close = raw.rfind("}")
146
+ if last_close != -1:
147
+ raw = raw[: last_close + 1]
148
+ try:
149
+ data = json.loads(raw)
150
+ return IncidentAction.model_validate(data)
151
+ except Exception as exc:
152
+ _LOG.debug(
153
+ "LLM JSON parse failed: %s :: raw=%s",
154
+ exc,
155
+ raw[:200],
156
+ )
157
+
158
+ return self._safe_fallback(observation)
159
+
160
+ def _safe_fallback(self, observation: IncidentObservation) -> IncidentAction:
161
+ logs = (observation.investigation_targets or {}).get("logs", []) or []
162
+ target = logs[0] if logs else "payments-api"
163
+ return IncidentAction(
164
+ actor="triage_agent",
165
+ action_type="inspect_logs",
166
+ target=target,
167
+ reason="LLM output invalid; using safe fallback action.",
168
+ )
169
+
170
+ # ------------------------------------------------------------------
171
+ # Resource cleanup
172
+ # ------------------------------------------------------------------
173
+
174
+ def release(self) -> None:
175
+ """Free GPU memory so a second model can be loaded after this one."""
176
+ try:
177
+ import gc
178
+ self.model = None # type: ignore[assignment]
179
+ self.tokenizer = None # type: ignore[assignment]
180
+ gc.collect()
181
+ if self._torch.cuda.is_available():
182
+ self._torch.cuda.empty_cache()
183
+ except Exception:
184
+ pass
train_trl.py CHANGED
@@ -1,17 +1,22 @@
1
  """Hugging Face TRL training + evaluation pipeline.
2
 
3
- What this script does end-to-end:
4
-
5
- 1. Rolls out the `HeuristicCoordinator` against a running Incident Command
6
- Center environment to produce `(prompt, completion)` training rows.
7
- 2. Fine-tunes a small instruction-tuned LLM using TRL's `SFTTrainer` with a
8
- single `text` column that works reliably across TRL >= 0.20.
9
- 3. Evaluates the heuristic and random baseline policies post-training and
10
- writes a reward curve + JSON metrics into `artifacts/` — exactly the
11
- evidence the hackathon judges look for.
12
-
13
- Designed to run equally well on CPU (for smoke checks) and on a Colab T4 /
14
- HF Spaces GPU (for the real run).
 
 
 
 
 
15
  """
16
 
17
  from __future__ import annotations
@@ -19,9 +24,9 @@ from __future__ import annotations
19
  import json
20
  import os
21
  import random
22
- from dataclasses import dataclass, asdict
23
  from pathlib import Path
24
- from typing import Dict, List
25
 
26
  import matplotlib.pyplot as plt
27
  from datasets import Dataset
@@ -33,15 +38,18 @@ from models import IncidentAction, IncidentObservation
33
 
34
  ARTIFACT_DIR = Path("artifacts")
35
  ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
 
36
 
37
  ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:8000")
38
  BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
39
  MAX_ROLLOUT_STEPS = int(os.getenv("MAX_ROLLOUT_STEPS", "120"))
 
40
  EPISODES_PER_TASK = int(os.getenv("EPISODES_PER_TASK", "3"))
41
  TRAIN_EPOCHS = float(os.getenv("TRAIN_EPOCHS", "1"))
42
  TRAIN_BATCH_SIZE = int(os.getenv("TRAIN_BATCH_SIZE", "1"))
43
  TRAIN_GRAD_ACCUM = int(os.getenv("TRAIN_GRAD_ACCUM", "2"))
44
  TRAIN_MAX_LENGTH = int(os.getenv("TRAIN_MAX_LENGTH", "768"))
 
45
 
46
 
47
  @dataclass
@@ -99,18 +107,28 @@ def rollout(
99
  policy_name: str,
100
  task_name: str,
101
  collect_dataset: bool = False,
 
 
102
  ):
 
 
 
 
 
103
  env = IncidentCommandEnvClient(base_url=ENV_URL).sync()
104
  coordinator = HeuristicCoordinator()
105
  records: List[Dict[str, str]] = []
106
  rewards: List[float] = []
107
  steps = 0
 
108
 
109
  try:
110
  result = env.reset(task_name=task_name)
111
- while not result.done and steps < MAX_ROLLOUT_STEPS:
112
  steps += 1
113
- if policy_name == "heuristic":
 
 
114
  action = coordinator.select_action(result.observation)
115
  else:
116
  action = random_action(result.observation)
@@ -157,11 +175,7 @@ def build_training_dataset(episodes_per_task: int = EPISODES_PER_TASK) -> Datase
157
 
158
 
159
  def _dataset_to_sft_text_column(dataset: Dataset, tokenizer) -> Dataset:
160
- """Collapse (prompt, completion) pairs into a single `text` field.
161
-
162
- The ``text`` column path in TRL 0.20+ is the most version-robust option,
163
- side-stepping brittle prompt/completion tokenization across TRL releases.
164
- """
165
  from transformers import PreTrainedTokenizerBase
166
 
167
  if not isinstance(tokenizer, PreTrainedTokenizerBase):
@@ -172,7 +186,8 @@ def _dataset_to_sft_text_column(dataset: Dataset, tokenizer) -> Dataset:
172
  dataset = dataset.rename_column("response", "completion")
173
  if "prompt" not in dataset.column_names or "completion" not in dataset.column_names:
174
  raise ValueError(
175
- f"Expected columns 'prompt' and 'completion' (or 'response'). Got: {dataset.column_names}"
 
176
  )
177
 
178
  has_template = bool(getattr(tokenizer, "chat_template", None))
@@ -200,7 +215,11 @@ def _dataset_to_sft_text_column(dataset: Dataset, tokenizer) -> Dataset:
200
  return dataset.map(to_text_batched, batched=True, remove_columns=to_drop)
201
 
202
 
203
- def run_trl_sft(dataset: Dataset) -> None:
 
 
 
 
204
  try:
205
  from transformers import AutoModelForCausalLM, AutoTokenizer
206
  from trl import SFTConfig, SFTTrainer
@@ -237,36 +256,161 @@ def run_trl_sft(dataset: Dataset) -> None:
237
  )
238
  trainer.train()
239
 
 
 
 
 
 
 
 
 
 
240
 
241
  # ---------------------------------------------------------------------------
242
  # Evaluation + reporting
243
  # ---------------------------------------------------------------------------
244
 
245
 
246
- def evaluate_policies(seed: int = 7) -> Dict[str, List[float]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  random.seed(seed)
248
- random_scores: List[float] = []
249
- heuristic_scores: List[float] = []
 
 
 
 
 
250
 
251
  for task in ["easy", "medium", "hard"]:
252
  random_stats, _, _ = rollout("random", task)
253
  heuristic_stats, _, _ = rollout("heuristic", task)
254
- random_scores.append(random_stats.total_reward)
255
- heuristic_scores.append(heuristic_stats.total_reward)
 
 
 
 
 
 
 
 
 
 
 
256
 
257
- return {"random": random_scores, "heuristic": heuristic_scores}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
 
260
  def plot_rewards(score_map: Dict[str, List[float]]) -> None:
261
  labels = ["easy", "medium", "hard"]
262
  x = list(range(len(labels)))
263
- plt.figure(figsize=(8, 4.5))
264
- plt.plot(x, score_map["random"], marker="o", label="Random baseline")
265
- plt.plot(x, score_map["heuristic"], marker="o", label="Heuristic coordinator")
 
 
 
 
 
 
 
 
 
 
 
 
266
  plt.xticks(x, labels)
267
  plt.xlabel("Task difficulty")
268
  plt.ylabel("Episode total reward")
269
- plt.title("Incident Command Center — baseline comparison")
 
270
  plt.grid(alpha=0.3)
271
  plt.legend()
272
  plt.tight_layout()
@@ -276,7 +420,7 @@ def plot_rewards(score_map: Dict[str, List[float]]) -> None:
276
 
277
  def main() -> None:
278
  dataset = build_training_dataset(episodes_per_task=EPISODES_PER_TASK)
279
- dataset.save_to_disk("artifacts/trl_dataset")
280
 
281
  run_trl_sft(dataset)
282
  scores = evaluate_policies()
@@ -286,10 +430,17 @@ def main() -> None:
286
  "base_model": BASE_MODEL,
287
  "dataset_rows": len(dataset),
288
  "episodes_per_task": EPISODES_PER_TASK,
289
- "random_rewards": scores["random"],
290
- "heuristic_rewards": scores["heuristic"],
291
- "improvement_absolute": [
292
- round(h - r, 4) for h, r in zip(scores["heuristic"], scores["random"])
 
 
 
 
 
 
 
293
  ],
294
  }
295
  with open(ARTIFACT_DIR / "summary_metrics.json", "w", encoding="utf-8") as f:
 
1
  """Hugging Face TRL training + evaluation pipeline.
2
 
3
+ Pipeline:
4
+
5
+ 1. **Rollout**: run the ``HeuristicCoordinator`` against the live Incident
6
+ Command Center environment to collect ``(prompt, completion)`` pairs.
7
+ 2. **SFT**: fine-tune a small instruction-tuned LLM on those pairs using
8
+ TRL's ``SFTTrainer`` with a single ``text`` column (robust across TRL
9
+ 0.20).
10
+ 3. **Save**: persist the fine-tuned weights + tokenizer to
11
+ ``artifacts/sft_model`` so the same script can later load them as an
12
+ agent policy.
13
+ 4. **Evaluate**: play the environment with four policies
14
+ ``random / heuristic / base_model / sft_model`` under identical seeds
15
+ and write a reward curve + metrics JSON into ``artifacts/``.
16
+
17
+ Designed to work on CPU for smoke checks and on Colab T4 / HF Spaces GPUs
18
+ for full runs. LLM evaluation auto-enables on CUDA and can be forced with
19
+ ``EVAL_LLM_MODELS=true``.
20
  """
21
 
22
  from __future__ import annotations
 
24
  import json
25
  import os
26
  import random
27
+ from dataclasses import dataclass
28
  from pathlib import Path
29
+ from typing import Callable, Dict, List, Optional
30
 
31
  import matplotlib.pyplot as plt
32
  from datasets import Dataset
 
38
 
39
  ARTIFACT_DIR = Path("artifacts")
40
  ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
41
+ SFT_MODEL_DIR = ARTIFACT_DIR / "sft_model"
42
 
43
  ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:8000")
44
  BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
45
  MAX_ROLLOUT_STEPS = int(os.getenv("MAX_ROLLOUT_STEPS", "120"))
46
+ MAX_LLM_EVAL_STEPS = int(os.getenv("MAX_LLM_EVAL_STEPS", "60"))
47
  EPISODES_PER_TASK = int(os.getenv("EPISODES_PER_TASK", "3"))
48
  TRAIN_EPOCHS = float(os.getenv("TRAIN_EPOCHS", "1"))
49
  TRAIN_BATCH_SIZE = int(os.getenv("TRAIN_BATCH_SIZE", "1"))
50
  TRAIN_GRAD_ACCUM = int(os.getenv("TRAIN_GRAD_ACCUM", "2"))
51
  TRAIN_MAX_LENGTH = int(os.getenv("TRAIN_MAX_LENGTH", "768"))
52
+ _EVAL_LLM_ENV = os.getenv("EVAL_LLM_MODELS", "auto").strip().lower()
53
 
54
 
55
  @dataclass
 
107
  policy_name: str,
108
  task_name: str,
109
  collect_dataset: bool = False,
110
+ policy_callable: Optional[Callable[[IncidentObservation], IncidentAction]] = None,
111
+ max_steps: Optional[int] = None,
112
  ):
113
+ """Play one episode and return (stats, rows, rewards).
114
+
115
+ If ``policy_callable`` is provided it takes precedence over
116
+ ``policy_name`` — this is how the LLM policies plug in.
117
+ """
118
  env = IncidentCommandEnvClient(base_url=ENV_URL).sync()
119
  coordinator = HeuristicCoordinator()
120
  records: List[Dict[str, str]] = []
121
  rewards: List[float] = []
122
  steps = 0
123
+ step_cap = max_steps if max_steps is not None else MAX_ROLLOUT_STEPS
124
 
125
  try:
126
  result = env.reset(task_name=task_name)
127
+ while not result.done and steps < step_cap:
128
  steps += 1
129
+ if policy_callable is not None:
130
+ action = policy_callable(result.observation)
131
+ elif policy_name == "heuristic":
132
  action = coordinator.select_action(result.observation)
133
  else:
134
  action = random_action(result.observation)
 
175
 
176
 
177
  def _dataset_to_sft_text_column(dataset: Dataset, tokenizer) -> Dataset:
178
+ """Collapse (prompt, completion) pairs into a single `text` field."""
 
 
 
 
179
  from transformers import PreTrainedTokenizerBase
180
 
181
  if not isinstance(tokenizer, PreTrainedTokenizerBase):
 
186
  dataset = dataset.rename_column("response", "completion")
187
  if "prompt" not in dataset.column_names or "completion" not in dataset.column_names:
188
  raise ValueError(
189
+ f"Expected columns 'prompt' and 'completion' (or 'response'). "
190
+ f"Got: {dataset.column_names}"
191
  )
192
 
193
  has_template = bool(getattr(tokenizer, "chat_template", None))
 
215
  return dataset.map(to_text_batched, batched=True, remove_columns=to_drop)
216
 
217
 
218
+ def run_trl_sft(dataset: Dataset) -> Path:
219
+ """Fine-tune ``BASE_MODEL`` on the collected dataset and save the model.
220
+
221
+ Returns the directory of the saved SFT checkpoint (``artifacts/sft_model``).
222
+ """
223
  try:
224
  from transformers import AutoModelForCausalLM, AutoTokenizer
225
  from trl import SFTConfig, SFTTrainer
 
256
  )
257
  trainer.train()
258
 
259
+ SFT_MODEL_DIR.mkdir(parents=True, exist_ok=True)
260
+ trainer.save_model(str(SFT_MODEL_DIR))
261
+ tokenizer.save_pretrained(str(SFT_MODEL_DIR))
262
+ print(f"[train] Saved SFT checkpoint to {SFT_MODEL_DIR}")
263
+
264
+ del trainer, model, tokenizer
265
+ _free_gpu_memory()
266
+ return SFT_MODEL_DIR
267
+
268
 
269
  # ---------------------------------------------------------------------------
270
  # Evaluation + reporting
271
  # ---------------------------------------------------------------------------
272
 
273
 
274
+ def _free_gpu_memory() -> None:
275
+ try:
276
+ import gc
277
+ gc.collect()
278
+ import torch
279
+
280
+ if torch.cuda.is_available():
281
+ torch.cuda.empty_cache()
282
+ except Exception:
283
+ pass
284
+
285
+
286
+ def _cuda_available() -> bool:
287
+ try:
288
+ import torch
289
+
290
+ return torch.cuda.is_available()
291
+ except Exception:
292
+ return False
293
+
294
+
295
+ def _should_evaluate_llms() -> bool:
296
+ if _EVAL_LLM_ENV in {"1", "true", "yes", "on"}:
297
+ return True
298
+ if _EVAL_LLM_ENV in {"0", "false", "no", "off"}:
299
+ return False
300
+ # "auto" / empty: enable only when a CUDA GPU is available so CPU runs
301
+ # stay fast.
302
+ return _cuda_available()
303
+
304
+
305
+ def _evaluate_single_policy(
306
+ policy_name: str,
307
+ select_fn: Callable[[IncidentObservation], IncidentAction],
308
+ max_steps: Optional[int] = None,
309
+ ) -> List[float]:
310
+ scores: List[float] = []
311
+ for task in ["easy", "medium", "hard"]:
312
+ stats, _, _ = rollout(
313
+ policy_name=policy_name,
314
+ task_name=task,
315
+ policy_callable=select_fn,
316
+ max_steps=max_steps,
317
+ )
318
+ print(
319
+ f"[eval] policy={policy_name} task={task} "
320
+ f"reward={stats.total_reward:+.2f} steps={stats.steps}"
321
+ )
322
+ scores.append(round(stats.total_reward, 4))
323
+ return scores
324
+
325
+
326
+ def evaluate_policies(
327
+ seed: int = 7,
328
+ evaluate_llms: Optional[bool] = None,
329
+ ) -> Dict[str, List[float]]:
330
+ """Run each policy once per task under the same seed.
331
+
332
+ The random policy is seeded for reproducibility. The heuristic policy is
333
+ deterministic already. LLM policies are evaluated with greedy decoding.
334
+ """
335
  random.seed(seed)
336
+
337
+ scores: Dict[str, List[float]] = {
338
+ "random": [],
339
+ "heuristic": [],
340
+ "base_model": [],
341
+ "sft_model": [],
342
+ }
343
 
344
  for task in ["easy", "medium", "hard"]:
345
  random_stats, _, _ = rollout("random", task)
346
  heuristic_stats, _, _ = rollout("heuristic", task)
347
+ scores["random"].append(round(random_stats.total_reward, 4))
348
+ scores["heuristic"].append(round(heuristic_stats.total_reward, 4))
349
+
350
+ should_eval_llms = _should_evaluate_llms() if evaluate_llms is None else evaluate_llms
351
+ if not should_eval_llms:
352
+ print("[eval] Skipping LLM evaluation (no GPU or EVAL_LLM_MODELS=false).")
353
+ return scores
354
+
355
+ try:
356
+ from llm_policy import LLMPolicy
357
+ except Exception as exc: # pragma: no cover - import-time safety
358
+ print(f"[eval] Could not import LLMPolicy ({exc}); skipping LLM eval.")
359
+ return scores
360
 
361
+ # Base model
362
+ try:
363
+ print(f"[eval] Loading BASE model: {BASE_MODEL}")
364
+ base = LLMPolicy(BASE_MODEL, label="base_model")
365
+ scores["base_model"] = _evaluate_single_policy(
366
+ "base_model", base.select_action, max_steps=MAX_LLM_EVAL_STEPS
367
+ )
368
+ base.release()
369
+ _free_gpu_memory()
370
+ except Exception as exc:
371
+ print(f"[eval] Base-model evaluation failed: {exc}")
372
+
373
+ # SFT model
374
+ if SFT_MODEL_DIR.exists():
375
+ try:
376
+ print(f"[eval] Loading SFT model: {SFT_MODEL_DIR}")
377
+ sft = LLMPolicy(str(SFT_MODEL_DIR), label="sft_model")
378
+ scores["sft_model"] = _evaluate_single_policy(
379
+ "sft_model", sft.select_action, max_steps=MAX_LLM_EVAL_STEPS
380
+ )
381
+ sft.release()
382
+ _free_gpu_memory()
383
+ except Exception as exc:
384
+ print(f"[eval] SFT-model evaluation failed: {exc}")
385
+ else:
386
+ print(f"[eval] No SFT checkpoint found at {SFT_MODEL_DIR}; skipping SFT eval.")
387
+
388
+ return scores
389
 
390
 
391
  def plot_rewards(score_map: Dict[str, List[float]]) -> None:
392
  labels = ["easy", "medium", "hard"]
393
  x = list(range(len(labels)))
394
+ plt.figure(figsize=(9, 5))
395
+
396
+ style = {
397
+ "random": ("x", "tab:red", "Random baseline"),
398
+ "heuristic": ("o", "tab:blue", "Heuristic coordinator"),
399
+ "base_model": ("^", "tab:orange", "Base LLM (untrained)"),
400
+ "sft_model": ("D", "tab:green", "Fine-tuned LLM (SFT)"),
401
+ }
402
+
403
+ for key, (marker, color, label) in style.items():
404
+ values = score_map.get(key) or []
405
+ if not values or len(values) != len(labels):
406
+ continue
407
+ plt.plot(x, values, marker=marker, color=color, label=label, linewidth=2)
408
+
409
  plt.xticks(x, labels)
410
  plt.xlabel("Task difficulty")
411
  plt.ylabel("Episode total reward")
412
+ plt.title("Incident Command Center — policy comparison")
413
+ plt.axhline(0, linestyle="--", color="gray", alpha=0.5)
414
  plt.grid(alpha=0.3)
415
  plt.legend()
416
  plt.tight_layout()
 
420
 
421
  def main() -> None:
422
  dataset = build_training_dataset(episodes_per_task=EPISODES_PER_TASK)
423
+ dataset.save_to_disk(str(ARTIFACT_DIR / "trl_dataset"))
424
 
425
  run_trl_sft(dataset)
426
  scores = evaluate_policies()
 
430
  "base_model": BASE_MODEL,
431
  "dataset_rows": len(dataset),
432
  "episodes_per_task": EPISODES_PER_TASK,
433
+ "random_rewards": scores.get("random", []),
434
+ "heuristic_rewards": scores.get("heuristic", []),
435
+ "base_model_rewards": scores.get("base_model", []),
436
+ "sft_model_rewards": scores.get("sft_model", []),
437
+ "improvement_sft_over_base": [
438
+ round(s - b, 4)
439
+ for s, b in zip(scores.get("sft_model", []), scores.get("base_model", []))
440
+ ] if scores.get("sft_model") and scores.get("base_model") else [],
441
+ "improvement_heuristic_over_random": [
442
+ round(h - r, 4)
443
+ for h, r in zip(scores.get("heuristic", []), scores.get("random", []))
444
  ],
445
  }
446
  with open(ARTIFACT_DIR / "summary_metrics.json", "w", encoding="utf-8") as f: