from fastapi import FastAPI, HTTPException from pydantic import BaseModel from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager import uvicorn import anyio import os import json import asyncio import logging from pathlib import Path import requests as _requests from core.pipeline_1.logic import PipelineLLMOnly from core.pipeline_2.logic import PipelineRAG from core.pipeline_3.logic import PipelineGraphRAG from services.metrics_service import MetricsService # --- SILENCE NOISY LOGGERS --- logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) logging.getLogger("transformers").setLevel(logging.ERROR) logging.getLogger("sentence_transformers").setLevel(logging.WARNING) logging.getLogger("pyTigerGraph").setLevel(logging.WARNING) logging.getLogger("uvicorn.access").setLevel(logging.WARNING) @asynccontextmanager async def lifespan(app: FastAPI): """Modern lifespan handler (replaces @on_event)""" loop = asyncio.get_event_loop() await loop.run_in_executor(None, shared_metrics.warmup) yield app = FastAPI( title="SEC Dataset API", description="A simple FastAPI setup to fetch and interact with the PleIAs/SEC dataset.", version="1.0.0", lifespan=lifespan ) # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Request model class QueryRequest(BaseModel): query: str ground_truth: str = None # Initialize Shared Services shared_metrics = MetricsService() # Initialize Pipelines pipeline_baseline = PipelineLLMOnly(top_n=20, max_full_text=20) pipeline_rag = PipelineRAG(retrieval_top_k=50, rerank_top_n=10, max_full_text=3) pipeline_graph = PipelineGraphRAG(rerank_top_n=10, max_full_text=3, retriever=pipeline_rag.retriever) # Inject shared metrics pipeline_baseline.metrics = shared_metrics pipeline_rag.metrics = shared_metrics pipeline_graph.metrics = shared_metrics @app.get("/") async def root(): return {"message": "Welcome to the SEC Dataset API", "status": "running"} @app.get("/health") async def health(): tg_host = os.environ.get("TG_HOST", "") tg_secret = os.environ.get("TG_SECRET", "") tg_status = "skipped" tg_paper_count = None if tg_host and tg_secret: try: import pyTigerGraph as tg conn = tg.TigerGraphConnection( host=tg_host, graphname="PaperGraph", gsqlSecret=tg_secret, restppPort=os.environ.get("TG_PORT", "443"), gsPort=os.environ.get("TG_PORT", "443"), ) token = conn.getToken(tg_secret)[0] resp = _requests.get( f"{tg_host}/restpp/graph/PaperGraph/vertices/Paper", params={"count_only": True}, headers={"Authorization": f"Bearer {token}"}, timeout=10, ) data = resp.json() tg_paper_count = data["results"][0]["count"] tg_status = "ok" except Exception as e: tg_status = f"error: {e}" qdrant_status = "skipped" try: cols = pipeline_rag.retriever.client.get_collections().collections qdrant_status = f"ok ({len(cols)} collection(s))" except Exception as e: qdrant_status = f"error: {e}" return { "app": "ok", "tigergraph": tg_status, "tigergraph_paper_count": tg_paper_count, "qdrant": qdrant_status, } _SENTINEL = object() class AsyncIteratorWrapper: """Wraps a synchronous iterator using a sentinel to avoid StopIteration leakage.""" def __init__(self, it): self.it = it def __aiter__(self): return self async def __anext__(self): def _get_next(): try: return next(self.it) except StopIteration: return _SENTINEL value = await anyio.to_thread.run_sync(_get_next) if value is _SENTINEL: raise StopAsyncIteration return value async def stream_pipeline(pipeline, query, ground_truth): try: # Wrap the synchronous generator with a sentinel-safe async iterator async_gen = AsyncIteratorWrapper(pipeline.run_stream(query, ground_truth)) async for event in async_gen: yield f"data: {json.dumps(event)}\n\n" await asyncio.sleep(0.01) except Exception as e: yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n" @app.post("/query/baseline") async def query_baseline(request: QueryRequest): return StreamingResponse( stream_pipeline(pipeline_baseline, request.query, request.ground_truth), media_type="text/event-stream" ) @app.post("/query/rag") async def query_rag(request: QueryRequest): return StreamingResponse( stream_pipeline(pipeline_rag, request.query, request.ground_truth), media_type="text/event-stream" ) @app.post("/query/graph") async def query_graph(request: QueryRequest): return StreamingResponse( stream_pipeline(pipeline_graph, request.query, request.ground_truth), media_type="text/event-stream" ) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000, log_level="warning")