| | """Token manager for Flow2API with AT auto-refresh""" |
| | import asyncio |
| | from datetime import datetime, timedelta, timezone |
| | from typing import Optional, List |
| | from ..core.database import Database |
| | from ..core.models import Token, Project |
| | from ..core.logger import debug_logger |
| | from .flow_client import FlowClient |
| | from .proxy_manager import ProxyManager |
| |
|
| |
|
| | class TokenManager: |
| | """Token lifecycle manager with AT auto-refresh""" |
| |
|
| | def __init__(self, db: Database, flow_client: FlowClient): |
| | self.db = db |
| | self.flow_client = flow_client |
| | self._lock = asyncio.Lock() |
| |
|
| | |
| |
|
| | async def get_all_tokens(self) -> List[Token]: |
| | """Get all tokens""" |
| | return await self.db.get_all_tokens() |
| |
|
| | async def get_active_tokens(self) -> List[Token]: |
| | """Get all active tokens""" |
| | return await self.db.get_active_tokens() |
| |
|
| | async def get_token(self, token_id: int) -> Optional[Token]: |
| | """Get token by ID""" |
| | return await self.db.get_token(token_id) |
| |
|
| | async def delete_token(self, token_id: int): |
| | """Delete token""" |
| | await self.db.delete_token(token_id) |
| |
|
| | async def enable_token(self, token_id: int): |
| | """Enable a token and reset error count""" |
| | |
| | await self.db.update_token(token_id, is_active=True) |
| | |
| | await self.db.reset_error_count(token_id) |
| |
|
| | async def disable_token(self, token_id: int): |
| | """Disable a token""" |
| | await self.db.update_token(token_id, is_active=False) |
| |
|
| | |
| |
|
| | async def add_token( |
| | self, |
| | st: str, |
| | project_id: Optional[str] = None, |
| | project_name: Optional[str] = None, |
| | remark: Optional[str] = None, |
| | image_enabled: bool = True, |
| | video_enabled: bool = True, |
| | image_concurrency: int = -1, |
| | video_concurrency: int = -1 |
| | ) -> Token: |
| | """Add a new token |
| | |
| | Args: |
| | st: Session Token (必需) |
| | project_id: 项目ID (可选,如果提供则直接使用,不创建新项目) |
| | project_name: 项目名称 (可选,如果不提供则自动生成) |
| | remark: 备注 |
| | image_enabled: 是否启用图片生成 |
| | video_enabled: 是否启用视频生成 |
| | image_concurrency: 图片并发限制 |
| | video_concurrency: 视频并发限制 |
| | |
| | Returns: |
| | Token object |
| | """ |
| | |
| | existing_token = await self.db.get_token_by_st(st) |
| | if existing_token: |
| | raise ValueError(f"Token 已存在(邮箱: {existing_token.email})") |
| |
|
| | |
| | debug_logger.log_info(f"[ADD_TOKEN] Converting ST to AT...") |
| | try: |
| | result = await self.flow_client.st_to_at(st) |
| | at = result["access_token"] |
| | expires = result.get("expires") |
| | user_info = result.get("user", {}) |
| | email = user_info.get("email", "") |
| | name = user_info.get("name", email.split("@")[0] if email else "") |
| |
|
| | |
| | at_expires = None |
| | if expires: |
| | try: |
| | at_expires = datetime.fromisoformat(expires.replace('Z', '+00:00')) |
| | except: |
| | pass |
| |
|
| | except Exception as e: |
| | raise ValueError(f"ST转AT失败: {str(e)}") |
| |
|
| | |
| | try: |
| | credits_result = await self.flow_client.get_credits(at) |
| | credits = credits_result.get("credits", 0) |
| | user_paygate_tier = credits_result.get("userPaygateTier") |
| | except: |
| | credits = 0 |
| | user_paygate_tier = None |
| |
|
| | |
| | if project_id: |
| | |
| | debug_logger.log_info(f"[ADD_TOKEN] Using provided project_id: {project_id}") |
| | if not project_name: |
| | |
| | now = datetime.now() |
| | project_name = now.strftime("%b %d - %H:%M") |
| | else: |
| | |
| | if not project_name: |
| | |
| | now = datetime.now() |
| | project_name = now.strftime("%b %d - %H:%M") |
| |
|
| | try: |
| | project_id = await self.flow_client.create_project(st, project_name) |
| | debug_logger.log_info(f"[ADD_TOKEN] Created new project: {project_name} (ID: {project_id})") |
| | except Exception as e: |
| | raise ValueError(f"创建项目失败: {str(e)}") |
| |
|
| | |
| | token = Token( |
| | st=st, |
| | at=at, |
| | at_expires=at_expires, |
| | email=email, |
| | name=name, |
| | remark=remark, |
| | is_active=True, |
| | credits=credits, |
| | user_paygate_tier=user_paygate_tier, |
| | current_project_id=project_id, |
| | current_project_name=project_name, |
| | image_enabled=image_enabled, |
| | video_enabled=video_enabled, |
| | image_concurrency=image_concurrency, |
| | video_concurrency=video_concurrency |
| | ) |
| |
|
| | |
| | token_id = await self.db.add_token(token) |
| | token.id = token_id |
| |
|
| | |
| | project = Project( |
| | project_id=project_id, |
| | token_id=token_id, |
| | project_name=project_name, |
| | tool_name="PINHOLE" |
| | ) |
| | await self.db.add_project(project) |
| |
|
| | debug_logger.log_info(f"[ADD_TOKEN] Token added successfully (ID: {token_id}, Email: {email})") |
| | return token |
| |
|
| | async def update_token( |
| | self, |
| | token_id: int, |
| | st: Optional[str] = None, |
| | at: Optional[str] = None, |
| | at_expires: Optional[datetime] = None, |
| | project_id: Optional[str] = None, |
| | project_name: Optional[str] = None, |
| | remark: Optional[str] = None, |
| | image_enabled: Optional[bool] = None, |
| | video_enabled: Optional[bool] = None, |
| | image_concurrency: Optional[int] = None, |
| | video_concurrency: Optional[int] = None |
| | ): |
| | """Update token (支持修改project_id和project_name) |
| | |
| | 当用户编辑保存token时,如果token未过期,自动清空429禁用状态 |
| | """ |
| | update_fields = {} |
| |
|
| | if st is not None: |
| | update_fields["st"] = st |
| | if at is not None: |
| | update_fields["at"] = at |
| | if at_expires is not None: |
| | update_fields["at_expires"] = at_expires |
| | if project_id is not None: |
| | update_fields["current_project_id"] = project_id |
| | if project_name is not None: |
| | update_fields["current_project_name"] = project_name |
| | if remark is not None: |
| | update_fields["remark"] = remark |
| | if image_enabled is not None: |
| | update_fields["image_enabled"] = image_enabled |
| | if video_enabled is not None: |
| | update_fields["video_enabled"] = video_enabled |
| | if image_concurrency is not None: |
| | update_fields["image_concurrency"] = image_concurrency |
| | if video_concurrency is not None: |
| | update_fields["video_concurrency"] = video_concurrency |
| |
|
| | |
| | token = await self.db.get_token(token_id) |
| | if token and token.ban_reason == "429_rate_limit": |
| | |
| | is_expired = False |
| | if token.at_expires: |
| | now = datetime.now(timezone.utc) |
| | if token.at_expires.tzinfo is None: |
| | at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc) |
| | else: |
| | at_expires_aware = token.at_expires |
| | is_expired = at_expires_aware <= now |
| |
|
| | |
| | if not is_expired: |
| | debug_logger.log_info(f"[UPDATE_TOKEN] Token {token_id} 编辑保存,清空429禁用状态") |
| | update_fields["ban_reason"] = None |
| | update_fields["banned_at"] = None |
| |
|
| | if update_fields: |
| | await self.db.update_token(token_id, **update_fields) |
| |
|
| | |
| |
|
| | async def is_at_valid(self, token_id: int) -> bool: |
| | """检查AT是否有效,如果无效或即将过期则自动刷新 |
| | |
| | Returns: |
| | True if AT is valid or refreshed successfully |
| | False if AT cannot be refreshed |
| | """ |
| | token = await self.db.get_token(token_id) |
| | if not token: |
| | return False |
| |
|
| | |
| | if not token.at: |
| | debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT不存在,需要刷新") |
| | return await self._refresh_at(token_id) |
| |
|
| | |
| | if not token.at_expires: |
| | debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT过期时间未知,尝试刷新") |
| | return await self._refresh_at(token_id) |
| |
|
| | |
| | now = datetime.now(timezone.utc) |
| | |
| | if token.at_expires.tzinfo is None: |
| | at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc) |
| | else: |
| | at_expires_aware = token.at_expires |
| |
|
| | time_until_expiry = at_expires_aware - now |
| |
|
| | if time_until_expiry.total_seconds() < 3600: |
| | debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT即将过期 (剩余 {time_until_expiry.total_seconds():.0f} 秒),需要刷新") |
| | return await self._refresh_at(token_id) |
| |
|
| | |
| | return True |
| |
|
| | async def _refresh_at(self, token_id: int) -> bool: |
| | """内部方法: 刷新AT |
| | |
| | Returns: |
| | True if refresh successful, False otherwise |
| | """ |
| | async with self._lock: |
| | token = await self.db.get_token(token_id) |
| | if not token: |
| | return False |
| |
|
| | try: |
| | debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: 开始刷新AT...") |
| |
|
| | |
| | result = await self.flow_client.st_to_at(token.st) |
| | new_at = result["access_token"] |
| | expires = result.get("expires") |
| |
|
| | |
| | new_at_expires = None |
| | if expires: |
| | try: |
| | new_at_expires = datetime.fromisoformat(expires.replace('Z', '+00:00')) |
| | except: |
| | pass |
| |
|
| | |
| | await self.db.update_token( |
| | token_id, |
| | at=new_at, |
| | at_expires=new_at_expires |
| | ) |
| |
|
| | debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: AT刷新成功") |
| | debug_logger.log_info(f" - 新过期时间: {new_at_expires}") |
| |
|
| | |
| | try: |
| | credits_result = await self.flow_client.get_credits(new_at) |
| | await self.db.update_token( |
| | token_id, |
| | credits=credits_result.get("credits", 0) |
| | ) |
| | except: |
| | pass |
| |
|
| | return True |
| |
|
| | except Exception as e: |
| | debug_logger.log_error(f"[AT_REFRESH] Token {token_id}: AT刷新失败 - {str(e)}") |
| | |
| | await self.disable_token(token_id) |
| | return False |
| |
|
| | async def ensure_project_exists(self, token_id: int) -> str: |
| | """确保Token有可用的Project |
| | |
| | Returns: |
| | project_id |
| | """ |
| | token = await self.db.get_token(token_id) |
| | if not token: |
| | raise ValueError("Token not found") |
| |
|
| | |
| | if token.current_project_id: |
| | return token.current_project_id |
| |
|
| | |
| | now = datetime.now() |
| | project_name = now.strftime("%b %d - %H:%M") |
| |
|
| | try: |
| | project_id = await self.flow_client.create_project(token.st, project_name) |
| | debug_logger.log_info(f"[PROJECT] Created project for token {token_id}: {project_name}") |
| |
|
| | |
| | await self.db.update_token( |
| | token_id, |
| | current_project_id=project_id, |
| | current_project_name=project_name |
| | ) |
| |
|
| | |
| | project = Project( |
| | project_id=project_id, |
| | token_id=token_id, |
| | project_name=project_name |
| | ) |
| | await self.db.add_project(project) |
| |
|
| | return project_id |
| |
|
| | except Exception as e: |
| | raise ValueError(f"Failed to create project: {str(e)}") |
| |
|
| | |
| |
|
| | async def record_usage(self, token_id: int, is_video: bool = False): |
| | """Record token usage""" |
| | await self.db.update_token(token_id, use_count=1, last_used_at=datetime.now()) |
| |
|
| | if is_video: |
| | await self.db.increment_token_stats(token_id, "video") |
| | else: |
| | await self.db.increment_token_stats(token_id, "image") |
| |
|
| | async def record_error(self, token_id: int): |
| | """Record token error and auto-disable if threshold reached""" |
| | await self.db.increment_token_stats(token_id, "error") |
| |
|
| | |
| | stats = await self.db.get_token_stats(token_id) |
| | admin_config = await self.db.get_admin_config() |
| |
|
| | if stats and stats.consecutive_error_count >= admin_config.error_ban_threshold: |
| | debug_logger.log_warning( |
| | f"[TOKEN_BAN] Token {token_id} consecutive error count ({stats.consecutive_error_count}) " |
| | f"reached threshold ({admin_config.error_ban_threshold}), auto-disabling" |
| | ) |
| | await self.disable_token(token_id) |
| |
|
| | async def record_success(self, token_id: int): |
| | """Record successful request (reset consecutive error count) |
| | |
| | This method resets error_count to 0, which is used for auto-disable threshold checking. |
| | Note: today_error_count and historical statistics are NOT reset. |
| | """ |
| | await self.db.reset_error_count(token_id) |
| |
|
| | async def ban_token_for_429(self, token_id: int): |
| | """因429错误立即禁用token |
| | |
| | Args: |
| | token_id: Token ID |
| | """ |
| | debug_logger.log_warning(f"[429_BAN] 禁用Token {token_id} (原因: 429 Rate Limit)") |
| | await self.db.update_token( |
| | token_id, |
| | is_active=False, |
| | ban_reason="429_rate_limit", |
| | banned_at=datetime.now(timezone.utc) |
| | ) |
| |
|
| | async def auto_unban_429_tokens(self): |
| | """自动解禁因429被禁用的token |
| | |
| | 规则: |
| | - 距离禁用时间12小时后自动解禁 |
| | - 仅解禁未过期的token |
| | - 仅解禁因429被禁用的token |
| | """ |
| | all_tokens = await self.db.get_all_tokens() |
| | now = datetime.now(timezone.utc) |
| |
|
| | for token in all_tokens: |
| | |
| | if token.ban_reason != "429_rate_limit": |
| | continue |
| |
|
| | |
| | if token.is_active: |
| | continue |
| |
|
| | |
| | if not token.banned_at: |
| | continue |
| |
|
| | |
| | if token.at_expires: |
| | |
| | if token.at_expires.tzinfo is None: |
| | at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc) |
| | else: |
| | at_expires_aware = token.at_expires |
| |
|
| | |
| | if at_expires_aware <= now: |
| | debug_logger.log_info(f"[AUTO_UNBAN] Token {token.id} 已过期,跳过解禁") |
| | continue |
| |
|
| | |
| | if token.banned_at.tzinfo is None: |
| | banned_at_aware = token.banned_at.replace(tzinfo=timezone.utc) |
| | else: |
| | banned_at_aware = token.banned_at |
| |
|
| | |
| | time_since_ban = now - banned_at_aware |
| | if time_since_ban.total_seconds() >= 12 * 3600: |
| | debug_logger.log_info( |
| | f"[AUTO_UNBAN] 解禁Token {token.id} (禁用时间: {banned_at_aware}, " |
| | f"已过 {time_since_ban.total_seconds() / 3600:.1f} 小时)" |
| | ) |
| | await self.db.update_token( |
| | token.id, |
| | is_active=True, |
| | ban_reason=None, |
| | banned_at=None |
| | ) |
| | |
| | await self.db.reset_error_count(token.id) |
| |
|
| | |
| |
|
| | async def refresh_credits(self, token_id: int) -> int: |
| | """刷新Token余额 |
| | |
| | Returns: |
| | credits |
| | """ |
| | token = await self.db.get_token(token_id) |
| | if not token: |
| | return 0 |
| |
|
| | |
| | if not await self.is_at_valid(token_id): |
| | return 0 |
| |
|
| | |
| | token = await self.db.get_token(token_id) |
| |
|
| | try: |
| | result = await self.flow_client.get_credits(token.at) |
| | credits = result.get("credits", 0) |
| |
|
| | |
| | await self.db.update_token(token_id, credits=credits) |
| |
|
| | return credits |
| | except Exception as e: |
| | debug_logger.log_error(f"Failed to refresh credits for token {token_id}: {str(e)}") |
| | return 0 |
| |
|