SentinelAI / agents /ai_analyst_agent.py
iitian's picture
Sync SentinelAI project and add Hugging Face Docker Space layout.
8b3905d
"""AI Security Analyst — vLLM / OpenAI-compatible, Ollama, or cinematic fallback."""
from __future__ import annotations
import json
import logging
import os
import re
from typing import Any
import httpx
logger = logging.getLogger(__name__)
from models.schemas import AnalystReport, Incident, RiskAssessment
async def generate_analyst_report(incident: Incident, risk: RiskAssessment) -> AnalystReport:
prompt = _build_prompt(incident, risk)
text: str | None = None
vllm_base = (os.getenv("VLLM_BASE_URL") or os.getenv("OPENAI_BASE_URL") or "").strip()
if vllm_base:
text = await _openai_compatible_chat(
vllm_base,
os.getenv("SENTINEL_LLM_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct"),
prompt,
)
if not text:
ollama = os.getenv("OLLAMA_HOST", "http://localhost:11434")
model = os.getenv("OLLAMA_MODEL", os.getenv("SENTINEL_LLM_MODEL", "llama3"))
text = await _ollama_generate(ollama, model, prompt)
if not text:
text = _cinematic_fallback_json(incident, risk)
parsed = _parse_analyst_json(text, incident, risk)
return AnalystReport(
incident_id=incident.id,
executive_summary=parsed["executive_summary"],
technical_analysis=parsed["technical_analysis"],
investigation_notes=parsed["investigation_notes"],
indicators=_extract_iocs(incident),
recommended_actions=parsed["recommended_actions"],
)
def _build_prompt(incident: Incident, risk: RiskAssessment) -> str:
tl = json.dumps(incident.timeline[:20], default=str)
return f"""You are a senior SOC analyst writing an executive-ready incident briefing.
Output ONLY valid JSON (no markdown fences) with exactly these string keys:
- "narrative": 2-4 sentences. Opening line MUST start with "SentinelAI detected". Enterprise tone: technical, concise, security-focused. Reference SSH/auth abuse, suspicious IPs, privilege moves, or outbound retrieval when applicable.
- "progression": numbered step-by-step attack progression (use \\n between steps). Map what likely happened chronologically.
- "severity_rationale": 2-3 sentences explaining why severity is justified (risk score {risk.risk_score}, label {risk.severity.value}), confidence, and blast radius.
- "recommended_actions": array of 4-7 short imperative strings (e.g. "Block offending IP at perimeter", "Rotate credentials for affected accounts", "Inspect shell history and authorized_keys", "Enable MFA on privileged users").
Incident title: {incident.title}
Machine summary: {incident.summary}
Risk: score={risk.risk_score} severity={risk.severity.value}
Timeline JSON: {tl}
"""
async def _openai_compatible_chat(base_url: str, model: str, prompt: str) -> str | None:
key = os.getenv("VLLM_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
headers: dict[str, str] = {
"Accept": "application/json",
"Content-Type": "application/json",
}
if key:
headers["Authorization"] = f"Bearer {key}"
max_tokens = int(os.getenv("LLM_MAX_TOKENS", "4096"))
payload: dict[str, Any] = {
"model": model,
"max_tokens": max_tokens,
"temperature": float(os.getenv("LLM_TEMPERATURE", "0.2")),
"messages": [
{
"role": "system",
"content": "You write incident reports as strict JSON only. No markdown.",
},
{"role": "user", "content": prompt},
],
}
_top_p = os.getenv("LLM_TOP_P")
if _top_p not in (None, ""):
payload["top_p"] = float(_top_p)
_top_k = os.getenv("LLM_TOP_K")
if _top_k not in (None, ""):
payload["top_k"] = int(_top_k)
base = base_url.rstrip("/")
chat_url = f"{base}/chat/completions" if base.endswith("/v1") else f"{base}/v1/chat/completions"
try:
async with httpx.AsyncClient(timeout=120.0) as client:
r = await client.post(
chat_url,
headers=headers,
json=payload,
)
if r.status_code != 200:
logger.warning(
"OpenAI-compatible chat failed: %s %s",
r.status_code,
(r.text or "")[:800],
)
return None
data = r.json()
choice = (data.get("choices") or [{}])[0]
msg = choice.get("message") or {}
content = (msg.get("content") or "").strip()
return _normalize_llm_json(content)
except Exception: # noqa: BLE001
return None
def _normalize_llm_json(content: str) -> str:
s = content.strip()
fence = re.match(r"^```(?:json)?\s*([\s\S]*?)```$", s, re.IGNORECASE)
if fence:
s = fence.group(1).strip()
try:
json.loads(s)
return s
except json.JSONDecodeError:
m = re.search(r"\{[\s\S]*\}", s)
if m:
return m.group(0).strip()
return s
async def _ollama_generate(host: str, model: str, prompt: str) -> str | None:
try:
async with httpx.AsyncClient(timeout=120.0) as client:
r = await client.post(
f"{host.rstrip('/')}/api/generate",
json={"model": model, "prompt": prompt, "stream": False},
)
if r.status_code != 200:
return None
return (r.json().get("response") or "").strip()
except Exception: # noqa: BLE001
return None
def _parse_analyst_json(blob: str, incident: Incident, risk: RiskAssessment) -> dict[str, Any]:
try:
data = json.loads(blob)
except json.JSONDecodeError:
return _cinematic_fallback_dict(incident, risk)
narrative = str(data.get("narrative") or data.get("executive") or "").strip()
progression = str(data.get("progression") or data.get("technical") or "").strip()
sev = str(data.get("severity_rationale") or data.get("notes") or "").strip()
actions = data.get("recommended_actions") or data.get("actions") or []
if isinstance(actions, str):
actions = [x.strip("- •\t ") for x in actions.split("\n") if x.strip()]
if not isinstance(actions, list):
actions = []
actions = [str(a).strip() for a in actions if str(a).strip()][:12]
if not narrative:
return _cinematic_fallback_dict(incident, risk)
if not progression:
progression = _default_progression(incident)
if not sev:
sev = _default_severity_rationale(risk)
if not actions:
actions = _default_actions()
return {
"executive_summary": narrative,
"technical_analysis": progression,
"investigation_notes": sev,
"recommended_actions": actions,
}
def _cinematic_fallback_json(incident: Incident, risk: RiskAssessment) -> str:
d = _cinematic_fallback_dict(incident, risk)
return json.dumps(
{
"narrative": d["executive_summary"],
"progression": d["technical_analysis"],
"severity_rationale": d["investigation_notes"],
"recommended_actions": d["recommended_actions"],
}
)
def _cinematic_fallback_dict(incident: Incident, risk: RiskAssessment) -> dict[str, Any]:
return {
"executive_summary": (
f"SentinelAI detected correlated authentication and host telemetry consistent with a targeted intrusion "
f"chain against assets tied to “{incident.title}”. "
f"Repeated SSH authentication failures from a concentrated source were followed by successful session "
f"establishment and privileged execution patterns indicative of post-compromise activity. "
f"Outbound retrieval-style commands suggest possible payload staging or command-and-control preparation."
),
"technical_analysis": _default_progression(incident),
"investigation_notes": _default_severity_rationale(risk),
"recommended_actions": _default_actions(),
}
def _default_progression(incident: Incident) -> str:
lines = [
"1. Reconnaissance / credential spray against SSH surface from a high-velocity source IP.",
"2. Brute-force or password-spray phase producing clustered authentication failures.",
"3. Successful authentication — pivot from noise to confirmed access.",
"4. Privilege escalation via sudo or equivalent administrative channel.",
"5. Potential exfil or staging via scripted download utilities (e.g. curl/wget) to non-standard paths.",
]
if incident.timeline:
lines.append(f"6. Correlated timeline contains {len(incident.timeline)} normalized events for graph reconstruction.")
return "\n".join(lines)
def _default_severity_rationale(risk: RiskAssessment) -> str:
return (
f"Severity is driven by a composite risk score of {risk.risk_score}/100 with label {risk.severity.value}. "
f"The sequence combines authentication abuse with privilege boundary crossing, elevating impact beyond "
f"nuisance scanning. Confidence reflects rule-and-window correlation across multiple telemetry stages; "
f"treat as incident-grade until disproven by host forensics."
)
def _default_actions() -> list[str]:
return [
"Block offending IP at perimeter firewall and WAF allowlists",
"Rotate credentials and invalidate active sessions for implicated accounts",
"Inspect shell history, authorized_keys, and cron for persistence",
"Enable or enforce MFA on all break-glass and sudo-capable users",
"Isolate affected host to a quarantine VLAN for memory and disk capture",
"Review outbound DNS and proxy logs for matching IOC time windows",
]
def _extract_iocs(incident: Incident) -> list[str]:
iocs: list[str] = []
for row in incident.timeline:
msg = str(row.get("msg", ""))
for token in msg.split():
if token.count(".") == 3 and token.replace(".", "").isdigit():
iocs.append(token)
return list(dict.fromkeys(iocs))[:16]