File size: 5,364 Bytes
1d96a60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""

基础任务服务类

提供通用的任务管理、日志记录和账户更新功能

"""
import asyncio
import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar

from core.account import update_accounts_config

logger = logging.getLogger("gemini.base_task")


class TaskStatus(str, Enum):
    """任务状态枚举"""
    PENDING = "pending"
    RUNNING = "running"
    SUCCESS = "success"
    FAILED = "failed"


@dataclass
class BaseTask:
    """基础任务数据类"""
    id: str
    status: TaskStatus = TaskStatus.PENDING
    progress: int = 0
    success_count: int = 0
    fail_count: int = 0
    created_at: float = field(default_factory=time.time)
    finished_at: Optional[float] = None
    results: List[Dict[str, Any]] = field(default_factory=list)
    error: Optional[str] = None
    logs: List[Dict[str, str]] = field(default_factory=list)

    def to_dict(self) -> dict:
        """转换为字典"""
        return {
            "id": self.id,
            "status": self.status.value,
            "progress": self.progress,
            "success_count": self.success_count,
            "fail_count": self.fail_count,
            "created_at": self.created_at,
            "finished_at": self.finished_at,
            "results": self.results,
            "error": self.error,
            "logs": self.logs,
        }


T = TypeVar('T', bound=BaseTask)


class BaseTaskService(Generic[T]):
    """

    基础任务服务类

    提供通用的任务管理、日志记录和账户更新功能

    """

    def __init__(

        self,

        multi_account_mgr,

        http_client,

        user_agent: str,

        account_failure_threshold: int,

        rate_limit_cooldown_seconds: int,

        session_cache_ttl_seconds: int,

        global_stats_provider: Callable[[], dict],

        set_multi_account_mgr: Optional[Callable[[Any], None]] = None,

        log_prefix: str = "TASK",

    ) -> None:
        """

        初始化基础任务服务



        Args:

            multi_account_mgr: 多账户管理器

            http_client: HTTP客户端

            user_agent: 用户代理

            account_failure_threshold: 账户失败阈值

            rate_limit_cooldown_seconds: 速率限制冷却秒数

            session_cache_ttl_seconds: 会话缓存TTL秒数

            global_stats_provider: 全局统计提供者

            set_multi_account_mgr: 设置多账户管理器的回调

            log_prefix: 日志前缀

        """
        self._executor = ThreadPoolExecutor(max_workers=1)
        self._tasks: Dict[str, T] = {}
        self._current_task_id: Optional[str] = None
        self._lock = asyncio.Lock()
        self._log_lock = threading.Lock()
        self._log_prefix = log_prefix

        self.multi_account_mgr = multi_account_mgr
        self.http_client = http_client
        self.user_agent = user_agent
        self.account_failure_threshold = account_failure_threshold
        self.rate_limit_cooldown_seconds = rate_limit_cooldown_seconds
        self.session_cache_ttl_seconds = session_cache_ttl_seconds
        self.global_stats_provider = global_stats_provider
        self.set_multi_account_mgr = set_multi_account_mgr

    def get_task(self, task_id: str) -> Optional[T]:
        """获取指定任务"""
        return self._tasks.get(task_id)

    def get_current_task(self) -> Optional[T]:
        """获取当前任务"""
        if not self._current_task_id:
            return None
        return self._tasks.get(self._current_task_id)

    def _append_log(self, task: T, level: str, message: str) -> None:
        """

        添加日志到任务



        Args:

            task: 任务对象

            level: 日志级别 (info, warning, error)

            message: 日志消息

        """
        entry = {
            "time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
            "level": level,
            "message": message,
        }
        with self._log_lock:
            task.logs.append(entry)
            if len(task.logs) > 200:
                task.logs = task.logs[-200:]

        log_message = f"[{self._log_prefix}] {message}"
        if level == "warning":
            logger.warning(log_message)
        elif level == "error":
            logger.error(log_message)
        else:
            logger.info(log_message)

    def _apply_accounts_update(self, accounts_data: list) -> None:
        """

        应用账户更新



        Args:

            accounts_data: 账户数据列表

        """
        global_stats = self.global_stats_provider() or {}
        new_mgr = update_accounts_config(
            accounts_data,
            self.multi_account_mgr,
            self.http_client,
            self.user_agent,
            self.account_failure_threshold,
            self.rate_limit_cooldown_seconds,
            self.session_cache_ttl_seconds,
            global_stats,
        )
        self.multi_account_mgr = new_mgr
        if self.set_multi_account_mgr:
            self.set_multi_account_mgr(new_mgr)