sam25kat's picture
Upload train.py with huggingface_hub
ea39a50 verified
"""SFT (Supervised Fine-Tuning) on SecureReview ground-truth findings.
Industry-standard pipeline: train the model to output the env's ground-truth
JSON exactly. Much faster + bigger improvements than GRPO alone, because
we're directly teaching the model the correct answer instead of waiting for
RL exploration to find it.
Pipeline matches GRPO train.py format:
- Same env (live SecureReview Space)
- Same baseline + post-training evaluation
- Saves plots/reward_curve.png + plots/before_after.png + plots/results.json
"""
import os
import sys
import json
import re
import time
import requests
import functools
print = functools.partial(print, flush=True)
# ── Config ────────────────────────────────────────────────────────────────────
ENV_URL = os.getenv("ENV_URL", "https://sam25kat-securereview.hf.space")
ENV_REPO_ID = os.getenv("ENV_REPO_ID", "sam25kat/securereview")
TASK_ID = "dependency_review"
SCENARIO_FOLDER = "dependency" # subdir under app/tasks/scenarios/
MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct"
MAX_SEQ_LEN = 1536
MAX_NEW_TOKENS = 600
NUM_EPOCHS = 3
LEARNING_RATE = 5e-5
LORA_RANK = 16
GRAD_ACCUM_STEPS = 2
OUTPUT_DIR = "./securereview-sft"
PLOTS_DIR = "./plots"
os.makedirs(PLOTS_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
SYSTEM_PROMPT = """You are a senior security engineer reviewing dependency files for vulnerabilities.
Identify ALL security issues including:
- Typosquatted packages (names that misspell popular libraries, e.g. 'reqeusts' instead of 'requests')
- Known CVE-vulnerable versions (e.g. requests<2.20.0 has CVE-2018-18074)
- Hallucinated / non-existent packages that don't exist on PyPI or npm
- Suspicious or malicious packages
Output ONLY a valid JSON array of findings. Each finding must have:
file, line (integer or null), rule_id (e.g. DEP-001), severity (critical/high/medium/low/info), description
Output ONLY the JSON array. No explanations, no markdown prose."""
# ── Environment helpers (same as GRPO version) ────────────────────────────────
def env_reset(task_id, scenario_id=None):
payload = {"task_id": task_id}
if scenario_id:
payload["scenario_id"] = scenario_id
r = requests.post(f"{ENV_URL}/reset", json=payload, timeout=30)
r.raise_for_status()
return r.json()
def env_step(action):
r = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=30)
r.raise_for_status()
return r.json()
def parse_findings(text):
patterns = [
r'```(?:json)?\s*(\[.*?\])\s*```',
r'(\[\s*\{.*?\}\s*\])',
]
for pattern in patterns:
m = re.search(pattern, text, re.DOTALL)
if m:
try:
return json.loads(m.group(1))
except json.JSONDecodeError:
continue
return []
def run_episode(completion, scenario_id):
findings = parse_findings(completion)
try:
env_reset(TASK_ID, scenario_id)
valid_sev = {"critical", "high", "medium", "low", "info"}
for f in findings:
if not isinstance(f, dict):
continue
finding = {
"file": str(f.get("file", "requirements.txt")),
"line": f.get("line"),
"rule_id": str(f.get("rule_id", "DEP-001")),
"severity": f.get("severity", "medium") if f.get("severity") in valid_sev else "medium",
"description": str(f.get("description", ""))[:400],
}
env_step({"action_type": "report_finding", "finding": finding})
result = env_step({"action_type": "mark_complete"})
return float(result.get("reward", 0.01))
except Exception as e:
print(f" [env error] {e}")
return 0.01
def build_prompt(obs):
ctx = obs["observation"]["context"]
files = ctx["files"]
parts = [f"Task: {ctx['task_description']}\n"]
for f in files:
parts.append(f"\n--- {f['filename']} ---\n{f['content']}")
parts.append("\nList all security issues as a JSON array:")
return "".join(parts)
# ── Ground-truth fetching (the SFT-specific bit) ──────────────────────────────
def fetch_ground_truth(scenario_id):
"""Download ground_truth.json for a scenario from the env Space repo."""
from huggingface_hub import hf_hub_download
# scenario_id like "dep_001" or "migration_002" -> scenario_NNN folder
num = scenario_id.split("_")[-1]
path = hf_hub_download(
repo_id=ENV_REPO_ID,
repo_type="space",
filename=f"app/tasks/scenarios/{SCENARIO_FOLDER}/scenario_{num}/ground_truth.json",
)
with open(path) as f:
return json.load(f)
def gt_to_target_json(gt_data):
"""Convert ground_truth.json's 'ground_truth' list into the JSON array
the model should output. Strips internal fields (match_key, category)."""
findings = []
for f in gt_data["ground_truth"]:
findings.append({
"file": f["file"],
"line": f.get("line"),
"rule_id": f["rule_id"],
"severity": f["severity"],
"description": f["description"],
})
return json.dumps(findings, indent=2)
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
print("=" * 60)
print(" SecureReview SFT Training")
print(f" Model : {MODEL_NAME}")
print(f" Task : {TASK_ID}")
print(f" Epochs: {NUM_EPOCHS}")
print("=" * 60)
# Verify env
print("\n[1/6] Checking environment connection...")
r = requests.get(f"{ENV_URL}/health", timeout=15)
print(f" Health: {r.json()}")
# Load model
print("\n[2/6] Loading model...")
from unsloth import FastLanguageModel
import torch
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = MODEL_NAME,
max_seq_length = MAX_SEQ_LEN,
dtype = torch.float16,
load_in_4bit = True,
)
model = FastLanguageModel.get_peft_model(
model,
r = LORA_RANK,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha = LORA_RANK,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 42,
)
model.print_trainable_parameters()
# Build SFT dataset from ground truth
print("\n[3/6] Building SFT dataset from ground-truth findings...")
from datasets import Dataset
scenario_ids = [f"dep_{i:03d}" for i in range(1, 25)] # dep override per task
examples = []
for sid in scenario_ids:
try:
obs = env_reset(TASK_ID, sid)
user = build_prompt(obs)
gt = fetch_ground_truth(sid)
target = gt_to_target_json(gt)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user},
{"role": "assistant", "content": target},
]
full_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
examples.append({"text": full_text})
print(f" Loaded {sid} ({len(gt['ground_truth'])} findings)")
except Exception as e:
print(f" Skipping {sid}: {e}")
dataset = Dataset.from_list(examples)
print(f" Dataset: {len(examples)} examples")
# Baseline eval
print("\n[4/6] Baseline evaluation (before SFT)...")
FastLanguageModel.for_inference(model)
def evaluate(sids, label):
scores = {}
for sid in sids:
obs = env_reset(TASK_ID, sid)
prompt_text = build_prompt(obs)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt_text},
]
inputs = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to("cuda")
with torch.no_grad():
outputs = model.generate(
inputs, max_new_tokens=MAX_NEW_TOKENS,
temperature=0.1, do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
completion = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
score = run_episode(completion, sid)
scores[sid] = score
print(f" [{label}] {sid}: {score:.3f}")
time.sleep(0.3)
return scores
baseline_scores = evaluate(scenario_ids, "before")
print(f" Baseline mean: {sum(baseline_scores.values())/len(baseline_scores):.3f}")
# SFT training
print("\n[5/6] SFT training...")
FastLanguageModel.for_training(model)
from trl import SFTTrainer, SFTConfig
sft_args = SFTConfig(
output_dir = OUTPUT_DIR,
max_seq_length = MAX_SEQ_LEN,
num_train_epochs = NUM_EPOCHS,
per_device_train_batch_size = 1,
gradient_accumulation_steps = GRAD_ACCUM_STEPS,
learning_rate = LEARNING_RATE,
lr_scheduler_type = "cosine",
warmup_ratio = 0.05,
logging_steps = 2,
save_steps = 50,
fp16 = True,
bf16 = False,
optim = "adamw_8bit",
weight_decay = 0.01,
report_to = "none",
seed = 42,
dataset_text_field = "text",
packing = False,
)
def formatting_func(example):
# SFTTrainer with Unsloth requires this even when dataset_text_field is set.
return example["text"]
trainer = SFTTrainer(
model = model,
processing_class = tokenizer,
args = sft_args,
train_dataset = dataset,
formatting_func = formatting_func,
)
trainer.train()
# Capture loss history for plot
loss_log = [
{"step": h.get("step", i), "loss": h["loss"]}
for i, h in enumerate(trainer.state.log_history)
if "loss" in h
]
# Post-training eval
print("\n[6/6] Post-SFT evaluation...")
FastLanguageModel.for_inference(model)
trained_scores = evaluate(scenario_ids, "after")
print(f" Trained mean: {sum(trained_scores.values())/len(trained_scores):.3f}")
print("\n=== Improvement Summary ===")
for sid in scenario_ids:
b = baseline_scores.get(sid, 0)
t = trained_scores.get(sid, 0)
arrow = "β–²" if t > b else ("β–Ό" if t < b else "β€”")
print(f" {sid}: {b:.3f} β†’ {t:.3f} {arrow} {t-b:+.3f}")
# Plots
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
plt.style.use("dark_background")
# Loss curve (instead of reward curve)
if loss_log:
steps = [e["step"] for e in loss_log]
losses = [e["loss"] for e in loss_log]
fig, ax = plt.subplots(figsize=(11, 4))
ax.plot(steps, losses, color="#ff6b35", linewidth=2)
ax.set_xlabel("Training Step"); ax.set_ylabel("Loss")
ax.set_title("SecureReview SFT β€” Training Loss", fontweight="bold")
ax.grid(True, alpha=0.2)
fig.tight_layout()
plt.savefig(f"{PLOTS_DIR}/reward_curve.png", dpi=150, bbox_inches="tight")
plt.close()
print(f" Saved {PLOTS_DIR}/reward_curve.png")
# Before/after bar chart
b_vals = [baseline_scores.get(s, 0) for s in scenario_ids]
t_vals = [trained_scores.get(s, 0) for s in scenario_ids]
x = np.arange(len(scenario_ids))
fig, ax = plt.subplots(figsize=(12, 5))
ax.bar(x - 0.175, b_vals, 0.35, label="Before", color="#444444")
ax.bar(x + 0.175, t_vals, 0.35, label="After", color="#ff6b35")
for i, (b, t) in enumerate(zip(b_vals, t_vals)):
if abs(t - b) > 0.005:
ax.text(i + 0.175, t + 0.02, f"{t-b:+.2f}", ha="center", fontsize=8,
color="#22d3ee" if t >= b else "#ef4444")
ax.set_xticks(x)
label_prefix = SCENARIO_FOLDER[:3].capitalize() + " "
ax.set_xticklabels([s.split("_")[-1] for s in scenario_ids], rotation=15, fontsize=8)
ax.set_ylim(0, 1); ax.legend()
ax.set_title("SecureReview β€” Before vs After SFT", fontweight="bold")
mb = sum(b_vals) / len(b_vals); mt = sum(t_vals) / len(t_vals)
ax.text(0.98, 0.92, f"Mean: {mb:.2f} β†’ {mt:.2f} ({mt-mb:+.2f})",
transform=ax.transAxes, ha="right", fontsize=11, color="white",
bbox=dict(boxstyle="round", facecolor="#1a1a1a", alpha=0.8))
fig.tight_layout()
plt.savefig(f"{PLOTS_DIR}/before_after.png", dpi=150, bbox_inches="tight")
plt.close()
print(f" Saved {PLOTS_DIR}/before_after.png")
# Save results JSON
results = {
"baseline_mean": sum(b_vals) / len(b_vals),
"trained_mean": sum(t_vals) / len(t_vals),
"improvement": sum(t_vals) / len(t_vals) - sum(b_vals) / len(b_vals),
"baseline_scores": baseline_scores,
"trained_scores": trained_scores,
}
with open(f"{PLOTS_DIR}/results.json", "w") as f:
json.dump(results, f, indent=2)
print("\n" + "=" * 60)
print(f" DONE β€” Mean {sum(b_vals)/len(b_vals):.3f} β†’ {sum(t_vals)/len(t_vals):.3f}")
print("=" * 60)
if __name__ == "__main__":
main()