hfproxydemo / huggingface_client.py
OpenCode Deployer
update
4ca5973
"""
HuggingFace Spaces API 客户端实现
负责与 HuggingFace API 的所有交互
"""
import aiohttp
import asyncio
import logging
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import asdict
import json
from core_system import HuggingFaceAPI, SpaceInfo, SpaceStatus, ErrorInfo
class HuggingFaceAPIClient(HuggingFaceAPI):
"""HuggingFace API 客户端实现"""
def __init__(self, token: str):
self.token = token
self.base_url = "https://huggingface.co/api"
self.headers = {"Authorization": f"Bearer {token}"}
self.logger = logging.getLogger(__name__)
self.session = None
async def _get_session(self) -> aiohttp.ClientSession:
"""获取 HTTP 会话"""
if self.session is None:
self.session = aiohttp.ClientSession(headers=self.headers)
return self.session
async def close(self) -> None:
"""关闭会话"""
if self.session:
await self.session.close()
async def get_space_status(self, space_id: str) -> SpaceStatus:
"""获取 Space 状态"""
try:
session = await self._get_session()
url = f"{self.base_url}/spaces/{space_id}"
async with session.get(url) as response:
if response.status == 200:
data = await response.json()
runtime_data = data.get('runtime', {})
# 根据运行时状态确定 Space 状态
if runtime_data.get('stage') == 'BUILDING':
return SpaceStatus.BUILDING
elif runtime_data.get('stage') == 'RUNNING':
if runtime_data.get('state') == 'RUNNING':
return SpaceStatus.RUNNING
else:
return SpaceStatus.ERROR
elif runtime_data.get('stage') == 'STOPPED':
return SpaceStatus.STOPPED
else:
return SpaceStatus.ERROR
else:
self.logger.error(f"获取 Space 状态失败: {response.status}")
return SpaceStatus.UNKNOWN
except Exception as e:
self.logger.error(f"获取 Space {space_id} 状态异常: {e}")
return SpaceStatus.UNKNOWN
async def get_space_logs(self, space_id: str, lines: int = 100) -> str:
"""获取 Space 日志"""
try:
session = await self._get_session()
url = f"{self.base_url}/spaces/{space_id}/logs"
params = {"lines": lines}
async with session.get(url, params=params) as response:
if response.status == 200:
data = await response.json()
# 解析日志数据
log_lines = []
for entry in data:
if isinstance(entry, dict) and 'message' in entry:
log_lines.append(entry['message'])
elif isinstance(entry, str):
log_lines.append(entry)
return '\n'.join(log_lines)
else:
error_msg = await response.text()
self.logger.error(f"获取日志失败: {response.status} - {error_msg}")
return f"ERROR: 无法获取日志 - {response.status}"
except Exception as e:
self.logger.error(f"获取 Space {space_id} 日志异常: {e}")
return f"ERROR: 获取日志异常 - {str(e)}"
async def trigger_rebuild(self, space_id: str) -> bool:
"""触发重新构建"""
try:
session = await self._get_session()
url = f"{self.base_url}/spaces/{space_id}/restart"
async with session.post(url) as response:
if response.status == 200:
self.logger.info(f"成功触发 Space {space_id} 重新构建")
return True
else:
error_msg = await response.text()
self.logger.error(f"触发重新构建失败: {response.status} - {error_msg}")
return False
except Exception as e:
self.logger.error(f"触发重新构建异常: {e}")
return False
async def get_space_info(self, space_id: str) -> SpaceInfo:
"""获取 Space 详细信息"""
try:
session = await self._get_session()
url = f"{self.base_url}/spaces/{space_id}"
async with session.get(url) as response:
if response.status == 200:
data = await response.json()
return SpaceInfo(
space_id=space_id,
name=data.get('id', space_id),
repository_url=data.get('url', ''),
current_status=await self.get_space_status(space_id),
last_updated=datetime.now(),
dockerfile_path="Dockerfile", # 默认路径
local_path="" # 本地路径需要额外配置
)
else:
raise Exception(f"无法获取 Space 信息: {response.status}")
except Exception as e:
self.logger.error(f"获取 Space {space_id} 信息异常: {e}")
# 返回默认信息
return SpaceInfo(
space_id=space_id,
name=space_id,
repository_url="",
current_status=SpaceStatus.UNKNOWN,
last_updated=datetime.now()
)
async def get_space_discussion(self, space_id: str) -> List[Dict]:
"""获取 Space 讨论信息(用于获取更多上下文)"""
try:
session = await self._get_session()
url = f"{self.base_url}/spaces/{space_id}/discussions"
async with session.get(url) as response:
if response.status == 200:
return await response.json()
else:
return []
except Exception as e:
self.logger.error(f"获取 Space {space_id} 讨论信息异常: {e}")
return []
async def get_space_runtime_info(self, space_id: str) -> Dict[str, Any]:
"""获取 Space 运行时详细信息"""
try:
session = await self._get_session()
url = f"{self.base_url}/spaces/{space_id}/runtime"
async with session.get(url) as response:
if response.status == 200:
return await response.json()
else:
return {}
except Exception as e:
self.logger.error(f"获取 Space {space_id} 运行时信息异常: {e}")
return {}
class HuggingFaceWebhookHandler:
"""HuggingFace Webhook 处理器"""
def __init__(self, api_client: HuggingFaceAPIClient):
self.api_client = api_client
self.logger = logging.getLogger(__name__)
self.event_handlers = {
'space.status_updated': self._handle_status_update,
'space.build_error': self._handle_build_error,
'space.started': self._handle_space_started,
'space.stopped': self._handle_space_stopped
}
async def handle_webhook(self, payload: Dict[str, Any]) -> None:
"""处理 Webhook 事件"""
try:
event_type = payload.get('event')
if event_type in self.event_handlers:
await self.event_handlers[event_type](payload)
else:
self.logger.warning(f"未知的事件类型: {event_type}")
except Exception as e:
self.logger.error(f"处理 Webhook 事件失败: {e}")
async def _handle_status_update(self, payload: Dict[str, Any]) -> None:
"""处理状态更新事件"""
space_id = payload.get('space', {}).get('id')
new_status = payload.get('space', {}).get('runtime', {}).get('stage')
self.logger.info(f"Space {space_id} 状态更新为: {new_status}")
# 根据状态变化触发相应处理
if new_status == 'ERROR':
await self._handle_build_error(payload)
async def _handle_build_error(self, payload: Dict[str, Any]) -> None:
"""处理构建错误事件"""
space_id = payload.get('space', {}).get('id')
# 获取错误日志
logs = await self.api_client.get_space_logs(space_id, lines=50)
# 触发错误分析流程
# 这里需要与错误分析器集成
async def _handle_space_started(self, payload: Dict[str, Any]) -> None:
"""处理 Space 启动事件"""
space_id = payload.get('space', {}).get('id')
self.logger.info(f"Space {space_id} 启动成功")
async def _handle_space_stopped(self, payload: Dict[str, Any]) -> None:
"""处理 Space 停止事件"""
space_id = payload.get('space', {}).get('id')
self.logger.info(f"Space {space_id} 已停止")
class RateLimiter:
"""API 请求限制器"""
def __init__(self, max_requests_per_minute: int = 60):
self.max_requests = max_requests_per_minute
self.requests = []
self.lock = asyncio.Lock()
async def acquire(self) -> None:
"""获取请求许可"""
async with self.lock:
now = datetime.now()
# 清理超过1分钟的请求记录
self.requests = [req_time for req_time in self.requests
if (now - req_time).total_seconds() < 60]
# 检查是否超过限制
if len(self.requests) >= self.max_requests:
# 计算需要等待的时间
oldest_request = min(self.requests)
wait_time = 60 - (now - oldest_request).total_seconds()
if wait_time > 0:
await asyncio.sleep(wait_time)
# 记录当前请求
self.requests.append(now)
class HuggingFaceAPIClientWithRateLimit(HuggingFaceAPIClient):
"""带速率限制的 HuggingFace API 客户端"""
def __init__(self, token: str, rate_limit: int = 60):
super().__init__(token)
self.rate_limiter = RateLimiter(rate_limit)
self.base_client = HuggingFaceAPIClient(token)
async def get_space_status(self, space_id: str) -> SpaceStatus:
"""获取 Space 状态(带速率限制)"""
await self.rate_limiter.acquire()
return await self.base_client.get_space_status(space_id)
async def get_space_logs(self, space_id: str, lines: int = 100) -> str:
"""获取 Space 日志(带速率限制)"""
await self.rate_limiter.acquire()
return await self.base_client.get_space_logs(space_id, lines)
async def trigger_rebuild(self, space_id: str) -> bool:
"""触发重新构建(带速率限制)"""
await self.rate_limiter.acquire()
return await self.base_client.trigger_rebuild(space_id)
async def get_space_info(self, space_id: str) -> SpaceInfo:
"""获取 Space 详细信息(带速率限制)"""
await self.rate_limiter.acquire()
return await self.base_client.get_space_info(space_id)