"""Before/after demo: base model vs fine-tuned model on the SAME incident. Runs both policies against the same task under the same seed, prints a clean side-by-side trace, and writes ``artifacts/before_after_demo.md`` which you can paste into the blog post or other writeups. Usage (after ``train_trl.py`` has saved ``artifacts/sft_model``):: ENV_URL=http://127.0.0.1:8000 python scripts/before_after_demo.py Env variables: ENV_URL — URL of a running Incident Command Center server BASE_MODEL — HF hub id of the base model SFT_MODEL_DIR — path to the fine-tuned checkpoint (default: artifacts/sft_model) DEMO_TASK — task difficulty to demo (default: hard) DEMO_MAX_STEPS — per-episode step cap (default: 120) """ from __future__ import annotations import json import os import random import sys from pathlib import Path from typing import Dict, List, Optional # Ensure repo root on sys.path when invoked from subdirectory _REPO_ROOT = Path(__file__).resolve().parents[1] if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) from client import IncidentCommandEnvClient # noqa: E402 from models import IncidentAction, IncidentObservation # noqa: E402 ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:8000") BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") SFT_MODEL_DIR = os.getenv("SFT_MODEL_DIR", "artifacts/sft_model") DEMO_TASK = os.getenv("DEMO_TASK", "hard") DEMO_MAX_STEPS = int(os.getenv("DEMO_MAX_STEPS", "120")) DEMO_SEED = int(os.getenv("DEMO_SEED", "2026")) def _format_obs_summary(obs: IncidentObservation) -> str: return ( f"{obs.incident_title} " f"(tier={obs.customer_tier}, users={obs.affected_users_estimate}, " f"$/min={obs.revenue_impact_usd_per_min})" ) def _format_action(action: IncidentAction) -> str: target = action.target or "-" bits = [f"{action.actor}:{action.action_type}:{target}"] if action.reason: bits.append(f"reason={action.reason[:80]}") return " | ".join(bits) def _format_components(components: Optional[Dict[str, float]]) -> str: if not components: return "-" return ", ".join(f"{k}={v:+.2f}" for k, v in components.items()) def _rollout_with_policy(policy_name: str, select_fn) -> Dict: env = IncidentCommandEnvClient(base_url=ENV_URL).sync() random.seed(DEMO_SEED) steps_log: List[Dict] = [] total_reward = 0.0 components_sum: Dict[str, float] = {} closed_incidents = 0 incident_seen: List[str] = [] try: result = env.reset(task_name=DEMO_TASK) step_idx = 0 while not result.done and step_idx < DEMO_MAX_STEPS: step_idx += 1 obs = result.observation if obs.incident_id not in incident_seen: incident_seen.append(obs.incident_id) action = select_fn(obs) result = env.step(action) reward = float(result.reward or 0.0) total_reward += reward new_obs = result.observation step_components = getattr(new_obs, "reward_components", None) or {} for k, v in step_components.items(): components_sum[k] = components_sum.get(k, 0.0) + float(v) if action.action_type == "close_incident" and reward > 0: closed_incidents += 1 steps_log.append( { "step": step_idx, "incident": obs.incident_id, "summary": _format_obs_summary(obs), "action": _format_action(action), "reward": round(reward, 3), "components": _format_components(step_components), } ) finally: try: env.close() except Exception: pass return { "policy": policy_name, "task": DEMO_TASK, "steps": len(steps_log), "total_reward": round(total_reward, 3), "incidents_seen": incident_seen, "incidents_closed": closed_incidents, "components_sum": {k: round(v, 3) for k, v in components_sum.items()}, "trace": steps_log, } def _write_markdown(base_run: Dict, sft_run: Dict, out_path: Path) -> None: lines: List[str] = [] lines.append(f"# Before vs After — {DEMO_TASK.title()} task demo\n") lines.append(f"Both policies ran against the same seeded task (`{DEMO_TASK}`, seed {DEMO_SEED}) ") lines.append("on an identical Incident Command Center server. Each sees the same incident ") lines.append("queue in the same order.\n") lines.append("## Headline\n") lines.append(f"| Policy | Total reward | Steps | Incidents closed |") lines.append(f"|---|---:|---:|---:|") lines.append( f"| Base `{BASE_MODEL}` | {base_run['total_reward']:+.2f} | " f"{base_run['steps']} | {base_run['incidents_closed']} |" ) lines.append( f"| **Fine-tuned (SFT)** | **{sft_run['total_reward']:+.2f}** | " f"{sft_run['steps']} | {sft_run['incidents_closed']} |" ) delta = sft_run["total_reward"] - base_run["total_reward"] lines.append(f"\n**Reward delta: {delta:+.2f}** in favor of fine-tuned.\n") lines.append("## Reward sources (summed across the episode)\n") lines.append("| Component | Base | Fine-tuned |") lines.append("|---|---:|---:|") all_keys = sorted(set(base_run["components_sum"]) | set(sft_run["components_sum"])) for k in all_keys: lines.append( f"| `{k}` | {base_run['components_sum'].get(k, 0.0):+.2f} | " f"{sft_run['components_sum'].get(k, 0.0):+.2f} |" ) def _trace_block(run: Dict, title: str) -> None: lines.append(f"\n## Trace — {title}\n") lines.append("```") for row in run["trace"]: lines.append( f"step {row['step']:>3} | incident={row['incident']} | " f"{row['action']} | reward={row['reward']:+.2f} | {row['components']}" ) lines.append("```") _trace_block(base_run, f"Base model ({BASE_MODEL})") _trace_block(sft_run, "Fine-tuned (SFT) model") out_path.write_text("\n".join(lines), encoding="utf-8") def main() -> None: from llm_policy import LLMPolicy print(f"[demo] task={DEMO_TASK} seed={DEMO_SEED} env={ENV_URL}") print(f"[demo] Loading base model: {BASE_MODEL}") base_policy = LLMPolicy(BASE_MODEL, label="base_model") base_run = _rollout_with_policy("base_model", base_policy.select_action) base_policy.release() print(f"[demo] Loading SFT model: {SFT_MODEL_DIR}") sft_policy = LLMPolicy(SFT_MODEL_DIR, label="sft_model") sft_run = _rollout_with_policy("sft_model", sft_policy.select_action) sft_policy.release() art_dir = Path("artifacts") art_dir.mkdir(exist_ok=True) md_path = art_dir / "before_after_demo.md" json_path = art_dir / "before_after_demo.json" _write_markdown(base_run, sft_run, md_path) with json_path.open("w", encoding="utf-8") as f: json.dump({"base": base_run, "sft": sft_run}, f, indent=2) print(f"[demo] Base total={base_run['total_reward']:+.2f} " f"steps={base_run['steps']} closed={base_run['incidents_closed']}") print(f"[demo] SFT total={sft_run['total_reward']:+.2f} " f"steps={sft_run['steps']} closed={sft_run['incidents_closed']}") print(f"[demo] Wrote {md_path} and {json_path}") if __name__ == "__main__": main()