| import asyncio |
| import json |
| import logging |
| import threading |
|
|
| from fastapi import APIRouter, HTTPException, Request |
| from fastapi.responses import StreamingResponse |
|
|
| from src.core.config import settings |
| from src.core.engine import engine |
| from src.utils.helpers import get_clean_text |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| router = APIRouter() |
|
|
|
|
| @router.post("/chat/completions") |
| async def chat_completions(request: Request): |
| if not engine.llm: |
| raise HTTPException(status_code=500, detail="Model not loaded") |
|
|
| try: |
| data = await request.json() |
| except Exception: |
| raise HTTPException(status_code=400, detail="Invalid JSON") |
|
|
| messages = [ |
| {"role": m.get("role", "user"), "content": get_clean_text(m.get("content"))} |
| for m in data.get("messages", []) |
| ] |
|
|
| max_tokens = data.get("max_tokens", settings.DEFAULT_MAX_TOKENS) |
| temperature = data.get("temperature", settings.DEFAULT_TEMP) |
| top_p = data.get("top_p", 0.95) |
| stop = data.get("stop", []) |
| if isinstance(stop, str): |
| stop = [stop] |
|
|
| default_stops = ["<|im_end|>", "<|endoftext|>", "<|file_sep|>"] |
| for s in default_stops: |
| if s not in stop: |
| stop.append(s) |
|
|
| abort_event = threading.Event() |
|
|
| async def stream_generator(): |
| queue = asyncio.Queue() |
| loop = asyncio.get_running_loop() |
|
|
| def worker(): |
| try: |
| gen_kwargs = { |
| "max_tokens": int(max_tokens), |
| "temperature": float(temperature), |
| "top_p": float(top_p), |
| "stop": stop, |
| "abort_event": abort_event, |
| } |
|
|
| |
| for chunk in engine.generate_stream(messages, **gen_kwargs): |
| loop.call_soon_threadsafe(queue.put_nowait, chunk) |
|
|
| loop.call_soon_threadsafe(queue.put_nowait, None) |
| except Exception as e: |
| if not abort_event.is_set(): |
| logger.error(f"Generation error: {e}") |
| loop.call_soon_threadsafe(queue.put_nowait, {"error": str(e)}) |
|
|
| loop.run_in_executor(None, worker) |
|
|
| try: |
| while True: |
| if await request.is_disconnected(): |
| logger.info("Client disconnected! Aborting generation...") |
| abort_event.set() |
| break |
|
|
| try: |
| chunk = await asyncio.wait_for(queue.get(), timeout=0.1) |
| except asyncio.TimeoutError: |
| continue |
|
|
| if chunk is None: |
| yield "data: [DONE]\n\n" |
| break |
|
|
| if isinstance(chunk, dict) and "error" in chunk: |
| if abort_event.is_set(): |
| break |
| err_json = json.dumps( |
| {"error": {"message": chunk["error"], "type": "internal_error"}} |
| ) |
| yield f"data: {err_json}\n\n" |
| break |
|
|
| yield f"data: {json.dumps(chunk)}\n\n" |
|
|
| except asyncio.CancelledError: |
| logger.info("Task cancelled. Stopping worker.") |
| abort_event.set() |
| raise |
|
|
| |
| headers = { |
| "X-Accel-Buffering": "no", |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "Content-Type": "text/event-stream", |
| } |
| return StreamingResponse( |
| stream_generator(), media_type="text/event-stream", headers=headers |
| ) |
|
|