Spaces:
Running on Zero
Running on Zero
File size: 8,799 Bytes
2542dd7 7a4aa97 2542dd7 7a4aa97 2542dd7 2fdf5a0 2542dd7 2fdf5a0 2542dd7 fdbfcae 2542dd7 4a6f75b 2542dd7 4a6f75b 2542dd7 4a6f75b 7a4aa97 2542dd7 4a6f75b 2542dd7 4a6f75b 2542dd7 7a4aa97 2542dd7 7a4aa97 2542dd7 7a4aa97 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 | """Core VERIS classification logic β dual-mode inference.
Supports two backends:
1. Fine-tuned HF model (primary) β runs on ZeroGPU in HF Spaces
2. OpenAI API (fallback) β for local dev or if HF model not available
"""
import json
import logging
import re
logger = logging.getLogger(__name__)
# ββ System prompts ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
CLASSIFY_SYSTEM_PROMPT = (
"You are a VERIS (Vocabulary for Event Recording and Incident Sharing) classifier. "
"Given a security incident description, output a JSON classification using the VERIS framework. "
"Include actor (external/internal/partner with variety and motive), "
"action (malware/hacking/social/misuse/physical/error/environmental with variety and vector), "
"asset (with variety like 'S - Web application', 'U - Laptop'), "
"and attribute (confidentiality/integrity/availability with relevant sub-fields). "
"Return ONLY valid JSON."
)
QA_SYSTEM_PROMPT = (
"You are a VERIS (Vocabulary for Event Recording and Incident Sharing) expert. "
"Answer questions about the VERIS framework accurately and thoroughly. "
"Reference specific VERIS terminology, enumeration values, and concepts. "
"Be helpful and educational. "
"Answer only the user's question. "
"Do not ask follow-up questions. "
"Do not append additional Q&A prompts."
)
# ββ HF Model Backend βββββββββββββββββββββββββββββββββββββββββββββββββββββ
HF_MODEL_ID = "vibesecurityguy/veris-classifier-v2" # LoRA adapter repo
BASE_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3" # Base model
_hf_pipeline = None
_hf_tokenizer = None
def load_hf_model():
"""Load the base model + LoRA adapter from HF Hub. Called once on first request.
The model repo only contains LoRA adapter weights (162 MB), not a full model.
We load the base Mistral-7B-Instruct model, then merge the adapter on top.
"""
global _hf_pipeline, _hf_tokenizer
if _hf_pipeline is not None:
return _hf_pipeline, _hf_tokenizer
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# This model path expects GPU execution (ZeroGPU on Spaces). On CPU-only
# runtimes, transformers can fail with opaque disk offload errors.
if not torch.cuda.is_available():
raise RuntimeError(
"Fine-tuned model requires GPU. This Space appears to be on CPU-only "
"(no CUDA device available). Request ZeroGPU (A10G) or provide an "
"OpenAI API key to use fallback inference."
)
logger.info(f"Loading base model: {BASE_MODEL_ID}")
_hf_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
if _hf_tokenizer.pad_token is None:
_hf_tokenizer.pad_token = _hf_tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
logger.info(f"Applying LoRA adapter: {HF_MODEL_ID}")
model = PeftModel.from_pretrained(model, HF_MODEL_ID)
model = model.merge_and_unload() # Merge adapter into base for faster inference
logger.info("Adapter merged successfully")
_hf_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=_hf_tokenizer,
return_full_text=False,
)
logger.info("Model loaded and ready for inference")
return _hf_pipeline, _hf_tokenizer
def _generate_hf(messages: list[dict], max_new_tokens: int = 1024) -> str:
"""Generate a response using the fine-tuned HF model."""
return _generate_hf_with_options(messages, max_new_tokens=max_new_tokens)
def _generate_hf_with_options(
messages: list[dict],
max_new_tokens: int = 1024,
do_sample: bool = True,
temperature: float = 0.2,
top_p: float = 0.9,
) -> str:
"""Generate a response using the fine-tuned HF model with explicit sampling controls."""
pipe, tokenizer = load_hf_model()
generate_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
}
if do_sample:
generate_kwargs["temperature"] = temperature
generate_kwargs["top_p"] = top_p
outputs = pipe(messages, **generate_kwargs)
return outputs[0]["generated_text"].strip()
# ββ OpenAI Backend ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _generate_openai(
client,
messages: list[dict],
model: str = "gpt-4o",
temperature: float = 0.2,
max_tokens: int = 1000,
json_mode: bool = False,
) -> str:
"""Generate a response using the OpenAI API."""
kwargs = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if json_mode:
kwargs["response_format"] = {"type": "json_object"}
response = client.chat.completions.create(**kwargs)
return response.choices[0].message.content.strip()
def _parse_json_response(raw: str) -> dict:
"""Parse model output into JSON with light recovery for wrapped text."""
text = raw.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
pass
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1]) if len(lines) > 2 else text
text = text.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Recover when the model prepends/appends prose around a JSON object.
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
return json.loads(text[start : end + 1])
raise json.JSONDecodeError("No JSON object found in model output", text, 0)
def _clean_qa_response(answer: str) -> str:
"""Remove model-appended follow-up question chains from QA output."""
text = answer.strip()
match = re.search(r"(?:\n|[.!?]\s+)(What|How|Why|When|Where|Who)\b", text)
if match and match.start() > 0:
text = text[: match.start()].rstrip()
return text
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def classify_incident(
client=None,
description: str = "",
model: str = "gpt-4o",
use_hf: bool = False,
) -> dict:
"""Classify a security incident into the VERIS framework.
Args:
client: OpenAI client (required if use_hf=False)
description: Plain-text incident description
model: OpenAI model name (only used if use_hf=False)
use_hf: If True, use the fine-tuned HF model instead of OpenAI
Returns:
dict: VERIS classification JSON
"""
messages = [
{"role": "system", "content": CLASSIFY_SYSTEM_PROMPT},
{"role": "user", "content": f"Classify this security incident:\n\n{description}"},
]
if use_hf:
raw = _generate_hf_with_options(messages, max_new_tokens=1024, do_sample=False)
else:
if client is None:
raise ValueError("OpenAI client required when use_hf=False")
raw = _generate_openai(
client, messages, model=model, temperature=0.2, json_mode=True
)
return _parse_json_response(raw)
def answer_question(
client=None,
question: str = "",
model: str = "gpt-4o",
use_hf: bool = False,
) -> str:
"""Answer a question about the VERIS framework.
Args:
client: OpenAI client (required if use_hf=False)
question: User's question about VERIS
model: OpenAI model name (only used if use_hf=False)
use_hf: If True, use the fine-tuned HF model instead of OpenAI
Returns:
str: Answer text
"""
messages = [
{"role": "system", "content": QA_SYSTEM_PROMPT},
{"role": "user", "content": question},
]
if use_hf:
raw = _generate_hf_with_options(
messages,
max_new_tokens=320,
do_sample=False,
)
return _clean_qa_response(raw)
else:
if client is None:
raise ValueError("OpenAI client required when use_hf=False")
raw = _generate_openai(
client, messages, model=model, temperature=0.3, max_tokens=800
)
return _clean_qa_response(raw)
|