| """SFT (Supervised Fine-Tuning) on SecureReview ground-truth findings. |
| |
| Industry-standard pipeline: train the model to output the env's ground-truth |
| JSON exactly. Much faster + bigger improvements than GRPO alone, because |
| we're directly teaching the model the correct answer instead of waiting for |
| RL exploration to find it. |
| |
| Pipeline matches GRPO train.py format: |
| - Same env (live SecureReview Space) |
| - Same baseline + post-training evaluation |
| - Saves plots/reward_curve.png + plots/before_after.png + plots/results.json |
| """ |
| import os |
| import sys |
| import json |
| import re |
| import time |
| import requests |
| import functools |
|
|
| print = functools.partial(print, flush=True) |
|
|
| |
| ENV_URL = os.getenv("ENV_URL", "https://sam25kat-securereview-env-migration.hf.space") |
| ENV_REPO_ID = os.getenv("ENV_REPO_ID", "sam25kat/securereview-env-migration") |
| TASK_ID = "migration_review" |
| SCENARIO_FOLDER = "migration" |
| MODEL_NAME = "unsloth/Qwen2.5-7B-Instruct-bnb-4bit" |
| MAX_SEQ_LEN = 1536 |
| MAX_NEW_TOKENS = 600 |
| NUM_EPOCHS = 3 |
| LEARNING_RATE = 5e-5 |
| LORA_RANK = 16 |
| GRAD_ACCUM_STEPS = 2 |
| OUTPUT_DIR = "./securereview-sft" |
| PLOTS_DIR = "./plots" |
|
|
| os.makedirs(PLOTS_DIR, exist_ok=True) |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| SYSTEM_PROMPT = """You are a senior database engineer reviewing SQL migration scripts for backward-incompatibility and production safety risks. |
| |
| Identify ALL issues including: |
| - Adding NOT NULL columns without default on populated tables |
| - Dropping columns/tables that may still be in use |
| - Renaming columns without compatibility shim |
| - Long-running ALTER TABLE / locking operations on large tables |
| - Missing indexes for frequent queries |
| - Foreign key changes that break referential integrity |
| - Type narrowing that loses data (e.g. INT to SMALLINT) |
| |
| Output ONLY a valid JSON array of findings. Each finding must have: |
| file, line (integer or null), rule_id (e.g. MIG-001), severity (critical/high/medium/low/info), description |
| |
| Output ONLY the JSON array. No explanations, no markdown prose.""" |
|
|
|
|
| |
|
|
| def env_reset(task_id, scenario_id=None): |
| payload = {"task_id": task_id} |
| if scenario_id: |
| payload["scenario_id"] = scenario_id |
| r = requests.post(f"{ENV_URL}/reset", json=payload, timeout=30) |
| r.raise_for_status() |
| return r.json() |
|
|
|
|
| def env_step(action): |
| r = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=30) |
| r.raise_for_status() |
| return r.json() |
|
|
|
|
| def parse_findings(text): |
| patterns = [ |
| r'```(?:json)?\s*(\[.*?\])\s*```', |
| r'(\[\s*\{.*?\}\s*\])', |
| ] |
| for pattern in patterns: |
| m = re.search(pattern, text, re.DOTALL) |
| if m: |
| try: |
| return json.loads(m.group(1)) |
| except json.JSONDecodeError: |
| continue |
| return [] |
|
|
|
|
| def run_episode(completion, scenario_id): |
| findings = parse_findings(completion) |
| try: |
| env_reset(TASK_ID, scenario_id) |
| valid_sev = {"critical", "high", "medium", "low", "info"} |
| for f in findings: |
| if not isinstance(f, dict): |
| continue |
| finding = { |
| "file": str(f.get("file", "requirements.txt")), |
| "line": f.get("line"), |
| "rule_id": str(f.get("rule_id", "DEP-001")), |
| "severity": f.get("severity", "medium") if f.get("severity") in valid_sev else "medium", |
| "description": str(f.get("description", ""))[:400], |
| } |
| env_step({"action_type": "report_finding", "finding": finding}) |
| result = env_step({"action_type": "mark_complete"}) |
| return float(result.get("reward", 0.01)) |
| except Exception as e: |
| print(f" [env error] {e}") |
| return 0.01 |
|
|
|
|
| def build_prompt(obs): |
| ctx = obs["observation"]["context"] |
| files = ctx["files"] |
| parts = [f"Task: {ctx['task_description']}\n"] |
| for f in files: |
| parts.append(f"\n--- {f['filename']} ---\n{f['content']}") |
| parts.append("\nList all security issues as a JSON array:") |
| return "".join(parts) |
|
|
|
|
| |
|
|
| def fetch_ground_truth(scenario_id): |
| """Download ground_truth.json for a scenario from the env Space repo.""" |
| from huggingface_hub import hf_hub_download |
| |
| num = scenario_id.split("_")[-1] |
| path = hf_hub_download( |
| repo_id=ENV_REPO_ID, |
| repo_type="space", |
| filename=f"app/tasks/scenarios/{SCENARIO_FOLDER}/scenario_{num}/ground_truth.json", |
| ) |
| with open(path) as f: |
| return json.load(f) |
|
|
|
|
| def gt_to_target_json(gt_data): |
| """Convert ground_truth.json's 'ground_truth' list into the JSON array |
| the model should output. Strips internal fields (match_key, category).""" |
| findings = [] |
| for f in gt_data["ground_truth"]: |
| findings.append({ |
| "file": f["file"], |
| "line": f.get("line"), |
| "rule_id": f["rule_id"], |
| "severity": f["severity"], |
| "description": f["description"], |
| }) |
| return json.dumps(findings, indent=2) |
|
|
|
|
| |
|
|
| def main(): |
| print("=" * 60) |
| print(" SecureReview SFT Training") |
| print(f" Model : {MODEL_NAME}") |
| print(f" Task : {TASK_ID}") |
| print(f" Epochs: {NUM_EPOCHS}") |
| print("=" * 60) |
|
|
| |
| print("\n[1/6] Checking environment connection...") |
| r = requests.get(f"{ENV_URL}/health", timeout=15) |
| print(f" Health: {r.json()}") |
|
|
| |
| print("\n[2/6] Loading model...") |
| from unsloth import FastLanguageModel |
| import torch |
|
|
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name = MODEL_NAME, |
| max_seq_length = MAX_SEQ_LEN, |
| dtype = torch.float16, |
| load_in_4bit = True, |
| ) |
| model = FastLanguageModel.get_peft_model( |
| model, |
| r = LORA_RANK, |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj"], |
| lora_alpha = LORA_RANK, |
| lora_dropout = 0, |
| bias = "none", |
| use_gradient_checkpointing = "unsloth", |
| random_state = 42, |
| ) |
| model.print_trainable_parameters() |
|
|
| |
| print("\n[3/6] Building SFT dataset from ground-truth findings...") |
| from datasets import Dataset |
|
|
| scenario_ids = [ |
| "migration_002", "migration_006", "migration_007", |
| "migration_009", "migration_012", "migration_017", "migration_018", |
| "migration_022", "migration_023", |
| "migration_024", "migration_025", "migration_028", |
| ] |
| examples = [] |
| for sid in scenario_ids: |
| try: |
| obs = env_reset(TASK_ID, sid) |
| user = build_prompt(obs) |
| gt = fetch_ground_truth(sid) |
| target = gt_to_target_json(gt) |
|
|
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user}, |
| {"role": "assistant", "content": target}, |
| ] |
| full_text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=False |
| ) |
| examples.append({"text": full_text}) |
| print(f" Loaded {sid} ({len(gt['ground_truth'])} findings)") |
| except Exception as e: |
| print(f" Skipping {sid}: {e}") |
|
|
| dataset = Dataset.from_list(examples) |
| print(f" Dataset: {len(examples)} examples") |
|
|
| |
| print("\n[4/6] Baseline evaluation (before SFT)...") |
| FastLanguageModel.for_inference(model) |
|
|
| def evaluate(sids, label): |
| scores = {} |
| for sid in sids: |
| obs = env_reset(TASK_ID, sid) |
| prompt_text = build_prompt(obs) |
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt_text}, |
| ] |
| inputs = tokenizer.apply_chat_template( |
| messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" |
| ).to("cuda") |
| with torch.no_grad(): |
| outputs = model.generate( |
| inputs, max_new_tokens=MAX_NEW_TOKENS, |
| temperature=0.1, do_sample=True, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
| completion = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) |
| score = run_episode(completion, sid) |
| scores[sid] = score |
| print(f" [{label}] {sid}: {score:.3f}") |
| time.sleep(0.3) |
| return scores |
|
|
| baseline_scores = evaluate(scenario_ids, "before") |
| print(f" Baseline mean: {sum(baseline_scores.values())/len(baseline_scores):.3f}") |
|
|
| |
| print("\n[5/6] SFT training...") |
| FastLanguageModel.for_training(model) |
|
|
| from trl import SFTTrainer, SFTConfig |
|
|
| sft_args = SFTConfig( |
| output_dir = OUTPUT_DIR, |
| max_seq_length = MAX_SEQ_LEN, |
| num_train_epochs = NUM_EPOCHS, |
| per_device_train_batch_size = 1, |
| gradient_accumulation_steps = GRAD_ACCUM_STEPS, |
| learning_rate = LEARNING_RATE, |
| lr_scheduler_type = "cosine", |
| warmup_ratio = 0.05, |
| logging_steps = 2, |
| save_steps = 50, |
| fp16 = True, |
| bf16 = False, |
| optim = "adamw_8bit", |
| weight_decay = 0.01, |
| report_to = "none", |
| seed = 42, |
| dataset_text_field = "text", |
| packing = False, |
| ) |
|
|
| def formatting_func(example): |
| |
| return example["text"] |
|
|
| trainer = SFTTrainer( |
| model = model, |
| processing_class = tokenizer, |
| args = sft_args, |
| train_dataset = dataset, |
| formatting_func = formatting_func, |
| ) |
| trainer.train() |
|
|
| |
| loss_log = [ |
| {"step": h.get("step", i), "loss": h["loss"]} |
| for i, h in enumerate(trainer.state.log_history) |
| if "loss" in h |
| ] |
|
|
| |
| print("\n[6/6] Post-SFT evaluation...") |
| FastLanguageModel.for_inference(model) |
| trained_scores = evaluate(scenario_ids, "after") |
| print(f" Trained mean: {sum(trained_scores.values())/len(trained_scores):.3f}") |
|
|
| print("\n=== Improvement Summary ===") |
| for sid in scenario_ids: |
| b = baseline_scores.get(sid, 0) |
| t = trained_scores.get(sid, 0) |
| arrow = "▲" if t > b else ("▼" if t < b else "—") |
| print(f" {sid}: {b:.3f} → {t:.3f} {arrow} {t-b:+.3f}") |
|
|
| |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| plt.style.use("dark_background") |
|
|
| |
| if loss_log: |
| steps = [e["step"] for e in loss_log] |
| losses = [e["loss"] for e in loss_log] |
| fig, ax = plt.subplots(figsize=(11, 4)) |
| ax.plot(steps, losses, color="#ff6b35", linewidth=2) |
| ax.set_xlabel("Training Step"); ax.set_ylabel("Loss") |
| ax.set_title("SecureReview SFT — Training Loss", fontweight="bold") |
| ax.grid(True, alpha=0.2) |
| fig.tight_layout() |
| plt.savefig(f"{PLOTS_DIR}/reward_curve.png", dpi=150, bbox_inches="tight") |
| plt.close() |
| print(f" Saved {PLOTS_DIR}/reward_curve.png") |
|
|
| |
| b_vals = [baseline_scores.get(s, 0) for s in scenario_ids] |
| t_vals = [trained_scores.get(s, 0) for s in scenario_ids] |
| x = np.arange(len(scenario_ids)) |
| fig, ax = plt.subplots(figsize=(12, 5)) |
| ax.bar(x - 0.175, b_vals, 0.35, label="Before", color="#444444") |
| ax.bar(x + 0.175, t_vals, 0.35, label="After", color="#ff6b35") |
| for i, (b, t) in enumerate(zip(b_vals, t_vals)): |
| if abs(t - b) > 0.005: |
| ax.text(i + 0.175, t + 0.02, f"{t-b:+.2f}", ha="center", fontsize=8, |
| color="#22d3ee" if t >= b else "#ef4444") |
| ax.set_xticks(x) |
| label_prefix = SCENARIO_FOLDER[:3].capitalize() + " " |
| ax.set_xticklabels([s.split("_")[-1] for s in scenario_ids], rotation=15, fontsize=8) |
| ax.set_ylim(0, 1); ax.legend() |
| ax.set_title("SecureReview — Before vs After SFT", fontweight="bold") |
| mb = sum(b_vals) / len(b_vals); mt = sum(t_vals) / len(t_vals) |
| ax.text(0.98, 0.92, f"Mean: {mb:.2f} → {mt:.2f} ({mt-mb:+.2f})", |
| transform=ax.transAxes, ha="right", fontsize=11, color="white", |
| bbox=dict(boxstyle="round", facecolor="#1a1a1a", alpha=0.8)) |
| fig.tight_layout() |
| plt.savefig(f"{PLOTS_DIR}/before_after.png", dpi=150, bbox_inches="tight") |
| plt.close() |
| print(f" Saved {PLOTS_DIR}/before_after.png") |
|
|
| |
| results = { |
| "baseline_mean": sum(b_vals) / len(b_vals), |
| "trained_mean": sum(t_vals) / len(t_vals), |
| "improvement": sum(t_vals) / len(t_vals) - sum(b_vals) / len(b_vals), |
| "baseline_scores": baseline_scores, |
| "trained_scores": trained_scores, |
| } |
| with open(f"{PLOTS_DIR}/results.json", "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| print("\n" + "=" * 60) |
| print(f" DONE — Mean {sum(b_vals)/len(b_vals):.3f} → {sum(t_vals)/len(t_vals):.3f}") |
| print("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|