| import os |
| import mimetypes |
| import datetime as dt |
| import aiohttp |
| import requests |
| import asyncio |
| from oss2.utils import http_date |
| from oss2.utils import content_type_by_name |
| from oss2 import Auth, Bucket |
| from typing import ( |
| AsyncGenerator, |
| Dict, |
| Generator, |
| List, |
| Optional, |
| Union, |
| Iterable, |
| overload, |
| Literal, |
| ) |
| from ..core.types.upload_file import FileResult |
| from ..core.exceptions import QwenAPIError, RateLimitError |
| from ..core.types.chat import ( |
| ChatResponseStream, |
| ChatResponse, |
| ChatMessage, |
| Choice, |
| Message, |
| ) |
| from ..core.types.chat_model import ChatModel |
| from ..core.types.endpoint_api import EndpointAPI |
| from ..core.types.response.tool_param import ToolParam |
| from .tool_handle import using_tools, async_using_tools |
|
|
|
|
| class Completion: |
| def __init__(self, client): |
| self._client = client |
|
|
| @overload |
| def create( |
| self, |
| messages: List[ChatMessage], |
| model: ChatModel = "qwen-max-latest", |
| stream: Literal[False] = False, |
| temperature: float = 0.7, |
| max_tokens: Optional[int] = 2048, |
| tools: Optional[Iterable[ToolParam]] | Optional[List[Dict]] = None, |
| ) -> ChatResponse: ... |
|
|
| @overload |
| def create( |
| self, |
| messages: List[ChatMessage], |
| model: ChatModel = "qwen-max-latest", |
| stream: Literal[True] = True, |
| temperature: float = 0.7, |
| max_tokens: Optional[int] = 2048, |
| tools: Optional[Iterable[ToolParam]] | Optional[List[Dict]] = None, |
| ) -> Generator[ChatResponseStream, None, None]: ... |
|
|
| def create( |
| self, |
| messages: List[ChatMessage], |
| model: ChatModel = "qwen-max-latest", |
| stream: bool = False, |
| temperature: float = 0.7, |
| max_tokens: Optional[int] = 2048, |
| tools: Optional[Iterable[ToolParam]] | Optional[List[Dict]] = None, |
| ) -> Union[ChatResponse, Generator[ChatResponseStream, None, None], None]: |
|
|
| if tools: |
| |
| tool_response = using_tools( |
| messages, tools, model, temperature, max_tokens, stream, self._client |
| ) |
|
|
| if stream: |
| |
| def tool_stream_generator(): |
| from ..core.types.chat import ChoiceStream, Delta, Usage |
|
|
| |
| delta = Delta( |
| role=tool_response.choices.message.role, |
| content=tool_response.choices.message.content, |
| extra=tool_response.choices.extra, |
| ) |
|
|
| choice_stream = ChoiceStream(delta=delta) |
|
|
| |
| stream_message = ChatMessage( |
| role=tool_response.choices.message.role, |
| content=tool_response.choices.message.content or "", |
| tool_calls=tool_response.choices.message.tool_calls, |
| ) |
|
|
| |
| usage = Usage() |
|
|
| stream_response = ChatResponseStream( |
| choices=[choice_stream], usage=usage, message=stream_message |
| ) |
| yield stream_response |
|
|
| return tool_stream_generator() |
| else: |
| return tool_response |
| |
| |
| chat_id = self._client.new_chat.create( |
| models=[model] |
| ) |
|
|
| |
| |
| payload = self._client._build_payload( |
| chat_id=chat_id, |
| messages=messages, |
| model=model, |
| temperature=temperature, |
| stream=stream, |
| max_tokens=max_tokens, |
| ) |
| |
| params = {"chat_id": chat_id} |
|
|
| response = requests.post( |
| url=self._client.base_url + EndpointAPI.completions, |
| headers=self._client._build_headers(), |
| params=params, |
| json=payload, |
| timeout=self._client.timeout, |
| stream=stream, |
| ) |
| |
| self._client.logger.debug(f"response OK: {response.ok}") |
| |
| if not stream: |
| self._client.logger.debug(f"response Body: {response.json()}") |
|
|
| if not response.ok: |
| error_text = response.json() |
| self._client.logger.error(f"API Error: {response.status_code} {error_text}") |
| raise QwenAPIError(f"API Error: {response.status_code} {error_text}") |
|
|
| if response.status_code == 429: |
| self._client.logger.error("Too many requests") |
| raise RateLimitError("Too many requests") |
|
|
| self._client.logger.info(f"Response: {response.status_code}") |
|
|
| if stream: |
| return self._client._process_stream(response) |
| try: |
| return self._client._process_response(response) |
| except Exception as e: |
| self._client.logger.error(f"Error: {e}") |
|
|
| @overload |
| async def acreate( |
| self, |
| messages: List[ChatMessage], |
| model: ChatModel = "qwen-max-latest", |
| stream: Literal[False] = False, |
| temperature: float = 0.7, |
| max_tokens: Optional[int] = 2048, |
| tools: Optional[Iterable[ToolParam]] | List[Dict] = None, |
| ) -> ChatResponse: ... |
|
|
| @overload |
| async def acreate( |
| self, |
| messages: List[ChatMessage], |
| model: ChatModel = "qwen-max-latest", |
| stream: Literal[True] = True, |
| temperature: float = 0.7, |
| max_tokens: Optional[int] = 2048, |
| tools: Optional[Iterable[ToolParam]] | List[Dict] = None, |
| ) -> AsyncGenerator[ChatResponseStream, None]: ... |
|
|
| async def acreate( |
| self, |
| messages: List[ChatMessage], |
| model: ChatModel = "qwen-max-latest", |
| stream: bool = False, |
| temperature: float = 0.7, |
| max_tokens: Optional[int] = 2048, |
| tools: Optional[Iterable[ToolParam]] | List[Dict] = None, |
| ) -> Union[ChatResponse, AsyncGenerator[ChatResponseStream, None], None]: |
| session = None |
| try: |
| if tools: |
| tool_response = await async_using_tools( |
| messages, |
| tools, |
| model, |
| temperature, |
| max_tokens, |
| self._client, |
| ) |
|
|
| if stream: |
| |
| async def tool_astream_generator(): |
| from ..core.types.chat import ChoiceStream, Delta, Usage |
|
|
| |
| delta = Delta( |
| role=tool_response.choices.message.role, |
| content=tool_response.choices.message.content, |
| extra=tool_response.choices.extra, |
| ) |
|
|
| choice_stream = ChoiceStream(delta=delta) |
|
|
| |
| stream_message = ChatMessage( |
| role=tool_response.choices.message.role, |
| content=tool_response.choices.message.content or "", |
| tool_calls=tool_response.choices.message.tool_calls, |
| ) |
|
|
| |
| usage = Usage() |
|
|
| stream_response = ChatResponseStream( |
| choices=[choice_stream], usage=usage, message=stream_message |
| ) |
| yield stream_response |
|
|
| return tool_astream_generator() |
| else: |
| return tool_response |
| else: |
| |
| |
| chat_id = await self._client.chat_manager.anew_chat( |
| models=[model] |
| ) |
|
|
| payload = self._client._build_payload( |
| chat_id=chat_id, |
| messages=messages, |
| model=model, |
| temperature=temperature, |
| stream=stream, |
| max_tokens=max_tokens, |
| ) |
|
|
| session = aiohttp.ClientSession() |
| |
| self._client._active_sessions.append(session) |
|
|
| response = await session.post( |
| url=self._client.base_url + EndpointAPI.completions, |
| headers=self._client._build_headers(), |
| json=payload, |
| timeout=aiohttp.ClientTimeout(total=self._client.timeout), |
| ) |
|
|
| if not response.ok: |
| error_text = await response.text() |
| self._client.logger.error( |
| f"API Error: {response.status} {error_text}" |
| ) |
| raise QwenAPIError(f"API Error: {response.status} {error_text}") |
|
|
| if response.status == 429: |
| self._client.logger.error("Too many requests") |
| raise RateLimitError("Too many requests") |
|
|
| self._client.logger.info(f"Response status: {response.status}") |
|
|
| if stream: |
| return self._client._process_astream(response, session) |
| try: |
| return await self._client._process_aresponse(response, session) |
| except Exception as e: |
| self._client.logger.error(f"Error: {e}") |
|
|
| except Exception as e: |
| self._client.logger.error(f"Error in acreate: {e}") |
| if session and not session.closed: |
| |
| if session in self._client._active_sessions: |
| self._client._active_sessions.remove(session) |
| await session.close() |
| raise |
|
|
| def upload_file( |
| self, file_path: Optional[str] = None, base64_data: Optional[str] = None |
| ): |
| if not file_path and not base64_data: |
| raise QwenAPIError("Either file_path or base64_data must be provided") |
|
|
| |
| if base64_data: |
| |
| import base64 |
| from io import BytesIO |
|
|
| |
| if base64_data.startswith("data:image/"): |
| try: |
| header, data = base64_data.split(",", 1) |
| mime_type = header.split(";")[0].split(":")[1] |
| is_base64 = True |
| except QwenAPIError: |
| |
| mime_type = "image/png" |
| data = base64_data |
| is_base64 = False |
| else: |
| data = base64_data |
| mime_type = "image/png" |
| is_base64 = True |
|
|
| |
| try: |
| file_content = base64.b64decode(data) |
| except Exception as e: |
| raise QwenAPIError(f"Invalid base64 data: {e}") |
|
|
| |
| filename = "uploaded_image.png" |
| if ";" in mime_type: |
| mime_type = mime_type.split(";")[0] |
|
|
| if "/" in mime_type: |
| ext = mime_type.split("/")[-1].lower() |
| if ext in ["jpeg", "jpg"]: |
| filename = f"uploaded_image.jpg" |
| elif ext == "png": |
| filename = f"uploaded_image.png" |
| elif ext == "gif": |
| filename = f"uploaded_image.gif" |
| elif ext == "webp": |
| filename = f"uploaded_image.webp" |
|
|
| |
| file_size = len(file_content) |
|
|
| elif file_path and os.path.isfile(file_path): |
| |
| with open(file_path, "rb") as file: |
| file_content = file.read() |
|
|
| file_size = os.path.getsize(file_path) |
| filename = os.path.basename(file_path) |
|
|
| else: |
| raise QwenAPIError(f"File {file_path} does not exist") |
|
|
| detected_mime_type = None |
| if not base64_data and file_path: |
| detected_mime_type, _ = mimetypes.guess_type(file_path) |
|
|
| mime_type = detected_mime_type |
| if base64_data: |
| mime_type = mime_type or "image/png" |
|
|
| payload = { |
| "filename": filename, |
| "filesize": file_size, |
| "filetype": mime_type.split("/")[0] if mime_type else "image", |
| } |
|
|
| headers = self._client._build_headers() |
| headers["Content-Type"] = "application/json" |
| response = requests.post( |
| url=self._client.base_url + EndpointAPI.upload_file, |
| headers=headers, |
| json=payload, |
| timeout=self._client.timeout, |
| ) |
|
|
| if not response.ok: |
| try: |
| error_text = response.json() |
| except Exception: |
| error_text = response.text |
| self._client.logger.error(f"API Error: {response.status_code} {error_text}") |
| raise QwenAPIError(f"API Error: {response.status_code} {error_text}") |
|
|
| if response.status_code == 429: |
| self._client.logger.error("Too many requests") |
| raise RateLimitError("Too many requests") |
|
|
| try: |
| response_data = response.json() |
| except Exception: |
| response_data = response.text |
|
|
| if not isinstance(response_data, dict): |
| raise QwenAPIError(f"Invalid response format: {response_data}") |
|
|
| |
| access_key_id = response_data["access_key_id"] |
| access_key_secret = response_data["access_key_secret"] |
| region = response_data["region"] |
| bucket_name = response_data.get("bucketname", "qwen-webui-prod") |
|
|
| |
| if not access_key_id: |
| raise QwenAPIError("AccessKey ID cannot be empty") |
| if not access_key_secret: |
| raise QwenAPIError("AccessKey Secret cannot be empty") |
|
|
| |
| security_token = response_data.get("security_token") |
| if not security_token: |
| raise QwenAPIError("Security token cannot be empty") |
|
|
| |
| request_datetime = dt.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") |
|
|
| |
| endpoint = f"https://{region}.aliyuncs.com" |
| auth = Auth(access_key_id, access_key_secret) |
| bucket = Bucket(auth, endpoint, response_data["bucketname"]) |
|
|
| |
| date_str = http_date() |
|
|
| |
| oss_headers = { |
| "Content-Type": ( |
| mime_type or content_type_by_name(file_path) |
| if not base64_data |
| else mime_type |
| ), |
| "Date": date_str, |
| "x-oss-security-token": security_token, |
| "x-oss-content-sha256": "UNSIGNED-PAYLOAD", |
| } |
|
|
| |
| request_datetime = dt.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") |
| oss_headers["date"] = request_datetime |
|
|
| |
| oss_response = bucket.put_object( |
| key=response_data["file_path"], data=file_content, headers=oss_headers |
| ) |
|
|
| |
| oss_headers.update( |
| { |
| "x-oss-date": request_datetime, |
| "Host": f"{bucket_name}.{region}.aliyuncs.com", |
| } |
| ) |
|
|
| |
| if oss_response.status != 200 and oss_response.status != 203: |
| error_text = str(oss_response) |
| self._client.logger.error(f"API Error: {oss_response.status} {error_text}") |
| raise QwenAPIError(f"API Error: {oss_response.status} {error_text}") |
|
|
| if oss_response.status == 429: |
| self._client.logger.error("Too many requests") |
| raise RateLimitError("Too many requests") |
|
|
| result = { |
| "file_url": response_data["file_url"], |
| "file_id": response_data["file_id"], |
| "image_mimetype": mime_type, |
| } |
| return FileResult(**result) |
|
|
| async def async_upload_file( |
| self, file_path: Optional[str] = None, base64_data: Optional[str] = None |
| ): |
| if not file_path and not base64_data: |
| raise QwenAPIError("Either file_path or base64_data must be provided") |
|
|
| |
| if base64_data: |
| |
| import base64 |
| from io import BytesIO |
|
|
| |
| if base64_data.startswith("data:image/"): |
| try: |
| header, data = base64_data.split(",", 1) |
| mime_type = header.split(";")[0].split(":")[1] |
| is_base64 = True |
| except QwenAPIError: |
| |
| mime_type = "image/png" |
| data = base64_data |
| is_base64 = False |
| else: |
| data = base64_data |
| mime_type = "image/png" |
| is_base64 = True |
|
|
| |
| try: |
| file_content = base64.b64decode(data) |
| except Exception as e: |
| raise QwenAPIError(f"Invalid base64 data: {e}") |
|
|
| |
| filename = "uploaded_image.png" |
| if ";" in mime_type: |
| mime_type = mime_type.split(";")[0] |
|
|
| if "/" in mime_type: |
| ext = mime_type.split("/")[-1].lower() |
| if ext in ["jpeg", "jpg"]: |
| filename = f"uploaded_image.jpg" |
| elif ext == "png": |
| filename = f"uploaded_image.png" |
| elif ext == "gif": |
| filename = f"uploaded_image.gif" |
| elif ext == "webp": |
| filename = f"uploaded_image.webp" |
|
|
| |
| file_size = len(file_content) |
|
|
| elif file_path and os.path.isfile(file_path): |
| |
| with open(file_path, "rb") as file: |
| file_content = file.read() |
|
|
| file_size = os.path.getsize(file_path) |
| filename = os.path.basename(file_path) |
|
|
| detected_mime_type = None |
| if not base64_data and file_path: |
| detected_mime_type, _ = mimetypes.guess_type(file_path) |
|
|
| mime_type = detected_mime_type |
| if base64_data: |
| mime_type = mime_type or "image/png" |
|
|
| payload = { |
| "filename": filename, |
| "filesize": file_size, |
| "filetype": mime_type.split("/")[0] if mime_type else "image", |
| } |
|
|
| headers = self._client._build_headers() |
| headers["Content-Type"] = "application/json" |
|
|
| async with aiohttp.ClientSession() as session: |
| response = await session.post( |
| url=self._client.base_url + EndpointAPI.upload_file, |
| headers=headers, |
| json=payload, |
| timeout=aiohttp.ClientTimeout(total=self._client.timeout), |
| ) |
|
|
| if not response.ok: |
| error_text = response.json() |
| self._client.logger.error(f"API Error: {response.status} {error_text}") |
| raise QwenAPIError(f"API Error: {response.status} {error_text}") |
|
|
| if response.status == 429: |
| self._client.logger.error("Too many requests") |
| raise RateLimitError("Too many requests") |
|
|
| response_data = await response.json() |
|
|
| |
| access_key_id = response_data["access_key_id"] |
| access_key_secret = response_data["access_key_secret"] |
| region = response_data["region"] |
| bucket_name = response_data.get("bucketname", "qwen-webui-prod") |
|
|
| |
| if not access_key_id: |
| raise QwenAPIError("AccessKey ID cannot be empty") |
| if not access_key_secret: |
| raise QwenAPIError("AccessKey Secret cannot be empty") |
|
|
| |
| security_token = response_data.get("security_token") |
| if not security_token: |
| raise QwenAPIError("Security token cannot be empty") |
|
|
| |
| request_datetime = dt.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") |
|
|
| |
| endpoint = f"https://{region}.aliyuncs.com" |
| auth = Auth(access_key_id, access_key_secret) |
| bucket = Bucket(auth, endpoint, response_data["bucketname"]) |
|
|
| |
| date_str = http_date() |
|
|
| |
| oss_headers = { |
| "Content-Type": ( |
| mime_type or content_type_by_name(file_path) |
| if not base64_data |
| else mime_type |
| ), |
| "Date": date_str, |
| "x-oss-security-token": security_token, |
| "x-oss-content-sha256": "UNSIGNED-PAYLOAD", |
| } |
|
|
| |
| request_datetime = dt.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") |
| oss_headers["date"] = request_datetime |
|
|
| |
| loop = asyncio.get_event_loop() |
| |
|
|
| try: |
| |
| oss_response = await loop.run_in_executor( |
| None, |
| lambda: bucket.put_object( |
| key=response_data["file_path"], |
| data=file_content if not base64_data else file_content, |
| headers=oss_headers, |
| ), |
| ) |
|
|
| |
| oss_headers.update( |
| { |
| "x-oss-date": request_datetime, |
| "Host": f"{bucket_name}.{region}.aliyuncs.com", |
| } |
| ) |
|
|
| |
| if oss_response.status != 200 and oss_response.status != 203: |
| error_text = str(oss_response) |
| self._client.logger.error( |
| f"API Error: {oss_response.status} {error_text}" |
| ) |
| raise QwenAPIError(f"API Error: {oss_response.status} {error_text}") |
|
|
| if oss_response.status == 429: |
| self._client.logger.error("Too many requests") |
| raise RateLimitError("Too many requests") |
|
|
| result = { |
| "file_url": response_data["file_url"], |
| "file_id": response_data["file_id"], |
| "image_mimetype": mime_type, |
| } |
| return FileResult(**result) |
|
|
| except Exception as e: |
| self._client.logger.error(f"Error: {e}") |
| raise |
| finally: |
| |
| self._client.logger.debug("Closing session") |
|
|