Spaces:
Paused
Paused
| """Sub2API integration for browsing and importing ChatGPT OAuth accounts from a sub2api admin.""" | |
| from __future__ import annotations | |
| import json | |
| import threading | |
| import time | |
| import uuid | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from threading import Lock | |
| from curl_cffi.requests import Session | |
| from services.account_service import account_service | |
| from services.config import DATA_DIR | |
| SUB2API_CONFIG_FILE = DATA_DIR / "sub2api_config.json" | |
| # Cached JWT per server to avoid re-login on every list/import call. | |
| # Token lifetime on sub2api defaults to 24h; we refresh 5 min before expiry. | |
| _TOKEN_REFRESH_SKEW = 5 * 60 | |
| def _new_id() -> str: | |
| return uuid.uuid4().hex[:12] | |
| def _now_iso() -> str: | |
| return datetime.now(timezone.utc).isoformat() | |
| def _clean(value: object) -> str: | |
| return str(value or "").strip() | |
| def _normalize_import_job(raw: object, *, fail_unfinished: bool) -> dict | None: | |
| if not isinstance(raw, dict): | |
| return None | |
| status = _clean(raw.get("status")) or "failed" | |
| if fail_unfinished and status in {"pending", "running"}: | |
| status = "failed" | |
| return { | |
| "job_id": _clean(raw.get("job_id")) or uuid.uuid4().hex, | |
| "status": status, | |
| "created_at": _clean(raw.get("created_at")) or _now_iso(), | |
| "updated_at": _clean(raw.get("updated_at")) or _clean(raw.get("created_at")) or _now_iso(), | |
| "total": int(raw.get("total") or 0), | |
| "completed": int(raw.get("completed") or 0), | |
| "added": int(raw.get("added") or 0), | |
| "skipped": int(raw.get("skipped") or 0), | |
| "refreshed": int(raw.get("refreshed") or 0), | |
| "failed": int(raw.get("failed") or 0), | |
| "errors": raw.get("errors") if isinstance(raw.get("errors"), list) else [], | |
| } | |
| def _normalize_server(raw: dict) -> dict: | |
| return { | |
| "id": _clean(raw.get("id")) or _new_id(), | |
| "name": _clean(raw.get("name")), | |
| "base_url": _clean(raw.get("base_url")), | |
| "email": _clean(raw.get("email")), | |
| "password": _clean(raw.get("password")), | |
| "api_key": _clean(raw.get("api_key")), | |
| "group_id": _clean(raw.get("group_id")), | |
| "import_job": _normalize_import_job(raw.get("import_job"), fail_unfinished=True), | |
| } | |
| class Sub2APIConfig: | |
| def __init__(self, store_file: Path): | |
| self._store_file = store_file | |
| self._lock = Lock() | |
| self._servers: list[dict] = self._load() | |
| def _load(self) -> list[dict]: | |
| if not self._store_file.exists(): | |
| return [] | |
| try: | |
| raw = json.loads(self._store_file.read_text(encoding="utf-8")) | |
| if isinstance(raw, list): | |
| return [_normalize_server(item) for item in raw if isinstance(item, dict)] | |
| except Exception: | |
| pass | |
| return [] | |
| def _save(self) -> None: | |
| self._store_file.parent.mkdir(parents=True, exist_ok=True) | |
| self._store_file.write_text( | |
| json.dumps(self._servers, ensure_ascii=False, indent=2) + "\n", | |
| encoding="utf-8", | |
| ) | |
| def list_servers(self) -> list[dict]: | |
| with self._lock: | |
| return [dict(server) for server in self._servers] | |
| def get_server(self, server_id: str) -> dict | None: | |
| with self._lock: | |
| for server in self._servers: | |
| if server["id"] == server_id: | |
| return dict(server) | |
| return None | |
| def add_server( | |
| self, | |
| *, | |
| name: str, | |
| base_url: str, | |
| email: str, | |
| password: str, | |
| api_key: str, | |
| group_id: str = "", | |
| ) -> dict: | |
| server = _normalize_server({ | |
| "id": _new_id(), | |
| "name": name, | |
| "base_url": base_url, | |
| "email": email, | |
| "password": password, | |
| "api_key": api_key, | |
| "group_id": group_id, | |
| }) | |
| with self._lock: | |
| self._servers.append(server) | |
| self._save() | |
| _token_cache.pop(server["id"], None) | |
| return dict(server) | |
| def update_server(self, server_id: str, updates: dict) -> dict | None: | |
| with self._lock: | |
| for index, server in enumerate(self._servers): | |
| if server["id"] != server_id: | |
| continue | |
| merged = {**server, **{k: v for k, v in updates.items() if v is not None}, "id": server_id} | |
| self._servers[index] = _normalize_server(merged) | |
| self._save() | |
| result = dict(self._servers[index]) | |
| break | |
| else: | |
| return None | |
| _token_cache.pop(server_id, None) | |
| return result | |
| def delete_server(self, server_id: str) -> bool: | |
| with self._lock: | |
| before = len(self._servers) | |
| self._servers = [server for server in self._servers if server["id"] != server_id] | |
| removed = len(self._servers) < before | |
| if removed: | |
| self._save() | |
| if removed: | |
| _token_cache.pop(server_id, None) | |
| return removed | |
| def set_import_job(self, server_id: str, import_job: dict | None) -> dict | None: | |
| with self._lock: | |
| for index, server in enumerate(self._servers): | |
| if server["id"] != server_id: | |
| continue | |
| next_server = dict(server) | |
| next_server["import_job"] = _normalize_import_job(import_job, fail_unfinished=False) | |
| self._servers[index] = next_server | |
| self._save() | |
| return dict(next_server) | |
| return None | |
| def get_import_job(self, server_id: str) -> dict | None: | |
| with self._lock: | |
| for server in self._servers: | |
| if server["id"] == server_id: | |
| job = server.get("import_job") | |
| return dict(job) if isinstance(job, dict) else None | |
| return None | |
| # Per-server cached access token: {server_id: (jwt, expires_at_epoch)} | |
| _token_cache: dict[str, tuple[str, float]] = {} | |
| _token_cache_lock = Lock() | |
| def _login(base_url: str, email: str, password: str) -> tuple[str, float]: | |
| url = f"{base_url.rstrip('/')}/api/v1/auth/login" | |
| session = Session(verify=True) | |
| try: | |
| response = session.post( | |
| url, | |
| json={"email": email, "password": password}, | |
| headers={"Accept": "application/json", "Content-Type": "application/json"}, | |
| timeout=30, | |
| ) | |
| if not response.ok: | |
| raise RuntimeError(f"sub2api login failed: HTTP {response.status_code} {response.text[:200]}") | |
| payload = response.json() | |
| finally: | |
| session.close() | |
| body = _unwrap_envelope(payload) | |
| if not isinstance(body, dict): | |
| raise RuntimeError("sub2api login payload is invalid") | |
| token = _clean(body.get("access_token")) | |
| if not token: | |
| raise RuntimeError("sub2api login did not return access_token") | |
| expires_in = int(body.get("expires_in") or 3600) | |
| expires_at = time.time() + max(60, expires_in) - _TOKEN_REFRESH_SKEW | |
| return token, expires_at | |
| def _auth_headers(server: dict) -> dict[str, str]: | |
| api_key = _clean(server.get("api_key")) | |
| if api_key: | |
| return {"x-api-key": api_key, "Accept": "application/json"} | |
| email = _clean(server.get("email")) | |
| password = _clean(server.get("password")) | |
| if not email or not password: | |
| raise RuntimeError("sub2api server requires email+password or api_key") | |
| server_id = _clean(server.get("id")) | |
| base_url = _clean(server.get("base_url")) | |
| with _token_cache_lock: | |
| cached = _token_cache.get(server_id) | |
| if cached and cached[1] > time.time(): | |
| return {"Authorization": f"Bearer {cached[0]}", "Accept": "application/json"} | |
| token, expires_at = _login(base_url, email, password) | |
| with _token_cache_lock: | |
| _token_cache[server_id] = (token, expires_at) | |
| return {"Authorization": f"Bearer {token}", "Accept": "application/json"} | |
| def _extract_access_token(credentials: object) -> str: | |
| if not isinstance(credentials, dict): | |
| return "" | |
| for key in ("access_token", "accessToken", "token"): | |
| value = _clean(credentials.get(key)) | |
| if value: | |
| return value | |
| return "" | |
| def _unwrap_envelope(payload: object) -> object: | |
| """Peel sub2api's `{code, message, data}` envelope, returning the inner `data` field | |
| when present. Also handles unwrapped responses from older/alt versions.""" | |
| if isinstance(payload, dict) and "data" in payload and "code" in payload: | |
| return payload.get("data") | |
| return payload | |
| def _extract_paged_items(payload: object) -> tuple[list, int]: | |
| """Return (items, total) from a paginated sub2api response. | |
| Handles both the wrapped shape `{code,data:{items,total,...}}` and a few looser | |
| variants (`{data:[...]}`, `[...]`, `{items:[...],total:N}`).""" | |
| inner = _unwrap_envelope(payload) | |
| if isinstance(inner, list): | |
| return inner, len(inner) | |
| if isinstance(inner, dict): | |
| for key in ("items", "data", "list"): | |
| value = inner.get(key) | |
| if isinstance(value, list): | |
| return value, int(inner.get("total") or len(value)) | |
| return [], 0 | |
| def list_remote_accounts(server: dict) -> list[dict]: | |
| """Return a flat list of OpenAI OAuth accounts from a sub2api server.""" | |
| base_url = _clean(server.get("base_url")) | |
| if not base_url: | |
| return [] | |
| headers = _auth_headers(server) | |
| group_id = _clean(server.get("group_id")) | |
| session = Session(verify=True) | |
| items: list[dict] = [] | |
| try: | |
| page = 1 | |
| while True: | |
| params: dict[str, object] = { | |
| "platform": "openai", | |
| "type": "oauth", | |
| "page": page, | |
| "page_size": 200, | |
| } | |
| if group_id: | |
| params["group"] = group_id | |
| response = session.get( | |
| f"{base_url.rstrip('/')}/api/v1/admin/accounts", | |
| headers=headers, | |
| params=params, | |
| timeout=30, | |
| ) | |
| if not response.ok: | |
| raise RuntimeError(f"sub2api list failed: HTTP {response.status_code} {response.text[:200]}") | |
| payload = response.json() | |
| data, total = _extract_paged_items(payload) | |
| if not data: | |
| break | |
| for account in data: | |
| if not isinstance(account, dict): | |
| continue | |
| credentials = account.get("credentials") if isinstance(account.get("credentials"), dict) else {} | |
| access_token = _extract_access_token(credentials) | |
| if not access_token: | |
| continue | |
| account_id = account.get("id") | |
| items.append({ | |
| "id": str(account_id) if account_id is not None else _clean(credentials.get("chatgpt_account_id")), | |
| "name": _clean(account.get("name")), | |
| "email": _clean(credentials.get("email")) or _clean(account.get("name")), | |
| "plan_type": _clean(credentials.get("plan_type")), | |
| "status": _clean(account.get("status")), | |
| "expires_at": _clean(credentials.get("expires_at")), | |
| "has_refresh_token": bool(_clean(credentials.get("refresh_token"))), | |
| }) | |
| if page * 200 >= total or len(data) < 200: | |
| break | |
| page += 1 | |
| finally: | |
| session.close() | |
| return items | |
| def list_remote_groups(server: dict) -> list[dict]: | |
| """Return OpenAI account groups from a sub2api server.""" | |
| base_url = _clean(server.get("base_url")) | |
| if not base_url: | |
| return [] | |
| headers = _auth_headers(server) | |
| session = Session(verify=True) | |
| items: list[dict] = [] | |
| try: | |
| page = 1 | |
| while True: | |
| response = session.get( | |
| f"{base_url.rstrip('/')}/api/v1/admin/groups", | |
| headers=headers, | |
| params={ | |
| "page": page, | |
| "page_size": 200, | |
| }, | |
| timeout=30, | |
| ) | |
| if not response.ok: | |
| raise RuntimeError(f"sub2api groups failed: HTTP {response.status_code} {response.text[:200]}") | |
| payload = response.json() | |
| data, total = _extract_paged_items(payload) | |
| if not data: | |
| break | |
| for group in data: | |
| if not isinstance(group, dict): | |
| continue | |
| group_id = group.get("id") | |
| if group_id is None: | |
| continue | |
| items.append({ | |
| "id": str(group_id), | |
| "name": _clean(group.get("name")), | |
| "description": _clean(group.get("description")), | |
| "platform": _clean(group.get("platform")), | |
| "status": _clean(group.get("status")), | |
| "account_count": int(group.get("account_count") or 0), | |
| "active_account_count": int(group.get("active_account_count") or 0), | |
| }) | |
| if page * 200 >= total or len(data) < 200: | |
| break | |
| page += 1 | |
| finally: | |
| session.close() | |
| return items | |
| def _fetch_access_token_for_account(server: dict, account_id: str) -> tuple[str, dict]: | |
| """Return (access_token, account_meta) for a single sub2api account id.""" | |
| base_url = _clean(server.get("base_url")) | |
| headers = _auth_headers(server) | |
| session = Session(verify=True) | |
| try: | |
| response = session.get( | |
| f"{base_url.rstrip('/')}/api/v1/admin/accounts/{account_id}", | |
| headers=headers, | |
| timeout=30, | |
| ) | |
| if not response.ok: | |
| raise RuntimeError(f"HTTP {response.status_code}") | |
| payload = response.json() | |
| finally: | |
| session.close() | |
| account = _unwrap_envelope(payload) | |
| if not isinstance(account, dict): | |
| account = payload if isinstance(payload, dict) else {} | |
| credentials = account.get("credentials") if isinstance(account.get("credentials"), dict) else {} | |
| access_token = _extract_access_token(credentials) | |
| if not access_token: | |
| raise RuntimeError("missing access_token") | |
| return access_token, { | |
| "email": _clean(credentials.get("email")), | |
| "plan_type": _clean(credentials.get("plan_type")), | |
| } | |
| class Sub2APIImportService: | |
| def __init__(self, sub2api_config: Sub2APIConfig): | |
| self._config = sub2api_config | |
| def start_import(self, server: dict, account_ids: list[str]) -> dict: | |
| ids = [_clean(item) for item in account_ids if _clean(item)] | |
| if not ids: | |
| raise ValueError("account ids is required") | |
| server_id = _clean(server.get("id")) | |
| job = { | |
| "job_id": uuid.uuid4().hex, | |
| "status": "pending", | |
| "created_at": _now_iso(), | |
| "updated_at": _now_iso(), | |
| "total": len(ids), | |
| "completed": 0, | |
| "added": 0, | |
| "skipped": 0, | |
| "refreshed": 0, | |
| "failed": 0, | |
| "errors": [], | |
| } | |
| saved = self._config.set_import_job(server_id, job) | |
| if saved is None: | |
| raise ValueError("server not found") | |
| thread = threading.Thread( | |
| target=self._run_import, | |
| args=(server_id, server, ids), | |
| name=f"sub2api-import-{server_id}", | |
| daemon=True, | |
| ) | |
| thread.start() | |
| return dict(saved.get("import_job") or job) | |
| def _update_job(self, server_id: str, **updates) -> None: | |
| current = self._config.get_import_job(server_id) | |
| if current is None: | |
| return | |
| next_job = {**current, **updates, "updated_at": _now_iso()} | |
| self._config.set_import_job(server_id, next_job) | |
| def _append_error(self, server_id: str, account_id: str, message: str) -> None: | |
| current = self._config.get_import_job(server_id) | |
| if current is None: | |
| return | |
| errors = list(current.get("errors") or []) | |
| errors.append({"name": account_id, "error": message}) | |
| self._update_job(server_id, errors=errors, failed=len(errors)) | |
| def _run_import(self, server_id: str, server: dict, account_ids: list[str]) -> None: | |
| self._update_job(server_id, status="running") | |
| tokens: list[str] = [] | |
| max_workers = min(8, max(1, len(account_ids))) | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| future_map = { | |
| executor.submit(_fetch_access_token_for_account, server, account_id): account_id | |
| for account_id in account_ids | |
| } | |
| for future in as_completed(future_map): | |
| account_id = future_map[future] | |
| try: | |
| token, _meta = future.result() | |
| tokens.append(token) | |
| except Exception as exc: | |
| self._append_error(server_id, account_id, str(exc) or "unknown error") | |
| current = self._config.get_import_job(server_id) or {} | |
| failed = len(current.get("errors") or []) | |
| self._update_job( | |
| server_id, | |
| completed=int(current.get("completed") or 0) + 1, | |
| failed=failed, | |
| ) | |
| if not tokens: | |
| current = self._config.get_import_job(server_id) or {} | |
| self._update_job( | |
| server_id, | |
| status="failed", | |
| completed=int(current.get("total") or 0), | |
| failed=len(current.get("errors") or []), | |
| ) | |
| return | |
| add_result = account_service.add_accounts(tokens) | |
| refresh_result = account_service.refresh_accounts(tokens) | |
| current = self._config.get_import_job(server_id) or {} | |
| self._update_job( | |
| server_id, | |
| status="completed", | |
| completed=len(account_ids), | |
| added=int(add_result.get("added") or 0), | |
| skipped=int(add_result.get("skipped") or 0), | |
| refreshed=int(refresh_result.get("refreshed") or 0), | |
| failed=len(current.get("errors") or []), | |
| ) | |
| sub2api_config = Sub2APIConfig(SUB2API_CONFIG_FILE) | |
| sub2api_import_service = Sub2APIImportService(sub2api_config) | |