"""SentinelAI FastAPI application — autonomous SOC control plane.""" from __future__ import annotations import asyncio import logging import os import sys import time from contextlib import asynccontextmanager import threading from pathlib import Path from typing import Annotated, Any from fastapi import Depends, FastAPI, Request, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse from fastapi.staticfiles import StaticFiles ROOT = Path(__file__).resolve().parents[2] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) try: from dotenv import load_dotenv load_dotenv(ROOT / ".env") except ImportError: pass from models.schemas import ( # noqa: E402 AlertPayload, DashboardMetrics, IncidentActionBody, RawLogIngest, ReplayStartBody, WorkflowState, ) from services.event_hub import EventHub # noqa: E402 from services.metrics_store import MetricsStore # noqa: E402 logging.basicConfig(level=logging.INFO) logger = logging.getLogger("sentinelai.api") hub = EventHub() metrics = MetricsStore() # Heavy imports (LangChain, SQLAlchemy models, agents) live inside services.pipeline — defer so uvicorn can bind immediately. _wire_lock = threading.Lock() class _Services: __slots__ = ("pipeline", "collector") def __init__(self) -> None: self.pipeline: Any = None self.collector: Any = None services = _Services() def _wire_pipeline_and_collector_sync() -> None: """Idempotent; safe across threads.""" if services.pipeline is not None: return with _wire_lock: if services.pipeline is not None: return from collectors.collector_agent import CollectorAgent # noqa: E402 from services.pipeline import SentinelPipeline # noqa: E402 logger.info("Loading SentinelPipeline module (first load can take 10–40s on cold start)") services.pipeline = SentinelPipeline(hub, metrics) services.collector = CollectorAgent(services.pipeline.ingest_from_collector) async def get_pipeline_dep() -> Any: """Dependency for routes that need the SOC pipeline.""" if services.pipeline is None: await asyncio.to_thread(_wire_pipeline_and_collector_sync) return services.pipeline PipelineDep = Annotated[Any, Depends(get_pipeline_dep)] async def _noop(_: dict) -> None: return None @asynccontextmanager async def lifespan(app: FastAPI): """Yield immediately so Uvicorn finishes startup and accepts HTTP (avoids browser ERR_CONNECTION_TIMED_OUT). Redis/DB/LangGraph/pipeline wiring run in the background — /health works before collectors attach. """ async def background_startup() -> None: try: await metrics.connect_redis() if os.getenv("SKIP_DB", "").lower() in {"1", "true", "yes"}: logger.info("SKIP_DB set — skipping PostgreSQL init") else: from database.session import init_db # defer heavy SQLAlchemy/asyncpg import try: await init_db() logger.info("PostgreSQL schema ready") except Exception as e: # noqa: BLE001 logger.warning("Database init skipped: %s", e) async def langgraph_warmup() -> None: """Compile + dry-run off the critical path — importing LangGraph can take minutes on cold start.""" await asyncio.sleep(0) if os.getenv("SKIP_LANGGRAPH_WARMUP", "").lower() in {"1", "true", "yes"}: logger.info("SKIP_LANGGRAPH_WARMUP set — skipping LangGraph compile dry-run") return try: from workflows.langgraph_flow import build_soc_graph # defer LangGraph import soc_graph = build_soc_graph({"enrich": _noop, "detect": _noop, "correlate": _noop}) if soc_graph: timeout = float(os.getenv("LANGGRAPH_WARMUP_TIMEOUT_SEC", "120")) await asyncio.wait_for( soc_graph.ainvoke({"notes": [], "bootstrap": True}), timeout=timeout, ) logger.info("LangGraph SOC workflow compiled and dry-run complete") except asyncio.TimeoutError: logger.warning( "LangGraph dry-run timed out after %ss — API is up; graph may compile on first use", os.getenv("LANGGRAPH_WARMUP_TIMEOUT_SEC", "120"), ) except Exception as e: # noqa: BLE001 logger.warning("LangGraph dry-run skipped: %s", e) asyncio.create_task(langgraph_warmup()) async def wire_and_run_collectors() -> None: await asyncio.sleep(0) await asyncio.to_thread(_wire_pipeline_and_collector_sync) if services.collector is None: return services.collector.start_all_tails() if os.getenv("ENABLE_MOCK_CLOUD_POLL", "1") == "1": services.collector.start_mock_cloud_poll() asyncio.create_task(wire_and_run_collectors()) async def metrics_tick() -> None: while True: await asyncio.sleep(60) metrics.tick_frequency() asyncio.create_task(metrics_tick()) logger.info("Background SOC wiring scheduled (Redis, DB, LangGraph, collectors)") except Exception: logger.exception("Background startup failed") asyncio.create_task(background_startup()) logger.info( "HTTP layer ready — GET /health while Redis, PostgreSQL, LangGraph, and collectors initialize in the background" ) yield if services.collector is not None: services.collector.stop() app = FastAPI(title="SentinelAI SOC API", version="1.0.0", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=os.getenv("CORS_ORIGINS", "*").split(","), allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) _UI_STATIC = ROOT / "frontend" / "out" if _UI_STATIC.is_dir(): app.mount("/ui", StaticFiles(directory=str(_UI_STATIC), html=True), name="ui") async def get_session(): if os.getenv("SKIP_DB", "").lower() in {"1", "true", "yes"}: yield None return from database.session import async_session_factory # defer heavy SQLAlchemy/asyncpg import async with async_session_factory() as session: yield session @app.post("/ingest-logs") async def ingest_logs(body: RawLogIngest, pipeline: PipelineDep, session: Any = Depends(get_session)): return await pipeline.ingest(body, session) @app.websocket("/live-events") async def live_events(ws: WebSocket) -> None: await hub.connect(ws) try: for row in list(hub.live_feed)[:80]: await ws.send_json(row) while True: try: await asyncio.wait_for(ws.receive_text(), timeout=20.0) except asyncio.TimeoutError: await ws.send_json({"type": "heartbeat", "ts": time.time()}) except WebSocketDisconnect: hub.disconnect(ws) finally: hub.disconnect(ws) @app.post("/detect-threats") async def detect_threats(body: RawLogIngest, pipeline: PipelineDep, session: Any = Depends(get_session)): return await pipeline.ingest(body, session) @app.post("/correlate-incidents") async def correlate_incidents(pipeline: PipelineDep): from agents.incident_correlation_agent import correlate incidents = correlate(pipeline._events, pipeline._findings) # noqa: SLF001 return {"incidents": [i.model_dump(mode="json") for i in incidents]} @app.post("/generate-summary") async def generate_summary(body: IncidentActionBody, pipeline: PipelineDep, session: Any = Depends(get_session)): return await pipeline.run_full_workflow_on_incident(body.incident_id, session) @app.post("/remediation") async def remediation(body: IncidentActionBody, pipeline: PipelineDep, session: Any = Depends(get_session)): payload = await pipeline.run_full_workflow_on_incident(body.incident_id, session) return {"remediation": payload.get("remediation")} @app.post("/send-alert") async def send_alert_endpoint(body: AlertPayload, session: Any = Depends(get_session)): from agents.alerting_agent import send_alert as _send from database.models import AlertRecord result = await _send(body) if session is not None: session.add( AlertRecord( channel=body.channel, title=body.title, body=body.body, severity=body.severity.value, ) ) await session.commit() return result @app.get("/dashboard-metrics") async def dashboard_metrics() -> DashboardMetrics: snap = metrics.snapshot() return DashboardMetrics(**snap) @app.get("/rocm-panel") async def rocm_panel(): """AMD ROCm story + demo inference/agent load (simulated GPU sway for UI).""" return metrics.rocm_panel() @app.get("/agent-activity") async def agent_activity(): return {"items": list(hub.agent_log)[:200]} @app.post("/replay/start") async def replay_start(body: ReplayStartBody = ReplayStartBody()): """Replay buffered threat_feed / detection / incident frames to all WebSocket clients.""" hub.schedule_replay(delay_ms=body.delay_ms) return {"status": "scheduled", "delay_ms": body.delay_ms, "buffered": len(hub.replay_buffer)} @app.get("/replay-buffer") async def replay_buffer(): return {"count": len(hub.replay_buffer), "items": list(hub.replay_buffer)} @app.get("/") async def root(request: Request): """Browsers get the Next dashboard at `/ui` when static export is baked in; API clients keep JSON.""" accept = request.headers.get("accept") or "" if _UI_STATIC.is_dir() and accept.startswith("text/html"): return RedirectResponse(url="/ui/", status_code=302) return { "service": "SentinelAI SOC API", "dashboard": "/ui/", "docs": "/docs", "health": "/health", "openapi_json": "/openapi.json", } @app.get("/health") async def health(): return {"status": "ok", "service": "sentinelai"} @app.get("/workflow-state") async def workflow_state(pipeline: PipelineDep) -> WorkflowState: return WorkflowState( events=pipeline._events[-50:], # noqa: SLF001 findings=pipeline._findings[-100:], # noqa: SLF001 incidents=pipeline._incidents[-20:], # noqa: SLF001 )