File size: 5,402 Bytes
90645a4
 
62e9983
 
 
90645a4
fbe97a2
ffa34a9
62e9983
 
 
90645a4
 
ffa34a9
 
90645a4
 
 
62e9983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90645a4
 
 
 
62e9983
 
 
 
 
 
 
 
 
 
 
90645a4
 
 
 
 
349ac13
90645a4
62e9983
 
 
 
90645a4
 
2d8e940
90645a4
62e9983
 
 
 
 
90645a4
 
 
 
ffa34a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbe97a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62e9983
 
fbe97a2
 
 
62e9983
 
 
 
 
90645a4
 
62e9983
 
 
 
90645a4
 
 
62e9983
 
 
 
90645a4
 
 
62e9983
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import uvicorn
import anyio
import os
import json
import asyncio
import logging
from pathlib import Path

import requests as _requests

from core.pipeline_1.logic import PipelineLLMOnly
from core.pipeline_2.logic import PipelineRAG
from core.pipeline_3.logic import PipelineGraphRAG
from services.metrics_service import MetricsService

# --- SILENCE NOISY LOGGERS ---
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
logging.getLogger("pyTigerGraph").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Modern lifespan handler (replaces @on_event)"""
    loop = asyncio.get_event_loop()
    await loop.run_in_executor(None, shared_metrics.warmup)
    yield

app = FastAPI(
    title="SEC Dataset API",
    description="A simple FastAPI setup to fetch and interact with the PleIAs/SEC dataset.",
    version="1.0.0",
    lifespan=lifespan
)

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Request model
class QueryRequest(BaseModel):
    query: str
    ground_truth: str = None

# Initialize Shared Services
shared_metrics = MetricsService()

# Initialize Pipelines
pipeline_baseline = PipelineLLMOnly(top_n=20, max_full_text=20)
pipeline_rag = PipelineRAG(retrieval_top_k=50, rerank_top_n=10, max_full_text=3)
pipeline_graph = PipelineGraphRAG(rerank_top_n=10, max_full_text=3, retriever=pipeline_rag.retriever)

# Inject shared metrics
pipeline_baseline.metrics = shared_metrics
pipeline_rag.metrics = shared_metrics
pipeline_graph.metrics = shared_metrics

@app.get("/")
async def root():
    return {"message": "Welcome to the SEC Dataset API", "status": "running"}

@app.get("/health")
async def health():
    tg_host = os.environ.get("TG_HOST", "")
    tg_secret = os.environ.get("TG_SECRET", "")
    tg_status = "skipped"
    tg_paper_count = None

    if tg_host and tg_secret:
        try:
            import pyTigerGraph as tg
            conn = tg.TigerGraphConnection(
                host=tg_host, graphname="PaperGraph",
                gsqlSecret=tg_secret,
                restppPort=os.environ.get("TG_PORT", "443"),
                gsPort=os.environ.get("TG_PORT", "443"),
            )
            token = conn.getToken(tg_secret)[0]
            resp = _requests.get(
                f"{tg_host}/restpp/graph/PaperGraph/vertices/Paper",
                params={"count_only": True},
                headers={"Authorization": f"Bearer {token}"},
                timeout=10,
            )
            data = resp.json()
            tg_paper_count = data["results"][0]["count"]
            tg_status = "ok"
        except Exception as e:
            tg_status = f"error: {e}"

    qdrant_status = "skipped"
    try:
        cols = pipeline_rag.retriever.client.get_collections().collections
        qdrant_status = f"ok ({len(cols)} collection(s))"
    except Exception as e:
        qdrant_status = f"error: {e}"

    return {
        "app": "ok",
        "tigergraph": tg_status,
        "tigergraph_paper_count": tg_paper_count,
        "qdrant": qdrant_status,
    }

_SENTINEL = object()

class AsyncIteratorWrapper:
    """Wraps a synchronous iterator using a sentinel to avoid StopIteration leakage."""
    def __init__(self, it):
        self.it = it
    def __aiter__(self):
        return self
    async def __anext__(self):
        def _get_next():
            try:
                return next(self.it)
            except StopIteration:
                return _SENTINEL
        
        value = await anyio.to_thread.run_sync(_get_next)
        if value is _SENTINEL:
            raise StopAsyncIteration
        return value

async def stream_pipeline(pipeline, query, ground_truth):
    try:
        # Wrap the synchronous generator with a sentinel-safe async iterator
        async_gen = AsyncIteratorWrapper(pipeline.run_stream(query, ground_truth))
        async for event in async_gen:
            yield f"data: {json.dumps(event)}\n\n"
            await asyncio.sleep(0.01)
    except Exception as e:
        yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n"

@app.post("/query/baseline")
async def query_baseline(request: QueryRequest):
    return StreamingResponse(
        stream_pipeline(pipeline_baseline, request.query, request.ground_truth),
        media_type="text/event-stream"
    )

@app.post("/query/rag")
async def query_rag(request: QueryRequest):
    return StreamingResponse(
        stream_pipeline(pipeline_rag, request.query, request.ground_truth),
        media_type="text/event-stream"
    )

@app.post("/query/graph")
async def query_graph(request: QueryRequest):
    return StreamingResponse(
        stream_pipeline(pipeline_graph, request.query, request.ground_truth),
        media_type="text/event-stream"
    )

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="warning")