File size: 4,315 Bytes
6338f31
 
 
215df55
6338f31
d62044b
6338f31
215df55
0e040d4
 
fed7eb0
 
d62044b
fed7eb0
 
 
6338f31
d62044b
 
 
 
6338f31
d62044b
 
 
 
 
 
 
215df55
 
 
 
 
 
d62044b
 
215df55
 
0e040d4
d62044b
 
6338f31
d62044b
 
6338f31
 
d62044b
fed7eb0
d62044b
fed7eb0
d62044b
 
 
 
4740c16
215df55
6338f31
 
4740c16
 
 
 
d62044b
6338f31
d62044b
 
 
4740c16
6338f31
 
4740c16
 
 
 
0e040d4
6338f31
0e040d4
 
 
 
6338f31
0e040d4
 
215df55
 
 
 
0e040d4
 
215df55
 
0e040d4
215df55
0e040d4
215df55
 
0e040d4
215df55
0e040d4
 
 
 
 
 
215df55
 
0e040d4
6338f31
0e040d4
 
 
6338f31
0e040d4
6338f31
0e040d4
215df55
 
 
 
 
 
0e040d4
6338f31
0e040d4
 
 
 
 
 
215df55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import annotations

import json
from typing import Any, Iterator, List, Optional

from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
from starlette.responses import StreamingResponse

from ..deps import get_settings
from ..core.config import Settings
from ..services.chat_service import ChatService

router = APIRouter()


class ChatMessage(BaseModel):
    role: str
    content: str


class ChatRequest(BaseModel):
    query: Optional[str] = None
    question: Optional[str] = None
    prompt: Optional[str] = None
    messages: Optional[List[ChatMessage]] = None

    def as_text(self) -> str:
        if self.query:
            return self.query
        if self.question:
            return self.question
        if self.prompt:
            return self.prompt
        if self.messages:
            for m in reversed(self.messages):
                if m.role.lower() == "user":
                    return m.content
            return self.messages[-1].content
        raise ValueError("Body must include 'query'/'question'/'prompt' or 'messages'")


class ChatResponse(BaseModel):
    answer: str
    sources: List[str] = Field(default_factory=list)


@router.post("/chat", response_model=ChatResponse)
async def chat(req: ChatRequest, settings: Settings = Depends(get_settings)):
    try:
        text = req.as_text()
    except ValueError as e:
        raise HTTPException(status_code=422, detail=str(e))
    svc = ChatService(settings)
    try:
        # run blocking client in a threadpool
        answer, sources = await run_in_threadpool(svc.answer_with_sources, text)
        return ChatResponse(answer=answer, sources=sources)
    except PermissionError as e:
        raise HTTPException(status_code=403, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=502, detail=f"Inference error: {e}")


@router.get("/chat", response_model=ChatResponse)
async def chat_get(query: str = Query(...), settings: Settings = Depends(get_settings)):
    svc = ChatService(settings)
    try:
        answer, sources = await run_in_threadpool(svc.answer_with_sources, query)
        return ChatResponse(answer=answer, sources=sources)
    except PermissionError as e:
        raise HTTPException(status_code=403, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=502, detail=f"Inference error: {e}")


def _sse_line(obj: Any) -> str:
    payload = obj if isinstance(obj, str) else json.dumps(obj, ensure_ascii=False)
    return f"data: {payload}\n\n"


@router.get("/chat/stream")
async def chat_stream(query: str = Query(...), settings: Settings = Depends(get_settings)):
    """
    SSE of token deltas. We iterate the sync streaming client in a threadpool
    so the event loop stays free.
    """
    svc = ChatService(settings)

    def sync_stream() -> Iterator[str]:
        # send anti-buffer padding + ping immediately
        yield ":" + (" " * 2048) + "\n\n"
        yield "retry: 1500\n\n"
        yield "event: ping\ndata: 0\n\n"

        any_tokens = False
        try:
            for token in svc.stream_answer(query):
                if token:
                    any_tokens = True
                    yield _sse_line({"delta": token})
            if not any_tokens:
                yield _sse_line({"delta": ""})
            yield _sse_line("[DONE]")
        except GeneratorExit:
            return
        except Exception as e:
            yield _sse_line({"error": str(e)})

    headers = {
        "Cache-Control": "no-cache, no-transform",
        "X-Accel-Buffering": "no",
        "Connection": "keep-alive",
        "Content-Encoding": "identity",
    }
    # iterate the sync generator in a threadpool (non-blocking for the loop)
    return StreamingResponse(
        iterate_in_threadpool(sync_stream()),
        media_type="text/event-stream; charset=utf-8",
        headers=headers,
    )


@router.post("/chat/stream")
async def chat_stream_post(req: ChatRequest, settings: Settings = Depends(get_settings)):
    try:
        q = req.as_text()
    except ValueError as e:
        raise HTTPException(status_code=422, detail=str(e))
    return await chat_stream(query=q, settings=settings)