Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from contextlib import asynccontextmanager | |
| import uvicorn | |
| import anyio | |
| import os | |
| import json | |
| import asyncio | |
| import logging | |
| from pathlib import Path | |
| import requests as _requests | |
| from core.pipeline_1.logic import PipelineLLMOnly | |
| from core.pipeline_2.logic import PipelineRAG | |
| from core.pipeline_3.logic import PipelineGraphRAG | |
| from services.metrics_service import MetricsService | |
| # --- SILENCE NOISY LOGGERS --- | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| logging.getLogger("httpcore").setLevel(logging.WARNING) | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| logging.getLogger("sentence_transformers").setLevel(logging.WARNING) | |
| logging.getLogger("pyTigerGraph").setLevel(logging.WARNING) | |
| logging.getLogger("uvicorn.access").setLevel(logging.WARNING) | |
| async def lifespan(app: FastAPI): | |
| """Modern lifespan handler (replaces @on_event)""" | |
| loop = asyncio.get_event_loop() | |
| await loop.run_in_executor(None, shared_metrics.warmup) | |
| yield | |
| app = FastAPI( | |
| title="SEC Dataset API", | |
| description="A simple FastAPI setup to fetch and interact with the PleIAs/SEC dataset.", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Enable CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request model | |
| class QueryRequest(BaseModel): | |
| query: str | |
| ground_truth: str = None | |
| # Initialize Shared Services | |
| shared_metrics = MetricsService() | |
| # Initialize Pipelines | |
| pipeline_baseline = PipelineLLMOnly(top_n=20, max_full_text=20) | |
| pipeline_rag = PipelineRAG(retrieval_top_k=50, rerank_top_n=10, max_full_text=3) | |
| pipeline_graph = PipelineGraphRAG(rerank_top_n=10, max_full_text=3, retriever=pipeline_rag.retriever) | |
| # Inject shared metrics | |
| pipeline_baseline.metrics = shared_metrics | |
| pipeline_rag.metrics = shared_metrics | |
| pipeline_graph.metrics = shared_metrics | |
| async def root(): | |
| return {"message": "Welcome to the SEC Dataset API", "status": "running"} | |
| async def health(): | |
| tg_host = os.environ.get("TG_HOST", "") | |
| tg_secret = os.environ.get("TG_SECRET", "") | |
| tg_status = "skipped" | |
| tg_paper_count = None | |
| if tg_host and tg_secret: | |
| try: | |
| import pyTigerGraph as tg | |
| conn = tg.TigerGraphConnection( | |
| host=tg_host, graphname="PaperGraph", | |
| gsqlSecret=tg_secret, | |
| restppPort=os.environ.get("TG_PORT", "443"), | |
| gsPort=os.environ.get("TG_PORT", "443"), | |
| ) | |
| token = conn.getToken(tg_secret)[0] | |
| resp = _requests.get( | |
| f"{tg_host}/restpp/graph/PaperGraph/vertices/Paper", | |
| params={"count_only": True}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| timeout=10, | |
| ) | |
| data = resp.json() | |
| tg_paper_count = data["results"][0]["count"] | |
| tg_status = "ok" | |
| except Exception as e: | |
| tg_status = f"error: {e}" | |
| qdrant_status = "skipped" | |
| try: | |
| cols = pipeline_rag.retriever.client.get_collections().collections | |
| qdrant_status = f"ok ({len(cols)} collection(s))" | |
| except Exception as e: | |
| qdrant_status = f"error: {e}" | |
| return { | |
| "app": "ok", | |
| "tigergraph": tg_status, | |
| "tigergraph_paper_count": tg_paper_count, | |
| "qdrant": qdrant_status, | |
| } | |
| _SENTINEL = object() | |
| class AsyncIteratorWrapper: | |
| """Wraps a synchronous iterator using a sentinel to avoid StopIteration leakage.""" | |
| def __init__(self, it): | |
| self.it = it | |
| def __aiter__(self): | |
| return self | |
| async def __anext__(self): | |
| def _get_next(): | |
| try: | |
| return next(self.it) | |
| except StopIteration: | |
| return _SENTINEL | |
| value = await anyio.to_thread.run_sync(_get_next) | |
| if value is _SENTINEL: | |
| raise StopAsyncIteration | |
| return value | |
| async def stream_pipeline(pipeline, query, ground_truth): | |
| try: | |
| # Wrap the synchronous generator with a sentinel-safe async iterator | |
| async_gen = AsyncIteratorWrapper(pipeline.run_stream(query, ground_truth)) | |
| async for event in async_gen: | |
| yield f"data: {json.dumps(event)}\n\n" | |
| await asyncio.sleep(0.01) | |
| except Exception as e: | |
| yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n" | |
| async def query_baseline(request: QueryRequest): | |
| return StreamingResponse( | |
| stream_pipeline(pipeline_baseline, request.query, request.ground_truth), | |
| media_type="text/event-stream" | |
| ) | |
| async def query_rag(request: QueryRequest): | |
| return StreamingResponse( | |
| stream_pipeline(pipeline_rag, request.query, request.ground_truth), | |
| media_type="text/event-stream" | |
| ) | |
| async def query_graph(request: QueryRequest): | |
| return StreamingResponse( | |
| stream_pipeline(pipeline_graph, request.query, request.ground_truth), | |
| media_type="text/event-stream" | |
| ) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000, log_level="warning") |