File size: 3,662 Bytes
42fa16e
 
a8a31d7
010db11
42fa16e
f9aca5d
 
42fa16e
 
 
fca7a73
42fa16e
a8a31d7
 
 
42fa16e
 
 
 
 
 
f9aca5d
42fa16e
a8a31d7
 
 
 
 
42fa16e
 
 
 
a8a31d7
f9aca5d
 
a8a31d7
 
 
 
 
010db11
a8a31d7
 
 
 
010db11
42fa16e
 
f9aca5d
 
fca7a73
f9aca5d
fca7a73
a8a31d7
 
 
 
 
010db11
a8a31d7
 
010db11
a8a31d7
f9aca5d
010db11
 
f9aca5d
010db11
 
f9aca5d
 
 
 
010db11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import asyncio
import json
import logging
import threading

from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse

from src.core.config import settings
from src.core.engine import engine
from src.utils.helpers import get_clean_text

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

router = APIRouter()


@router.post("/chat/completions")
async def chat_completions(request: Request):
    if not engine.llm:
        raise HTTPException(status_code=500, detail="Model not loaded")

    try:
        data = await request.json()
    except Exception:
        raise HTTPException(status_code=400, detail="Invalid JSON")

    messages = [
        {"role": m.get("role", "user"), "content": get_clean_text(m.get("content"))}
        for m in data.get("messages", [])
    ]

    max_tokens = data.get("max_tokens", settings.DEFAULT_MAX_TOKENS)
    temperature = data.get("temperature", settings.DEFAULT_TEMP)
    top_p = data.get("top_p", 0.95)
    stop = data.get("stop", [])
    if isinstance(stop, str):
        stop = [stop]

    default_stops = ["<|im_end|>", "<|endoftext|>", "<|file_sep|>"]
    for s in default_stops:
        if s not in stop:
            stop.append(s)

    abort_event = threading.Event()

    async def stream_generator():
        queue = asyncio.Queue()
        loop = asyncio.get_running_loop()

        def worker():
            try:
                gen_kwargs = {
                    "max_tokens": int(max_tokens),
                    "temperature": float(temperature),
                    "top_p": float(top_p),
                    "stop": stop,
                    "abort_event": abort_event,
                }

                # Запускаем генерацию
                for chunk in engine.generate_stream(messages, **gen_kwargs):
                    loop.call_soon_threadsafe(queue.put_nowait, chunk)

                loop.call_soon_threadsafe(queue.put_nowait, None)
            except Exception as e:
                if not abort_event.is_set():
                    logger.error(f"Generation error: {e}")
                loop.call_soon_threadsafe(queue.put_nowait, {"error": str(e)})

        loop.run_in_executor(None, worker)

        try:
            while True:
                if await request.is_disconnected():
                    logger.info("Client disconnected! Aborting generation...")
                    abort_event.set()
                    break

                try:
                    chunk = await asyncio.wait_for(queue.get(), timeout=0.1)
                except asyncio.TimeoutError:
                    continue

                if chunk is None:
                    yield "data: [DONE]\n\n"
                    break

                if isinstance(chunk, dict) and "error" in chunk:
                    if abort_event.is_set():
                        break
                    err_json = json.dumps(
                        {"error": {"message": chunk["error"], "type": "internal_error"}}
                    )
                    yield f"data: {err_json}\n\n"
                    break

                yield f"data: {json.dumps(chunk)}\n\n"

        except asyncio.CancelledError:
            logger.info("Task cancelled. Stopping worker.")
            abort_event.set()
            raise

    # Возвращаем стрим
    headers = {
        "X-Accel-Buffering": "no",
        "Cache-Control": "no-cache",
        "Connection": "keep-alive",
        "Content-Type": "text/event-stream",
    }
    return StreamingResponse(
        stream_generator(), media_type="text/event-stream", headers=headers
    )