| | from io import IOBase |
| | from pydantic import BaseModel, Field, model_validator, ConfigDict |
| | from typing import List, Dict, Optional, Union, Tuple, Literal, Any |
| | from log_config import logger |
| |
|
| | class FunctionParameter(BaseModel): |
| | type: str |
| | properties: Dict[str, Dict[str, Any]] |
| | required: List[str] |
| |
|
| | class Function(BaseModel): |
| | name: str |
| | description: str |
| | parameters: Optional[FunctionParameter] = Field(default=None, exclude=None) |
| |
|
| | class Tool(BaseModel): |
| | type: str |
| | function: Function |
| |
|
| | class FunctionCall(BaseModel): |
| | name: str |
| | arguments: str |
| |
|
| | class ToolCall(BaseModel): |
| | id: str |
| | type: str |
| | function: FunctionCall |
| |
|
| | class ImageUrl(BaseModel): |
| | url: str |
| |
|
| | class ContentItem(BaseModel): |
| | type: str |
| | text: Optional[str] = None |
| | image_url: Optional[ImageUrl] = None |
| |
|
| | class Message(BaseModel): |
| | role: str |
| | name: Optional[str] = None |
| | arguments: Optional[str] = None |
| | content: Optional[Union[str, List[ContentItem]]] = None |
| | tool_calls: Optional[List[ToolCall]] = None |
| |
|
| | class Message(BaseModel): |
| | role: str |
| | name: Optional[str] = None |
| | content: Optional[Union[str, List[ContentItem]]] = None |
| | tool_calls: Optional[List[ToolCall]] = None |
| | tool_call_id: Optional[str] = None |
| |
|
| | class Config: |
| | extra = "allow" |
| |
|
| | class FunctionChoice(BaseModel): |
| | name: str |
| |
|
| | class ToolChoice(BaseModel): |
| | type: str |
| | function: Optional[FunctionChoice] = None |
| |
|
| | class BaseRequest(BaseModel): |
| | request_type: Optional[Literal["chat", "image", "audio", "moderation"]] = Field(default=None, exclude=True) |
| |
|
| | def create_json_schema_class(): |
| | class JsonSchema(BaseModel): |
| | name: str |
| |
|
| | model_config = ConfigDict(protected_namespaces=()) |
| |
|
| | JsonSchema.__annotations__['schema'] = Dict[str, Any] |
| | return JsonSchema |
| |
|
| | JsonSchema = create_json_schema_class() |
| | class ResponseFormat(BaseModel): |
| | type: Literal["text", "json_object", "json_schema"] |
| | json_schema: Optional[JsonSchema] = None |
| |
|
| | class RequestModel(BaseRequest): |
| | model: str |
| | messages: List[Message] |
| | logprobs: Optional[bool] = None |
| | top_logprobs: Optional[int] = None |
| | stream: Optional[bool] = None |
| | include_usage: Optional[bool] = None |
| | temperature: Optional[float] = 0.5 |
| | top_p: Optional[float] = 1.0 |
| | max_tokens: Optional[int] = None |
| | presence_penalty: Optional[float] = 0.0 |
| | frequency_penalty: Optional[float] = 0.0 |
| | n: Optional[int] = 1 |
| | user: Optional[str] = None |
| | tool_choice: Optional[Union[str, ToolChoice]] = None |
| | tools: Optional[List[Tool]] = None |
| | response_format: Optional[ResponseFormat] = None |
| |
|
| | def get_last_text_message(self) -> Optional[str]: |
| | for message in reversed(self.messages): |
| | if message.content: |
| | if isinstance(message.content, str): |
| | return message.content |
| | elif isinstance(message.content, list): |
| | for item in reversed(message.content): |
| | if item.type == "text" and item.text: |
| | return item.text |
| | return "" |
| |
|
| | class ImageGenerationRequest(BaseRequest): |
| | prompt: str |
| | model: Optional[str] = "dall-e-3" |
| | n: Optional[int] = 1 |
| | size: Optional[str] = "1024x1024" |
| | stream: bool = False |
| |
|
| | class EmbeddingRequest(BaseRequest): |
| | input: str |
| | model: str |
| | encoding_format: Optional[str] = "float" |
| | stream: bool = False |
| |
|
| | class AudioTranscriptionRequest(BaseRequest): |
| | file: Tuple[str, IOBase, str] |
| | model: str |
| | language: Optional[str] = None |
| | prompt: Optional[str] = None |
| | response_format: Optional[str] = None |
| | temperature: Optional[float] = None |
| | stream: bool = False |
| |
|
| | class Config: |
| | arbitrary_types_allowed = True |
| |
|
| | class ModerationRequest(BaseRequest): |
| | input: str |
| | model: Optional[str] = "text-moderation-latest" |
| | stream: bool = False |
| |
|
| | class UnifiedRequest(BaseModel): |
| | data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest] |
| |
|
| | @model_validator(mode='before') |
| | @classmethod |
| | def set_request_type(cls, values): |
| | if isinstance(values, dict): |
| | if "messages" in values: |
| | values["data"] = RequestModel(**values) |
| | values["data"].request_type = "chat" |
| | elif "prompt" in values: |
| | values["data"] = ImageGenerationRequest(**values) |
| | values["data"].request_type = "image" |
| | elif "file" in values: |
| | values["data"] = AudioTranscriptionRequest(**values) |
| | values["data"].request_type = "audio" |
| | elif "input" in values: |
| | values["data"] = ModerationRequest(**values) |
| | values["data"].request_type = "moderation" |
| | elif "input" in values: |
| | values["data"] = EmbeddingRequest(**values) |
| | values["data"].request_type = "embedding" |
| | else: |
| | raise ValueError("无法确定请求类型") |
| | return values |