SentinelAI / backend /app /main.py
iitian's picture
Serve Next.js SOC dashboard at /ui with FastAPI redirect from /.
81fe24b
"""SentinelAI FastAPI application — autonomous SOC control plane."""
from __future__ import annotations
import asyncio
import logging
import os
import sys
import time
from contextlib import asynccontextmanager
import threading
from pathlib import Path
from typing import Annotated, Any
from fastapi import Depends, FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
try:
from dotenv import load_dotenv
load_dotenv(ROOT / ".env")
except ImportError:
pass
from models.schemas import ( # noqa: E402
AlertPayload,
DashboardMetrics,
IncidentActionBody,
RawLogIngest,
ReplayStartBody,
WorkflowState,
)
from services.event_hub import EventHub # noqa: E402
from services.metrics_store import MetricsStore # noqa: E402
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("sentinelai.api")
hub = EventHub()
metrics = MetricsStore()
# Heavy imports (LangChain, SQLAlchemy models, agents) live inside services.pipeline — defer so uvicorn can bind immediately.
_wire_lock = threading.Lock()
class _Services:
__slots__ = ("pipeline", "collector")
def __init__(self) -> None:
self.pipeline: Any = None
self.collector: Any = None
services = _Services()
def _wire_pipeline_and_collector_sync() -> None:
"""Idempotent; safe across threads."""
if services.pipeline is not None:
return
with _wire_lock:
if services.pipeline is not None:
return
from collectors.collector_agent import CollectorAgent # noqa: E402
from services.pipeline import SentinelPipeline # noqa: E402
logger.info("Loading SentinelPipeline module (first load can take 10–40s on cold start)")
services.pipeline = SentinelPipeline(hub, metrics)
services.collector = CollectorAgent(services.pipeline.ingest_from_collector)
async def get_pipeline_dep() -> Any:
"""Dependency for routes that need the SOC pipeline."""
if services.pipeline is None:
await asyncio.to_thread(_wire_pipeline_and_collector_sync)
return services.pipeline
PipelineDep = Annotated[Any, Depends(get_pipeline_dep)]
async def _noop(_: dict) -> None:
return None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Yield immediately so Uvicorn finishes startup and accepts HTTP (avoids browser ERR_CONNECTION_TIMED_OUT).
Redis/DB/LangGraph/pipeline wiring run in the background — /health works before collectors attach.
"""
async def background_startup() -> None:
try:
await metrics.connect_redis()
if os.getenv("SKIP_DB", "").lower() in {"1", "true", "yes"}:
logger.info("SKIP_DB set — skipping PostgreSQL init")
else:
from database.session import init_db # defer heavy SQLAlchemy/asyncpg import
try:
await init_db()
logger.info("PostgreSQL schema ready")
except Exception as e: # noqa: BLE001
logger.warning("Database init skipped: %s", e)
async def langgraph_warmup() -> None:
"""Compile + dry-run off the critical path — importing LangGraph can take minutes on cold start."""
await asyncio.sleep(0)
if os.getenv("SKIP_LANGGRAPH_WARMUP", "").lower() in {"1", "true", "yes"}:
logger.info("SKIP_LANGGRAPH_WARMUP set — skipping LangGraph compile dry-run")
return
try:
from workflows.langgraph_flow import build_soc_graph # defer LangGraph import
soc_graph = build_soc_graph({"enrich": _noop, "detect": _noop, "correlate": _noop})
if soc_graph:
timeout = float(os.getenv("LANGGRAPH_WARMUP_TIMEOUT_SEC", "120"))
await asyncio.wait_for(
soc_graph.ainvoke({"notes": [], "bootstrap": True}),
timeout=timeout,
)
logger.info("LangGraph SOC workflow compiled and dry-run complete")
except asyncio.TimeoutError:
logger.warning(
"LangGraph dry-run timed out after %ss — API is up; graph may compile on first use",
os.getenv("LANGGRAPH_WARMUP_TIMEOUT_SEC", "120"),
)
except Exception as e: # noqa: BLE001
logger.warning("LangGraph dry-run skipped: %s", e)
asyncio.create_task(langgraph_warmup())
async def wire_and_run_collectors() -> None:
await asyncio.sleep(0)
await asyncio.to_thread(_wire_pipeline_and_collector_sync)
if services.collector is None:
return
services.collector.start_all_tails()
if os.getenv("ENABLE_MOCK_CLOUD_POLL", "1") == "1":
services.collector.start_mock_cloud_poll()
asyncio.create_task(wire_and_run_collectors())
async def metrics_tick() -> None:
while True:
await asyncio.sleep(60)
metrics.tick_frequency()
asyncio.create_task(metrics_tick())
logger.info("Background SOC wiring scheduled (Redis, DB, LangGraph, collectors)")
except Exception:
logger.exception("Background startup failed")
asyncio.create_task(background_startup())
logger.info(
"HTTP layer ready — GET /health while Redis, PostgreSQL, LangGraph, and collectors initialize in the background"
)
yield
if services.collector is not None:
services.collector.stop()
app = FastAPI(title="SentinelAI SOC API", version="1.0.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=os.getenv("CORS_ORIGINS", "*").split(","),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
_UI_STATIC = ROOT / "frontend" / "out"
if _UI_STATIC.is_dir():
app.mount("/ui", StaticFiles(directory=str(_UI_STATIC), html=True), name="ui")
async def get_session():
if os.getenv("SKIP_DB", "").lower() in {"1", "true", "yes"}:
yield None
return
from database.session import async_session_factory # defer heavy SQLAlchemy/asyncpg import
async with async_session_factory() as session:
yield session
@app.post("/ingest-logs")
async def ingest_logs(body: RawLogIngest, pipeline: PipelineDep, session: Any = Depends(get_session)):
return await pipeline.ingest(body, session)
@app.websocket("/live-events")
async def live_events(ws: WebSocket) -> None:
await hub.connect(ws)
try:
for row in list(hub.live_feed)[:80]:
await ws.send_json(row)
while True:
try:
await asyncio.wait_for(ws.receive_text(), timeout=20.0)
except asyncio.TimeoutError:
await ws.send_json({"type": "heartbeat", "ts": time.time()})
except WebSocketDisconnect:
hub.disconnect(ws)
finally:
hub.disconnect(ws)
@app.post("/detect-threats")
async def detect_threats(body: RawLogIngest, pipeline: PipelineDep, session: Any = Depends(get_session)):
return await pipeline.ingest(body, session)
@app.post("/correlate-incidents")
async def correlate_incidents(pipeline: PipelineDep):
from agents.incident_correlation_agent import correlate
incidents = correlate(pipeline._events, pipeline._findings) # noqa: SLF001
return {"incidents": [i.model_dump(mode="json") for i in incidents]}
@app.post("/generate-summary")
async def generate_summary(body: IncidentActionBody, pipeline: PipelineDep, session: Any = Depends(get_session)):
return await pipeline.run_full_workflow_on_incident(body.incident_id, session)
@app.post("/remediation")
async def remediation(body: IncidentActionBody, pipeline: PipelineDep, session: Any = Depends(get_session)):
payload = await pipeline.run_full_workflow_on_incident(body.incident_id, session)
return {"remediation": payload.get("remediation")}
@app.post("/send-alert")
async def send_alert_endpoint(body: AlertPayload, session: Any = Depends(get_session)):
from agents.alerting_agent import send_alert as _send
from database.models import AlertRecord
result = await _send(body)
if session is not None:
session.add(
AlertRecord(
channel=body.channel,
title=body.title,
body=body.body,
severity=body.severity.value,
)
)
await session.commit()
return result
@app.get("/dashboard-metrics")
async def dashboard_metrics() -> DashboardMetrics:
snap = metrics.snapshot()
return DashboardMetrics(**snap)
@app.get("/rocm-panel")
async def rocm_panel():
"""AMD ROCm story + demo inference/agent load (simulated GPU sway for UI)."""
return metrics.rocm_panel()
@app.get("/agent-activity")
async def agent_activity():
return {"items": list(hub.agent_log)[:200]}
@app.post("/replay/start")
async def replay_start(body: ReplayStartBody = ReplayStartBody()):
"""Replay buffered threat_feed / detection / incident frames to all WebSocket clients."""
hub.schedule_replay(delay_ms=body.delay_ms)
return {"status": "scheduled", "delay_ms": body.delay_ms, "buffered": len(hub.replay_buffer)}
@app.get("/replay-buffer")
async def replay_buffer():
return {"count": len(hub.replay_buffer), "items": list(hub.replay_buffer)}
@app.get("/")
async def root(request: Request):
"""Browsers get the Next dashboard at `/ui` when static export is baked in; API clients keep JSON."""
accept = request.headers.get("accept") or ""
if _UI_STATIC.is_dir() and accept.startswith("text/html"):
return RedirectResponse(url="/ui/", status_code=302)
return {
"service": "SentinelAI SOC API",
"dashboard": "/ui/",
"docs": "/docs",
"health": "/health",
"openapi_json": "/openapi.json",
}
@app.get("/health")
async def health():
return {"status": "ok", "service": "sentinelai"}
@app.get("/workflow-state")
async def workflow_state(pipeline: PipelineDep) -> WorkflowState:
return WorkflowState(
events=pipeline._events[-50:], # noqa: SLF001
findings=pipeline._findings[-100:], # noqa: SLF001
incidents=pipeline._incidents[-20:], # noqa: SLF001
)