Spaces:
Sleeping
Sleeping
| """OpenAI-compatible chat completion routes.""" | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import time | |
| from typing import Literal | |
| from fastapi import APIRouter, Request | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from pydantic import BaseModel, Field | |
| from app.utils.config import settings | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(tags=["chat"]) | |
| class ChatMessage(BaseModel): | |
| """OpenAI-compatible chat message format.""" | |
| role: Literal["system", "user", "assistant"] | |
| content: str | |
| class ChatCompletionRequest(BaseModel): | |
| """Subset of OpenAI chat completion request fields.""" | |
| model: str = Field(default_factory=lambda: settings.model_name) | |
| messages: list[ChatMessage] | |
| stream: bool = False | |
| temperature: float | None = None | |
| top_p: float | None = None | |
| max_tokens: int | None = Field(default=None, ge=1) | |
| def _sse_event(payload: dict) -> str: | |
| """Format one SSE data event.""" | |
| return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" | |
| async def create_chat_completion(request: Request, body: ChatCompletionRequest): | |
| """OpenAI-compatible completions with optional SSE token streaming.""" | |
| prompt_service = request.app.state.prompt_service | |
| model_service = request.app.state.model_service | |
| injected_messages = prompt_service.inject_system_prompt( | |
| [message.model_dump() for message in body.messages] | |
| ) | |
| temperature = body.temperature if body.temperature is not None else settings.default_temperature | |
| top_p = body.top_p if body.top_p is not None else settings.default_top_p | |
| max_tokens = body.max_tokens if body.max_tokens is not None else settings.default_max_tokens | |
| created = int(time.time()) | |
| if body.stream: | |
| async def event_generator(): | |
| request_id = None | |
| try: | |
| # Initial chunk with assistant role to follow OpenAI streaming style. | |
| bootstrap_chunk = { | |
| "id": "chatcmpl-bootstrap", | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": body.model, | |
| "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], | |
| } | |
| yield _sse_event(bootstrap_chunk) | |
| async for stream_request_id, delta in model_service.stream_text( | |
| injected_messages, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| ): | |
| request_id = stream_request_id | |
| chunk = { | |
| "id": request_id, | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": body.model, | |
| "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}], | |
| } | |
| yield _sse_event(chunk) | |
| final_chunk = { | |
| "id": request_id or "chatcmpl-final", | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": body.model, | |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], | |
| } | |
| yield _sse_event(final_chunk) | |
| yield "data: [DONE]\n\n" | |
| except (RuntimeError, ValueError): # pragma: no cover - runtime guard | |
| logger.exception("Failed to stream completion for request") | |
| error_payload = { | |
| "error": { | |
| "message": "Failed to stream completion for request", | |
| "type": "server_error", | |
| } | |
| } | |
| yield _sse_event(error_payload) | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| request_id, text = await model_service.complete_text( | |
| injected_messages, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| ) | |
| response_payload = { | |
| "id": request_id, | |
| "object": "chat.completion", | |
| "created": created, | |
| "model": body.model, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": {"role": "assistant", "content": text}, | |
| "finish_reason": "stop", | |
| } | |
| ], | |
| } | |
| return JSONResponse(response_payload) | |