Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import os | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Optional, Sequence | |
| import aiomysql | |
| from aiomysql.cursors import DictCursor | |
| from config import ( | |
| IMAGES_DIR, | |
| MYSQL_DB, | |
| MYSQL_HOST, | |
| MYSQL_PASSWORD, | |
| MYSQL_POOL_MAX_SIZE, | |
| MYSQL_POOL_MIN_SIZE, | |
| MYSQL_PORT, | |
| MYSQL_USER, | |
| logger, | |
| ) | |
| _pool: Optional[aiomysql.Pool] = None | |
| _pool_lock = asyncio.Lock() | |
| _IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR)) | |
| async def init_mysql_pool() -> aiomysql.Pool: | |
| global _pool | |
| if _pool is not None: | |
| return _pool | |
| async with _pool_lock: | |
| if _pool is not None: | |
| return _pool | |
| _pool = await aiomysql.create_pool( | |
| host=MYSQL_HOST, | |
| port=MYSQL_PORT, | |
| user=MYSQL_USER, | |
| password=MYSQL_PASSWORD, | |
| db=MYSQL_DB, | |
| minsize=MYSQL_POOL_MIN_SIZE, | |
| maxsize=MYSQL_POOL_MAX_SIZE, | |
| autocommit=True, | |
| charset="utf8mb4", | |
| cursorclass=DictCursor, | |
| ) | |
| logger.info("MySQL 连接池初始化成功,host=%s db=%s", MYSQL_HOST, MYSQL_DB) | |
| return _pool | |
| async def close_mysql_pool() -> None: | |
| global _pool | |
| if _pool is None: | |
| return | |
| async with _pool_lock: | |
| if _pool is None: | |
| return | |
| _pool.close() | |
| await _pool.wait_closed() | |
| _pool = None | |
| logger.info("MySQL 连接池已关闭") | |
| async def get_connection(): | |
| if _pool is None: | |
| await init_mysql_pool() | |
| assert _pool is not None | |
| conn = await _pool.acquire() | |
| try: | |
| yield conn | |
| finally: | |
| _pool.release(conn) | |
| async def execute(query: str, params: Sequence[Any] | Dict[str, Any] | None = None) -> None: | |
| async with get_connection() as conn: | |
| async with conn.cursor() as cursor: | |
| await cursor.execute(query, params or ()) | |
| async def upsert_device_record( | |
| *, | |
| device_id: str, | |
| device_type: Optional[str] = None, | |
| device_model: Optional[str] = None, | |
| os_version: Optional[str] = None, | |
| app_version: Optional[str] = None, | |
| region: Optional[str] = None, | |
| timezone: Optional[str] = None, | |
| language: Optional[str] = None, | |
| ) -> None: | |
| query = """ | |
| INSERT INTO tpl_app_user_devices ( | |
| device_id, device_type, device_model, os_version, app_version, | |
| region, timezone, language, updated_at | |
| ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, CURRENT_TIMESTAMP) | |
| ON DUPLICATE KEY UPDATE | |
| device_type = VALUES(device_type), | |
| device_model = VALUES(device_model), | |
| os_version = VALUES(os_version), | |
| app_version = VALUES(app_version), | |
| region = VALUES(region), | |
| timezone = VALUES(timezone), | |
| language = VALUES(language), | |
| updated_at = CURRENT_TIMESTAMP | |
| """ | |
| params = ( | |
| device_id, | |
| device_type, | |
| device_model, | |
| os_version, | |
| app_version, | |
| region, | |
| timezone, | |
| language, | |
| ) | |
| await execute(query, params) | |
| def _serialize_extra(extra: Optional[Dict[str, Any]]) -> Optional[str]: | |
| if extra is None: | |
| return None | |
| try: | |
| return json.dumps(extra, ensure_ascii=False) | |
| except Exception: | |
| logger.warning("无法序列化 extra_metadata,已忽略") | |
| return None | |
| def _normalize_file_path(file_path: str) -> Optional[str]: | |
| if not file_path: | |
| return None | |
| abs_path = os.path.abspath(os.path.expanduser(file_path)) | |
| try: | |
| rel_path = os.path.relpath(abs_path, _IMAGES_DIR_ABS) | |
| except ValueError: | |
| return os.path.basename(file_path) | |
| if rel_path.startswith(".."): | |
| return os.path.basename(file_path) | |
| return rel_path.replace(os.sep, "/") | |
| async def record_image_creation( | |
| *, | |
| file_path: str, | |
| nickname: Optional[str], | |
| score: float = 0.0, | |
| category: Optional[str] = None, | |
| bos_uploaded: bool = False, | |
| region: Optional[str] = None, | |
| extra_metadata: Optional[Dict[str, Any]] = None, | |
| ) -> None: | |
| normalized = _normalize_file_path(file_path) | |
| if normalized is None: | |
| logger.info("record_image_creation: 无法计算文件名,路径=%s", file_path) | |
| return | |
| abs_path = os.path.join(_IMAGES_DIR_ABS, normalized) | |
| if not os.path.isfile(abs_path): | |
| logger.info("record_image_creation: 文件不存在,跳过记录 file=%s", abs_path) | |
| return | |
| stat = os.stat(abs_path) | |
| nickname_value = nickname.strip() if isinstance(nickname, str) and nickname.strip() else None | |
| query = """ | |
| INSERT INTO tpl_app_processed_images ( | |
| file_path, category, nickname, score, is_cropped_face, size_bytes, | |
| last_modified, bos_uploaded, region, extra_metadata | |
| ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) | |
| ON DUPLICATE KEY UPDATE | |
| category = VALUES(category), | |
| nickname = VALUES(nickname), | |
| score = VALUES(score), | |
| is_cropped_face = VALUES(is_cropped_face), | |
| size_bytes = VALUES(size_bytes), | |
| last_modified = VALUES(last_modified), | |
| bos_uploaded = VALUES(bos_uploaded), | |
| region = VALUES(region), | |
| extra_metadata = VALUES(extra_metadata), | |
| updated_at = CURRENT_TIMESTAMP | |
| """ | |
| await execute( | |
| query, | |
| ( | |
| normalized, | |
| category or "anime_style", | |
| nickname_value, | |
| score, | |
| 0, | |
| stat.st_size, | |
| datetime.fromtimestamp(stat.st_mtime), | |
| 1 if bos_uploaded else 0, | |
| region, | |
| _serialize_extra(extra_metadata), | |
| ), | |
| ) | |