Spaces:
Paused
Paused
| from __future__ import annotations | |
| import json | |
| import shutil | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Any | |
| from git import Repo | |
| from git.exc import GitCommandError | |
| from services.storage.base import StorageBackend | |
| class GitStorageBackend(StorageBackend): | |
| """Git 私有仓库存储后端""" | |
| def __init__( | |
| self, | |
| repo_url: str, | |
| token: str, | |
| branch: str = "main", | |
| file_path: str = "accounts.json", | |
| auth_keys_file_path: str = "auth_keys.json", | |
| local_cache_dir: Path | None = None, | |
| ): | |
| self.repo_url = repo_url | |
| self.token = token | |
| self.branch = branch | |
| self.file_path = file_path | |
| self.auth_keys_file_path = auth_keys_file_path | |
| # 本地缓存目录 | |
| if local_cache_dir is None: | |
| local_cache_dir = Path(tempfile.gettempdir()) / "chatgpt2api_git_cache" | |
| self.local_cache_dir = local_cache_dir | |
| self.local_cache_dir.mkdir(parents=True, exist_ok=True) | |
| # 构建带认证的 Git URL | |
| self.auth_repo_url = self._build_auth_url(repo_url, token) | |
| def _build_auth_url(repo_url: str, token: str) -> str: | |
| """构建带认证的 Git URL""" | |
| if not token: | |
| return repo_url | |
| # 支持 HTTPS 格式:https://github.com/user/repo.git | |
| if repo_url.startswith("https://"): | |
| # 插入 token | |
| return repo_url.replace("https://", f"https://{token}@") | |
| # 支持 git@ 格式:git@github.com:user/repo.git | |
| # 转换为 HTTPS 格式 | |
| if repo_url.startswith("git@"): | |
| repo_url = repo_url.replace("git@", "https://") | |
| repo_url = repo_url.replace(".com:", ".com/") | |
| return repo_url.replace("https://", f"https://{token}@") | |
| return repo_url | |
| def _clone_or_pull(self) -> Repo: | |
| """克隆或拉取仓库""" | |
| repo_path = self.local_cache_dir / "repo" | |
| if repo_path.exists() and (repo_path / ".git").exists(): | |
| # 仓库已存在,拉取最新代码 | |
| try: | |
| repo = Repo(repo_path) | |
| origin = repo.remote("origin") | |
| origin.pull(self.branch) | |
| return repo | |
| except GitCommandError: | |
| # 拉取失败,删除重新克隆 | |
| shutil.rmtree(repo_path) | |
| # 克隆仓库 | |
| repo = Repo.clone_from( | |
| self.auth_repo_url, | |
| repo_path, | |
| branch=self.branch, | |
| ) | |
| return repo | |
| def load_accounts(self) -> list[dict[str, Any]]: | |
| """从 Git 仓库加载账号数据""" | |
| try: | |
| return self._load_json_file(self.file_path) | |
| except Exception as e: | |
| print(f"[git-storage] load failed: {e}") | |
| raise | |
| def save_accounts(self, accounts: list[dict[str, Any]]) -> None: | |
| """保存账号数据到 Git 仓库""" | |
| try: | |
| self._save_json_file(self.file_path, accounts, "Update accounts data") | |
| except Exception as e: | |
| print(f"[git-storage] save failed: {e}") | |
| raise e | |
| def load_auth_keys(self) -> list[dict[str, Any]]: | |
| """从 Git 仓库加载鉴权密钥数据""" | |
| try: | |
| data = self._load_json_value(self.auth_keys_file_path) | |
| if isinstance(data, dict): | |
| data = data.get("items") | |
| return data if isinstance(data, list) else [] | |
| except Exception as e: | |
| print(f"[git-storage] load failed: {e}") | |
| raise | |
| def save_auth_keys(self, auth_keys: list[dict[str, Any]]) -> None: | |
| """保存鉴权密钥数据到 Git 仓库""" | |
| try: | |
| self._save_json_file(self.auth_keys_file_path, {"items": auth_keys}, "Update auth keys data") | |
| except Exception as e: | |
| print(f"[git-storage] save failed: {e}") | |
| raise e | |
| def _load_json_file(self, file_path: str) -> list[dict[str, Any]]: | |
| data = self._load_json_value(file_path) | |
| return data if isinstance(data, list) else [] | |
| def _load_json_value(self, file_path: str) -> Any: | |
| repo = self._clone_or_pull() | |
| file_full_path = Path(repo.working_dir) / file_path | |
| if not file_full_path.exists(): | |
| return None | |
| return json.loads(file_full_path.read_text(encoding="utf-8")) | |
| def _save_json_file(self, file_path: str, items: Any, message: str) -> None: | |
| repo = self._clone_or_pull() | |
| file_full_path = Path(repo.working_dir) / file_path | |
| file_full_path.parent.mkdir(parents=True, exist_ok=True) | |
| file_full_path.write_text( | |
| json.dumps(items, ensure_ascii=False, indent=2) + "\n", | |
| encoding="utf-8", | |
| ) | |
| repo.index.add([file_path]) | |
| if repo.is_dirty(): | |
| repo.index.commit(message) | |
| repo.remote("origin").push(self.branch) | |
| def health_check(self) -> dict[str, Any]: | |
| """健康检查""" | |
| try: | |
| repo = self._clone_or_pull() | |
| return { | |
| "status": "healthy", | |
| "backend": "git", | |
| "repo_url": self._mask_token(self.repo_url), | |
| "branch": self.branch, | |
| "file_path": self.file_path, | |
| "auth_keys_file_path": self.auth_keys_file_path, | |
| "last_commit": repo.head.commit.hexsha[:8], | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "unhealthy", | |
| "backend": "git", | |
| "error": str(e), | |
| } | |
| def get_backend_info(self) -> dict[str, Any]: | |
| """获取存储后端信息""" | |
| return { | |
| "type": "git", | |
| "description": "Git 私有仓库存储", | |
| "repo_url": self._mask_token(self.repo_url), | |
| "branch": self.branch, | |
| "file_path": self.file_path, | |
| "auth_keys_file_path": self.auth_keys_file_path, | |
| } | |
| def _mask_token(url: str) -> str: | |
| """隐藏 URL 中的 token""" | |
| if "@" in url and "://" in url: | |
| protocol, rest = url.split("://", 1) | |
| if "@" in rest: | |
| _, host = rest.split("@", 1) | |
| return f"{protocol}://****@{host}" | |
| return url | |