picpocket-anime-style / database.py
chenchaoyun
Fix device DB writes for existing schema
b681d82
import asyncio
import json
import os
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence
import aiomysql
from aiomysql.cursors import DictCursor
from config import (
IMAGES_DIR,
MYSQL_DB,
MYSQL_HOST,
MYSQL_PASSWORD,
MYSQL_POOL_MAX_SIZE,
MYSQL_POOL_MIN_SIZE,
MYSQL_PORT,
MYSQL_USER,
logger,
)
_pool: Optional[aiomysql.Pool] = None
_pool_lock = asyncio.Lock()
_IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR))
async def init_mysql_pool() -> aiomysql.Pool:
global _pool
if _pool is not None:
return _pool
async with _pool_lock:
if _pool is not None:
return _pool
_pool = await aiomysql.create_pool(
host=MYSQL_HOST,
port=MYSQL_PORT,
user=MYSQL_USER,
password=MYSQL_PASSWORD,
db=MYSQL_DB,
minsize=MYSQL_POOL_MIN_SIZE,
maxsize=MYSQL_POOL_MAX_SIZE,
autocommit=True,
charset="utf8mb4",
cursorclass=DictCursor,
)
logger.info("MySQL 连接池初始化成功,host=%s db=%s", MYSQL_HOST, MYSQL_DB)
return _pool
async def close_mysql_pool() -> None:
global _pool
if _pool is None:
return
async with _pool_lock:
if _pool is None:
return
_pool.close()
await _pool.wait_closed()
_pool = None
logger.info("MySQL 连接池已关闭")
@asynccontextmanager
async def get_connection():
if _pool is None:
await init_mysql_pool()
assert _pool is not None
conn = await _pool.acquire()
try:
yield conn
finally:
_pool.release(conn)
async def execute(query: str, params: Sequence[Any] | Dict[str, Any] | None = None) -> None:
async with get_connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, params or ())
async def upsert_device_record(
*,
device_id: str,
device_type: Optional[str] = None,
device_model: Optional[str] = None,
os_version: Optional[str] = None,
app_version: Optional[str] = None,
region: Optional[str] = None,
timezone: Optional[str] = None,
language: Optional[str] = None,
) -> None:
query = """
INSERT INTO tpl_app_user_devices (
device_id, device_type, device_model, os_version, app_version,
region, timezone, language, updated_at
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE
device_type = VALUES(device_type),
device_model = VALUES(device_model),
os_version = VALUES(os_version),
app_version = VALUES(app_version),
region = VALUES(region),
timezone = VALUES(timezone),
language = VALUES(language),
updated_at = CURRENT_TIMESTAMP
"""
params = (
device_id,
device_type,
device_model,
os_version,
app_version,
region,
timezone,
language,
)
await execute(query, params)
def _serialize_extra(extra: Optional[Dict[str, Any]]) -> Optional[str]:
if extra is None:
return None
try:
return json.dumps(extra, ensure_ascii=False)
except Exception:
logger.warning("无法序列化 extra_metadata,已忽略")
return None
def _normalize_file_path(file_path: str) -> Optional[str]:
if not file_path:
return None
abs_path = os.path.abspath(os.path.expanduser(file_path))
try:
rel_path = os.path.relpath(abs_path, _IMAGES_DIR_ABS)
except ValueError:
return os.path.basename(file_path)
if rel_path.startswith(".."):
return os.path.basename(file_path)
return rel_path.replace(os.sep, "/")
async def record_image_creation(
*,
file_path: str,
nickname: Optional[str],
score: float = 0.0,
category: Optional[str] = None,
bos_uploaded: bool = False,
region: Optional[str] = None,
extra_metadata: Optional[Dict[str, Any]] = None,
) -> None:
normalized = _normalize_file_path(file_path)
if normalized is None:
logger.info("record_image_creation: 无法计算文件名,路径=%s", file_path)
return
abs_path = os.path.join(_IMAGES_DIR_ABS, normalized)
if not os.path.isfile(abs_path):
logger.info("record_image_creation: 文件不存在,跳过记录 file=%s", abs_path)
return
stat = os.stat(abs_path)
nickname_value = nickname.strip() if isinstance(nickname, str) and nickname.strip() else None
query = """
INSERT INTO tpl_app_processed_images (
file_path, category, nickname, score, is_cropped_face, size_bytes,
last_modified, bos_uploaded, region, extra_metadata
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
category = VALUES(category),
nickname = VALUES(nickname),
score = VALUES(score),
is_cropped_face = VALUES(is_cropped_face),
size_bytes = VALUES(size_bytes),
last_modified = VALUES(last_modified),
bos_uploaded = VALUES(bos_uploaded),
region = VALUES(region),
extra_metadata = VALUES(extra_metadata),
updated_at = CURRENT_TIMESTAMP
"""
await execute(
query,
(
normalized,
category or "anime_style",
nickname_value,
score,
0,
stat.st_size,
datetime.fromtimestamp(stat.st_mtime),
1 if bos_uploaded else 0,
region,
_serialize_extra(extra_metadata),
),
)