Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """ | |
| GitHub 异步客户端 | |
| 设计原则: | |
| 1. 异步非阻塞 - 使用 httpx.AsyncClient | |
| 2. 连接池复用 - 单例模式管理客户端生命周期 | |
| 3. 自动重试 - 集成 tenacity 处理瞬时错误 | |
| 4. 类型安全 - 完整的类型注解 | |
| 5. 可扩展 - 易于添加新的 API 端点 | |
| """ | |
| import asyncio | |
| import base64 | |
| import logging | |
| import os | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional, Dict, Any, Set | |
| from contextlib import asynccontextmanager | |
| import httpx | |
| from app.core.config import settings | |
| from app.utils.retry import llm_retry # 复用已有的重试装饰器 | |
| logger = logging.getLogger(__name__) | |
| # ============================================================ | |
| # 数据模型 | |
| # ============================================================ | |
| class GitHubFile: | |
| """GitHub 文件信息""" | |
| path: str | |
| type: str # "blob" | "tree" | |
| size: int = 0 | |
| sha: str = "" | |
| def is_file(self) -> bool: | |
| return self.type == "blob" | |
| def is_directory(self) -> bool: | |
| return self.type == "tree" | |
| class GitHubRepo: | |
| """GitHub 仓库信息""" | |
| owner: str | |
| name: str | |
| default_branch: str = "main" | |
| description: str = "" | |
| stars: int = 0 | |
| def full_name(self) -> str: | |
| return f"{self.owner}/{self.name}" | |
| class FileFilter: | |
| """文件过滤配置""" | |
| ignored_extensions: Set[str] = field(default_factory=lambda: { | |
| '.png', '.jpg', '.jpeg', '.gif', '.svg', '.ico', '.mp4', '.webp', | |
| '.pyc', '.pyo', '.lock', '.zip', '.tar', '.gz', '.pdf', '.woff', '.woff2', | |
| '.DS_Store', '.gitignore', '.gitattributes', '.editorconfig' | |
| }) | |
| ignored_directories: Set[str] = field(default_factory=lambda: { | |
| '.git', '.github', '.vscode', '.idea', '__pycache__', | |
| 'node_modules', 'venv', 'env', '.env', 'build', 'dist', | |
| 'site-packages', 'migrations', '.next', '.nuxt', 'coverage', | |
| 'vendor', 'target', 'out', 'bin', 'obj' | |
| }) | |
| max_file_size: int = 500_000 # 500KB | |
| def should_include(self, file: GitHubFile) -> bool: | |
| """判断文件是否应该被包含""" | |
| if not file.is_file: | |
| return False | |
| # 检查目录 | |
| path_parts = file.path.split("/") | |
| if any(part in self.ignored_directories for part in path_parts): | |
| return False | |
| # 检查扩展名 | |
| ext = os.path.splitext(file.path)[1].lower() | |
| if ext in self.ignored_extensions: | |
| return False | |
| # 检查文件大小 | |
| if file.size > self.max_file_size: | |
| return False | |
| return True | |
| # ============================================================ | |
| # 异常定义 | |
| # ============================================================ | |
| class GitHubError(Exception): | |
| """GitHub API 错误基类""" | |
| def __init__(self, message: str, status_code: int = 0): | |
| self.message = message | |
| self.status_code = status_code | |
| super().__init__(message) | |
| class GitHubAuthError(GitHubError): | |
| """认证错误 (401)""" | |
| pass | |
| class GitHubRateLimitError(GitHubError): | |
| """速率限制错误 (403)""" | |
| pass | |
| class GitHubNotFoundError(GitHubError): | |
| """资源不存在 (404)""" | |
| pass | |
| # ============================================================ | |
| # GitHub 异步客户端 | |
| # ============================================================ | |
| class GitHubClient: | |
| """ | |
| GitHub 异步 API 客户端 | |
| 使用示例: | |
| ```python | |
| async with GitHubClient() as client: | |
| repo = await client.get_repo("owner", "repo") | |
| files = await client.get_repo_tree(repo) | |
| content = await client.get_file_content(repo, "README.md") | |
| ``` | |
| """ | |
| BASE_URL = "https://api.github.com" | |
| def __init__( | |
| self, | |
| token: Optional[str] = None, | |
| timeout: float = 30.0, | |
| max_concurrent_requests: int = 10 | |
| ): | |
| self.token = token or settings.GITHUB_TOKEN | |
| self.timeout = timeout | |
| self._client: Optional[httpx.AsyncClient] = None | |
| self._semaphore = asyncio.Semaphore(max_concurrent_requests) | |
| def _headers(self) -> Dict[str, str]: | |
| """构建请求头""" | |
| headers = { | |
| "Accept": "application/vnd.github.v3+json", | |
| "User-Agent": "GitHub-Agent-Demo/1.0" | |
| } | |
| if self.token: | |
| headers["Authorization"] = f"Bearer {self.token}" | |
| return headers | |
| async def _ensure_client(self) -> httpx.AsyncClient: | |
| """确保客户端已初始化""" | |
| if self._client is None or self._client.is_closed: | |
| self._client = httpx.AsyncClient( | |
| base_url=self.BASE_URL, | |
| headers=self._headers, | |
| timeout=httpx.Timeout(self.timeout), | |
| follow_redirects=True, | |
| limits=httpx.Limits( | |
| max_keepalive_connections=20, | |
| max_connections=50 | |
| ) | |
| ) | |
| return self._client | |
| async def close(self): | |
| """关闭客户端连接""" | |
| if self._client and not self._client.is_closed: | |
| await self._client.aclose() | |
| self._client = None | |
| async def __aenter__(self): | |
| await self._ensure_client() | |
| return self | |
| async def __aexit__(self, exc_type, exc_val, exc_tb): | |
| await self.close() | |
| def _handle_error(self, response: httpx.Response, context: str = ""): | |
| """统一错误处理""" | |
| status = response.status_code | |
| try: | |
| data = response.json() | |
| message = data.get("message", response.text) | |
| except Exception: | |
| message = response.text | |
| error_msg = f"{context}: {message}" if context else message | |
| if status == 401: | |
| raise GitHubAuthError( | |
| "GitHub Token 无效或已过期,请检查 .env 配置", | |
| status | |
| ) | |
| elif status == 403: | |
| if "rate limit" in message.lower(): | |
| raise GitHubRateLimitError( | |
| "GitHub API 请求已达上限,请稍后重试或添加 Token", | |
| status | |
| ) | |
| raise GitHubError(error_msg, status) | |
| elif status == 404: | |
| raise GitHubNotFoundError(error_msg, status) | |
| else: | |
| raise GitHubError(error_msg, status) | |
| async def _request( | |
| self, | |
| method: str, | |
| endpoint: str, | |
| **kwargs | |
| ) -> Dict[str, Any]: | |
| """ | |
| 发送 API 请求 (带重试) | |
| Args: | |
| method: HTTP 方法 | |
| endpoint: API 端点 (如 /repos/{owner}/{repo}) | |
| **kwargs: 传递给 httpx 的参数 | |
| Returns: | |
| JSON 响应 | |
| """ | |
| async with self._semaphore: | |
| client = await self._ensure_client() | |
| response = await client.request(method, endpoint, **kwargs) | |
| if response.status_code >= 400: | |
| self._handle_error(response, endpoint) | |
| return response.json() | |
| async def _request_raw( | |
| self, | |
| method: str, | |
| endpoint: str, | |
| **kwargs | |
| ) -> httpx.Response: | |
| """发送请求并返回原始响应""" | |
| async with self._semaphore: | |
| client = await self._ensure_client() | |
| return await client.request(method, endpoint, **kwargs) | |
| # -------------------------------------------------------- | |
| # 仓库相关 API | |
| # -------------------------------------------------------- | |
| async def get_repo(self, owner: str, name: str) -> GitHubRepo: | |
| """获取仓库信息""" | |
| data = await self._request("GET", f"/repos/{owner}/{name}") | |
| return GitHubRepo( | |
| owner=owner, | |
| name=name, | |
| default_branch=data.get("default_branch", "main"), | |
| description=data.get("description", ""), | |
| stars=data.get("stargazers_count", 0) | |
| ) | |
| async def get_repo_tree( | |
| self, | |
| repo: GitHubRepo, | |
| file_filter: Optional[FileFilter] = None | |
| ) -> List[GitHubFile]: | |
| """ | |
| 获取仓库文件树 | |
| Args: | |
| repo: 仓库信息 | |
| file_filter: 文件过滤器 (默认使用标准过滤) | |
| Returns: | |
| 过滤后的文件列表 | |
| """ | |
| filter_config = file_filter or FileFilter() | |
| data = await self._request( | |
| "GET", | |
| f"/repos/{repo.owner}/{repo.name}/git/trees/{repo.default_branch}", | |
| params={"recursive": "1"} | |
| ) | |
| files = [] | |
| for item in data.get("tree", []): | |
| file = GitHubFile( | |
| path=item["path"], | |
| type=item["type"], | |
| size=item.get("size", 0), | |
| sha=item.get("sha", "") | |
| ) | |
| if filter_config.should_include(file): | |
| files.append(file) | |
| logger.info(f"📂 仓库 {repo.full_name}: 共 {len(data.get('tree', []))} 项, 过滤后 {len(files)} 文件") | |
| return files | |
| # -------------------------------------------------------- | |
| # 文件内容 API | |
| # -------------------------------------------------------- | |
| async def get_file_content( | |
| self, | |
| repo: GitHubRepo, | |
| path: str | |
| ) -> Optional[str]: | |
| """ | |
| 获取单个文件内容 | |
| Args: | |
| repo: 仓库信息 | |
| path: 文件路径 | |
| Returns: | |
| 文件内容 (UTF-8 解码),失败返回 None | |
| """ | |
| try: | |
| data = await self._request( | |
| "GET", | |
| f"/repos/{repo.owner}/{repo.name}/contents/{path}", | |
| params={"ref": repo.default_branch} | |
| ) | |
| # 处理目录情况 | |
| if isinstance(data, list): | |
| file_names = [f["name"] for f in data] | |
| return f"Directory '{path}' contains:\n" + "\n".join( | |
| f"- {name}" for name in file_names | |
| ) | |
| # 解码文件内容 | |
| content = data.get("content", "") | |
| encoding = data.get("encoding", "base64") | |
| if encoding == "base64": | |
| return base64.b64decode(content).decode("utf-8") | |
| return content | |
| except GitHubNotFoundError: | |
| logger.warning(f"文件不存在: {path}") | |
| return None | |
| except UnicodeDecodeError: | |
| logger.warning(f"文件无法解码为 UTF-8: {path}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"获取文件失败 {path}: {e}") | |
| return None | |
| async def get_files_content( | |
| self, | |
| repo: GitHubRepo, | |
| paths: List[str], | |
| show_progress: bool = False | |
| ) -> Dict[str, Optional[str]]: | |
| """ | |
| 批量获取文件内容 (并发优化) | |
| Args: | |
| repo: 仓库信息 | |
| paths: 文件路径列表 | |
| show_progress: 是否显示进度 | |
| Returns: | |
| {path: content} 字典 | |
| """ | |
| if not paths: | |
| return {} | |
| if show_progress: | |
| logger.info(f"📥 开始下载 {len(paths)} 个文件 (并发: {self._semaphore._value})") | |
| # 并发获取所有文件 | |
| tasks = [ | |
| self.get_file_content(repo, path) | |
| for path in paths | |
| ] | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| # 组装结果 | |
| content_map = {} | |
| success_count = 0 | |
| for path, result in zip(paths, results): | |
| if isinstance(result, Exception): | |
| logger.error(f"下载失败 {path}: {result}") | |
| content_map[path] = None | |
| else: | |
| content_map[path] = result | |
| if result is not None: | |
| success_count += 1 | |
| if show_progress: | |
| logger.info(f"✅ 文件下载完成: {success_count}/{len(paths)} 成功") | |
| return content_map | |
| # ============================================================ | |
| # 全局单例管理 | |
| # ============================================================ | |
| _github_client: Optional[GitHubClient] = None | |
| def get_github_client() -> GitHubClient: | |
| """获取 GitHub 客户端单例""" | |
| global _github_client | |
| if _github_client is None: | |
| _github_client = GitHubClient() | |
| return _github_client | |
| async def close_github_client(): | |
| """关闭全局客户端 (应用关闭时调用)""" | |
| global _github_client | |
| if _github_client: | |
| await _github_client.close() | |
| _github_client = None | |
| # ============================================================ | |
| # 便捷函数 (兼容旧接口) | |
| # ============================================================ | |
| def parse_repo_url(url: str) -> Optional[tuple[str, str]]: | |
| """ | |
| 解析 GitHub URL | |
| Args: | |
| url: GitHub 仓库 URL | |
| Returns: | |
| (owner, repo) 元组,无效返回 None | |
| """ | |
| if url.endswith(".git"): | |
| url = url[:-4] | |
| # 支持多种格式 | |
| # https://github.com/owner/repo | |
| # github.com/owner/repo | |
| # owner/repo | |
| parts = url.replace("https://", "").replace("http://", "").split("/") | |
| if "github.com" in parts: | |
| idx = parts.index("github.com") | |
| if len(parts) > idx + 2: | |
| return (parts[idx + 1], parts[idx + 2]) | |
| elif len(parts) == 2: | |
| # 直接是 owner/repo 格式 | |
| return (parts[0], parts[1]) | |
| return None | |