CHATSAM / services /sub2api_service.py
yiranxiaohui's picture
fix(sub2api): 正确解析响应的 {code,data:{items,total}} 包装
96921b5
Raw
History Blame Contribute Delete
18.6 kB
"""Sub2API integration for browsing and importing ChatGPT OAuth accounts from a sub2api admin."""
from __future__ import annotations
import json
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
from pathlib import Path
from threading import Lock
from curl_cffi.requests import Session
from services.account_service import account_service
from services.config import DATA_DIR
SUB2API_CONFIG_FILE = DATA_DIR / "sub2api_config.json"
# Cached JWT per server to avoid re-login on every list/import call.
# Token lifetime on sub2api defaults to 24h; we refresh 5 min before expiry.
_TOKEN_REFRESH_SKEW = 5 * 60
def _new_id() -> str:
return uuid.uuid4().hex[:12]
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _clean(value: object) -> str:
return str(value or "").strip()
def _normalize_import_job(raw: object, *, fail_unfinished: bool) -> dict | None:
if not isinstance(raw, dict):
return None
status = _clean(raw.get("status")) or "failed"
if fail_unfinished and status in {"pending", "running"}:
status = "failed"
return {
"job_id": _clean(raw.get("job_id")) or uuid.uuid4().hex,
"status": status,
"created_at": _clean(raw.get("created_at")) or _now_iso(),
"updated_at": _clean(raw.get("updated_at")) or _clean(raw.get("created_at")) or _now_iso(),
"total": int(raw.get("total") or 0),
"completed": int(raw.get("completed") or 0),
"added": int(raw.get("added") or 0),
"skipped": int(raw.get("skipped") or 0),
"refreshed": int(raw.get("refreshed") or 0),
"failed": int(raw.get("failed") or 0),
"errors": raw.get("errors") if isinstance(raw.get("errors"), list) else [],
}
def _normalize_server(raw: dict) -> dict:
return {
"id": _clean(raw.get("id")) or _new_id(),
"name": _clean(raw.get("name")),
"base_url": _clean(raw.get("base_url")),
"email": _clean(raw.get("email")),
"password": _clean(raw.get("password")),
"api_key": _clean(raw.get("api_key")),
"group_id": _clean(raw.get("group_id")),
"import_job": _normalize_import_job(raw.get("import_job"), fail_unfinished=True),
}
class Sub2APIConfig:
def __init__(self, store_file: Path):
self._store_file = store_file
self._lock = Lock()
self._servers: list[dict] = self._load()
def _load(self) -> list[dict]:
if not self._store_file.exists():
return []
try:
raw = json.loads(self._store_file.read_text(encoding="utf-8"))
if isinstance(raw, list):
return [_normalize_server(item) for item in raw if isinstance(item, dict)]
except Exception:
pass
return []
def _save(self) -> None:
self._store_file.parent.mkdir(parents=True, exist_ok=True)
self._store_file.write_text(
json.dumps(self._servers, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
def list_servers(self) -> list[dict]:
with self._lock:
return [dict(server) for server in self._servers]
def get_server(self, server_id: str) -> dict | None:
with self._lock:
for server in self._servers:
if server["id"] == server_id:
return dict(server)
return None
def add_server(
self,
*,
name: str,
base_url: str,
email: str,
password: str,
api_key: str,
group_id: str = "",
) -> dict:
server = _normalize_server({
"id": _new_id(),
"name": name,
"base_url": base_url,
"email": email,
"password": password,
"api_key": api_key,
"group_id": group_id,
})
with self._lock:
self._servers.append(server)
self._save()
_token_cache.pop(server["id"], None)
return dict(server)
def update_server(self, server_id: str, updates: dict) -> dict | None:
with self._lock:
for index, server in enumerate(self._servers):
if server["id"] != server_id:
continue
merged = {**server, **{k: v for k, v in updates.items() if v is not None}, "id": server_id}
self._servers[index] = _normalize_server(merged)
self._save()
result = dict(self._servers[index])
break
else:
return None
_token_cache.pop(server_id, None)
return result
def delete_server(self, server_id: str) -> bool:
with self._lock:
before = len(self._servers)
self._servers = [server for server in self._servers if server["id"] != server_id]
removed = len(self._servers) < before
if removed:
self._save()
if removed:
_token_cache.pop(server_id, None)
return removed
def set_import_job(self, server_id: str, import_job: dict | None) -> dict | None:
with self._lock:
for index, server in enumerate(self._servers):
if server["id"] != server_id:
continue
next_server = dict(server)
next_server["import_job"] = _normalize_import_job(import_job, fail_unfinished=False)
self._servers[index] = next_server
self._save()
return dict(next_server)
return None
def get_import_job(self, server_id: str) -> dict | None:
with self._lock:
for server in self._servers:
if server["id"] == server_id:
job = server.get("import_job")
return dict(job) if isinstance(job, dict) else None
return None
# Per-server cached access token: {server_id: (jwt, expires_at_epoch)}
_token_cache: dict[str, tuple[str, float]] = {}
_token_cache_lock = Lock()
def _login(base_url: str, email: str, password: str) -> tuple[str, float]:
url = f"{base_url.rstrip('/')}/api/v1/auth/login"
session = Session(verify=True)
try:
response = session.post(
url,
json={"email": email, "password": password},
headers={"Accept": "application/json", "Content-Type": "application/json"},
timeout=30,
)
if not response.ok:
raise RuntimeError(f"sub2api login failed: HTTP {response.status_code} {response.text[:200]}")
payload = response.json()
finally:
session.close()
body = _unwrap_envelope(payload)
if not isinstance(body, dict):
raise RuntimeError("sub2api login payload is invalid")
token = _clean(body.get("access_token"))
if not token:
raise RuntimeError("sub2api login did not return access_token")
expires_in = int(body.get("expires_in") or 3600)
expires_at = time.time() + max(60, expires_in) - _TOKEN_REFRESH_SKEW
return token, expires_at
def _auth_headers(server: dict) -> dict[str, str]:
api_key = _clean(server.get("api_key"))
if api_key:
return {"x-api-key": api_key, "Accept": "application/json"}
email = _clean(server.get("email"))
password = _clean(server.get("password"))
if not email or not password:
raise RuntimeError("sub2api server requires email+password or api_key")
server_id = _clean(server.get("id"))
base_url = _clean(server.get("base_url"))
with _token_cache_lock:
cached = _token_cache.get(server_id)
if cached and cached[1] > time.time():
return {"Authorization": f"Bearer {cached[0]}", "Accept": "application/json"}
token, expires_at = _login(base_url, email, password)
with _token_cache_lock:
_token_cache[server_id] = (token, expires_at)
return {"Authorization": f"Bearer {token}", "Accept": "application/json"}
def _extract_access_token(credentials: object) -> str:
if not isinstance(credentials, dict):
return ""
for key in ("access_token", "accessToken", "token"):
value = _clean(credentials.get(key))
if value:
return value
return ""
def _unwrap_envelope(payload: object) -> object:
"""Peel sub2api's `{code, message, data}` envelope, returning the inner `data` field
when present. Also handles unwrapped responses from older/alt versions."""
if isinstance(payload, dict) and "data" in payload and "code" in payload:
return payload.get("data")
return payload
def _extract_paged_items(payload: object) -> tuple[list, int]:
"""Return (items, total) from a paginated sub2api response.
Handles both the wrapped shape `{code,data:{items,total,...}}` and a few looser
variants (`{data:[...]}`, `[...]`, `{items:[...],total:N}`)."""
inner = _unwrap_envelope(payload)
if isinstance(inner, list):
return inner, len(inner)
if isinstance(inner, dict):
for key in ("items", "data", "list"):
value = inner.get(key)
if isinstance(value, list):
return value, int(inner.get("total") or len(value))
return [], 0
def list_remote_accounts(server: dict) -> list[dict]:
"""Return a flat list of OpenAI OAuth accounts from a sub2api server."""
base_url = _clean(server.get("base_url"))
if not base_url:
return []
headers = _auth_headers(server)
group_id = _clean(server.get("group_id"))
session = Session(verify=True)
items: list[dict] = []
try:
page = 1
while True:
params: dict[str, object] = {
"platform": "openai",
"type": "oauth",
"page": page,
"page_size": 200,
}
if group_id:
params["group"] = group_id
response = session.get(
f"{base_url.rstrip('/')}/api/v1/admin/accounts",
headers=headers,
params=params,
timeout=30,
)
if not response.ok:
raise RuntimeError(f"sub2api list failed: HTTP {response.status_code} {response.text[:200]}")
payload = response.json()
data, total = _extract_paged_items(payload)
if not data:
break
for account in data:
if not isinstance(account, dict):
continue
credentials = account.get("credentials") if isinstance(account.get("credentials"), dict) else {}
access_token = _extract_access_token(credentials)
if not access_token:
continue
account_id = account.get("id")
items.append({
"id": str(account_id) if account_id is not None else _clean(credentials.get("chatgpt_account_id")),
"name": _clean(account.get("name")),
"email": _clean(credentials.get("email")) or _clean(account.get("name")),
"plan_type": _clean(credentials.get("plan_type")),
"status": _clean(account.get("status")),
"expires_at": _clean(credentials.get("expires_at")),
"has_refresh_token": bool(_clean(credentials.get("refresh_token"))),
})
if page * 200 >= total or len(data) < 200:
break
page += 1
finally:
session.close()
return items
def list_remote_groups(server: dict) -> list[dict]:
"""Return OpenAI account groups from a sub2api server."""
base_url = _clean(server.get("base_url"))
if not base_url:
return []
headers = _auth_headers(server)
session = Session(verify=True)
items: list[dict] = []
try:
page = 1
while True:
response = session.get(
f"{base_url.rstrip('/')}/api/v1/admin/groups",
headers=headers,
params={
"page": page,
"page_size": 200,
},
timeout=30,
)
if not response.ok:
raise RuntimeError(f"sub2api groups failed: HTTP {response.status_code} {response.text[:200]}")
payload = response.json()
data, total = _extract_paged_items(payload)
if not data:
break
for group in data:
if not isinstance(group, dict):
continue
group_id = group.get("id")
if group_id is None:
continue
items.append({
"id": str(group_id),
"name": _clean(group.get("name")),
"description": _clean(group.get("description")),
"platform": _clean(group.get("platform")),
"status": _clean(group.get("status")),
"account_count": int(group.get("account_count") or 0),
"active_account_count": int(group.get("active_account_count") or 0),
})
if page * 200 >= total or len(data) < 200:
break
page += 1
finally:
session.close()
return items
def _fetch_access_token_for_account(server: dict, account_id: str) -> tuple[str, dict]:
"""Return (access_token, account_meta) for a single sub2api account id."""
base_url = _clean(server.get("base_url"))
headers = _auth_headers(server)
session = Session(verify=True)
try:
response = session.get(
f"{base_url.rstrip('/')}/api/v1/admin/accounts/{account_id}",
headers=headers,
timeout=30,
)
if not response.ok:
raise RuntimeError(f"HTTP {response.status_code}")
payload = response.json()
finally:
session.close()
account = _unwrap_envelope(payload)
if not isinstance(account, dict):
account = payload if isinstance(payload, dict) else {}
credentials = account.get("credentials") if isinstance(account.get("credentials"), dict) else {}
access_token = _extract_access_token(credentials)
if not access_token:
raise RuntimeError("missing access_token")
return access_token, {
"email": _clean(credentials.get("email")),
"plan_type": _clean(credentials.get("plan_type")),
}
class Sub2APIImportService:
def __init__(self, sub2api_config: Sub2APIConfig):
self._config = sub2api_config
def start_import(self, server: dict, account_ids: list[str]) -> dict:
ids = [_clean(item) for item in account_ids if _clean(item)]
if not ids:
raise ValueError("account ids is required")
server_id = _clean(server.get("id"))
job = {
"job_id": uuid.uuid4().hex,
"status": "pending",
"created_at": _now_iso(),
"updated_at": _now_iso(),
"total": len(ids),
"completed": 0,
"added": 0,
"skipped": 0,
"refreshed": 0,
"failed": 0,
"errors": [],
}
saved = self._config.set_import_job(server_id, job)
if saved is None:
raise ValueError("server not found")
thread = threading.Thread(
target=self._run_import,
args=(server_id, server, ids),
name=f"sub2api-import-{server_id}",
daemon=True,
)
thread.start()
return dict(saved.get("import_job") or job)
def _update_job(self, server_id: str, **updates) -> None:
current = self._config.get_import_job(server_id)
if current is None:
return
next_job = {**current, **updates, "updated_at": _now_iso()}
self._config.set_import_job(server_id, next_job)
def _append_error(self, server_id: str, account_id: str, message: str) -> None:
current = self._config.get_import_job(server_id)
if current is None:
return
errors = list(current.get("errors") or [])
errors.append({"name": account_id, "error": message})
self._update_job(server_id, errors=errors, failed=len(errors))
def _run_import(self, server_id: str, server: dict, account_ids: list[str]) -> None:
self._update_job(server_id, status="running")
tokens: list[str] = []
max_workers = min(8, max(1, len(account_ids)))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_map = {
executor.submit(_fetch_access_token_for_account, server, account_id): account_id
for account_id in account_ids
}
for future in as_completed(future_map):
account_id = future_map[future]
try:
token, _meta = future.result()
tokens.append(token)
except Exception as exc:
self._append_error(server_id, account_id, str(exc) or "unknown error")
current = self._config.get_import_job(server_id) or {}
failed = len(current.get("errors") or [])
self._update_job(
server_id,
completed=int(current.get("completed") or 0) + 1,
failed=failed,
)
if not tokens:
current = self._config.get_import_job(server_id) or {}
self._update_job(
server_id,
status="failed",
completed=int(current.get("total") or 0),
failed=len(current.get("errors") or []),
)
return
add_result = account_service.add_accounts(tokens)
refresh_result = account_service.refresh_accounts(tokens)
current = self._config.get_import_job(server_id) or {}
self._update_job(
server_id,
status="completed",
completed=len(account_ids),
added=int(add_result.get("added") or 0),
skipped=int(add_result.get("skipped") or 0),
refreshed=int(refresh_result.get("refreshed") or 0),
failed=len(current.get("errors") or []),
)
sub2api_config = Sub2APIConfig(SUB2API_CONFIG_FILE)
sub2api_import_service = Sub2APIImportService(sub2api_config)