File size: 8,799 Bytes
2542dd7
 
 
 
 
 
 
 
 
7a4aa97
2542dd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a4aa97
 
 
 
2542dd7
 
 
 
2fdf5a0
 
2542dd7
 
 
 
 
 
 
 
2fdf5a0
2542dd7
 
 
 
 
 
 
 
 
 
fdbfcae
 
 
 
 
 
 
 
 
2542dd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a6f75b
 
 
 
 
 
 
 
 
 
 
2542dd7
 
4a6f75b
 
 
 
 
 
 
 
 
2542dd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a6f75b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a4aa97
 
 
 
 
 
 
 
 
2542dd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a6f75b
2542dd7
 
 
 
 
 
 
4a6f75b
2542dd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a4aa97
 
 
 
 
 
2542dd7
 
 
7a4aa97
2542dd7
 
7a4aa97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
"""Core VERIS classification logic β€” dual-mode inference.

Supports two backends:
1. Fine-tuned HF model (primary) β€” runs on ZeroGPU in HF Spaces
2. OpenAI API (fallback) β€” for local dev or if HF model not available
"""

import json
import logging
import re

logger = logging.getLogger(__name__)

# ── System prompts ────────────────────────────────────────────────────────

CLASSIFY_SYSTEM_PROMPT = (
    "You are a VERIS (Vocabulary for Event Recording and Incident Sharing) classifier. "
    "Given a security incident description, output a JSON classification using the VERIS framework. "
    "Include actor (external/internal/partner with variety and motive), "
    "action (malware/hacking/social/misuse/physical/error/environmental with variety and vector), "
    "asset (with variety like 'S - Web application', 'U - Laptop'), "
    "and attribute (confidentiality/integrity/availability with relevant sub-fields). "
    "Return ONLY valid JSON."
)

QA_SYSTEM_PROMPT = (
    "You are a VERIS (Vocabulary for Event Recording and Incident Sharing) expert. "
    "Answer questions about the VERIS framework accurately and thoroughly. "
    "Reference specific VERIS terminology, enumeration values, and concepts. "
    "Be helpful and educational. "
    "Answer only the user's question. "
    "Do not ask follow-up questions. "
    "Do not append additional Q&A prompts."
)

# ── HF Model Backend ─────────────────────────────────────────────────────

HF_MODEL_ID = "vibesecurityguy/veris-classifier-v2"   # LoRA adapter repo
BASE_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"   # Base model
_hf_pipeline = None
_hf_tokenizer = None


def load_hf_model():
    """Load the base model + LoRA adapter from HF Hub. Called once on first request.

    The model repo only contains LoRA adapter weights (162 MB), not a full model.
    We load the base Mistral-7B-Instruct model, then merge the adapter on top.
    """
    global _hf_pipeline, _hf_tokenizer

    if _hf_pipeline is not None:
        return _hf_pipeline, _hf_tokenizer

    import torch
    from peft import PeftModel
    from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

    # This model path expects GPU execution (ZeroGPU on Spaces). On CPU-only
    # runtimes, transformers can fail with opaque disk offload errors.
    if not torch.cuda.is_available():
        raise RuntimeError(
            "Fine-tuned model requires GPU. This Space appears to be on CPU-only "
            "(no CUDA device available). Request ZeroGPU (A10G) or provide an "
            "OpenAI API key to use fallback inference."
        )

    logger.info(f"Loading base model: {BASE_MODEL_ID}")
    _hf_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
    if _hf_tokenizer.pad_token is None:
        _hf_tokenizer.pad_token = _hf_tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )

    logger.info(f"Applying LoRA adapter: {HF_MODEL_ID}")
    model = PeftModel.from_pretrained(model, HF_MODEL_ID)
    model = model.merge_and_unload()  # Merge adapter into base for faster inference
    logger.info("Adapter merged successfully")

    _hf_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=_hf_tokenizer,
        return_full_text=False,
    )

    logger.info("Model loaded and ready for inference")
    return _hf_pipeline, _hf_tokenizer


def _generate_hf(messages: list[dict], max_new_tokens: int = 1024) -> str:
    """Generate a response using the fine-tuned HF model."""
    return _generate_hf_with_options(messages, max_new_tokens=max_new_tokens)


