File size: 10,760 Bytes
8b3905d
 
 
 
 
 
 
 
 
 
 
 
 
 
81fe24b
8b3905d
81fe24b
 
8b3905d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81fe24b
 
 
 
8b3905d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81fe24b
 
 
 
 
8b3905d
 
81fe24b
8b3905d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
"""SentinelAI FastAPI application β€” autonomous SOC control plane."""

from __future__ import annotations

import asyncio
import logging
import os
import sys
import time
from contextlib import asynccontextmanager
import threading
from pathlib import Path
from typing import Annotated, Any

from fastapi import Depends, FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles

ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

try:
    from dotenv import load_dotenv

    load_dotenv(ROOT / ".env")
except ImportError:
    pass

from models.schemas import (  # noqa: E402
    AlertPayload,
    DashboardMetrics,
    IncidentActionBody,
    RawLogIngest,
    ReplayStartBody,
    WorkflowState,
)
from services.event_hub import EventHub  # noqa: E402
from services.metrics_store import MetricsStore  # noqa: E402

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("sentinelai.api")

hub = EventHub()
metrics = MetricsStore()

# Heavy imports (LangChain, SQLAlchemy models, agents) live inside services.pipeline β€” defer so uvicorn can bind immediately.
_wire_lock = threading.Lock()


class _Services:
    __slots__ = ("pipeline", "collector")

    def __init__(self) -> None:
        self.pipeline: Any = None
        self.collector: Any = None


services = _Services()


def _wire_pipeline_and_collector_sync() -> None:
    """Idempotent; safe across threads."""
    if services.pipeline is not None:
        return
    with _wire_lock:
        if services.pipeline is not None:
            return
        from collectors.collector_agent import CollectorAgent  # noqa: E402
        from services.pipeline import SentinelPipeline  # noqa: E402

        logger.info("Loading SentinelPipeline module (first load can take 10–40s on cold start)")
        services.pipeline = SentinelPipeline(hub, metrics)
        services.collector = CollectorAgent(services.pipeline.ingest_from_collector)


async def get_pipeline_dep() -> Any:
    """Dependency for routes that need the SOC pipeline."""
    if services.pipeline is None:
        await asyncio.to_thread(_wire_pipeline_and_collector_sync)
    return services.pipeline


PipelineDep = Annotated[Any, Depends(get_pipeline_dep)]


