import asyncio import json import os from contextlib import asynccontextmanager from datetime import datetime from typing import Any, Dict, List, Optional, Sequence import aiomysql from aiomysql.cursors import DictCursor from config import ( IMAGES_DIR, MYSQL_DB, MYSQL_HOST, MYSQL_PASSWORD, MYSQL_POOL_MAX_SIZE, MYSQL_POOL_MIN_SIZE, MYSQL_PORT, MYSQL_USER, logger, ) _pool: Optional[aiomysql.Pool] = None _pool_lock = asyncio.Lock() _IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR)) async def init_mysql_pool() -> aiomysql.Pool: global _pool if _pool is not None: return _pool async with _pool_lock: if _pool is not None: return _pool _pool = await aiomysql.create_pool( host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, db=MYSQL_DB, minsize=MYSQL_POOL_MIN_SIZE, maxsize=MYSQL_POOL_MAX_SIZE, autocommit=True, charset="utf8mb4", cursorclass=DictCursor, ) logger.info("MySQL 连接池初始化成功,host=%s db=%s", MYSQL_HOST, MYSQL_DB) return _pool async def close_mysql_pool() -> None: global _pool if _pool is None: return async with _pool_lock: if _pool is None: return _pool.close() await _pool.wait_closed() _pool = None logger.info("MySQL 连接池已关闭") @asynccontextmanager async def get_connection(): if _pool is None: await init_mysql_pool() assert _pool is not None conn = await _pool.acquire() try: yield conn finally: _pool.release(conn) async def execute(query: str, params: Sequence[Any] | Dict[str, Any] | None = None) -> None: async with get_connection() as conn: async with conn.cursor() as cursor: await cursor.execute(query, params or ()) async def upsert_device_record( *, device_id: str, device_type: Optional[str] = None, device_model: Optional[str] = None, os_version: Optional[str] = None, app_version: Optional[str] = None, region: Optional[str] = None, timezone: Optional[str] = None, language: Optional[str] = None, ) -> None: query = """ INSERT INTO tpl_app_user_devices ( device_id, device_type, device_model, os_version, app_version, region, timezone, language, updated_at ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, CURRENT_TIMESTAMP) ON DUPLICATE KEY UPDATE device_type = VALUES(device_type), device_model = VALUES(device_model), os_version = VALUES(os_version), app_version = VALUES(app_version), region = VALUES(region), timezone = VALUES(timezone), language = VALUES(language), updated_at = CURRENT_TIMESTAMP """ params = ( device_id, device_type, device_model, os_version, app_version, region, timezone, language, ) await execute(query, params) def _serialize_extra(extra: Optional[Dict[str, Any]]) -> Optional[str]: if extra is None: return None try: return json.dumps(extra, ensure_ascii=False) except Exception: logger.warning("无法序列化 extra_metadata,已忽略") return None def _normalize_file_path(file_path: str) -> Optional[str]: if not file_path: return None abs_path = os.path.abspath(os.path.expanduser(file_path)) try: rel_path = os.path.relpath(abs_path, _IMAGES_DIR_ABS) except ValueError: return os.path.basename(file_path) if rel_path.startswith(".."): return os.path.basename(file_path) return rel_path.replace(os.sep, "/") async def record_image_creation( *, file_path: str, nickname: Optional[str], score: float = 0.0, category: Optional[str] = None, bos_uploaded: bool = False, region: Optional[str] = None, extra_metadata: Optional[Dict[str, Any]] = None, ) -> None: normalized = _normalize_file_path(file_path) if normalized is None: logger.info("record_image_creation: 无法计算文件名,路径=%s", file_path) return abs_path = os.path.join(_IMAGES_DIR_ABS, normalized) if not os.path.isfile(abs_path): logger.info("record_image_creation: 文件不存在,跳过记录 file=%s", abs_path) return stat = os.stat(abs_path) nickname_value = nickname.strip() if isinstance(nickname, str) and nickname.strip() else None query = """ INSERT INTO tpl_app_processed_images ( file_path, category, nickname, score, is_cropped_face, size_bytes, last_modified, bos_uploaded, region, extra_metadata ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE category = VALUES(category), nickname = VALUES(nickname), score = VALUES(score), is_cropped_face = VALUES(is_cropped_face), size_bytes = VALUES(size_bytes), last_modified = VALUES(last_modified), bos_uploaded = VALUES(bos_uploaded), region = VALUES(region), extra_metadata = VALUES(extra_metadata), updated_at = CURRENT_TIMESTAMP """ await execute( query, ( normalized, category or "anime_style", nickname_value, score, 0, stat.st_size, datetime.fromtimestamp(stat.st_mtime), 1 if bos_uploaded else 0, region, _serialize_extra(extra_metadata), ), )