chatgpt2api / services /storage /git_storage.py
tx1538's picture
Upload 179 files
9d7ddb9 verified
Raw
History Blame
6.37 kB
from __future__ import annotations
import json
import shutil
import tempfile
from pathlib import Path
from typing import Any
from git import Repo
from git.exc import GitCommandError
from services.storage.base import StorageBackend
class GitStorageBackend(StorageBackend):
"""Git 私有仓库存储后端"""
def __init__(
self,
repo_url: str,
token: str,
branch: str = "main",
file_path: str = "accounts.json",
auth_keys_file_path: str = "auth_keys.json",
local_cache_dir: Path | None = None,
):
self.repo_url = repo_url
self.token = token
self.branch = branch
self.file_path = file_path
self.auth_keys_file_path = auth_keys_file_path
# 本地缓存目录
if local_cache_dir is None:
local_cache_dir = Path(tempfile.gettempdir()) / "chatgpt2api_git_cache"
self.local_cache_dir = local_cache_dir
self.local_cache_dir.mkdir(parents=True, exist_ok=True)
# 构建带认证的 Git URL
self.auth_repo_url = self._build_auth_url(repo_url, token)
@staticmethod
def _build_auth_url(repo_url: str, token: str) -> str:
"""构建带认证的 Git URL"""
if not token:
return repo_url
# 支持 HTTPS 格式:https://github.com/user/repo.git
if repo_url.startswith("https://"):
# 插入 token
return repo_url.replace("https://", f"https://{token}@")
# 支持 git@ 格式:git@github.com:user/repo.git
# 转换为 HTTPS 格式
if repo_url.startswith("git@"):
repo_url = repo_url.replace("git@", "https://")
repo_url = repo_url.replace(".com:", ".com/")
return repo_url.replace("https://", f"https://{token}@")
return repo_url
def _clone_or_pull(self) -> Repo:
"""克隆或拉取仓库"""
repo_path = self.local_cache_dir / "repo"
if repo_path.exists() and (repo_path / ".git").exists():
# 仓库已存在,拉取最新代码
try:
repo = Repo(repo_path)
origin = repo.remote("origin")
origin.pull(self.branch)
return repo
except GitCommandError:
# 拉取失败,删除重新克隆
shutil.rmtree(repo_path)
# 克隆仓库
repo = Repo.clone_from(
self.auth_repo_url,
repo_path,
branch=self.branch,
)
return repo
def load_accounts(self) -> list[dict[str, Any]]:
"""从 Git 仓库加载账号数据"""
try:
return self._load_json_file(self.file_path)
except Exception as e:
print(f"[git-storage] load failed: {e}")
raise
def save_accounts(self, accounts: list[dict[str, Any]]) -> None:
"""保存账号数据到 Git 仓库"""
try:
self._save_json_file(self.file_path, accounts, "Update accounts data")
except Exception as e:
print(f"[git-storage] save failed: {e}")
raise e
def load_auth_keys(self) -> list[dict[str, Any]]:
"""从 Git 仓库加载鉴权密钥数据"""
try:
data = self._load_json_value(self.auth_keys_file_path)
if isinstance(data, dict):
data = data.get("items")
return data if isinstance(data, list) else []
except Exception as e:
print(f"[git-storage] load failed: {e}")
raise
def save_auth_keys(self, auth_keys: list[dict[str, Any]]) -> None:
"""保存鉴权密钥数据到 Git 仓库"""
try:
self._save_json_file(self.auth_keys_file_path, {"items": auth_keys}, "Update auth keys data")
except Exception as e:
print(f"[git-storage] save failed: {e}")
raise e
def _load_json_file(self, file_path: str) -> list[dict[str, Any]]:
data = self._load_json_value(file_path)
return data if isinstance(data, list) else []
def _load_json_value(self, file_path: str) -> Any:
repo = self._clone_or_pull()
file_full_path = Path(repo.working_dir) / file_path
if not file_full_path.exists():
return None
return json.loads(file_full_path.read_text(encoding="utf-8"))
def _save_json_file(self, file_path: str, items: Any, message: str) -> None:
repo = self._clone_or_pull()
file_full_path = Path(repo.working_dir) / file_path
file_full_path.parent.mkdir(parents=True, exist_ok=True)
file_full_path.write_text(
json.dumps(items, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
repo.index.add([file_path])
if repo.is_dirty():
repo.index.commit(message)
repo.remote("origin").push(self.branch)
def health_check(self) -> dict[str, Any]:
"""健康检查"""
try:
repo = self._clone_or_pull()
return {
"status": "healthy",
"backend": "git",
"repo_url": self._mask_token(self.repo_url),
"branch": self.branch,
"file_path": self.file_path,
"auth_keys_file_path": self.auth_keys_file_path,
"last_commit": repo.head.commit.hexsha[:8],
}
except Exception as e:
return {
"status": "unhealthy",
"backend": "git",
"error": str(e),
}
def get_backend_info(self) -> dict[str, Any]:
"""获取存储后端信息"""
return {
"type": "git",
"description": "Git 私有仓库存储",
"repo_url": self._mask_token(self.repo_url),
"branch": self.branch,
"file_path": self.file_path,
"auth_keys_file_path": self.auth_keys_file_path,
}
@staticmethod
def _mask_token(url: str) -> str:
"""隐藏 URL 中的 token"""
if "@" in url and "://" in url:
protocol, rest = url.split("://", 1)
if "@" in rest:
_, host = rest.split("@", 1)
return f"{protocol}://****@{host}"
return url