Spaces:
Sleeping
Sleeping
| """ | |
| training/train_qwen3_grpo.py β ShadowOps GRPO Training on Qwen3-4B-Base | |
| ========================================================================= | |
| v4.1 β All 4 actions fire at meaningful rates; Q-aware > Heuristic | |
| Key fixes from v4.0: | |
| - episode_length: 40 (was 20) β QUARANTINE hold+resolution fits in episode | |
| - QuarantineAwarePolicy: tiered threshold ladder so BLOCK is reachable | |
| (was sandwiched between FORK and QUARANTINE with no reachable band) | |
| - TierAwareOptimalPolicy: oracle ceiling added for blog comparison | |
| - evaluate() now logs A/B/F/Q% per policy so BLOCK>0 is visible | |
| - run_preflight(): automated checks before training starts | |
| - All _compute_reward callers unpack as: outcome, reward = ... (correct order) | |
| Expected baseline output: | |
| [Random ] rew=+400β500 Bβ25% Qβ25% | |
| [Heuristic (no quarantine) ] rew=+300β350 Bβ25% Q=0% | |
| [Quarantine-aware ] rew=+450β550 Bβ18% Qβ14% β beats heuristic | |
| [Oracle (tier-aware ceiling) ] rew=+600β700 Bβ20% Qβ15% β upper bound | |
| Run: | |
| # CPU baseline only (no GPU): | |
| python training/train_qwen3_grpo.py --skip-training | |
| # Full GRPO on GPU: | |
| python training/train_qwen3_grpo.py | |
| # Use 1.7B if T4 is tight (<12 GB free VRAM): | |
| python training/train_qwen3_grpo.py --model 1.7b | |
| """ | |
| import sys | |
| import os | |
| import json | |
| import random | |
| import statistics | |
| import argparse | |
| from pathlib import Path | |
| from collections import defaultdict | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from shadowops_env import ( | |
| UniversalShadowEnv, | |
| ATTACK_SCENARIOS, | |
| ACTIONS, | |
| OBS_DIM, | |
| extract_features, | |
| build_llama_prompt, | |
| compute_ambiguity, | |
| R, | |
| DOMAINS, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Config | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_OPTIONS = { | |
| "4b": "unsloth/Qwen3-4B-Base", | |
| "1.7b": "unsloth/Qwen3-1.7B", | |
| "8b": "unsloth/Qwen3-8B-Base", | |
| } | |
| CFG = { | |
| "model_name": MODEL_OPTIONS["4b"], | |
| "output_dir": "./shadowops_qwen3_model", | |
| "n_train_episodes": 500, | |
| "n_val_episodes": 100, | |
| # ββ CRITICAL: 40 steps so QUARANTINE hold (3 steps) resolves in-episode ββ | |
| "episode_length": 40, | |
| "malicious_rate": 0.5, | |
| "seed": 42, | |
| # GRPO hyperparams | |
| "learning_rate": 1e-5, | |
| "max_steps": 300, | |
| "batch_size": 2, | |
| "grad_accum": 8, | |
| "max_seq_len": 512, | |
| "num_generations": 4, | |
| "max_completion_len": 8, | |
| "lora_r": 16, | |
| "lora_alpha": 32, | |
| "patience": 8, | |
| "val_check_every": 30, | |
| } | |
| SYSTEM_PREFIX = ( | |
| "You are ShadowOps Supervisor. /no_think\n" | |
| "Analyse payload + risk vector. " | |
| "Output exactly one word: ALLOW, BLOCK, FORK, or QUARANTINE.\n" | |
| "QUARANTINE = hold action 3 steps, gather SIEM context, then resolve." | |
| ) | |
| VALID_ACTIONS = set(ACTIONS.values()) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. Policies | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RandomPolicy: | |
| """Uniform random over all 4 actions. Lower bound baseline.""" | |
| def __init__(self, seed=0): | |
| self.rng = random.Random(seed) | |
| def act(self, obs_text, obs_vec): | |
| return self.rng.randint(0, 3) | |
| class HeuristicPolicy: | |
| """ | |
| Rule-based, no quarantine. Strong 3-action baseline. | |
| risk >= 0.65 β FORK, >= 0.35 β BLOCK, else ALLOW. | |
| """ | |
| def act(self, obs_text, obs_vec): | |
| rv = obs_vec[:16] | |
| risk = rv[0]*0.35 + rv[1]*0.25 + rv[3]*0.20 + rv[6]*0.20 | |
| if risk >= 0.65: return 2 # FORK | |
| if risk >= 0.35: return 1 # BLOCK | |
| return 0 # ALLOW | |
| class QuarantineAwarePolicy: | |
| """ | |
| 4-action policy with a tiered threshold ladder. | |
| Routing (fresh scenario): | |
| risk >= 0.75 β FORK (high-confidence hard malicious) | |
| risk >= 0.55 β BLOCK (moderate-high, not worth shadow env) | |
| ambiguity > 0.35 β QUARANTINE (genuinely unclear) | |
| risk >= 0.30 β BLOCK (low-ambiguity mid-risk, fast reject) | |
| else β ALLOW | |
| Resolution (active hold, q_steps == 0): | |
| risk > 0.45 β FORK | |
| else β ALLOW | |
| Fix vs v4.0: | |
| Old policy had ambiguity check BEFORE the lower BLOCK band, so BLOCK | |
| was only reachable for risk in [0.35, 0.40) β a tiny sliver. | |
| Now BLOCK has two explicit bands: [0.55, 0.75) and [0.30, ambiguity_threshold). | |
| """ | |
| def act(self, obs_text, obs_vec): | |
| rv = obs_vec[:16] | |
| q_active = obs_vec[16] > 0.5 | |
| q_steps = obs_vec[17] | |
| risk = rv[0]*0.35 + rv[1]*0.25 + rv[3]*0.20 + rv[6]*0.20 | |
| ambiguity = compute_ambiguity(rv) | |
| # ββ Resolve active quarantine hold βββββββββββββββββββββ | |
| if q_active: | |
| if q_steps > 0.0: | |
| return 3 # still in hold window | |
| return 2 if risk > 0.45 else 0 # resolve: FORK or ALLOW | |
| # ββ Fresh scenario routing βββββββββββββββββββββββββββββ | |
| if risk >= 0.75: return 2 # FORK β MALICIOUS_HARD | |
| if risk >= 0.55: return 1 # BLOCK β upper mid-risk band | |
| if ambiguity > 0.35: return 3 # QUARANTINE β ambiguous | |
| if risk >= 0.30: return 1 # BLOCK β lower mid-risk band | |
| return 0 # ALLOW β benign | |
| class TierAwareOptimalPolicy: | |
| """ | |
| Oracle ceiling β uses feature vector proxies for tier knowledge. | |
| bracket_override (rv[6]) + exfil (rv[2]) high β MALICIOUS_HARD β FORK | |
| risk >= 0.55 β MALICIOUS_SOFT β BLOCK | |
| ambiguity > 0.35 β AMBIGUOUS β QUARANTINE | |
| else β BENIGN β ALLOW | |
| Useful for the HuggingFace blog: shows the gap the GRPO model closes. | |
| """ | |
| def act(self, obs_text, obs_vec): | |
| rv = obs_vec[:16] | |
| risk = rv[0]*0.35 + rv[1]*0.25 + rv[3]*0.20 + rv[6]*0.20 | |
| ambiguity = compute_ambiguity(rv) | |
| bracket_override = rv[6] | |
| exfil = rv[2] | |
| if bracket_override > 0.6 or exfil > 0.5: return 2 # FORK | |
| if risk >= 0.55: return 1 # BLOCK | |
| if ambiguity > 0.35: return 3 # QUARANTINE | |
| return 0 # ALLOW | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. Trajectory collection | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def collect_episodes(n_episodes, policy, seed, malicious_rate=0.5, | |
| episode_length=None) -> list: | |
| ep_len = episode_length or CFG["episode_length"] | |
| env = UniversalShadowEnv( | |
| malicious_rate=malicious_rate, | |
| episode_max_length=ep_len, | |
| mode="training", | |
| seed=seed, | |
| ) | |
| episodes = [] | |
| for ep_i in range(n_episodes): | |
| obs_text, obs_vec = env.reset() | |
| done = False | |
| steps = [] | |
| while not done: | |
| action_int = policy.act(obs_text, obs_vec) | |
| obs_text, obs_vec, reward, done, info = env.step(action_int) | |
| steps.append({ | |
| "prompt": obs_text, | |
| "action": ACTIONS[action_int], | |
| "reward": reward, | |
| "is_malicious": info["is_malicious"], | |
| "domain": info["domain"], | |
| "outcome": info["outcome"], | |
| "q_active": obs_vec[16] > 0.5, | |
| "q_steps": int(round(obs_vec[17] * 3)), | |
| }) | |
| episodes.append({ | |
| "episode_id": ep_i, | |
| "steps": steps, | |
| "total_reward": sum(s["reward"] for s in steps), | |
| "accuracy": sum( | |
| 1 for s in steps | |
| if (s["is_malicious"] and s["action"] in ("BLOCK", "FORK", "QUARANTINE")) | |
| or (not s["is_malicious"] and s["action"] == "ALLOW") | |
| ) / max(len(steps), 1), | |
| }) | |
| return episodes | |
| def episodes_to_grpo(episodes: list) -> list: | |
| samples = [] | |
| for ep in episodes: | |
| for s in ep["steps"]: | |
| prompt = s["prompt"].replace( | |
| "You are ShadowOps Supervisor.", | |
| SYSTEM_PREFIX, | |
| 1, | |
| ) | |
| samples.append({ | |
| "query": prompt, | |
| "response": s["action"], | |
| "reward": float(s["reward"]), | |
| }) | |
| return samples | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. Evaluation helper | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def evaluate(policy, n_episodes=50, seed=99, label="") -> dict: | |
| eps = collect_episodes(n_episodes, policy, seed=seed) | |
| rews = [e["total_reward"] for e in eps] | |
| accs = [e["accuracy"] for e in eps] | |
| act_counts = defaultdict(int) | |
| total_steps = 0 | |
| for e in eps: | |
| for s in e["steps"]: | |
| act_counts[s["action"]] += 1 | |
| total_steps += 1 | |
| fp = sum(1 for e in eps for s in e["steps"] | |
| if not s["is_malicious"] and s["action"] in ("BLOCK", "FORK", "QUARANTINE")) | |
| benign = sum(1 for e in eps for s in e["steps"] if not s["is_malicious"]) | |
| fpr = fp / max(benign, 1) | |
| a_pct = act_counts["ALLOW"] / max(total_steps, 1) | |
| b_pct = act_counts["BLOCK"] / max(total_steps, 1) | |
| f_pct = act_counts["FORK"] / max(total_steps, 1) | |
| q_pct = act_counts["QUARANTINE"] / max(total_steps, 1) | |
| res = { | |
| "label": label, | |
| "mean_reward": statistics.mean(rews), | |
| "std_reward": statistics.stdev(rews) if len(rews) > 1 else 0.0, | |
| "mean_acc": statistics.mean(accs), | |
| "fpr": fpr, | |
| "allow_pct": a_pct, | |
| "block_pct": b_pct, | |
| "fork_pct": f_pct, | |
| "quarantine_pct": q_pct, | |
| "rewards": rews, | |
| "accuracies": accs, | |
| } | |
| print(f" [{label:<35}] " | |
| f"rew={res['mean_reward']:+7.1f}Β±{res['std_reward']:.1f} " | |
| f"acc={res['mean_acc']:.1%} " | |
| f"fpr={fpr:.1%} " | |
| f"A={a_pct:.1%} B={b_pct:.1%} F={f_pct:.1%} Q={q_pct:.1%}") | |
| return res | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. Reward function for GRPO | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_reward_fn(valid_tokens=None): | |
| valid = VALID_ACTIONS if valid_tokens is None else valid_tokens | |
| def reward_fn(completions, **kwargs): | |
| rewards = [] | |
| for comp in completions: | |
| text = comp.strip().split()[0].upper() if comp.strip() else "" | |
| if text in valid: | |
| rewards.append(0.0) | |
| else: | |
| rewards.append(-8.0) | |
| return rewards | |
| return reward_fn | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. Unsloth + TRL GRPO training | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train_grpo(train_samples: list, val_episodes: list, model_name: str): | |
| try: | |
| from unsloth import FastModel | |
| from trl import GRPOConfig, GRPOTrainer | |
| from datasets import Dataset | |
| import torch | |
| except ImportError as e: | |
| print(f"\n Missing dependency: {e}") | |
| print(" Install with: pip install unsloth trl datasets torch") | |
| return None | |
| print(f"\n[3] Loading {model_name} with Unsloth (4-bit QLoRA) β¦") | |
| model, tokenizer = FastModel.from_pretrained( | |
| model_name = model_name, | |
| max_seq_length = CFG["max_seq_len"], | |
| load_in_4bit = True, | |
| load_in_8bit = False, | |
| full_finetuning = False, | |
| ) | |
| model = FastModel.get_peft_model( | |
| model, | |
| r = CFG["lora_r"], | |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha = CFG["lora_alpha"], | |
| lora_dropout = 0.05, | |
| bias = "none", | |
| use_gradient_checkpointing = "unsloth", | |
| ) | |
| grpo_data = Dataset.from_list([ | |
| {"query": s["query"], "response": s["response"], "reward": s["reward"]} | |
| for s in train_samples | |
| ]) | |
| reward_fn = make_reward_fn() | |
| best_val = float("-inf") | |
| patience = 0 | |
| val_curve = [] | |
| def _val_check(step): | |
| nonlocal best_val, patience | |
| FastModel.for_inference(model) | |
| correct = total = 0 | |
| total_r = 0.0 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| for ep in val_episodes[:20]: | |
| for s in ep["steps"]: | |
| inp = tokenizer(s["prompt"], return_tensors="pt", | |
| truncation=True, | |
| max_length=CFG["max_seq_len"]).to(device) | |
| with torch.no_grad(): | |
| out = model.generate(**inp, max_new_tokens=8, do_sample=False) | |
| decoded = tokenizer.decode(out[0], skip_special_tokens=True) | |
| pred = decoded.strip().split()[-1].upper() | |
| action = pred if pred in VALID_ACTIONS else "BLOCK" | |
| is_mal = s["is_malicious"] | |
| if ((is_mal and action in ("BLOCK", "FORK", "QUARANTINE")) or | |
| (not is_mal and action == "ALLOW")): | |
| correct += 1 | |
| total += 1 | |
| total_r += s["reward"] | |
| FastModel.for_training(model) | |
| val_r = total_r / max(total, 1) | |
| val_acc = correct / max(total, 1) | |
| val_curve.append({"step": step, "val_reward": val_r, "val_acc": val_acc}) | |
| print(f" Val step {step:3d}: reward={val_r:+.2f} acc={val_acc:.2%}") | |
| if val_r > best_val: | |
| best_val = val_r | |
| patience = 0 | |
| model.save_pretrained(CFG["output_dir"]) | |
| tokenizer.save_pretrained(CFG["output_dir"]) | |
| print(f" β New best: {best_val:.2f} β saved to {CFG['output_dir']}") | |
| else: | |
| patience += 1 | |
| print(f" No improvement ({patience}/{CFG['patience']})") | |
| return patience >= CFG["patience"] | |
| grpo_cfg = GRPOConfig( | |
| learning_rate = CFG["learning_rate"], | |
| max_steps = CFG["max_steps"], | |
| per_device_train_batch_size = CFG["batch_size"], | |
| gradient_accumulation_steps = CFG["grad_accum"], | |
| warmup_ratio = 0.05, | |
| output_dir = CFG["output_dir"], | |
| logging_steps = 10, | |
| save_steps = CFG["val_check_every"], | |
| fp16 = not torch.cuda.is_bf16_supported(), | |
| bf16 = torch.cuda.is_bf16_supported(), | |
| report_to = "none", | |
| seed = CFG["seed"], | |
| num_generations = CFG["num_generations"], | |
| max_completion_length = CFG["max_completion_len"], | |
| temperature = 0.7, | |
| stop_token_ids = [], | |
| ) | |
| trainer = GRPOTrainer( | |
| model = model, | |
| processing_class = tokenizer, | |
| config = grpo_cfg, | |
| train_dataset = grpo_data, | |
| reward_funcs = reward_fn, | |
| ) | |
| print(f"\n[4] GRPO training starts β¦") | |
| print(f" Model : {model_name}") | |
| print(f" OBS_DIM : {OBS_DIM} (16 risk + 2 quarantine state)") | |
| print(f" Actions : {sorted(VALID_ACTIONS)}") | |
| print(f" Train samples : {len(train_samples)}") | |
| print(f" Val episodes : {len(val_episodes)}") | |
| print(f" Episode length : {CFG['episode_length']}") | |
| print(f" Max steps : {CFG['max_steps']}") | |
| print(f" Batch (eff.) : {CFG['batch_size'] * CFG['grad_accum']}") | |
| print() | |
| for chunk_start in range(0, CFG["max_steps"], CFG["val_check_every"]): | |
| trainer.train() | |
| if _val_check(chunk_start + CFG["val_check_every"]): | |
| print(f"\n Early stopping at step {chunk_start + CFG['val_check_every']}") | |
| break | |
| return {"val_curve": val_curve, "best_val_reward": best_val} | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 6. Pre-flight sanity check | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_preflight(): | |
| """ | |
| Automated checks before GPU training. | |
| Runs 200 episodes per policy (fast, CPU-only). | |
| All 5 checks must pass before training proceeds. | |
| """ | |
| print("\n[PREFLIGHT] Sanity checks (200 episodes @ ep_len=40) β¦") | |
| n = 200 | |
| r_heur = evaluate(HeuristicPolicy(), n, seed=10, label="Heuristic (preflight)") | |
| r_qaw = evaluate(QuarantineAwarePolicy(), n, seed=11, label="Q-aware (preflight)") | |
| r_oracle = evaluate(TierAwareOptimalPolicy(), n, seed=12, label="Oracle (preflight)") | |
| checks = [ | |
| ("BLOCK% > 0% for Q-aware", r_qaw["block_pct"] > 0.0), | |
| ("QUARANTINE% > 0% for Q-aware", r_qaw["quarantine_pct"] > 0.0), | |
| ("FORK% > 0% for Q-aware", r_qaw["fork_pct"] > 0.0), | |
| ("Q-aware reward > Heuristic", r_qaw["mean_reward"] > r_heur["mean_reward"]), | |
| ("Oracle reward > Q-aware", r_oracle["mean_reward"] > r_qaw["mean_reward"]), | |
| ] | |
| all_pass = True | |
| print() | |
| for desc, passed in checks: | |
| status = "β PASS" if passed else "β FAIL" | |
| all_pass = all_pass and passed | |
| print(f" {status} {desc}") | |
| if all_pass: | |
| print("\n β All preflight checks passed β safe to start GRPO training.\n") | |
| else: | |
| print("\n β One or more preflight checks FAILED β fix before training.\n") | |
| return all_pass, r_heur, r_qaw, r_oracle | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 7. Main | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--skip-training", action="store_true", | |
| help="Baselines + dataset only (no GPU required)") | |
| parser.add_argument("--skip-preflight", action="store_true", | |
| help="Skip preflight checks (not recommended)") | |
| parser.add_argument("--model", choices=["4b", "1.7b", "8b"], default="4b", | |
| help="Model size (default: 4b = Qwen3-4B-Base)") | |
| args = parser.parse_args() | |
| model_name = MODEL_OPTIONS[args.model] | |
| rng = random.Random(CFG["seed"]) | |
| print("=" * 70) | |
| print(" ShadowOps β GRPO Training Pipeline (Qwen3 Edition, 4-action RL)") | |
| print("=" * 70) | |
| print(f" episode_length : {CFG['episode_length']} " | |
| f"(40 required so QUARANTINE resolves in-episode)") | |
| print(f" n_train_episodes: {CFG['n_train_episodes']}") | |
| print(f" malicious_rate : {CFG['malicious_rate']}") | |
| print() | |
| # ββ Preflight ββββββββββββββββββββββββββββββββββββββββββββ | |
| if not args.skip_preflight: | |
| ok, r_heur_pre, r_qaw_pre, r_oracle_pre = run_preflight() | |
| if not ok and not args.skip_training: | |
| print(" Aborting β fix preflight failures before training.") | |
| return | |
| else: | |
| print(" [PREFLIGHT] Skipped.\n") | |
| # ββ Step 1: Full baselines βββββββββββββββββββββββββββββββ | |
| print("[1] Evaluating 4 baseline policies (50 episodes each) β¦") | |
| r_rand = evaluate(RandomPolicy(seed=0), 50, seed=1, label="Random (uniform 4-action)") | |
| r_heur = evaluate(HeuristicPolicy(), 50, seed=2, label="Heuristic (no quarantine)") | |
| r_qaw = evaluate(QuarantineAwarePolicy(), 50, seed=3, label="Quarantine-aware") | |
| r_oracle = evaluate(TierAwareOptimalPolicy(), 50, seed=4, label="Oracle (tier-aware ceiling)") | |
| # ββ Step 2: Collect training data βββββββββββββββββββββββ | |
| n_total = CFG["n_train_episodes"] + CFG["n_val_episodes"] | |
| print(f"\n[2] Collecting {n_total} episodes with QuarantineAwarePolicy β¦") | |
| policy = QuarantineAwarePolicy() | |
| all_eps = collect_episodes(n_total, policy, seed=CFG["seed"], | |
| malicious_rate=CFG["malicious_rate"]) | |
| rng.shuffle(all_eps) | |
| train_eps = all_eps[:CFG["n_train_episodes"]] | |
| val_eps = all_eps[CFG["n_train_episodes"]:] | |
| train_samples = episodes_to_grpo(train_eps) | |
| val_samples = episodes_to_grpo(val_eps) | |
| act_dist = defaultdict(int) | |
| for s in train_samples: | |
| act_dist[s["response"]] += 1 | |
| total = len(train_samples) | |
| print(f" Train steps : {total} | Val steps : {len(val_samples)}") | |
| print(" Action distribution (training data):") | |
| for act in ["ALLOW", "BLOCK", "FORK", "QUARANTINE"]: | |
| c = act_dist[act] | |
| bar = "β" * max(1, int(c / total * 40)) | |
| print(f" {act:<12} {c:6d} ({c/total:5.1%}) {bar}") | |
| Path("training").mkdir(exist_ok=True) | |
| with open("training/qwen3_train_dataset.json", "w") as f: | |
| json.dump(train_samples[:100], f, indent=2) | |
| print(" Sample saved β training/qwen3_train_dataset.json") | |
| # ββ Step 3: Train ββββββββββββββββββββββββββββββββββββββββ | |
| llama_results = None | |
| if not args.skip_training: | |
| llama_results = train_grpo(train_samples, val_eps, model_name) | |
| else: | |
| print(f"\n[3] Skipping training (--skip-training). Model would be: {model_name}") | |
| # ββ Step 4: Save reward curves βββββββββββββββββββββββββββ | |
| print("\n[4] Saving reward curves β¦") | |
| curves = { | |
| "model": model_name, | |
| "obs_dim": OBS_DIM, | |
| "n_actions": 4, | |
| "action_names": sorted(VALID_ACTIONS), | |
| "episode_length": CFG["episode_length"], | |
| "baselines": { | |
| "random": {"mean_reward": r_rand["mean_reward"], | |
| "fpr": r_rand["fpr"], | |
| "block_pct": r_rand["block_pct"], | |
| "quarantine_pct": r_rand["quarantine_pct"], | |
| "rewards": r_rand["rewards"]}, | |
| "heuristic": {"mean_reward": r_heur["mean_reward"], | |
| "fpr": r_heur["fpr"], | |
| "block_pct": r_heur["block_pct"], | |
| "quarantine_pct": r_heur["quarantine_pct"], | |
| "rewards": r_heur["rewards"]}, | |
| "quarantine_aware": {"mean_reward": r_qaw["mean_reward"], | |
| "fpr": r_qaw["fpr"], | |
| "block_pct": r_qaw["block_pct"], | |
| "quarantine_pct": r_qaw["quarantine_pct"], | |
| "rewards": r_qaw["rewards"]}, | |
| "oracle": {"mean_reward": r_oracle["mean_reward"], | |
| "fpr": r_oracle["fpr"], | |
| "block_pct": r_oracle["block_pct"], | |
| "quarantine_pct": r_oracle["quarantine_pct"], | |
| "rewards": r_oracle["rewards"]}, | |
| }, | |
| "train": { | |
| "rewards": [ep["total_reward"] for ep in train_eps], | |
| "accuracies": [ep["accuracy"] for ep in train_eps], | |
| }, | |
| "val": { | |
| "rewards": [ep["total_reward"] for ep in val_eps], | |
| "accuracies": [ep["accuracy"] for ep in val_eps], | |
| }, | |
| "grpo_val_curve": llama_results["val_curve"] if llama_results else [], | |
| "best_val_reward": llama_results["best_val_reward"] if llama_results else None, | |
| "action_dist": dict(act_dist), | |
| "improvements": { | |
| "heuristic_over_random": r_heur["mean_reward"] - r_rand["mean_reward"], | |
| "quarantine_over_heuristic": r_qaw["mean_reward"] - r_heur["mean_reward"], | |
| "oracle_over_quarantine": r_oracle["mean_reward"] - r_qaw["mean_reward"], | |
| "fpr_reduction_heur_vs_rand": r_rand["fpr"] - r_heur["fpr"], | |
| "fpr_reduction_qaw_vs_heur": r_heur["fpr"] - r_qaw["fpr"], | |
| }, | |
| } | |
| with open("reward_curves_qwen3.json", "w") as f: | |
| json.dump(curves, f, indent=2) | |
| # ββ Summary ββββββββββββββββββββββββββββββββββββββββββββββ | |
| qaw_delta = curves["improvements"]["quarantine_over_heuristic"] | |
| print("\n" + "=" * 70) | |
| print(" SUMMARY") | |
| print("=" * 70) | |
| print(f" episode_length : {CFG['episode_length']}") | |
| print(f" Model : {model_name}") | |
| print() | |
| print(f" Random baseline : {r_rand['mean_reward']:+7.1f} " | |
| f"B={r_rand['block_pct']:.1%} Q={r_rand['quarantine_pct']:.1%}") | |
| print(f" Heuristic (no quarantine) : {r_heur['mean_reward']:+7.1f} " | |
| f"B={r_heur['block_pct']:.1%} Q={r_heur['quarantine_pct']:.1%}") | |
| print(f" Quarantine-aware heuristic : {r_qaw['mean_reward']:+7.1f} " | |
| f"B={r_qaw['block_pct']:.1%} Q={r_qaw['quarantine_pct']:.1%}") | |
| print(f" Oracle (tier-aware ceiling) : {r_oracle['mean_reward']:+7.1f} " | |
| f"B={r_oracle['block_pct']:.1%} Q={r_oracle['quarantine_pct']:.1%}") | |
| print() | |
| print(f" heuristic β q-aware delta : {qaw_delta:+.1f}" | |
| + (" β Q-aware wins" if qaw_delta > 0 else " β UNEXPECTED β check thresholds")) | |
| print(f" q-aware β oracle delta : {curves['improvements']['oracle_over_quarantine']:+.1f}") | |
| print(f" FPR reduction heurβq-aware : {curves['improvements']['fpr_reduction_qaw_vs_heur']:.1%}") | |
| if llama_results: | |
| print(f"\n Best Qwen3 GRPO val reward : {llama_results['best_val_reward']:+.1f}") | |
| print(f"\n reward_curves_qwen3.json saved") | |
| print(f" Run plot_curves.py to generate graphs for your HuggingFace blog\n") | |
| if __name__ == "__main__": | |
| main() |