Spaces:
Paused
Paused
| import asyncio | |
| import json | |
| import os | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime | |
| from typing import Any, Dict, Iterable, List, Optional, Sequence | |
| import aiomysql | |
| from aiomysql.cursors import DictCursor | |
| from config import ( | |
| IMAGES_DIR, | |
| logger, | |
| MYSQL_HOST, | |
| MYSQL_PORT, | |
| MYSQL_DB, | |
| MYSQL_USER, | |
| MYSQL_PASSWORD, | |
| MYSQL_POOL_MIN_SIZE, | |
| MYSQL_POOL_MAX_SIZE, | |
| ) | |
| _pool: Optional[aiomysql.Pool] = None | |
| _pool_lock = asyncio.Lock() | |
| async def init_mysql_pool() -> aiomysql.Pool: | |
| """初始化 MySQL 连接池""" | |
| global _pool | |
| if _pool is not None: | |
| return _pool | |
| async with _pool_lock: | |
| if _pool is not None: | |
| return _pool | |
| try: | |
| _pool = await aiomysql.create_pool( | |
| host=MYSQL_HOST, | |
| port=MYSQL_PORT, | |
| user=MYSQL_USER, | |
| password=MYSQL_PASSWORD, | |
| db=MYSQL_DB, | |
| minsize=MYSQL_POOL_MIN_SIZE, | |
| maxsize=MYSQL_POOL_MAX_SIZE, | |
| autocommit=True, | |
| charset="utf8mb4", | |
| cursorclass=DictCursor, | |
| ) | |
| logger.info( | |
| "MySQL 连接池初始化成功,host=%s db=%s", | |
| MYSQL_HOST, | |
| MYSQL_DB, | |
| ) | |
| except Exception as exc: | |
| logger.error(f"初始化 MySQL 连接池失败: {exc}") | |
| raise | |
| return _pool | |
| async def close_mysql_pool() -> None: | |
| """关闭 MySQL 连接池""" | |
| global _pool | |
| if _pool is None: | |
| return | |
| async with _pool_lock: | |
| if _pool is None: | |
| return | |
| _pool.close() | |
| await _pool.wait_closed() | |
| _pool = None | |
| logger.info("MySQL 连接池已关闭") | |
| async def get_connection(): | |
| """获取连接池中的连接""" | |
| if _pool is None: | |
| await init_mysql_pool() | |
| assert _pool is not None | |
| conn = await _pool.acquire() | |
| try: | |
| yield conn | |
| finally: | |
| _pool.release(conn) | |
| async def execute(query: str, | |
| params: Sequence[Any] | Dict[str, Any] | None = None) -> None: | |
| """执行写入类 SQL""" | |
| async with get_connection() as conn: | |
| async with conn.cursor() as cursor: | |
| await cursor.execute(query, params or ()) | |
| async def fetch_all( | |
| query: str, params: Sequence[Any] | Dict[str, Any] | None = None | |
| ) -> List[Dict[str, Any]]: | |
| """执行查询并返回全部结果""" | |
| async with get_connection() as conn: | |
| async with conn.cursor() as cursor: | |
| await cursor.execute(query, params or ()) | |
| rows = await cursor.fetchall() | |
| return list(rows) | |
| def _serialize_extra(extra: Optional[Dict[str, Any]]) -> Optional[str]: | |
| if extra is None: | |
| return None | |
| try: | |
| return json.dumps(extra, ensure_ascii=False) | |
| except Exception: | |
| logger.warning("无法序列化 extra_metadata,已忽略") | |
| return None | |
| async def upsert_image_record( | |
| *, | |
| file_path: str, | |
| category: str, | |
| nickname: Optional[str], | |
| score: float, | |
| is_cropped_face: bool, | |
| size_bytes: int, | |
| last_modified: datetime, | |
| bos_uploaded: bool, | |
| hostname: Optional[str] = None, | |
| extra_metadata: Optional[Dict[str, Any]] = None, | |
| ) -> None: | |
| """写入或更新图片记录""" | |
| query = """ | |
| INSERT INTO tpl_app_processed_images ( | |
| file_path, | |
| category, | |
| nickname, | |
| score, | |
| is_cropped_face, | |
| size_bytes, | |
| last_modified, | |
| bos_uploaded, | |
| hostname, | |
| extra_metadata | |
| ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) | |
| ON DUPLICATE KEY UPDATE | |
| category = VALUES(category), | |
| nickname = VALUES(nickname), | |
| score = VALUES(score), | |
| is_cropped_face = VALUES(is_cropped_face), | |
| size_bytes = VALUES(size_bytes), | |
| last_modified = VALUES(last_modified), | |
| bos_uploaded = VALUES(bos_uploaded), | |
| hostname = VALUES(hostname), | |
| extra_metadata = VALUES(extra_metadata), | |
| updated_at = CURRENT_TIMESTAMP | |
| """ | |
| extra_value = _serialize_extra(extra_metadata) | |
| await execute( | |
| query, | |
| ( | |
| file_path, | |
| category, | |
| nickname, | |
| score, | |
| 1 if is_cropped_face else 0, | |
| size_bytes, | |
| last_modified, | |
| 1 if bos_uploaded else 0, | |
| hostname, | |
| extra_value, | |
| ), | |
| ) | |
| async def fetch_paged_image_records( | |
| *, | |
| category: Optional[str], | |
| nickname: Optional[str], | |
| offset: int, | |
| limit: int, | |
| ) -> List[Dict[str, Any]]: | |
| """按条件分页查询图片记录""" | |
| where_clauses: List[str] = [] | |
| params: List[Any] = [] | |
| if category and category != "all": | |
| where_clauses.append("category = %s") | |
| params.append(category) | |
| if nickname: | |
| where_clauses.append("nickname = %s") | |
| params.append(nickname) | |
| where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" | |
| query = f""" | |
| SELECT | |
| file_path, | |
| category, | |
| nickname, | |
| score, | |
| is_cropped_face, | |
| size_bytes, | |
| last_modified, | |
| bos_uploaded, | |
| hostname | |
| FROM tpl_app_processed_images | |
| {where_sql} | |
| ORDER BY last_modified DESC, id DESC | |
| LIMIT %s OFFSET %s | |
| """ | |
| params.extend([limit, offset]) | |
| return await fetch_all(query, params) | |
| async def count_image_records( | |
| *, category: Optional[str], nickname: Optional[str] | |
| ) -> int: | |
| """按条件统计图片记录数量""" | |
| where_clauses: List[str] = [] | |
| params: List[Any] = [] | |
| if category and category != "all": | |
| where_clauses.append("category = %s") | |
| params.append(category) | |
| if nickname: | |
| where_clauses.append("nickname = %s") | |
| params.append(nickname) | |
| where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" | |
| query = f"SELECT COUNT(*) AS total FROM tpl_app_processed_images {where_sql}" | |
| rows = await fetch_all(query, params) | |
| if not rows: | |
| return 0 | |
| return int(rows[0].get("total", 0) or 0) | |
| async def fetch_today_category_counts() -> List[Dict[str, Any]]: | |
| """统计当天按类别分组的数量""" | |
| query = """ | |
| SELECT | |
| COALESCE(category, 'unknown') AS category, | |
| COUNT(*) AS count | |
| FROM tpl_app_processed_images | |
| WHERE last_modified >= CURDATE() | |
| AND last_modified < DATE_ADD(CURDATE(), INTERVAL 1 DAY) | |
| GROUP BY COALESCE(category, 'unknown') | |
| """ | |
| rows = await fetch_all(query) | |
| return [ | |
| { | |
| "category": str(row.get("category") or "unknown"), | |
| "count": int(row.get("count") or 0), | |
| } | |
| for row in rows | |
| ] | |
| async def fetch_records_by_paths(file_paths: Iterable[str]) -> Dict[ | |
| str, Dict[str, Any]]: | |
| """根据文件名批量查询图片记录""" | |
| paths = list({path for path in file_paths if path}) | |
| if not paths: | |
| return {} | |
| placeholders = ", ".join(["%s"] * len(paths)) | |
| query = f""" | |
| SELECT | |
| file_path, | |
| category, | |
| nickname, | |
| score, | |
| is_cropped_face, | |
| size_bytes, | |
| last_modified, | |
| bos_uploaded, | |
| hostname | |
| FROM tpl_app_processed_images | |
| WHERE file_path IN ({placeholders}) | |
| """ | |
| rows = await fetch_all(query, paths) | |
| return {row["file_path"]: row for row in rows} | |
| _IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR)) | |
| def _normalize_file_path(file_path: str) -> Optional[str]: | |
| """将绝对路径转换为相对 IMAGES_DIR 的文件名""" | |
| try: | |
| abs_path = os.path.abspath(os.path.expanduser(file_path)) | |
| if os.path.isdir(abs_path): | |
| return None | |
| if os.path.commonpath([_IMAGES_DIR_ABS, abs_path]) != _IMAGES_DIR_ABS: | |
| return os.path.basename(abs_path) | |
| rel_path = os.path.relpath(abs_path, _IMAGES_DIR_ABS) | |
| return rel_path.replace("\\", "/") | |
| except Exception: | |
| return None | |
| def infer_category_from_filename(filename: str, default: str = "other") -> str: | |
| """根据文件名推断类别""" | |
| lower_name = filename.lower() | |
| if "_face_" in lower_name: | |
| return "face" | |
| if lower_name.endswith("_original.webp") or "_original" in lower_name: | |
| return "original" | |
| if "_restore" in lower_name: | |
| return "restore" | |
| if "_upcolor" in lower_name: | |
| return "upcolor" | |
| if "_compress" in lower_name: | |
| return "compress" | |
| if "_upscale" in lower_name: | |
| return "upscale" | |
| if "_anime_style_" in lower_name: | |
| return "anime_style" | |
| if "_grayscale" in lower_name: | |
| return "grayscale" | |
| if "_id_photo" in lower_name or "_save_id_photo" in lower_name: | |
| return "id_photo" | |
| if "_grid_" in lower_name: | |
| return "grid" | |
| if "_rvm_id_photo" in lower_name: | |
| return "rvm" | |
| if "_celebrity_" in lower_name or "_celebrity" in lower_name: | |
| return "celebrity" | |
| return default | |
| from config import HOSTNAME | |
| async def record_image_creation( | |
| *, | |
| file_path: str, | |
| nickname: Optional[str], | |
| score: float = 0.0, | |
| category: Optional[str] = None, | |
| bos_uploaded: bool = False, | |
| extra_metadata: Optional[Dict[str, Any]] = None, | |
| ) -> None: | |
| """ | |
| 记录图片元数据到数据库,如果数据库不可用则静默忽略。 | |
| :param file_path: 绝对或相对文件路径 | |
| :param nickname: 用户昵称 | |
| :param score: 关联得分 | |
| :param category: 文件类别,未提供时自动根据文件名推断 | |
| :param bos_uploaded: 是否已上传至 BOS | |
| :param extra_metadata: 额外信息 | |
| """ | |
| normalized = _normalize_file_path(file_path) | |
| if normalized is None: | |
| logger.info("record_image_creation: 无法计算文件名,路径=%s", file_path) | |
| return | |
| abs_path = os.path.join(_IMAGES_DIR_ABS, normalized) | |
| if not os.path.isfile(abs_path): | |
| logger.info("record_image_creation: 文件不存在,跳过记录 file=%s", abs_path) | |
| return | |
| try: | |
| stat = os.stat(abs_path) | |
| category_name = category or infer_category_from_filename(normalized) | |
| is_cropped_face = "_face_" in normalized and normalized.count("_") >= 2 | |
| last_modified = datetime.fromtimestamp(stat.st_mtime) | |
| nickname_value = nickname.strip() if isinstance(nickname, | |
| str) and nickname.strip() else None | |
| await upsert_image_record( | |
| file_path=normalized, | |
| category=category_name, | |
| nickname=nickname_value, | |
| score=score, | |
| is_cropped_face=is_cropped_face, | |
| size_bytes=stat.st_size, | |
| last_modified=last_modified, | |
| bos_uploaded=bos_uploaded, | |
| hostname=HOSTNAME, | |
| extra_metadata=extra_metadata, | |
| ) | |
| except Exception as exc: | |
| logger.warning(f"写入图片记录失败: {exc}") | |