Spaces:
Sleeping
Sleeping
| """SFT (Supervised Fine-Tuning) on SecureReview ground-truth findings. | |
| Industry-standard pipeline: train the model to output the env's ground-truth | |
| JSON exactly. Much faster + bigger improvements than GRPO alone, because | |
| we're directly teaching the model the correct answer instead of waiting for | |
| RL exploration to find it. | |
| Pipeline matches GRPO train.py format: | |
| - Same env (live SecureReview Space) | |
| - Same baseline + post-training evaluation | |
| - Saves plots/reward_curve.png + plots/before_after.png + plots/results.json | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import re | |
| import time | |
| import requests | |
| import functools | |
| print = functools.partial(print, flush=True) | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ENV_URL = os.getenv("ENV_URL", "https://sam25kat-securereview.hf.space") | |
| ENV_REPO_ID = os.getenv("ENV_REPO_ID", "sam25kat/securereview") | |
| TASK_ID = "dependency_review" | |
| SCENARIO_FOLDER = "dependency" # subdir under app/tasks/scenarios/ | |
| MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct" | |
| MAX_SEQ_LEN = 1536 | |
| MAX_NEW_TOKENS = 600 | |
| NUM_EPOCHS = 3 | |
| LEARNING_RATE = 5e-5 | |
| LORA_RANK = 16 | |
| GRAD_ACCUM_STEPS = 2 | |
| OUTPUT_DIR = "./securereview-sft" | |
| PLOTS_DIR = "./plots" | |
| os.makedirs(PLOTS_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| SYSTEM_PROMPT = """You are a senior security engineer reviewing dependency files for vulnerabilities. | |
| Identify ALL security issues including: | |
| - Typosquatted packages (names that misspell popular libraries, e.g. 'reqeusts' instead of 'requests') | |
| - Known CVE-vulnerable versions (e.g. requests<2.20.0 has CVE-2018-18074) | |
| - Hallucinated / non-existent packages that don't exist on PyPI or npm | |
| - Suspicious or malicious packages | |
| Output ONLY a valid JSON array of findings. Each finding must have: | |
| file, line (integer or null), rule_id (e.g. DEP-001), severity (critical/high/medium/low/info), description | |
| Output ONLY the JSON array. No explanations, no markdown prose.""" | |
| # ββ Environment helpers (same as GRPO version) ββββββββββββββββββββββββββββββββ | |
| def env_reset(task_id, scenario_id=None): | |
| payload = {"task_id": task_id} | |
| if scenario_id: | |
| payload["scenario_id"] = scenario_id | |
| r = requests.post(f"{ENV_URL}/reset", json=payload, timeout=30) | |
| r.raise_for_status() | |
| return r.json() | |
| def env_step(action): | |
| r = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=30) | |
| r.raise_for_status() | |
| return r.json() | |
| def parse_findings(text): | |
| patterns = [ | |
| r'```(?:json)?\s*(\[.*?\])\s*```', | |
| r'(\[\s*\{.*?\}\s*\])', | |
| ] | |
| for pattern in patterns: | |
| m = re.search(pattern, text, re.DOTALL) | |
| if m: | |
| try: | |
| return json.loads(m.group(1)) | |
| except json.JSONDecodeError: | |
| continue | |
| return [] | |
| def run_episode(completion, scenario_id): | |
| findings = parse_findings(completion) | |
| try: | |
| env_reset(TASK_ID, scenario_id) | |
| valid_sev = {"critical", "high", "medium", "low", "info"} | |
| for f in findings: | |
| if not isinstance(f, dict): | |
| continue | |
| finding = { | |
| "file": str(f.get("file", "requirements.txt")), | |
| "line": f.get("line"), | |
| "rule_id": str(f.get("rule_id", "DEP-001")), | |
| "severity": f.get("severity", "medium") if f.get("severity") in valid_sev else "medium", | |
| "description": str(f.get("description", ""))[:400], | |
| } | |
| env_step({"action_type": "report_finding", "finding": finding}) | |
| result = env_step({"action_type": "mark_complete"}) | |
| return float(result.get("reward", 0.01)) | |
| except Exception as e: | |
| print(f" [env error] {e}") | |
| return 0.01 | |
| def build_prompt(obs): | |
| ctx = obs["observation"]["context"] | |
| files = ctx["files"] | |
| parts = [f"Task: {ctx['task_description']}\n"] | |
| for f in files: | |
| parts.append(f"\n--- {f['filename']} ---\n{f['content']}") | |
| parts.append("\nList all security issues as a JSON array:") | |
| return "".join(parts) | |
| # ββ Ground-truth fetching (the SFT-specific bit) ββββββββββββββββββββββββββββββ | |
| def fetch_ground_truth(scenario_id): | |
| """Download ground_truth.json for a scenario from the env Space repo.""" | |
| from huggingface_hub import hf_hub_download | |
| # scenario_id like "dep_001" or "migration_002" -> scenario_NNN folder | |
| num = scenario_id.split("_")[-1] | |
| path = hf_hub_download( | |
| repo_id=ENV_REPO_ID, | |
| repo_type="space", | |
| filename=f"app/tasks/scenarios/{SCENARIO_FOLDER}/scenario_{num}/ground_truth.json", | |
| ) | |
| with open(path) as f: | |
| return json.load(f) | |
| def gt_to_target_json(gt_data): | |
| """Convert ground_truth.json's 'ground_truth' list into the JSON array | |
| the model should output. Strips internal fields (match_key, category).""" | |
| findings = [] | |
| for f in gt_data["ground_truth"]: | |
| findings.append({ | |
| "file": f["file"], | |
| "line": f.get("line"), | |
| "rule_id": f["rule_id"], | |
| "severity": f["severity"], | |
| "description": f["description"], | |
| }) | |
| return json.dumps(findings, indent=2) | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| print("=" * 60) | |
| print(" SecureReview SFT Training") | |
| print(f" Model : {MODEL_NAME}") | |
| print(f" Task : {TASK_ID}") | |
| print(f" Epochs: {NUM_EPOCHS}") | |
| print("=" * 60) | |
| # Verify env | |
| print("\n[1/6] Checking environment connection...") | |
| r = requests.get(f"{ENV_URL}/health", timeout=15) | |
| print(f" Health: {r.json()}") | |
| # Load model | |
| print("\n[2/6] Loading model...") | |
| from unsloth import FastLanguageModel | |
| import torch | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = MODEL_NAME, | |
| max_seq_length = MAX_SEQ_LEN, | |
| dtype = torch.float16, | |
| load_in_4bit = True, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r = LORA_RANK, | |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha = LORA_RANK, | |
| lora_dropout = 0, | |
| bias = "none", | |
| use_gradient_checkpointing = "unsloth", | |
| random_state = 42, | |
| ) | |
| model.print_trainable_parameters() | |
| # Build SFT dataset from ground truth | |
| print("\n[3/6] Building SFT dataset from ground-truth findings...") | |
| from datasets import Dataset | |
| scenario_ids = [f"dep_{i:03d}" for i in range(1, 25)] # dep override per task | |
| examples = [] | |
| for sid in scenario_ids: | |
| try: | |
| obs = env_reset(TASK_ID, sid) | |
| user = build_prompt(obs) | |
| gt = fetch_ground_truth(sid) | |
| target = gt_to_target_json(gt) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user}, | |
| {"role": "assistant", "content": target}, | |
| ] | |
| full_text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=False | |
| ) | |
| examples.append({"text": full_text}) | |
| print(f" Loaded {sid} ({len(gt['ground_truth'])} findings)") | |
| except Exception as e: | |
| print(f" Skipping {sid}: {e}") | |
| dataset = Dataset.from_list(examples) | |
| print(f" Dataset: {len(examples)} examples") | |
| # Baseline eval | |
| print("\n[4/6] Baseline evaluation (before SFT)...") | |
| FastLanguageModel.for_inference(model) | |
| def evaluate(sids, label): | |
| scores = {} | |
| for sid in sids: | |
| obs = env_reset(TASK_ID, sid) | |
| prompt_text = build_prompt(obs) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt_text}, | |
| ] | |
| inputs = tokenizer.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" | |
| ).to("cuda") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs, max_new_tokens=MAX_NEW_TOKENS, | |
| temperature=0.1, do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| completion = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
| score = run_episode(completion, sid) | |
| scores[sid] = score | |
| print(f" [{label}] {sid}: {score:.3f}") | |
| time.sleep(0.3) | |
| return scores | |
| baseline_scores = evaluate(scenario_ids, "before") | |
| print(f" Baseline mean: {sum(baseline_scores.values())/len(baseline_scores):.3f}") | |
| # SFT training | |
| print("\n[5/6] SFT training...") | |
| FastLanguageModel.for_training(model) | |
| from trl import SFTTrainer, SFTConfig | |
| sft_args = SFTConfig( | |
| output_dir = OUTPUT_DIR, | |
| max_seq_length = MAX_SEQ_LEN, | |
| num_train_epochs = NUM_EPOCHS, | |
| per_device_train_batch_size = 1, | |
| gradient_accumulation_steps = GRAD_ACCUM_STEPS, | |
| learning_rate = LEARNING_RATE, | |
| lr_scheduler_type = "cosine", | |
| warmup_ratio = 0.05, | |
| logging_steps = 2, | |
| save_steps = 50, | |
| fp16 = True, | |
| bf16 = False, | |
| optim = "adamw_8bit", | |
| weight_decay = 0.01, | |
| report_to = "none", | |
| seed = 42, | |
| dataset_text_field = "text", | |
| packing = False, | |
| ) | |
| def formatting_func(example): | |
| # SFTTrainer with Unsloth requires this even when dataset_text_field is set. | |
| return example["text"] | |
| trainer = SFTTrainer( | |
| model = model, | |
| processing_class = tokenizer, | |
| args = sft_args, | |
| train_dataset = dataset, | |
| formatting_func = formatting_func, | |
| ) | |
| trainer.train() | |
| # Capture loss history for plot | |
| loss_log = [ | |
| {"step": h.get("step", i), "loss": h["loss"]} | |
| for i, h in enumerate(trainer.state.log_history) | |
| if "loss" in h | |
| ] | |
| # Post-training eval | |
| print("\n[6/6] Post-SFT evaluation...") | |
| FastLanguageModel.for_inference(model) | |
| trained_scores = evaluate(scenario_ids, "after") | |
| print(f" Trained mean: {sum(trained_scores.values())/len(trained_scores):.3f}") | |
| print("\n=== Improvement Summary ===") | |
| for sid in scenario_ids: | |
| b = baseline_scores.get(sid, 0) | |
| t = trained_scores.get(sid, 0) | |
| arrow = "β²" if t > b else ("βΌ" if t < b else "β") | |
| print(f" {sid}: {b:.3f} β {t:.3f} {arrow} {t-b:+.3f}") | |
| # Plots | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| plt.style.use("dark_background") | |
| # Loss curve (instead of reward curve) | |
| if loss_log: | |
| steps = [e["step"] for e in loss_log] | |
| losses = [e["loss"] for e in loss_log] | |
| fig, ax = plt.subplots(figsize=(11, 4)) | |
| ax.plot(steps, losses, color="#ff6b35", linewidth=2) | |
| ax.set_xlabel("Training Step"); ax.set_ylabel("Loss") | |
| ax.set_title("SecureReview SFT β Training Loss", fontweight="bold") | |
| ax.grid(True, alpha=0.2) | |
| fig.tight_layout() | |
| plt.savefig(f"{PLOTS_DIR}/reward_curve.png", dpi=150, bbox_inches="tight") | |
| plt.close() | |
| print(f" Saved {PLOTS_DIR}/reward_curve.png") | |
| # Before/after bar chart | |
| b_vals = [baseline_scores.get(s, 0) for s in scenario_ids] | |
| t_vals = [trained_scores.get(s, 0) for s in scenario_ids] | |
| x = np.arange(len(scenario_ids)) | |
| fig, ax = plt.subplots(figsize=(12, 5)) | |
| ax.bar(x - 0.175, b_vals, 0.35, label="Before", color="#444444") | |
| ax.bar(x + 0.175, t_vals, 0.35, label="After", color="#ff6b35") | |
| for i, (b, t) in enumerate(zip(b_vals, t_vals)): | |
| if abs(t - b) > 0.005: | |
| ax.text(i + 0.175, t + 0.02, f"{t-b:+.2f}", ha="center", fontsize=8, | |
| color="#22d3ee" if t >= b else "#ef4444") | |
| ax.set_xticks(x) | |
| label_prefix = SCENARIO_FOLDER[:3].capitalize() + " " | |
| ax.set_xticklabels([s.split("_")[-1] for s in scenario_ids], rotation=15, fontsize=8) | |
| ax.set_ylim(0, 1); ax.legend() | |
| ax.set_title("SecureReview β Before vs After SFT", fontweight="bold") | |
| mb = sum(b_vals) / len(b_vals); mt = sum(t_vals) / len(t_vals) | |
| ax.text(0.98, 0.92, f"Mean: {mb:.2f} β {mt:.2f} ({mt-mb:+.2f})", | |
| transform=ax.transAxes, ha="right", fontsize=11, color="white", | |
| bbox=dict(boxstyle="round", facecolor="#1a1a1a", alpha=0.8)) | |
| fig.tight_layout() | |
| plt.savefig(f"{PLOTS_DIR}/before_after.png", dpi=150, bbox_inches="tight") | |
| plt.close() | |
| print(f" Saved {PLOTS_DIR}/before_after.png") | |
| # Save results JSON | |
| results = { | |
| "baseline_mean": sum(b_vals) / len(b_vals), | |
| "trained_mean": sum(t_vals) / len(t_vals), | |
| "improvement": sum(t_vals) / len(t_vals) - sum(b_vals) / len(b_vals), | |
| "baseline_scores": baseline_scores, | |
| "trained_scores": trained_scores, | |
| } | |
| with open(f"{PLOTS_DIR}/results.json", "w") as f: | |
| json.dump(results, f, indent=2) | |
| print("\n" + "=" * 60) | |
| print(f" DONE β Mean {sum(b_vals)/len(b_vals):.3f} β {sum(t_vals)/len(t_vals):.3f}") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |