remdms commited on
Commit
4bb26eb
·
1 Parent(s): 475c2dd

feat(api): replace SSE chat with sync POST /api/search

Browse files
src/mediastorm/api.py CHANGED
@@ -6,16 +6,13 @@ from typing import Any
6
  import httpx
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
9
- from sse_starlette.sse import EventSourceResponse
10
-
11
  from mediastorm.config import CHROMADB_PATH, BM25_INDEX_PATH, INGEST_CACHE_DIR
12
  from mediastorm.vectorize.store import VectorStore
13
  from mediastorm.vectorize.embedder import Embedder
14
  from mediastorm.vectorize.bm25_store import BM25Store
15
  from mediastorm.rag.router import QueryRouter
16
  from mediastorm.rag.retriever import HybridRetriever
17
- from mediastorm.rag.generator import generate_response_stream
18
- from mediastorm.rag.rewriter import rewrite_if_needed
19
 
20
  store: VectorStore | None = None
21
  embedder: Embedder | None = None
@@ -155,10 +152,8 @@ async def lifespan(app: FastAPI):
155
  app = FastAPI(title="MediaStorm Archive Explorer", lifespan=lifespan)
156
 
157
 
158
- class ChatRequest(BaseModel):
159
  query: str
160
- history: list[dict] = []
161
- filters: dict = {}
162
 
163
 
164
  def _extract_story_cards(stories: list[dict]) -> list[dict]:
@@ -202,44 +197,19 @@ async def stories():
202
  return {"stories": enriched}
203
 
204
 
205
- @app.post("/api/chat")
206
- async def chat(req: ChatRequest):
207
- async def event_generator():
208
- # Convert history to Gemini format
209
- chat_history = []
210
- for msg in (req.history or []):
211
- role = msg.get("role", "user")
212
- content = msg.get("content", "")
213
- if role == "user":
214
- chat_history.append({"role": "user", "parts": [{"text": content}]})
215
- else:
216
- chat_history.append({"role": "model", "parts": [{"text": content}]})
217
-
218
- # Rewrite query for retrieval if it's a follow-up with implicit references
219
- rewrite_history = req.history or []
220
- search_query = await rewrite_if_needed(req.query, rewrite_history)
221
-
222
- ui_filters = req.filters if req.filters else None
223
- result = await retriever.retrieve(search_query, ui_filters=ui_filters)
224
-
225
- all_cards = _extract_story_cards(result.stories)
226
-
227
- # Stream text first, accumulate full response
228
- full_text = ""
229
- async for text in generate_response_stream(
230
- req.query, result, chat_history or None, link_lookup=_link_lookup,
231
- ):
232
- full_text = text
233
- yield {"event": "text", "data": json.dumps({"content": text})}
234
-
235
- # Only show cards for stories Gemini actually cited (by link URL in text)
236
- cited_cards = [
237
- c for c in all_cards
238
- if c.get("link") and c["link"] in full_text
239
- ]
240
- if cited_cards:
241
- yield {"event": "stories", "data": json.dumps({"stories": cited_cards})}
242
-
243
- yield {"event": "done", "data": "{}"}
244
-
245
- return EventSourceResponse(event_generator())
 
6
  import httpx
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
 
 
9
  from mediastorm.config import CHROMADB_PATH, BM25_INDEX_PATH, INGEST_CACHE_DIR
10
  from mediastorm.vectorize.store import VectorStore
11
  from mediastorm.vectorize.embedder import Embedder
12
  from mediastorm.vectorize.bm25_store import BM25Store
13
  from mediastorm.rag.router import QueryRouter
14
  from mediastorm.rag.retriever import HybridRetriever
15
+ from mediastorm.rag.generator import generate_response
 
16
 
17
  store: VectorStore | None = None
18
  embedder: Embedder | None = None
 
152
  app = FastAPI(title="MediaStorm Archive Explorer", lifespan=lifespan)
153
 
154
 
155
+ class SearchRequest(BaseModel):
156
  query: str
 
 
157
 
158
 
159
  def _extract_story_cards(stories: list[dict]) -> list[dict]:
 
197
  return {"stories": enriched}
198
 
199
 
200
+ @app.post("/api/search")
201
+ async def search(req: SearchRequest):
202
+ result = await retriever.retrieve(req.query)
203
+ all_cards = _extract_story_cards(result.stories)
204
+
205
+ # Use Gemini as relevance filter — generate text, keep only cited stories
206
+ full_text = await generate_response(
207
+ req.query, result, link_lookup=_link_lookup,
208
+ )
209
+
210
+ cited_cards = [
211
+ c for c in all_cards
212
+ if c.get("link") and c["link"] in full_text
213
+ ]
214
+
215
+ return {"stories": cited_cards}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/mediastorm/rag/retriever.py CHANGED
@@ -21,7 +21,7 @@ class HybridRetriever:
21
  embedder: Embedder,
22
  router: QueryRouter,
23
  top_k_retrieval: int = 50,
24
- top_k_final: int = 4,
25
  ):
26
  self._store = vector_store
27
  self._bm25 = bm25_store
 
21
  embedder: Embedder,
22
  router: QueryRouter,
23
  top_k_retrieval: int = 50,
24
+ top_k_final: int = 8,
25
  ):
26
  self._store = vector_store
27
  self._bm25 = bm25_store
tests/test_api.py CHANGED
@@ -6,13 +6,27 @@ from mediastorm.api import app
6
 
7
 
8
  @pytest.mark.asyncio
9
- async def test_chat_requires_query():
10
  transport = ASGITransport(app=app)
11
  async with AsyncClient(transport=transport, base_url="http://test") as client:
12
- resp = await client.post("/api/chat", json={})
13
  assert resp.status_code == 422
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  @pytest.mark.asyncio
17
  async def test_health():
18
  transport = ASGITransport(app=app)
 
6
 
7
 
8
  @pytest.mark.asyncio
9
+ async def test_search_requires_query():
10
  transport = ASGITransport(app=app)
11
  async with AsyncClient(transport=transport, base_url="http://test") as client:
12
+ resp = await client.post("/api/search", json={})
13
  assert resp.status_code == 422
14
 
15
 
16
+ @pytest.mark.asyncio
17
+ async def test_search_returns_stories():
18
+ """Integration test — requires GEMINI_API_KEY and populated ChromaDB."""
19
+ if not os.environ.get("GEMINI_API_KEY"):
20
+ pytest.skip("GEMINI_API_KEY not set")
21
+ transport = ASGITransport(app=app)
22
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
23
+ resp = await client.post("/api/search", json={"query": "Emmy winning documentaries"})
24
+ assert resp.status_code == 200
25
+ data = resp.json()
26
+ assert "stories" in data
27
+ assert isinstance(data["stories"], list)
28
+
29
+
30
  @pytest.mark.asyncio
31
  async def test_health():
32
  transport = ASGITransport(app=app)