"""SFT (Supervised Fine-Tuning) on SecureReview ground-truth findings. Industry-standard pipeline: train the model to output the env's ground-truth JSON exactly. Much faster + bigger improvements than GRPO alone, because we're directly teaching the model the correct answer instead of waiting for RL exploration to find it. Pipeline matches GRPO train.py format: - Same env (live SecureReview Space) - Same baseline + post-training evaluation - Saves plots/reward_curve.png + plots/before_after.png + plots/results.json """ import os import sys import json import re import time import requests import functools print = functools.partial(print, flush=True) # ── Config ──────────────────────────────────────────────────────────────────── ENV_URL = os.getenv("ENV_URL", "https://sam25kat-securereview.hf.space") ENV_REPO_ID = os.getenv("ENV_REPO_ID", "sam25kat/securereview") TASK_ID = "dependency_review" SCENARIO_FOLDER = "dependency" # subdir under app/tasks/scenarios/ MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct" MAX_SEQ_LEN = 1536 MAX_NEW_TOKENS = 600 NUM_EPOCHS = 3 LEARNING_RATE = 5e-5 LORA_RANK = 16 GRAD_ACCUM_STEPS = 2 OUTPUT_DIR = "./securereview-sft" PLOTS_DIR = "./plots" os.makedirs(PLOTS_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) SYSTEM_PROMPT = """You are a senior security engineer reviewing dependency files for vulnerabilities. Identify ALL security issues including: - Typosquatted packages (names that misspell popular libraries, e.g. 'reqeusts' instead of 'requests') - Known CVE-vulnerable versions (e.g. requests<2.20.0 has CVE-2018-18074) - Hallucinated / non-existent packages that don't exist on PyPI or npm - Suspicious or malicious packages Output ONLY a valid JSON array of findings. Each finding must have: file, line (integer or null), rule_id (e.g. DEP-001), severity (critical/high/medium/low/info), description Output ONLY the JSON array. No explanations, no markdown prose.""" # ── Environment helpers (same as GRPO version) ──────────────────────────────── def env_reset(task_id, scenario_id=None): payload = {"task_id": task_id} if scenario_id: payload["scenario_id"] = scenario_id r = requests.post(f"{ENV_URL}/reset", json=payload, timeout=30) r.raise_for_status() return r.json() def env_step(action): r = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=30) r.raise_for_status() return r.json() def parse_findings(text): patterns = [ r'```(?:json)?\s*(\[.*?\])\s*```', r'(\[\s*\{.*?\}\s*\])', ] for pattern in patterns: m = re.search(pattern, text, re.DOTALL) if m: try: return json.loads(m.group(1)) except json.JSONDecodeError: continue return [] def run_episode(completion, scenario_id): findings = parse_findings(completion) try: env_reset(TASK_ID, scenario_id) valid_sev = {"critical", "high", "medium", "low", "info"} for f in findings: if not isinstance(f, dict): continue finding = { "file": str(f.get("file", "requirements.txt")), "line": f.get("line"), "rule_id": str(f.get("rule_id", "DEP-001")), "severity": f.get("severity", "medium") if f.get("severity") in valid_sev else "medium", "description": str(f.get("description", ""))[:400], } env_step({"action_type": "report_finding", "finding": finding}) result = env_step({"action_type": "mark_complete"}) return float(result.get("reward", 0.01)) except Exception as e: print(f" [env error] {e}") return 0.01 def build_prompt(obs): ctx = obs["observation"]["context"] files = ctx["files"] parts = [f"Task: {ctx['task_description']}\n"] for f in files: parts.append(f"\n--- {f['filename']} ---\n{f['content']}") parts.append("\nList all security issues as a JSON array:") return "".join(parts) # ── Ground-truth fetching (the SFT-specific bit) ────────────────────────────── def fetch_ground_truth(scenario_id): """Download ground_truth.json for a scenario from the env Space repo.""" from huggingface_hub import hf_hub_download # scenario_id like "dep_001" or "migration_002" -> scenario_NNN folder num = scenario_id.split("_")[-1] path = hf_hub_download( repo_id=ENV_REPO_ID, repo_type="space", filename=f"app/tasks/scenarios/{SCENARIO_FOLDER}/scenario_{num}/ground_truth.json", ) with open(path) as f: return json.load(f) def gt_to_target_json(gt_data): """Convert ground_truth.json's 'ground_truth' list into the JSON array the model should output. Strips internal fields (match_key, category).""" findings = [] for f in gt_data["ground_truth"]: findings.append({ "file": f["file"], "line": f.get("line"), "rule_id": f["rule_id"], "severity": f["severity"], "description": f["description"], }) return json.dumps(findings, indent=2) # ── Main ────────────────────────────────────────────────────────────────────── def main(): print("=" * 60) print(" SecureReview SFT Training") print(f" Model : {MODEL_NAME}") print(f" Task : {TASK_ID}") print(f" Epochs: {NUM_EPOCHS}") print("=" * 60) # Verify env print("\n[1/6] Checking environment connection...") r = requests.get(f"{ENV_URL}/health", timeout=15) print(f" Health: {r.json()}") # Load model print("\n[2/6] Loading model...") from unsloth import FastLanguageModel import torch model, tokenizer = FastLanguageModel.from_pretrained( model_name = MODEL_NAME, max_seq_length = MAX_SEQ_LEN, dtype = torch.float16, load_in_4bit = True, ) model = FastLanguageModel.get_peft_model( model, r = LORA_RANK, target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_alpha = LORA_RANK, lora_dropout = 0, bias = "none", use_gradient_checkpointing = "unsloth", random_state = 42, ) model.print_trainable_parameters() # Build SFT dataset from ground truth print("\n[3/6] Building SFT dataset from ground-truth findings...") from datasets import Dataset scenario_ids = [f"dep_{i:03d}" for i in range(1, 25)] # dep override per task examples = [] for sid in scenario_ids: try: obs = env_reset(TASK_ID, sid) user = build_prompt(obs) gt = fetch_ground_truth(sid) target = gt_to_target_json(gt) messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user}, {"role": "assistant", "content": target}, ] full_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) examples.append({"text": full_text}) print(f" Loaded {sid} ({len(gt['ground_truth'])} findings)") except Exception as e: print(f" Skipping {sid}: {e}") dataset = Dataset.from_list(examples) print(f" Dataset: {len(examples)} examples") # Baseline eval print("\n[4/6] Baseline evaluation (before SFT)...") FastLanguageModel.for_inference(model) def evaluate(sids, label): scores = {} for sid in sids: obs = env_reset(TASK_ID, sid) prompt_text = build_prompt(obs) messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt_text}, ] inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to("cuda") with torch.no_grad(): outputs = model.generate( inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=0.1, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) completion = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) score = run_episode(completion, sid) scores[sid] = score print(f" [{label}] {sid}: {score:.3f}") time.sleep(0.3) return scores baseline_scores = evaluate(scenario_ids, "before") print(f" Baseline mean: {sum(baseline_scores.values())/len(baseline_scores):.3f}") # SFT training print("\n[5/6] SFT training...") FastLanguageModel.for_training(model) from trl import SFTTrainer, SFTConfig sft_args = SFTConfig( output_dir = OUTPUT_DIR, max_seq_length = MAX_SEQ_LEN, num_train_epochs = NUM_EPOCHS, per_device_train_batch_size = 1, gradient_accumulation_steps = GRAD_ACCUM_STEPS, learning_rate = LEARNING_RATE, lr_scheduler_type = "cosine", warmup_ratio = 0.05, logging_steps = 2, save_steps = 50, fp16 = True, bf16 = False, optim = "adamw_8bit", weight_decay = 0.01, report_to = "none", seed = 42, dataset_text_field = "text", packing = False, ) def formatting_func(example): # SFTTrainer with Unsloth requires this even when dataset_text_field is set. return example["text"] trainer = SFTTrainer( model = model, processing_class = tokenizer, args = sft_args, train_dataset = dataset, formatting_func = formatting_func, ) trainer.train() # Capture loss history for plot loss_log = [ {"step": h.get("step", i), "loss": h["loss"]} for i, h in enumerate(trainer.state.log_history) if "loss" in h ] # Post-training eval print("\n[6/6] Post-SFT evaluation...") FastLanguageModel.for_inference(model) trained_scores = evaluate(scenario_ids, "after") print(f" Trained mean: {sum(trained_scores.values())/len(trained_scores):.3f}") print("\n=== Improvement Summary ===") for sid in scenario_ids: b = baseline_scores.get(sid, 0) t = trained_scores.get(sid, 0) arrow = "▲" if t > b else ("▼" if t < b else "—") print(f" {sid}: {b:.3f} → {t:.3f} {arrow} {t-b:+.3f}") # Plots import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np plt.style.use("dark_background") # Loss curve (instead of reward curve) if loss_log: steps = [e["step"] for e in loss_log] losses = [e["loss"] for e in loss_log] fig, ax = plt.subplots(figsize=(11, 4)) ax.plot(steps, losses, color="#ff6b35", linewidth=2) ax.set_xlabel("Training Step"); ax.set_ylabel("Loss") ax.set_title("SecureReview SFT — Training Loss", fontweight="bold") ax.grid(True, alpha=0.2) fig.tight_layout() plt.savefig(f"{PLOTS_DIR}/reward_curve.png", dpi=150, bbox_inches="tight") plt.close() print(f" Saved {PLOTS_DIR}/reward_curve.png") # Before/after bar chart b_vals = [baseline_scores.get(s, 0) for s in scenario_ids] t_vals = [trained_scores.get(s, 0) for s in scenario_ids] x = np.arange(len(scenario_ids)) fig, ax = plt.subplots(figsize=(12, 5)) ax.bar(x - 0.175, b_vals, 0.35, label="Before", color="#444444") ax.bar(x + 0.175, t_vals, 0.35, label="After", color="#ff6b35") for i, (b, t) in enumerate(zip(b_vals, t_vals)): if abs(t - b) > 0.005: ax.text(i + 0.175, t + 0.02, f"{t-b:+.2f}", ha="center", fontsize=8, color="#22d3ee" if t >= b else "#ef4444") ax.set_xticks(x) label_prefix = SCENARIO_FOLDER[:3].capitalize() + " " ax.set_xticklabels([s.split("_")[-1] for s in scenario_ids], rotation=15, fontsize=8) ax.set_ylim(0, 1); ax.legend() ax.set_title("SecureReview — Before vs After SFT", fontweight="bold") mb = sum(b_vals) / len(b_vals); mt = sum(t_vals) / len(t_vals) ax.text(0.98, 0.92, f"Mean: {mb:.2f} → {mt:.2f} ({mt-mb:+.2f})", transform=ax.transAxes, ha="right", fontsize=11, color="white", bbox=dict(boxstyle="round", facecolor="#1a1a1a", alpha=0.8)) fig.tight_layout() plt.savefig(f"{PLOTS_DIR}/before_after.png", dpi=150, bbox_inches="tight") plt.close() print(f" Saved {PLOTS_DIR}/before_after.png") # Save results JSON results = { "baseline_mean": sum(b_vals) / len(b_vals), "trained_mean": sum(t_vals) / len(t_vals), "improvement": sum(t_vals) / len(t_vals) - sum(b_vals) / len(b_vals), "baseline_scores": baseline_scores, "trained_scores": trained_scores, } with open(f"{PLOTS_DIR}/results.json", "w") as f: json.dump(results, f, indent=2) print("\n" + "=" * 60) print(f" DONE — Mean {sum(b_vals)/len(b_vals):.3f} → {sum(t_vals)/len(t_vals):.3f}") print("=" * 60) if __name__ == "__main__": main()