ClinAssist Bot
Harden upload flow, add tactical UI, and model route selection
3b13c95
"""
Inference module for report generation.
Supports UMSR and Google Flan model routes with graceful fallback behavior.
"""
from __future__ import annotations
import gc
import os
from typing import Any
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
from .preprocessing import extract_iocs, parse_input, preprocess_for_model
from .risk_scoring import compute_risk_score, extract_severities
_model = None
_tokenizer = None
_model_source: str | None = None
_is_encoder_decoder = False
_loaded_preference: str | None = None
DEFAULT_MODEL_ID = os.environ.get("MODEL_ID", "NorthernTribe-Research/UMSR-Reasoner-7B")
MODEL_DIR = os.environ.get("MODEL_DIR", "models/flan_t5_report")
FALLBACK_MODEL_ID = os.environ.get("FALLBACK_MODEL_ID", "google/flan-t5-base")
MAX_INPUT_TOKENS = int(os.environ.get("MAX_INPUT_TOKENS", "1024"))
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "220"))
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SYSTEM_PROMPT = (
"You are a cybersecurity incident analyst. "
"Write a concise executive summary of the incident, attack flow, and response actions."
)
_AUTO_SENTINEL = "__auto__"
MODEL_PRESETS: dict[str, str] = {
"Auto (Best Available)": _AUTO_SENTINEL,
"UMSR-Reasoner-7B": "NorthernTribe-Research/UMSR-Reasoner-7B",
"Google Flan-T5 Base": "google/flan-t5-base",
}
LOCAL_MODEL_LABEL = "Local Fine-Tuned (MODEL_DIR)"
def get_model_choices() -> list[str]:
"""UI-safe list of model route options."""
choices = list(MODEL_PRESETS.keys())
if _local_model_available():
choices.append(LOCAL_MODEL_LABEL)
return choices
def get_default_model_choice() -> str:
"""Default dropdown value for model route selection."""
return "Auto (Best Available)"
def get_runtime_status(model_preference: str | None = None, active_model: str | None = None) -> str:
"""Human-readable runtime status for UI diagnostics."""
requested = _resolve_model_preference(model_preference) or DEFAULT_MODEL_ID
active = active_model or _model_source or "Not loaded yet"
return f"Device: `{_DEVICE.type}` | Requested: `{requested}` | Active: `{active}`"
def _local_model_available() -> bool:
return os.path.exists(MODEL_DIR) and bool(os.listdir(MODEL_DIR))
def _resolve_model_preference(model_preference: str | None) -> str | None:
"""Resolve UI route to a concrete model id/path. None means auto-routing."""
if not model_preference:
return None
if model_preference == LOCAL_MODEL_LABEL:
return MODEL_DIR
preset = MODEL_PRESETS.get(model_preference)
if preset == _AUTO_SENTINEL:
return None
if preset:
return preset
# Treat unknown value as a direct model id/path.
return model_preference
def _is_heavy_model(model_ref: str) -> bool:
"""Heuristic to avoid trying very large models first on CPU auto-route."""
key = model_ref.lower()
heavy_markers = ("7b", "8b", "13b", "34b", "70b")
return any(marker in key for marker in heavy_markers)
def _dedupe(items: list[str]) -> list[str]:
seen: set[str] = set()
ordered: list[str] = []
for item in items:
if item and item not in seen:
ordered.append(item)
seen.add(item)
return ordered
def _model_candidates(model_preference: str | None = None) -> list[str]:
"""Ordered model candidates based on route selection and environment."""
explicit = _resolve_model_preference(model_preference)
if explicit:
candidates = [explicit]
if _local_model_available() and explicit != MODEL_DIR:
candidates.append(MODEL_DIR)
if FALLBACK_MODEL_ID != explicit:
candidates.append(FALLBACK_MODEL_ID)
return _dedupe(candidates)
# Auto-route: keep app responsive on CPU by trying smaller fallback first.
if _DEVICE.type == "cpu" and _is_heavy_model(DEFAULT_MODEL_ID):
candidates = [FALLBACK_MODEL_ID, DEFAULT_MODEL_ID]
else:
candidates = [DEFAULT_MODEL_ID]
if _local_model_available():
candidates.append(MODEL_DIR)
candidates.append(FALLBACK_MODEL_ID)
return _dedupe(candidates)
def _clear_loaded_model() -> None:
"""Release model memory when switching route choices."""
global _model, _tokenizer, _model_source, _is_encoder_decoder, _loaded_preference
_model = None
_tokenizer = None
_model_source = None
_is_encoder_decoder = False
_loaded_preference = None
gc.collect()
if _DEVICE.type == "cuda":
torch.cuda.empty_cache()
def _load_model(model_preference: str | None = None):
"""Lazy load model/tokenizer with route-aware cache and fallback."""
global _model, _tokenizer, _model_source, _is_encoder_decoder, _loaded_preference
route_key = model_preference or get_default_model_choice()
if _model is not None and _loaded_preference == route_key:
return _model, _tokenizer, _is_encoder_decoder
if _model is not None and _loaded_preference != route_key:
_clear_loaded_model()
load_kwargs: dict[str, Any] = {
"trust_remote_code": True,
"low_cpu_mem_usage": True,
}
load_kwargs["torch_dtype"] = torch.float16 if _DEVICE.type == "cuda" else torch.float32
last_error: Exception | None = None
candidates = _model_candidates(model_preference)
for candidate in candidates:
try:
print(f"Loading model from {candidate} on {_DEVICE}...")
config = AutoConfig.from_pretrained(candidate, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(candidate, trust_remote_code=True)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
if config.is_encoder_decoder:
model = AutoModelForSeq2SeqLM.from_pretrained(candidate, **load_kwargs)
is_encoder_decoder = True
else:
model = AutoModelForCausalLM.from_pretrained(candidate, **load_kwargs)
is_encoder_decoder = False
model.to(_DEVICE)
model.eval()
_model = model
_tokenizer = tokenizer
_model_source = candidate
_is_encoder_decoder = is_encoder_decoder
_loaded_preference = route_key
print(f"Loaded model: {_model_source}")
return _model, _tokenizer, _is_encoder_decoder
except Exception as exc:
last_error = exc
print(f"Error loading model from {candidate}: {exc}")
if _DEVICE.type == "cuda":
torch.cuda.empty_cache()
raise RuntimeError(f"Failed to load any model candidate: {candidates}") from last_error
def _build_prompt(model_input_text: str, tokenizer, is_encoder_decoder: bool) -> str:
"""Build route-appropriate prompt with chat template support."""
if is_encoder_decoder:
return (
"Summarize this cybersecurity incident for operations leadership. "
"Include notable entities, likely attack flow, and immediate containment actions.\n\n"
f"Incident data:\n{model_input_text}"
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": model_input_text},
]
if hasattr(tokenizer, "apply_chat_template"):
try:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
except Exception:
pass
return (
f"{SYSTEM_PROMPT}\n\n"
f"Incident logs and alerts:\n{model_input_text}\n\n"
"Executive summary:"
)
def _generate_narrative(model_input_text: str, model_preference: str | None = None) -> tuple[str, str]:
"""Generate incident narrative from the selected route."""
if not model_input_text.strip():
return "No usable incident text was provided.", "unavailable"
model, tokenizer, is_encoder_decoder = _load_model(model_preference)
prompt_text = _build_prompt(model_input_text, tokenizer, is_encoder_decoder)
model_inputs = tokenizer(
prompt_text,
return_tensors="pt",
truncation=True,
max_length=MAX_INPUT_TOKENS,
)
input_len = model_inputs["input_ids"].shape[-1]
model_inputs = {k: v.to(_DEVICE) for k, v in model_inputs.items()}
gen_kwargs: dict[str, Any] = {
"max_new_tokens": MAX_NEW_TOKENS,
"do_sample": False,
}
if tokenizer.pad_token_id is not None:
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
if tokenizer.eos_token_id is not None:
gen_kwargs["eos_token_id"] = tokenizer.eos_token_id
if is_encoder_decoder:
gen_kwargs["num_beams"] = 2
with torch.inference_mode():
output_ids = model.generate(**model_inputs, **gen_kwargs)
if is_encoder_decoder:
generated_ids = output_ids[0]
else:
generated_ids = output_ids[0][input_len:]
text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
if not text:
text = "Unable to generate a model summary from this incident input."
return text, _model_source or "unavailable"
def generate_report(raw_input: str, model_preference: str | None = None) -> dict[str, Any]:
"""Generate a cybersecurity report from logs/text/JSON using the selected model route."""
parsed = parse_input(raw_input)
requested_model = _resolve_model_preference(model_preference) or DEFAULT_MODEL_ID
if parsed["type"] == "empty":
return {
"executive_summary": "No input provided. Paste logs or upload JSON to generate a report.",
"technical_summary": "",
"ttps": [],
"iocs": [],
"cves": [],
"threat_actors": [],
"risk_score": 0,
"confidence": 0.0,
"model_used": "n/a",
"model_requested": requested_model,
"device": _DEVICE.type,
}
full_text = parsed["content"]
iocs_raw = extract_iocs(full_text)
ttps: list[str] = []
cves: list[str] = []
iocs: list[str] = []
for indicator in iocs_raw:
if indicator.startswith("T") and indicator[1:].replace(".", "").isdigit():
ttps.append(indicator)
elif indicator.upper().startswith("CVE-"):
cves.append(indicator)
else:
iocs.append(indicator)
model_input_text = preprocess_for_model(raw_input)
try:
generated_text, model_used = _generate_narrative(model_input_text, model_preference=model_preference)
except Exception as exc:
model_used = _model_source or "unavailable"
generated_text = (
"Model generation is temporarily unavailable. "
f"Fallback summary: processed {len(parsed['lines'])} event(s), "
f"detected {len(ttps)} TTP(s), {len(cves)} CVE(s), and {len(iocs)} IOC(s). "
f"Error type: {type(exc).__name__}."
)
num_events = len(parsed["lines"])
tech_summary = (
f"Input type: {parsed['type']}. {num_events} events processed. "
f"Detected {len(ttps)} TTPs, {len(cves)} CVEs, and {len(iocs)} IOCs."
)
severities = extract_severities(full_text)
risk, confidence = compute_risk_score(ttps, cves, iocs, severities, num_events)
return {
"executive_summary": generated_text,
"technical_summary": tech_summary,
"ttps": ttps,
"iocs": iocs,
"cves": cves,
"threat_actors": [],
"risk_score": risk,
"confidence": round(confidence, 2),
"model_used": model_used,
"model_requested": requested_model,
"device": _DEVICE.type,
}
def generate_stub_report(raw_input: str) -> dict[str, Any]:
"""Backward-compatible alias used by batch/evaluation modules."""
return generate_report(raw_input)
def _tag_entity(text: str, tag: str) -> str:
"""Tag entity for highlighting (Markdown inline code)."""
return f"`{text}`"
def format_report_markdown(report: dict[str, Any]) -> str:
"""Format report dict as Markdown with tagging, tables, and model metadata."""
sections = [
"## Executive Summary",
report["executive_summary"],
"",
"## Technical Summary",
report["technical_summary"],
"",
"## Extracted Entities",
"| Type | Count | Sample |",
"|------|-------|--------|",
f"| **TTPs** | {len(report['ttps'])} | {', '.join(_tag_entity(t, 'ttp') for t in report['ttps'][:5]) or '-'} |",
f"| **IOCs** | {len(report['iocs'])} | {', '.join(_tag_entity(i, 'ioc') for i in report['iocs'][:5]) or '-'} |",
f"| **CVEs** | {len(report['cves'])} | {', '.join(_tag_entity(c, 'cve') for c in report['cves'][:5]) or '-'} |",
"",
"## Risk Assessment",
f"- **Risk Score:** {report['risk_score']}/100",
f"- **Confidence:** {report['confidence']}",
f"- **Model Used:** {report.get('model_used', 'n/a')}",
f"- **Runtime Device:** {report.get('device', 'n/a')}",
"",
]
if report["ttps"]:
sections.extend(["### TTPs (tagged)", ", ".join(_tag_entity(t, "ttp") for t in report["ttps"]), ""])
if report["iocs"]:
sections.extend(["### IOCs (highlighted)", ", ".join(_tag_entity(i, "ioc") for i in report["iocs"][:25]), ""])
if report["cves"]:
sections.extend(["### CVEs", ", ".join(_tag_entity(c, "cve") for c in report["cves"]), ""])
contrib = report.get("ttps", []) + report.get("cves", [])[:5] + report.get("iocs", [])[:5]
if contrib:
sections.extend(
["## Explainability", "Contributing entities (driving risk score):", ", ".join(_tag_entity(x, "") for x in contrib), ""]
)
return "\n".join(sections)