#!/usr/bin/env python3 """Evaluate a fine-tuned CVE -> CWE model on the held-out test split. Reports exact-match accuracy plus micro/macro multi-label F1, stratified into "easy" (the weakness is named in the description) vs "hard" (it must be inferred), so you see real-world performance instead of one flattered average. Loads with plain transformers. Newer architectures (e.g. model_type ``gemma4``, used by gemma-4-E4B) need **transformers >= 5.5** -- older versions raise ``KeyError: 'gemma4'``. Note: do NOT load gemma4 through unsloth in a Studio env whose transformers was upgraded -- the upgrade pulls ``huggingface_hub`` 1.x, which breaks ``unsloth_zoo``'s config lookup. Plain transformers is the clean path. python evaluate.py --model "C:\\path\\to\\exported\\merged_model" --limit 500 python evaluate.py --model "C:\\path\\to\\exported\\merged_model" Needs: transformers>=5.5, torch, datasets, accelerate. """ from __future__ import annotations import argparse import re import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer CWE_RE = re.compile(r"CWE-\d+") # A row is "easy" if the description literally names the weakness (the model can # keyword-match); "hard" rows require inferring the CWE from the prose. EASY_KW = [ "sql injection", "cross-site scripting", "cross site scripting", "xss", "buffer overflow", "use after free", "use-after-free", "path traversal", "command injection", "out-of-bounds", "out of bounds", "race condition", "deserialization", "ssrf", "server-side request forgery", "csrf", "cross-site request forgery", "open redirect", "integer overflow", ] def parse_cwes(text: str) -> set[str]: return set(CWE_RE.findall(text)) def is_easy(description: str) -> bool: return any(k in description.lower() for k in EASY_KW) def prf(tp: int, fp: int, fn: int) -> tuple[float, float, float]: p = tp / (tp + fp) if (tp + fp) else 0.0 r = tp / (tp + fn) if (tp + fn) else 0.0 f = 2 * p * r / (p + r) if (p + r) else 0.0 return p, r, f def build_prompt(tok, messages: list[dict]) -> str: """Prompt = everything up to (but not including) the assistant answer.""" convo = messages[:-1] try: return tok.apply_chat_template(convo, tokenize=False, add_generation_prompt=True) except Exception: # Some chat templates (e.g. Gemma) reject a separate "system" role; # fold the system text into the user turn instead. sys_txt = next((m["content"] for m in convo if m["role"] == "system"), "") usr_txt = next((m["content"] for m in convo if m["role"] == "user"), "") folded = [{"role": "user", "content": f"{sys_txt}\n\n{usr_txt}".strip()}] return tok.apply_chat_template(folded, tokenize=False, add_generation_prompt=True) def score(truths: list[set[str]], preds: list[set[str]], easies: list[bool]) -> None: micro = [0, 0, 0] # tp, fp, fn per_label: dict[str, list[int]] = {} exact = 0 strata = {"easy": [0, 0, 0, 0, 0], "hard": [0, 0, 0, 0, 0]} # tp,fp,fn,exact,n for true, pred, easy in zip(truths, preds, easies): tp, fp, fn = len(pred & true), len(pred - true), len(true - pred) micro[0] += tp micro[1] += fp micro[2] += fn ex = int(pred == true) exact += ex for lab in true | pred: d = per_label.setdefault(lab, [0, 0, 0]) if lab in true and lab in pred: d[0] += 1 elif lab in pred: d[1] += 1 else: d[2] += 1 s = strata["easy" if easy else "hard"] s[0] += tp s[1] += fp s[2] += fn s[3] += ex s[4] += 1 n = len(truths) micro_f1 = prf(*micro)[2] macro_f1 = sum(prf(*v)[2] for v in per_label.values()) / len(per_label) if per_label else 0.0 print("\n=== CVE -> CWE evaluation ===") print(f"examples : {n}") print(f"exact-match accuracy : {exact / n:.3f} (predicted CWE set == true set)") print(f"micro-F1 : {micro_f1:.3f}") print(f"macro-F1 : {macro_f1:.3f} (unweighted mean over {len(per_label)} CWEs)") print("\n-- by difficulty --") for name, label in (("easy", "easy (weakness named)"), ("hard", "hard (must infer) ")): tp, fp, fn, ex, m = strata[name] if m: print(f" {label:22s} n={m:5d} exact={ex / m:.3f} micro-F1={prf(tp, fp, fn)[2]:.3f}") def main() -> None: ap = argparse.ArgumentParser(description="Evaluate a CVE->CWE model on the test split.") ap.add_argument("--model", required=True, help="path or HF id of the fine-tuned (merged) model") ap.add_argument("--dataset", default="exploitintel/cve-cwe-consensus") ap.add_argument("--split", default="test") ap.add_argument( "--limit", type=int, default=None, help="evaluate only the first N rows (quick check)" ) ap.add_argument("--batch-size", type=int, default=16) ap.add_argument("--max-new-tokens", type=int, default=32) args = ap.parse_args() print(f"loading model: {args.model}") try: tok = AutoTokenizer.from_pretrained(args.model) except (AttributeError, TypeError): # Some Gemma tokenizer configs store `extra_special_tokens` as a list, which # trips a transformers bug ('list' object has no attribute 'keys'). tok = AutoTokenizer.from_pretrained(args.model, extra_special_tokens={}) tok.padding_side = "left" # decoder-only batched generation needs left padding if tok.pad_token is None: tok.pad_token = tok.eos_token device = "cuda" if torch.cuda.is_available() else "cpu" try: model = AutoModelForCausalLM.from_pretrained(args.model, dtype="auto").to(device) except TypeError: # `dtype` is the transformers 5.x name; older releases use `torch_dtype`. model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype="auto").to(device) model.eval() ds = load_dataset(args.dataset, split=args.split) if args.limit: ds = ds.select(range(min(args.limit, len(ds)))) prompts, truths, easies = [], [], [] for ex in ds: msgs = ex["messages"] prompts.append(build_prompt(tok, msgs)) truths.append(parse_cwes(msgs[-1]["content"])) usr = next((m["content"] for m in msgs if m["role"] == "user"), "") easies.append(is_easy(usr)) preds: list[set[str]] = [] for i in range(0, len(prompts), args.batch_size): batch = prompts[i : i + args.batch_size] enc = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=1024).to( device ) with torch.no_grad(): out = model.generate( **enc, max_new_tokens=args.max_new_tokens, do_sample=False, # greedy = deterministic pad_token_id=tok.pad_token_id, ) new_tokens = out[:, enc["input_ids"].shape[1] :] # drop the prompt, keep the answer for row in new_tokens: preds.append(parse_cwes(tok.decode(row, skip_special_tokens=True))) print(f" {min(i + args.batch_size, len(prompts))}/{len(prompts)}", end="\r") print() score(truths, preds, easies) if __name__ == "__main__": main()