File size: 7,551 Bytes
c3648b5 8cbdbde c3648b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | """Before/after demo: base model vs fine-tuned model on the SAME incident.
Runs both policies against the same task under the same seed, prints a clean
side-by-side trace, and writes ``artifacts/before_after_demo.md`` which you
can paste into the blog post or other writeups.
Usage (after ``train_trl.py`` has saved ``artifacts/sft_model``)::
ENV_URL=http://127.0.0.1:8000 python scripts/before_after_demo.py
Env variables:
ENV_URL β URL of a running Incident Command Center server
BASE_MODEL β HF hub id of the base model
SFT_MODEL_DIR β path to the fine-tuned checkpoint (default: artifacts/sft_model)
DEMO_TASK β task difficulty to demo (default: hard)
DEMO_MAX_STEPS β per-episode step cap (default: 120)
"""
from __future__ import annotations
import json
import os
import random
import sys
from pathlib import Path
from typing import Dict, List, Optional
# Ensure repo root on sys.path when invoked from subdirectory
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from client import IncidentCommandEnvClient # noqa: E402
from models import IncidentAction, IncidentObservation # noqa: E402
ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:8000")
BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
SFT_MODEL_DIR = os.getenv("SFT_MODEL_DIR", "artifacts/sft_model")
DEMO_TASK = os.getenv("DEMO_TASK", "hard")
DEMO_MAX_STEPS = int(os.getenv("DEMO_MAX_STEPS", "120"))
DEMO_SEED = int(os.getenv("DEMO_SEED", "2026"))
def _format_obs_summary(obs: IncidentObservation) -> str:
return (
f"{obs.incident_title} "
f"(tier={obs.customer_tier}, users={obs.affected_users_estimate}, "
f"$/min={obs.revenue_impact_usd_per_min})"
)
def _format_action(action: IncidentAction) -> str:
target = action.target or "-"
bits = [f"{action.actor}:{action.action_type}:{target}"]
if action.reason:
bits.append(f"reason={action.reason[:80]}")
return " | ".join(bits)
def _format_components(components: Optional[Dict[str, float]]) -> str:
if not components:
return "-"
return ", ".join(f"{k}={v:+.2f}" for k, v in components.items())
def _rollout_with_policy(policy_name: str, select_fn) -> Dict:
env = IncidentCommandEnvClient(base_url=ENV_URL).sync()
random.seed(DEMO_SEED)
steps_log: List[Dict] = []
total_reward = 0.0
components_sum: Dict[str, float] = {}
closed_incidents = 0
incident_seen: List[str] = []
try:
result = env.reset(task_name=DEMO_TASK)
step_idx = 0
while not result.done and step_idx < DEMO_MAX_STEPS:
step_idx += 1
obs = result.observation
if obs.incident_id not in incident_seen:
incident_seen.append(obs.incident_id)
action = select_fn(obs)
result = env.step(action)
reward = float(result.reward or 0.0)
total_reward += reward
new_obs = result.observation
step_components = getattr(new_obs, "reward_components", None) or {}
for k, v in step_components.items():
components_sum[k] = components_sum.get(k, 0.0) + float(v)
if action.action_type == "close_incident" and reward > 0:
closed_incidents += 1
steps_log.append(
{
"step": step_idx,
"incident": obs.incident_id,
"summary": _format_obs_summary(obs),
"action": _format_action(action),
"reward": round(reward, 3),
"components": _format_components(step_components),
}
)
finally:
try:
env.close()
except Exception:
pass
return {
"policy": policy_name,
"task": DEMO_TASK,
"steps": len(steps_log),
"total_reward": round(total_reward, 3),
"incidents_seen": incident_seen,
"incidents_closed": closed_incidents,
"components_sum": {k: round(v, 3) for k, v in components_sum.items()},
"trace": steps_log,
}
def _write_markdown(base_run: Dict, sft_run: Dict, out_path: Path) -> None:
lines: List[str] = []
lines.append(f"# Before vs After β {DEMO_TASK.title()} task demo\n")
lines.append(f"Both policies ran against the same seeded task (`{DEMO_TASK}`, seed {DEMO_SEED}) ")
lines.append("on an identical Incident Command Center server. Each sees the same incident ")
lines.append("queue in the same order.\n")
lines.append("## Headline\n")
lines.append(f"| Policy | Total reward | Steps | Incidents closed |")
lines.append(f"|---|---:|---:|---:|")
lines.append(
f"| Base `{BASE_MODEL}` | {base_run['total_reward']:+.2f} | "
f"{base_run['steps']} | {base_run['incidents_closed']} |"
)
lines.append(
f"| **Fine-tuned (SFT)** | **{sft_run['total_reward']:+.2f}** | "
f"{sft_run['steps']} | {sft_run['incidents_closed']} |"
)
delta = sft_run["total_reward"] - base_run["total_reward"]
lines.append(f"\n**Reward delta: {delta:+.2f}** in favor of fine-tuned.\n")
lines.append("## Reward sources (summed across the episode)\n")
lines.append("| Component | Base | Fine-tuned |")
lines.append("|---|---:|---:|")
all_keys = sorted(set(base_run["components_sum"]) | set(sft_run["components_sum"]))
for k in all_keys:
lines.append(
f"| `{k}` | {base_run['components_sum'].get(k, 0.0):+.2f} | "
f"{sft_run['components_sum'].get(k, 0.0):+.2f} |"
)
def _trace_block(run: Dict, title: str) -> None:
lines.append(f"\n## Trace β {title}\n")
lines.append("```")
for row in run["trace"]:
lines.append(
f"step {row['step']:>3} | incident={row['incident']} | "
f"{row['action']} | reward={row['reward']:+.2f} | {row['components']}"
)
lines.append("```")
_trace_block(base_run, f"Base model ({BASE_MODEL})")
_trace_block(sft_run, "Fine-tuned (SFT) model")
out_path.write_text("\n".join(lines), encoding="utf-8")
def main() -> None:
from llm_policy import LLMPolicy
print(f"[demo] task={DEMO_TASK} seed={DEMO_SEED} env={ENV_URL}")
print(f"[demo] Loading base model: {BASE_MODEL}")
base_policy = LLMPolicy(BASE_MODEL, label="base_model")
base_run = _rollout_with_policy("base_model", base_policy.select_action)
base_policy.release()
print(f"[demo] Loading SFT model: {SFT_MODEL_DIR}")
sft_policy = LLMPolicy(SFT_MODEL_DIR, label="sft_model")
sft_run = _rollout_with_policy("sft_model", sft_policy.select_action)
sft_policy.release()
art_dir = Path("artifacts")
art_dir.mkdir(exist_ok=True)
md_path = art_dir / "before_after_demo.md"
json_path = art_dir / "before_after_demo.json"
_write_markdown(base_run, sft_run, md_path)
with json_path.open("w", encoding="utf-8") as f:
json.dump({"base": base_run, "sft": sft_run}, f, indent=2)
print(f"[demo] Base total={base_run['total_reward']:+.2f} "
f"steps={base_run['steps']} closed={base_run['incidents_closed']}")
print(f"[demo] SFT total={sft_run['total_reward']:+.2f} "
f"steps={sft_run['steps']} closed={sft_run['incidents_closed']}")
print(f"[demo] Wrote {md_path} and {json_path}")
if __name__ == "__main__":
main()
|