File size: 7,410 Bytes
16d4233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#!/usr/bin/env python3
"""Evaluate a fine-tuned CVE -> CWE model on the held-out test split.

Reports exact-match accuracy plus micro/macro multi-label F1, stratified into
"easy" (the weakness is named in the description) vs "hard" (it must be inferred),
so you see real-world performance instead of one flattered average.

Loads with plain transformers. Newer architectures (e.g. model_type ``gemma4``,
used by gemma-4-E4B) need **transformers >= 5.5** -- older versions raise
``KeyError: 'gemma4'``. Note: do NOT load gemma4 through unsloth in a Studio env
whose transformers was upgraded -- the upgrade pulls ``huggingface_hub`` 1.x,
which breaks ``unsloth_zoo``'s config lookup. Plain transformers is the clean path.

    python evaluate.py --model "C:\\path\\to\\exported\\merged_model" --limit 500
    python evaluate.py --model "C:\\path\\to\\exported\\merged_model"

Needs: transformers>=5.5, torch, datasets, accelerate.
"""

from __future__ import annotations

import argparse
import re

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

CWE_RE = re.compile(r"CWE-\d+")

# A row is "easy" if the description literally names the weakness (the model can
# keyword-match); "hard" rows require inferring the CWE from the prose.
EASY_KW = [
    "sql injection",
    "cross-site scripting",
    "cross site scripting",
    "xss",
    "buffer overflow",
    "use after free",
    "use-after-free",
    "path traversal",
    "command injection",
    "out-of-bounds",
    "out of bounds",
    "race condition",
    "deserialization",
    "ssrf",
    "server-side request forgery",
    "csrf",
    "cross-site request forgery",
    "open redirect",
    "integer overflow",
]


def parse_cwes(text: str) -> set[str]:
    return set(CWE_RE.findall(text))


def is_easy(description: str) -> bool:
    return any(k in description.lower() for k in EASY_KW)


def prf(tp: int, fp: int, fn: int) -> tuple[float, float, float]:
    p = tp / (tp + fp) if (tp + fp) else 0.0
    r = tp / (tp + fn) if (tp + fn) else 0.0
    f = 2 * p * r / (p + r) if (p + r) else 0.0
    return p, r, f


def build_prompt(tok, messages: list[dict]) -> str:
    """Prompt = everything up to (but not including) the assistant answer."""
    convo = messages[:-1]
    try:
        return tok.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
    except Exception:
        # Some chat templates (e.g. Gemma) reject a separate "system" role;
        # fold the system text into the user turn instead.
        sys_txt = next((m["content"] for m in convo if m["role"] == "system"), "")
        usr_txt = next((m["content"] for m in convo if m["role"] == "user"), "")
        folded = [{"role": "user", "content": f"{sys_txt}\n\n{usr_txt}".strip()}]
        return tok.apply_chat_template(folded, tokenize=False, add_generation_prompt=True)


def score(truths: list[set[str]], preds: list[set[str]], easies: list[bool]) -> None:
    micro = [0, 0, 0]  # tp, fp, fn
    per_label: dict[str, list[int]] = {}
    exact = 0
    strata = {"easy": [0, 0, 0, 0, 0], "hard": [0, 0, 0, 0, 0]}  # tp,fp,fn,exact,n

    for true, pred, easy in zip(truths, preds, easies):
        tp, fp, fn = len(pred & true), len(pred - true), len(true - pred)
        micro[0] += tp
        micro[1] += fp
        micro[2] += fn
        ex = int(pred == true)
        exact += ex
        for lab in true | pred:
            d = per_label.setdefault(lab, [0, 0, 0])
            if lab in true and lab in pred:
                d[0] += 1
            elif lab in pred:
                d[1] += 1
            else:
                d[2] += 1
        s = strata["easy" if easy else "hard"]
        s[0] += tp
        s[1] += fp
        s[2] += fn
        s[3] += ex
        s[4] += 1

    n = len(truths)
    micro_f1 = prf(*micro)[2]
    macro_f1 = sum(prf(*v)[2] for v in per_label.values()) / len(per_label) if per_label else 0.0

    print("\n=== CVE -> CWE evaluation ===")
    print(f"examples             : {n}")
    print(f"exact-match accuracy : {exact / n:.3f}   (predicted CWE set == true set)")
    print(f"micro-F1             : {micro_f1:.3f}")
    print(f"macro-F1             : {macro_f1:.3f}   (unweighted mean over {len(per_label)} CWEs)")
    print("\n-- by difficulty --")
    for name, label in (("easy", "easy (weakness named)"), ("hard", "hard (must infer) ")):
        tp, fp, fn, ex, m = strata[name]
        if m:
            print(f"  {label:22s} n={m:5d}  exact={ex / m:.3f}  micro-F1={prf(tp, fp, fn)[2]:.3f}")


def main() -> None:
    ap = argparse.ArgumentParser(description="Evaluate a CVE->CWE model on the test split.")
    ap.add_argument("--model", required=True, help="path or HF id of the fine-tuned (merged) model")
    ap.add_argument("--dataset", default="exploitintel/cve-cwe-consensus")
    ap.add_argument("--split", default="test")
    ap.add_argument(
        "--limit", type=int, default=None, help="evaluate only the first N rows (quick check)"
    )
    ap.add_argument("--batch-size", type=int, default=16)
    ap.add_argument("--max-new-tokens", type=int, default=32)
    args = ap.parse_args()

    print(f"loading model: {args.model}")
    try:
        tok = AutoTokenizer.from_pretrained(args.model)
    except (AttributeError, TypeError):
        # Some Gemma tokenizer configs store `extra_special_tokens` as a list, which
        # trips a transformers bug ('list' object has no attribute 'keys').
        tok = AutoTokenizer.from_pretrained(args.model, extra_special_tokens={})
    tok.padding_side = "left"  # decoder-only batched generation needs left padding
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    device = "cuda" if torch.cuda.is_available() else "cpu"
    try:
        model = AutoModelForCausalLM.from_pretrained(args.model, dtype="auto").to(device)
    except TypeError:
        # `dtype` is the transformers 5.x name; older releases use `torch_dtype`.
        model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype="auto").to(device)
    model.eval()

    ds = load_dataset(args.dataset, split=args.split)
    if args.limit:
        ds = ds.select(range(min(args.limit, len(ds))))

    prompts, truths, easies = [], [], []
    for ex in ds:
        msgs = ex["messages"]
        prompts.append(build_prompt(tok, msgs))
        truths.append(parse_cwes(msgs[-1]["content"]))
        usr = next((m["content"] for m in msgs if m["role"] == "user"), "")
        easies.append(is_easy(usr))

    preds: list[set[str]] = []
    for i in range(0, len(prompts), args.batch_size):
        batch = prompts[i : i + args.batch_size]
        enc = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(
            device
        )
        with torch.no_grad():
            out = model.generate(
                **enc,
                max_new_tokens=args.max_new_tokens,
                do_sample=False,  # greedy = deterministic
                pad_token_id=tok.pad_token_id,
            )
        new_tokens = out[:, enc["input_ids"].shape[1] :]  # drop the prompt, keep the answer
        for row in new_tokens:
            preds.append(parse_cwes(tok.decode(row, skip_special_tokens=True)))
        print(f"  {min(i + args.batch_size, len(prompts))}/{len(prompts)}", end="\r")
    print()

    score(truths, preds, easies)


if __name__ == "__main__":
    main()