Spaces:
Sleeping
Sleeping
| """OpenAI-compatible API request/response format handling.""" | |
| import time | |
| import uuid | |
| import json | |
| import logging | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Generator, Literal | |
| from pydantic import BaseModel, Field | |
| logger = logging.getLogger(__name__) | |
| # --- Request Models --- | |
| class ChatMessage(BaseModel): | |
| """A single message in the conversation.""" | |
| role: Literal["system", "user", "assistant"] | |
| content: str | |
| class ChatCompletionRequest(BaseModel): | |
| """OpenAI-compatible chat completion request.""" | |
| model: str = Field(..., description="HuggingFace model ID") | |
| messages: list[ChatMessage] | |
| temperature: float = Field(default=0.7, ge=0.0, le=2.0) | |
| top_p: float = Field(default=0.95, ge=0.0, le=1.0) | |
| max_tokens: Optional[int] = Field(default=512, ge=1, le=8192) | |
| stream: bool = False | |
| stop: Optional[list[str]] = None | |
| presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) | |
| frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) | |
| n: int = Field(default=1, ge=1, le=1) # Only support n=1 for now | |
| user: Optional[str] = None | |
| # --- Response Models --- | |
| class ChatCompletionChoice(BaseModel): | |
| """A single completion choice.""" | |
| index: int | |
| message: ChatMessage | |
| finish_reason: Literal["stop", "length", "content_filter"] = "stop" | |
| class ChatCompletionUsage(BaseModel): | |
| """Token usage statistics.""" | |
| prompt_tokens: int | |
| completion_tokens: int | |
| total_tokens: int | |
| class ChatCompletionResponse(BaseModel): | |
| """OpenAI-compatible chat completion response.""" | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: list[ChatCompletionChoice] | |
| usage: ChatCompletionUsage | |
| # --- Streaming Response Models --- | |
| class DeltaMessage(BaseModel): | |
| """Delta content for streaming responses.""" | |
| role: Optional[str] = None | |
| content: Optional[str] = None | |
| class StreamChoice(BaseModel): | |
| """A single streaming choice.""" | |
| index: int | |
| delta: DeltaMessage | |
| finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None | |
| class ChatCompletionChunk(BaseModel): | |
| """OpenAI-compatible streaming chunk.""" | |
| id: str | |
| object: str = "chat.completion.chunk" | |
| created: int | |
| model: str | |
| choices: list[StreamChoice] | |
| # --- Helper Functions --- | |
| def generate_completion_id() -> str: | |
| """Generate a unique completion ID.""" | |
| return f"chatcmpl-{uuid.uuid4().hex[:24]}" | |
| def create_chat_response( | |
| model: str, | |
| content: str, | |
| prompt_tokens: int = 0, | |
| completion_tokens: int = 0, | |
| finish_reason: str = "stop", | |
| ) -> ChatCompletionResponse: | |
| """Create a complete chat completion response.""" | |
| return ChatCompletionResponse( | |
| id=generate_completion_id(), | |
| created=int(time.time()), | |
| model=model, | |
| choices=[ | |
| ChatCompletionChoice( | |
| index=0, | |
| message=ChatMessage(role="assistant", content=content), | |
| finish_reason=finish_reason, | |
| ) | |
| ], | |
| usage=ChatCompletionUsage( | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| total_tokens=prompt_tokens + completion_tokens, | |
| ), | |
| ) | |
| def create_stream_chunk( | |
| completion_id: str, | |
| model: str, | |
| content: Optional[str] = None, | |
| role: Optional[str] = None, | |
| finish_reason: Optional[str] = None, | |
| ) -> ChatCompletionChunk: | |
| """Create a single streaming chunk.""" | |
| return ChatCompletionChunk( | |
| id=completion_id, | |
| created=int(time.time()), | |
| model=model, | |
| choices=[ | |
| StreamChoice( | |
| index=0, | |
| delta=DeltaMessage(role=role, content=content), | |
| finish_reason=finish_reason, | |
| ) | |
| ], | |
| ) | |
| def stream_response_generator( | |
| model: str, | |
| token_generator: Generator[str, None, None], | |
| ) -> Generator[str, None, None]: | |
| """ | |
| Convert a token generator to SSE-formatted streaming response. | |
| Yields SSE-formatted strings ready for HTTP streaming. | |
| """ | |
| completion_id = generate_completion_id() | |
| # First chunk: role | |
| first_chunk = create_stream_chunk( | |
| completion_id=completion_id, | |
| model=model, | |
| role="assistant", | |
| ) | |
| yield f"data: {first_chunk.model_dump_json()}\n\n" | |
| # Content chunks | |
| for token in token_generator: | |
| chunk = create_stream_chunk( | |
| completion_id=completion_id, | |
| model=model, | |
| content=token, | |
| ) | |
| yield f"data: {chunk.model_dump_json()}\n\n" | |
| # Final chunk: finish reason | |
| final_chunk = create_stream_chunk( | |
| completion_id=completion_id, | |
| model=model, | |
| finish_reason="stop", | |
| ) | |
| yield f"data: {final_chunk.model_dump_json()}\n\n" | |
| # End marker | |
| yield "data: [DONE]\n\n" | |
| def messages_to_dicts(messages: list[ChatMessage]) -> list[dict[str, str]]: | |
| """Convert Pydantic ChatMessage objects to simple dicts.""" | |
| return [{"role": msg.role, "content": msg.content} for msg in messages] | |
| def estimate_tokens(text: str) -> int: | |
| """ | |
| Rough token count estimation. | |
| This is a simple approximation - actual token count depends on the tokenizer. | |
| Rule of thumb: ~4 characters per token for English text. | |
| """ | |
| return max(1, len(text) // 4) | |
| class InferenceParams: | |
| """Extracted inference parameters from request.""" | |
| model_id: str | |
| messages: list[dict[str, str]] | |
| max_new_tokens: int | |
| temperature: float | |
| top_p: float | |
| stop_sequences: Optional[list[str]] | |
| stream: bool | |
| def from_request(cls, request: ChatCompletionRequest) -> "InferenceParams": | |
| """Extract inference parameters from an OpenAI-compatible request.""" | |
| return cls( | |
| model_id=request.model, | |
| messages=messages_to_dicts(request.messages), | |
| max_new_tokens=request.max_tokens or 512, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| stop_sequences=request.stop, | |
| stream=request.stream, | |
| ) | |
| # --- Error Responses --- | |
| class ErrorDetail(BaseModel): | |
| """Error detail for API error responses.""" | |
| message: str | |
| type: str | |
| param: Optional[str] = None | |
| code: Optional[str] = None | |
| class ErrorResponse(BaseModel): | |
| """OpenAI-compatible error response.""" | |
| error: ErrorDetail | |
| def create_error_response( | |
| message: str, | |
| error_type: str = "invalid_request_error", | |
| param: Optional[str] = None, | |
| code: Optional[str] = None, | |
| ) -> ErrorResponse: | |
| """Create an error response.""" | |
| return ErrorResponse( | |
| error=ErrorDetail( | |
| message=message, | |
| type=error_type, | |
| param=param, | |
| code=code, | |
| ) | |
| ) | |