Spaces:
Running
Running
| """SentinelAI FastAPI application — autonomous SOC control plane.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| from contextlib import asynccontextmanager | |
| import threading | |
| from pathlib import Path | |
| from typing import Annotated, Any | |
| from fastapi import Depends, FastAPI, Request, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import RedirectResponse | |
| from fastapi.staticfiles import StaticFiles | |
| ROOT = Path(__file__).resolve().parents[2] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv(ROOT / ".env") | |
| except ImportError: | |
| pass | |
| from models.schemas import ( # noqa: E402 | |
| AlertPayload, | |
| DashboardMetrics, | |
| IncidentActionBody, | |
| RawLogIngest, | |
| ReplayStartBody, | |
| WorkflowState, | |
| ) | |
| from services.event_hub import EventHub # noqa: E402 | |
| from services.metrics_store import MetricsStore # noqa: E402 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("sentinelai.api") | |
| hub = EventHub() | |
| metrics = MetricsStore() | |
| # Heavy imports (LangChain, SQLAlchemy models, agents) live inside services.pipeline — defer so uvicorn can bind immediately. | |
| _wire_lock = threading.Lock() | |
| class _Services: | |
| __slots__ = ("pipeline", "collector") | |
| def __init__(self) -> None: | |
| self.pipeline: Any = None | |
| self.collector: Any = None | |
| services = _Services() | |
| def _wire_pipeline_and_collector_sync() -> None: | |
| """Idempotent; safe across threads.""" | |
| if services.pipeline is not None: | |
| return | |
| with _wire_lock: | |
| if services.pipeline is not None: | |
| return | |
| from collectors.collector_agent import CollectorAgent # noqa: E402 | |
| from services.pipeline import SentinelPipeline # noqa: E402 | |
| logger.info("Loading SentinelPipeline module (first load can take 10–40s on cold start)") | |
| services.pipeline = SentinelPipeline(hub, metrics) | |
| services.collector = CollectorAgent(services.pipeline.ingest_from_collector) | |
| async def get_pipeline_dep() -> Any: | |
| """Dependency for routes that need the SOC pipeline.""" | |
| if services.pipeline is None: | |
| await asyncio.to_thread(_wire_pipeline_and_collector_sync) | |
| return services.pipeline | |
| PipelineDep = Annotated[Any, Depends(get_pipeline_dep)] | |
| async def _noop(_: dict) -> None: | |
| return None | |
| async def lifespan(app: FastAPI): | |
| """Yield immediately so Uvicorn finishes startup and accepts HTTP (avoids browser ERR_CONNECTION_TIMED_OUT). | |
| Redis/DB/LangGraph/pipeline wiring run in the background — /health works before collectors attach. | |
| """ | |
| async def background_startup() -> None: | |
| try: | |
| await metrics.connect_redis() | |
| if os.getenv("SKIP_DB", "").lower() in {"1", "true", "yes"}: | |
| logger.info("SKIP_DB set — skipping PostgreSQL init") | |
| else: | |
| from database.session import init_db # defer heavy SQLAlchemy/asyncpg import | |
| try: | |
| await init_db() | |
| logger.info("PostgreSQL schema ready") | |
| except Exception as e: # noqa: BLE001 | |
| logger.warning("Database init skipped: %s", e) | |
| async def langgraph_warmup() -> None: | |
| """Compile + dry-run off the critical path — importing LangGraph can take minutes on cold start.""" | |
| await asyncio.sleep(0) | |
| if os.getenv("SKIP_LANGGRAPH_WARMUP", "").lower() in {"1", "true", "yes"}: | |
| logger.info("SKIP_LANGGRAPH_WARMUP set — skipping LangGraph compile dry-run") | |
| return | |
| try: | |
| from workflows.langgraph_flow import build_soc_graph # defer LangGraph import | |
| soc_graph = build_soc_graph({"enrich": _noop, "detect": _noop, "correlate": _noop}) | |
| if soc_graph: | |
| timeout = float(os.getenv("LANGGRAPH_WARMUP_TIMEOUT_SEC", "120")) | |
| await asyncio.wait_for( | |
| soc_graph.ainvoke({"notes": [], "bootstrap": True}), | |
| timeout=timeout, | |
| ) | |
| logger.info("LangGraph SOC workflow compiled and dry-run complete") | |
| except asyncio.TimeoutError: | |
| logger.warning( | |
| "LangGraph dry-run timed out after %ss — API is up; graph may compile on first use", | |
| os.getenv("LANGGRAPH_WARMUP_TIMEOUT_SEC", "120"), | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| logger.warning("LangGraph dry-run skipped: %s", e) | |
| asyncio.create_task(langgraph_warmup()) | |
| async def wire_and_run_collectors() -> None: | |
| await asyncio.sleep(0) | |
| await asyncio.to_thread(_wire_pipeline_and_collector_sync) | |
| if services.collector is None: | |
| return | |
| services.collector.start_all_tails() | |
| if os.getenv("ENABLE_MOCK_CLOUD_POLL", "1") == "1": | |
| services.collector.start_mock_cloud_poll() | |
| asyncio.create_task(wire_and_run_collectors()) | |
| async def metrics_tick() -> None: | |
| while True: | |
| await asyncio.sleep(60) | |
| metrics.tick_frequency() | |
| asyncio.create_task(metrics_tick()) | |
| logger.info("Background SOC wiring scheduled (Redis, DB, LangGraph, collectors)") | |
| except Exception: | |
| logger.exception("Background startup failed") | |
| asyncio.create_task(background_startup()) | |
| logger.info( | |
| "HTTP layer ready — GET /health while Redis, PostgreSQL, LangGraph, and collectors initialize in the background" | |
| ) | |
| yield | |
| if services.collector is not None: | |
| services.collector.stop() | |
| app = FastAPI(title="SentinelAI SOC API", version="1.0.0", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=os.getenv("CORS_ORIGINS", "*").split(","), | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| _UI_STATIC = ROOT / "frontend" / "out" | |
| if _UI_STATIC.is_dir(): | |
| app.mount("/ui", StaticFiles(directory=str(_UI_STATIC), html=True), name="ui") | |
| async def get_session(): | |
| if os.getenv("SKIP_DB", "").lower() in {"1", "true", "yes"}: | |
| yield None | |
| return | |
| from database.session import async_session_factory # defer heavy SQLAlchemy/asyncpg import | |
| async with async_session_factory() as session: | |
| yield session | |
| async def ingest_logs(body: RawLogIngest, pipeline: PipelineDep, session: Any = Depends(get_session)): | |
| return await pipeline.ingest(body, session) | |
| async def live_events(ws: WebSocket) -> None: | |
| await hub.connect(ws) | |
| try: | |
| for row in list(hub.live_feed)[:80]: | |
| await ws.send_json(row) | |
| while True: | |
| try: | |
| await asyncio.wait_for(ws.receive_text(), timeout=20.0) | |
| except asyncio.TimeoutError: | |
| await ws.send_json({"type": "heartbeat", "ts": time.time()}) | |
| except WebSocketDisconnect: | |
| hub.disconnect(ws) | |
| finally: | |
| hub.disconnect(ws) | |
| async def detect_threats(body: RawLogIngest, pipeline: PipelineDep, session: Any = Depends(get_session)): | |
| return await pipeline.ingest(body, session) | |
| async def correlate_incidents(pipeline: PipelineDep): | |
| from agents.incident_correlation_agent import correlate | |
| incidents = correlate(pipeline._events, pipeline._findings) # noqa: SLF001 | |
| return {"incidents": [i.model_dump(mode="json") for i in incidents]} | |
| async def generate_summary(body: IncidentActionBody, pipeline: PipelineDep, session: Any = Depends(get_session)): | |
| return await pipeline.run_full_workflow_on_incident(body.incident_id, session) | |
| async def remediation(body: IncidentActionBody, pipeline: PipelineDep, session: Any = Depends(get_session)): | |
| payload = await pipeline.run_full_workflow_on_incident(body.incident_id, session) | |
| return {"remediation": payload.get("remediation")} | |
| async def send_alert_endpoint(body: AlertPayload, session: Any = Depends(get_session)): | |
| from agents.alerting_agent import send_alert as _send | |
| from database.models import AlertRecord | |
| result = await _send(body) | |
| if session is not None: | |
| session.add( | |
| AlertRecord( | |
| channel=body.channel, | |
| title=body.title, | |
| body=body.body, | |
| severity=body.severity.value, | |
| ) | |
| ) | |
| await session.commit() | |
| return result | |
| async def dashboard_metrics() -> DashboardMetrics: | |
| snap = metrics.snapshot() | |
| return DashboardMetrics(**snap) | |
| async def rocm_panel(): | |
| """AMD ROCm story + demo inference/agent load (simulated GPU sway for UI).""" | |
| return metrics.rocm_panel() | |
| async def agent_activity(): | |
| return {"items": list(hub.agent_log)[:200]} | |
| async def replay_start(body: ReplayStartBody = ReplayStartBody()): | |
| """Replay buffered threat_feed / detection / incident frames to all WebSocket clients.""" | |
| hub.schedule_replay(delay_ms=body.delay_ms) | |
| return {"status": "scheduled", "delay_ms": body.delay_ms, "buffered": len(hub.replay_buffer)} | |
| async def replay_buffer(): | |
| return {"count": len(hub.replay_buffer), "items": list(hub.replay_buffer)} | |
| async def root(request: Request): | |
| """Browsers get the Next dashboard at `/ui` when static export is baked in; API clients keep JSON.""" | |
| accept = request.headers.get("accept") or "" | |
| if _UI_STATIC.is_dir() and accept.startswith("text/html"): | |
| return RedirectResponse(url="/ui/", status_code=302) | |
| return { | |
| "service": "SentinelAI SOC API", | |
| "dashboard": "/ui/", | |
| "docs": "/docs", | |
| "health": "/health", | |
| "openapi_json": "/openapi.json", | |
| } | |
| async def health(): | |
| return {"status": "ok", "service": "sentinelai"} | |
| async def workflow_state(pipeline: PipelineDep) -> WorkflowState: | |
| return WorkflowState( | |
| events=pipeline._events[-50:], # noqa: SLF001 | |
| findings=pipeline._findings[-100:], # noqa: SLF001 | |
| incidents=pipeline._incidents[-20:], # noqa: SLF001 | |
| ) | |