diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,1556 +1,1565 @@ -import asyncio -import atexit -import hashlib -import json -import logging -import os -import secrets -import threading -import traceback -import time -from collections import deque -from datetime import datetime -from pathlib import Path -from typing import Any, Optional -from urllib.parse import parse_qsl, unquote, urlsplit - -import pymysql -import uvicorn -from apscheduler.schedulers.background import BackgroundScheduler -from apscheduler.triggers.cron import CronTrigger -from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile, status -from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse -from fastapi.staticfiles import StaticFiles -from fastapi.templating import Jinja2Templates -from pydantic import BaseModel, Field - -from core.tasks import runTasks -from utils.logger import setup_logger - - -logger = setup_logger(level=logging.DEBUG) - -BASE_DIR = Path(__file__).resolve().parent -TEMPLATES_DIR = BASE_DIR / "templates" -STATIC_DIR = BASE_DIR / "static" -ROOT_CONFIG_PATH = BASE_DIR / "config.json" -LEGACY_DATA_DIR = BASE_DIR / "data" -LEGACY_USERS_META_PATH = LEGACY_DATA_DIR / "users.json" -MYSQL_DSN_TEMPLATE = "mysql://SQL_PASSWORD@mysql-2bace9cd-cacode.i.aivencloud.com:21260/defaultdb?ssl-mode=REQUIRED" -MYSQL_DSN_ENV = "MYSQL_DSN_TEMPLATE" -MYSQL_PASSWORD_ENV = "SQL_PASSWORD" -MYSQL_USER_ENV = "MYSQL_USER" -MYSQL_CA_CERT_ENV = "MYSQL_CA_CERT_PATH" -MYSQL_DEFAULT_USER = "avnadmin" -USERS_TABLE = "app_users" -SESSION_COOKIE_NAME = "sparkflow_auth" -DEFAULT_TIMEZONE = "Asia/Shanghai" -MAX_LOG_LINES = 1200 -MAX_TEMPLATE_LENGTH = 2000 -PASSWORD_ITERATIONS = 210000 -FAILED_RETRY_JOB_ID = "_system_retry_failed_tasks" -FAILED_RETRY_INTERVAL_HOURS = 1 - -DEFAULT_USER_CONFIG = { - "multiTask": True, - "taskCount": 5, - "proxyAddress": "", - "messageTemplate": "[续火花]", - "hitokotoTypes": ["文学", "影视", "诗词", "哲学"], - "scheduler": { - "enabled": True, - "timezone": DEFAULT_TIMEZONE, - "hour": 9, - "minute": 0, - "runOnStartup": False, - }, -} - -AUTH_SESSIONS: dict[str, dict[str, str]] = {} -db_init_lock = threading.Lock() -scheduler_lock = threading.Lock() -runtime_map_lock = threading.Lock() -db_initialized = False -db_status_lock = threading.Lock() -db_status = { - "connected": None, - "last_checked_at": None, - "last_ok_at": None, - "last_error": "", -} -scheduler_bootstrapped = False -scheduler_bootstrap_lock = threading.Lock() -scheduler_bootstrap_running = False - - -class UserRuntimeState: - def __init__(self, username: str): - self.username = username - self._run_lock = threading.Lock() - self._state_lock = threading.Lock() - self.is_running = False - self.last_status = "未开始" - self.last_error = "" - self.last_trigger = "-" - self.last_start = None - self.last_end = None - self.next_run = None - self.schedule_hour = 9 - self.schedule_minute = 0 - self.schedule_timezone = DEFAULT_TIMEZONE - self.history = deque(maxlen=50) - self.logs = deque(maxlen=2000) - - def _format_ts(self, value: Optional[datetime]): - if not value: - return "-" - return value.strftime("%Y-%m-%d %H:%M:%S") - - def schedule_time(self): - return f"{self.schedule_hour:02d}:{self.schedule_minute:02d}" - - def _set_running(self, value: bool): - with self._state_lock: - self.is_running = value - - def add_log(self, message: str): - ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - with self._state_lock: - self.logs.append(f"{ts} [{self.username}] {message}") - - def update_schedule(self, hour: int, minute: int, timezone: str): - with self._state_lock: - self.schedule_hour = hour - self.schedule_minute = minute - self.schedule_timezone = timezone - - def update_next_run(self, next_run): - with self._state_lock: - self.next_run = next_run - - def snapshot(self, account_count: int, target_count: int): - with self._state_lock: - return { - "is_running": self.is_running, - "last_status": self.last_status, - "last_error": self.last_error, - "last_trigger": self.last_trigger, - "last_start": self._format_ts(self.last_start), - "last_end": self._format_ts(self.last_end), - "next_run": self._format_ts(self.next_run), - "account_count": account_count, - "target_count": target_count, - "schedule_time": self.schedule_time(), - "schedule_timezone": self.schedule_timezone, - } - - def history_rows(self): - with self._state_lock: - return list(self.history)[::-1] - - def recent_logs(self, limit=MAX_LOG_LINES): - with self._state_lock: - lines = list(self.logs)[-max(1, limit):] - return "\n".join(lines) if lines else "暂无日志。" - - def run_once(self, trigger: str): - if not self._run_lock.acquire(blocking=False): - self.add_log(f"任务已在运行中,忽略触发:{trigger}") - return False, "已有任务在运行,本次触发已跳过。" - - self._set_running(True) - with self._state_lock: - self.last_trigger = trigger - self.last_start = datetime.now() - self.last_end = None - self.last_error = "" - self.last_status = "运行中" - self.add_log(f"任务开始执行,触发方式:{trigger}") - - ok = True - message = "任务执行完成。" - try: - asyncio.run(_run_user_tasks(self.username)) - with self._state_lock: - self.last_status = "成功" - except Exception as exc: - ok = False - message = f"任务执行失败:{exc}" - with self._state_lock: - self.last_status = "失败" - self.last_error = repr(exc) - self.add_log(f"任务失败:{exc}") - logger.error("Task failed. user=%s trigger=%s error=%s", self.username, trigger, exc) - logger.debug("Task traceback:\n%s", traceback.format_exc()) - finally: - end_at = datetime.now() - with self._state_lock: - self.last_end = end_at - duration = (self.last_end - self.last_start).total_seconds() - self.history.append( - { - "trigger": trigger, - "start": self._format_ts(self.last_start), - "end": self._format_ts(self.last_end), - "status": self.last_status, - "duration": f"{duration:.2f}s", - "message": self.last_error or "OK", - } - ) - current_status = self.last_status - self.add_log(f"任务结束,状态={current_status},耗时={duration:.2f}s") - self._set_running(False) - self._run_lock.release() - return ok, message - - -runtime_map: dict[str, UserRuntimeState] = {} -scheduler = None - - -class UserLoginPayload(BaseModel): - username: str - password: str - - -class AdminLoginPayload(BaseModel): - password: str - - -class SchedulePayload(BaseModel): - time: str - - -class MessageTemplatePayload(BaseModel): - message: str - - -class UserTargetsItem(BaseModel): - unique_id: str - targets: list[str] = Field(default_factory=list) - - -class UserTargetsPayload(BaseModel): - users: list[UserTargetsItem] - - -def _ensure_data_layout(): - global db_initialized - if db_initialized: - return - - logger.info("DB layout ensure begin.") - with db_init_lock: - if db_initialized: - logger.info("DB layout already initialized by another worker.") - return - logger.info("DB schema initialization begin.") - _init_db_schema() - db_initialized = True - logger.info("DB schema initialization complete.") - - try: - logger.info("Legacy migration stage begin.") - _migrate_legacy_file_data_if_needed() - logger.info("Legacy migration stage complete.") - except Exception as exc: - logger.warning("Legacy data migration skipped due to error: %s", exc) - - -def _hash_password(password: str, salt_hex: Optional[str] = None): - salt = bytes.fromhex(salt_hex) if salt_hex else secrets.token_bytes(16) - digest = hashlib.pbkdf2_hmac( - "sha256", - password.encode("utf-8"), - salt, - PASSWORD_ITERATIONS, - ) - return { - "salt": salt.hex(), - "hash": digest.hex(), - } - - -def _verify_password(password: str, salt_hex: str, expected_hash: str): - data = _hash_password(password, salt_hex=salt_hex) - return secrets.compare_digest(data["hash"], expected_hash) - - -def _deep_copy_json(value): - return json.loads(json.dumps(value, ensure_ascii=False)) - - -def _merge_config_with_defaults(raw_cfg: Any): - base = _deep_copy_json(DEFAULT_USER_CONFIG) - if not isinstance(raw_cfg, dict): - return base - - merged = _deep_copy_json(base) - merged.update(raw_cfg) - base_scheduler = base.get("scheduler", {}) - merged_scheduler = raw_cfg.get("scheduler", {}) - if isinstance(merged_scheduler, dict): - scheduler = _deep_copy_json(base_scheduler) - scheduler.update(merged_scheduler) - merged["scheduler"] = scheduler - else: - merged["scheduler"] = _deep_copy_json(base_scheduler) - return merged - - -def _format_common_ts(value: Optional[datetime]): - if not value: - return "-" - return value.strftime("%Y-%m-%d %H:%M:%S") - - -def _update_db_status(connected: bool, error: Optional[Exception] = None): - now = datetime.now() - with db_status_lock: - db_status["connected"] = connected - db_status["last_checked_at"] = now - if connected: - db_status["last_ok_at"] = now - db_status["last_error"] = "" - else: - db_status["last_error"] = str(error or "数据库连接失败") - - -def _build_db_status_payload(): - with db_status_lock: - connected = db_status.get("connected") - return { - "connected": connected, - "last_checked_at": _format_common_ts(db_status.get("last_checked_at")), - "last_ok_at": _format_common_ts(db_status.get("last_ok_at")), - "last_error": str(db_status.get("last_error") or ""), - } - - -def _resolve_mysql_dsn(): - raw = os.getenv(MYSQL_DSN_ENV, MYSQL_DSN_TEMPLATE).strip() - if "SQL_PASSWORD" in raw: - secret = os.getenv(MYSQL_PASSWORD_ENV, "").strip() - if not secret: - raise RuntimeError(f"环境变量 {MYSQL_PASSWORD_ENV} 未设置,无法连接 MySQL。") - raw = raw.replace("SQL_PASSWORD", secret, 1) - return raw - - -def _build_mysql_conn_kwargs(): - dsn = _resolve_mysql_dsn() - parsed = urlsplit(dsn) - if parsed.scheme not in ("mysql", "mysql+pymysql"): - raise RuntimeError(f"不支持的 MySQL DSN 协议:{parsed.scheme}") - - host = parsed.hostname - if not host: - raise RuntimeError("MySQL DSN 缺少主机地址。") - - user = unquote(parsed.username or "") - password = unquote(parsed.password) if parsed.password is not None else None - if user and password is None: - password = user - user = os.getenv(MYSQL_USER_ENV, MYSQL_DEFAULT_USER).strip() or MYSQL_DEFAULT_USER - if not user: - user = os.getenv(MYSQL_USER_ENV, MYSQL_DEFAULT_USER).strip() or MYSQL_DEFAULT_USER - if not password: - password = os.getenv(MYSQL_PASSWORD_ENV, "").strip() - if not password: - raise RuntimeError("MySQL 密码为空,请检查 SQL_PASSWORD 环境变量。") - - db_name = parsed.path.lstrip("/") or "defaultdb" - query = {k.lower(): v for k, v in parse_qsl(parsed.query, keep_blank_values=True)} - ssl_mode = str(query.get("ssl-mode", query.get("ssl_mode", ""))).upper() - - kwargs = { - "host": host, - "port": parsed.port or 3306, - "user": user, - "password": password, - "database": db_name, - "charset": "utf8mb4", - "autocommit": True, - "connect_timeout": int(os.getenv("MYSQL_CONNECT_TIMEOUT", "4")), - "read_timeout": int(os.getenv("MYSQL_READ_TIMEOUT", "8")), - "write_timeout": int(os.getenv("MYSQL_WRITE_TIMEOUT", "8")), - "cursorclass": pymysql.cursors.DictCursor, - } - - if ssl_mode in {"REQUIRED", "VERIFY_CA", "VERIFY_IDENTITY"}: - ca_file = Path(os.getenv(MYSQL_CA_CERT_ENV, str(BASE_DIR / "camysql2.pem"))).resolve() - if not ca_file.exists(): - raise RuntimeError(f"MySQL CA 证书不存在:{ca_file}") - kwargs["ssl"] = {"ca": str(ca_file)} - - return kwargs - - -def _db_connect(): - kwargs = _build_mysql_conn_kwargs() - host = kwargs.get("host") - port = kwargs.get("port") - database = kwargs.get("database") - connect_timeout = kwargs.get("connect_timeout") - read_timeout = kwargs.get("read_timeout") - write_timeout = kwargs.get("write_timeout") - has_ssl = bool(kwargs.get("ssl")) - started_at = time.perf_counter() - - logger.info( - "MySQL connect begin. host=%s port=%s db=%s connect_timeout=%ss read_timeout=%ss write_timeout=%ss ssl=%s", - host, - port, - database, - connect_timeout, - read_timeout, - write_timeout, - has_ssl, - ) - try: - conn = pymysql.connect(**kwargs) - except Exception as exc: - elapsed = time.perf_counter() - started_at - logger.warning( - "MySQL connect failed after %.2fs. host=%s port=%s db=%s error=%s", - elapsed, - host, - port, - database, - exc, - ) - _update_db_status(False, exc) - raise - - elapsed = time.perf_counter() - started_at - logger.info("MySQL connect success. host=%s port=%s db=%s elapsed=%.2fs", host, port, database, elapsed) - _update_db_status(True) - return conn - - -def _db_query_all(query: str, params=()): - conn = _db_connect() - try: - with conn.cursor() as cursor: - cursor.execute(query, params) - return cursor.fetchall() - finally: - conn.close() - - -def _db_query_one(query: str, params=()): - conn = _db_connect() - try: - with conn.cursor() as cursor: - cursor.execute(query, params) - return cursor.fetchone() - finally: - conn.close() - - -def _db_execute(query: str, params=()): - conn = _db_connect() - try: - with conn.cursor() as cursor: - cursor.execute(query, params) - return cursor.rowcount - finally: - conn.close() - - -def _init_db_schema(): - logger.info("DB schema creation SQL begin.") - _db_execute( - f""" - CREATE TABLE IF NOT EXISTS `{USERS_TABLE}` ( - `username` VARCHAR(128) NOT NULL, - `unique_id` VARCHAR(255) NOT NULL, - `password_hash` VARCHAR(128) NOT NULL, - `password_salt` VARCHAR(64) NOT NULL, - `created_at` VARCHAR(32) NOT NULL, - `config_json` LONGTEXT NOT NULL, - `users_data_json` LONGTEXT NOT NULL, - `updated_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - PRIMARY KEY (`username`), - UNIQUE KEY `uniq_unique_id` (`unique_id`) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - ) - logger.info("DB schema creation SQL complete.") - - -def _legacy_load_json(path: Path, default): - if not path.exists(): - return default - with path.open("r", encoding="utf-8") as f: - return json.load(f) - - -def _migrate_legacy_file_data_if_needed(): - logger.info("Legacy migration check. path=%s", LEGACY_USERS_META_PATH) - if not LEGACY_USERS_META_PATH.exists(): - logger.info("Legacy migration skipped: users.json not found.") - return - - row = _db_query_one(f"SELECT COUNT(*) AS cnt FROM `{USERS_TABLE}`") - existing_count = int(row.get("cnt", 0)) if row else 0 - if existing_count > 0: - logger.info("Legacy migration skipped: database already has %s users.", existing_count) - return - - try: - raw = _legacy_load_json(LEGACY_USERS_META_PATH, {"users": []}) - except Exception as exc: - logger.warning("读取旧版 users.json 失败,跳过迁移:%s", exc) - return - - users = raw.get("users", []) if isinstance(raw, dict) else [] - if not users: - logger.info("Legacy migration skipped: legacy users list is empty.") - return - - logger.info("Legacy migration loaded %s legacy users.", len(users)) - migrated = 0 - for item in users: - username = str(item.get("username", "")).strip() - unique_id = str(item.get("unique_id", "")).strip() - password_hash = str(item.get("password_hash", "")).strip() - password_salt = str(item.get("password_salt", "")).strip() - created_at = str(item.get("created_at", "")).strip() or datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - if not (username and unique_id and password_hash and password_salt): - logger.warning("旧版用户数据不完整,跳过:%s", username or "") - continue - - cfg = _get_default_user_config() - users_data = [] - tenant_rel = str(item.get("tenant_dir", "")).strip() - if tenant_rel: - tenant_dir = (BASE_DIR / tenant_rel).resolve() - cfg_path = tenant_dir / "config.json" - users_data_path = tenant_dir / "usersData.json" - try: - cfg = _merge_config_with_defaults(_legacy_load_json(cfg_path, cfg)) - except Exception as exc: - logger.warning("读取旧版配置失败,使用默认配置。user=%s error=%s", username, exc) - try: - users_data = _legacy_load_json(users_data_path, []) - except Exception as exc: - logger.warning("读取旧版 usersData 失败。user=%s error=%s", username, exc) - - if not isinstance(users_data, list): - users_data = [] - - try: - _create_user_record( - username=username, - unique_id=unique_id, - password_hash=password_hash, - password_salt=password_salt, - created_at=created_at, - config_payload=cfg, - users_data_payload=users_data, - ) - migrated += 1 - except Exception as exc: - logger.warning("迁移用户失败。user=%s error=%s", username, exc) - - logger.info("Legacy migration finished. migrated=%s total=%s", migrated, len(users)) - if migrated > 0: - logger.info("已完成旧版文件数据迁移,共迁移 %s 个用户。", migrated) - - -def _load_users_meta(): - logger.info("Load users meta begin.") - _ensure_data_layout() - rows = _db_query_all( - f""" - SELECT username, unique_id, password_hash, password_salt, created_at - FROM `{USERS_TABLE}` - ORDER BY username ASC - """ - ) - logger.info("Load users meta complete. count=%s", len(rows)) - return {str(row["username"]): row for row in rows} - - -def _load_user_row(username: str): - _ensure_data_layout() - return _db_query_one( - f""" - SELECT username, unique_id, password_hash, password_salt, created_at, config_json, users_data_json - FROM `{USERS_TABLE}` - WHERE username=%s - """, - (username,), - ) - - -def _user_exists(username: str): - _ensure_data_layout() - row = _db_query_one( - f"SELECT 1 AS ok FROM `{USERS_TABLE}` WHERE username=%s", - (username,), - ) - return bool(row) - - -def _create_user_record( - *, - username: str, - unique_id: str, - password_hash: str, - password_salt: str, - created_at: str, - config_payload: dict[str, Any], - users_data_payload: list[dict[str, Any]], -): - _db_execute( - f""" - INSERT INTO `{USERS_TABLE}` - (username, unique_id, password_hash, password_salt, created_at, config_json, users_data_json) - VALUES (%s, %s, %s, %s, %s, %s, %s) - """, - ( - username, - unique_id, - password_hash, - password_salt, - created_at, - json.dumps(config_payload, ensure_ascii=False), - json.dumps(users_data_payload, ensure_ascii=False), - ), - ) - - -def _delete_user_record(username: str): - _ensure_data_layout() - return _db_execute(f"DELETE FROM `{USERS_TABLE}` WHERE username=%s", (username,)) - - -def _get_user_meta_or_404(username: str): - users_map = _load_users_meta() - user = users_map.get(username) - if not user: - raise HTTPException(status_code=404, detail="用户不存在") - return user - - -def _get_default_user_config(): - if ROOT_CONFIG_PATH.exists(): - try: - with ROOT_CONFIG_PATH.open("r", encoding="utf-8") as f: - root_cfg = json.load(f) - return _merge_config_with_defaults(root_cfg) - except Exception: - logger.warning("Failed to read root config.json. fallback to DEFAULT_USER_CONFIG") - return _deep_copy_json(DEFAULT_USER_CONFIG) - - -def _load_user_config(username: str): - row = _load_user_row(username) - if not row: - raise FileNotFoundError(f"用户 {username} 不存在") - try: - payload = json.loads(row.get("config_json", "{}")) - except Exception as exc: - raise ValueError(f"用户 {username} 的配置数据损坏:{exc}") - return _merge_config_with_defaults(payload) - - -def _save_user_config(username: str, cfg: dict): - normalized = _merge_config_with_defaults(cfg) - changed = _db_execute( - f"UPDATE `{USERS_TABLE}` SET config_json=%s WHERE username=%s", - (json.dumps(normalized, ensure_ascii=False), username), - ) - if changed == 0 and not _user_exists(username): - raise FileNotFoundError(f"用户 {username} 不存在") - - -def _load_user_users_data(username: str): - row = _load_user_row(username) - if not row: - raise FileNotFoundError(f"用户 {username} 不存在") - try: - data = json.loads(row.get("users_data_json", "[]")) - except Exception as exc: - raise ValueError(f"用户 {username} 的 usersData 数据损坏:{exc}") - if not isinstance(data, list): - raise ValueError("usersData.json 必须是数组") - return data - - -def _save_user_users_data(username: str, users_data: list): - changed = _db_execute( - f"UPDATE `{USERS_TABLE}` SET users_data_json=%s WHERE username=%s", - (json.dumps(users_data, ensure_ascii=False), username), - ) - if changed == 0 and not _user_exists(username): - raise FileNotFoundError(f"用户 {username} 不存在") - - -def _sanitize_targets(values): - cleaned = [] - seen = set() - for value in values or []: - text = str(value).strip() - if not text or text in seen: - continue - seen.add(text) - cleaned.append(text) - return cleaned - - -def _validate_and_normalize_users_data(raw_bytes: bytes): - try: - payload = json.loads(raw_bytes.decode("utf-8")) - except Exception as exc: - raise ValueError(f"上传文件不是合法 JSON:{exc}") - - if not isinstance(payload, list) or not payload: - raise ValueError("usersData.json 必须是非空数组") - - normalized = [] - for idx, item in enumerate(payload): - if not isinstance(item, dict): - raise ValueError(f"第 {idx + 1} 条用户数据格式错误(必须是对象)") - - unique_id = str(item.get("unique_id", "")).strip() - username = str(item.get("username", "")).strip() - cookies = item.get("cookies", []) - targets = item.get("targets", []) - - if not unique_id: - raise ValueError(f"第 {idx + 1} 条缺少 unique_id") - if not username: - raise ValueError(f"第 {idx + 1} 条缺少 username") - if not isinstance(cookies, list) or not cookies: - raise ValueError(f"第 {idx + 1} 条 cookies 不能为空且必须是数组") - if not isinstance(targets, list): - raise ValueError(f"第 {idx + 1} 条 targets 必须是数组") - - normalized.append( - { - "unique_id": unique_id, - "username": username, - "cookies": cookies, - "targets": _sanitize_targets(targets), - } - ) - - primary_username = normalized[0]["username"] - primary_unique_id = normalized[0]["unique_id"] - return normalized, primary_username, primary_unique_id - - -def _count_targets(users_data: list): - return sum(len(user.get("targets", [])) for user in users_data) - - -def _get_runtime(username: str): - with runtime_map_lock: - runtime = runtime_map.get(username) - if runtime is None: - runtime = UserRuntimeState(username=username) - runtime_map[username] = runtime - return runtime - - -def _delete_runtime(username: str): - with runtime_map_lock: - runtime_map.pop(username, None) - - -def _session_from_request(request: Request): - token = request.cookies.get(SESSION_COOKIE_NAME) - if not token: - return None - return AUTH_SESSIONS.get(token) - - -def _require_user_session(request: Request): - session = _session_from_request(request) - if not session or session.get("role") != "user": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="未登录或登录已失效", - ) - return session - - -def _require_admin_session(request: Request): - session = _session_from_request(request) - if not session or session.get("role") != "admin": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="未登录或登录已失效", - ) - return session - - -def _parse_time_string(value: str): - parts = value.strip().split(":") - if len(parts) not in (2, 3): - raise ValueError("时间格式错误,必须是 HH:MM") - hour = int(parts[0]) - minute = int(parts[1]) - if hour < 0 or hour > 23 or minute < 0 or minute > 59: - raise ValueError("时间范围错误,小时 0-23,分钟 0-59") - return hour, minute - - -def _build_editor_state(username: str): - cfg = _load_user_config(username) - users = _load_user_users_data(username) - return { - "message_template": str(cfg.get("messageTemplate", "")), - "users": [ - { - "unique_id": str(user.get("unique_id", "")), - "username": str(user.get("username", "未知用户")), - "targets": _sanitize_targets(user.get("targets", [])), - } - for user in users - ], - } - - -def _scheduler_job_id(username: str): - return f"daily_task::{username}" - - -def _run_scheduled_once(username: str): - runtime = _get_runtime(username) - runtime.run_once("schedule") - if scheduler: - job = scheduler.get_job(_scheduler_job_id(username)) - runtime.update_next_run(job.next_run_time if job else None) - - -async def _run_user_tasks(username: str): - cfg = _load_user_config(username) - users_data = _load_user_users_data(username) - await runTasks(config=cfg, userData=users_data) - - -def _sync_user_jobs_from_meta(users_map: dict[str, Any], run_startup_tasks: bool = False): - global scheduler_bootstrapped - - logger.info("Sync user jobs begin. count=%s run_startup_tasks=%s", len(users_map), run_startup_tasks) - for username in users_map.keys(): - logger.info("Sync user job. username=%s", username) - _schedule_user_job(username) - if run_startup_tasks: - cfg = _load_user_config(username) - run_on_startup = bool(cfg.get("scheduler", {}).get("runOnStartup", False)) - logger.info("Startup run flag loaded. username=%s run_on_startup=%s", username, run_on_startup) - if run_on_startup: - logger.info("Trigger startup run. username=%s", username) - _start_background_run(username, "startup") - - scheduler_bootstrapped = True - logger.info("Sync user jobs complete. count=%s", len(users_map)) - - -def _start_scheduler_bootstrap(run_startup_tasks: bool): - global scheduler_bootstrapped, scheduler_bootstrap_running - - with scheduler_bootstrap_lock: - if scheduler_bootstrap_running: - logger.info("Scheduler bootstrap already running; skip duplicate start.") - return False - scheduler_bootstrap_running = True - - def _worker(): - global scheduler_bootstrapped, scheduler_bootstrap_running - try: - logger.info("Scheduler bootstrap started. run_startup_tasks=%s", run_startup_tasks) - - logger.info("Bootstrap stage begin: ensure_data_layout") - _ensure_data_layout() - logger.info("Bootstrap stage complete: ensure_data_layout") - - logger.info("Bootstrap stage begin: load_users_meta") - users_map = _load_users_meta() - logger.info("Bootstrap stage complete: load_users_meta count=%s", len(users_map)) - - logger.info("Bootstrap stage begin: sync_user_jobs") - _sync_user_jobs_from_meta(users_map, run_startup_tasks=run_startup_tasks) - logger.info("Bootstrap stage complete: sync_user_jobs count=%s", len(users_map)) - - logger.info("Scheduler bootstrap completed. users=%s", len(users_map)) - except Exception as exc: - scheduler_bootstrapped = False - logger.warning("Scheduler bootstrap skipped, database unavailable. error=%s", exc) - finally: - with scheduler_bootstrap_lock: - scheduler_bootstrap_running = False - - thread = threading.Thread(target=_worker, daemon=True, name="scheduler-bootstrap") - thread.start() - return True - - -def _retry_failed_tasks_once(trigger: str, *, raise_on_db_error: bool = False): - try: - users_map = _load_users_meta() - if not scheduler_bootstrapped: - _sync_user_jobs_from_meta(users_map, run_startup_tasks=False) - except Exception as exc: - logger.warning("Failed to load users for failed-task retry. error=%s", exc) - if raise_on_db_error: - raise - return [] - - triggered = [] - for username in users_map.keys(): - runtime = _get_runtime(username) - snapshot = runtime.snapshot(account_count=0, target_count=0) - if snapshot.get("is_running") or snapshot.get("last_status") != "失败": - continue - - try: - cfg = _load_user_config(username) - except Exception as exc: - runtime.add_log(f"自动重试前加载配置失败:{exc}") - continue - - if not bool(cfg.get("scheduler", {}).get("enabled", True)): - continue - - runtime.add_log(f"检测到失败任务,准备执行自动重试:{trigger}") - _start_background_run(username, trigger) - triggered.append(username) - - if triggered: - logger.info("Retried failed tasks for users: %s", ", ".join(triggered)) - return triggered - - -def _retry_failed_tasks_job(): - _retry_failed_tasks_once("hourly_retry") - - -def _schedule_user_job(username: str): - global scheduler - - cfg = _load_user_config(username) - scheduler_cfg = cfg.get("scheduler", {}) if isinstance(cfg, dict) else {} - enabled = bool(scheduler_cfg.get("enabled", True)) - timezone = str(scheduler_cfg.get("timezone", DEFAULT_TIMEZONE)) - hour = int(scheduler_cfg.get("hour", 9)) - minute = int(scheduler_cfg.get("minute", 0)) - - runtime = _get_runtime(username) - runtime.update_schedule(hour, minute, timezone) - - with scheduler_lock: - if scheduler is None: - scheduler = BackgroundScheduler(timezone=timezone) - scheduler.start() - - job_id = _scheduler_job_id(username) - if not enabled: - if scheduler.get_job(job_id): - scheduler.remove_job(job_id) - runtime.update_next_run(None) - runtime.add_log("定时任务已禁用") - return - - scheduler.add_job( - _run_scheduled_once, - args=[username], - trigger=CronTrigger(hour=hour, minute=minute, timezone=timezone), - id=job_id, - replace_existing=True, - max_instances=1, - coalesce=True, - ) - job = scheduler.get_job(job_id) - runtime.update_next_run(job.next_run_time if job else None) - runtime.add_log(f"定时任务更新为 {hour:02d}:{minute:02d} ({timezone})") - - -def _remove_user_schedule_job(username: str): - with scheduler_lock: - if scheduler is None: - return - job_id = _scheduler_job_id(username) - if scheduler.get_job(job_id): - scheduler.remove_job(job_id) - - -def _start_background_run(username: str, trigger: str): - runtime = _get_runtime(username) - - def _worker(): - runtime.run_once(trigger) - if scheduler: - job = scheduler.get_job(_scheduler_job_id(username)) - runtime.update_next_run(job.next_run_time if job else None) - - thread = threading.Thread(target=_worker, daemon=True) - thread.start() - return True - - -def _start_scheduler(): - global scheduler - with scheduler_lock: - if scheduler is None: - scheduler = BackgroundScheduler(timezone=DEFAULT_TIMEZONE) - scheduler.start() - scheduler.add_job( - _retry_failed_tasks_job, - trigger="interval", - hours=FAILED_RETRY_INTERVAL_HOURS, - id=FAILED_RETRY_JOB_ID, - replace_existing=True, - max_instances=1, - coalesce=True, - ) - - _start_scheduler_bootstrap(run_startup_tasks=True) - - -def _stop_scheduler(): - global scheduler, scheduler_bootstrapped, scheduler_bootstrap_running - with scheduler_lock: - if scheduler and scheduler.running: - scheduler.shutdown(wait=False) - logger.info("Scheduler stopped.") - scheduler = None - scheduler_bootstrapped = False - with scheduler_bootstrap_lock: - scheduler_bootstrap_running = False - - -app = FastAPI(title="DouYin Spark Flow Dashboard") -app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") -templates = Jinja2Templates(directory=str(TEMPLATES_DIR)) - - -@app.on_event("startup") -async def on_startup(): - logger.info("Application startup begin.") - _start_scheduler() - atexit.register(_stop_scheduler) - logger.info("Application startup complete.") - - -@app.on_event("shutdown") -async def on_shutdown(): - _stop_scheduler() - - -@app.get("/", response_class=HTMLResponse) -async def dashboard(request: Request): - session = _session_from_request(request) - if not session: - return RedirectResponse(url="/login", status_code=303) - if session.get("role") == "admin": - return RedirectResponse(url="/admin", status_code=303) - - username = session.get("username") - runtime = _get_runtime(username) - return templates.TemplateResponse( - "dashboard.html", - { - "request": request, - "default_time": runtime.schedule_time(), - "username": username, - }, - ) - - -@app.get("/login", response_class=HTMLResponse) -async def login_page(request: Request): - session = _session_from_request(request) - if session: - if session.get("role") == "admin": - return RedirectResponse(url="/admin", status_code=303) - return RedirectResponse(url="/", status_code=303) - return templates.TemplateResponse("login.html", {"request": request}) - - -@app.get("/register", response_class=HTMLResponse) -async def register_page(request: Request): - session = _session_from_request(request) - if session: - if session.get("role") == "admin": - return RedirectResponse(url="/admin", status_code=303) - return RedirectResponse(url="/", status_code=303) - return templates.TemplateResponse("register.html", {"request": request}) - - -@app.get("/admin", response_class=HTMLResponse) -async def admin_page(request: Request): - session = _session_from_request(request) - if not session or session.get("role") != "admin": - return templates.TemplateResponse( - "admin_login.html", - { - "request": request, - "password_missing": not bool(os.getenv("PASSWORD")), - }, - ) - return templates.TemplateResponse("admin.html", {"request": request}) - - -@app.post("/api/login") -async def api_login(payload: UserLoginPayload): - username = payload.username.strip() - if not username: - return JSONResponse(status_code=400, content={"ok": False, "message": "用户名不能为空。"}) - - users_map = _load_users_meta() - user = users_map.get(username) - if not user: - return JSONResponse(status_code=401, content={"ok": False, "message": "用户名或密码错误。"}) - - if not _verify_password(payload.password, user.get("password_salt", ""), user.get("password_hash", "")): - return JSONResponse(status_code=401, content={"ok": False, "message": "用户名或密码错误。"}) - - token = secrets.token_urlsafe(32) - AUTH_SESSIONS[token] = {"role": "user", "username": username} - - response = JSONResponse({"ok": True, "message": "登录成功。"}) - response.set_cookie( - key=SESSION_COOKIE_NAME, - value=token, - httponly=True, - samesite="lax", - max_age=7 * 24 * 3600, - ) - return response - - -@app.post("/api/admin/login") -async def api_admin_login(payload: AdminLoginPayload): - expected_password = os.getenv("PASSWORD") - if not expected_password: - return JSONResponse( - status_code=500, - content={"ok": False, "message": "服务端未配置 PASSWORD 环境变量。"}, - ) - - if payload.password != expected_password: - return JSONResponse(status_code=401, content={"ok": False, "message": "密码错误。"}) - - token = secrets.token_urlsafe(32) - AUTH_SESSIONS[token] = {"role": "admin", "username": "admin"} - response = JSONResponse({"ok": True, "message": "登录成功。"}) - response.set_cookie( - key=SESSION_COOKIE_NAME, - value=token, - httponly=True, - samesite="lax", - max_age=7 * 24 * 3600, - ) - return response - - -@app.post("/api/register") -async def api_register(password: str = Form(...), users_file: UploadFile = File(...)): - if len(password.strip()) < 4: - return JSONResponse(status_code=400, content={"ok": False, "message": "密码至少 4 位。"}) - - if not users_file.filename.lower().endswith(".json"): - return JSONResponse(status_code=400, content={"ok": False, "message": "请上传 usersData.json 文件。"}) - - try: - raw = await users_file.read() - users_data, username, unique_id = _validate_and_normalize_users_data(raw) - except Exception as exc: - return JSONResponse(status_code=400, content={"ok": False, "message": str(exc)}) - - users_map = _load_users_meta() - if username in users_map: - return JSONResponse(status_code=409, content={"ok": False, "message": f"用户名 {username} 已注册。"}) - - for existing in users_map.values(): - if str(existing.get("unique_id", "")).strip() == unique_id: - return JSONResponse(status_code=409, content={"ok": False, "message": f"unique_id {unique_id} 已注册。"}) - - default_config = _get_default_user_config() - default_config.setdefault("scheduler", {}) - default_config["scheduler"].setdefault("enabled", True) - default_config["scheduler"].setdefault("timezone", DEFAULT_TIMEZONE) - default_config["scheduler"].setdefault("hour", 9) - default_config["scheduler"].setdefault("minute", 0) - default_config["scheduler"].setdefault("runOnStartup", False) - - hash_data = _hash_password(password.strip()) - created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - try: - _create_user_record( - username=username, - unique_id=unique_id, - password_hash=hash_data["hash"], - password_salt=hash_data["salt"], - created_at=created_at, - config_payload=default_config, - users_data_payload=users_data, - ) - except pymysql.err.IntegrityError: - return JSONResponse(status_code=409, content={"ok": False, "message": "用户名或 unique_id 已注册。"}) - - _schedule_user_job(username) - _get_runtime(username).add_log("用户已注册并完成定时任务初始化") - - return { - "ok": True, - "message": "注册成功,请使用用户名和密码登录。", - "username": username, - } - - -@app.post("/api/logout") -async def api_logout(request: Request): - token = request.cookies.get(SESSION_COOKIE_NAME) - if token: - AUTH_SESSIONS.pop(token, None) - response = JSONResponse({"ok": True}) - response.delete_cookie(SESSION_COOKIE_NAME) - return response - - -@app.get("/api/status") -async def api_status(request: Request): - session = _require_user_session(request) - username = session["username"] - runtime = _get_runtime(username) - users_data = _load_user_users_data(username) - return { - "ok": True, - "runtime": runtime.snapshot( - account_count=len(users_data), - target_count=_count_targets(users_data), - ), - "history": runtime.history_rows(), - } - - -@app.get("/api/logs") -async def api_logs(request: Request, limit: int = MAX_LOG_LINES): - session = _require_user_session(request) - username = session["username"] - runtime = _get_runtime(username) - limit = min(max(100, limit), 3000) - return {"ok": True, "logs": runtime.recent_logs(limit=limit)} - - -@app.post("/api/run") -async def api_run(request: Request): - session = _require_user_session(request) - username = session["username"] - runtime = _get_runtime(username) - - if runtime.is_running: - return JSONResponse( - status_code=409, - content={"ok": False, "message": "已有任务正在执行,请稍后再试。"}, - ) - - _start_background_run(username, "manual") - return {"ok": True, "message": "任务已开始执行。"} - - -@app.post("/api/schedule") -async def api_schedule(request: Request, payload: SchedulePayload): - session = _require_user_session(request) - username = session["username"] - - try: - hour, minute = _parse_time_string(payload.time) - except Exception as exc: - return JSONResponse(status_code=400, content={"ok": False, "message": str(exc)}) - - cfg = _load_user_config(username) - scheduler_cfg = cfg.setdefault("scheduler", {}) - scheduler_cfg["enabled"] = True - scheduler_cfg["hour"] = hour - scheduler_cfg["minute"] = minute - scheduler_cfg["timezone"] = str(scheduler_cfg.get("timezone", DEFAULT_TIMEZONE)) - scheduler_cfg["runOnStartup"] = bool(scheduler_cfg.get("runOnStartup", False)) - _save_user_config(username, cfg) - - _schedule_user_job(username) - runtime = _get_runtime(username) - return { - "ok": True, - "message": f"定时任务已更新为每天 {hour:02d}:{minute:02d}。", - "time": f"{hour:02d}:{minute:02d}", - "next_run": runtime.snapshot(0, 0)["next_run"], - } - - -@app.get("/api/editor/state") -async def api_editor_state(request: Request): - session = _require_user_session(request) - username = session["username"] - return {"ok": True, **_build_editor_state(username)} - - -@app.post("/api/editor/message") -async def api_editor_message(request: Request, payload: MessageTemplatePayload): - session = _require_user_session(request) - username = session["username"] - - message = payload.message.strip() - if not message: - return JSONResponse(status_code=400, content={"ok": False, "message": "消息内容不能为空。"}) - if len(message) > MAX_TEMPLATE_LENGTH: - return JSONResponse( - status_code=400, - content={"ok": False, "message": f"消息内容过长,最多 {MAX_TEMPLATE_LENGTH} 字符。"}, - ) - - cfg = _load_user_config(username) - cfg["messageTemplate"] = message - _save_user_config(username, cfg) - _get_runtime(username).add_log("消息模板已更新") - return {"ok": True, "message": "消息模板已保存。"} - - -@app.post("/api/editor/targets") -async def api_editor_targets(request: Request, payload: UserTargetsPayload): - session = _require_user_session(request) - username = session["username"] - - users_data = _load_user_users_data(username) - updates = {item.unique_id: _sanitize_targets(item.targets) for item in payload.users} - - updated = 0 - for user in users_data: - uid = str(user.get("unique_id", "")) - if uid in updates: - user["targets"] = updates[uid] - updated += 1 - - _save_user_users_data(username, users_data) - _get_runtime(username).add_log(f"目标好友已更新,涉及账号数:{updated}") - return {"ok": True, "message": f"目标好友已保存({updated} 个账号)。"} - - -@app.get("/api/admin/overview") -async def api_admin_overview(request: Request): - _require_admin_session(request) - try: - users_map = _load_users_meta() - if not scheduler_bootstrapped: - _sync_user_jobs_from_meta(users_map, run_startup_tasks=False) - except Exception as exc: - logger.warning("Admin overview failed to reach database. error=%s", exc) - return { - "ok": True, - "users": [], - "task_count": 0, - "db_status": _build_db_status_payload(), - "message": f"无法连接 SQL 服务器:{exc}", - } - - rows = [] - for username, meta in sorted(users_map.items(), key=lambda x: x[0]): - try: - cfg = _load_user_config(username) - users_data = _load_user_users_data(username) - except Exception as exc: - rows.append( - { - "username": username, - "unique_id": meta.get("unique_id", ""), - "created_at": meta.get("created_at", "-"), - "error": str(exc), - } - ) - continue - - scheduler_cfg = cfg.get("scheduler", {}) - runtime = _get_runtime(username) - runtime_snapshot = runtime.snapshot( - account_count=len(users_data), - target_count=_count_targets(users_data), - ) - - receivers = [] - for item in users_data: - receivers.extend(item.get("targets", [])) - - rows.append( - { - "username": username, - "unique_id": meta.get("unique_id", ""), - "created_at": meta.get("created_at", "-"), - "scheduler_enabled": bool(scheduler_cfg.get("enabled", True)), - "schedule_time": f"{int(scheduler_cfg.get('hour', 9)):02d}:{int(scheduler_cfg.get('minute', 0)):02d}", - "schedule_timezone": str(scheduler_cfg.get("timezone", DEFAULT_TIMEZONE)), - "message_template": str(cfg.get("messageTemplate", "")), - "targets": receivers, - "target_count": len(receivers), - "next_run": runtime_snapshot.get("next_run", "-"), - "last_status": runtime_snapshot.get("last_status", "-"), - "last_start": runtime_snapshot.get("last_start", "-"), - "is_running": runtime_snapshot.get("is_running", False), - "can_retry": bool( - not runtime_snapshot.get("is_running", False) - and runtime_snapshot.get("last_status") == "失败" - ), - } - ) - - return { - "ok": True, - "users": rows, - "task_count": len(rows), - "db_status": _build_db_status_payload(), - } - - -@app.get("/api/admin/tasks/{username}") -async def api_admin_task_detail(request: Request, username: str, log_limit: int = MAX_LOG_LINES): - _require_admin_session(request) - username = username.strip() - user_meta = _get_user_meta_or_404(username) - - try: - cfg = _load_user_config(username) - users_data = _load_user_users_data(username) - except Exception as exc: - return JSONResponse( - status_code=500, - content={"ok": False, "message": f"加载任务详情失败:{exc}"}, - ) - - scheduler_cfg = cfg.get("scheduler", {}) - runtime = _get_runtime(username) - target_count = _count_targets(users_data) - snapshot = runtime.snapshot(account_count=len(users_data), target_count=target_count) - - accounts = [] - all_targets = [] - for item in users_data: - targets = _sanitize_targets(item.get("targets", [])) - all_targets.extend(targets) - accounts.append( - { - "username": str(item.get("username", "未知用户")), - "unique_id": str(item.get("unique_id", "")), - "target_count": len(targets), - "targets": targets, - "cookie_count": len(item.get("cookies", [])) if isinstance(item.get("cookies", []), list) else 0, - } - ) - - log_limit = min(max(100, log_limit), 3000) - return { - "ok": True, - "task": { - "username": username, - "unique_id": user_meta.get("unique_id", ""), - "created_at": user_meta.get("created_at", "-"), - "scheduler_enabled": bool(scheduler_cfg.get("enabled", True)), - "schedule_time": f"{int(scheduler_cfg.get('hour', 9)):02d}:{int(scheduler_cfg.get('minute', 0)):02d}", - "schedule_timezone": str(scheduler_cfg.get("timezone", DEFAULT_TIMEZONE)), - "message_template": str(cfg.get("messageTemplate", "")), - "targets": all_targets, - "target_count": len(all_targets), - "runtime": snapshot, - "can_retry": bool(not snapshot.get("is_running", False) and snapshot.get("last_status") == "失败"), - "history": runtime.history_rows(), - "logs": runtime.recent_logs(limit=log_limit), - "config": { - "multiTask": bool(cfg.get("multiTask", True)), - "taskCount": int(cfg.get("taskCount", 1) or 1), - "hitokotoTypes": cfg.get("hitokotoTypes", []), - "proxyAddress": str(cfg.get("proxyAddress", "")), - }, - "accounts": accounts, - }, - } - - -@app.post("/api/admin/tasks/{username}/retry") -async def api_admin_retry_task(request: Request, username: str): - _require_admin_session(request) - username = username.strip() - _get_user_meta_or_404(username) - - runtime = _get_runtime(username) - snapshot = runtime.snapshot(account_count=0, target_count=0) - if snapshot.get("is_running"): - return JSONResponse(status_code=409, content={"ok": False, "message": "任务正在运行中,请稍后再试。"}) - if snapshot.get("last_status") != "失败": - return JSONResponse(status_code=409, content={"ok": False, "message": "当前任务不是失败状态,无需重试。"}) - - runtime.add_log("管理员手动触发失败任务重试") - _start_background_run(username, "admin_retry") - return {"ok": True, "message": f"已开始重试 {username} 的失败任务。"} - - -@app.post("/api/admin/tasks/retry-failed") -async def api_admin_retry_all_failed_tasks(request: Request): - _require_admin_session(request) - - try: - usernames = _retry_failed_tasks_once("admin_bulk_retry", raise_on_db_error=True) - except Exception as exc: - return JSONResponse(status_code=503, content={"ok": False, "message": f"???? SQL ????{exc}"}) - - if not usernames: - return {"ok": True, "message": "?????????????", "count": 0, "usernames": []} - - return { - "ok": True, - "message": f"??????? {len(usernames)} ??????", - "count": len(usernames), - "usernames": usernames, - } - - -@app.post("/api/admin/tasks/{username}/delete") -async def api_admin_delete_task(request: Request, username: str): - _require_admin_session(request) - username = username.strip() - _get_user_meta_or_404(username) - - cfg = _load_user_config(username) - scheduler_cfg = cfg.setdefault("scheduler", {}) - scheduler_cfg["enabled"] = False - _save_user_config(username, cfg) - - _remove_user_schedule_job(username) - runtime = _get_runtime(username) - runtime.update_next_run(None) - runtime.add_log("管理员已删除(禁用)该用户定时任务") - - return {"ok": True, "message": f"已删除用户 {username} 的定时任务。"} - - -@app.delete("/api/admin/users/{username}") -async def api_admin_delete_user(request: Request, username: str): - _require_admin_session(request) - username = username.strip() - - _get_user_meta_or_404(username) - - _remove_user_schedule_job(username) - _delete_user_record(username) - _delete_runtime(username) - - return {"ok": True, "message": f"用户 {username} 已删除。"} - - -@app.get("/health") -async def health(): - return {"ok": True, "status": "alive"} - - -def run_server(): - port = int(os.getenv("PORT", "7860")) - uvicorn.run("app:app", host="0.0.0.0", port=port, workers=1) - - -if __name__ == "__main__": - run_server() +import asyncio +import atexit +import hashlib +import json +import logging +import os +import secrets +import threading +import traceback +import time +from collections import deque +from datetime import datetime +from pathlib import Path +from typing import Any, Optional +from urllib.parse import parse_qsl, unquote, urlsplit + +import pymysql +import uvicorn +from apscheduler.schedulers.background import BackgroundScheduler +from apscheduler.triggers.cron import CronTrigger +from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile, status +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse +from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates +from pydantic import BaseModel, Field + +from core.tasks import runTasks +from utils.logger import setup_logger + + +logger = setup_logger(level=logging.DEBUG) + +BASE_DIR = Path(__file__).resolve().parent +TEMPLATES_DIR = BASE_DIR / "templates" +STATIC_DIR = BASE_DIR / "static" +ROOT_CONFIG_PATH = BASE_DIR / "config.json" +LEGACY_DATA_DIR = BASE_DIR / "data" +LEGACY_USERS_META_PATH = LEGACY_DATA_DIR / "users.json" +MYSQL_DSN_TEMPLATE = "mysql://SQL_PASSWORD@mysql-2bace9cd-cacode.i.aivencloud.com:21260/defaultdb?ssl-mode=REQUIRED" +MYSQL_DSN_ENV = "MYSQL_DSN_TEMPLATE" +MYSQL_PASSWORD_ENV = "SQL_PASSWORD" +MYSQL_USER_ENV = "MYSQL_USER" +MYSQL_CA_CERT_ENV = "MYSQL_CA_CERT_PATH" +MYSQL_DEFAULT_USER = "avnadmin" +USERS_TABLE = "app_users" +SESSION_COOKIE_NAME = "sparkflow_auth" +DEFAULT_TIMEZONE = "Asia/Shanghai" +MAX_LOG_LINES = 1200 +MAX_TEMPLATE_LENGTH = 2000 +PASSWORD_ITERATIONS = 210000 +FAILED_RETRY_JOB_ID = "_system_retry_failed_tasks" +FAILED_RETRY_INTERVAL_HOURS = 1 + +DEFAULT_USER_CONFIG = { + "multiTask": True, + "taskCount": 5, + "proxyAddress": "", + "messageTemplate": "[续火花]", + "hitokotoTypes": ["文学", "影视", "诗词", "哲学"], + "scheduler": { + "enabled": True, + "timezone": DEFAULT_TIMEZONE, + "hour": 9, + "minute": 0, + "runOnStartup": False, + }, +} + +AUTH_SESSIONS: dict[str, dict[str, str]] = {} +db_init_lock = threading.Lock() +scheduler_lock = threading.Lock() +runtime_map_lock = threading.Lock() +db_initialized = False +db_status_lock = threading.Lock() +db_status = { + "connected": None, + "last_checked_at": None, + "last_ok_at": None, + "last_error": "", +} +scheduler_bootstrapped = False +scheduler_bootstrap_lock = threading.Lock() +scheduler_bootstrap_running = False + + +class UserRuntimeState: + def __init__(self, username: str): + self.username = username + self._run_lock = threading.Lock() + self._state_lock = threading.Lock() + self.is_running = False + self.last_status = "未开始" + self.last_error = "" + self.last_trigger = "-" + self.last_start = None + self.last_end = None + self.next_run = None + self.schedule_hour = 9 + self.schedule_minute = 0 + self.schedule_timezone = DEFAULT_TIMEZONE + self.history = deque(maxlen=50) + self.logs = deque(maxlen=2000) + + def _format_ts(self, value: Optional[datetime]): + if not value: + return "-" + return value.strftime("%Y-%m-%d %H:%M:%S") + + def schedule_time(self): + return f"{self.schedule_hour:02d}:{self.schedule_minute:02d}" + + def _set_running(self, value: bool): + with self._state_lock: + self.is_running = value + + def add_log(self, message: str): + ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + with self._state_lock: + self.logs.append(f"{ts} [{self.username}] {message}") + + def update_schedule(self, hour: int, minute: int, timezone: str): + with self._state_lock: + self.schedule_hour = hour + self.schedule_minute = minute + self.schedule_timezone = timezone + + def update_next_run(self, next_run): + with self._state_lock: + self.next_run = next_run + + def snapshot(self, account_count: int, target_count: int): + with self._state_lock: + return { + "is_running": self.is_running, + "last_status": self.last_status, + "last_error": self.last_error, + "last_trigger": self.last_trigger, + "last_start": self._format_ts(self.last_start), + "last_end": self._format_ts(self.last_end), + "next_run": self._format_ts(self.next_run), + "account_count": account_count, + "target_count": target_count, + "schedule_time": self.schedule_time(), + "schedule_timezone": self.schedule_timezone, + } + + def history_rows(self): + with self._state_lock: + return list(self.history)[::-1] + + def recent_logs(self, limit=MAX_LOG_LINES): + with self._state_lock: + lines = list(self.logs)[-max(1, limit):] + return "\n".join(lines) if lines else "暂无日志。" + + def run_once(self, trigger: str): + if not self._run_lock.acquire(blocking=False): + self.add_log(f"任务已在运行中,忽略触发:{trigger}") + return False, "已有任务在运行,本次触发已跳过。" + + self._set_running(True) + with self._state_lock: + self.last_trigger = trigger + self.last_start = datetime.now() + self.last_end = None + self.last_error = "" + self.last_status = "运行中" + self.add_log(f"任务开始执行,触发方式:{trigger}") + + ok = True + message = "任务执行完成。" + try: + asyncio.run(_run_user_tasks(self.username)) + with self._state_lock: + self.last_status = "成功" + except Exception as exc: + ok = False + message = f"任务执行失败:{exc}" + with self._state_lock: + self.last_status = "失败" + self.last_error = repr(exc) + self.add_log(f"任务失败:{exc}") + logger.error("Task failed. user=%s trigger=%s error=%s", self.username, trigger, exc) + logger.debug("Task traceback:\n%s", traceback.format_exc()) + finally: + end_at = datetime.now() + with self._state_lock: + self.last_end = end_at + duration = (self.last_end - self.last_start).total_seconds() + self.history.append( + { + "trigger": trigger, + "start": self._format_ts(self.last_start), + "end": self._format_ts(self.last_end), + "status": self.last_status, + "duration": f"{duration:.2f}s", + "message": self.last_error or "OK", + } + ) + current_status = self.last_status + self.add_log(f"任务结束,状态={current_status},耗时={duration:.2f}s") + self._set_running(False) + self._run_lock.release() + return ok, message + + +runtime_map: dict[str, UserRuntimeState] = {} +scheduler = None + + +class UserLoginPayload(BaseModel): + username: str + password: str + + +class AdminLoginPayload(BaseModel): + password: str + + +class SchedulePayload(BaseModel): + time: str + + +class MessageTemplatePayload(BaseModel): + message: str + + +class UserTargetsItem(BaseModel): + unique_id: str + targets: list[str] = Field(default_factory=list) + + +class UserTargetsPayload(BaseModel): + users: list[UserTargetsItem] + + +def _ensure_data_layout(): + global db_initialized + if db_initialized: + return + + logger.info("DB layout ensure begin.") + with db_init_lock: + if db_initialized: + logger.info("DB layout already initialized by another worker.") + return + logger.info("DB schema initialization begin.") + _init_db_schema() + db_initialized = True + logger.info("DB schema initialization complete.") + + try: + logger.info("Legacy migration stage begin.") + _migrate_legacy_file_data_if_needed() + logger.info("Legacy migration stage complete.") + except Exception as exc: + logger.warning("Legacy data migration skipped due to error: %s", exc) + + +def _hash_password(password: str, salt_hex: Optional[str] = None): + salt = bytes.fromhex(salt_hex) if salt_hex else secrets.token_bytes(16) + digest = hashlib.pbkdf2_hmac( + "sha256", + password.encode("utf-8"), + salt, + PASSWORD_ITERATIONS, + ) + return { + "salt": salt.hex(), + "hash": digest.hex(), + } + + +def _verify_password(password: str, salt_hex: str, expected_hash: str): + data = _hash_password(password, salt_hex=salt_hex) + return secrets.compare_digest(data["hash"], expected_hash) + + +def _deep_copy_json(value): + return json.loads(json.dumps(value, ensure_ascii=False)) + + +def _merge_config_with_defaults(raw_cfg: Any): + base = _deep_copy_json(DEFAULT_USER_CONFIG) + if not isinstance(raw_cfg, dict): + return base + + merged = _deep_copy_json(base) + merged.update(raw_cfg) + base_scheduler = base.get("scheduler", {}) + merged_scheduler = raw_cfg.get("scheduler", {}) + if isinstance(merged_scheduler, dict): + scheduler = _deep_copy_json(base_scheduler) + scheduler.update(merged_scheduler) + merged["scheduler"] = scheduler + else: + merged["scheduler"] = _deep_copy_json(base_scheduler) + return merged + + +def _format_common_ts(value: Optional[datetime]): + if not value: + return "-" + return value.strftime("%Y-%m-%d %H:%M:%S") + + +def _update_db_status(connected: bool, error: Optional[Exception] = None): + now = datetime.now() + with db_status_lock: + db_status["connected"] = connected + db_status["last_checked_at"] = now + if connected: + db_status["last_ok_at"] = now + db_status["last_error"] = "" + else: + db_status["last_error"] = str(error or "数据库连接失败") + + +def _build_db_status_payload(): + with db_status_lock: + connected = db_status.get("connected") + return { + "connected": connected, + "last_checked_at": _format_common_ts(db_status.get("last_checked_at")), + "last_ok_at": _format_common_ts(db_status.get("last_ok_at")), + "last_error": str(db_status.get("last_error") or ""), + } + + +def _resolve_mysql_dsn(): + raw = os.getenv(MYSQL_DSN_ENV, MYSQL_DSN_TEMPLATE).strip() + if "SQL_PASSWORD" in raw: + secret = os.getenv(MYSQL_PASSWORD_ENV, "").strip() + if not secret: + raise RuntimeError(f"环境变量 {MYSQL_PASSWORD_ENV} 未设置,无法连接 MySQL。") + raw = raw.replace("SQL_PASSWORD", secret, 1) + return raw + + +def _build_mysql_conn_kwargs(): + dsn = _resolve_mysql_dsn() + parsed = urlsplit(dsn) + if parsed.scheme not in ("mysql", "mysql+pymysql"): + raise RuntimeError(f"不支持的 MySQL DSN 协议:{parsed.scheme}") + + host = parsed.hostname + if not host: + raise RuntimeError("MySQL DSN 缺少主机地址。") + + user = unquote(parsed.username or "") + password = unquote(parsed.password) if parsed.password is not None else None + if user and password is None: + password = user + user = os.getenv(MYSQL_USER_ENV, MYSQL_DEFAULT_USER).strip() or MYSQL_DEFAULT_USER + if not user: + user = os.getenv(MYSQL_USER_ENV, MYSQL_DEFAULT_USER).strip() or MYSQL_DEFAULT_USER + if not password: + password = os.getenv(MYSQL_PASSWORD_ENV, "").strip() + if not password: + raise RuntimeError("MySQL 密码为空,请检查 SQL_PASSWORD 环境变量。") + + db_name = parsed.path.lstrip("/") or "defaultdb" + query = {k.lower(): v for k, v in parse_qsl(parsed.query, keep_blank_values=True)} + ssl_mode = str(query.get("ssl-mode", query.get("ssl_mode", ""))).upper() + + kwargs = { + "host": host, + "port": parsed.port or 3306, + "user": user, + "password": password, + "database": db_name, + "charset": "utf8mb4", + "autocommit": True, + "connect_timeout": int(os.getenv("MYSQL_CONNECT_TIMEOUT", "4")), + "read_timeout": int(os.getenv("MYSQL_READ_TIMEOUT", "8")), + "write_timeout": int(os.getenv("MYSQL_WRITE_TIMEOUT", "8")), + "cursorclass": pymysql.cursors.DictCursor, + } + + if ssl_mode in {"REQUIRED", "VERIFY_CA", "VERIFY_IDENTITY"}: + ca_file = Path(os.getenv(MYSQL_CA_CERT_ENV, str(BASE_DIR / "camysql2.pem"))).resolve() + if not ca_file.exists(): + raise RuntimeError(f"MySQL CA 证书不存在:{ca_file}") + kwargs["ssl"] = {"ca": str(ca_file)} + + return kwargs + + +def _db_connect(): + kwargs = _build_mysql_conn_kwargs() + host = kwargs.get("host") + port = kwargs.get("port") + database = kwargs.get("database") + connect_timeout = kwargs.get("connect_timeout") + read_timeout = kwargs.get("read_timeout") + write_timeout = kwargs.get("write_timeout") + has_ssl = bool(kwargs.get("ssl")) + started_at = time.perf_counter() + + logger.info( + "MySQL connect begin. host=%s port=%s db=%s connect_timeout=%ss read_timeout=%ss write_timeout=%ss ssl=%s", + host, + port, + database, + connect_timeout, + read_timeout, + write_timeout, + has_ssl, + ) + try: + conn = pymysql.connect(**kwargs) + except Exception as exc: + elapsed = time.perf_counter() - started_at + logger.warning( + "MySQL connect failed after %.2fs. host=%s port=%s db=%s error=%s", + elapsed, + host, + port, + database, + exc, + ) + _update_db_status(False, exc) + raise + + elapsed = time.perf_counter() - started_at + logger.info("MySQL connect success. host=%s port=%s db=%s elapsed=%.2fs", host, port, database, elapsed) + _update_db_status(True) + return conn + + +def _db_query_all(query: str, params=()): + conn = _db_connect() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + return cursor.fetchall() + finally: + conn.close() + + +def _db_query_one(query: str, params=()): + conn = _db_connect() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + return cursor.fetchone() + finally: + conn.close() + + +def _db_execute(query: str, params=()): + conn = _db_connect() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + return cursor.rowcount + finally: + conn.close() + + +def _init_db_schema(): + logger.info("DB schema creation SQL begin.") + _db_execute( + f""" + CREATE TABLE IF NOT EXISTS `{USERS_TABLE}` ( + `username` VARCHAR(128) NOT NULL, + `unique_id` VARCHAR(255) NOT NULL, + `password_hash` VARCHAR(128) NOT NULL, + `password_salt` VARCHAR(64) NOT NULL, + `created_at` VARCHAR(32) NOT NULL, + `config_json` LONGTEXT NOT NULL, + `users_data_json` LONGTEXT NOT NULL, + `updated_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`username`), + UNIQUE KEY `uniq_unique_id` (`unique_id`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + ) + logger.info("DB schema creation SQL complete.") + + +def _legacy_load_json(path: Path, default): + if not path.exists(): + return default + with path.open("r", encoding="utf-8") as f: + return json.load(f) + + +def _migrate_legacy_file_data_if_needed(): + logger.info("Legacy migration check. path=%s", LEGACY_USERS_META_PATH) + if not LEGACY_USERS_META_PATH.exists(): + logger.info("Legacy migration skipped: users.json not found.") + return + + row = _db_query_one(f"SELECT COUNT(*) AS cnt FROM `{USERS_TABLE}`") + existing_count = int(row.get("cnt", 0)) if row else 0 + if existing_count > 0: + logger.info("Legacy migration skipped: database already has %s users.", existing_count) + return + + try: + raw = _legacy_load_json(LEGACY_USERS_META_PATH, {"users": []}) + except Exception as exc: + logger.warning("读取旧版 users.json 失败,跳过迁移:%s", exc) + return + + users = raw.get("users", []) if isinstance(raw, dict) else [] + if not users: + logger.info("Legacy migration skipped: legacy users list is empty.") + return + + logger.info("Legacy migration loaded %s legacy users.", len(users)) + migrated = 0 + for item in users: + username = str(item.get("username", "")).strip() + unique_id = str(item.get("unique_id", "")).strip() + password_hash = str(item.get("password_hash", "")).strip() + password_salt = str(item.get("password_salt", "")).strip() + created_at = str(item.get("created_at", "")).strip() or datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + if not (username and unique_id and password_hash and password_salt): + logger.warning("旧版用户数据不完整,跳过:%s", username or "") + continue + + cfg = _get_default_user_config() + users_data = [] + tenant_rel = str(item.get("tenant_dir", "")).strip() + if tenant_rel: + tenant_dir = (BASE_DIR / tenant_rel).resolve() + cfg_path = tenant_dir / "config.json" + users_data_path = tenant_dir / "usersData.json" + try: + cfg = _merge_config_with_defaults(_legacy_load_json(cfg_path, cfg)) + except Exception as exc: + logger.warning("读取旧版配置失败,使用默认配置。user=%s error=%s", username, exc) + try: + users_data = _legacy_load_json(users_data_path, []) + except Exception as exc: + logger.warning("读取旧版 usersData 失败。user=%s error=%s", username, exc) + + if not isinstance(users_data, list): + users_data = [] + + try: + _create_user_record( + username=username, + unique_id=unique_id, + password_hash=password_hash, + password_salt=password_salt, + created_at=created_at, + config_payload=cfg, + users_data_payload=users_data, + ) + migrated += 1 + except Exception as exc: + logger.warning("迁移用户失败。user=%s error=%s", username, exc) + + logger.info("Legacy migration finished. migrated=%s total=%s", migrated, len(users)) + if migrated > 0: + logger.info("已完成旧版文件数据迁移,共迁移 %s 个用户。", migrated) + + +def _load_users_meta(): + logger.info("Load users meta begin.") + _ensure_data_layout() + rows = _db_query_all( + f""" + SELECT username, unique_id, password_hash, password_salt, created_at + FROM `{USERS_TABLE}` + ORDER BY username ASC + """ + ) + logger.info("Load users meta complete. count=%s", len(rows)) + return {str(row["username"]): row for row in rows} + + +def _load_user_row(username: str): + _ensure_data_layout() + return _db_query_one( + f""" + SELECT username, unique_id, password_hash, password_salt, created_at, config_json, users_data_json + FROM `{USERS_TABLE}` + WHERE username=%s + """, + (username,), + ) + + +def _user_exists(username: str): + _ensure_data_layout() + row = _db_query_one( + f"SELECT 1 AS ok FROM `{USERS_TABLE}` WHERE username=%s", + (username,), + ) + return bool(row) + + +def _create_user_record( + *, + username: str, + unique_id: str, + password_hash: str, + password_salt: str, + created_at: str, + config_payload: dict[str, Any], + users_data_payload: list[dict[str, Any]], +): + _db_execute( + f""" + INSERT INTO `{USERS_TABLE}` + (username, unique_id, password_hash, password_salt, created_at, config_json, users_data_json) + VALUES (%s, %s, %s, %s, %s, %s, %s) + """, + ( + username, + unique_id, + password_hash, + password_salt, + created_at, + json.dumps(config_payload, ensure_ascii=False), + json.dumps(users_data_payload, ensure_ascii=False), + ), + ) + + +def _delete_user_record(username: str): + _ensure_data_layout() + return _db_execute(f"DELETE FROM `{USERS_TABLE}` WHERE username=%s", (username,)) + + +def _get_user_meta_or_404(username: str): + users_map = _load_users_meta() + user = users_map.get(username) + if not user: + raise HTTPException(status_code=404, detail="用户不存在") + return user + + +def _get_default_user_config(): + if ROOT_CONFIG_PATH.exists(): + try: + with ROOT_CONFIG_PATH.open("r", encoding="utf-8") as f: + root_cfg = json.load(f) + return _merge_config_with_defaults(root_cfg) + except Exception: + logger.warning("Failed to read root config.json. fallback to DEFAULT_USER_CONFIG") + return _deep_copy_json(DEFAULT_USER_CONFIG) + + +def _load_user_config(username: str): + row = _load_user_row(username) + if not row: + raise FileNotFoundError(f"用户 {username} 不存在") + try: + payload = json.loads(row.get("config_json", "{}")) + except Exception as exc: + raise ValueError(f"用户 {username} 的配置数据损坏:{exc}") + return _merge_config_with_defaults(payload) + + +def _save_user_config(username: str, cfg: dict): + normalized = _merge_config_with_defaults(cfg) + changed = _db_execute( + f"UPDATE `{USERS_TABLE}` SET config_json=%s WHERE username=%s", + (json.dumps(normalized, ensure_ascii=False), username), + ) + if changed == 0 and not _user_exists(username): + raise FileNotFoundError(f"用户 {username} 不存在") + + +def _load_user_users_data(username: str): + row = _load_user_row(username) + if not row: + raise FileNotFoundError(f"用户 {username} 不存在") + try: + data = json.loads(row.get("users_data_json", "[]")) + except Exception as exc: + raise ValueError(f"用户 {username} 的 usersData 数据损坏:{exc}") + if not isinstance(data, list): + raise ValueError("usersData.json 必须是数组") + return data + + +def _save_user_users_data(username: str, users_data: list): + changed = _db_execute( + f"UPDATE `{USERS_TABLE}` SET users_data_json=%s WHERE username=%s", + (json.dumps(users_data, ensure_ascii=False), username), + ) + if changed == 0 and not _user_exists(username): + raise FileNotFoundError(f"用户 {username} 不存在") + + +def _sanitize_targets(values): + cleaned = [] + seen = set() + for value in values or []: + text = str(value).strip() + if not text or text in seen: + continue + seen.add(text) + cleaned.append(text) + return cleaned + + +def _validate_and_normalize_users_data(raw_bytes: bytes): + try: + payload = json.loads(raw_bytes.decode("utf-8")) + except Exception as exc: + raise ValueError(f"上传文件不是合法 JSON:{exc}") + + if not isinstance(payload, list) or not payload: + raise ValueError("usersData.json 必须是非空数组") + + normalized = [] + for idx, item in enumerate(payload): + if not isinstance(item, dict): + raise ValueError(f"第 {idx + 1} 条用户数据格式错误(必须是对象)") + + unique_id = str(item.get("unique_id", "")).strip() + username = str(item.get("username", "")).strip() + cookies = item.get("cookies", []) + targets = item.get("targets", []) + + if not unique_id: + raise ValueError(f"第 {idx + 1} 条缺少 unique_id") + if not username: + raise ValueError(f"第 {idx + 1} 条缺少 username") + if not isinstance(cookies, list) or not cookies: + raise ValueError(f"第 {idx + 1} 条 cookies 不能为空且必须是数组") + if not isinstance(targets, list): + raise ValueError(f"第 {idx + 1} 条 targets 必须是数组") + + normalized.append( + { + "unique_id": unique_id, + "username": username, + "cookies": cookies, + "targets": _sanitize_targets(targets), + } + ) + + primary_username = normalized[0]["username"] + primary_unique_id = normalized[0]["unique_id"] + return normalized, primary_username, primary_unique_id + + +def _count_targets(users_data: list): + return sum(len(user.get("targets", [])) for user in users_data) + + +def _get_runtime(username: str): + with runtime_map_lock: + runtime = runtime_map.get(username) + if runtime is None: + runtime = UserRuntimeState(username=username) + runtime_map[username] = runtime + return runtime + + +def _delete_runtime(username: str): + with runtime_map_lock: + runtime_map.pop(username, None) + + +def _session_from_request(request: Request): + token = request.cookies.get(SESSION_COOKIE_NAME) + if not token: + return None + return AUTH_SESSIONS.get(token) + + +def _require_user_session(request: Request): + session = _session_from_request(request) + if not session or session.get("role") != "user": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="未登录或登录已失效", + ) + return session + + +def _require_admin_session(request: Request): + session = _session_from_request(request) + if not session or session.get("role") != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="未登录或登录已失效", + ) + return session + + +def _parse_time_string(value: str): + parts = value.strip().split(":") + if len(parts) not in (2, 3): + raise ValueError("时间格式错误,必须是 HH:MM") + hour = int(parts[0]) + minute = int(parts[1]) + if hour < 0 or hour > 23 or minute < 0 or minute > 59: + raise ValueError("时间范围错误,小时 0-23,分钟 0-59") + return hour, minute + + +def _build_editor_state(username: str): + cfg = _load_user_config(username) + users = _load_user_users_data(username) + return { + "message_template": str(cfg.get("messageTemplate", "")), + "users": [ + { + "unique_id": str(user.get("unique_id", "")), + "username": str(user.get("username", "未知用户")), + "targets": _sanitize_targets(user.get("targets", [])), + } + for user in users + ], + } + + +def _scheduler_job_id(username: str): + return f"daily_task::{username}" + + +def _run_scheduled_once(username: str): + runtime = _get_runtime(username) + runtime.run_once("schedule") + if scheduler: + job = scheduler.get_job(_scheduler_job_id(username)) + runtime.update_next_run(job.next_run_time if job else None) + + +async def _run_user_tasks(username: str): + cfg = _load_user_config(username) + users_data = _load_user_users_data(username) + await runTasks(config=cfg, userData=users_data) + + +def _sync_user_jobs_from_meta(users_map: dict[str, Any], run_startup_tasks: bool = False): + global scheduler_bootstrapped + + logger.info("Sync user jobs begin. count=%s run_startup_tasks=%s", len(users_map), run_startup_tasks) + for username in users_map.keys(): + logger.info("Sync user job. username=%s", username) + _schedule_user_job(username) + if run_startup_tasks: + cfg = _load_user_config(username) + run_on_startup = bool(cfg.get("scheduler", {}).get("runOnStartup", False)) + logger.info("Startup run flag loaded. username=%s run_on_startup=%s", username, run_on_startup) + if run_on_startup: + logger.info("Trigger startup run. username=%s", username) + _start_background_run(username, "startup") + + scheduler_bootstrapped = True + logger.info("Sync user jobs complete. count=%s", len(users_map)) + + +def _start_scheduler_bootstrap(run_startup_tasks: bool): + global scheduler_bootstrapped, scheduler_bootstrap_running + + with scheduler_bootstrap_lock: + if scheduler_bootstrap_running: + logger.info("Scheduler bootstrap already running; skip duplicate start.") + return False + scheduler_bootstrap_running = True + + def _worker(): + global scheduler_bootstrapped, scheduler_bootstrap_running + try: + logger.info("Scheduler bootstrap started. run_startup_tasks=%s", run_startup_tasks) + + logger.info("Bootstrap stage begin: ensure_data_layout") + _ensure_data_layout() + logger.info("Bootstrap stage complete: ensure_data_layout") + + logger.info("Bootstrap stage begin: load_users_meta") + users_map = _load_users_meta() + logger.info("Bootstrap stage complete: load_users_meta count=%s", len(users_map)) + + logger.info("Bootstrap stage begin: sync_user_jobs") + _sync_user_jobs_from_meta(users_map, run_startup_tasks=run_startup_tasks) + logger.info("Bootstrap stage complete: sync_user_jobs count=%s", len(users_map)) + + logger.info("Scheduler bootstrap completed. users=%s", len(users_map)) + except Exception as exc: + scheduler_bootstrapped = False + logger.warning("Scheduler bootstrap skipped, database unavailable. error=%s", exc) + finally: + with scheduler_bootstrap_lock: + scheduler_bootstrap_running = False + + thread = threading.Thread(target=_worker, daemon=True, name="scheduler-bootstrap") + thread.start() + return True + + +def _retry_failed_tasks_once(trigger: str, *, raise_on_db_error: bool = False): + try: + users_map = _load_users_meta() + if not scheduler_bootstrapped: + _sync_user_jobs_from_meta(users_map, run_startup_tasks=False) + except Exception as exc: + logger.warning("Failed to load users for failed-task retry. error=%s", exc) + if raise_on_db_error: + raise + return [] + + triggered = [] + for username in users_map.keys(): + runtime = _get_runtime(username) + snapshot = runtime.snapshot(account_count=0, target_count=0) + if snapshot.get("is_running") or snapshot.get("last_status") != "失败": + continue + + try: + cfg = _load_user_config(username) + except Exception as exc: + runtime.add_log(f"自动重试前加载配置失败:{exc}") + continue + + if not bool(cfg.get("scheduler", {}).get("enabled", True)): + continue + + runtime.add_log(f"检测到失败任务,准备执行自动重试:{trigger}") + _start_background_run(username, trigger) + triggered.append(username) + + if triggered: + logger.info("Retried failed tasks for users: %s", ", ".join(triggered)) + return triggered + + +def _retry_failed_tasks_job(): + _retry_failed_tasks_once("hourly_retry") + + +def _schedule_user_job(username: str): + global scheduler + + cfg = _load_user_config(username) + scheduler_cfg = cfg.get("scheduler", {}) if isinstance(cfg, dict) else {} + enabled = bool(scheduler_cfg.get("enabled", True)) + timezone = str(scheduler_cfg.get("timezone", DEFAULT_TIMEZONE)) + hour = int(scheduler_cfg.get("hour", 9)) + minute = int(scheduler_cfg.get("minute", 0)) + + runtime = _get_runtime(username) + runtime.update_schedule(hour, minute, timezone) + + with scheduler_lock: + if scheduler is None: + scheduler = BackgroundScheduler(timezone=timezone) + scheduler.start() + + job_id = _scheduler_job_id(username) + if not enabled: + if scheduler.get_job(job_id): + scheduler.remove_job(job_id) + runtime.update_next_run(None) + runtime.add_log("定时任务已禁用") + return + + scheduler.add_job( + _run_scheduled_once, + args=[username], + trigger=CronTrigger(hour=hour, minute=minute, timezone=timezone), + id=job_id, + replace_existing=True, + max_instances=1, + coalesce=True, + ) + job = scheduler.get_job(job_id) + runtime.update_next_run(job.next_run_time if job else None) + runtime.add_log(f"定时任务更新为 {hour:02d}:{minute:02d} ({timezone})") + + +def _remove_user_schedule_job(username: str): + with scheduler_lock: + if scheduler is None: + return + job_id = _scheduler_job_id(username) + if scheduler.get_job(job_id): + scheduler.remove_job(job_id) + + +def _start_background_run(username: str, trigger: str): + runtime = _get_runtime(username) + + def _worker(): + runtime.run_once(trigger) + if scheduler: + job = scheduler.get_job(_scheduler_job_id(username)) + runtime.update_next_run(job.next_run_time if job else None) + + thread = threading.Thread(target=_worker, daemon=True) + thread.start() + return True + + +def _start_scheduler(): + global scheduler + with scheduler_lock: + if scheduler is None: + scheduler = BackgroundScheduler(timezone=DEFAULT_TIMEZONE) + scheduler.start() + scheduler.add_job( + _retry_failed_tasks_job, + trigger="interval", + hours=FAILED_RETRY_INTERVAL_HOURS, + id=FAILED_RETRY_JOB_ID, + replace_existing=True, + max_instances=1, + coalesce=True, + ) + + _start_scheduler_bootstrap(run_startup_tasks=True) + + +def _stop_scheduler(): + global scheduler, scheduler_bootstrapped, scheduler_bootstrap_running + with scheduler_lock: + if scheduler and scheduler.running: + scheduler.shutdown(wait=False) + logger.info("Scheduler stopped.") + scheduler = None + scheduler_bootstrapped = False + with scheduler_bootstrap_lock: + scheduler_bootstrap_running = False + + +app = FastAPI(title="DouYin Spark Flow Dashboard") +app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") +templates = Jinja2Templates(directory=str(TEMPLATES_DIR)) + + +def _render_template(request: Request, name: str, context: Optional[dict[str, Any]] = None, **kwargs): + payload = {"request": request} + if context: + payload.update(context) + + try: + return templates.TemplateResponse(request=request, name=name, context=payload, **kwargs) + except TypeError: + return templates.TemplateResponse(name, payload, **kwargs) + + +@app.on_event("startup") +async def on_startup(): + logger.info("Application startup begin.") + _start_scheduler() + atexit.register(_stop_scheduler) + logger.info("Application startup complete.") + + +@app.on_event("shutdown") +async def on_shutdown(): + _stop_scheduler() + + +@app.get("/", response_class=HTMLResponse) +async def dashboard(request: Request): + session = _session_from_request(request) + if not session: + return RedirectResponse(url="/login", status_code=303) + if session.get("role") == "admin": + return RedirectResponse(url="/admin", status_code=303) + + username = session.get("username") + runtime = _get_runtime(username) + return _render_template( + request, + "dashboard.html", + { + "default_time": runtime.schedule_time(), + "username": username, + }, + ) + + +@app.get("/login", response_class=HTMLResponse) +async def login_page(request: Request): + session = _session_from_request(request) + if session: + if session.get("role") == "admin": + return RedirectResponse(url="/admin", status_code=303) + return RedirectResponse(url="/", status_code=303) + return _render_template(request, "login.html") + + +@app.get("/register", response_class=HTMLResponse) +async def register_page(request: Request): + session = _session_from_request(request) + if session: + if session.get("role") == "admin": + return RedirectResponse(url="/admin", status_code=303) + return RedirectResponse(url="/", status_code=303) + return _render_template(request, "register.html") + + +@app.get("/admin", response_class=HTMLResponse) +async def admin_page(request: Request): + session = _session_from_request(request) + if not session or session.get("role") != "admin": + return _render_template( + request, + "admin_login.html", + {"password_missing": not bool(os.getenv("PASSWORD"))}, + ) + return _render_template(request, "admin.html") + + +@app.post("/api/login") +async def api_login(payload: UserLoginPayload): + username = payload.username.strip() + if not username: + return JSONResponse(status_code=400, content={"ok": False, "message": "用户名不能为空。"}) + + users_map = _load_users_meta() + user = users_map.get(username) + if not user: + return JSONResponse(status_code=401, content={"ok": False, "message": "用户名或密码错误。"}) + + if not _verify_password(payload.password, user.get("password_salt", ""), user.get("password_hash", "")): + return JSONResponse(status_code=401, content={"ok": False, "message": "用户名或密码错误。"}) + + token = secrets.token_urlsafe(32) + AUTH_SESSIONS[token] = {"role": "user", "username": username} + + response = JSONResponse({"ok": True, "message": "登录成功。"}) + response.set_cookie( + key=SESSION_COOKIE_NAME, + value=token, + httponly=True, + samesite="lax", + max_age=7 * 24 * 3600, + ) + return response + + +@app.post("/api/admin/login") +async def api_admin_login(payload: AdminLoginPayload): + expected_password = os.getenv("PASSWORD") + if not expected_password: + return JSONResponse( + status_code=500, + content={"ok": False, "message": "服务端未配置 PASSWORD 环境变量。"}, + ) + + if payload.password != expected_password: + return JSONResponse(status_code=401, content={"ok": False, "message": "密码错误。"}) + + token = secrets.token_urlsafe(32) + AUTH_SESSIONS[token] = {"role": "admin", "username": "admin"} + response = JSONResponse({"ok": True, "message": "登录成功。"}) + response.set_cookie( + key=SESSION_COOKIE_NAME, + value=token, + httponly=True, + samesite="lax", + max_age=7 * 24 * 3600, + ) + return response + + +@app.post("/api/register") +async def api_register(password: str = Form(...), users_file: UploadFile = File(...)): + if len(password.strip()) < 4: + return JSONResponse(status_code=400, content={"ok": False, "message": "密码至少 4 位。"}) + + if not users_file.filename.lower().endswith(".json"): + return JSONResponse(status_code=400, content={"ok": False, "message": "请上传 usersData.json 文件。"}) + + try: + raw = await users_file.read() + users_data, username, unique_id = _validate_and_normalize_users_data(raw) + except Exception as exc: + return JSONResponse(status_code=400, content={"ok": False, "message": str(exc)}) + + users_map = _load_users_meta() + if username in users_map: + return JSONResponse(status_code=409, content={"ok": False, "message": f"用户名 {username} 已注册。"}) + + for existing in users_map.values(): + if str(existing.get("unique_id", "")).strip() == unique_id: + return JSONResponse(status_code=409, content={"ok": False, "message": f"unique_id {unique_id} 已注册。"}) + + default_config = _get_default_user_config() + default_config.setdefault("scheduler", {}) + default_config["scheduler"].setdefault("enabled", True) + default_config["scheduler"].setdefault("timezone", DEFAULT_TIMEZONE) + default_config["scheduler"].setdefault("hour", 9) + default_config["scheduler"].setdefault("minute", 0) + default_config["scheduler"].setdefault("runOnStartup", False) + + hash_data = _hash_password(password.strip()) + created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + try: + _create_user_record( + username=username, + unique_id=unique_id, + password_hash=hash_data["hash"], + password_salt=hash_data["salt"], + created_at=created_at, + config_payload=default_config, + users_data_payload=users_data, + ) + except pymysql.err.IntegrityError: + return JSONResponse(status_code=409, content={"ok": False, "message": "用户名或 unique_id 已注册。"}) + + _schedule_user_job(username) + _get_runtime(username).add_log("用户已注册并完成定时任务初始化") + + return { + "ok": True, + "message": "注册成功,请使用用户名和密码登录。", + "username": username, + } + + +@app.post("/api/logout") +async def api_logout(request: Request): + token = request.cookies.get(SESSION_COOKIE_NAME) + if token: + AUTH_SESSIONS.pop(token, None) + response = JSONResponse({"ok": True}) + response.delete_cookie(SESSION_COOKIE_NAME) + return response + + +@app.get("/api/status") +async def api_status(request: Request): + session = _require_user_session(request) + username = session["username"] + runtime = _get_runtime(username) + users_data = _load_user_users_data(username) + return { + "ok": True, + "runtime": runtime.snapshot( + account_count=len(users_data), + target_count=_count_targets(users_data), + ), + "history": runtime.history_rows(), + } + + +@app.get("/api/logs") +async def api_logs(request: Request, limit: int = MAX_LOG_LINES): + session = _require_user_session(request) + username = session["username"] + runtime = _get_runtime(username) + limit = min(max(100, limit), 3000) + return {"ok": True, "logs": runtime.recent_logs(limit=limit)} + + +@app.post("/api/run") +async def api_run(request: Request): + session = _require_user_session(request) + username = session["username"] + runtime = _get_runtime(username) + + if runtime.is_running: + return JSONResponse( + status_code=409, + content={"ok": False, "message": "已有任务正在执行,请稍后再试。"}, + ) + + _start_background_run(username, "manual") + return {"ok": True, "message": "任务已开始执行。"} + + +@app.post("/api/schedule") +async def api_schedule(request: Request, payload: SchedulePayload): + session = _require_user_session(request) + username = session["username"] + + try: + hour, minute = _parse_time_string(payload.time) + except Exception as exc: + return JSONResponse(status_code=400, content={"ok": False, "message": str(exc)}) + + cfg = _load_user_config(username) + scheduler_cfg = cfg.setdefault("scheduler", {}) + scheduler_cfg["enabled"] = True + scheduler_cfg["hour"] = hour + scheduler_cfg["minute"] = minute + scheduler_cfg["timezone"] = str(scheduler_cfg.get("timezone", DEFAULT_TIMEZONE)) + scheduler_cfg["runOnStartup"] = bool(scheduler_cfg.get("runOnStartup", False)) + _save_user_config(username, cfg) + + _schedule_user_job(username) + runtime = _get_runtime(username) + return { + "ok": True, + "message": f"定时任务已更新为每天 {hour:02d}:{minute:02d}。", + "time": f"{hour:02d}:{minute:02d}", + "next_run": runtime.snapshot(0, 0)["next_run"], + } + + +@app.get("/api/editor/state") +async def api_editor_state(request: Request): + session = _require_user_session(request) + username = session["username"] + return {"ok": True, **_build_editor_state(username)} + + +@app.post("/api/editor/message") +async def api_editor_message(request: Request, payload: MessageTemplatePayload): + session = _require_user_session(request) + username = session["username"] + + message = payload.message.strip() + if not message: + return JSONResponse(status_code=400, content={"ok": False, "message": "消息内容不能为空。"}) + if len(message) > MAX_TEMPLATE_LENGTH: + return JSONResponse( + status_code=400, + content={"ok": False, "message": f"消息内容过长,最多 {MAX_TEMPLATE_LENGTH} 字符。"}, + ) + + cfg = _load_user_config(username) + cfg["messageTemplate"] = message + _save_user_config(username, cfg) + _get_runtime(username).add_log("消息模板已更新") + return {"ok": True, "message": "消息模板已保存。"} + + +@app.post("/api/editor/targets") +async def api_editor_targets(request: Request, payload: UserTargetsPayload): + session = _require_user_session(request) + username = session["username"] + + users_data = _load_user_users_data(username) + updates = {item.unique_id: _sanitize_targets(item.targets) for item in payload.users} + + updated = 0 + for user in users_data: + uid = str(user.get("unique_id", "")) + if uid in updates: + user["targets"] = updates[uid] + updated += 1 + + _save_user_users_data(username, users_data) + _get_runtime(username).add_log(f"目标好友已更新,涉及账号数:{updated}") + return {"ok": True, "message": f"目标好友已保存({updated} 个账号)。"} + + +@app.get("/api/admin/overview") +async def api_admin_overview(request: Request): + _require_admin_session(request) + try: + users_map = _load_users_meta() + if not scheduler_bootstrapped: + _sync_user_jobs_from_meta(users_map, run_startup_tasks=False) + except Exception as exc: + logger.warning("Admin overview failed to reach database. error=%s", exc) + return { + "ok": True, + "users": [], + "task_count": 0, + "db_status": _build_db_status_payload(), + "message": f"无法连接 SQL 服务器:{exc}", + } + + rows = [] + for username, meta in sorted(users_map.items(), key=lambda x: x[0]): + try: + cfg = _load_user_config(username) + users_data = _load_user_users_data(username) + except Exception as exc: + rows.append( + { + "username": username, + "unique_id": meta.get("unique_id", ""), + "created_at": meta.get("created_at", "-"), + "error": str(exc), + } + ) + continue + + scheduler_cfg = cfg.get("scheduler", {}) + runtime = _get_runtime(username) + runtime_snapshot = runtime.snapshot( + account_count=len(users_data), + target_count=_count_targets(users_data), + ) + + receivers = [] + for item in users_data: + receivers.extend(item.get("targets", [])) + + rows.append( + { + "username": username, + "unique_id": meta.get("unique_id", ""), + "created_at": meta.get("created_at", "-"), + "scheduler_enabled": bool(scheduler_cfg.get("enabled", True)), + "schedule_time": f"{int(scheduler_cfg.get('hour', 9)):02d}:{int(scheduler_cfg.get('minute', 0)):02d}", + "schedule_timezone": str(scheduler_cfg.get("timezone", DEFAULT_TIMEZONE)), + "message_template": str(cfg.get("messageTemplate", "")), + "targets": receivers, + "target_count": len(receivers), + "next_run": runtime_snapshot.get("next_run", "-"), + "last_status": runtime_snapshot.get("last_status", "-"), + "last_start": runtime_snapshot.get("last_start", "-"), + "is_running": runtime_snapshot.get("is_running", False), + "can_retry": bool( + not runtime_snapshot.get("is_running", False) + and runtime_snapshot.get("last_status") == "失败" + ), + } + ) + + return { + "ok": True, + "users": rows, + "task_count": len(rows), + "db_status": _build_db_status_payload(), + } + + +@app.get("/api/admin/tasks/{username}") +async def api_admin_task_detail(request: Request, username: str, log_limit: int = MAX_LOG_LINES): + _require_admin_session(request) + username = username.strip() + user_meta = _get_user_meta_or_404(username) + + try: + cfg = _load_user_config(username) + users_data = _load_user_users_data(username) + except Exception as exc: + return JSONResponse( + status_code=500, + content={"ok": False, "message": f"加载任务详情失败:{exc}"}, + ) + + scheduler_cfg = cfg.get("scheduler", {}) + runtime = _get_runtime(username) + target_count = _count_targets(users_data) + snapshot = runtime.snapshot(account_count=len(users_data), target_count=target_count) + + accounts = [] + all_targets = [] + for item in users_data: + targets = _sanitize_targets(item.get("targets", [])) + all_targets.extend(targets) + accounts.append( + { + "username": str(item.get("username", "未知用户")), + "unique_id": str(item.get("unique_id", "")), + "target_count": len(targets), + "targets": targets, + "cookie_count": len(item.get("cookies", [])) if isinstance(item.get("cookies", []), list) else 0, + } + ) + + log_limit = min(max(100, log_limit), 3000) + return { + "ok": True, + "task": { + "username": username, + "unique_id": user_meta.get("unique_id", ""), + "created_at": user_meta.get("created_at", "-"), + "scheduler_enabled": bool(scheduler_cfg.get("enabled", True)), + "schedule_time": f"{int(scheduler_cfg.get('hour', 9)):02d}:{int(scheduler_cfg.get('minute', 0)):02d}", + "schedule_timezone": str(scheduler_cfg.get("timezone", DEFAULT_TIMEZONE)), + "message_template": str(cfg.get("messageTemplate", "")), + "targets": all_targets, + "target_count": len(all_targets), + "runtime": snapshot, + "can_retry": bool(not snapshot.get("is_running", False) and snapshot.get("last_status") == "失败"), + "history": runtime.history_rows(), + "logs": runtime.recent_logs(limit=log_limit), + "config": { + "multiTask": bool(cfg.get("multiTask", True)), + "taskCount": int(cfg.get("taskCount", 1) or 1), + "hitokotoTypes": cfg.get("hitokotoTypes", []), + "proxyAddress": str(cfg.get("proxyAddress", "")), + }, + "accounts": accounts, + }, + } + + +@app.post("/api/admin/tasks/{username}/retry") +async def api_admin_retry_task(request: Request, username: str): + _require_admin_session(request) + username = username.strip() + _get_user_meta_or_404(username) + + runtime = _get_runtime(username) + snapshot = runtime.snapshot(account_count=0, target_count=0) + if snapshot.get("is_running"): + return JSONResponse(status_code=409, content={"ok": False, "message": "任务正在运行中,请稍后再试。"}) + if snapshot.get("last_status") != "失败": + return JSONResponse(status_code=409, content={"ok": False, "message": "当前任务不是失败状态,无需重试。"}) + + runtime.add_log("管理员手动触发失败任务重试") + _start_background_run(username, "admin_retry") + return {"ok": True, "message": f"已开始重试 {username} 的失败任务。"} + + +@app.post("/api/admin/tasks/retry-failed") +async def api_admin_retry_all_failed_tasks(request: Request): + _require_admin_session(request) + + try: + usernames = _retry_failed_tasks_once("admin_bulk_retry", raise_on_db_error=True) + except Exception as exc: + return JSONResponse(status_code=503, content={"ok": False, "message": f"???? SQL ????{exc}"}) + + if not usernames: + return {"ok": True, "message": "?????????????", "count": 0, "usernames": []} + + return { + "ok": True, + "message": f"??????? {len(usernames)} ??????", + "count": len(usernames), + "usernames": usernames, + } + + +@app.post("/api/admin/tasks/{username}/delete") +async def api_admin_delete_task(request: Request, username: str): + _require_admin_session(request) + username = username.strip() + _get_user_meta_or_404(username) + + cfg = _load_user_config(username) + scheduler_cfg = cfg.setdefault("scheduler", {}) + scheduler_cfg["enabled"] = False + _save_user_config(username, cfg) + + _remove_user_schedule_job(username) + runtime = _get_runtime(username) + runtime.update_next_run(None) + runtime.add_log("管理员已删除(禁用)该用户定时任务") + + return {"ok": True, "message": f"已删除用户 {username} 的定时任务。"} + + +@app.delete("/api/admin/users/{username}") +async def api_admin_delete_user(request: Request, username: str): + _require_admin_session(request) + username = username.strip() + + _get_user_meta_or_404(username) + + _remove_user_schedule_job(username) + _delete_user_record(username) + _delete_runtime(username) + + return {"ok": True, "message": f"用户 {username} 已删除。"} + + +@app.get("/health") +async def health(): + return {"ok": True, "status": "alive"} + + +def run_server(): + port = int(os.getenv("PORT", "7860")) + uvicorn.run("app:app", host="0.0.0.0", port=port, workers=1) + + +if __name__ == "__main__": + run_server()