Add LLM policy, SFT saving & LLM evaluation
Browse files- README.md +26 -4
- inference.py +38 -9
- llm_policy.py +184 -0
- train_trl.py +189 -38
README.md
CHANGED
|
@@ -236,9 +236,9 @@ Expected output: **21 passing** (domain rubric, incident catalog, environment in
|
|
| 236 |
[`train_trl.py`](./train_trl.py) orchestrates the end-to-end training & evaluation pipeline:
|
| 237 |
|
| 238 |
1. **Rollout** — the `HeuristicCoordinator` drives the live environment to collect `(prompt, completion)` pairs. Prompts include customer tier, revenue impact, visible signals and investigation targets; completions are structured JSON actions.
|
| 239 |
-
2. **SFT** — the dataset is collapsed into a single `text` column (robust across TRL ≥ 0.20) and fed to `SFTTrainer`.
|
| 240 |
-
3. **Evaluation** —
|
| 241 |
-
4. **Artifacts** — `artifacts/reward_curve.png` and `artifacts/summary_metrics.json` are written.
|
| 242 |
|
| 243 |
### Local run (small model)
|
| 244 |
|
|
@@ -274,7 +274,29 @@ Environment variables you can tune before running `train_trl.py`:
|
|
| 274 |
| `TRAIN_EPOCHS` | `1` | SFT epochs |
|
| 275 |
| `TRAIN_MAX_LENGTH` | `768` | Max sequence length |
|
| 276 |
| `TRAIN_BATCH_SIZE` / `TRAIN_GRAD_ACCUM` | `1` / `2` | Effective batch size |
|
| 277 |
-
| `MAX_ROLLOUT_STEPS` | `120` | Safety cap per episode |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
---
|
| 280 |
|
|
|
|
| 236 |
[`train_trl.py`](./train_trl.py) orchestrates the end-to-end training & evaluation pipeline:
|
| 237 |
|
| 238 |
1. **Rollout** — the `HeuristicCoordinator` drives the live environment to collect `(prompt, completion)` pairs. Prompts include customer tier, revenue impact, visible signals and investigation targets; completions are structured JSON actions.
|
| 239 |
+
2. **SFT** — the dataset is collapsed into a single `text` column (robust across TRL ≥ 0.20) and fed to `SFTTrainer`. The fine-tuned weights + tokenizer are saved to `artifacts/sft_model/`.
|
| 240 |
+
3. **Evaluation** — four policies are rolled out under identical seeds: `random`, `heuristic`, `base_model` (raw `BASE_MODEL` HF checkpoint), and `sft_model` (the fine-tuned checkpoint just saved). LLM evaluation auto-enables on a CUDA GPU; force it with `EVAL_LLM_MODELS=true` or disable with `EVAL_LLM_MODELS=false`.
|
| 241 |
+
4. **Artifacts** — `artifacts/reward_curve.png` (4 lines) and `artifacts/summary_metrics.json` (random / heuristic / base / SFT rewards + per-task SFT-over-base improvements) are written.
|
| 242 |
|
| 243 |
### Local run (small model)
|
| 244 |
|
|
|
|
| 274 |
| `TRAIN_EPOCHS` | `1` | SFT epochs |
|
| 275 |
| `TRAIN_MAX_LENGTH` | `768` | Max sequence length |
|
| 276 |
| `TRAIN_BATCH_SIZE` / `TRAIN_GRAD_ACCUM` | `1` / `2` | Effective batch size |
|
| 277 |
+
| `MAX_ROLLOUT_STEPS` | `120` | Safety cap per episode (data collection + baselines) |
|
| 278 |
+
| `MAX_LLM_EVAL_STEPS` | `60` | Safety cap per episode when an LLM policy is acting |
|
| 279 |
+
| `EVAL_LLM_MODELS` | `auto` | `auto` ⇒ eval LLMs only if CUDA is available; `true`/`false` to force |
|
| 280 |
+
|
| 281 |
+
### Running a base vs fine-tuned comparison
|
| 282 |
+
|
| 283 |
+
After `train_trl.py` finishes, the fine-tuned checkpoint lives at
|
| 284 |
+
`artifacts/sft_model/`. You can re-run just the LLM rollouts against the
|
| 285 |
+
running environment without retraining:
|
| 286 |
+
|
| 287 |
+
```python
|
| 288 |
+
# Colab / local
|
| 289 |
+
import os
|
| 290 |
+
os.environ["POLICY_MODEL"] = "Qwen/Qwen2.5-0.5B-Instruct" # base model
|
| 291 |
+
!python inference.py
|
| 292 |
+
|
| 293 |
+
os.environ["POLICY_MODEL"] = "artifacts/sft_model" # fine-tuned
|
| 294 |
+
!python inference.py
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
`inference.py` picks up `POLICY_MODEL` and routes every step through the
|
| 298 |
+
LLM via `llm_policy.LLMPolicy`, falling back to a safe action only when
|
| 299 |
+
the model emits invalid JSON.
|
| 300 |
|
| 301 |
---
|
| 302 |
|
inference.py
CHANGED
|
@@ -26,6 +26,9 @@ from models import IncidentAction, IncidentObservation
|
|
| 26 |
ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:8000")
|
| 27 |
BENCHMARK = "incident_command_center_env"
|
| 28 |
RANDOM_BASELINE = os.getenv("RANDOM_BASELINE", "false").lower() == "true"
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
# ---------------------------------------------------------------------------
|
|
@@ -297,9 +300,14 @@ def random_action(observation: IncidentObservation) -> IncidentAction:
|
|
| 297 |
# ---------------------------------------------------------------------------
|
| 298 |
|
| 299 |
|
| 300 |
-
async def run_task(task_name: str) -> None:
|
| 301 |
env = IncidentCommandEnvClient(base_url=ENV_URL).sync()
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
coordinator = HeuristicCoordinator()
|
| 304 |
|
| 305 |
log_start(task=task_name, env=BENCHMARK, policy=policy_name)
|
|
@@ -313,11 +321,12 @@ async def run_task(task_name: str) -> None:
|
|
| 313 |
res = env.reset(task_name=task_name)
|
| 314 |
while not res.done:
|
| 315 |
steps_taken += 1
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
|
|
|
| 321 |
res = env.step(action)
|
| 322 |
reward = float(res.reward or 0.0)
|
| 323 |
rewards.append(reward)
|
|
@@ -340,19 +349,39 @@ async def run_task(task_name: str) -> None:
|
|
| 340 |
|
| 341 |
|
| 342 |
def main() -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
for task in ["easy", "medium", "hard"]:
|
| 344 |
-
asyncio.run(run_task(task))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
print(
|
| 346 |
json.dumps(
|
| 347 |
{
|
| 348 |
"benchmark": BENCHMARK,
|
| 349 |
-
"policy":
|
| 350 |
"env_url": ENV_URL,
|
| 351 |
},
|
| 352 |
indent=2,
|
| 353 |
)
|
| 354 |
)
|
| 355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
if __name__ == "__main__":
|
| 358 |
main()
|
|
|
|
| 26 |
ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:8000")
|
| 27 |
BENCHMARK = "incident_command_center_env"
|
| 28 |
RANDOM_BASELINE = os.getenv("RANDOM_BASELINE", "false").lower() == "true"
|
| 29 |
+
# When set, run an LLM-backed policy (base or fine-tuned checkpoint) instead
|
| 30 |
+
# of the heuristic / random ones. Point this at a HF hub id or a local dir.
|
| 31 |
+
POLICY_MODEL = os.getenv("POLICY_MODEL", "").strip()
|
| 32 |
|
| 33 |
|
| 34 |
# ---------------------------------------------------------------------------
|
|
|
|
| 300 |
# ---------------------------------------------------------------------------
|
| 301 |
|
| 302 |
|
| 303 |
+
async def run_task(task_name: str, llm_policy=None) -> None:
|
| 304 |
env = IncidentCommandEnvClient(base_url=ENV_URL).sync()
|
| 305 |
+
if llm_policy is not None:
|
| 306 |
+
policy_name = f"llm:{getattr(llm_policy, 'label', POLICY_MODEL)}"
|
| 307 |
+
elif RANDOM_BASELINE:
|
| 308 |
+
policy_name = "random_baseline"
|
| 309 |
+
else:
|
| 310 |
+
policy_name = "heuristic_coordinator"
|
| 311 |
coordinator = HeuristicCoordinator()
|
| 312 |
|
| 313 |
log_start(task=task_name, env=BENCHMARK, policy=policy_name)
|
|
|
|
| 321 |
res = env.reset(task_name=task_name)
|
| 322 |
while not res.done:
|
| 323 |
steps_taken += 1
|
| 324 |
+
if llm_policy is not None:
|
| 325 |
+
action = llm_policy.select_action(res.observation)
|
| 326 |
+
elif RANDOM_BASELINE:
|
| 327 |
+
action = random_action(res.observation)
|
| 328 |
+
else:
|
| 329 |
+
action = coordinator.select_action(res.observation)
|
| 330 |
res = env.step(action)
|
| 331 |
reward = float(res.reward or 0.0)
|
| 332 |
rewards.append(reward)
|
|
|
|
| 349 |
|
| 350 |
|
| 351 |
def main() -> None:
|
| 352 |
+
llm_policy = None
|
| 353 |
+
if POLICY_MODEL:
|
| 354 |
+
from llm_policy import LLMPolicy
|
| 355 |
+
|
| 356 |
+
llm_policy = LLMPolicy(POLICY_MODEL, label=POLICY_MODEL)
|
| 357 |
+
|
| 358 |
for task in ["easy", "medium", "hard"]:
|
| 359 |
+
asyncio.run(run_task(task, llm_policy=llm_policy))
|
| 360 |
+
|
| 361 |
+
if llm_policy is not None:
|
| 362 |
+
policy_label = f"llm:{POLICY_MODEL}"
|
| 363 |
+
elif RANDOM_BASELINE:
|
| 364 |
+
policy_label = "random_baseline"
|
| 365 |
+
else:
|
| 366 |
+
policy_label = "heuristic_coordinator"
|
| 367 |
+
|
| 368 |
print(
|
| 369 |
json.dumps(
|
| 370 |
{
|
| 371 |
"benchmark": BENCHMARK,
|
| 372 |
+
"policy": policy_label,
|
| 373 |
"env_url": ENV_URL,
|
| 374 |
},
|
| 375 |
indent=2,
|
| 376 |
)
|
| 377 |
)
|
| 378 |
|
| 379 |
+
if llm_policy is not None:
|
| 380 |
+
try:
|
| 381 |
+
llm_policy.release()
|
| 382 |
+
except Exception:
|
| 383 |
+
pass
|
| 384 |
+
|
| 385 |
|
| 386 |
if __name__ == "__main__":
|
| 387 |
main()
|
llm_policy.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM-backed policy for the Incident Command Center environment.
|
| 2 |
+
|
| 3 |
+
Wraps any Hugging Face causal-LM (a base model OR a fine-tuned checkpoint)
|
| 4 |
+
into a callable that takes an ``IncidentObservation`` and returns a typed
|
| 5 |
+
``IncidentAction``. This is what turns a raw language model into an agent
|
| 6 |
+
that can act inside the environment.
|
| 7 |
+
|
| 8 |
+
Usage::
|
| 9 |
+
|
| 10 |
+
from llm_policy import LLMPolicy
|
| 11 |
+
policy = LLMPolicy("Qwen/Qwen2.5-0.5B-Instruct")
|
| 12 |
+
action = policy.select_action(observation)
|
| 13 |
+
|
| 14 |
+
If the model emits invalid JSON, the policy degrades gracefully to a safe
|
| 15 |
+
default action (inspect the first log target) so one bad generation never
|
| 16 |
+
crashes a whole rollout.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import logging
|
| 23 |
+
import re
|
| 24 |
+
from typing import Any, Dict, Optional
|
| 25 |
+
|
| 26 |
+
from models import IncidentAction, IncidentObservation
|
| 27 |
+
|
| 28 |
+
_LOG = logging.getLogger("icc.llm_policy")
|
| 29 |
+
|
| 30 |
+
# Regex for the first balanced-ish JSON object in the model output.
|
| 31 |
+
# (Greedy `.*` inside `{...}` keeps nested braces intact for our tiny JSON.)
|
| 32 |
+
_JSON_RE = re.compile(r"\{[\s\S]*\}")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class LLMPolicy:
|
| 36 |
+
"""Policy that calls a HF causal-LM and parses its JSON action."""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
model_name_or_path: str,
|
| 41 |
+
*,
|
| 42 |
+
device: Optional[str] = None,
|
| 43 |
+
max_new_tokens: int = 160,
|
| 44 |
+
temperature: float = 0.0,
|
| 45 |
+
dtype: Optional[str] = None,
|
| 46 |
+
label: Optional[str] = None,
|
| 47 |
+
) -> None:
|
| 48 |
+
try:
|
| 49 |
+
import torch
|
| 50 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 51 |
+
except ImportError as exc: # pragma: no cover - runtime dep
|
| 52 |
+
raise RuntimeError(
|
| 53 |
+
"LLMPolicy requires `transformers` and `torch` installed. "
|
| 54 |
+
"Run: pip install transformers torch"
|
| 55 |
+
) from exc
|
| 56 |
+
|
| 57 |
+
self._torch = torch
|
| 58 |
+
self.label = label or model_name_or_path
|
| 59 |
+
self.max_new_tokens = max_new_tokens
|
| 60 |
+
self.temperature = temperature
|
| 61 |
+
|
| 62 |
+
resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 63 |
+
if dtype is None:
|
| 64 |
+
torch_dtype = torch.float16 if resolved_device == "cuda" else torch.float32
|
| 65 |
+
else:
|
| 66 |
+
torch_dtype = getattr(torch, dtype)
|
| 67 |
+
|
| 68 |
+
_LOG.info(
|
| 69 |
+
"Loading LLM policy %s on %s (dtype=%s)",
|
| 70 |
+
model_name_or_path,
|
| 71 |
+
resolved_device,
|
| 72 |
+
torch_dtype,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 76 |
+
if self.tokenizer.pad_token is None:
|
| 77 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 78 |
+
|
| 79 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 80 |
+
model_name_or_path,
|
| 81 |
+
torch_dtype=torch_dtype,
|
| 82 |
+
).to(resolved_device)
|
| 83 |
+
self.model.eval()
|
| 84 |
+
self.device = resolved_device
|
| 85 |
+
|
| 86 |
+
# ------------------------------------------------------------------
|
| 87 |
+
# Public API
|
| 88 |
+
# ------------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
def select_action(self, observation: IncidentObservation) -> IncidentAction:
|
| 91 |
+
prompt_text = self._build_prompt_text(observation)
|
| 92 |
+
response_text = self._generate(prompt_text)
|
| 93 |
+
return self._parse_action(response_text, observation)
|
| 94 |
+
|
| 95 |
+
# ------------------------------------------------------------------
|
| 96 |
+
# Internals
|
| 97 |
+
# ------------------------------------------------------------------
|
| 98 |
+
|
| 99 |
+
def _build_prompt_text(self, observation: IncidentObservation) -> str:
|
| 100 |
+
# Keep this import here to avoid importing the trainer stack when the
|
| 101 |
+
# module is used for inference only.
|
| 102 |
+
from train_trl import obs_to_prompt
|
| 103 |
+
|
| 104 |
+
user_prompt = obs_to_prompt(observation)
|
| 105 |
+
if getattr(self.tokenizer, "chat_template", None):
|
| 106 |
+
messages = [{"role": "user", "content": user_prompt}]
|
| 107 |
+
return self.tokenizer.apply_chat_template(
|
| 108 |
+
messages,
|
| 109 |
+
tokenize=False,
|
| 110 |
+
add_generation_prompt=True,
|
| 111 |
+
)
|
| 112 |
+
return f"User: {user_prompt}\n\nAssistant:"
|
| 113 |
+
|
| 114 |
+
def _generate(self, prompt_text: str) -> str:
|
| 115 |
+
torch = self._torch
|
| 116 |
+
inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device)
|
| 117 |
+
gen_kwargs: Dict[str, Any] = {
|
| 118 |
+
"max_new_tokens": self.max_new_tokens,
|
| 119 |
+
"pad_token_id": self.tokenizer.pad_token_id,
|
| 120 |
+
}
|
| 121 |
+
if self.temperature > 0:
|
| 122 |
+
gen_kwargs.update(
|
| 123 |
+
do_sample=True,
|
| 124 |
+
temperature=self.temperature,
|
| 125 |
+
top_p=0.9,
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
gen_kwargs["do_sample"] = False
|
| 129 |
+
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
output = self.model.generate(**inputs, **gen_kwargs)
|
| 132 |
+
generated_ids = output[0][inputs["input_ids"].shape[1]:]
|
| 133 |
+
return self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 134 |
+
|
| 135 |
+
def _parse_action(
|
| 136 |
+
self,
|
| 137 |
+
response_text: str,
|
| 138 |
+
observation: IncidentObservation,
|
| 139 |
+
) -> IncidentAction:
|
| 140 |
+
json_match = _JSON_RE.search(response_text)
|
| 141 |
+
if json_match:
|
| 142 |
+
raw = json_match.group(0)
|
| 143 |
+
# Qwen / Llama sometimes add trailing commentary; strip past the
|
| 144 |
+
# last closing brace to give JSON parser a clean slice.
|
| 145 |
+
last_close = raw.rfind("}")
|
| 146 |
+
if last_close != -1:
|
| 147 |
+
raw = raw[: last_close + 1]
|
| 148 |
+
try:
|
| 149 |
+
data = json.loads(raw)
|
| 150 |
+
return IncidentAction.model_validate(data)
|
| 151 |
+
except Exception as exc:
|
| 152 |
+
_LOG.debug(
|
| 153 |
+
"LLM JSON parse failed: %s :: raw=%s",
|
| 154 |
+
exc,
|
| 155 |
+
raw[:200],
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
return self._safe_fallback(observation)
|
| 159 |
+
|
| 160 |
+
def _safe_fallback(self, observation: IncidentObservation) -> IncidentAction:
|
| 161 |
+
logs = (observation.investigation_targets or {}).get("logs", []) or []
|
| 162 |
+
target = logs[0] if logs else "payments-api"
|
| 163 |
+
return IncidentAction(
|
| 164 |
+
actor="triage_agent",
|
| 165 |
+
action_type="inspect_logs",
|
| 166 |
+
target=target,
|
| 167 |
+
reason="LLM output invalid; using safe fallback action.",
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# ------------------------------------------------------------------
|
| 171 |
+
# Resource cleanup
|
| 172 |
+
# ------------------------------------------------------------------
|
| 173 |
+
|
| 174 |
+
def release(self) -> None:
|
| 175 |
+
"""Free GPU memory so a second model can be loaded after this one."""
|
| 176 |
+
try:
|
| 177 |
+
import gc
|
| 178 |
+
self.model = None # type: ignore[assignment]
|
| 179 |
+
self.tokenizer = None # type: ignore[assignment]
|
| 180 |
+
gc.collect()
|
| 181 |
+
if self._torch.cuda.is_available():
|
| 182 |
+
self._torch.cuda.empty_cache()
|
| 183 |
+
except Exception:
|
| 184 |
+
pass
|
train_trl.py
CHANGED
|
@@ -1,17 +1,22 @@
|
|
| 1 |
"""Hugging Face TRL training + evaluation pipeline.
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
1.
|
| 6 |
-
Center environment to
|
| 7 |
-
2.
|
| 8 |
-
single `text` column
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
|
| 17 |
from __future__ import annotations
|
|
@@ -19,9 +24,9 @@ from __future__ import annotations
|
|
| 19 |
import json
|
| 20 |
import os
|
| 21 |
import random
|
| 22 |
-
from dataclasses import dataclass
|
| 23 |
from pathlib import Path
|
| 24 |
-
from typing import Dict, List
|
| 25 |
|
| 26 |
import matplotlib.pyplot as plt
|
| 27 |
from datasets import Dataset
|
|
@@ -33,15 +38,18 @@ from models import IncidentAction, IncidentObservation
|
|
| 33 |
|
| 34 |
ARTIFACT_DIR = Path("artifacts")
|
| 35 |
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 36 |
|
| 37 |
ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:8000")
|
| 38 |
BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
|
| 39 |
MAX_ROLLOUT_STEPS = int(os.getenv("MAX_ROLLOUT_STEPS", "120"))
|
|
|
|
| 40 |
EPISODES_PER_TASK = int(os.getenv("EPISODES_PER_TASK", "3"))
|
| 41 |
TRAIN_EPOCHS = float(os.getenv("TRAIN_EPOCHS", "1"))
|
| 42 |
TRAIN_BATCH_SIZE = int(os.getenv("TRAIN_BATCH_SIZE", "1"))
|
| 43 |
TRAIN_GRAD_ACCUM = int(os.getenv("TRAIN_GRAD_ACCUM", "2"))
|
| 44 |
TRAIN_MAX_LENGTH = int(os.getenv("TRAIN_MAX_LENGTH", "768"))
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
@dataclass
|
|
@@ -99,18 +107,28 @@ def rollout(
|
|
| 99 |
policy_name: str,
|
| 100 |
task_name: str,
|
| 101 |
collect_dataset: bool = False,
|
|
|
|
|
|
|
| 102 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
env = IncidentCommandEnvClient(base_url=ENV_URL).sync()
|
| 104 |
coordinator = HeuristicCoordinator()
|
| 105 |
records: List[Dict[str, str]] = []
|
| 106 |
rewards: List[float] = []
|
| 107 |
steps = 0
|
|
|
|
| 108 |
|
| 109 |
try:
|
| 110 |
result = env.reset(task_name=task_name)
|
| 111 |
-
while not result.done and steps <
|
| 112 |
steps += 1
|
| 113 |
-
if
|
|
|
|
|
|
|
| 114 |
action = coordinator.select_action(result.observation)
|
| 115 |
else:
|
| 116 |
action = random_action(result.observation)
|
|
@@ -157,11 +175,7 @@ def build_training_dataset(episodes_per_task: int = EPISODES_PER_TASK) -> Datase
|
|
| 157 |
|
| 158 |
|
| 159 |
def _dataset_to_sft_text_column(dataset: Dataset, tokenizer) -> Dataset:
|
| 160 |
-
"""Collapse (prompt, completion) pairs into a single `text` field.
|
| 161 |
-
|
| 162 |
-
The ``text`` column path in TRL 0.20+ is the most version-robust option,
|
| 163 |
-
side-stepping brittle prompt/completion tokenization across TRL releases.
|
| 164 |
-
"""
|
| 165 |
from transformers import PreTrainedTokenizerBase
|
| 166 |
|
| 167 |
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
|
@@ -172,7 +186,8 @@ def _dataset_to_sft_text_column(dataset: Dataset, tokenizer) -> Dataset:
|
|
| 172 |
dataset = dataset.rename_column("response", "completion")
|
| 173 |
if "prompt" not in dataset.column_names or "completion" not in dataset.column_names:
|
| 174 |
raise ValueError(
|
| 175 |
-
f"Expected columns 'prompt' and 'completion' (or 'response').
|
|
|
|
| 176 |
)
|
| 177 |
|
| 178 |
has_template = bool(getattr(tokenizer, "chat_template", None))
|
|
@@ -200,7 +215,11 @@ def _dataset_to_sft_text_column(dataset: Dataset, tokenizer) -> Dataset:
|
|
| 200 |
return dataset.map(to_text_batched, batched=True, remove_columns=to_drop)
|
| 201 |
|
| 202 |
|
| 203 |
-
def run_trl_sft(dataset: Dataset) ->
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
try:
|
| 205 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 206 |
from trl import SFTConfig, SFTTrainer
|
|
@@ -237,36 +256,161 @@ def run_trl_sft(dataset: Dataset) -> None:
|
|
| 237 |
)
|
| 238 |
trainer.train()
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
# ---------------------------------------------------------------------------
|
| 242 |
# Evaluation + reporting
|
| 243 |
# ---------------------------------------------------------------------------
|
| 244 |
|
| 245 |
|
| 246 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
random.seed(seed)
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
for task in ["easy", "medium", "hard"]:
|
| 252 |
random_stats, _, _ = rollout("random", task)
|
| 253 |
heuristic_stats, _, _ = rollout("heuristic", task)
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
|
| 260 |
def plot_rewards(score_map: Dict[str, List[float]]) -> None:
|
| 261 |
labels = ["easy", "medium", "hard"]
|
| 262 |
x = list(range(len(labels)))
|
| 263 |
-
plt.figure(figsize=(
|
| 264 |
-
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
plt.xticks(x, labels)
|
| 267 |
plt.xlabel("Task difficulty")
|
| 268 |
plt.ylabel("Episode total reward")
|
| 269 |
-
plt.title("Incident Command Center —
|
|
|
|
| 270 |
plt.grid(alpha=0.3)
|
| 271 |
plt.legend()
|
| 272 |
plt.tight_layout()
|
|
@@ -276,7 +420,7 @@ def plot_rewards(score_map: Dict[str, List[float]]) -> None:
|
|
| 276 |
|
| 277 |
def main() -> None:
|
| 278 |
dataset = build_training_dataset(episodes_per_task=EPISODES_PER_TASK)
|
| 279 |
-
dataset.save_to_disk(
|
| 280 |
|
| 281 |
run_trl_sft(dataset)
|
| 282 |
scores = evaluate_policies()
|
|
@@ -286,10 +430,17 @@ def main() -> None:
|
|
| 286 |
"base_model": BASE_MODEL,
|
| 287 |
"dataset_rows": len(dataset),
|
| 288 |
"episodes_per_task": EPISODES_PER_TASK,
|
| 289 |
-
"random_rewards": scores
|
| 290 |
-
"heuristic_rewards": scores
|
| 291 |
-
"
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
],
|
| 294 |
}
|
| 295 |
with open(ARTIFACT_DIR / "summary_metrics.json", "w", encoding="utf-8") as f:
|
|
|
|
| 1 |
"""Hugging Face TRL training + evaluation pipeline.
|
| 2 |
|
| 3 |
+
Pipeline:
|
| 4 |
+
|
| 5 |
+
1. **Rollout**: run the ``HeuristicCoordinator`` against the live Incident
|
| 6 |
+
Command Center environment to collect ``(prompt, completion)`` pairs.
|
| 7 |
+
2. **SFT**: fine-tune a small instruction-tuned LLM on those pairs using
|
| 8 |
+
TRL's ``SFTTrainer`` with a single ``text`` column (robust across TRL
|
| 9 |
+
≥ 0.20).
|
| 10 |
+
3. **Save**: persist the fine-tuned weights + tokenizer to
|
| 11 |
+
``artifacts/sft_model`` so the same script can later load them as an
|
| 12 |
+
agent policy.
|
| 13 |
+
4. **Evaluate**: play the environment with four policies
|
| 14 |
+
``random / heuristic / base_model / sft_model`` under identical seeds
|
| 15 |
+
and write a reward curve + metrics JSON into ``artifacts/``.
|
| 16 |
+
|
| 17 |
+
Designed to work on CPU for smoke checks and on Colab T4 / HF Spaces GPUs
|
| 18 |
+
for full runs. LLM evaluation auto-enables on CUDA and can be forced with
|
| 19 |
+
``EVAL_LLM_MODELS=true``.
|
| 20 |
"""
|
| 21 |
|
| 22 |
from __future__ import annotations
|
|
|
|
| 24 |
import json
|
| 25 |
import os
|
| 26 |
import random
|
| 27 |
+
from dataclasses import dataclass
|
| 28 |
from pathlib import Path
|
| 29 |
+
from typing import Callable, Dict, List, Optional
|
| 30 |
|
| 31 |
import matplotlib.pyplot as plt
|
| 32 |
from datasets import Dataset
|
|
|
|
| 38 |
|
| 39 |
ARTIFACT_DIR = Path("artifacts")
|
| 40 |
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
|
| 41 |
+
SFT_MODEL_DIR = ARTIFACT_DIR / "sft_model"
|
| 42 |
|
| 43 |
ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:8000")
|
| 44 |
BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
|
| 45 |
MAX_ROLLOUT_STEPS = int(os.getenv("MAX_ROLLOUT_STEPS", "120"))
|
| 46 |
+
MAX_LLM_EVAL_STEPS = int(os.getenv("MAX_LLM_EVAL_STEPS", "60"))
|
| 47 |
EPISODES_PER_TASK = int(os.getenv("EPISODES_PER_TASK", "3"))
|
| 48 |
TRAIN_EPOCHS = float(os.getenv("TRAIN_EPOCHS", "1"))
|
| 49 |
TRAIN_BATCH_SIZE = int(os.getenv("TRAIN_BATCH_SIZE", "1"))
|
| 50 |
TRAIN_GRAD_ACCUM = int(os.getenv("TRAIN_GRAD_ACCUM", "2"))
|
| 51 |
TRAIN_MAX_LENGTH = int(os.getenv("TRAIN_MAX_LENGTH", "768"))
|
| 52 |
+
_EVAL_LLM_ENV = os.getenv("EVAL_LLM_MODELS", "auto").strip().lower()
|
| 53 |
|
| 54 |
|
| 55 |
@dataclass
|
|
|
|
| 107 |
policy_name: str,
|
| 108 |
task_name: str,
|
| 109 |
collect_dataset: bool = False,
|
| 110 |
+
policy_callable: Optional[Callable[[IncidentObservation], IncidentAction]] = None,
|
| 111 |
+
max_steps: Optional[int] = None,
|
| 112 |
):
|
| 113 |
+
"""Play one episode and return (stats, rows, rewards).
|
| 114 |
+
|
| 115 |
+
If ``policy_callable`` is provided it takes precedence over
|
| 116 |
+
``policy_name`` — this is how the LLM policies plug in.
|
| 117 |
+
"""
|
| 118 |
env = IncidentCommandEnvClient(base_url=ENV_URL).sync()
|
| 119 |
coordinator = HeuristicCoordinator()
|
| 120 |
records: List[Dict[str, str]] = []
|
| 121 |
rewards: List[float] = []
|
| 122 |
steps = 0
|
| 123 |
+
step_cap = max_steps if max_steps is not None else MAX_ROLLOUT_STEPS
|
| 124 |
|
| 125 |
try:
|
| 126 |
result = env.reset(task_name=task_name)
|
| 127 |
+
while not result.done and steps < step_cap:
|
| 128 |
steps += 1
|
| 129 |
+
if policy_callable is not None:
|
| 130 |
+
action = policy_callable(result.observation)
|
| 131 |
+
elif policy_name == "heuristic":
|
| 132 |
action = coordinator.select_action(result.observation)
|
| 133 |
else:
|
| 134 |
action = random_action(result.observation)
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
def _dataset_to_sft_text_column(dataset: Dataset, tokenizer) -> Dataset:
|
| 178 |
+
"""Collapse (prompt, completion) pairs into a single `text` field."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
from transformers import PreTrainedTokenizerBase
|
| 180 |
|
| 181 |
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
|
|
|
| 186 |
dataset = dataset.rename_column("response", "completion")
|
| 187 |
if "prompt" not in dataset.column_names or "completion" not in dataset.column_names:
|
| 188 |
raise ValueError(
|
| 189 |
+
f"Expected columns 'prompt' and 'completion' (or 'response'). "
|
| 190 |
+
f"Got: {dataset.column_names}"
|
| 191 |
)
|
| 192 |
|
| 193 |
has_template = bool(getattr(tokenizer, "chat_template", None))
|
|
|
|
| 215 |
return dataset.map(to_text_batched, batched=True, remove_columns=to_drop)
|
| 216 |
|
| 217 |
|
| 218 |
+
def run_trl_sft(dataset: Dataset) -> Path:
|
| 219 |
+
"""Fine-tune ``BASE_MODEL`` on the collected dataset and save the model.
|
| 220 |
+
|
| 221 |
+
Returns the directory of the saved SFT checkpoint (``artifacts/sft_model``).
|
| 222 |
+
"""
|
| 223 |
try:
|
| 224 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 225 |
from trl import SFTConfig, SFTTrainer
|
|
|
|
| 256 |
)
|
| 257 |
trainer.train()
|
| 258 |
|
| 259 |
+
SFT_MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 260 |
+
trainer.save_model(str(SFT_MODEL_DIR))
|
| 261 |
+
tokenizer.save_pretrained(str(SFT_MODEL_DIR))
|
| 262 |
+
print(f"[train] Saved SFT checkpoint to {SFT_MODEL_DIR}")
|
| 263 |
+
|
| 264 |
+
del trainer, model, tokenizer
|
| 265 |
+
_free_gpu_memory()
|
| 266 |
+
return SFT_MODEL_DIR
|
| 267 |
+
|
| 268 |
|
| 269 |
# ---------------------------------------------------------------------------
|
| 270 |
# Evaluation + reporting
|
| 271 |
# ---------------------------------------------------------------------------
|
| 272 |
|
| 273 |
|
| 274 |
+
def _free_gpu_memory() -> None:
|
| 275 |
+
try:
|
| 276 |
+
import gc
|
| 277 |
+
gc.collect()
|
| 278 |
+
import torch
|
| 279 |
+
|
| 280 |
+
if torch.cuda.is_available():
|
| 281 |
+
torch.cuda.empty_cache()
|
| 282 |
+
except Exception:
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def _cuda_available() -> bool:
|
| 287 |
+
try:
|
| 288 |
+
import torch
|
| 289 |
+
|
| 290 |
+
return torch.cuda.is_available()
|
| 291 |
+
except Exception:
|
| 292 |
+
return False
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _should_evaluate_llms() -> bool:
|
| 296 |
+
if _EVAL_LLM_ENV in {"1", "true", "yes", "on"}:
|
| 297 |
+
return True
|
| 298 |
+
if _EVAL_LLM_ENV in {"0", "false", "no", "off"}:
|
| 299 |
+
return False
|
| 300 |
+
# "auto" / empty: enable only when a CUDA GPU is available so CPU runs
|
| 301 |
+
# stay fast.
|
| 302 |
+
return _cuda_available()
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _evaluate_single_policy(
|
| 306 |
+
policy_name: str,
|
| 307 |
+
select_fn: Callable[[IncidentObservation], IncidentAction],
|
| 308 |
+
max_steps: Optional[int] = None,
|
| 309 |
+
) -> List[float]:
|
| 310 |
+
scores: List[float] = []
|
| 311 |
+
for task in ["easy", "medium", "hard"]:
|
| 312 |
+
stats, _, _ = rollout(
|
| 313 |
+
policy_name=policy_name,
|
| 314 |
+
task_name=task,
|
| 315 |
+
policy_callable=select_fn,
|
| 316 |
+
max_steps=max_steps,
|
| 317 |
+
)
|
| 318 |
+
print(
|
| 319 |
+
f"[eval] policy={policy_name} task={task} "
|
| 320 |
+
f"reward={stats.total_reward:+.2f} steps={stats.steps}"
|
| 321 |
+
)
|
| 322 |
+
scores.append(round(stats.total_reward, 4))
|
| 323 |
+
return scores
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def evaluate_policies(
|
| 327 |
+
seed: int = 7,
|
| 328 |
+
evaluate_llms: Optional[bool] = None,
|
| 329 |
+
) -> Dict[str, List[float]]:
|
| 330 |
+
"""Run each policy once per task under the same seed.
|
| 331 |
+
|
| 332 |
+
The random policy is seeded for reproducibility. The heuristic policy is
|
| 333 |
+
deterministic already. LLM policies are evaluated with greedy decoding.
|
| 334 |
+
"""
|
| 335 |
random.seed(seed)
|
| 336 |
+
|
| 337 |
+
scores: Dict[str, List[float]] = {
|
| 338 |
+
"random": [],
|
| 339 |
+
"heuristic": [],
|
| 340 |
+
"base_model": [],
|
| 341 |
+
"sft_model": [],
|
| 342 |
+
}
|
| 343 |
|
| 344 |
for task in ["easy", "medium", "hard"]:
|
| 345 |
random_stats, _, _ = rollout("random", task)
|
| 346 |
heuristic_stats, _, _ = rollout("heuristic", task)
|
| 347 |
+
scores["random"].append(round(random_stats.total_reward, 4))
|
| 348 |
+
scores["heuristic"].append(round(heuristic_stats.total_reward, 4))
|
| 349 |
+
|
| 350 |
+
should_eval_llms = _should_evaluate_llms() if evaluate_llms is None else evaluate_llms
|
| 351 |
+
if not should_eval_llms:
|
| 352 |
+
print("[eval] Skipping LLM evaluation (no GPU or EVAL_LLM_MODELS=false).")
|
| 353 |
+
return scores
|
| 354 |
+
|
| 355 |
+
try:
|
| 356 |
+
from llm_policy import LLMPolicy
|
| 357 |
+
except Exception as exc: # pragma: no cover - import-time safety
|
| 358 |
+
print(f"[eval] Could not import LLMPolicy ({exc}); skipping LLM eval.")
|
| 359 |
+
return scores
|
| 360 |
|
| 361 |
+
# Base model
|
| 362 |
+
try:
|
| 363 |
+
print(f"[eval] Loading BASE model: {BASE_MODEL}")
|
| 364 |
+
base = LLMPolicy(BASE_MODEL, label="base_model")
|
| 365 |
+
scores["base_model"] = _evaluate_single_policy(
|
| 366 |
+
"base_model", base.select_action, max_steps=MAX_LLM_EVAL_STEPS
|
| 367 |
+
)
|
| 368 |
+
base.release()
|
| 369 |
+
_free_gpu_memory()
|
| 370 |
+
except Exception as exc:
|
| 371 |
+
print(f"[eval] Base-model evaluation failed: {exc}")
|
| 372 |
+
|
| 373 |
+
# SFT model
|
| 374 |
+
if SFT_MODEL_DIR.exists():
|
| 375 |
+
try:
|
| 376 |
+
print(f"[eval] Loading SFT model: {SFT_MODEL_DIR}")
|
| 377 |
+
sft = LLMPolicy(str(SFT_MODEL_DIR), label="sft_model")
|
| 378 |
+
scores["sft_model"] = _evaluate_single_policy(
|
| 379 |
+
"sft_model", sft.select_action, max_steps=MAX_LLM_EVAL_STEPS
|
| 380 |
+
)
|
| 381 |
+
sft.release()
|
| 382 |
+
_free_gpu_memory()
|
| 383 |
+
except Exception as exc:
|
| 384 |
+
print(f"[eval] SFT-model evaluation failed: {exc}")
|
| 385 |
+
else:
|
| 386 |
+
print(f"[eval] No SFT checkpoint found at {SFT_MODEL_DIR}; skipping SFT eval.")
|
| 387 |
+
|
| 388 |
+
return scores
|
| 389 |
|
| 390 |
|
| 391 |
def plot_rewards(score_map: Dict[str, List[float]]) -> None:
|
| 392 |
labels = ["easy", "medium", "hard"]
|
| 393 |
x = list(range(len(labels)))
|
| 394 |
+
plt.figure(figsize=(9, 5))
|
| 395 |
+
|
| 396 |
+
style = {
|
| 397 |
+
"random": ("x", "tab:red", "Random baseline"),
|
| 398 |
+
"heuristic": ("o", "tab:blue", "Heuristic coordinator"),
|
| 399 |
+
"base_model": ("^", "tab:orange", "Base LLM (untrained)"),
|
| 400 |
+
"sft_model": ("D", "tab:green", "Fine-tuned LLM (SFT)"),
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
for key, (marker, color, label) in style.items():
|
| 404 |
+
values = score_map.get(key) or []
|
| 405 |
+
if not values or len(values) != len(labels):
|
| 406 |
+
continue
|
| 407 |
+
plt.plot(x, values, marker=marker, color=color, label=label, linewidth=2)
|
| 408 |
+
|
| 409 |
plt.xticks(x, labels)
|
| 410 |
plt.xlabel("Task difficulty")
|
| 411 |
plt.ylabel("Episode total reward")
|
| 412 |
+
plt.title("Incident Command Center — policy comparison")
|
| 413 |
+
plt.axhline(0, linestyle="--", color="gray", alpha=0.5)
|
| 414 |
plt.grid(alpha=0.3)
|
| 415 |
plt.legend()
|
| 416 |
plt.tight_layout()
|
|
|
|
| 420 |
|
| 421 |
def main() -> None:
|
| 422 |
dataset = build_training_dataset(episodes_per_task=EPISODES_PER_TASK)
|
| 423 |
+
dataset.save_to_disk(str(ARTIFACT_DIR / "trl_dataset"))
|
| 424 |
|
| 425 |
run_trl_sft(dataset)
|
| 426 |
scores = evaluate_policies()
|
|
|
|
| 430 |
"base_model": BASE_MODEL,
|
| 431 |
"dataset_rows": len(dataset),
|
| 432 |
"episodes_per_task": EPISODES_PER_TASK,
|
| 433 |
+
"random_rewards": scores.get("random", []),
|
| 434 |
+
"heuristic_rewards": scores.get("heuristic", []),
|
| 435 |
+
"base_model_rewards": scores.get("base_model", []),
|
| 436 |
+
"sft_model_rewards": scores.get("sft_model", []),
|
| 437 |
+
"improvement_sft_over_base": [
|
| 438 |
+
round(s - b, 4)
|
| 439 |
+
for s, b in zip(scores.get("sft_model", []), scores.get("base_model", []))
|
| 440 |
+
] if scores.get("sft_model") and scores.get("base_model") else [],
|
| 441 |
+
"improvement_heuristic_over_random": [
|
| 442 |
+
round(h - r, 4)
|
| 443 |
+
for h, r in zip(scores.get("heuristic", []), scores.get("random", []))
|
| 444 |
],
|
| 445 |
}
|
| 446 |
with open(ARTIFACT_DIR / "summary_metrics.json", "w", encoding="utf-8") as f:
|