File size: 21,884 Bytes
47258ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Token 池管理器 - 基于数据库的 Token 轮询和健康检查

核心功能:
1. Token 轮询机制 - 负载均衡和容错
2. Z.AI 官方认证接口验证 - 基于 role 字段区分用户类型
3. Token 健康度监控 - 自动禁用失败 Token
4. 数据库集成 - 与 TokenDAO 协同工作
"""

import asyncio
import time
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from threading import Lock
import httpx

from app.utils.logger import logger


# ==================== Token 状态管理 ====================


@dataclass
class TokenStatus:
    """Token 运行时状态(内存中)"""
    token: str
    token_id: int  # 数据库 ID,用于同步统计
    token_type: str = "unknown"  # "user", "guest", "unknown"
    is_available: bool = True
    failure_count: int = 0
    last_failure_time: float = 0.0
    last_success_time: float = 0.0
    total_requests: int = 0
    successful_requests: int = 0

    @property
    def success_rate(self) -> float:
        """成功率"""
        if self.total_requests == 0:
            return 1.0
        return self.successful_requests / self.total_requests

    @property
    def is_healthy(self) -> bool:
        """
        Token 健康状态判断

        健康标准:
        1. 必须是认证用户 Token (token_type = "user")
        2. 当前可用 (is_available = True)
        3. 成功率 >= 50% 或总请求数 <= 3(新 Token 容错)

        注意:
        - guest Token 永远不健康
        - unknown Token 永远不健康
        """
        # guest 和 unknown token 永远不健康
        if self.token_type != "user":
            return False

        # 不可用的 token 不健康
        if not self.is_available:
            return False

        # 新 token 容错:请求数很少时,只要没失败就健康
        if self.total_requests <= 3:
            return self.failure_count == 0

        # 基于成功率判断
        return self.success_rate >= 0.5


# ==================== Token 验证服务 ====================


class ZAITokenValidator:
    """Z.AI Token 验证器(使用官方认证接口)"""

    AUTH_URL = "https://chat.z.ai/api/v1/auths/"

    @staticmethod
    def get_headers(token: str) -> Dict[str, str]:
        """构建认证请求头"""
        return {
            "Accept": "*/*",
            "Accept-Language": "zh-CN,zh;q=0.9",
            "Authorization": f"Bearer {token}",
            "Connection": "keep-alive",
            "Content-Type": "application/json",
            "DNT": "1",
            "Referer": "https://chat.z.ai/",
            "Sec-Fetch-Dest": "empty",
            "Sec-Fetch-Mode": "cors",
            "Sec-Fetch-Site": "same-origin",
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36",
            "sec-ch-ua": '"Chromium";v="140", "Not=A?Brand";v="24", "Google Chrome";v="140"',
            "sec-ch-ua-mobile": "?0",
            "sec-ch-ua-platform": '"Windows"'
        }

    @classmethod
    async def validate_token(cls, token: str) -> Tuple[str, bool, Optional[str]]:
        """
        验证 Token 有效性并返回类型

        Args:
            token: 待验证的 Token

        Returns:
            (token_type, is_valid, error_message)
            - token_type: "user" | "guest" | "unknown"
            - is_valid: True 表示是有效的认证用户 Token
            - error_message: 失败原因(仅在 is_valid=False 时有值)
        """
        try:
            async with httpx.AsyncClient(timeout=15.0) as client:
                response = await client.get(
                    cls.AUTH_URL,
                    headers=cls.get_headers(token)
                )

                # 解析响应
                return cls._parse_auth_response(response)

        except httpx.TimeoutException:
            return ("unknown", False, "请求超时")
        except httpx.ConnectError:
            return ("unknown", False, "连接失败")
        except Exception as e:
            return ("unknown", False, f"验证异常: {str(e)}")

    @staticmethod
    def _parse_auth_response(response: httpx.Response) -> Tuple[str, bool, Optional[str]]:
        """
        解析 Z.AI 认证接口响应

        响应格式示例:
        {
            "id": "...",
            "email": "user@example.com",
            "role": "user"  # 或 "guest"
        }

        验证规则:
        - role: "user" → 认证用户 Token(有效,可添加)
        - role: "guest" → 匿名用户 Token(无效,拒绝添加)
        - 其他情况 → 无效 Token
        """
        # 检查 HTTP 状态码
        if response.status_code != 200:
            return ("unknown", False, f"HTTP {response.status_code}")

        try:
            data = response.json()

            # 验证响应格式
            if not isinstance(data, dict):
                return ("unknown", False, "无效的响应格式")

            # 检查是否包含错误信息
            if "error" in data or "message" in data:
                error_msg = data.get("error") or data.get("message", "未知错误")
                return ("unknown", False, str(error_msg))

            # 核心验证:检查 role 字段
            role = data.get("role")

            if role == "user":
                return ("user", True, None)
            elif role == "guest":
                return ("guest", False, "匿名用户 Token 不允许添加")
            else:
                return ("unknown", False, f"未知 role: {role}")

        except (ValueError, Exception) as e:
            return ("unknown", False, f"解析响应失败: {str(e)}")


# ==================== Token 池管理器 ====================


class TokenPool:
    """Token 池管理器(数据库驱动)"""

    def __init__(
        self,
        tokens: List[Tuple[int, str, str]],  # [(token_id, token_value, token_type), ...]
        failure_threshold: int = 3,
        recovery_timeout: int = 1800
    ):
        """
        初始化 Token 池

        Args:
            tokens: Token 列表 [(token_id, token_value, token_type), ...]
            failure_threshold: 失败阈值,超过此次数将标记为不可用
            recovery_timeout: 恢复超时时间(秒),失败 Token 在此时间后重新尝试
        """
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self._lock = Lock()
        self._current_index = 0

        # 初始化 Token 状态(内存中)
        self.token_statuses: Dict[str, TokenStatus] = {}
        self.token_id_map: Dict[str, int] = {}  # token -> token_id 映射

        for token_id, token_value, token_type in tokens:
            if token_value and token_value not in self.token_statuses:
                self.token_statuses[token_value] = TokenStatus(
                    token=token_value,
                    token_id=token_id,
                    token_type=token_type
                )
                self.token_id_map[token_value] = token_id

        if not self.token_statuses:
            logger.warning("⚠️ Token 池为空,将依赖匿名模式")

    def get_next_token(self) -> Optional[str]:
        """
        获取下一个可用的认证用户 Token(轮询算法)

        Returns:
            可用的 Token 字符串,如果没有可用 Token 则返回 None
        """
        with self._lock:
            if not self.token_statuses:
                return None

            available_tokens = self._get_available_user_tokens()
            if not available_tokens:
                # 尝试恢复过期的失败 Token
                self._try_recover_failed_tokens()
                available_tokens = self._get_available_user_tokens()

                if not available_tokens:
                    logger.warning("⚠️ 没有可用的认证用户 Token")
                    return None

            # 轮询选择
            token = available_tokens[self._current_index % len(available_tokens)]
            self._current_index = (self._current_index + 1) % len(available_tokens)

            return token

    def _get_available_user_tokens(self) -> List[str]:
        """
        获取当前可用的认证用户 Token 列表

        过滤条件:
        1. is_available = True
        2. token_type == "user"
        """
        available_user_tokens = [
            status.token for status in self.token_statuses.values()
            if status.is_available and status.token_type == "user"
        ]

        # 警告:如果有 guest token 但没有 user token
        if not available_user_tokens and self.token_statuses:
            guest_count = sum(
                1 for status in self.token_statuses.values()
                if status.token_type == "guest"
            )
            if guest_count > 0:
                logger.warning(f"⚠️ 检测到 {guest_count} 个匿名用户 Token,轮询机制将跳过这些 Token")

        return available_user_tokens

    def _try_recover_failed_tokens(self):
        """尝试恢复失败的 Token(仅针对认证用户 Token)"""
        current_time = time.time()
        recovered_count = 0

        for status in self.token_statuses.values():
            # 只恢复认证用户 Token
            if (
                status.token_type == "user"
                and not status.is_available
                and current_time - status.last_failure_time > self.recovery_timeout
            ):
                status.is_available = True
                status.failure_count = 0
                recovered_count += 1
                logger.info(f"🔄 恢复失败 Token: {status.token[:20]}...")

        if recovered_count > 0:
            logger.info(f"✅ 恢复了 {recovered_count} 个失败的 Token")

    def mark_token_success(self, token: str):
        """标记 Token 使用成功"""
        with self._lock:
            if token in self.token_statuses:
                status = self.token_statuses[token]
                status.total_requests += 1
                status.successful_requests += 1
                status.last_success_time = time.time()
                status.failure_count = 0  # 重置失败计数

                if not status.is_available:
                    status.is_available = True
                    logger.info(f"✅ Token 恢复可用: {token[:20]}...")

    def mark_token_failure(self, token: str, error: Exception = None):
        """标记 Token 使用失败"""
        with self._lock:
            if token in self.token_statuses:
                status = self.token_statuses[token]
                status.total_requests += 1
                status.failure_count += 1
                status.last_failure_time = time.time()

                if status.failure_count >= self.failure_threshold:
                    status.is_available = False
                    logger.warning(f"🚫 Token 已禁用: {token[:20]}... (失败 {status.failure_count} 次)")

    def get_token_id(self, token: str) -> Optional[int]:
        """获取 Token 的数据库 ID"""
        return self.token_id_map.get(token)

    def get_pool_status(self) -> Dict:
        """获取 Token 池状态信息"""
        with self._lock:
            available_count = len(self._get_available_user_tokens())
            total_count = len(self.token_statuses)
            healthy_count = sum(1 for status in self.token_statuses.values() if status.is_healthy)

            # 统计各类型 Token
            user_count = sum(1 for s in self.token_statuses.values() if s.token_type == "user")
            guest_count = sum(1 for s in self.token_statuses.values() if s.token_type == "guest")
            unknown_count = sum(1 for s in self.token_statuses.values() if s.token_type == "unknown")

            status_info = {
                "total_tokens": total_count,
                "available_tokens": available_count,
                "unavailable_tokens": total_count - available_count,
                "healthy_tokens": healthy_count,
                "unhealthy_tokens": total_count - healthy_count,
                "user_tokens": user_count,
                "guest_tokens": guest_count,
                "unknown_tokens": unknown_count,
                "current_index": self._current_index,
                "tokens": []
            }

            for token, status in self.token_statuses.items():
                status_info["tokens"].append({
                    "token": f"{token[:10]}...{token[-10:]}",
                    "token_id": status.token_id,
                    "token_type": status.token_type,
                    "is_available": status.is_available,
                    "failure_count": status.failure_count,
                    "success_count": status.successful_requests,
                    "success_rate": f"{status.success_rate:.2%}",
                    "total_requests": status.total_requests,
                    "is_healthy": status.is_healthy,
                    "last_failure_time": status.last_failure_time,
                    "last_success_time": status.last_success_time
                })

            return status_info

    def update_token_type(self, token: str, token_type: str):
        """更新 Token 类型(用于健康检查后更新)"""
        with self._lock:
            if token in self.token_statuses:
                old_type = self.token_statuses[token].token_type
                self.token_statuses[token].token_type = token_type

                if old_type != token_type:
                    logger.info(f"🔄 更新 Token 类型: {token[:20]}... {old_type}{token_type}")

    async def health_check_token(self, token: str) -> bool:
        """
        异步健康检查单个 Token(使用 Z.AI 官方认证接口)

        Args:
            token: 要检查的 Token

        Returns:
            Token 是否健康(True = 有效的认证用户 Token)
        """
        token_type, is_valid, error_message = await ZAITokenValidator.validate_token(token)

        # 更新 Token 类型
        self.update_token_type(token, token_type)

        # 更新状态
        if is_valid:
            self.mark_token_success(token)
        else:
            self.mark_token_failure(token, Exception(error_message or "验证失败"))

        return is_valid

    async def health_check_all(self):
        """异步健康检查所有 Token"""
        if not self.token_statuses:
            logger.warning("⚠️ Token 池为空,跳过健康检查")
            return

        total_tokens = len(self.token_statuses)
        logger.info(f"🔍 开始 Token 池健康检查... (共 {total_tokens} 个 Token)")

        # 并发执行所有 Token 的健康检查
        tasks = [
            self.health_check_token(token)
            for token in self.token_statuses.keys()
        ]

        results = await asyncio.gather(*tasks, return_exceptions=True)

        # 统计结果
        healthy_count = sum(1 for r in results if r is True)
        failed_count = sum(1 for r in results if r is False)
        exception_count = sum(1 for r in results if isinstance(r, Exception))

        health_rate = (healthy_count / total_tokens) * 100 if total_tokens > 0 else 0

        if healthy_count == 0 and total_tokens > 0:
            logger.warning(f"⚠️ 健康检查完成: 0/{total_tokens} 个 Token 健康 - 请检查 Token 配置")
        elif failed_count > 0:
            logger.warning(f"⚠️ 健康检查完成: {healthy_count}/{total_tokens} 个 Token 健康 ({health_rate:.1f}%)")
        else:
            logger.info(f"✅ 健康检查完成: {healthy_count}/{total_tokens} 个 Token 健康")

        if exception_count > 0:
            logger.error(f"💥 {exception_count} 个 Token 检查异常")

    async def sync_from_database(self, provider: str = "zai"):
        """
        从数据库同步 Token 状态(禁用/启用状态)

        Args:
            provider: 提供商名称

        说明:
            - 从数据库读取最新的 Token 启用状态
            - 如果数据库中 Token 被禁用,则从池中移除
            - 如果数据库中有新增的启用 Token,则添加到池中
            - 保留现有 Token 的运行时统计(请求数、成功率等)
        """
        from app.services.token_dao import get_token_dao

        dao = get_token_dao()

        # 从数据库加载所有启用的认证用户 Token
        token_records = await dao.get_tokens_by_provider(provider, enabled_only=True)

        # 构建数据库中的 Token 映射
        db_tokens = {
            record["token"]: (record["id"], record.get("token_type", "unknown"))
            for record in token_records
            if record.get("token_type") != "guest"  # 过滤 guest token
        }

        with self._lock:
            # 1. 移除已在数据库中禁用的 Token
            tokens_to_remove = []
            for token_value in list(self.token_statuses.keys()):
                if token_value not in db_tokens:
                    tokens_to_remove.append(token_value)

            for token_value in tokens_to_remove:
                del self.token_statuses[token_value]
                del self.token_id_map[token_value]
                logger.info(f"🗑️ 从池中移除已禁用 Token: {token_value[:20]}...")

            # 2. 添加新启用的 Token
            new_tokens_count = 0
            for token_value, (token_id, token_type) in db_tokens.items():
                if token_value not in self.token_statuses:
                    self.token_statuses[token_value] = TokenStatus(
                        token=token_value,
                        token_id=token_id,
                        token_type=token_type
                    )
                    self.token_id_map[token_value] = token_id
                    new_tokens_count += 1
                    logger.info(f"➕ 添加新启用 Token: {token_value[:20]}...")

            # 3. 更新现有 Token 的类型(如果数据库中有更新)
            for token_value, (token_id, token_type) in db_tokens.items():
                if token_value in self.token_statuses:
                    old_type = self.token_statuses[token_value].token_type
                    if old_type != token_type:
                        self.token_statuses[token_value].token_type = token_type
                        logger.info(f"🔄 更新 Token 类型: {token_value[:20]}... {old_type}{token_type}")

            logger.info(
                f"✅ Token 池同步完成: "
                f"当前 {len(self.token_statuses)} 个 Token "
                f"(移除 {len(tokens_to_remove)}, 新增 {new_tokens_count})"
            )


# ==================== 全局实例管理 ====================


_token_pool: Optional[TokenPool] = None
_pool_lock = Lock()


def get_token_pool() -> Optional[TokenPool]:
    """获取全局 Token 池实例"""
    return _token_pool


async def initialize_token_pool_from_db(
    provider: str = "zai",
    failure_threshold: int = 3,
    recovery_timeout: int = 1800
) -> Optional[TokenPool]:
    """
    从数据库初始化全局 Token 池

    Args:
        provider: 提供商名称 (zai, k2think, longcat)
        failure_threshold: 失败阈值
        recovery_timeout: 恢复超时时间(秒)

    Returns:
        TokenPool 实例(即使没有 Token 也会创建空池)
    """
    global _token_pool

    from app.services.token_dao import get_token_dao

    dao = get_token_dao()

    # 从数据库加载 Token(只加载启用的认证用户 Token)
    token_records = await dao.get_tokens_by_provider(provider, enabled_only=True)

    # 转换为 TokenPool 所需格式
    tokens = []
    if token_records:
        tokens = [
            (record["id"], record["token"], record.get("token_type", "unknown"))
            for record in token_records
        ]

        # 过滤掉 guest token(不应该在数据库中,但防御性检查)
        user_tokens = [
            (tid, tval, ttype) for tid, tval, ttype in tokens
            if ttype != "guest"
        ]

        if len(user_tokens) < len(tokens):
            guest_count = len(tokens) - len(user_tokens)
            logger.warning(f"⚠️ 过滤了 {guest_count} 个匿名用户 Token")

        tokens = user_tokens

    # 始终创建 Token 池实例(即使为空)
    with _pool_lock:
        _token_pool = TokenPool(tokens, failure_threshold, recovery_timeout)

        if not tokens:
            logger.warning(f"⚠️ {provider} 没有有效的认证用户 Token,已创建空 Token 池")
        else:
            logger.info(f"🔧 从数据库初始化 Token 池({provider}),共 {len(tokens)} 个 Token")

        return _token_pool


async def sync_token_stats_to_db():
    """
    将内存中的 Token 统计同步到数据库

    应在服务关闭或定期调用,确保统计数据不丢失
    """
    pool = get_token_pool()
    if not pool:
        return

    from app.services.token_dao import get_token_dao

    dao = get_token_dao()

    with pool._lock:
        for token, status in pool.token_statuses.items():
            token_id = status.token_id

            # 更新数据库统计(简化版,实际可能需要增量更新)
            if status.successful_requests > 0:
                for _ in range(status.successful_requests):
                    await dao.record_success(token_id)

            if status.total_requests - status.successful_requests > 0:
                for _ in range(status.total_requests - status.successful_requests):
                    await dao.record_failure(token_id)

    logger.info("✅ Token 统计已同步到数据库")