""" Inference module for report generation. Supports UMSR and Google Flan model routes with graceful fallback behavior. """ from __future__ import annotations import gc import os from typing import Any import torch from transformers import ( AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, ) from .preprocessing import extract_iocs, parse_input, preprocess_for_model from .risk_scoring import compute_risk_score, extract_severities _model = None _tokenizer = None _model_source: str | None = None _is_encoder_decoder = False _loaded_preference: str | None = None DEFAULT_MODEL_ID = os.environ.get("MODEL_ID", "NorthernTribe-Research/UMSR-Reasoner-7B") MODEL_DIR = os.environ.get("MODEL_DIR", "models/flan_t5_report") FALLBACK_MODEL_ID = os.environ.get("FALLBACK_MODEL_ID", "google/flan-t5-base") MAX_INPUT_TOKENS = int(os.environ.get("MAX_INPUT_TOKENS", "1024")) MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "220")) _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") SYSTEM_PROMPT = ( "You are a cybersecurity incident analyst. " "Write a concise executive summary of the incident, attack flow, and response actions." ) _AUTO_SENTINEL = "__auto__" MODEL_PRESETS: dict[str, str] = { "Auto (Best Available)": _AUTO_SENTINEL, "UMSR-Reasoner-7B": "NorthernTribe-Research/UMSR-Reasoner-7B", "Google Flan-T5 Base": "google/flan-t5-base", } LOCAL_MODEL_LABEL = "Local Fine-Tuned (MODEL_DIR)" def get_model_choices() -> list[str]: """UI-safe list of model route options.""" choices = list(MODEL_PRESETS.keys()) if _local_model_available(): choices.append(LOCAL_MODEL_LABEL) return choices def get_default_model_choice() -> str: """Default dropdown value for model route selection.""" return "Auto (Best Available)" def get_runtime_status(model_preference: str | None = None, active_model: str | None = None) -> str: """Human-readable runtime status for UI diagnostics.""" requested = _resolve_model_preference(model_preference) or DEFAULT_MODEL_ID active = active_model or _model_source or "Not loaded yet" return f"Device: `{_DEVICE.type}` | Requested: `{requested}` | Active: `{active}`" def _local_model_available() -> bool: return os.path.exists(MODEL_DIR) and bool(os.listdir(MODEL_DIR)) def _resolve_model_preference(model_preference: str | None) -> str | None: """Resolve UI route to a concrete model id/path. None means auto-routing.""" if not model_preference: return None if model_preference == LOCAL_MODEL_LABEL: return MODEL_DIR preset = MODEL_PRESETS.get(model_preference) if preset == _AUTO_SENTINEL: return None if preset: return preset # Treat unknown value as a direct model id/path. return model_preference def _is_heavy_model(model_ref: str) -> bool: """Heuristic to avoid trying very large models first on CPU auto-route.""" key = model_ref.lower() heavy_markers = ("7b", "8b", "13b", "34b", "70b") return any(marker in key for marker in heavy_markers) def _dedupe(items: list[str]) -> list[str]: seen: set[str] = set() ordered: list[str] = [] for item in items: if item and item not in seen: ordered.append(item) seen.add(item) return ordered def _model_candidates(model_preference: str | None = None) -> list[str]: """Ordered model candidates based on route selection and environment.""" explicit = _resolve_model_preference(model_preference) if explicit: candidates = [explicit] if _local_model_available() and explicit != MODEL_DIR: candidates.append(MODEL_DIR) if FALLBACK_MODEL_ID != explicit: candidates.append(FALLBACK_MODEL_ID) return _dedupe(candidates) # Auto-route: keep app responsive on CPU by trying smaller fallback first. if _DEVICE.type == "cpu" and _is_heavy_model(DEFAULT_MODEL_ID): candidates = [FALLBACK_MODEL_ID, DEFAULT_MODEL_ID] else: candidates = [DEFAULT_MODEL_ID] if _local_model_available(): candidates.append(MODEL_DIR) candidates.append(FALLBACK_MODEL_ID) return _dedupe(candidates) def _clear_loaded_model() -> None: """Release model memory when switching route choices.""" global _model, _tokenizer, _model_source, _is_encoder_decoder, _loaded_preference _model = None _tokenizer = None _model_source = None _is_encoder_decoder = False _loaded_preference = None gc.collect() if _DEVICE.type == "cuda": torch.cuda.empty_cache() def _load_model(model_preference: str | None = None): """Lazy load model/tokenizer with route-aware cache and fallback.""" global _model, _tokenizer, _model_source, _is_encoder_decoder, _loaded_preference route_key = model_preference or get_default_model_choice() if _model is not None and _loaded_preference == route_key: return _model, _tokenizer, _is_encoder_decoder if _model is not None and _loaded_preference != route_key: _clear_loaded_model() load_kwargs: dict[str, Any] = { "trust_remote_code": True, "low_cpu_mem_usage": True, } load_kwargs["torch_dtype"] = torch.float16 if _DEVICE.type == "cuda" else torch.float32 last_error: Exception | None = None candidates = _model_candidates(model_preference) for candidate in candidates: try: print(f"Loading model from {candidate} on {_DEVICE}...") config = AutoConfig.from_pretrained(candidate, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(candidate, trust_remote_code=True) if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token = tokenizer.eos_token if config.is_encoder_decoder: model = AutoModelForSeq2SeqLM.from_pretrained(candidate, **load_kwargs) is_encoder_decoder = True else: model = AutoModelForCausalLM.from_pretrained(candidate, **load_kwargs) is_encoder_decoder = False model.to(_DEVICE) model.eval() _model = model _tokenizer = tokenizer _model_source = candidate _is_encoder_decoder = is_encoder_decoder _loaded_preference = route_key print(f"Loaded model: {_model_source}") return _model, _tokenizer, _is_encoder_decoder except Exception as exc: last_error = exc print(f"Error loading model from {candidate}: {exc}") if _DEVICE.type == "cuda": torch.cuda.empty_cache() raise RuntimeError(f"Failed to load any model candidate: {candidates}") from last_error def _build_prompt(model_input_text: str, tokenizer, is_encoder_decoder: bool) -> str: """Build route-appropriate prompt with chat template support.""" if is_encoder_decoder: return ( "Summarize this cybersecurity incident for operations leadership. " "Include notable entities, likely attack flow, and immediate containment actions.\n\n" f"Incident data:\n{model_input_text}" ) messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": model_input_text}, ] if hasattr(tokenizer, "apply_chat_template"): try: return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) except Exception: pass return ( f"{SYSTEM_PROMPT}\n\n" f"Incident logs and alerts:\n{model_input_text}\n\n" "Executive summary:" ) def _generate_narrative(model_input_text: str, model_preference: str | None = None) -> tuple[str, str]: """Generate incident narrative from the selected route.""" if not model_input_text.strip(): return "No usable incident text was provided.", "unavailable" model, tokenizer, is_encoder_decoder = _load_model(model_preference) prompt_text = _build_prompt(model_input_text, tokenizer, is_encoder_decoder) model_inputs = tokenizer( prompt_text, return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKENS, ) input_len = model_inputs["input_ids"].shape[-1] model_inputs = {k: v.to(_DEVICE) for k, v in model_inputs.items()} gen_kwargs: dict[str, Any] = { "max_new_tokens": MAX_NEW_TOKENS, "do_sample": False, } if tokenizer.pad_token_id is not None: gen_kwargs["pad_token_id"] = tokenizer.pad_token_id if tokenizer.eos_token_id is not None: gen_kwargs["eos_token_id"] = tokenizer.eos_token_id if is_encoder_decoder: gen_kwargs["num_beams"] = 2 with torch.inference_mode(): output_ids = model.generate(**model_inputs, **gen_kwargs) if is_encoder_decoder: generated_ids = output_ids[0] else: generated_ids = output_ids[0][input_len:] text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() if not text: text = "Unable to generate a model summary from this incident input." return text, _model_source or "unavailable" def generate_report(raw_input: str, model_preference: str | None = None) -> dict[str, Any]: """Generate a cybersecurity report from logs/text/JSON using the selected model route.""" parsed = parse_input(raw_input) requested_model = _resolve_model_preference(model_preference) or DEFAULT_MODEL_ID if parsed["type"] == "empty": return { "executive_summary": "No input provided. Paste logs or upload JSON to generate a report.", "technical_summary": "", "ttps": [], "iocs": [], "cves": [], "threat_actors": [], "risk_score": 0, "confidence": 0.0, "model_used": "n/a", "model_requested": requested_model, "device": _DEVICE.type, } full_text = parsed["content"] iocs_raw = extract_iocs(full_text) ttps: list[str] = [] cves: list[str] = [] iocs: list[str] = [] for indicator in iocs_raw: if indicator.startswith("T") and indicator[1:].replace(".", "").isdigit(): ttps.append(indicator) elif indicator.upper().startswith("CVE-"): cves.append(indicator) else: iocs.append(indicator) model_input_text = preprocess_for_model(raw_input) try: generated_text, model_used = _generate_narrative(model_input_text, model_preference=model_preference) except Exception as exc: model_used = _model_source or "unavailable" generated_text = ( "Model generation is temporarily unavailable. " f"Fallback summary: processed {len(parsed['lines'])} event(s), " f"detected {len(ttps)} TTP(s), {len(cves)} CVE(s), and {len(iocs)} IOC(s). " f"Error type: {type(exc).__name__}." ) num_events = len(parsed["lines"]) tech_summary = ( f"Input type: {parsed['type']}. {num_events} events processed. " f"Detected {len(ttps)} TTPs, {len(cves)} CVEs, and {len(iocs)} IOCs." ) severities = extract_severities(full_text) risk, confidence = compute_risk_score(ttps, cves, iocs, severities, num_events) return { "executive_summary": generated_text, "technical_summary": tech_summary, "ttps": ttps, "iocs": iocs, "cves": cves, "threat_actors": [], "risk_score": risk, "confidence": round(confidence, 2), "model_used": model_used, "model_requested": requested_model, "device": _DEVICE.type, } def generate_stub_report(raw_input: str) -> dict[str, Any]: """Backward-compatible alias used by batch/evaluation modules.""" return generate_report(raw_input) def _tag_entity(text: str, tag: str) -> str: """Tag entity for highlighting (Markdown inline code).""" return f"`{text}`" def format_report_markdown(report: dict[str, Any]) -> str: """Format report dict as Markdown with tagging, tables, and model metadata.""" sections = [ "## Executive Summary", report["executive_summary"], "", "## Technical Summary", report["technical_summary"], "", "## Extracted Entities", "| Type | Count | Sample |", "|------|-------|--------|", f"| **TTPs** | {len(report['ttps'])} | {', '.join(_tag_entity(t, 'ttp') for t in report['ttps'][:5]) or '-'} |", f"| **IOCs** | {len(report['iocs'])} | {', '.join(_tag_entity(i, 'ioc') for i in report['iocs'][:5]) or '-'} |", f"| **CVEs** | {len(report['cves'])} | {', '.join(_tag_entity(c, 'cve') for c in report['cves'][:5]) or '-'} |", "", "## Risk Assessment", f"- **Risk Score:** {report['risk_score']}/100", f"- **Confidence:** {report['confidence']}", f"- **Model Used:** {report.get('model_used', 'n/a')}", f"- **Runtime Device:** {report.get('device', 'n/a')}", "", ] if report["ttps"]: sections.extend(["### TTPs (tagged)", ", ".join(_tag_entity(t, "ttp") for t in report["ttps"]), ""]) if report["iocs"]: sections.extend(["### IOCs (highlighted)", ", ".join(_tag_entity(i, "ioc") for i in report["iocs"][:25]), ""]) if report["cves"]: sections.extend(["### CVEs", ", ".join(_tag_entity(c, "cve") for c in report["cves"]), ""]) contrib = report.get("ttps", []) + report.get("cves", [])[:5] + report.get("iocs", [])[:5] if contrib: sections.extend( ["## Explainability", "Contributing entities (driving risk score):", ", ".join(_tag_entity(x, "") for x in contrib), ""] ) return "\n".join(sections)