| """Sub2API integration for browsing and importing ChatGPT OAuth accounts from a sub2api admin.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import threading |
| import time |
| import uuid |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from threading import Lock |
|
|
| from curl_cffi.requests import Session |
|
|
| from services.account_service import account_service |
| from services.config import DATA_DIR |
|
|
|
|
| SUB2API_CONFIG_FILE = DATA_DIR / "sub2api_config.json" |
|
|
| |
| |
| _TOKEN_REFRESH_SKEW = 5 * 60 |
|
|
|
|
| def _new_id() -> str: |
| return uuid.uuid4().hex[:12] |
|
|
|
|
| def _now_iso() -> str: |
| return datetime.now(timezone.utc).isoformat() |
|
|
|
|
| def _clean(value: object) -> str: |
| return str(value or "").strip() |
|
|
|
|
| def _normalize_import_job(raw: object, *, fail_unfinished: bool) -> dict | None: |
| if not isinstance(raw, dict): |
| return None |
| status = _clean(raw.get("status")) or "failed" |
| if fail_unfinished and status in {"pending", "running"}: |
| status = "failed" |
| return { |
| "job_id": _clean(raw.get("job_id")) or uuid.uuid4().hex, |
| "status": status, |
| "created_at": _clean(raw.get("created_at")) or _now_iso(), |
| "updated_at": _clean(raw.get("updated_at")) or _clean(raw.get("created_at")) or _now_iso(), |
| "total": int(raw.get("total") or 0), |
| "completed": int(raw.get("completed") or 0), |
| "added": int(raw.get("added") or 0), |
| "skipped": int(raw.get("skipped") or 0), |
| "refreshed": int(raw.get("refreshed") or 0), |
| "failed": int(raw.get("failed") or 0), |
| "errors": raw.get("errors") if isinstance(raw.get("errors"), list) else [], |
| } |
|
|
|
|
| def _normalize_server(raw: dict) -> dict: |
| return { |
| "id": _clean(raw.get("id")) or _new_id(), |
| "name": _clean(raw.get("name")), |
| "base_url": _clean(raw.get("base_url")), |
| "email": _clean(raw.get("email")), |
| "password": _clean(raw.get("password")), |
| "api_key": _clean(raw.get("api_key")), |
| "group_id": _clean(raw.get("group_id")), |
| "import_job": _normalize_import_job(raw.get("import_job"), fail_unfinished=True), |
| } |
|
|
|
|
| class Sub2APIConfig: |
| def __init__(self, store_file: Path): |
| self._store_file = store_file |
| self._lock = Lock() |
| self._servers: list[dict] = self._load() |
|
|
| def _load(self) -> list[dict]: |
| if not self._store_file.exists(): |
| return [] |
| try: |
| raw = json.loads(self._store_file.read_text(encoding="utf-8")) |
| if isinstance(raw, list): |
| return [_normalize_server(item) for item in raw if isinstance(item, dict)] |
| except Exception: |
| pass |
| return [] |
|
|
| def _save(self) -> None: |
| self._store_file.parent.mkdir(parents=True, exist_ok=True) |
| self._store_file.write_text( |
| json.dumps(self._servers, ensure_ascii=False, indent=2) + "\n", |
| encoding="utf-8", |
| ) |
|
|
| def list_servers(self) -> list[dict]: |
| with self._lock: |
| return [dict(server) for server in self._servers] |
|
|
| def get_server(self, server_id: str) -> dict | None: |
| with self._lock: |
| for server in self._servers: |
| if server["id"] == server_id: |
| return dict(server) |
| return None |
|
|
| def add_server( |
| self, |
| *, |
| name: str, |
| base_url: str, |
| email: str, |
| password: str, |
| api_key: str, |
| group_id: str = "", |
| ) -> dict: |
| server = _normalize_server({ |
| "id": _new_id(), |
| "name": name, |
| "base_url": base_url, |
| "email": email, |
| "password": password, |
| "api_key": api_key, |
| "group_id": group_id, |
| }) |
| with self._lock: |
| self._servers.append(server) |
| self._save() |
| _token_cache.pop(server["id"], None) |
| return dict(server) |
|
|
| def update_server(self, server_id: str, updates: dict) -> dict | None: |
| with self._lock: |
| for index, server in enumerate(self._servers): |
| if server["id"] != server_id: |
| continue |
| merged = {**server, **{k: v for k, v in updates.items() if v is not None}, "id": server_id} |
| self._servers[index] = _normalize_server(merged) |
| self._save() |
| result = dict(self._servers[index]) |
| break |
| else: |
| return None |
| _token_cache.pop(server_id, None) |
| return result |
|
|
| def delete_server(self, server_id: str) -> bool: |
| with self._lock: |
| before = len(self._servers) |
| self._servers = [server for server in self._servers if server["id"] != server_id] |
| removed = len(self._servers) < before |
| if removed: |
| self._save() |
| if removed: |
| _token_cache.pop(server_id, None) |
| return removed |
|
|
| def set_import_job(self, server_id: str, import_job: dict | None) -> dict | None: |
| with self._lock: |
| for index, server in enumerate(self._servers): |
| if server["id"] != server_id: |
| continue |
| next_server = dict(server) |
| next_server["import_job"] = _normalize_import_job(import_job, fail_unfinished=False) |
| self._servers[index] = next_server |
| self._save() |
| return dict(next_server) |
| return None |
|
|
| def get_import_job(self, server_id: str) -> dict | None: |
| with self._lock: |
| for server in self._servers: |
| if server["id"] == server_id: |
| job = server.get("import_job") |
| return dict(job) if isinstance(job, dict) else None |
| return None |
|
|
|
|
| |
| _token_cache: dict[str, tuple[str, float]] = {} |
| _token_cache_lock = Lock() |
|
|
|
|
| def _login(base_url: str, email: str, password: str) -> tuple[str, float]: |
| url = f"{base_url.rstrip('/')}/api/v1/auth/login" |
| session = Session(verify=True) |
| try: |
| response = session.post( |
| url, |
| json={"email": email, "password": password}, |
| headers={"Accept": "application/json", "Content-Type": "application/json"}, |
| timeout=30, |
| ) |
| if not response.ok: |
| raise RuntimeError(f"sub2api login failed: HTTP {response.status_code} {response.text[:200]}") |
| payload = response.json() |
| finally: |
| session.close() |
|
|
| body = _unwrap_envelope(payload) |
| if not isinstance(body, dict): |
| raise RuntimeError("sub2api login payload is invalid") |
|
|
| token = _clean(body.get("access_token")) |
| if not token: |
| raise RuntimeError("sub2api login did not return access_token") |
|
|
| expires_in = int(body.get("expires_in") or 3600) |
| expires_at = time.time() + max(60, expires_in) - _TOKEN_REFRESH_SKEW |
| return token, expires_at |
|
|
|
|
| def _auth_headers(server: dict) -> dict[str, str]: |
| api_key = _clean(server.get("api_key")) |
| if api_key: |
| return {"x-api-key": api_key, "Accept": "application/json"} |
|
|
| email = _clean(server.get("email")) |
| password = _clean(server.get("password")) |
| if not email or not password: |
| raise RuntimeError("sub2api server requires email+password or api_key") |
|
|
| server_id = _clean(server.get("id")) |
| base_url = _clean(server.get("base_url")) |
|
|
| with _token_cache_lock: |
| cached = _token_cache.get(server_id) |
| if cached and cached[1] > time.time(): |
| return {"Authorization": f"Bearer {cached[0]}", "Accept": "application/json"} |
|
|
| token, expires_at = _login(base_url, email, password) |
| with _token_cache_lock: |
| _token_cache[server_id] = (token, expires_at) |
| return {"Authorization": f"Bearer {token}", "Accept": "application/json"} |
|
|
|
|
| def _extract_access_token(credentials: object) -> str: |
| if not isinstance(credentials, dict): |
| return "" |
| for key in ("access_token", "accessToken", "token"): |
| value = _clean(credentials.get(key)) |
| if value: |
| return value |
| return "" |
|
|
|
|
| def _unwrap_envelope(payload: object) -> object: |
| """Peel sub2api's `{code, message, data}` envelope, returning the inner `data` field |
| when present. Also handles unwrapped responses from older/alt versions.""" |
| if isinstance(payload, dict) and "data" in payload and "code" in payload: |
| return payload.get("data") |
| return payload |
|
|
|
|
| def _extract_paged_items(payload: object) -> tuple[list, int]: |
| """Return (items, total) from a paginated sub2api response. |
| |
| Handles both the wrapped shape `{code,data:{items,total,...}}` and a few looser |
| variants (`{data:[...]}`, `[...]`, `{items:[...],total:N}`).""" |
| inner = _unwrap_envelope(payload) |
| if isinstance(inner, list): |
| return inner, len(inner) |
| if isinstance(inner, dict): |
| for key in ("items", "data", "list"): |
| value = inner.get(key) |
| if isinstance(value, list): |
| return value, int(inner.get("total") or len(value)) |
| return [], 0 |
|
|
|
|
| def list_remote_accounts(server: dict) -> list[dict]: |
| """Return a flat list of OpenAI OAuth accounts from a sub2api server.""" |
| base_url = _clean(server.get("base_url")) |
| if not base_url: |
| return [] |
|
|
| headers = _auth_headers(server) |
| group_id = _clean(server.get("group_id")) |
|
|
| session = Session(verify=True) |
| items: list[dict] = [] |
| try: |
| page = 1 |
| while True: |
| params: dict[str, object] = { |
| "platform": "openai", |
| "type": "oauth", |
| "page": page, |
| "page_size": 200, |
| } |
| if group_id: |
| params["group"] = group_id |
| response = session.get( |
| f"{base_url.rstrip('/')}/api/v1/admin/accounts", |
| headers=headers, |
| params=params, |
| timeout=30, |
| ) |
| if not response.ok: |
| raise RuntimeError(f"sub2api list failed: HTTP {response.status_code} {response.text[:200]}") |
| payload = response.json() |
|
|
| data, total = _extract_paged_items(payload) |
| if not data: |
| break |
|
|
| for account in data: |
| if not isinstance(account, dict): |
| continue |
| credentials = account.get("credentials") if isinstance(account.get("credentials"), dict) else {} |
| access_token = _extract_access_token(credentials) |
| if not access_token: |
| continue |
| account_id = account.get("id") |
| items.append({ |
| "id": str(account_id) if account_id is not None else _clean(credentials.get("chatgpt_account_id")), |
| "name": _clean(account.get("name")), |
| "email": _clean(credentials.get("email")) or _clean(account.get("name")), |
| "plan_type": _clean(credentials.get("plan_type")), |
| "status": _clean(account.get("status")), |
| "expires_at": _clean(credentials.get("expires_at")), |
| "has_refresh_token": bool(_clean(credentials.get("refresh_token"))), |
| }) |
|
|
| if page * 200 >= total or len(data) < 200: |
| break |
| page += 1 |
| finally: |
| session.close() |
|
|
| return items |
|
|
|
|
| def list_remote_groups(server: dict) -> list[dict]: |
| """Return OpenAI account groups from a sub2api server.""" |
| base_url = _clean(server.get("base_url")) |
| if not base_url: |
| return [] |
|
|
| headers = _auth_headers(server) |
|
|
| session = Session(verify=True) |
| items: list[dict] = [] |
| try: |
| page = 1 |
| while True: |
| response = session.get( |
| f"{base_url.rstrip('/')}/api/v1/admin/groups", |
| headers=headers, |
| params={ |
| "page": page, |
| "page_size": 200, |
| }, |
| timeout=30, |
| ) |
| if not response.ok: |
| raise RuntimeError(f"sub2api groups failed: HTTP {response.status_code} {response.text[:200]}") |
| payload = response.json() |
|
|
| data, total = _extract_paged_items(payload) |
| if not data: |
| break |
|
|
| for group in data: |
| if not isinstance(group, dict): |
| continue |
| group_id = group.get("id") |
| if group_id is None: |
| continue |
| items.append({ |
| "id": str(group_id), |
| "name": _clean(group.get("name")), |
| "description": _clean(group.get("description")), |
| "platform": _clean(group.get("platform")), |
| "status": _clean(group.get("status")), |
| "account_count": int(group.get("account_count") or 0), |
| "active_account_count": int(group.get("active_account_count") or 0), |
| }) |
|
|
| if page * 200 >= total or len(data) < 200: |
| break |
| page += 1 |
| finally: |
| session.close() |
|
|
| return items |
|
|
|
|
| def _fetch_access_token_for_account(server: dict, account_id: str) -> tuple[str, dict]: |
| """Return (access_token, account_meta) for a single sub2api account id.""" |
| base_url = _clean(server.get("base_url")) |
| headers = _auth_headers(server) |
|
|
| session = Session(verify=True) |
| try: |
| response = session.get( |
| f"{base_url.rstrip('/')}/api/v1/admin/accounts/{account_id}", |
| headers=headers, |
| timeout=30, |
| ) |
| if not response.ok: |
| raise RuntimeError(f"HTTP {response.status_code}") |
| payload = response.json() |
| finally: |
| session.close() |
|
|
| account = _unwrap_envelope(payload) |
| if not isinstance(account, dict): |
| account = payload if isinstance(payload, dict) else {} |
| credentials = account.get("credentials") if isinstance(account.get("credentials"), dict) else {} |
| access_token = _extract_access_token(credentials) |
| if not access_token: |
| raise RuntimeError("missing access_token") |
| return access_token, { |
| "email": _clean(credentials.get("email")), |
| "plan_type": _clean(credentials.get("plan_type")), |
| } |
|
|
|
|
| class Sub2APIImportService: |
| def __init__(self, sub2api_config: Sub2APIConfig): |
| self._config = sub2api_config |
|
|
| def start_import(self, server: dict, account_ids: list[str]) -> dict: |
| ids = [_clean(item) for item in account_ids if _clean(item)] |
| if not ids: |
| raise ValueError("account ids is required") |
|
|
| server_id = _clean(server.get("id")) |
| job = { |
| "job_id": uuid.uuid4().hex, |
| "status": "pending", |
| "created_at": _now_iso(), |
| "updated_at": _now_iso(), |
| "total": len(ids), |
| "completed": 0, |
| "added": 0, |
| "skipped": 0, |
| "refreshed": 0, |
| "failed": 0, |
| "errors": [], |
| } |
| saved = self._config.set_import_job(server_id, job) |
| if saved is None: |
| raise ValueError("server not found") |
|
|
| thread = threading.Thread( |
| target=self._run_import, |
| args=(server_id, server, ids), |
| name=f"sub2api-import-{server_id}", |
| daemon=True, |
| ) |
| thread.start() |
| return dict(saved.get("import_job") or job) |
|
|
| def _update_job(self, server_id: str, **updates) -> None: |
| current = self._config.get_import_job(server_id) |
| if current is None: |
| return |
| next_job = {**current, **updates, "updated_at": _now_iso()} |
| self._config.set_import_job(server_id, next_job) |
|
|
| def _append_error(self, server_id: str, account_id: str, message: str) -> None: |
| current = self._config.get_import_job(server_id) |
| if current is None: |
| return |
| errors = list(current.get("errors") or []) |
| errors.append({"name": account_id, "error": message}) |
| self._update_job(server_id, errors=errors, failed=len(errors)) |
|
|
| def _run_import(self, server_id: str, server: dict, account_ids: list[str]) -> None: |
| self._update_job(server_id, status="running") |
|
|
| tokens: list[str] = [] |
| max_workers = min(8, max(1, len(account_ids))) |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| future_map = { |
| executor.submit(_fetch_access_token_for_account, server, account_id): account_id |
| for account_id in account_ids |
| } |
| for future in as_completed(future_map): |
| account_id = future_map[future] |
| try: |
| token, _meta = future.result() |
| tokens.append(token) |
| except Exception as exc: |
| self._append_error(server_id, account_id, str(exc) or "unknown error") |
|
|
| current = self._config.get_import_job(server_id) or {} |
| failed = len(current.get("errors") or []) |
| self._update_job( |
| server_id, |
| completed=int(current.get("completed") or 0) + 1, |
| failed=failed, |
| ) |
|
|
| if not tokens: |
| current = self._config.get_import_job(server_id) or {} |
| self._update_job( |
| server_id, |
| status="failed", |
| completed=int(current.get("total") or 0), |
| failed=len(current.get("errors") or []), |
| ) |
| return |
|
|
| add_result = account_service.add_accounts(tokens) |
| refresh_result = account_service.refresh_accounts(tokens) |
| current = self._config.get_import_job(server_id) or {} |
| self._update_job( |
| server_id, |
| status="completed", |
| completed=len(account_ids), |
| added=int(add_result.get("added") or 0), |
| skipped=int(add_result.get("skipped") or 0), |
| refreshed=int(refresh_result.get("refreshed") or 0), |
| failed=len(current.get("errors") or []), |
| ) |
|
|
|
|
| sub2api_config = Sub2APIConfig(SUB2API_CONFIG_FILE) |
| sub2api_import_service = Sub2APIImportService(sub2api_config) |
|
|