ruslanmv's picture
VectorDB
215df55
from __future__ import annotations
import json
from typing import Any, Iterator, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
from starlette.responses import StreamingResponse
from ..deps import get_settings
from ..core.config import Settings
from ..services.chat_service import ChatService
router = APIRouter()
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
query: Optional[str] = None
question: Optional[str] = None
prompt: Optional[str] = None
messages: Optional[List[ChatMessage]] = None
def as_text(self) -> str:
if self.query:
return self.query
if self.question:
return self.question
if self.prompt:
return self.prompt
if self.messages:
for m in reversed(self.messages):
if m.role.lower() == "user":
return m.content
return self.messages[-1].content
raise ValueError("Body must include 'query'/'question'/'prompt' or 'messages'")
class ChatResponse(BaseModel):
answer: str
sources: List[str] = Field(default_factory=list)
@router.post("/chat", response_model=ChatResponse)
async def chat(req: ChatRequest, settings: Settings = Depends(get_settings)):
try:
text = req.as_text()
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
svc = ChatService(settings)
try:
# run blocking client in a threadpool
answer, sources = await run_in_threadpool(svc.answer_with_sources, text)
return ChatResponse(answer=answer, sources=sources)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
raise HTTPException(status_code=502, detail=f"Inference error: {e}")
@router.get("/chat", response_model=ChatResponse)
async def chat_get(query: str = Query(...), settings: Settings = Depends(get_settings)):
svc = ChatService(settings)
try:
answer, sources = await run_in_threadpool(svc.answer_with_sources, query)
return ChatResponse(answer=answer, sources=sources)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
raise HTTPException(status_code=502, detail=f"Inference error: {e}")
def _sse_line(obj: Any) -> str:
payload = obj if isinstance(obj, str) else json.dumps(obj, ensure_ascii=False)
return f"data: {payload}\n\n"
@router.get("/chat/stream")
async def chat_stream(query: str = Query(...), settings: Settings = Depends(get_settings)):
"""
SSE of token deltas. We iterate the sync streaming client in a threadpool
so the event loop stays free.
"""
svc = ChatService(settings)
def sync_stream() -> Iterator[str]:
# send anti-buffer padding + ping immediately
yield ":" + (" " * 2048) + "\n\n"
yield "retry: 1500\n\n"
yield "event: ping\ndata: 0\n\n"
any_tokens = False
try:
for token in svc.stream_answer(query):
if token:
any_tokens = True
yield _sse_line({"delta": token})
if not any_tokens:
yield _sse_line({"delta": ""})
yield _sse_line("[DONE]")
except GeneratorExit:
return
except Exception as e:
yield _sse_line({"error": str(e)})
headers = {
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
"Content-Encoding": "identity",
}
# iterate the sync generator in a threadpool (non-blocking for the loop)
return StreamingResponse(
iterate_in_threadpool(sync_stream()),
media_type="text/event-stream; charset=utf-8",
headers=headers,
)
@router.post("/chat/stream")
async def chat_stream_post(req: ChatRequest, settings: Settings = Depends(get_settings)):
try:
q = req.as_text()
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
return await chat_stream(query=q, settings=settings)