SentinelAI / services /pipeline.py
iitian's picture
Sync SentinelAI project and add Hugging Face Docker Space layout.
8b3905d
"""End-to-end async pipeline wiring parsers, agents, persistence, and hub."""
from __future__ import annotations
import asyncio
import logging
import os
import time
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from agents.ai_analyst_agent import generate_analyst_report
from agents.alerting_agent import send_alert
from agents.incident_correlation_agent import correlate
from agents.normalization_agent import normalize_event
from agents.remediation_agent import build_remediation
from agents.risk_scoring_agent import score_incident
from agents.threat_detection_agent import detect_threats
from agents.threat_enrichment_agent import enrich_event
from database.models import AlertRecord, EventRecord, IncidentRecord
from models.schemas import AlertPayload, RawLogIngest, Severity
from parsers.parser_agent import parse_raw
from services.chroma_memory import remember_incident
from services.event_hub import EventHub
from services.metrics_store import MetricsStore
logger = logging.getLogger("sentinelai.pipeline")
class SentinelPipeline:
def __init__(self, hub: EventHub, metrics: MetricsStore) -> None:
self.hub = hub
self.metrics = metrics
self._events: list[Any] = []
self._findings: list[Any] = []
self._incidents: list[Any] = []
self._last_auto_ai_at: float = 0.0
async def ingest_from_collector(self, ingest: RawLogIngest) -> None:
"""Fire-and-forget path for autonomous collectors (no DB session)."""
try:
await self.ingest(ingest, None)
except Exception: # noqa: BLE001
logger.exception("Collector ingest failed — continuing tail")
async def _auto_analyst(self, incident_id: UUID) -> None:
try:
await self.run_full_workflow_on_incident(incident_id, None)
except Exception: # noqa: BLE001
logger.exception("Auto AI analyst failed for incident %s", incident_id)
async def ingest(self, ingest: RawLogIngest, session: AsyncSession | None) -> dict[str, Any]:
await self.hub.log_agent("collector", "running", f"ingest {ingest.source}")
await self.hub.log_agent("parser", "running", "parse raw")
event = parse_raw(ingest)
event = normalize_event(event)
await self.hub.log_agent("normalization", "running", "schema unify")
await self.hub.log_agent("threat_enrichment", "running", "intel overlay")
enriched = await enrich_event(event)
geo = enriched.enrichment.get("geo") or {}
self.metrics.record_country(geo.get("countryCode"))
await self.hub.broadcast(
{
"type": "threat_feed",
"severity": enriched.severity.value,
"message": enriched.message,
"source_ip": enriched.source_ip,
"event_type": enriched.event_type,
"ts": enriched.timestamp.isoformat(),
}
)
await self.hub.log_agent("threat_detection", "running", "rules + heuristics")
findings = detect_threats(enriched)
for f in findings:
self.metrics.bump_threat()
await self.hub.broadcast(
{
"type": "detection",
"technique": f.technique,
"severity": f.severity.value,
"confidence": f.confidence,
"description": f.description,
}
)
self._events.append(enriched)
self._findings.extend(findings)
self.metrics.inc_frequency()
if session is not None:
await self._persist_event(session, enriched)
# Correlation & downstream when graph warrants
incidents = correlate(self._events, self._findings)
self._incidents = incidents
self.metrics.set_active_incidents(len(incidents))
out: dict[str, Any] = {"event_id": str(enriched.id), "findings": [f.model_dump() for f in findings]}
if incidents:
latest = incidents[-1]
risk = score_incident(latest, self._events, self._findings)
self.metrics.record_risk(risk.risk_score)
await self.hub.broadcast(
{
"type": "incident",
"title": latest.title,
"summary": latest.summary,
"risk": risk.model_dump(),
"timeline": latest.timeline,
}
)
remember_incident(latest.summary, {"severity": risk.severity.value})
if session is not None:
await self._persist_incident(session, latest, risk)
out["incident"] = latest.model_dump(mode="json")
out["risk"] = risk.model_dump()
if os.getenv("AUTO_AI_ON_INCIDENT", "1").lower() not in {"0", "false", "no", "off"}:
min_gap = float(os.getenv("AUTO_AI_MIN_SEC", "75"))
now = time.monotonic()
if now - self._last_auto_ai_at >= min_gap:
self._last_auto_ai_at = now
asyncio.create_task(self._auto_analyst(latest.id))
return out
async def run_full_workflow_on_incident(self, incident_id: UUID, session: AsyncSession | None) -> dict[str, Any]:
inc = next((i for i in self._incidents if i.id == incident_id), None)
if not inc:
return {"error": "incident not found"}
risk = score_incident(inc, self._events, self._findings)
await self.hub.log_agent("ai_analyst", "running", "LLM / ROCm inference")
report = await generate_analyst_report(inc, risk)
await self.hub.log_agent("remediation", "running", "playbook synthesis")
rem = build_remediation(inc, risk, report)
await self.hub.broadcast(
{
"type": "ai_report",
"incident_id": str(inc.id),
"executive": report.executive_summary,
"technical": report.technical_analysis,
"investigation_notes": report.investigation_notes,
"recommended_actions": report.recommended_actions,
}
)
payload = {
"incident": inc.model_dump(mode="json"),
"risk": risk.model_dump(),
"report": report.model_dump(),
"remediation": rem.model_dump(),
}
if session is not None:
await self._persist_alert(session, "stored", f"Analyst report {inc.id}", report.executive_summary, risk.severity, inc.id)
return payload
async def _persist_event(self, session: AsyncSession, enriched: Any) -> None:
rec = EventRecord(
id=enriched.id,
timestamp=enriched.timestamp,
event_type=enriched.event_type,
source_ip=enriched.source_ip,
host=enriched.host,
severity=enriched.severity.value,
payload=enriched.model_dump(mode="json"),
)
session.add(rec)
await session.commit()
async def _persist_incident(self, session: AsyncSession, incident: Any, risk: Any) -> None:
rec = IncidentRecord(
id=incident.id,
title=incident.title,
summary=incident.summary,
graph={"nodes": [n.model_dump(mode="json") for n in incident.nodes], "edges": [e.model_dump(mode="json") for e in incident.edges]},
risk_score=risk.risk_score,
severity=risk.severity.value,
)
session.add(rec)
await session.commit()
async def _persist_alert(self, session: AsyncSession, channel: str, title: str, body: str, severity: Severity, incident_id: UUID) -> None:
session.add(
AlertRecord(
channel=channel,
title=title,
body=body,
severity=severity.value,
incident_id=incident_id,
)
)
await session.commit()