| """ |
| main.py β FastAPI WebSocket hub with Redis-backed claim caching. |
| |
| Architecture: |
| Browser extension β WS connection β ConnectionManager |
| β |
| Redis cache check (xxhash key) |
| β miss |
| Gatekeeper (Groq) |
| β fact |
| RAG pipeline + Trust graph |
| β |
| Prefect multi-agent flow |
| β |
| AnalysisResult β WS push to extension |
| """ |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import hashlib |
| import json |
| import os |
| import time |
| from contextlib import asynccontextmanager |
| from typing import Any |
|
|
| import orjson |
| import redis.asyncio as aioredis |
| import structlog |
| import xxhash |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import HTMLResponse |
| from pydantic import BaseModel, Field |
| from pydantic_settings import BaseSettings |
|
|
| from agents import run_analysis_flow |
| from gatekeeper import classify_claim |
| from rag_pipeline import build_rag_context |
|
|
| |
| structlog.configure( |
| processors=[ |
| structlog.stdlib.add_log_level, |
| structlog.processors.TimeStamper(fmt="iso"), |
| structlog.dev.ConsoleRenderer(colors=False), |
| ], |
| wrapper_class=structlog.make_filtering_bound_logger(20), |
| context_class=dict, |
| logger_factory=structlog.PrintLoggerFactory(), |
| ) |
| log = structlog.get_logger(__name__) |
|
|
| |
| class Settings(BaseSettings): |
| groq_api_key: str = os.getenv("GROQ_API_KEY", "") |
| anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "") |
| openai_api_key: str = os.getenv("OPENAI_API_KEY", "") |
| redis_url: str = os.getenv("REDIS_URL", "redis://localhost:6379") |
| qdrant_url: str = os.getenv("QDRANT_URL", "http://localhost:6333") |
| memgraph_host: str = os.getenv("MEMGRAPH_HOST", "localhost") |
| memgraph_port: int = int(os.getenv("MEMGRAPH_PORT", "7687")) |
| redpanda_brokers: str = os.getenv("REDPANDA_BROKERS", "localhost:9092") |
| x_bearer_token: str = os.getenv("X_BEARER_TOKEN", "") |
| |
| cache_ttl_green_red: int = 21600 |
| cache_ttl_yellow: int = 900 |
|
|
| class Config: |
| env_file = ".env" |
|
|
| settings = Settings() |
|
|
| |
| redis_client: aioredis.Redis | None = None |
|
|
| async def get_redis() -> aioredis.Redis: |
| global redis_client |
| if redis_client is None: |
| try: |
| redis_client = aioredis.from_url( |
| settings.redis_url, |
| encoding="utf-8", |
| decode_responses=True, |
| socket_connect_timeout=2, |
| ) |
| await redis_client.ping() |
| log.info("redis.connected", url=settings.redis_url) |
| except Exception as exc: |
| log.warning("redis.unavailable", error=str(exc)) |
| redis_client = None |
| return redis_client |
|
|
|
|
| def claim_cache_key(claim_hash: str) -> str: |
| return f"fact:v1:{claim_hash}" |
|
|
|
|
| async def cache_get(claim_hash: str) -> dict[str, Any] | None: |
| try: |
| r = await get_redis() |
| if r is None: |
| return None |
| raw = await r.get(claim_cache_key(claim_hash)) |
| return orjson.loads(raw) if raw else None |
| except Exception: |
| return None |
|
|
|
|
| async def cache_set(claim_hash: str, result: dict[str, Any]) -> None: |
| try: |
| r = await get_redis() |
| if r is None: |
| return |
| color = result.get("color", "yellow") |
| ttl = ( |
| settings.cache_ttl_green_red if color in ("green", "red") |
| else settings.cache_ttl_yellow if color == "yellow" |
| else None |
| ) |
| if ttl is not None: |
| await r.setex( |
| claim_cache_key(claim_hash), |
| ttl, |
| orjson.dumps(result).decode(), |
| ) |
| except Exception as exc: |
| log.warning("cache.set_failed", error=str(exc)) |
|
|
|
|
| |
| class ConnectionManager: |
| """Thread-safe registry of active WebSocket connections.""" |
|
|
| def __init__(self) -> None: |
| self._connections: dict[str, WebSocket] = {} |
| self._lock = asyncio.Lock() |
|
|
| async def connect(self, ws: WebSocket, client_id: str) -> None: |
| await ws.accept() |
| async with self._lock: |
| self._connections[client_id] = ws |
| log.info("ws.connected", client_id=client_id, total=len(self._connections)) |
|
|
| async def disconnect(self, client_id: str) -> None: |
| async with self._lock: |
| self._connections.pop(client_id, None) |
| log.info("ws.disconnected", client_id=client_id, total=len(self._connections)) |
|
|
| async def send(self, client_id: str, payload: dict[str, Any]) -> None: |
| async with self._lock: |
| ws = self._connections.get(client_id) |
| if ws: |
| try: |
| await ws.send_text(orjson.dumps(payload).decode()) |
| except Exception as exc: |
| log.warning("ws.send_failed", client_id=client_id, error=str(exc)) |
| await self.disconnect(client_id) |
|
|
| async def broadcast(self, payload: dict[str, Any]) -> None: |
| async with self._lock: |
| targets = list(self._connections.items()) |
| await asyncio.gather( |
| *[ws.send_text(orjson.dumps(payload).decode()) for _, ws in targets], |
| return_exceptions=True, |
| ) |
|
|
| @property |
| def count(self) -> int: |
| return len(self._connections) |
|
|
|
|
| manager = ConnectionManager() |
|
|
|
|
| |
| class AnalysisBatch(BaseModel): |
| """Incoming batch from the browser extension.""" |
| client_id: str |
| claims: list[str] = Field(..., min_length=1, max_length=20) |
| platform: str = Field(default="web") |
| timestamp: float = Field(default_factory=time.time) |
|
|
|
|
| class AnalysisResult(BaseModel): |
| """Outgoing result per-claim.""" |
| claim_hash: str |
| claim_text: str |
| color: str |
| confidence: int |
| verdict: str |
| explanation: str |
| sources: list[str] |
| trust_score: float |
| cached: bool = False |
| processing_ms: float = 0.0 |
|
|
|
|
| |
| @asynccontextmanager |
| async def lifespan(_app: FastAPI): |
| log.info("startup", version="1.0.0") |
| await get_redis() |
| yield |
| log.info("shutdown") |
| if redis_client: |
| await redis_client.aclose() |
|
|
|
|
| |
| app = FastAPI( |
| title="Fact & Hallucination Intelligence Engine", |
| version="1.0.0", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def root(): |
| return HTMLResponse(""" |
| <!DOCTYPE html> |
| <html> |
| <head> |
| <title>Fact Engine β Status</title> |
| <style> |
| body { font-family: 'Courier New', monospace; background: #0a0a0f; color: #00ff88; |
| display: flex; align-items: center; justify-content: center; height: 100vh; margin: 0; } |
| .card { text-align: center; border: 1px solid #00ff88; padding: 2rem 3rem; } |
| h1 { font-size: 1.5rem; letter-spacing: .2em; margin: 0 0 .5rem; } |
| p { margin: .25rem 0; color: #88ffcc; font-size: .85rem; } |
| .dot { display: inline-block; width: 8px; height: 8px; border-radius: 50%; |
| background: #00ff88; margin-right: 8px; animation: pulse 1.5s infinite; } |
| @keyframes pulse { 0%,100% { opacity:1 } 50% { opacity:.3 } } |
| </style> |
| </head> |
| <body> |
| <div class="card"> |
| <h1>β‘ FACT ENGINE</h1> |
| <p><span class="dot"></span>System Online</p> |
| <p>WebSocket: <strong>ws://[host]/ws/{client_id}</strong></p> |
| <p>Health: <strong>/health</strong></p> |
| </div> |
| </body> |
| </html> |
| """) |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| redis_ok = False |
| try: |
| r = await get_redis() |
| if r: |
| await r.ping() |
| redis_ok = True |
| except Exception: |
| pass |
| return { |
| "status": "ok", |
| "connections": manager.count, |
| "redis": redis_ok, |
| "timestamp": time.time(), |
| } |
|
|
|
|
| @app.websocket("/ws/{client_id}") |
| async def websocket_endpoint(ws: WebSocket, client_id: str): |
| await manager.connect(ws, client_id) |
| try: |
| while True: |
| raw = await ws.receive_text() |
| try: |
| data = orjson.loads(raw) |
| except Exception: |
| await manager.send(client_id, {"error": "invalid_json"}) |
| continue |
|
|
| batch = AnalysisBatch.model_validate(data) |
| |
| sem = asyncio.Semaphore(5) |
| tasks = [process_claim(sem, claim, batch.platform) for claim in batch.claims] |
| results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
| response_items = [] |
| for res in results: |
| if isinstance(res, Exception): |
| log.error("claim.process_error", error=str(res)) |
| else: |
| response_items.append(res.model_dump()) |
|
|
| await manager.send(client_id, { |
| "type": "analysis_batch", |
| "results": response_items, |
| "request_timestamp": batch.timestamp, |
| }) |
|
|
| except WebSocketDisconnect: |
| await manager.disconnect(client_id) |
| except Exception as exc: |
| log.error("ws.error", client_id=client_id, error=str(exc)) |
| await manager.disconnect(client_id) |
|
|
|
|
| async def process_claim( |
| sem: asyncio.Semaphore, |
| claim_text: str, |
| platform: str, |
| ) -> AnalysisResult: |
| """ |
| Full pipeline per claim: |
| 1. xxhash β Redis cache check (skip pipeline on hit) |
| 2. Gatekeeper (Groq): fact vs. noise filter |
| 3. RAG pipeline: embed β Qdrant ANN β Memgraph trust score |
| 4. Prefect multi-agent flow: misinformation + hallucination tasks |
| 5. Cache result, return AnalysisResult |
| """ |
| async with sem: |
| t0 = time.perf_counter() |
| claim_hash = xxhash.xxh64(claim_text.encode()).hexdigest() |
|
|
| |
| cached = await cache_get(claim_hash) |
| if cached: |
| return AnalysisResult(**{**cached, "cached": True}) |
|
|
| |
| gate = await classify_claim(claim_text) |
| if gate.label == "noise": |
| |
| result = AnalysisResult( |
| claim_hash=claim_hash, |
| claim_text=claim_text, |
| color="green", |
| confidence=50, |
| verdict="Opinion / Social noise", |
| explanation=gate.reason, |
| sources=[], |
| trust_score=0.5, |
| processing_ms=(time.perf_counter() - t0) * 1000, |
| ) |
| return result |
|
|
| |
| rag_ctx = await build_rag_context(claim_text, claim_hash) |
|
|
| |
| analysis = await run_analysis_flow( |
| claim_text=claim_text, |
| claim_hash=claim_hash, |
| platform=platform, |
| rag_context=rag_ctx, |
| ) |
|
|
| result = AnalysisResult( |
| claim_hash=claim_hash, |
| claim_text=claim_text, |
| color=analysis.color, |
| confidence=analysis.confidence, |
| verdict=analysis.verdict, |
| explanation=analysis.explanation, |
| sources=analysis.sources, |
| trust_score=rag_ctx.trust_score, |
| processing_ms=(time.perf_counter() - t0) * 1000, |
| ) |
|
|
| |
| await cache_set(claim_hash, result.model_dump()) |
| return result |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run( |
| "main:app", |
| host="0.0.0.0", |
| port=int(os.getenv("PORT", "7860")), |
| reload=False, |
| log_level="info", |
| ) |
|
|