"""OpenAI-compatible API request/response format handling.""" import time import uuid import json import logging from dataclasses import dataclass, field from typing import Optional, Generator, Literal from pydantic import BaseModel, Field logger = logging.getLogger(__name__) # --- Request Models --- class ChatMessage(BaseModel): """A single message in the conversation.""" role: Literal["system", "user", "assistant"] content: str class ChatCompletionRequest(BaseModel): """OpenAI-compatible chat completion request.""" model: str = Field(..., description="HuggingFace model ID") messages: list[ChatMessage] temperature: float = Field(default=0.7, ge=0.0, le=2.0) top_p: float = Field(default=0.95, ge=0.0, le=1.0) max_tokens: Optional[int] = Field(default=512, ge=1, le=8192) stream: bool = False stop: Optional[list[str]] = None presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) n: int = Field(default=1, ge=1, le=1) # Only support n=1 for now user: Optional[str] = None # --- Response Models --- class ChatCompletionChoice(BaseModel): """A single completion choice.""" index: int message: ChatMessage finish_reason: Literal["stop", "length", "content_filter"] = "stop" class ChatCompletionUsage(BaseModel): """Token usage statistics.""" prompt_tokens: int completion_tokens: int total_tokens: int class ChatCompletionResponse(BaseModel): """OpenAI-compatible chat completion response.""" id: str object: str = "chat.completion" created: int model: str choices: list[ChatCompletionChoice] usage: ChatCompletionUsage # --- Streaming Response Models --- class DeltaMessage(BaseModel): """Delta content for streaming responses.""" role: Optional[str] = None content: Optional[str] = None class StreamChoice(BaseModel): """A single streaming choice.""" index: int delta: DeltaMessage finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None class ChatCompletionChunk(BaseModel): """OpenAI-compatible streaming chunk.""" id: str object: str = "chat.completion.chunk" created: int model: str choices: list[StreamChoice] # --- Helper Functions --- def generate_completion_id() -> str: """Generate a unique completion ID.""" return f"chatcmpl-{uuid.uuid4().hex[:24]}" def create_chat_response( model: str, content: str, prompt_tokens: int = 0, completion_tokens: int = 0, finish_reason: str = "stop", ) -> ChatCompletionResponse: """Create a complete chat completion response.""" return ChatCompletionResponse( id=generate_completion_id(), created=int(time.time()), model=model, choices=[ ChatCompletionChoice( index=0, message=ChatMessage(role="assistant", content=content), finish_reason=finish_reason, ) ], usage=ChatCompletionUsage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ), ) def create_stream_chunk( completion_id: str, model: str, content: Optional[str] = None, role: Optional[str] = None, finish_reason: Optional[str] = None, ) -> ChatCompletionChunk: """Create a single streaming chunk.""" return ChatCompletionChunk( id=completion_id, created=int(time.time()), model=model, choices=[ StreamChoice( index=0, delta=DeltaMessage(role=role, content=content), finish_reason=finish_reason, ) ], ) def stream_response_generator( model: str, token_generator: Generator[str, None, None], ) -> Generator[str, None, None]: """ Convert a token generator to SSE-formatted streaming response. Yields SSE-formatted strings ready for HTTP streaming. """ completion_id = generate_completion_id() # First chunk: role first_chunk = create_stream_chunk( completion_id=completion_id, model=model, role="assistant", ) yield f"data: {first_chunk.model_dump_json()}\n\n" # Content chunks for token in token_generator: chunk = create_stream_chunk( completion_id=completion_id, model=model, content=token, ) yield f"data: {chunk.model_dump_json()}\n\n" # Final chunk: finish reason final_chunk = create_stream_chunk( completion_id=completion_id, model=model, finish_reason="stop", ) yield f"data: {final_chunk.model_dump_json()}\n\n" # End marker yield "data: [DONE]\n\n" def messages_to_dicts(messages: list[ChatMessage]) -> list[dict[str, str]]: """Convert Pydantic ChatMessage objects to simple dicts.""" return [{"role": msg.role, "content": msg.content} for msg in messages] def estimate_tokens(text: str) -> int: """ Rough token count estimation. This is a simple approximation - actual token count depends on the tokenizer. Rule of thumb: ~4 characters per token for English text. """ return max(1, len(text) // 4) @dataclass class InferenceParams: """Extracted inference parameters from request.""" model_id: str messages: list[dict[str, str]] max_new_tokens: int temperature: float top_p: float stop_sequences: Optional[list[str]] stream: bool @classmethod def from_request(cls, request: ChatCompletionRequest) -> "InferenceParams": """Extract inference parameters from an OpenAI-compatible request.""" return cls( model_id=request.model, messages=messages_to_dicts(request.messages), max_new_tokens=request.max_tokens or 512, temperature=request.temperature, top_p=request.top_p, stop_sequences=request.stop, stream=request.stream, ) # --- Error Responses --- class ErrorDetail(BaseModel): """Error detail for API error responses.""" message: str type: str param: Optional[str] = None code: Optional[str] = None class ErrorResponse(BaseModel): """OpenAI-compatible error response.""" error: ErrorDetail def create_error_response( message: str, error_type: str = "invalid_request_error", param: Optional[str] = None, code: Optional[str] = None, ) -> ErrorResponse: """Create an error response.""" return ErrorResponse( error=ErrorDetail( message=message, type=error_type, param=param, code=code, ) )