async def _noop(_: dict) -> None:
    return None


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Yield immediately so Uvicorn finishes startup and accepts HTTP (avoids browser ERR_CONNECTION_TIMED_OUT).

    Redis/DB/LangGraph/pipeline wiring run in the background β€” /health works before collectors attach.
    """

    async def background_startup() -> None:
        try:
            await metrics.connect_redis()
            if os.getenv("SKIP_DB", "").lower() in {"1", "true", "yes"}:
                logger.info("SKIP_DB set β€” skipping PostgreSQL init")
            else:
                from database.session import init_db  # defer heavy SQLAlchemy/asyncpg import

                try:
                    await init_db()
                    logger.info("PostgreSQL schema ready")
                except Exception as e:  # noqa: BLE001
                    logger.warning("Database init skipped: %s", e)

            async def langgraph_warmup() -> None:
                """Compile + dry-run off the critical path β€” importing LangGraph can take minutes on cold start."""
                await asyncio.sleep(0)
                if os.getenv("SKIP_LANGGRAPH_WARMUP", "").lower() in {"1", "true", "yes"}:
                    logger.info("SKIP_LANGGRAPH_WARMUP set β€” skipping LangGraph compile dry-run")
                    return
                try:
                    from workflows.langgraph_flow import build_soc_graph  # defer LangGraph import

                    soc_graph = build_soc_graph({"enrich": _noop, "detect": _noop, "correlate": _noop})
                    if soc_graph:
                        timeout = float(os.getenv("LANGGRAPH_WARMUP_TIMEOUT_SEC", "120"))
                        await asyncio.wait_for(
                            soc_graph.ainvoke({"notes": [], "bootstrap": True}),
                            timeout=timeout,
                        )
                        logger.info("LangGraph SOC workflow compiled and dry-run complete")
                except asyncio.TimeoutError:
                    logger.warning(
                        "LangGraph dry-run timed out after %ss β€” API is up; graph may compile on first use",
                        os.getenv("LANGGRAPH_WARMUP_TIMEOUT_SEC", "120"),
                    )
                except Exception as e:  # noqa: BLE001
                    logger.warning("LangGraph dry-run skipped: %s", e)

            asyncio.create_task(langgraph_warmup())

            async def wire_and_run_collectors() -> None:
                await asyncio.sleep(0)
                await asyncio.to_thread(_wire_pipeline_and_collector_sync)
                if services.collector is None:
                    return
                services.collector.start_all_tails()
                if os.getenv("ENABLE_MOCK_CLOUD_POLL", "1") == "1":
                    services.collector.start_mock_cloud_poll()

            asyncio.create_task(wire_and_run_collectors())

            async def metrics_tick() -> None:
                while True:
                    await asyncio.sleep(60)
                    metrics.tick_frequency()

            asyncio.create_task(metrics_tick())
            logger.info("Background SOC wiring scheduled (Redis, DB, LangGraph, collectors)")
        except Exception:
            logger.exception("Background startup failed")

    asyncio.create_task(background_startup())
    logger.info(
        "HTTP layer ready β€” GET /health while Redis, PostgreSQL, LangGraph, and collectors initialize in the background"
    )
    yield
    if services.collector is not None:
        services.collector.stop()


app = FastAPI(title="SentinelAI SOC API", version="1.0.0", lifespan=lifespan)
app.add_middleware(
    CORSMiddleware,
    allow_origins=os.getenv("CORS_ORIGINS", "*").split(","),
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

_UI_STATIC = ROOT / "frontend" / "out"
if _UI_STATIC.is_dir():
    app.mount("/ui", StaticFiles(directory=str(_UI_STATIC), html=True), name="ui")


async def get_session():
    if os.getenv("SKIP_DB", "").lower() in {"1", "true", "yes"}:
        yield None
        return
    from database.session import async_session_factory  # defer heavy SQLAlchemy/asyncpg import

    async with async_session_factory() as session:
        yield session


@app.post("/ingest-logs")
async def ingest_logs(body: RawLogIngest, pipeline: PipelineDep, session: Any = Depends(get_session)):
    return await pipeline.ingest(body, session)


@app.websocket("/live-events")
async def live_events(ws: WebSocket) -> None:
    await hub.connect(ws)
    try:
        for row in list(hub.live_feed)[:80]:
            await ws.send_json(row)
        while True:
            try:
                await asyncio.wait_for(ws.receive_text(), timeout=20.0)
            except asyncio.TimeoutError:
                await ws.send_json({"type": "heartbeat", "ts": time.time()})
    except WebSocketDisconnect:
        hub.disconnect(ws)
    finally:
        hub.disconnect(ws)


@app.post("/detect-threats")
async def detect_threats(body: RawLogIngest, pipeline: PipelineDep, session: Any = Depends(get_session)):
    return await pipeline.ingest(body, session)


@app.post("/correlate-incidents")
async def correlate_incidents(pipeline: PipelineDep):
    from agents.incident_correlation_agent import correlate

    incidents = correlate(pipeline._events, pipeline._findings)  # noqa: SLF001
    return {"incidents": [i.model_dump(mode="json") for i in incidents]}


@app.post("/generate-summary")
async def generate_summary(body: IncidentActionBody, pipeline: PipelineDep, session: Any = Depends(get_session)):
    return await pipeline.run_full_workflow_on_incident(body.incident_id, session)


@app.post("/remediation")
async def remediation(body: IncidentActionBody, pipeline: PipelineDep, session: Any = Depends(get_session)):
    payload = await pipeline.run_full_workflow_on_incident(body.incident_id, session)
    return {"remediation": payload.get("remediation")}


@app.post("/send-alert")
async def send_alert_endpoint(body: AlertPayload, session: Any = Depends(get_session)):
    from agents.alerting_agent import send_alert as _send
    from database.models import AlertRecord

    result = await _send(body)
    if session is not None:
        session.add(
            AlertRecord(
                channel=body.channel,
                title=body.title,
                body=body.body,
                severity=body.severity.value,
            )
        )
        await session.commit()
    return result


@app.get("/dashboard-metrics")
async def dashboard_metrics() -> DashboardMetrics:
    snap = metrics.snapshot()
    return DashboardMetrics(**snap)


@app.get("/rocm-panel")
async def rocm_panel():
    """AMD ROCm story + demo inference/agent load (simulated GPU sway for UI)."""
    return metrics.rocm_panel()


@app.get("/agent-activity")
async def agent_activity():
    return {"items": list(hub.agent_log)[:200]}


@app.post("/replay/start")
async def replay_start(body: ReplayStartBody = ReplayStartBody()):
    """Replay buffered threat_feed / detection / incident frames to all WebSocket clients."""
    hub.schedule_replay(delay_ms=body.delay_ms)
    return {"status": "scheduled", "delay_ms": body.delay_ms, "buffered": len(hub.replay_buffer)}


@app.get("/replay-buffer")
async def replay_buffer():
    return {"count": len(hub.replay_buffer), "items": list(hub.replay_buffer)}


@app.get("/")
async def root(request: Request):
    """Browsers get the Next dashboard at `/ui` when static export is baked in; API clients keep JSON."""
    accept = request.headers.get("accept") or ""
    if _UI_STATIC.is_dir() and accept.startswith("text/html"):
        return RedirectResponse(url="/ui/", status_code=302)
    return {
        "service": "SentinelAI SOC API",
        "dashboard": "/ui/",
        "docs": "/docs",
        "health": "/health",
        "openapi_json": "/openapi.json",
    }


@app.get("/health")
async def health():
    return {"status": "ok", "service": "sentinelai"}


@app.get("/workflow-state")
async def workflow_state(pipeline: PipelineDep) -> WorkflowState:
    return WorkflowState(
        events=pipeline._events[-50:],  # noqa: SLF001
        findings=pipeline._findings[-100:],  # noqa: SLF001
        incidents=pipeline._incidents[-20:],  # noqa: SLF001
    )