File size: 8,583 Bytes
4da4469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""VectraYX-Bench B1-B4 + B5 (conversational) benchmark runner.

Loads a checkpoint and an `eval_data/` directory of JSONL test files, runs them,
and prints a summary table compatible with the paper draft.

Expected files (any subset is fine):
    eval_data/b1_cveqa.jsonl         {"cve_id":..., "prompt":..., "expected_keywords":[...]}
    eval_data/b2_classification.jsonl {"prompt":..., "label":"phishing|malware|..."}
    eval_data/b3_commands.jsonl      {"prompt":..., "expected":"nmap -sV ...", "tool":"nmap"}
    eval_data/b4_tooluse.jsonl       {"prompt":..., "expected_tool":"nvd_get_cve"}
    eval_data/b5_conversational.jsonl {"prompt":"hola", "category":"saludo"}
"""

import argparse
import json
import re
import sys
from pathlib import Path

import sentencepiece as spm
import torch

ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(ROOT))

from training_v2.model.transformer import VectraYXNano, ModelConfig
from training_v2.train.utils import load_checkpoint


SYSTEM_BASE = ("Eres VectraYX-Nano, asistente experto en ciberseguridad para "
               "América Latina. Responde en español de forma natural y concisa.")

# v2 (2026-05-05): system prompt extendido para B4. El SFT tooluse_dataset.jsonl
# fue entrenado con descripciones de cada herramienta + un ejemplo de formato
# JSON; el prompt anterior (lista plana de nombres) producía 0/25 en B4 porque
# nunca disparaba el patrón <|tool_call|>{"name":...}<|/tool_call|>.
SYSTEM_TOOL = (
    "Eres VectraYX, asistente experto en ciberseguridad para LATAM con acceso "
    "a las siguientes herramientas. Cuando una pregunta requiera datos en "
    "tiempo real (CVEs, IOCs, comandos), responde EXCLUSIVAMENTE con un "
    "bloque <|tool_call|>{...}<|/tool_call|> en formato JSON.\n\n"
    "Herramientas disponibles:\n"
    "- nvd_get_cve(cve_id): obtiene CVSS, descripción y referencias de un CVE.\n"
    "- nvd_search(keyword): busca CVEs recientes por palabra clave.\n"
    "- cisa_kev_check(cve_id): verifica si un CVE está en el catálogo KEV.\n"
    "- mitre_get_technique(technique_id): describe una técnica MITRE ATT&CK.\n"
    "- otx_check_ioc(ioc): verifica IP/dominio/hash en AlienVault OTX.\n"
    "- bash_exec(cmd): ejecuta un comando bash de análisis o forensics.\n\n"
    "Ejemplo:\n"
    "Usuario: ¿Está siendo explotada CVE-2021-44228?\n"
    "Asistente: <|tool_call|>{\"name\": \"cisa_kev_check\", "
    "\"args\": {\"cve_id\": \"CVE-2021-44228\"}}<|/tool_call|>"
)


def chat(user, system):
    return f"<|system|>{system}<|end|><|user|>{user}<|end|><|assistant|>"


def generate(model, sp, prompt, max_new, end_id, eos_id, device,
             temperature=0.7, top_k=40, top_p=0.9, repeat_penalty=1.3):
    ids = torch.tensor([sp.encode(prompt, out_type=int)], dtype=torch.long, device=device)
    out = model.generate(
        ids, max_new_tokens=max_new, temperature=temperature, top_k=top_k,
        top_p=top_p, eos_id=end_id, repeat_penalty=repeat_penalty,
    )
    gen = out[0, ids.size(1):].tolist()
    if end_id in gen:
        gen = gen[: gen.index(end_id)]
    if eos_id != end_id and eos_id in gen:
        gen = gen[: gen.index(eos_id)]
    return sp.decode(gen).strip()


def b1_cveqa(model, sp, data, ctx):
    if not data:
        return None
    hits = 0
    for ex in data:
        cve_id = ex.get("cve_id") or ex.get("id", "")
        prompt_text = ex.get("prompt") or ex.get("question") or f"Resume {cve_id}"
        prompt = chat(prompt_text, SYSTEM_BASE)
        out = generate(model, sp, prompt, 200, ctx["end_id"], ctx["eos_id"], ctx["device"]).lower()
        kws = [k.lower() for k in ex.get("expected_keywords", [])]
        score = sum(1 for k in kws if k in out) / max(1, len(kws))
        hits += score
    return hits / len(data)


def b2_classification(model, sp, data, ctx):
    if not data:
        return None
    labels = ["phishing", "malware", "ransomware", "apt", "otro"]
    correct = 0
    per_label = {l: [0, 0] for l in labels}  # [tp, total]
    for ex in data:
        text = ex.get("prompt") or ex.get("text") or ex.get("question", "")
        prompt = chat(f"{text}\nClasifica en una palabra: phishing, malware, ransomware, apt, otro.",
                      SYSTEM_BASE)
        out = generate(model, sp, prompt, 16, ctx["end_id"], ctx["eos_id"], ctx["device"]).lower()
        pred = next((l for l in labels if l in out), "otro")
        gold = ex["label"].lower()
        per_label[gold][1] += 1
        if pred == gold:
            correct += 1
            per_label[gold][0] += 1
    f1s = []
    for l, (tp, total) in per_label.items():
        if total == 0:
            continue
        recall = tp / total
        f1s.append(recall)
    return {"accuracy": correct / len(data), "f1_macro": sum(f1s) / max(1, len(f1s))}


def b3_commands(model, sp, data, ctx):
    if not data:
        return None
    exact = 0
    tool_match = 0
    for ex in data:
        prompt_text = ex.get("prompt") or ex.get("question", "")
        prompt = chat(prompt_text, SYSTEM_BASE)
        out = generate(model, sp, prompt, 80, ctx["end_id"], ctx["eos_id"], ctx["device"])
        gold_cmd = (ex.get("expected") or ex.get("expected_command", "")).strip()
        gold_tool = ex.get("tool", gold_cmd.split()[0] if gold_cmd else "")
        if gold_cmd in out:
            exact += 1
        if gold_tool.lower() in out.lower():
            tool_match += 1
    return {"exact_match": exact / len(data), "tool_match": tool_match / len(data)}


def b4_tooluse(model, sp, data, ctx):
    if not data:
        return None
    tools = ["nvd_get_cve", "nvd_search", "cisa_kev_check", "mitre_get_technique",
             "otx_check_ioc", "bash_exec"]
    correct = 0
    for ex in data:
        prompt_text = ex.get("prompt") or ex.get("question", "")
        prompt = chat(prompt_text, SYSTEM_TOOL)
        out = generate(model, sp, prompt, 120, ctx["end_id"], ctx["eos_id"], ctx["device"])
        m = re.search(r'"name"\s*:\s*"([^"]+)"', out)
        pred = m.group(1) if m else next((t for t in tools if t in out), None)
        if pred == ex["expected_tool"]:
            correct += 1
    return correct / len(data)


def b5_conversational(model, sp, data, ctx):
    if not data:
        return None
    ok = 0
    for ex in data:
        prompt = chat(ex["prompt"], SYSTEM_BASE)
        out = generate(model, sp, prompt, 80, ctx["end_id"], ctx["eos_id"], ctx["device"]).lower()
        cat = ex.get("category", "")
        if cat == "saludo":
            ok += int(any(w in out[:80] for w in ["hola", "buen", "qué tal", "encantad"]))
        elif cat == "agradecimiento":
            ok += int(any(w in out[:80] for w in ["nada", "gusto", "ayud"]))
        else:
            ok += int(len(out) > 5 and not out.startswith("cve"))
    return ok / len(data)


def load_jsonl(path):
    if not Path(path).exists():
        return []
    return [json.loads(line) for line in open(path, "r", encoding="utf-8") if line.strip()]


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--config", required=True)
    p.add_argument("--tokenizer", required=True)
    p.add_argument("--checkpoint", required=True)
    p.add_argument("--data-dir", required=True, help="folder with bN_*.jsonl files")
    p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--out", default=None, help="optional JSON output path")
    args = p.parse_args()

    cfg = ModelConfig.from_json(args.config)
    model = VectraYXNano(cfg).to(args.device).eval()
    load_checkpoint(args.checkpoint, model, map_location=args.device)
    sp = spm.SentencePieceProcessor()
    sp.load(args.tokenizer)

    ctx = {
        "device": args.device,
        "end_id": sp.piece_to_id("<|end|>"),
        "eos_id": sp.eos_id(),
    }

    d = Path(args.data_dir)
    res = {
        "B1_cveqa_keyword":   b1_cveqa(model, sp, load_jsonl(d / "b1_cveqa.jsonl"), ctx),
        "B2_classification":  b2_classification(model, sp, load_jsonl(d / "b2_classification.jsonl"), ctx),
        "B3_commands":        b3_commands(model, sp, load_jsonl(d / "b3_commands.jsonl"), ctx),
        "B4_tooluse":         b4_tooluse(model, sp, load_jsonl(d / "b4_tooluse.jsonl"), ctx),
        "B5_conversational":  b5_conversational(model, sp, load_jsonl(d / "b5_conversational.jsonl"), ctx),
    }
    print("\n=== VectraYX-Bench ===")
    for k, v in res.items():
        print(f"  {k}: {v}")

    if args.out:
        Path(args.out).write_text(json.dumps(res, indent=2, ensure_ascii=False))


if __name__ == "__main__":
    main()