|
|
import asyncio |
|
|
import json |
|
|
import os |
|
|
from contextlib import asynccontextmanager |
|
|
from datetime import datetime |
|
|
from typing import Any, Dict, Iterable, List, Optional, Sequence |
|
|
|
|
|
import aiomysql |
|
|
from aiomysql.cursors import DictCursor |
|
|
|
|
|
from config import ( |
|
|
IMAGES_DIR, |
|
|
logger, |
|
|
MYSQL_HOST, |
|
|
MYSQL_PORT, |
|
|
MYSQL_DB, |
|
|
MYSQL_USER, |
|
|
MYSQL_PASSWORD, |
|
|
MYSQL_POOL_MIN_SIZE, |
|
|
MYSQL_POOL_MAX_SIZE, |
|
|
) |
|
|
|
|
|
_pool: Optional[aiomysql.Pool] = None |
|
|
_pool_lock = asyncio.Lock() |
|
|
|
|
|
|
|
|
async def init_mysql_pool() -> aiomysql.Pool: |
|
|
"""初始化 MySQL 连接池""" |
|
|
global _pool |
|
|
if _pool is not None: |
|
|
return _pool |
|
|
|
|
|
async with _pool_lock: |
|
|
if _pool is not None: |
|
|
return _pool |
|
|
try: |
|
|
_pool = await aiomysql.create_pool( |
|
|
host=MYSQL_HOST, |
|
|
port=MYSQL_PORT, |
|
|
user=MYSQL_USER, |
|
|
password=MYSQL_PASSWORD, |
|
|
db=MYSQL_DB, |
|
|
minsize=MYSQL_POOL_MIN_SIZE, |
|
|
maxsize=MYSQL_POOL_MAX_SIZE, |
|
|
autocommit=True, |
|
|
charset="utf8mb4", |
|
|
cursorclass=DictCursor, |
|
|
) |
|
|
logger.info( |
|
|
"MySQL 连接池初始化成功,host=%s db=%s", |
|
|
MYSQL_HOST, |
|
|
MYSQL_DB, |
|
|
) |
|
|
except Exception as exc: |
|
|
logger.error(f"初始化 MySQL 连接池失败: {exc}") |
|
|
raise |
|
|
return _pool |
|
|
|
|
|
|
|
|
async def close_mysql_pool() -> None: |
|
|
"""关闭 MySQL 连接池""" |
|
|
global _pool |
|
|
if _pool is None: |
|
|
return |
|
|
|
|
|
async with _pool_lock: |
|
|
if _pool is None: |
|
|
return |
|
|
_pool.close() |
|
|
await _pool.wait_closed() |
|
|
_pool = None |
|
|
logger.info("MySQL 连接池已关闭") |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def get_connection(): |
|
|
"""获取连接池中的连接""" |
|
|
if _pool is None: |
|
|
await init_mysql_pool() |
|
|
assert _pool is not None |
|
|
conn = await _pool.acquire() |
|
|
try: |
|
|
yield conn |
|
|
finally: |
|
|
_pool.release(conn) |
|
|
|
|
|
|
|
|
async def execute(query: str, |
|
|
params: Sequence[Any] | Dict[str, Any] | None = None) -> None: |
|
|
"""执行写入类 SQL""" |
|
|
async with get_connection() as conn: |
|
|
async with conn.cursor() as cursor: |
|
|
await cursor.execute(query, params or ()) |
|
|
|
|
|
|
|
|
async def fetch_all( |
|
|
query: str, params: Sequence[Any] | Dict[str, Any] | None = None |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""执行查询并返回全部结果""" |
|
|
async with get_connection() as conn: |
|
|
async with conn.cursor() as cursor: |
|
|
await cursor.execute(query, params or ()) |
|
|
rows = await cursor.fetchall() |
|
|
return list(rows) |
|
|
|
|
|
|
|
|
def _serialize_extra(extra: Optional[Dict[str, Any]]) -> Optional[str]: |
|
|
if extra is None: |
|
|
return None |
|
|
try: |
|
|
return json.dumps(extra, ensure_ascii=False) |
|
|
except Exception: |
|
|
logger.warning("无法序列化 extra_metadata,已忽略") |
|
|
return None |
|
|
|
|
|
|
|
|
async def upsert_image_record( |
|
|
*, |
|
|
file_path: str, |
|
|
category: str, |
|
|
nickname: Optional[str], |
|
|
score: float, |
|
|
is_cropped_face: bool, |
|
|
size_bytes: int, |
|
|
last_modified: datetime, |
|
|
bos_uploaded: bool, |
|
|
hostname: Optional[str] = None, |
|
|
extra_metadata: Optional[Dict[str, Any]] = None, |
|
|
) -> None: |
|
|
"""写入或更新图片记录""" |
|
|
query = """ |
|
|
INSERT INTO tpl_app_processed_images ( |
|
|
file_path, |
|
|
category, |
|
|
nickname, |
|
|
score, |
|
|
is_cropped_face, |
|
|
size_bytes, |
|
|
last_modified, |
|
|
bos_uploaded, |
|
|
hostname, |
|
|
extra_metadata |
|
|
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) |
|
|
ON DUPLICATE KEY UPDATE |
|
|
category = VALUES(category), |
|
|
nickname = VALUES(nickname), |
|
|
score = VALUES(score), |
|
|
is_cropped_face = VALUES(is_cropped_face), |
|
|
size_bytes = VALUES(size_bytes), |
|
|
last_modified = VALUES(last_modified), |
|
|
bos_uploaded = VALUES(bos_uploaded), |
|
|
hostname = VALUES(hostname), |
|
|
extra_metadata = VALUES(extra_metadata), |
|
|
updated_at = CURRENT_TIMESTAMP |
|
|
""" |
|
|
extra_value = _serialize_extra(extra_metadata) |
|
|
await execute( |
|
|
query, |
|
|
( |
|
|
file_path, |
|
|
category, |
|
|
nickname, |
|
|
score, |
|
|
1 if is_cropped_face else 0, |
|
|
size_bytes, |
|
|
last_modified, |
|
|
1 if bos_uploaded else 0, |
|
|
hostname, |
|
|
extra_value, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
async def fetch_paged_image_records( |
|
|
*, |
|
|
category: Optional[str], |
|
|
nickname: Optional[str], |
|
|
offset: int, |
|
|
limit: int, |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""按条件分页查询图片记录""" |
|
|
where_clauses: List[str] = [] |
|
|
params: List[Any] = [] |
|
|
if category and category != "all": |
|
|
where_clauses.append("category = %s") |
|
|
params.append(category) |
|
|
if nickname: |
|
|
where_clauses.append("nickname = %s") |
|
|
params.append(nickname) |
|
|
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" |
|
|
query = f""" |
|
|
SELECT |
|
|
file_path, |
|
|
category, |
|
|
nickname, |
|
|
score, |
|
|
is_cropped_face, |
|
|
size_bytes, |
|
|
last_modified, |
|
|
bos_uploaded, |
|
|
hostname |
|
|
FROM tpl_app_processed_images |
|
|
{where_sql} |
|
|
ORDER BY last_modified DESC, id DESC |
|
|
LIMIT %s OFFSET %s |
|
|
""" |
|
|
params.extend([limit, offset]) |
|
|
return await fetch_all(query, params) |
|
|
|
|
|
|
|
|
async def count_image_records( |
|
|
*, category: Optional[str], nickname: Optional[str] |
|
|
) -> int: |
|
|
"""按条件统计图片记录数量""" |
|
|
where_clauses: List[str] = [] |
|
|
params: List[Any] = [] |
|
|
if category and category != "all": |
|
|
where_clauses.append("category = %s") |
|
|
params.append(category) |
|
|
if nickname: |
|
|
where_clauses.append("nickname = %s") |
|
|
params.append(nickname) |
|
|
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" |
|
|
query = f"SELECT COUNT(*) AS total FROM tpl_app_processed_images {where_sql}" |
|
|
rows = await fetch_all(query, params) |
|
|
if not rows: |
|
|
return 0 |
|
|
return int(rows[0].get("total", 0) or 0) |
|
|
|
|
|
|
|
|
async def fetch_today_category_counts() -> List[Dict[str, Any]]: |
|
|
"""统计当天按类别分组的数量""" |
|
|
query = """ |
|
|
SELECT |
|
|
COALESCE(category, 'unknown') AS category, |
|
|
COUNT(*) AS count |
|
|
FROM tpl_app_processed_images |
|
|
WHERE last_modified >= CURDATE() |
|
|
AND last_modified < DATE_ADD(CURDATE(), INTERVAL 1 DAY) |
|
|
GROUP BY COALESCE(category, 'unknown') |
|
|
""" |
|
|
rows = await fetch_all(query) |
|
|
return [ |
|
|
{ |
|
|
"category": str(row.get("category") or "unknown"), |
|
|
"count": int(row.get("count") or 0), |
|
|
} |
|
|
for row in rows |
|
|
] |
|
|
|
|
|
|
|
|
async def fetch_records_by_paths(file_paths: Iterable[str]) -> Dict[ |
|
|
str, Dict[str, Any]]: |
|
|
"""根据文件名批量查询图片记录""" |
|
|
paths = list({path for path in file_paths if path}) |
|
|
if not paths: |
|
|
return {} |
|
|
|
|
|
placeholders = ", ".join(["%s"] * len(paths)) |
|
|
query = f""" |
|
|
SELECT |
|
|
file_path, |
|
|
category, |
|
|
nickname, |
|
|
score, |
|
|
is_cropped_face, |
|
|
size_bytes, |
|
|
last_modified, |
|
|
bos_uploaded, |
|
|
hostname |
|
|
FROM tpl_app_processed_images |
|
|
WHERE file_path IN ({placeholders}) |
|
|
""" |
|
|
rows = await fetch_all(query, paths) |
|
|
return {row["file_path"]: row for row in rows} |
|
|
|
|
|
|
|
|
_IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR)) |
|
|
|
|
|
|
|
|
def _normalize_file_path(file_path: str) -> Optional[str]: |
|
|
"""将绝对路径转换为相对 IMAGES_DIR 的文件名""" |
|
|
try: |
|
|
abs_path = os.path.abspath(os.path.expanduser(file_path)) |
|
|
if os.path.isdir(abs_path): |
|
|
return None |
|
|
if os.path.commonpath([_IMAGES_DIR_ABS, abs_path]) != _IMAGES_DIR_ABS: |
|
|
return os.path.basename(abs_path) |
|
|
rel_path = os.path.relpath(abs_path, _IMAGES_DIR_ABS) |
|
|
return rel_path.replace("\\", "/") |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
def infer_category_from_filename(filename: str, default: str = "other") -> str: |
|
|
"""根据文件名推断类别""" |
|
|
lower_name = filename.lower() |
|
|
if "_face_" in lower_name: |
|
|
return "face" |
|
|
if lower_name.endswith("_original.webp") or "_original" in lower_name: |
|
|
return "original" |
|
|
if "_restore" in lower_name: |
|
|
return "restore" |
|
|
if "_upcolor" in lower_name: |
|
|
return "upcolor" |
|
|
if "_compress" in lower_name: |
|
|
return "compress" |
|
|
if "_upscale" in lower_name: |
|
|
return "upscale" |
|
|
if "_anime_style_" in lower_name: |
|
|
return "anime_style" |
|
|
if "_grayscale" in lower_name: |
|
|
return "grayscale" |
|
|
if "_id_photo" in lower_name or "_save_id_photo" in lower_name: |
|
|
return "id_photo" |
|
|
if "_grid_" in lower_name: |
|
|
return "grid" |
|
|
if "_rvm_id_photo" in lower_name: |
|
|
return "rvm" |
|
|
if "_celebrity_" in lower_name or "_celebrity" in lower_name: |
|
|
return "celebrity" |
|
|
return default |
|
|
|
|
|
|
|
|
from config import HOSTNAME |
|
|
|
|
|
async def record_image_creation( |
|
|
*, |
|
|
file_path: str, |
|
|
nickname: Optional[str], |
|
|
score: float = 0.0, |
|
|
category: Optional[str] = None, |
|
|
bos_uploaded: bool = False, |
|
|
extra_metadata: Optional[Dict[str, Any]] = None, |
|
|
) -> None: |
|
|
""" |
|
|
记录图片元数据到数据库,如果数据库不可用则静默忽略。 |
|
|
:param file_path: 绝对或相对文件路径 |
|
|
:param nickname: 用户昵称 |
|
|
:param score: 关联得分 |
|
|
:param category: 文件类别,未提供时自动根据文件名推断 |
|
|
:param bos_uploaded: 是否已上传至 BOS |
|
|
:param extra_metadata: 额外信息 |
|
|
""" |
|
|
normalized = _normalize_file_path(file_path) |
|
|
if normalized is None: |
|
|
logger.info("record_image_creation: 无法计算文件名,路径=%s", file_path) |
|
|
return |
|
|
|
|
|
abs_path = os.path.join(_IMAGES_DIR_ABS, normalized) |
|
|
if not os.path.isfile(abs_path): |
|
|
logger.info("record_image_creation: 文件不存在,跳过记录 file=%s", abs_path) |
|
|
return |
|
|
|
|
|
try: |
|
|
stat = os.stat(abs_path) |
|
|
category_name = category or infer_category_from_filename(normalized) |
|
|
is_cropped_face = "_face_" in normalized and normalized.count("_") >= 2 |
|
|
last_modified = datetime.fromtimestamp(stat.st_mtime) |
|
|
|
|
|
nickname_value = nickname.strip() if isinstance(nickname, |
|
|
str) and nickname.strip() else None |
|
|
|
|
|
await upsert_image_record( |
|
|
file_path=normalized, |
|
|
category=category_name, |
|
|
nickname=nickname_value, |
|
|
score=score, |
|
|
is_cropped_face=is_cropped_face, |
|
|
size_bytes=stat.st_size, |
|
|
last_modified=last_modified, |
|
|
bos_uploaded=bos_uploaded, |
|
|
hostname=HOSTNAME, |
|
|
extra_metadata=extra_metadata, |
|
|
) |
|
|
except Exception as exc: |
|
|
logger.warning(f"写入图片记录失败: {exc}") |
|
|
|