cve-cwe-qwen3-8b / evaluate.py
exploitintel's picture
Update default dataset id to exploitintel
dbe2ea6 verified
#!/usr/bin/env python3
"""Evaluate a fine-tuned CVE -> CWE model on the held-out test split.
Reports exact-match accuracy plus micro/macro multi-label F1, stratified into
"easy" (the weakness is named in the description) vs "hard" (it must be inferred),
so you see real-world performance instead of one flattered average.
Loads with plain transformers. Newer architectures (e.g. model_type ``gemma4``,
used by gemma-4-E4B) need **transformers >= 5.5** -- older versions raise
``KeyError: 'gemma4'``. Note: do NOT load gemma4 through unsloth in a Studio env
whose transformers was upgraded -- the upgrade pulls ``huggingface_hub`` 1.x,
which breaks ``unsloth_zoo``'s config lookup. Plain transformers is the clean path.
python evaluate.py --model "C:\\path\\to\\exported\\merged_model" --limit 500
python evaluate.py --model "C:\\path\\to\\exported\\merged_model"
Needs: transformers>=5.5, torch, datasets, accelerate.
"""
from __future__ import annotations
import argparse
import re
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
CWE_RE = re.compile(r"CWE-\d+")
# A row is "easy" if the description literally names the weakness (the model can
# keyword-match); "hard" rows require inferring the CWE from the prose.
EASY_KW = [
"sql injection",
"cross-site scripting",
"cross site scripting",
"xss",
"buffer overflow",
"use after free",
"use-after-free",
"path traversal",
"command injection",
"out-of-bounds",
"out of bounds",
"race condition",
"deserialization",
"ssrf",
"server-side request forgery",
"csrf",
"cross-site request forgery",
"open redirect",
"integer overflow",
]
def parse_cwes(text: str) -> set[str]:
return set(CWE_RE.findall(text))
def is_easy(description: str) -> bool:
return any(k in description.lower() for k in EASY_KW)
def prf(tp: int, fp: int, fn: int) -> tuple[float, float, float]:
p = tp / (tp + fp) if (tp + fp) else 0.0
r = tp / (tp + fn) if (tp + fn) else 0.0
f = 2 * p * r / (p + r) if (p + r) else 0.0
return p, r, f
def build_prompt(tok, messages: list[dict]) -> str:
"""Prompt = everything up to (but not including) the assistant answer."""
convo = messages[:-1]
try:
return tok.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
except Exception:
# Some chat templates (e.g. Gemma) reject a separate "system" role;
# fold the system text into the user turn instead.
sys_txt = next((m["content"] for m in convo if m["role"] == "system"), "")
usr_txt = next((m["content"] for m in convo if m["role"] == "user"), "")
folded = [{"role": "user", "content": f"{sys_txt}\n\n{usr_txt}".strip()}]
return tok.apply_chat_template(folded, tokenize=False, add_generation_prompt=True)
def score(truths: list[set[str]], preds: list[set[str]], easies: list[bool]) -> None:
micro = [0, 0, 0] # tp, fp, fn
per_label: dict[str, list[int]] = {}
exact = 0
strata = {"easy": [0, 0, 0, 0, 0], "hard": [0, 0, 0, 0, 0]} # tp,fp,fn,exact,n
for true, pred, easy in zip(truths, preds, easies):
tp, fp, fn = len(pred & true), len(pred - true), len(true - pred)
micro[0] += tp
micro[1] += fp
micro[2] += fn
ex = int(pred == true)
exact += ex
for lab in true | pred:
d = per_label.setdefault(lab, [0, 0, 0])
if lab in true and lab in pred:
d[0] += 1
elif lab in pred:
d[1] += 1
else:
d[2] += 1
s = strata["easy" if easy else "hard"]
s[0] += tp
s[1] += fp
s[2] += fn
s[3] += ex
s[4] += 1
n = len(truths)
micro_f1 = prf(*micro)[2]
macro_f1 = sum(prf(*v)[2] for v in per_label.values()) / len(per_label) if per_label else 0.0
print("\n=== CVE -> CWE evaluation ===")
print(f"examples : {n}")
print(f"exact-match accuracy : {exact / n:.3f} (predicted CWE set == true set)")
print(f"micro-F1 : {micro_f1:.3f}")
print(f"macro-F1 : {macro_f1:.3f} (unweighted mean over {len(per_label)} CWEs)")
print("\n-- by difficulty --")
for name, label in (("easy", "easy (weakness named)"), ("hard", "hard (must infer) ")):
tp, fp, fn, ex, m = strata[name]
if m:
print(f" {label:22s} n={m:5d} exact={ex / m:.3f} micro-F1={prf(tp, fp, fn)[2]:.3f}")
def main() -> None:
ap = argparse.ArgumentParser(description="Evaluate a CVE->CWE model on the test split.")
ap.add_argument("--model", required=True, help="path or HF id of the fine-tuned (merged) model")
ap.add_argument("--dataset", default="exploitintel/cve-cwe-consensus")
ap.add_argument("--split", default="test")
ap.add_argument(
"--limit", type=int, default=None, help="evaluate only the first N rows (quick check)"
)
ap.add_argument("--batch-size", type=int, default=16)
ap.add_argument("--max-new-tokens", type=int, default=32)
args = ap.parse_args()
print(f"loading model: {args.model}")
try:
tok = AutoTokenizer.from_pretrained(args.model)
except (AttributeError, TypeError):
# Some Gemma tokenizer configs store `extra_special_tokens` as a list, which
# trips a transformers bug ('list' object has no attribute 'keys').
tok = AutoTokenizer.from_pretrained(args.model, extra_special_tokens={})
tok.padding_side = "left" # decoder-only batched generation needs left padding
if tok.pad_token is None:
tok.pad_token = tok.eos_token
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
model = AutoModelForCausalLM.from_pretrained(args.model, dtype="auto").to(device)
except TypeError:
# `dtype` is the transformers 5.x name; older releases use `torch_dtype`.
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype="auto").to(device)
model.eval()
ds = load_dataset(args.dataset, split=args.split)
if args.limit:
ds = ds.select(range(min(args.limit, len(ds))))
prompts, truths, easies = [], [], []
for ex in ds:
msgs = ex["messages"]
prompts.append(build_prompt(tok, msgs))
truths.append(parse_cwes(msgs[-1]["content"]))
usr = next((m["content"] for m in msgs if m["role"] == "user"), "")
easies.append(is_easy(usr))
preds: list[set[str]] = []
for i in range(0, len(prompts), args.batch_size):
batch = prompts[i : i + args.batch_size]
enc = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(
device
)
with torch.no_grad():
out = model.generate(
**enc,
max_new_tokens=args.max_new_tokens,
do_sample=False, # greedy = deterministic
pad_token_id=tok.pad_token_id,
)
new_tokens = out[:, enc["input_ids"].shape[1] :] # drop the prompt, keep the answer
for row in new_tokens:
preds.append(parse_cwes(tok.decode(row, skip_special_tokens=True)))
print(f" {min(i + args.batch_size, len(prompts))}/{len(prompts)}", end="\r")
print()
score(truths, preds, easies)
if __name__ == "__main__":
main()