Spaces:
Paused
Paused
| from __future__ import annotations | |
| from fastapi import APIRouter, Header, HTTPException | |
| from fastapi.concurrency import run_in_threadpool | |
| from pydantic import BaseModel, Field | |
| from services.auth_service import auth_service | |
| from api.support import ( | |
| require_admin, | |
| sanitize_cpa_pool, | |
| sanitize_cpa_pools, | |
| sanitize_sub2api_server, | |
| sanitize_sub2api_servers, | |
| ) | |
| from services.account_service import account_service | |
| from services.cpa_service import cpa_config, cpa_import_service, list_remote_files | |
| from services.sub2api_service import ( | |
| list_remote_accounts as sub2api_list_remote_accounts, | |
| list_remote_groups as sub2api_list_remote_groups, | |
| sub2api_config, | |
| sub2api_import_service, | |
| ) | |
| class UserKeyCreateRequest(BaseModel): | |
| name: str = "" | |
| class UserKeyUpdateRequest(BaseModel): | |
| name: str | None = None | |
| enabled: bool | None = None | |
| key: str | None = None | |
| class AccountCreateRequest(BaseModel): | |
| tokens: list[str] = Field(default_factory=list) | |
| class AccountDeleteRequest(BaseModel): | |
| tokens: list[str] = Field(default_factory=list) | |
| class AccountRefreshRequest(BaseModel): | |
| access_tokens: list[str] = Field(default_factory=list) | |
| class AccountUpdateRequest(BaseModel): | |
| access_token: str = "" | |
| type: str | None = None | |
| status: str | None = None | |
| quota: int | None = None | |
| class CPAPoolCreateRequest(BaseModel): | |
| name: str = "" | |
| base_url: str = "" | |
| secret_key: str = "" | |
| class CPAPoolUpdateRequest(BaseModel): | |
| name: str | None = None | |
| base_url: str | None = None | |
| secret_key: str | None = None | |
| class CPAImportRequest(BaseModel): | |
| names: list[str] = Field(default_factory=list) | |
| class Sub2APIServerCreateRequest(BaseModel): | |
| name: str = "" | |
| base_url: str = "" | |
| email: str = "" | |
| password: str = "" | |
| api_key: str = "" | |
| group_id: str = "" | |
| class Sub2APIServerUpdateRequest(BaseModel): | |
| name: str | None = None | |
| base_url: str | None = None | |
| email: str | None = None | |
| password: str | None = None | |
| api_key: str | None = None | |
| group_id: str | None = None | |
| class Sub2APIImportRequest(BaseModel): | |
| account_ids: list[str] = Field(default_factory=list) | |
| def create_router() -> APIRouter: | |
| router = APIRouter() | |
| async def list_user_keys(authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| return {"items": auth_service.list_keys(role="user")} | |
| async def create_user_key(body: UserKeyCreateRequest, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| try: | |
| item, raw_key = auth_service.create_key(role="user", name=body.name) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail={"error": str(exc)}) from exc | |
| return {"item": item, "key": raw_key, "items": auth_service.list_keys(role="user")} | |
| async def update_user_key( | |
| key_id: str, | |
| body: UserKeyUpdateRequest, | |
| authorization: str | None = Header(default=None), | |
| ): | |
| require_admin(authorization) | |
| updates = { | |
| key: value | |
| for key, value in { | |
| "name": body.name, | |
| "enabled": body.enabled, | |
| "key": body.key, | |
| }.items() | |
| if value is not None | |
| } | |
| if not updates: | |
| raise HTTPException(status_code=400, detail={"error": "还没有检测到改动,请修改后再保存"}) | |
| try: | |
| item = auth_service.update_key(key_id, updates, role="user") | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail={"error": str(exc)}) from exc | |
| if item is None: | |
| raise HTTPException(status_code=404, detail={"error": "这条用户密钥不存在,可能已经被删除"}) | |
| return {"item": item, "items": auth_service.list_keys(role="user")} | |
| async def delete_user_key(key_id: str, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| if not auth_service.delete_key(key_id, role="user"): | |
| raise HTTPException(status_code=404, detail={"error": "这条用户密钥不存在,可能已经被删除"}) | |
| return {"items": auth_service.list_keys(role="user")} | |
| async def get_accounts(authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| return {"items": account_service.list_accounts()} | |
| async def create_accounts(body: AccountCreateRequest, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| tokens = [str(token or "").strip() for token in body.tokens if str(token or "").strip()] | |
| if not tokens: | |
| raise HTTPException(status_code=400, detail={"error": "tokens is required"}) | |
| result = account_service.add_accounts(tokens) | |
| refresh_result = account_service.refresh_accounts(tokens) | |
| return { | |
| **result, | |
| "refreshed": refresh_result.get("refreshed", 0), | |
| "errors": refresh_result.get("errors", []), | |
| "items": refresh_result.get("items", result.get("items", [])), | |
| } | |
| async def delete_accounts(body: AccountDeleteRequest, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| tokens = [str(token or "").strip() for token in body.tokens if str(token or "").strip()] | |
| if not tokens: | |
| raise HTTPException(status_code=400, detail={"error": "tokens is required"}) | |
| return account_service.delete_accounts(tokens) | |
| async def refresh_accounts(body: AccountRefreshRequest, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| access_tokens = [str(token or "").strip() for token in body.access_tokens if str(token or "").strip()] | |
| if not access_tokens: | |
| access_tokens = account_service.list_tokens() | |
| if not access_tokens: | |
| raise HTTPException(status_code=400, detail={"error": "access_tokens is required"}) | |
| return account_service.refresh_accounts(access_tokens) | |
| async def update_account(body: AccountUpdateRequest, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| access_token = str(body.access_token or "").strip() | |
| if not access_token: | |
| raise HTTPException(status_code=400, detail={"error": "access_token is required"}) | |
| updates = {key: value for key, value in {"type": body.type, "status": body.status, "quota": body.quota}.items() if value is not None} | |
| if not updates: | |
| raise HTTPException(status_code=400, detail={"error": "还没有检测到改动,请修改后再保存"}) | |
| account = account_service.update_account(access_token, updates) | |
| if account is None: | |
| raise HTTPException(status_code=404, detail={"error": "account not found"}) | |
| return {"item": account, "items": account_service.list_accounts()} | |
| async def list_cpa_pools(authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| return {"pools": sanitize_cpa_pools(cpa_config.list_pools())} | |
| async def create_cpa_pool(body: CPAPoolCreateRequest, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| if not body.base_url.strip(): | |
| raise HTTPException(status_code=400, detail={"error": "base_url is required"}) | |
| if not body.secret_key.strip(): | |
| raise HTTPException(status_code=400, detail={"error": "secret_key is required"}) | |
| pool = cpa_config.add_pool(name=body.name, base_url=body.base_url, secret_key=body.secret_key) | |
| return {"pool": sanitize_cpa_pool(pool), "pools": sanitize_cpa_pools(cpa_config.list_pools())} | |
| async def update_cpa_pool(pool_id: str, body: CPAPoolUpdateRequest, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| pool = cpa_config.update_pool(pool_id, body.model_dump(exclude_none=True)) | |
| if pool is None: | |
| raise HTTPException(status_code=404, detail={"error": "pool not found"}) | |
| return {"pool": sanitize_cpa_pool(pool), "pools": sanitize_cpa_pools(cpa_config.list_pools())} | |
| async def delete_cpa_pool(pool_id: str, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| if not cpa_config.delete_pool(pool_id): | |
| raise HTTPException(status_code=404, detail={"error": "pool not found"}) | |
| return {"pools": sanitize_cpa_pools(cpa_config.list_pools())} | |
| async def cpa_pool_files(pool_id: str, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| pool = cpa_config.get_pool(pool_id) | |
| if pool is None: | |
| raise HTTPException(status_code=404, detail={"error": "pool not found"}) | |
| return {"pool_id": pool_id, "files": await run_in_threadpool(list_remote_files, pool)} | |
| async def cpa_pool_import(pool_id: str, body: CPAImportRequest, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| pool = cpa_config.get_pool(pool_id) | |
| if pool is None: | |
| raise HTTPException(status_code=404, detail={"error": "pool not found"}) | |
| try: | |
| job = cpa_import_service.start_import(pool, body.names) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail={"error": str(exc)}) from exc | |
| return {"import_job": job} | |
| async def cpa_pool_import_progress(pool_id: str, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| pool = cpa_config.get_pool(pool_id) | |
| if pool is None: | |
| raise HTTPException(status_code=404, detail={"error": "pool not found"}) | |
| return {"import_job": pool.get("import_job")} | |
| async def list_sub2api_servers(authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| return {"servers": sanitize_sub2api_servers(sub2api_config.list_servers())} | |
| async def create_sub2api_server(body: Sub2APIServerCreateRequest, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| if not body.base_url.strip(): | |
| raise HTTPException(status_code=400, detail={"error": "base_url is required"}) | |
| has_login = body.email.strip() and body.password.strip() | |
| has_api_key = bool(body.api_key.strip()) | |
| if not has_login and not has_api_key: | |
| raise HTTPException(status_code=400, detail={"error": "email+password or api_key is required"}) | |
| server = sub2api_config.add_server( | |
| name=body.name, | |
| base_url=body.base_url, | |
| email=body.email, | |
| password=body.password, | |
| api_key=body.api_key, | |
| group_id=body.group_id, | |
| ) | |
| return {"server": sanitize_sub2api_server(server), "servers": sanitize_sub2api_servers(sub2api_config.list_servers())} | |
| async def update_sub2api_server(server_id: str, body: Sub2APIServerUpdateRequest, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| server = sub2api_config.update_server(server_id, body.model_dump(exclude_none=True)) | |
| if server is None: | |
| raise HTTPException(status_code=404, detail={"error": "server not found"}) | |
| return {"server": sanitize_sub2api_server(server), "servers": sanitize_sub2api_servers(sub2api_config.list_servers())} | |
| async def delete_sub2api_server(server_id: str, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| if not sub2api_config.delete_server(server_id): | |
| raise HTTPException(status_code=404, detail={"error": "server not found"}) | |
| return {"servers": sanitize_sub2api_servers(sub2api_config.list_servers())} | |
| async def sub2api_server_groups(server_id: str, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| server = sub2api_config.get_server(server_id) | |
| if server is None: | |
| raise HTTPException(status_code=404, detail={"error": "server not found"}) | |
| try: | |
| groups = await run_in_threadpool(sub2api_list_remote_groups, server) | |
| except Exception as exc: | |
| raise HTTPException(status_code=502, detail={"error": str(exc)}) from exc | |
| return {"server_id": server_id, "groups": groups} | |
| async def sub2api_server_accounts(server_id: str, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| server = sub2api_config.get_server(server_id) | |
| if server is None: | |
| raise HTTPException(status_code=404, detail={"error": "server not found"}) | |
| try: | |
| accounts = await run_in_threadpool(sub2api_list_remote_accounts, server) | |
| except Exception as exc: | |
| raise HTTPException(status_code=502, detail={"error": str(exc)}) from exc | |
| return {"server_id": server_id, "accounts": accounts} | |
| async def sub2api_server_import(server_id: str, body: Sub2APIImportRequest, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| server = sub2api_config.get_server(server_id) | |
| if server is None: | |
| raise HTTPException(status_code=404, detail={"error": "server not found"}) | |
| try: | |
| job = sub2api_import_service.start_import(server, body.account_ids) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail={"error": str(exc)}) from exc | |
| return {"import_job": job} | |
| async def sub2api_server_import_progress(server_id: str, authorization: str | None = Header(default=None)): | |
| require_admin(authorization) | |
| server = sub2api_config.get_server(server_id) | |
| if server is None: | |
| raise HTTPException(status_code=404, detail={"error": "server not found"}) | |
| return {"import_job": server.get("import_job")} | |
| return router | |