Spaces:
Paused
Paused
| """基础客户端类,定义通用接口""" | |
| from typing import AsyncGenerator, Any | |
| import aiohttp | |
| from app.utils.logger import logger | |
| from abc import ABC, abstractmethod | |
| class BaseClient(ABC): | |
| def __init__(self, api_key: str, api_url: str): | |
| """初始化基础客户端 | |
| Args: | |
| api_key: API密钥 | |
| api_url: API地址 | |
| """ | |
| self.api_key = api_key | |
| self.api_url = api_url | |
| async def _make_request(self, headers: dict, data: dict) -> AsyncGenerator[bytes, None]: | |
| """发送请求并处理响应 | |
| Args: | |
| headers: 请求头 | |
| data: 请求数据 | |
| Yields: | |
| bytes: 原始响应数据 | |
| """ | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post(self.api_url, headers=headers, json=data) as response: | |
| if response.status != 200: | |
| error_text = await response.text() | |
| logger.error(f"API 请求失败: {error_text}") | |
| return | |
| async for chunk in response.content.iter_any(): | |
| yield chunk | |
| except Exception as e: | |
| logger.error(f"请求 API 时发生错误: {e}") | |
| async def stream_chat(self, messages: list, model: str) -> AsyncGenerator[tuple[str, str], None]: | |
| """流式对话,由子类实现 | |
| Args: | |
| messages: 消息列表 | |
| model: 模型名称 | |
| Yields: | |
| tuple[str, str]: (内容类型, 内容) | |
| """ | |
| pass | |