def _generate_hf_with_options(
    messages: list[dict],
    max_new_tokens: int = 1024,
    do_sample: bool = True,
    temperature: float = 0.2,
    top_p: float = 0.9,
) -> str:
    """Generate a response using the fine-tuned HF model with explicit sampling controls."""
    pipe, tokenizer = load_hf_model()

    generate_kwargs = {
        "max_new_tokens": max_new_tokens,
        "do_sample": do_sample,
    }
    if do_sample:
        generate_kwargs["temperature"] = temperature
        generate_kwargs["top_p"] = top_p

    outputs = pipe(messages, **generate_kwargs)

    return outputs[0]["generated_text"].strip()


# ── OpenAI Backend ────────────────────────────────────────────────────────


def _generate_openai(
    client,
    messages: list[dict],
    model: str = "gpt-4o",
    temperature: float = 0.2,
    max_tokens: int = 1000,
    json_mode: bool = False,
) -> str:
    """Generate a response using the OpenAI API."""
    kwargs = {
        "model": model,
        "messages": messages,
        "temperature": temperature,
        "max_tokens": max_tokens,
    }
    if json_mode:
        kwargs["response_format"] = {"type": "json_object"}

    response = client.chat.completions.create(**kwargs)
    return response.choices[0].message.content.strip()


def _parse_json_response(raw: str) -> dict:
    """Parse model output into JSON with light recovery for wrapped text."""
    text = raw.strip()
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass

    if text.startswith("```"):
        lines = text.split("\n")
        text = "\n".join(lines[1:-1]) if len(lines) > 2 else text
        text = text.strip()
        try:
            return json.loads(text)
        except json.JSONDecodeError:
            pass

    # Recover when the model prepends/appends prose around a JSON object.
    start = text.find("{")
    end = text.rfind("}")
    if start != -1 and end != -1 and end > start:
        return json.loads(text[start : end + 1])

    raise json.JSONDecodeError("No JSON object found in model output", text, 0)


def _clean_qa_response(answer: str) -> str:
    """Remove model-appended follow-up question chains from QA output."""
    text = answer.strip()
    match = re.search(r"(?:\n|[.!?]\s+)(What|How|Why|When|Where|Who)\b", text)
    if match and match.start() > 0:
        text = text[: match.start()].rstrip()
    return text


# ── Public API ────────────────────────────────────────────────────────────


def classify_incident(
    client=None,
    description: str = "",
    model: str = "gpt-4o",
    use_hf: bool = False,
) -> dict:
    """Classify a security incident into the VERIS framework.

    Args:
        client: OpenAI client (required if use_hf=False)
        description: Plain-text incident description
        model: OpenAI model name (only used if use_hf=False)
        use_hf: If True, use the fine-tuned HF model instead of OpenAI

    Returns:
        dict: VERIS classification JSON
    """
    messages = [
        {"role": "system", "content": CLASSIFY_SYSTEM_PROMPT},
        {"role": "user", "content": f"Classify this security incident:\n\n{description}"},
    ]

    if use_hf:
        raw = _generate_hf_with_options(messages, max_new_tokens=1024, do_sample=False)
    else:
        if client is None:
            raise ValueError("OpenAI client required when use_hf=False")
        raw = _generate_openai(
            client, messages, model=model, temperature=0.2, json_mode=True
        )

    return _parse_json_response(raw)


def answer_question(
    client=None,
    question: str = "",
    model: str = "gpt-4o",
    use_hf: bool = False,
) -> str:
    """Answer a question about the VERIS framework.

    Args:
        client: OpenAI client (required if use_hf=False)
        question: User's question about VERIS
        model: OpenAI model name (only used if use_hf=False)
        use_hf: If True, use the fine-tuned HF model instead of OpenAI

    Returns:
        str: Answer text
    """
    messages = [
        {"role": "system", "content": QA_SYSTEM_PROMPT},
        {"role": "user", "content": question},
    ]

    if use_hf:
        raw = _generate_hf_with_options(
            messages,
            max_new_tokens=320,
            do_sample=False,
        )
        return _clean_qa_response(raw)
    else:
        if client is None:
            raise ValueError("OpenAI client required when use_hf=False")
        raw = _generate_openai(
            client, messages, model=model, temperature=0.3, max_tokens=800
        )
        return _clean_qa_response(raw)