Meshyboi commited on
Commit
62e9983
·
verified ·
1 Parent(s): 349ac13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -34
app.py CHANGED
@@ -1,9 +1,13 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
 
 
 
3
  import uvicorn
4
- import importlib.util
5
  import os
6
- import sys
 
 
7
  from pathlib import Path
8
 
9
  import requests as _requests
@@ -11,11 +15,37 @@ import requests as _requests
11
  from core.pipeline_1.logic import PipelineLLMOnly
12
  from core.pipeline_2.logic import PipelineRAG
13
  from core.pipeline_3.logic import PipelineGraphRAG
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  app = FastAPI(
16
  title="SEC Dataset API",
17
  description="A simple FastAPI setup to fetch and interact with the PleIAs/SEC dataset.",
18
- version="1.0.0"
 
 
 
 
 
 
 
 
 
 
19
  )
20
 
21
  # Request model
@@ -23,16 +53,23 @@ class QueryRequest(BaseModel):
23
  query: str
24
  ground_truth: str = None
25
 
26
- # Initialize Pipeline (Scaling up for full context)
 
 
 
27
  pipeline_baseline = PipelineLLMOnly(top_n=20, max_full_text=20)
28
  pipeline_rag = PipelineRAG(retrieval_top_k=50, rerank_top_n=10, max_full_text=3)
29
  pipeline_graph = PipelineGraphRAG(rerank_top_n=10, max_full_text=3, retriever=pipeline_rag.retriever)
30
 
 
 
 
 
 
31
  @app.get("/")
32
  async def root():
33
  return {"message": "Welcome to the SEC Dataset API", "status": "running"}
34
 
35
-
36
  @app.get("/health")
37
  async def health():
38
  tg_host = os.environ.get("TG_HOST", "")
@@ -76,40 +113,34 @@ async def health():
76
  "qdrant": qdrant_status,
77
  }
78
 
 
 
 
 
 
 
 
 
79
  @app.post("/query/baseline")
80
  async def query_baseline(request: QueryRequest):
81
- """
82
- Pipeline 1: LLM-Only Baseline.
83
- Performs metadata filtering and LLM synthesis.
84
- """
85
- if not request.query:
86
- raise HTTPException(status_code=400, detail="Query cannot be empty")
87
-
88
- result = pipeline_baseline.run(request.query, ground_truth=request.ground_truth)
89
- return result
90
-
91
 
92
  @app.post("/query/rag")
93
  async def query_rag(request: QueryRequest):
94
- """
95
- Pipeline 2: Hybrid RAG.
96
- Hybrid vector search + cross-encoder rerank + Groq LLM synthesis.
97
- """
98
- if not request.query:
99
- raise HTTPException(status_code=400, detail="Query cannot be empty")
100
-
101
- result = pipeline_rag.run(request.query, ground_truth=request.ground_truth)
102
- return result
103
-
104
 
105
  @app.post("/query/graph")
106
  async def query_graph(request: QueryRequest):
107
- """
108
- Pipeline 3: GraphRAG.
109
- TigerGraph vector search + topic linking + citation expansion + Groq LLM synthesis.
110
- """
111
- if not request.query:
112
- raise HTTPException(status_code=400, detail="Query cannot be empty")
113
-
114
- result = pipeline_graph.run(request.query, ground_truth=request.ground_truth)
115
- return result
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from fastapi.responses import StreamingResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from contextlib import asynccontextmanager
6
  import uvicorn
 
7
  import os
8
+ import json
9
+ import asyncio
10
+ import logging
11
  from pathlib import Path
12
 
13
  import requests as _requests
 
15
  from core.pipeline_1.logic import PipelineLLMOnly
16
  from core.pipeline_2.logic import PipelineRAG
17
  from core.pipeline_3.logic import PipelineGraphRAG
18
+ from services.metrics_service import MetricsService
19
+
20
+ # --- SILENCE NOISY LOGGERS ---
21
+ logging.getLogger("httpx").setLevel(logging.WARNING)
22
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
23
+ logging.getLogger("transformers").setLevel(logging.ERROR)
24
+ logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
25
+ logging.getLogger("pyTigerGraph").setLevel(logging.WARNING)
26
+ logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
27
+
28
+ @asynccontextmanager
29
+ async def lifespan(app: FastAPI):
30
+ """Modern lifespan handler (replaces @on_event)"""
31
+ loop = asyncio.get_event_loop()
32
+ await loop.run_in_executor(None, shared_metrics.warmup)
33
+ yield
34
 
35
  app = FastAPI(
36
  title="SEC Dataset API",
37
  description="A simple FastAPI setup to fetch and interact with the PleIAs/SEC dataset.",
38
+ version="1.0.0",
39
+ lifespan=lifespan
40
+ )
41
+
42
+ # Enable CORS
43
+ app.add_middleware(
44
+ CORSMiddleware,
45
+ allow_origins=["*"],
46
+ allow_credentials=True,
47
+ allow_methods=["*"],
48
+ allow_headers=["*"],
49
  )
50
 
51
  # Request model
 
53
  query: str
54
  ground_truth: str = None
55
 
56
+ # Initialize Shared Services
57
+ shared_metrics = MetricsService()
58
+
59
+ # Initialize Pipelines
60
  pipeline_baseline = PipelineLLMOnly(top_n=20, max_full_text=20)
61
  pipeline_rag = PipelineRAG(retrieval_top_k=50, rerank_top_n=10, max_full_text=3)
62
  pipeline_graph = PipelineGraphRAG(rerank_top_n=10, max_full_text=3, retriever=pipeline_rag.retriever)
63
 
64
+ # Inject shared metrics
65
+ pipeline_baseline.metrics = shared_metrics
66
+ pipeline_rag.metrics = shared_metrics
67
+ pipeline_graph.metrics = shared_metrics
68
+
69
  @app.get("/")
70
  async def root():
71
  return {"message": "Welcome to the SEC Dataset API", "status": "running"}
72
 
 
73
  @app.get("/health")
74
  async def health():
75
  tg_host = os.environ.get("TG_HOST", "")
 
113
  "qdrant": qdrant_status,
114
  }
115
 
116
+ async def stream_pipeline(pipeline, query, ground_truth):
117
+ try:
118
+ for event in pipeline.run_stream(query, ground_truth):
119
+ yield f"data: {json.dumps(event)}\n\n"
120
+ await asyncio.sleep(0.01)
121
+ except Exception as e:
122
+ yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n"
123
+
124
  @app.post("/query/baseline")
125
  async def query_baseline(request: QueryRequest):
126
+ return StreamingResponse(
127
+ stream_pipeline(pipeline_baseline, request.query, request.ground_truth),
128
+ media_type="text/event-stream"
129
+ )
 
 
 
 
 
 
130
 
131
  @app.post("/query/rag")
132
  async def query_rag(request: QueryRequest):
133
+ return StreamingResponse(
134
+ stream_pipeline(pipeline_rag, request.query, request.ground_truth),
135
+ media_type="text/event-stream"
136
+ )
 
 
 
 
 
 
137
 
138
  @app.post("/query/graph")
139
  async def query_graph(request: QueryRequest):
140
+ return StreamingResponse(
141
+ stream_pipeline(pipeline_graph, request.query, request.ground_truth),
142
+ media_type="text/event-stream"
143
+ )
144
+
145
+ if __name__ == "__main__":
146
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_level="warning")