opencode-zerogpu / openai_compat.py
serenichron's picture
Initial implementation of ZeroGPU OpenCode Provider
adcb9bd
"""OpenAI-compatible API request/response format handling."""
import time
import uuid
import json
import logging
from dataclasses import dataclass, field
from typing import Optional, Generator, Literal
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
# --- Request Models ---
class ChatMessage(BaseModel):
"""A single message in the conversation."""
role: Literal["system", "user", "assistant"]
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI-compatible chat completion request."""
model: str = Field(..., description="HuggingFace model ID")
messages: list[ChatMessage]
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
top_p: float = Field(default=0.95, ge=0.0, le=1.0)
max_tokens: Optional[int] = Field(default=512, ge=1, le=8192)
stream: bool = False
stop: Optional[list[str]] = None
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
n: int = Field(default=1, ge=1, le=1) # Only support n=1 for now
user: Optional[str] = None
# --- Response Models ---
class ChatCompletionChoice(BaseModel):
"""A single completion choice."""
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "content_filter"] = "stop"
class ChatCompletionUsage(BaseModel):
"""Token usage statistics."""
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
"""OpenAI-compatible chat completion response."""
id: str
object: str = "chat.completion"
created: int
model: str
choices: list[ChatCompletionChoice]
usage: ChatCompletionUsage
# --- Streaming Response Models ---
class DeltaMessage(BaseModel):
"""Delta content for streaming responses."""
role: Optional[str] = None
content: Optional[str] = None
class StreamChoice(BaseModel):
"""A single streaming choice."""
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
class ChatCompletionChunk(BaseModel):
"""OpenAI-compatible streaming chunk."""
id: str
object: str = "chat.completion.chunk"
created: int
model: str
choices: list[StreamChoice]
# --- Helper Functions ---
def generate_completion_id() -> str:
"""Generate a unique completion ID."""
return f"chatcmpl-{uuid.uuid4().hex[:24]}"
def create_chat_response(
model: str,
content: str,
prompt_tokens: int = 0,
completion_tokens: int = 0,
finish_reason: str = "stop",
) -> ChatCompletionResponse:
"""Create a complete chat completion response."""
return ChatCompletionResponse(
id=generate_completion_id(),
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatMessage(role="assistant", content=content),
finish_reason=finish_reason,
)
],
usage=ChatCompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
def create_stream_chunk(
completion_id: str,
model: str,
content: Optional[str] = None,
role: Optional[str] = None,
finish_reason: Optional[str] = None,
) -> ChatCompletionChunk:
"""Create a single streaming chunk."""
return ChatCompletionChunk(
id=completion_id,
created=int(time.time()),
model=model,
choices=[
StreamChoice(
index=0,
delta=DeltaMessage(role=role, content=content),
finish_reason=finish_reason,
)
],
)
def stream_response_generator(
model: str,
token_generator: Generator[str, None, None],
) -> Generator[str, None, None]:
"""
Convert a token generator to SSE-formatted streaming response.
Yields SSE-formatted strings ready for HTTP streaming.
"""
completion_id = generate_completion_id()
# First chunk: role
first_chunk = create_stream_chunk(
completion_id=completion_id,
model=model,
role="assistant",
)
yield f"data: {first_chunk.model_dump_json()}\n\n"
# Content chunks
for token in token_generator:
chunk = create_stream_chunk(
completion_id=completion_id,
model=model,
content=token,
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Final chunk: finish reason
final_chunk = create_stream_chunk(
completion_id=completion_id,
model=model,
finish_reason="stop",
)
yield f"data: {final_chunk.model_dump_json()}\n\n"
# End marker
yield "data: [DONE]\n\n"
def messages_to_dicts(messages: list[ChatMessage]) -> list[dict[str, str]]:
"""Convert Pydantic ChatMessage objects to simple dicts."""
return [{"role": msg.role, "content": msg.content} for msg in messages]
def estimate_tokens(text: str) -> int:
"""
Rough token count estimation.
This is a simple approximation - actual token count depends on the tokenizer.
Rule of thumb: ~4 characters per token for English text.
"""
return max(1, len(text) // 4)
@dataclass
class InferenceParams:
"""Extracted inference parameters from request."""
model_id: str
messages: list[dict[str, str]]
max_new_tokens: int
temperature: float
top_p: float
stop_sequences: Optional[list[str]]
stream: bool
@classmethod
def from_request(cls, request: ChatCompletionRequest) -> "InferenceParams":
"""Extract inference parameters from an OpenAI-compatible request."""
return cls(
model_id=request.model,
messages=messages_to_dicts(request.messages),
max_new_tokens=request.max_tokens or 512,
temperature=request.temperature,
top_p=request.top_p,
stop_sequences=request.stop,
stream=request.stream,
)
# --- Error Responses ---
class ErrorDetail(BaseModel):
"""Error detail for API error responses."""
message: str
type: str
param: Optional[str] = None
code: Optional[str] = None
class ErrorResponse(BaseModel):
"""OpenAI-compatible error response."""
error: ErrorDetail
def create_error_response(
message: str,
error_type: str = "invalid_request_error",
param: Optional[str] = None,
code: Optional[str] = None,
) -> ErrorResponse:
"""Create an error response."""
return ErrorResponse(
error=ErrorDetail(
message=message,
type=error_type,
param=param,
code=code,
)
)