# -*- coding: utf-8 -*- """ GitHub 异步客户端 设计原则: 1. 异步非阻塞 - 使用 httpx.AsyncClient 2. 连接池复用 - 单例模式管理客户端生命周期 3. 自动重试 - 集成 tenacity 处理瞬时错误 4. 类型安全 - 完整的类型注解 5. 可扩展 - 易于添加新的 API 端点 """ import asyncio import base64 import logging import os from dataclasses import dataclass, field from typing import List, Optional, Dict, Any, Set from contextlib import asynccontextmanager import httpx from app.core.config import settings from app.utils.retry import llm_retry # 复用已有的重试装饰器 logger = logging.getLogger(__name__) # ============================================================ # 数据模型 # ============================================================ @dataclass class GitHubFile: """GitHub 文件信息""" path: str type: str # "blob" | "tree" size: int = 0 sha: str = "" @property def is_file(self) -> bool: return self.type == "blob" @property def is_directory(self) -> bool: return self.type == "tree" @dataclass class GitHubRepo: """GitHub 仓库信息""" owner: str name: str default_branch: str = "main" description: str = "" stars: int = 0 @property def full_name(self) -> str: return f"{self.owner}/{self.name}" @dataclass class FileFilter: """文件过滤配置""" ignored_extensions: Set[str] = field(default_factory=lambda: { '.png', '.jpg', '.jpeg', '.gif', '.svg', '.ico', '.mp4', '.webp', '.pyc', '.pyo', '.lock', '.zip', '.tar', '.gz', '.pdf', '.woff', '.woff2', '.DS_Store', '.gitignore', '.gitattributes', '.editorconfig' }) ignored_directories: Set[str] = field(default_factory=lambda: { '.git', '.github', '.vscode', '.idea', '__pycache__', 'node_modules', 'venv', 'env', '.env', 'build', 'dist', 'site-packages', 'migrations', '.next', '.nuxt', 'coverage', 'vendor', 'target', 'out', 'bin', 'obj' }) max_file_size: int = 500_000 # 500KB def should_include(self, file: GitHubFile) -> bool: """判断文件是否应该被包含""" if not file.is_file: return False # 检查目录 path_parts = file.path.split("/") if any(part in self.ignored_directories for part in path_parts): return False # 检查扩展名 ext = os.path.splitext(file.path)[1].lower() if ext in self.ignored_extensions: return False # 检查文件大小 if file.size > self.max_file_size: return False return True # ============================================================ # 异常定义 # ============================================================ class GitHubError(Exception): """GitHub API 错误基类""" def __init__(self, message: str, status_code: int = 0): self.message = message self.status_code = status_code super().__init__(message) class GitHubAuthError(GitHubError): """认证错误 (401)""" pass class GitHubRateLimitError(GitHubError): """速率限制错误 (403)""" pass class GitHubNotFoundError(GitHubError): """资源不存在 (404)""" pass # ============================================================ # GitHub 异步客户端 # ============================================================ class GitHubClient: """ GitHub 异步 API 客户端 使用示例: ```python async with GitHubClient() as client: repo = await client.get_repo("owner", "repo") files = await client.get_repo_tree(repo) content = await client.get_file_content(repo, "README.md") ``` """ BASE_URL = "https://api.github.com" def __init__( self, token: Optional[str] = None, timeout: float = 30.0, max_concurrent_requests: int = 10 ): self.token = token or settings.GITHUB_TOKEN self.timeout = timeout self._client: Optional[httpx.AsyncClient] = None self._semaphore = asyncio.Semaphore(max_concurrent_requests) @property def _headers(self) -> Dict[str, str]: """构建请求头""" headers = { "Accept": "application/vnd.github.v3+json", "User-Agent": "GitHub-Agent-Demo/1.0" } if self.token: headers["Authorization"] = f"Bearer {self.token}" return headers async def _ensure_client(self) -> httpx.AsyncClient: """确保客户端已初始化""" if self._client is None or self._client.is_closed: self._client = httpx.AsyncClient( base_url=self.BASE_URL, headers=self._headers, timeout=httpx.Timeout(self.timeout), follow_redirects=True, limits=httpx.Limits( max_keepalive_connections=20, max_connections=50 ) ) return self._client async def close(self): """关闭客户端连接""" if self._client and not self._client.is_closed: await self._client.aclose() self._client = None async def __aenter__(self): await self._ensure_client() return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() def _handle_error(self, response: httpx.Response, context: str = ""): """统一错误处理""" status = response.status_code try: data = response.json() message = data.get("message", response.text) except Exception: message = response.text error_msg = f"{context}: {message}" if context else message if status == 401: raise GitHubAuthError( "GitHub Token 无效或已过期,请检查 .env 配置", status ) elif status == 403: if "rate limit" in message.lower(): raise GitHubRateLimitError( "GitHub API 请求已达上限,请稍后重试或添加 Token", status ) raise GitHubError(error_msg, status) elif status == 404: raise GitHubNotFoundError(error_msg, status) else: raise GitHubError(error_msg, status) @llm_retry async def _request( self, method: str, endpoint: str, **kwargs ) -> Dict[str, Any]: """ 发送 API 请求 (带重试) Args: method: HTTP 方法 endpoint: API 端点 (如 /repos/{owner}/{repo}) **kwargs: 传递给 httpx 的参数 Returns: JSON 响应 """ async with self._semaphore: client = await self._ensure_client() response = await client.request(method, endpoint, **kwargs) if response.status_code >= 400: self._handle_error(response, endpoint) return response.json() async def _request_raw( self, method: str, endpoint: str, **kwargs ) -> httpx.Response: """发送请求并返回原始响应""" async with self._semaphore: client = await self._ensure_client() return await client.request(method, endpoint, **kwargs) # -------------------------------------------------------- # 仓库相关 API # -------------------------------------------------------- async def get_repo(self, owner: str, name: str) -> GitHubRepo: """获取仓库信息""" data = await self._request("GET", f"/repos/{owner}/{name}") return GitHubRepo( owner=owner, name=name, default_branch=data.get("default_branch", "main"), description=data.get("description", ""), stars=data.get("stargazers_count", 0) ) async def get_repo_tree( self, repo: GitHubRepo, file_filter: Optional[FileFilter] = None ) -> List[GitHubFile]: """ 获取仓库文件树 Args: repo: 仓库信息 file_filter: 文件过滤器 (默认使用标准过滤) Returns: 过滤后的文件列表 """ filter_config = file_filter or FileFilter() data = await self._request( "GET", f"/repos/{repo.owner}/{repo.name}/git/trees/{repo.default_branch}", params={"recursive": "1"} ) files = [] for item in data.get("tree", []): file = GitHubFile( path=item["path"], type=item["type"], size=item.get("size", 0), sha=item.get("sha", "") ) if filter_config.should_include(file): files.append(file) logger.info(f"📂 仓库 {repo.full_name}: 共 {len(data.get('tree', []))} 项, 过滤后 {len(files)} 文件") return files # -------------------------------------------------------- # 文件内容 API # -------------------------------------------------------- async def get_file_content( self, repo: GitHubRepo, path: str ) -> Optional[str]: """ 获取单个文件内容 Args: repo: 仓库信息 path: 文件路径 Returns: 文件内容 (UTF-8 解码),失败返回 None """ try: data = await self._request( "GET", f"/repos/{repo.owner}/{repo.name}/contents/{path}", params={"ref": repo.default_branch} ) # 处理目录情况 if isinstance(data, list): file_names = [f["name"] for f in data] return f"Directory '{path}' contains:\n" + "\n".join( f"- {name}" for name in file_names ) # 解码文件内容 content = data.get("content", "") encoding = data.get("encoding", "base64") if encoding == "base64": return base64.b64decode(content).decode("utf-8") return content except GitHubNotFoundError: logger.warning(f"文件不存在: {path}") return None except UnicodeDecodeError: logger.warning(f"文件无法解码为 UTF-8: {path}") return None except Exception as e: logger.error(f"获取文件失败 {path}: {e}") return None async def get_files_content( self, repo: GitHubRepo, paths: List[str], show_progress: bool = False ) -> Dict[str, Optional[str]]: """ 批量获取文件内容 (并发优化) Args: repo: 仓库信息 paths: 文件路径列表 show_progress: 是否显示进度 Returns: {path: content} 字典 """ if not paths: return {} if show_progress: logger.info(f"📥 开始下载 {len(paths)} 个文件 (并发: {self._semaphore._value})") # 并发获取所有文件 tasks = [ self.get_file_content(repo, path) for path in paths ] results = await asyncio.gather(*tasks, return_exceptions=True) # 组装结果 content_map = {} success_count = 0 for path, result in zip(paths, results): if isinstance(result, Exception): logger.error(f"下载失败 {path}: {result}") content_map[path] = None else: content_map[path] = result if result is not None: success_count += 1 if show_progress: logger.info(f"✅ 文件下载完成: {success_count}/{len(paths)} 成功") return content_map # ============================================================ # 全局单例管理 # ============================================================ _github_client: Optional[GitHubClient] = None def get_github_client() -> GitHubClient: """获取 GitHub 客户端单例""" global _github_client if _github_client is None: _github_client = GitHubClient() return _github_client async def close_github_client(): """关闭全局客户端 (应用关闭时调用)""" global _github_client if _github_client: await _github_client.close() _github_client = None # ============================================================ # 便捷函数 (兼容旧接口) # ============================================================ def parse_repo_url(url: str) -> Optional[tuple[str, str]]: """ 解析 GitHub URL Args: url: GitHub 仓库 URL Returns: (owner, repo) 元组,无效返回 None """ if url.endswith(".git"): url = url[:-4] # 支持多种格式 # https://github.com/owner/repo # github.com/owner/repo # owner/repo parts = url.replace("https://", "").replace("http://", "").split("/") if "github.com" in parts: idx = parts.index("github.com") if len(parts) > idx + 2: return (parts[idx + 1], parts[idx + 2]) elif len(parts) == 2: # 直接是 owner/repo 格式 return (parts[0], parts[1]) return None