"""OpenRouter API compatible Pydantic models for ACE-Step. This module defines request/response models that conform to OpenRouter's chat completions API specification for audio generation. """ from __future__ import annotations from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field # ============================================================================= # Request Models # ============================================================================= class AudioInputContent(BaseModel): """Audio input content in base64 format.""" data: str = Field(..., description="Base64-encoded audio data") format: str = Field(default="mp3", description="Audio format (mp3, wav, flac, etc.)") class TextContent(BaseModel): """Text content block.""" type: Literal["text"] = "text" text: str = Field(..., description="Text content") class AudioContent(BaseModel): """Audio input content block.""" type: Literal["input_audio"] = "input_audio" input_audio: AudioInputContent # Union type for message content ContentPart = Union[TextContent, AudioContent, Dict[str, Any]] class ChatMessage(BaseModel): """A single message in the chat conversation.""" role: Literal["system", "user", "assistant"] = Field(..., description="Message role") content: Union[str, List[ContentPart]] = Field(..., description="Message content") name: Optional[str] = Field(default=None, description="Optional name for the message author") class AudioConfig(BaseModel): """Audio generation configuration.""" duration: Optional[float] = Field(default=None, description="Target audio duration in seconds") format: str = Field(default="mp3", description="Output audio format") # ACE-Step specific parameters bpm: Optional[int] = Field(default=None, description="Beats per minute") key_scale: Optional[str] = Field(default=None, description="Musical key and scale") time_signature: Optional[str] = Field(default=None, description="Time signature (e.g., 4/4)") vocal_language: Optional[str] = Field(default=None, description="Vocal language code") instrumental: Optional[bool] = Field(default=None, description="Generate instrumental only") class ChatCompletionRequest(BaseModel): """OpenRouter-compatible chat completion request.""" model: str = Field(..., description="Model ID to use") messages: List[ChatMessage] = Field(..., description="List of messages") # Modalities modalities: Optional[List[str]] = Field( default=None, description="Output modalities (e.g., ['audio', 'text'])" ) # Audio configuration audio_config: Optional[AudioConfig] = Field( default=None, description="Audio generation configuration" ) # Standard OpenAI parameters temperature: Optional[float] = Field(default=None, ge=0, le=2) top_p: Optional[float] = Field(default=None, ge=0, le=1) top_k: Optional[int] = Field(default=None, ge=0) max_tokens: Optional[int] = Field(default=None, ge=1) stream: bool = Field(default=False, description="Enable streaming response") stop: Optional[Union[str, List[str]]] = Field(default=None) seed: Optional[Union[int, str]] = Field(default=None, description="Seed(s) for reproducibility. Comma-separated for batch (e.g. '42,123,456')") # ACE-Step specific parameters (extended) thinking: Optional[bool] = Field(default=None, description="Use LM for audio code generation") guidance_scale: Optional[float] = Field(default=None, description="Classifier-free guidance scale") batch_size: Optional[int] = Field(default=None, description="Number of audio samples to generate") # ACE-Step direct fields (bypass message parsing / audio_config) lyrics: str = Field(default="", description="Direct lyrics input (bypass message parsing)") sample_mode: bool = Field(default=False, description="Auto-generate caption/lyrics/metas via LM; user message becomes the query") use_format: bool = Field(default=False, description="Use format_sample to enhance caption/lyrics") use_cot_caption: bool = Field(default=True, description="Use CoT for caption rewriting") use_cot_language: bool = Field(default=True, description="Use CoT for language detection") # Task type task_type: str = Field(default="text2music", description="Task type: text2music, cover, repaint, extract, lego, complete") # Audio editing parameters repainting_start: float = Field(default=0.0, description="Repainting region start (seconds)") repainting_end: Optional[float] = Field(default=None, description="Repainting region end (seconds)") audio_cover_strength: float = Field(default=1.0, description="Audio cover strength (0.0~1.0)") class Config: extra = "allow" # Allow additional fields for forward compatibility # ============================================================================= # Response Models # ============================================================================= class AudioOutputUrl(BaseModel): """Audio output URL (base64 data URL).""" url: str = Field(..., description="Base64 data URL of the audio") class AudioOutput(BaseModel): """Audio output content block.""" type: Literal["audio_url"] = "audio_url" audio_url: AudioOutputUrl class AssistantMessage(BaseModel): """Assistant response message.""" role: Literal["assistant"] = "assistant" content: Optional[str] = Field(default=None, description="Text content") audio: Optional[List[AudioOutput]] = Field(default=None, description="Generated audio files") class Choice(BaseModel): """A single completion choice.""" index: int = Field(default=0) message: AssistantMessage finish_reason: Literal["stop", "length", "content_filter", "error"] = "stop" class Usage(BaseModel): """Token usage statistics.""" prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 class ChatCompletionResponse(BaseModel): """OpenRouter-compatible chat completion response.""" id: str = Field(..., description="Unique completion ID") object: Literal["chat.completion"] = "chat.completion" created: int = Field(..., description="Unix timestamp") model: str = Field(..., description="Model ID used") choices: List[Choice] = Field(..., description="Completion choices") usage: Usage = Field(default_factory=Usage) # Extended metadata system_fingerprint: Optional[str] = Field(default=None) # ============================================================================= # Streaming Response Models # ============================================================================= class DeltaContent(BaseModel): """Delta content for streaming.""" role: Optional[Literal["assistant"]] = None content: Optional[str] = None audio: Optional[List[AudioOutput]] = None class StreamChoice(BaseModel): """Streaming choice.""" index: int = 0 delta: DeltaContent finish_reason: Optional[Literal["stop", "length", "content_filter", "error"]] = None class ChatCompletionChunk(BaseModel): """Streaming chunk response.""" id: str object: Literal["chat.completion.chunk"] = "chat.completion.chunk" created: int model: str choices: List[StreamChoice] # ============================================================================= # Models Endpoint Response # ============================================================================= class ModelPricing(BaseModel): """Model pricing information.""" prompt: str = Field(default="0", description="Price per prompt token in USD") completion: str = Field(default="0", description="Price per completion token in USD") request: str = Field(default="0", description="Price per request in USD") image: str = Field(default="0", description="Price per image in USD") class ModelInfo(BaseModel): """OpenRouter-compatible model information.""" id: str = Field(..., description="Model identifier") name: str = Field(..., description="Display name") created: int = Field(..., description="Unix timestamp of creation") # Modalities input_modalities: List[str] = Field( default_factory=lambda: ["text"], description="Supported input modalities" ) output_modalities: List[str] = Field( default_factory=lambda: ["audio", "text"], description="Supported output modalities" ) # Limits context_length: int = Field(default=4096, description="Maximum context length") max_output_length: int = Field(default=300, description="Maximum output length in seconds") # Pricing pricing: ModelPricing = Field(default_factory=ModelPricing) # Metadata description: Optional[str] = Field(default=None) architecture: Optional[Dict[str, Any]] = Field(default=None) class ModelsResponse(BaseModel): """Response for /v1/models endpoint.""" object: Literal["list"] = "list" data: List[ModelInfo] = Field(default_factory=list) # ============================================================================= # Error Response # ============================================================================= class ErrorDetail(BaseModel): """Error detail information.""" message: str type: str = "invalid_request_error" param: Optional[str] = None code: Optional[str] = None class ErrorResponse(BaseModel): """OpenRouter-compatible error response.""" error: ErrorDetail