TigerGraph-Hack / app.py
Meshyboi's picture
Update app.py
fbe97a2 verified
Raw
History Blame Contribute Delete
5.4 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import uvicorn
import anyio
import os
import json
import asyncio
import logging
from pathlib import Path
import requests as _requests
from core.pipeline_1.logic import PipelineLLMOnly
from core.pipeline_2.logic import PipelineRAG
from core.pipeline_3.logic import PipelineGraphRAG
from services.metrics_service import MetricsService
# --- SILENCE NOISY LOGGERS ---
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
logging.getLogger("pyTigerGraph").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Modern lifespan handler (replaces @on_event)"""
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, shared_metrics.warmup)
yield
app = FastAPI(
title="SEC Dataset API",
description="A simple FastAPI setup to fetch and interact with the PleIAs/SEC dataset.",
version="1.0.0",
lifespan=lifespan
)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request model
class QueryRequest(BaseModel):
query: str
ground_truth: str = None
# Initialize Shared Services
shared_metrics = MetricsService()
# Initialize Pipelines
pipeline_baseline = PipelineLLMOnly(top_n=20, max_full_text=20)
pipeline_rag = PipelineRAG(retrieval_top_k=50, rerank_top_n=10, max_full_text=3)
pipeline_graph = PipelineGraphRAG(rerank_top_n=10, max_full_text=3, retriever=pipeline_rag.retriever)
# Inject shared metrics
pipeline_baseline.metrics = shared_metrics
pipeline_rag.metrics = shared_metrics
pipeline_graph.metrics = shared_metrics
@app.get("/")
async def root():
return {"message": "Welcome to the SEC Dataset API", "status": "running"}
@app.get("/health")
async def health():
tg_host = os.environ.get("TG_HOST", "")
tg_secret = os.environ.get("TG_SECRET", "")
tg_status = "skipped"
tg_paper_count = None
if tg_host and tg_secret:
try:
import pyTigerGraph as tg
conn = tg.TigerGraphConnection(
host=tg_host, graphname="PaperGraph",
gsqlSecret=tg_secret,
restppPort=os.environ.get("TG_PORT", "443"),
gsPort=os.environ.get("TG_PORT", "443"),
)
token = conn.getToken(tg_secret)[0]
resp = _requests.get(
f"{tg_host}/restpp/graph/PaperGraph/vertices/Paper",
params={"count_only": True},
headers={"Authorization": f"Bearer {token}"},
timeout=10,
)
data = resp.json()
tg_paper_count = data["results"][0]["count"]
tg_status = "ok"
except Exception as e:
tg_status = f"error: {e}"
qdrant_status = "skipped"
try:
cols = pipeline_rag.retriever.client.get_collections().collections
qdrant_status = f"ok ({len(cols)} collection(s))"
except Exception as e:
qdrant_status = f"error: {e}"
return {
"app": "ok",
"tigergraph": tg_status,
"tigergraph_paper_count": tg_paper_count,
"qdrant": qdrant_status,
}
_SENTINEL = object()
class AsyncIteratorWrapper:
"""Wraps a synchronous iterator using a sentinel to avoid StopIteration leakage."""
def __init__(self, it):
self.it = it
def __aiter__(self):
return self
async def __anext__(self):
def _get_next():
try:
return next(self.it)
except StopIteration:
return _SENTINEL
value = await anyio.to_thread.run_sync(_get_next)
if value is _SENTINEL:
raise StopAsyncIteration
return value
async def stream_pipeline(pipeline, query, ground_truth):
try:
# Wrap the synchronous generator with a sentinel-safe async iterator
async_gen = AsyncIteratorWrapper(pipeline.run_stream(query, ground_truth))
async for event in async_gen:
yield f"data: {json.dumps(event)}\n\n"
await asyncio.sleep(0.01)
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n"
@app.post("/query/baseline")
async def query_baseline(request: QueryRequest):
return StreamingResponse(
stream_pipeline(pipeline_baseline, request.query, request.ground_truth),
media_type="text/event-stream"
)
@app.post("/query/rag")
async def query_rag(request: QueryRequest):
return StreamingResponse(
stream_pipeline(pipeline_rag, request.query, request.ground_truth),
media_type="text/event-stream"
)
@app.post("/query/graph")
async def query_graph(request: QueryRequest):
return StreamingResponse(
stream_pipeline(pipeline_graph, request.query, request.ground_truth),
media_type="text/event-stream"
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="warning")