Spaces:
Sleeping
Sleeping
File size: 5,402 Bytes
90645a4 62e9983 90645a4 fbe97a2 ffa34a9 62e9983 90645a4 ffa34a9 90645a4 62e9983 90645a4 62e9983 90645a4 349ac13 90645a4 62e9983 90645a4 2d8e940 90645a4 62e9983 90645a4 ffa34a9 fbe97a2 62e9983 fbe97a2 62e9983 90645a4 62e9983 90645a4 62e9983 90645a4 62e9983 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import uvicorn
import anyio
import os
import json
import asyncio
import logging
from pathlib import Path
import requests as _requests
from core.pipeline_1.logic import PipelineLLMOnly
from core.pipeline_2.logic import PipelineRAG
from core.pipeline_3.logic import PipelineGraphRAG
from services.metrics_service import MetricsService
# --- SILENCE NOISY LOGGERS ---
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
logging.getLogger("pyTigerGraph").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Modern lifespan handler (replaces @on_event)"""
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, shared_metrics.warmup)
yield
app = FastAPI(
title="SEC Dataset API",
description="A simple FastAPI setup to fetch and interact with the PleIAs/SEC dataset.",
version="1.0.0",
lifespan=lifespan
)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request model
class QueryRequest(BaseModel):
query: str
ground_truth: str = None
# Initialize Shared Services
shared_metrics = MetricsService()
# Initialize Pipelines
pipeline_baseline = PipelineLLMOnly(top_n=20, max_full_text=20)
pipeline_rag = PipelineRAG(retrieval_top_k=50, rerank_top_n=10, max_full_text=3)
pipeline_graph = PipelineGraphRAG(rerank_top_n=10, max_full_text=3, retriever=pipeline_rag.retriever)
# Inject shared metrics
pipeline_baseline.metrics = shared_metrics
pipeline_rag.metrics = shared_metrics
pipeline_graph.metrics = shared_metrics
@app.get("/")
async def root():
return {"message": "Welcome to the SEC Dataset API", "status": "running"}
@app.get("/health")
async def health():
tg_host = os.environ.get("TG_HOST", "")
tg_secret = os.environ.get("TG_SECRET", "")
tg_status = "skipped"
tg_paper_count = None
if tg_host and tg_secret:
try:
import pyTigerGraph as tg
conn = tg.TigerGraphConnection(
host=tg_host, graphname="PaperGraph",
gsqlSecret=tg_secret,
restppPort=os.environ.get("TG_PORT", "443"),
gsPort=os.environ.get("TG_PORT", "443"),
)
token = conn.getToken(tg_secret)[0]
resp = _requests.get(
f"{tg_host}/restpp/graph/PaperGraph/vertices/Paper",
params={"count_only": True},
headers={"Authorization": f"Bearer {token}"},
timeout=10,
)
data = resp.json()
tg_paper_count = data["results"][0]["count"]
tg_status = "ok"
except Exception as e:
tg_status = f"error: {e}"
qdrant_status = "skipped"
try:
cols = pipeline_rag.retriever.client.get_collections().collections
qdrant_status = f"ok ({len(cols)} collection(s))"
except Exception as e:
qdrant_status = f"error: {e}"
return {
"app": "ok",
"tigergraph": tg_status,
"tigergraph_paper_count": tg_paper_count,
"qdrant": qdrant_status,
}
_SENTINEL = object()
class AsyncIteratorWrapper:
"""Wraps a synchronous iterator using a sentinel to avoid StopIteration leakage."""
def __init__(self, it):
self.it = it
def __aiter__(self):
return self
async def __anext__(self):
def _get_next():
try:
return next(self.it)
except StopIteration:
return _SENTINEL
value = await anyio.to_thread.run_sync(_get_next)
if value is _SENTINEL:
raise StopAsyncIteration
return value
async def stream_pipeline(pipeline, query, ground_truth):
try:
# Wrap the synchronous generator with a sentinel-safe async iterator
async_gen = AsyncIteratorWrapper(pipeline.run_stream(query, ground_truth))
async for event in async_gen:
yield f"data: {json.dumps(event)}\n\n"
await asyncio.sleep(0.01)
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n"
@app.post("/query/baseline")
async def query_baseline(request: QueryRequest):
return StreamingResponse(
stream_pipeline(pipeline_baseline, request.query, request.ground_truth),
media_type="text/event-stream"
)
@app.post("/query/rag")
async def query_rag(request: QueryRequest):
return StreamingResponse(
stream_pipeline(pipeline_rag, request.query, request.ground_truth),
media_type="text/event-stream"
)
@app.post("/query/graph")
async def query_graph(request: QueryRequest):
return StreamingResponse(
stream_pipeline(pipeline_graph, request.query, request.ground_truth),
media_type="text/event-stream"
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="warning") |