File size: 3,662 Bytes
42fa16e a8a31d7 010db11 42fa16e f9aca5d 42fa16e fca7a73 42fa16e a8a31d7 42fa16e f9aca5d 42fa16e a8a31d7 42fa16e a8a31d7 f9aca5d a8a31d7 010db11 a8a31d7 010db11 42fa16e f9aca5d fca7a73 f9aca5d fca7a73 a8a31d7 010db11 a8a31d7 010db11 a8a31d7 f9aca5d 010db11 f9aca5d 010db11 f9aca5d 010db11 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | import asyncio
import json
import logging
import threading
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from src.core.config import settings
from src.core.engine import engine
from src.utils.helpers import get_clean_text
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/chat/completions")
async def chat_completions(request: Request):
if not engine.llm:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
data = await request.json()
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON")
messages = [
{"role": m.get("role", "user"), "content": get_clean_text(m.get("content"))}
for m in data.get("messages", [])
]
max_tokens = data.get("max_tokens", settings.DEFAULT_MAX_TOKENS)
temperature = data.get("temperature", settings.DEFAULT_TEMP)
top_p = data.get("top_p", 0.95)
stop = data.get("stop", [])
if isinstance(stop, str):
stop = [stop]
default_stops = ["<|im_end|>", "<|endoftext|>", "<|file_sep|>"]
for s in default_stops:
if s not in stop:
stop.append(s)
abort_event = threading.Event()
async def stream_generator():
queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def worker():
try:
gen_kwargs = {
"max_tokens": int(max_tokens),
"temperature": float(temperature),
"top_p": float(top_p),
"stop": stop,
"abort_event": abort_event,
}
# Запускаем генерацию
for chunk in engine.generate_stream(messages, **gen_kwargs):
loop.call_soon_threadsafe(queue.put_nowait, chunk)
loop.call_soon_threadsafe(queue.put_nowait, None)
except Exception as e:
if not abort_event.is_set():
logger.error(f"Generation error: {e}")
loop.call_soon_threadsafe(queue.put_nowait, {"error": str(e)})
loop.run_in_executor(None, worker)
try:
while True:
if await request.is_disconnected():
logger.info("Client disconnected! Aborting generation...")
abort_event.set()
break
try:
chunk = await asyncio.wait_for(queue.get(), timeout=0.1)
except asyncio.TimeoutError:
continue
if chunk is None:
yield "data: [DONE]\n\n"
break
if isinstance(chunk, dict) and "error" in chunk:
if abort_event.is_set():
break
err_json = json.dumps(
{"error": {"message": chunk["error"], "type": "internal_error"}}
)
yield f"data: {err_json}\n\n"
break
yield f"data: {json.dumps(chunk)}\n\n"
except asyncio.CancelledError:
logger.info("Task cancelled. Stopping worker.")
abort_event.set()
raise
# Возвращаем стрим
headers = {
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
}
return StreamingResponse(
stream_generator(), media_type="text/event-stream", headers=headers
)
|