| import asyncio |
| import json |
|
|
| from fastapi import APIRouter, Request |
| from fastapi.responses import JSONResponse, StreamingResponse |
|
|
| from src.core.config import settings |
| from src.core.engine import engine |
| from src.utils.helpers import get_clean_text |
|
|
| router = APIRouter() |
|
|
|
|
| @router.post("/chat/completions") |
| async def chat_completions(request: Request): |
| if not engine.llm: |
| return JSONResponse({"error": "Model not loaded"}, status_code=500) |
|
|
| data = await request.json() |
| messages = [ |
| {"role": m.get("role", "user"), "content": get_clean_text(m.get("content"))} |
| for m in data.get("messages", []) |
| ] |
|
|
| stream = data.get("stream", True) |
|
|
| async def stream_generator(): |
| |
| async with engine.lock: |
| import threading |
|
|
| |
| q: asyncio.Queue = asyncio.Queue() |
| stop_event = threading.Event() |
| loop = asyncio.get_running_loop() |
|
|
| def worker(): |
| try: |
| for chunk in engine.llm.create_chat_completion( |
| messages=messages, |
| max_tokens=int( |
| data.get("max_tokens", settings.DEFAULT_MAX_TOKENS) |
| ), |
| temperature=float( |
| data.get("temperature", settings.DEFAULT_TEMP) |
| ), |
| stream=True, |
| ): |
| |
| if stop_event.is_set(): |
| break |
| loop.call_soon_threadsafe(q.put_nowait, chunk) |
| except Exception as e: |
| |
| loop.call_soon_threadsafe(q.put_nowait, {"__error": str(e)}) |
| finally: |
| |
| loop.call_soon_threadsafe(q.put_nowait, None) |
|
|
| |
| worker_future = loop.run_in_executor(None, worker) |
|
|
| try: |
| while True: |
| item = await q.get() |
| if item is None: |
| |
| break |
| |
| if isinstance(item, dict) and item.get("__error"): |
| yield f"data: {json.dumps({'error': item['__error']})}\n\n" |
| break |
| yield f"data: {json.dumps(item)}\n\n" |
| yield "data: [DONE]\n\n" |
| except asyncio.CancelledError: |
| |
| stop_event.set() |
| try: |
| await worker_future |
| except Exception: |
| pass |
| raise |
| finally: |
| |
| stop_event.set() |
| try: |
| await worker_future |
| except Exception: |
| pass |
|
|
| if stream: |
| return StreamingResponse(stream_generator(), media_type="text/event-stream") |
|
|
| |
| async with engine.lock: |
| result = await asyncio.to_thread( |
| engine.generate, |
| messages, |
| data.get("max_tokens", settings.DEFAULT_MAX_TOKENS), |
| data.get("temperature", settings.DEFAULT_TEMP), |
| stream=False, |
| ) |
| return result |
|
|