diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5fbf6507687aae5473bb1aac93d8141edcf322f8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,66 @@ +HELP.md +target/ +output/ +!.mvn/wrapper/maven-wrapper.jar +!**/src/main/**/target/ +!**/src/test/**/target/ + +.flattened-pom.xml + +### STS ### +.apt_generated +.classpath +.factorypath +.project +.settings +.springBeans +.sts4-cache + +### IntelliJ IDEA ### +.idea +*.iws +*.iml +*.ipr + +### NetBeans ### +/nbproject/private/ +/nbbuild/ +/dist/ +/nbdist/ +/.nb-gradle/ +build/ +!**/src/main/**/build/ +!**/src/test/**/build/ + +### VS Code ### +.vscode/ + +### LOG ### +logs/ + +*.class + +**/node_modules/ +/*.log +/output/ +/faiss/ +/web/facelist-web/ +**/._* +__pycache__/ +.DS_Store +*.pth +/data/celebrity_faces/ds_model_arcface_detector_retinaface_aligned_normalization_base_expand_0.pkl +/data/celebrity_faces/jpeg_6c06eca6.jpeg +/data/celebrity_faces/jpeg_51e1394b.jpeg +/data/celebrity_faces/jpeg_66fee390.jpeg +/data/celebrity_faces/jpeg_70b86102.jpeg +/data/celebrity_faces/jpeg_406b961a.jpeg +/data/celebrity_faces/jpeg_1321f87f.jpeg +/data/celebrity_faces/jpeg_b56ae384.jpeg +/data/celebrity_faces/jpeg_c07cdb46.jpeg +/data/celebrity_faces/jpeg_c7353005.jpeg +/data/celebrity_faces/jpeg_d4cb0602.jpeg +/data/celebrity_faces/jpeg_dbb64030.jpeg +/data/celebrity_faces/jpeg_fc652ad4.jpeg +/data/celebrity_faces/jpeg_fd6b0869.jpeg +/data/celebrity_embeddings.db diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..6f5acb0270c0fcded1e2a799a92b4376ae0e10b9 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,76 @@ +FROM python:3.10-slim + +ENV TZ=Asia/Shanghai \ + OUTPUT_DIR=/opt/data/output \ + IMAGES_DIR=/opt/data/images \ + MODELS_PATH=/opt/data/models \ + DEEPFACE_HOME=/opt/data/models \ + FAISS_INDEX_DIR=/opt/data/faiss \ + CELEBRITY_SOURCE_DIR=/opt/data/chinese_celeb_dataset \ + GENDER_CONFIDENCE=1 \ + UPSCALE_SIZE=2 \ + AGE_CONFIDENCE=0.1 \ + DRAW_SCORE=true \ + FACE_CONFIDENCE=0.7 \ + ENABLE_DDCOLOR=true \ + ENABLE_GFPGAN=true \ + ENABLE_REALESRGAN=true \ + ENABLE_ANIME_STYLE=true \ + ENABLE_RVM=true \ + ENABLE_REMBG=true \ + ENABLE_CLIP=false \ + CLEANUP_INTERVAL_HOURS=1 \ + CLEANUP_AGE_HOURS=1 \ + BEAUTY_ADJUST_GAMMA=0.8 \ + BEAUTY_ADJUST_MIN=1.0 \ + BEAUTY_ADJUST_MAX=9.0 \ + ENABLE_ANIME_PRELOAD=false \ + ENABLE_LOGGING=true \ + BEAUTY_ADJUST_ENABLED=true \ + RVM_LOCAL_REPO=/opt/data/models/RobustVideoMatting \ + RVM_WEIGHTS_PATH=/opt/data/models/torch/hub/checkpoints/rvm_resnet50.pth \ + RVM_MODEL=resnet50 \ + AUTO_INIT_GFPGAN=false \ + AUTO_INIT_DDCOLOR=false \ + AUTO_INIT_REALESRGAN=false \ + AUTO_INIT_ANIME_STYLE=false \ + AUTO_INIT_CLIP=false \ + AUTO_INIT_RVM=false \ + AUTO_INIT_REMBG=false \ + ENABLE_WARMUP=true \ + REALESRGAN_MODEL=realesr-general-x4v3 \ + CELEBRITY_FIND_THRESHOLD=0.87 \ + FEMALE_AGE_ADJUSTMENT=4 \ + HOSTNAME=HG + +RUN mkdir -p /opt/data/chinese_celeb_dataset /opt/data/faiss /opt/data/models /opt/data/images /opt/data/output +WORKDIR /app +COPY requirements.txt . +COPY *.py /app/ + +# 安装必要的系统工具和依赖 +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + cmake \ + git \ + wget \ + curl \ + ca-certificates \ + libopenblas-dev \ + liblapack-dev \ + libx11-dev \ + libgtk-3-dev \ + libboost-python-dev \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libxrender-dev \ + libgomp1 \ + && rm -rf /var/lib/apt/lists/* + +RUN pip install --upgrade pip +# 安装所有依赖 - 现在可以一次性完成 +RUN pip install --no-cache-dir -r requirements.txt +EXPOSE 7860 +CMD ["uvicorn", "app:app", "--workers", "1", "--loop", "asyncio", "--http", "httptools", "--host", "0.0.0.0", "--port", "7860", "--timeout-keep-alive", "600"] + diff --git a/README.md b/README.md index 9b03b8264f6f3791fdcb11118f404829d67f2066..27f92da954d75b0a9865f782ce4637481ea2ee65 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ --- -title: Picpocket2 -emoji: 💻 -colorFrom: indigo -colorTo: purple +title: Picpocket +emoji: 🔥 +colorFrom: yellow +colorTo: red sdk: docker pinned: false --- diff --git a/anime_stylizer.py b/anime_stylizer.py new file mode 100644 index 0000000000000000000000000000000000000000..16a0315cf7c3a1511c8cf8192461ae5e8b020512 --- /dev/null +++ b/anime_stylizer.py @@ -0,0 +1,427 @@ +import os +import tempfile +import time + +import cv2 + +from config import logger + + +class AnimeStylizer: + def __init__(self): + start_time = time.perf_counter() + self.stylizers = {} # 存储不同风格的模型 + self.current_style = None + self.current_stylizer = None + + # 检查是否启用Anime Style功能 + from config import ENABLE_ANIME_STYLE + if ENABLE_ANIME_STYLE: + self._initialize_models() + else: + logger.info("Anime Style feature is disabled, skipping model initialization") + init_time = time.perf_counter() - start_time + if hasattr(self, 'model_configs') and len(self.model_configs) > 0: + logger.info(f"AnimeStylizer initialized successfully, time: {init_time:.3f}s") + else: + logger.info(f"AnimeStylizer initialization completed but not available, time: {init_time:.3f}s") + + def _initialize_models(self): + """初始化所有Anime Style模型(使用ModelScope)""" + try: + logger.info("Initializing multiple Anime Style models (using ModelScope)...") + + # 添加torch类型兼容性补丁 + import torch + if not hasattr(torch, 'uint64'): + logger.info("Adding torch.uint64 compatibility patch...") + torch.uint64 = torch.int64 # 使用int64作为uint64的替代 + if not hasattr(torch, 'uint32'): + logger.info("Adding torch.uint32 compatibility patch...") + torch.uint32 = torch.int32 # 使用int32作为uint32的替代 + if not hasattr(torch, 'uint16'): + logger.info("Adding torch.uint16 compatibility patch...") + torch.uint16 = torch.int16 # 使用int16作为uint16的替代 + + # 导入ModelScope相关模块 + from modelscope.outputs import OutputKeys + from modelscope.pipelines import pipeline + from modelscope.utils.constant import Tasks + + self.OutputKeys = OutputKeys + + # 定义所有可用的模型和风格 + self.model_configs = { + "handdrawn": { + "model_id": "iic/cv_unet_person-image-cartoon-handdrawn_compound-models", + "name": "手绘风格", + "description": "手绘动漫风格 - 传统手绘感觉,线条清晰" + }, + "disney": { + "model_id": "iic/cv_unet_person-image-cartoon-3d_compound-models", + "name": "迪士尼风格", + "description": "迪士尼风格 - 立体感强,色彩鲜艳" + }, + "illustration": { + "model_id": "iic/cv_unet_person-image-cartoon-sd-design_compound-models", + "name": "插画风格", + "description": "插画风格 - 现代插画设计感" + }, + "artstyle": { + "model_id": "iic/cv_unet_person-image-cartoon-artstyle_compound-models", + "name": "艺术风格", + "description": "艺术风格 - 独特的艺术表现力" + }, + "anime": { + "model_id": "iic/cv_unet_person-image-cartoon_compound-models", + "name": "二次元风格", + "description": "二次元风格 - 经典动漫角色风格" + }, + "sketch": { + "model_id": "iic/cv_unet_person-image-cartoon-sketch_compound-models", + "name": "素描风格", + "description": "素描风格 - 黑白素描画效果" + } + } + + logger.info(f"Defined {len(self.model_configs)} anime style model configurations") + logger.info("Models will be loaded on-demand when first used to save memory") + + # 检查是否启用预加载 + try: + from config import ENABLE_ANIME_PRELOAD + if ENABLE_ANIME_PRELOAD: + logger.info("Enabling anime style model preloading...") + self.preload_models() + else: + logger.info("Anime style model preloading is disabled, will be loaded on-demand when first used") + except ImportError: + logger.info("Anime style model preloading configuration not found, will be loaded on-demand when first used") + + except ImportError as e: + logger.error(f"ModelScope module import failed: {e}") + self.model_configs = {} + except Exception as e: + logger.error(f"Anime Style model initialization failed: {e}") + self.model_configs = {} + + def _load_model(self, style_type): + """按需加载指定风格的模型""" + if style_type not in self.model_configs: + logger.error(f"Unsupported style type: {style_type}") + return False + + if style_type in self.stylizers: + logger.info(f"Model {style_type} already loaded, using directly") + return True + + try: + from modelscope.pipelines import pipeline + from modelscope.utils.constant import Tasks + + config = self.model_configs[style_type] + logger.info(f"Loading {config['name']} model: {config['model_id']}") + + # 根据模型类型选择合适的任务类型 + if "stable_diffusion" in config["model_id"]: + # Stable Diffusion 系列模型使用文生图任务类型 + task_type = Tasks.text_to_image_synthesis + logger.info(f"Using text_to_image_synthesis task type to load Stable Diffusion model") + else: + # UNet 系列模型使用人像风格化任务 + task_type = Tasks.image_portrait_stylization + logger.info(f"Using image_portrait_stylization task type to load UNet model") + + stylizer = pipeline(task_type, model=config["model_id"]) + self.stylizers[style_type] = stylizer + + logger.info(f"{config['name']} model loaded successfully") + return True + + except Exception as e: + logger.error(f"Failed to load {style_type} model: {e}") + return False + + def preload_models(self, style_types=None): + """ + 预加载指定的动漫风格模型 + :param style_types: 要预加载的风格类型列表,如果为None则预加载所有模型 + """ + if not self.is_available(): + logger.warning("Anime Style module is not available, cannot preload models") + return + + if style_types is None: + style_types = list(self.model_configs.keys()) + elif isinstance(style_types, str): + style_types = [style_types] + + logger.info(f"Starting to preload anime style models: {style_types}") + + successful_loads = [] + failed_loads = [] + + for style_type in style_types: + if style_type not in self.model_configs: + logger.warning(f"Unknown style type: {style_type}, skipping preload") + failed_loads.append(style_type) + continue + + try: + logger.info(f"Preloading model: {self.model_configs[style_type]['name']} ({style_type})") + if self._load_model(style_type): + successful_loads.append(style_type) + logger.info(f"✓ Successfully preloaded: {self.model_configs[style_type]['name']}") + else: + failed_loads.append(style_type) + logger.error(f"✗ Preload failed: {self.model_configs[style_type]['name']}") + except Exception as e: + logger.error(f"✗ Exception occurred while preloading model {style_type}: {e}") + failed_loads.append(style_type) + + if successful_loads: + logger.info(f"Successfully preloaded models ({len(successful_loads)}): {successful_loads}") + if failed_loads: + logger.warning(f"Failed to preload models ({len(failed_loads)}): {failed_loads}") + + logger.info(f"Anime style model preloading completed, success: {len(successful_loads)}/{len(style_types)}") + + def get_loaded_models(self): + """ + 获取已加载的模型列表 + :return: 已加载的模型风格类型列表 + """ + return list(self.stylizers.keys()) + + def is_model_loaded(self, style_type): + """ + 检查指定风格的模型是否已加载 + :param style_type: 风格类型 + :return: 是否已加载 + """ + return style_type in self.stylizers + + def get_preload_status(self): + """ + 获取模型预加载状态 + :return: 包含预加载状态的字典 + """ + total_models = len(self.model_configs) + loaded_models = len(self.stylizers) + + status = { + "total_models": total_models, + "loaded_models": loaded_models, + "preload_ratio": f"{loaded_models}/{total_models}", + "preload_percentage": round((loaded_models / total_models * 100) if total_models > 0 else 0, 1), + "available_styles": list(self.model_configs.keys()), + "loaded_styles": list(self.stylizers.keys()), + "unloaded_styles": [style for style in self.model_configs.keys() if style not in self.stylizers] + } + + return status + + def is_available(self): + """检查Anime Stylizer是否可用""" + return hasattr(self, 'model_configs') and len(self.model_configs) > 0 + + def stylize_image(self, image, style_type="disney"): + """ + 对图像进行动漫风格化 + :param image: 输入图像 (numpy array, BGR格式) + :param style_type: 动漫风格类型,支持的类型: + "handdrawn" - 手绘风格 + "disney" - 迪士尼风格 (默认) + "illustration" - 插画风格 + "flat" - 扁平风格 + "clipart" - 剪贴画风格 + "watercolor" - 水彩风格 + "artstyle" - 艺术风格 + "anime" - 二次元风格 + "sketch" - 素描风格 + :return: 动漫风格化后的图像 (numpy array, BGR格式) + """ + if not self.is_available(): + logger.error("Anime Style model not initialized") + return image + + # 加载指定风格的模型 + if not self._load_model(style_type): + logger.error(f"Failed to load {style_type} model") + return image + + return self._stylize_image_via_file(image, style_type) + + def _stylize_image_via_file(self, image, style_type="disney"): + """ + 通过临时文件进行动漫风格化 + :param image: 输入图像 (numpy array, BGR格式) + :param style_type: 动漫风格类型 + :return: 动漫风格化后的图像 (numpy array, BGR格式) + """ + try: + config = self.model_configs.get(style_type, {}) + style_name = config.get('name', style_type) + logger.info(f"Using anime stylization processing, style type: {style_name} ({style_type})") + + # 验证风格类型 + if style_type not in self.model_configs: + logger.warning(f"Invalid style type: {style_type}, using default style disney") + style_type = "disney" + + # 使用最高质量设置保存临时图像 + with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_input: + # 使用WebP格式,最高质量设置 + cv2.imwrite(tmp_input.name, image, [cv2.IMWRITE_WEBP_QUALITY, 100]) + tmp_input_path = tmp_input.name + + try: + logger.info(f"Temporary file saved to: {tmp_input_path}") + + # 使用ModelScope进行动漫风格化 + stylizer = self.stylizers[style_type] + + # 根据模型类型使用不同的调用方式 + if "stable_diffusion" in config["model_id"]: + # Stable Diffusion模型需要特殊处理 + logger.info("Using Stable Diffusion model, text parameter is required") + # 对于Stable Diffusion,必须使用'sks style'格式的提示词 + style_prompts = {} + prompt = style_prompts.get(style_type, "sks style, cartoon style artwork") + logger.info(f"Using prompt: {prompt}") + result = stylizer({"text": prompt}) + else: + # UNet模型直接处理 + result = stylizer(tmp_input_path) + + # 获取风格化后的图像 + # 不同模型的输出键名可能不同,需要适配 + if "stable_diffusion" in config["model_id"]: + # Stable Diffusion模型通常使用不同的输出键名 + logger.info(f"Stable Diffusion model output keys: {list(result.keys())}") + if 'output_imgs' in result: + stylized_image = result['output_imgs'][0] + elif 'output_img' in result: + stylized_image = result['output_img'] + elif self.OutputKeys.OUTPUT_IMG in result: + stylized_image = result[self.OutputKeys.OUTPUT_IMG] + else: + # 尝试获取第一个图像输出 + for key in result.keys(): + if isinstance(result[key], (list, tuple)) and len(result[key]) > 0: + stylized_image = result[key][0] + logger.info(f"Using output key: {key}") + break + elif hasattr(result[key], 'shape'): + stylized_image = result[key] + logger.info(f"Using output key: {key}") + break + else: + raise KeyError(f"未找到有效的图像输出键,可用键: {list(result.keys())}") + else: + # UNet模型使用标准输出键 + stylized_image = result[self.OutputKeys.OUTPUT_IMG] + + logger.info(f"Anime stylization output: size={stylized_image.shape}, type={stylized_image.dtype}") + + # ModelScope输出的图像已经是BGR格式,不需要转换 + logger.info("Anime stylization processing completed") + return stylized_image + + finally: + # 清理临时文件 + try: + os.unlink(tmp_input_path) + except: + pass + + except Exception as e: + logger.error(f"Anime stylization processing failed: {e}") + logger.info("Returning original image") + return image + + def get_available_styles(self): + """ + 获取支持的动漫风格类型 + :return: 字典,包含风格代码和描述 + """ + if not hasattr(self, 'model_configs'): + return {} + + return { + style_type: f"{config['name']} - {config['description'].split(' - ')[1]}" + for style_type, config in self.model_configs.items() + } + + def save_debug_image(self, image, filename_prefix): + """保存调试用的图像""" + try: + debug_path = f"{filename_prefix}_debug.webp" + cv2.imwrite(debug_path, image, [cv2.IMWRITE_WEBP_QUALITY, 95]) + logger.info(f"Debug image saved: {debug_path}") + return debug_path + except Exception as e: + logger.error(f"Failed to save debug image: {e}") + return None + + def test_stylization(self, test_url=None): + """ + 测试动漫风格化功能 + :param test_url: 测试图像URL,默认使用官方示例 + :return: 测试结果 + """ + if not self.is_available(): + return False, "Anime Style模型未初始化" + + try: + test_url = test_url or 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/portrait.jpg' + logger.info(f"Testing anime stylization feature, using image: {test_url}") + + # 测试默认风格 + result = self.stylizer(test_url) + stylized_img = result[self.OutputKeys.OUTPUT_IMG] + + # 保存测试结果 + test_output_path = 'anime_style_test_result.webp' + cv2.imwrite(test_output_path, stylized_img, [cv2.IMWRITE_WEBP_QUALITY, 95]) + + logger.info(f"Anime stylization test successful, result saved to: {test_output_path}") + return True, f"测试成功,结果保存到: {test_output_path}" + + except Exception as e: + logger.error(f"Anime stylization test failed: {e}") + return False, f"测试失败: {e}" + + def test_local_image(self, image_path, style_type="disney"): + """ + 测试本地图像动漫风格化 + :param image_path: 本地图像路径 + :param style_type: 动漫风格类型 + :return: 测试结果 + """ + if not self.is_available(): + return False, "Anime Style模型未初始化" + + try: + logger.info(f"Testing local image anime stylization: {image_path}, style: {style_type}") + + # 读取本地图像 + image = cv2.imread(image_path) + if image is None: + return False, f"Unable to read image: {image_path}" + + # 保存原图用于对比 + self.save_debug_image(image, "original") + + # 动漫风格化处理 + stylized_image = self.stylize_image(image, style_type) + + # 保存风格化结果 + result_path = self.save_debug_image(stylized_image, f"anime_style_{style_type}") + + logger.info(f"Local image anime stylization successful, result saved to: {result_path}") + return True, f"本地图像动漫风格化成功,结果保存到: {result_path}" + + except Exception as e: + logger.error(f"Local image anime stylization failed: {e}") + return False, f"本地图像动漫风格化失败: {e}" diff --git a/api_routes.py b/api_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..7fee081c5f4682ac9dd85c9db0c53370df61ba03 --- /dev/null +++ b/api_routes.py @@ -0,0 +1,4675 @@ +import asyncio +import base64 +import functools +import glob +import hashlib +import inspect +import io +import json +import os +import shutil +import time +import uuid +import subprocess +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +import cv2 +import numpy as np +from fastapi import APIRouter, File, UploadFile, HTTPException, Query, Request, \ + Form + +try: + from tensorflow.keras import backend as keras_backend +except ImportError: + try: + from tf_keras import backend as keras_backend # type: ignore + except ImportError: + keras_backend = None + +try: + from starlette.datastructures import \ + UploadFile as StarletteUploadFile # 更精确的类型匹配 +except Exception: + StarletteUploadFile = None +from fastapi.responses import JSONResponse, FileResponse, HTMLResponse + +import wx_access_token +from config import logger, OUTPUT_DIR, IMAGES_DIR, DEEPFACE_AVAILABLE, \ + DLIB_AVAILABLE, GFPGAN_AVAILABLE, DDCOLOR_AVAILABLE, REALESRGAN_AVAILABLE, \ + UPSCALE_SIZE, CLIP_AVAILABLE, REALESRGAN_MODEL, REMBG_AVAILABLE, \ + ANIME_STYLE_AVAILABLE, SAVE_QUALITY, \ + AUTO_INIT_ANALYZER, AUTO_INIT_GFPGAN, AUTO_INIT_DDCOLOR, \ + AUTO_INIT_REALESRGAN, MODELS_PATH, \ + AUTO_INIT_REMBG, AUTO_INIT_ANIME_STYLE, RVM_AVAILABLE, AUTO_INIT_RVM, \ + FACE_SCORE_MAX_IMAGES, FEMALE_AGE_ADJUSTMENT, \ + FEMALE_AGE_ADJUSTMENT_THRESHOLD, CELEBRITY_SOURCE_DIR, \ + CELEBRITY_FIND_THRESHOLD +from database import ( + record_image_creation, + fetch_paged_image_records, + count_image_records, + fetch_records_by_paths, + infer_category_from_filename, + fetch_today_category_counts, +) + +SERVER_HOSTNAME = os.environ.get("HOSTNAME", "") + +# 尝试导入DeepFace +deepface_module = None +if DEEPFACE_AVAILABLE: + t_start = time.perf_counter() + + t_start = time.perf_counter() + + try: + from deepface import DeepFace + deepface_module = DeepFace + + # 为 DeepFace.verify 方法添加兼容性包装 + _original_verify = getattr(DeepFace, 'verify', None) + + if _original_verify: + def _wrapped_verify(*args, **kwargs): + """ + 包装 DeepFace.verify 方法以处理 SymbolicTensor 错误 + """ + try: + return _original_verify(*args, **kwargs) + except AttributeError as attr_err: + if "numpy" not in str(attr_err): + raise + logger.warning("DeepFace verify 触发 numpy AttributeError,尝试清理模型后重试") + _recover_deepface_model() + return _original_verify(*args, **kwargs) + except Exception as generic_exc: + if "SymbolicTensor" not in str(generic_exc) and "numpy" not in str(generic_exc): + raise + logger.warning( + f"DeepFace verify 触发 SymbolicTensor 异常({generic_exc}), 尝试清理模型后重试" + ) + _recover_deepface_model() + return _original_verify(*args, **kwargs) + + DeepFace.verify = _wrapped_verify + logger.info("Patched DeepFace.verify for SymbolicTensor compatibility") + + try: + from deepface.models import FacialRecognition as df_facial_recognition + + _original_forward = df_facial_recognition.FacialRecognition.forward + + def _safe_tensor_to_numpy(output_obj): + """尝试把tensorflow张量、安全列表转换为numpy数组。""" + if output_obj is None: + return None + if hasattr(output_obj, "numpy"): + try: + return output_obj.numpy() + except Exception: + return None + if isinstance(output_obj, np.ndarray): + return output_obj + if isinstance(output_obj, (list, tuple)): + # DeepFace只关心第一个输出 + for item in output_obj: + result = _safe_tensor_to_numpy(item) + if result is not None: + return result + return None + + def _patched_forward(self, img): + """ + 兼容Keras 3 / tf_keras 返回SymbolicTensor的情况,必要时退回predict。 + """ + try: + return _original_forward(self, img) + except AttributeError as attr_err: + if "numpy" not in str(attr_err): + raise + logger.warning("DeepFace 原始 forward 触发 numpy AttributeError,启用兼容路径") + except Exception as generic_exc: + if "SymbolicTensor" not in str(generic_exc) and "numpy" not in str(generic_exc): + raise + logger.warning( + f"DeepFace 原始 forward 触发 SymbolicTensor 异常({generic_exc}), 启用兼容路径" + ) + + if img.ndim == 3: + img = np.expand_dims(img, axis=0) + + if img.ndim != 4: + raise ValueError( + f"Input image must be (N, X, X, 3) shaped but it is {img.shape}" + ) + + embeddings = None + try: + outputs = self.model(img, training=False) + embeddings = _safe_tensor_to_numpy(outputs) + except Exception as call_exc: + logger.info(f"DeepFace forward fallback self.model 调用失败,改用 predict: {call_exc}") + + if embeddings is None: + # Keras 3 调用 self.model(...) 可能返回SymbolicTensor,退回 predict + predict_fn = getattr(self.model, "predict", None) + if predict_fn is None: + raise RuntimeError("DeepFace model 没有 predict 方法,无法转换 SymbolicTensor") + embeddings = predict_fn(img, verbose=0) + + embeddings = np.asarray(embeddings) + if embeddings.ndim == 0: + raise ValueError("Embeddings output is empty.") + + if embeddings.shape[0] == 1: + return embeddings[0].tolist() + return embeddings.tolist() + + df_facial_recognition.FacialRecognition.forward = _patched_forward + logger.info("Patched DeepFace FacialRecognition.forward for SymbolicTensor compatibility") + except Exception as patch_exc: + logger.warning(f"Failed to patch DeepFace forward method: {patch_exc}") + logger.info("DeepFace module imported successfully") + except ImportError as e: + logger.error(f"Failed to import DeepFace: {e}") + DEEPFACE_AVAILABLE = False + +# 添加模块初始化日志 +logger.info("Starting initialization of api_routes module...") +logger.info(f"Configuration status - GFPGAN: {GFPGAN_AVAILABLE}, DDCOLOR: {DDCOLOR_AVAILABLE}, REALESRGAN: {REALESRGAN_AVAILABLE}, REMBG: {REMBG_AVAILABLE}, CLIP: {CLIP_AVAILABLE}, ANIME_STYLE: {ANIME_STYLE_AVAILABLE}") + +# 初始化CLIP相关功能 +clip_encode_image = None +clip_encode_text = None +add_image_vector = None +search_text_vector = None +check_image_exists = None + +if CLIP_AVAILABLE: + try: + from clip_utils import encode_image, encode_text + from vector_store import add_image_vector, search_text_vector, check_image_exists + clip_encode_image = encode_image + clip_encode_text = encode_text + logger.info("CLIP text-image retrieval function initialized successfully") + except Exception as e: + logger.error(f"CLIP function import failed: {e}") + CLIP_AVAILABLE = False + +# 创建线程池执行器用于异步处理CPU密集型任务 +executor = ThreadPoolExecutor(max_workers=4) + + +def _log_stage_duration(stage: str, start_time: float, extra: str | None = None) -> float: + """ + 统一的耗时日志输出,便于快速定位慢点。 + """ + elapsed = time.perf_counter() - start_time + if extra: + logger.info("耗时统计 | %s: %.3fs (%s)", stage, elapsed, extra) + else: + logger.info("耗时统计 | %s: %.3fs", stage, elapsed) + return elapsed + + +async def process_cpu_intensive_task(func, *args, **kwargs): + """ + 异步执行CPU密集型任务 + :param func: 要执行的函数 + :param args: 函数参数 + :param kwargs: 函数关键字参数 + :return: 函数执行结果 + """ + loop = asyncio.get_event_loop() + return await loop.run_in_executor(executor, lambda: func(*args, **kwargs)) + + +def _keep_cpu_busy(duration: float, inner_loops: int = 5000) -> Dict[str, Any]: + """ + 在给定时间内执行纯CPU计算,用于防止服务器进入空闲态。 + """ + if duration <= 0: + return {"iterations": 0, "checksum": 0, "elapsed": 0.0} + + end_time = time.perf_counter() + duration + iterations = 0 + checksum = 0 + mask = (1 << 64) - 1 + start = time.perf_counter() + + while time.perf_counter() < end_time: + iterations += 1 + payload = f"{iterations}-{checksum}".encode("utf-8") + digest = hashlib.sha256(payload).digest() + checksum ^= int.from_bytes(digest[:8], "big") + checksum &= mask + + for _ in range(inner_loops): + checksum = ((checksum << 7) | (checksum >> 57)) & mask + checksum ^= 0xA5A5A5A5A5A5A5A5 + + return { + "iterations": iterations, + "checksum": checksum, + "elapsed": time.perf_counter() - start, + } + +deepface_call_lock: Optional[asyncio.Lock] = None + + +def _ensure_deepface_lock() -> asyncio.Lock: + """延迟初始化DeepFace调用锁,避免多线程混用同一模型导致状态损坏。""" + global deepface_call_lock + if deepface_call_lock is None: + deepface_call_lock = asyncio.Lock() + return deepface_call_lock + + +def _clear_keras_session() -> bool: + """清理Keras会话,防止模型状态异常持续存在。""" + if keras_backend is None: + return False + try: + keras_backend.clear_session() + return True + except Exception as exc: + logger.warning(f"清理Keras会话失败: {exc}") + return False + + +def _reset_deepface_model_cache(model_name: str = "ArcFace") -> None: + """移除DeepFace内部缓存的模型,确保下次调用重新加载。""" + if deepface_module is None: + return + try: + from deepface.commons import functions + except Exception as exc: + logger.warning( + f"无法导入deepface.commons.functions,跳过模型缓存重置: {exc}") + return + + removed = False + for attr_name in ("models", "model_cache", "built_models"): + cache = getattr(functions, attr_name, None) + if isinstance(cache, dict) and model_name in cache: + cache.pop(model_name, None) + removed = True + if removed: + logger.info(f"已清除DeepFace缓存模型: {model_name}") + + +def _recover_deepface_model(model_name: str = "ArcFace") -> None: + """组合清理动作,尽量恢复DeepFace模型可用状态。""" + cleared = _clear_keras_session() + _reset_deepface_model_cache(model_name) + if cleared: + logger.info(f"Keras会话已清理,将在下次调用时重新加载模型: {model_name}") + + +from models import ( + ModelType, + ImageFileList, + PagedImageFileList, + SearchRequest, + CelebrityMatchResponse, + CategoryStatsResponse, + CategoryStatItem, +) + +from face_analyzer import EnhancedFaceAnalyzer +from utils import ( + save_image_high_quality, + save_image_with_transparency, + human_readable_size, + convert_numpy_types, + compress_image_by_quality, + compress_image_by_dimensions, + compress_image_by_file_size, + convert_image_format, + upload_file_to_bos, + ensure_bos_resources, + download_bos_directory, +) +from cleanup_scheduler import get_cleanup_status, manual_cleanup + +# 初始化照片修复器(优先GFPGAN,备选简单修复器) +photo_restorer = None +restorer_type = "none" + +# 优先尝试GFPGAN(可配置是否启动时自动初始化) +if GFPGAN_AVAILABLE and AUTO_INIT_GFPGAN: + try: + from gfpgan_restorer import GFPGANRestorer + t_start = time.perf_counter() + photo_restorer = GFPGANRestorer() + init_time = time.perf_counter() - t_start + if photo_restorer.is_available(): + restorer_type = "gfpgan" + logger.info(f"GFPGAN restorer initialized successfully, time: {init_time:.3f}s") + else: + photo_restorer = None + logger.info(f"GFPGAN restorer initialization completed but not available, time: {init_time:.3f}s") + except Exception as e: + init_time = time.perf_counter() - t_start + logger.error(f"Failed to initialize GFPGAN restorer, time: {init_time:.3f}s, error: {e}") + photo_restorer = None +else: + logger.info("GFPGAN restorer is set to lazy initialization or unavailable") + +# 初始化DDColor上色器 +ddcolor_colorizer = None +if DDCOLOR_AVAILABLE and AUTO_INIT_DDCOLOR: + try: + from ddcolor_colorizer import DDColorColorizer + t_start = time.perf_counter() + ddcolor_colorizer = DDColorColorizer() + init_time = time.perf_counter() - t_start + if ddcolor_colorizer.is_available(): + logger.info(f"DDColor colorizer initialized successfully, time: {init_time:.3f}s") + else: + ddcolor_colorizer = None + logger.info(f"DDColor colorizer initialization completed but not available, time: {init_time:.3f}s") + except Exception as e: + init_time = time.perf_counter() - t_start + logger.error(f"Failed to initialize DDColor colorizer, time: {init_time:.3f}s, error: {e}") + ddcolor_colorizer = None +else: + logger.info("DDColor colorizer is set to lazy initialization or unavailable") + +# 如果GFPGAN不可用,服务将无法提供照片修复功能 +if photo_restorer is None: + logger.warning("Photo restoration feature unavailable: GFPGAN initialization failed") + +if ddcolor_colorizer is None: + if DDCOLOR_AVAILABLE: + logger.warning("Photo colorization feature unavailable: DDColor initialization failed") + else: + logger.info("Photo colorization feature not enabled or unavailable") + +# 初始化Real-ESRGAN超清处理器 +realesrgan_upscaler = None +if REALESRGAN_AVAILABLE and AUTO_INIT_REALESRGAN: + try: + from realesrgan_upscaler import get_upscaler + t_start = time.perf_counter() + realesrgan_upscaler = get_upscaler() + init_time = time.perf_counter() - t_start + if realesrgan_upscaler.is_available(): + logger.info(f"Real-ESRGAN super resolution processor initialized successfully, time: {init_time:.3f}s") + else: + realesrgan_upscaler = None + logger.info(f"Real-ESRGAN super resolution processor initialization completed but not available, time: {init_time:.3f}s") + except Exception as e: + init_time = time.perf_counter() - t_start + logger.error(f"Failed to initialize Real-ESRGAN super resolution processor, time: {init_time:.3f}s, error: {e}") + realesrgan_upscaler = None +else: + logger.info("Real-ESRGAN super resolution processor is set to lazy initialization or unavailable") + +if realesrgan_upscaler is None: + if REALESRGAN_AVAILABLE: + logger.warning("Photo super resolution feature unavailable: Real-ESRGAN initialization failed") + else: + logger.info("Photo super resolution feature not enabled or unavailable") + +# 初始化rembg抠图处理器 +rembg_processor = None +if REMBG_AVAILABLE and AUTO_INIT_REMBG: + try: + from rembg_processor import RembgProcessor + t_start = time.perf_counter() + rembg_processor = RembgProcessor() + init_time = time.perf_counter() - t_start + if rembg_processor.is_available(): + logger.info(f"rembg background removal processor initialized successfully, time: {init_time:.3f}s") + else: + rembg_processor = None + logger.info(f"rembg background removal processor initialization completed but not available, time: {init_time:.3f}s") + except Exception as e: + init_time = time.perf_counter() - t_start + logger.error(f"Failed to initialize rembg background removal processor, time: {init_time:.3f}s, error: {e}") + rembg_processor = None +else: + logger.info("rembg background removal processor is set to lazy initialization or unavailable") + +if rembg_processor is None: + if REMBG_AVAILABLE: + logger.warning("ID photo background removal feature unavailable: rembg initialization failed") + else: + logger.info("ID photo background removal feature not enabled or unavailable") + +# 初始化RVM抠图处理器 +rvm_processor = None +if RVM_AVAILABLE and AUTO_INIT_RVM: + try: + from rvm_processor import RVMProcessor + t_start = time.perf_counter() + rvm_processor = RVMProcessor() + init_time = time.perf_counter() - t_start + if rvm_processor.is_available(): + logger.info(f"RVM background removal processor initialized successfully, time: {init_time:.3f}s") + else: + rvm_processor = None + logger.info(f"RVM background removal processor initialization completed but not available, time: {init_time:.3f}s") + except Exception as e: + init_time = time.perf_counter() - t_start + logger.error(f"Failed to initialize RVM background removal processor, time: {init_time:.3f}s, error: {e}") + rvm_processor = None +else: + logger.info("RVM background removal processor is set to lazy initialization or unavailable") + +if rvm_processor is None: + if RVM_AVAILABLE: + logger.warning("RVM background removal feature unavailable: initialization failed") + else: + logger.info("RVM background removal feature not enabled or unavailable") + +# 初始化动漫风格化处理器 +anime_stylizer = None +if ANIME_STYLE_AVAILABLE and AUTO_INIT_ANIME_STYLE: + try: + from anime_stylizer import AnimeStylizer + t_start = time.perf_counter() + anime_stylizer = AnimeStylizer() + init_time = time.perf_counter() - t_start + if anime_stylizer.is_available(): + logger.info(f"Anime stylization processor initialized successfully, time: {init_time:.3f}s") + else: + anime_stylizer = None + logger.info(f"Anime stylization processor initialization completed but not available, time: {init_time:.3f}s") + except Exception as e: + init_time = time.perf_counter() - t_start + logger.error(f"Failed to initialize anime stylization processor, time: {init_time:.3f}s, error: {e}") + anime_stylizer = None +else: + logger.info("Anime stylization processor is set to lazy initialization or unavailable") + +if anime_stylizer is None: + if ANIME_STYLE_AVAILABLE: + logger.warning("Anime stylization feature unavailable: AnimeStylizer initialization failed") + else: + logger.info("Anime stylization feature not enabled or unavailable") + +def _ensure_analyzer(): + global analyzer + if analyzer is None: + try: + analyzer = EnhancedFaceAnalyzer() + logger.info("Face analyzer delayed initialization successful") + except Exception as e: + logger.error(f"Failed to initialize analyzer: {e}") + analyzer = None + +# 初始化分析器(可配置是否在启动时自动初始化) +analyzer = None +if AUTO_INIT_ANALYZER: + t_start = time.perf_counter() + _ensure_analyzer() + init_time = time.perf_counter() - t_start + if analyzer is not None: + logger.info(f"Face analyzer initialized successfully, time: {init_time:.3f}s") + else: + logger.info(f"Face analyzer initialization completed but not available, time: {init_time:.3f}s") + +# 创建路由 +api_router = APIRouter(prefix="/facescore", tags=["Face API"]) +logger.info("API router initialization completed") + + +# 延迟初始化工具函数 +def _ensure_photo_restorer(): + global photo_restorer, restorer_type + if photo_restorer is None and GFPGAN_AVAILABLE: + try: + from gfpgan_restorer import GFPGANRestorer + photo_restorer = GFPGANRestorer() + if photo_restorer.is_available(): + restorer_type = "gfpgan" + logger.info("GFPGAN restorer delayed initialization successful") + except Exception as e: + logger.error(f"GFPGAN restorer delayed initialization failed: {e}") + +def _ensure_ddcolor(): + global ddcolor_colorizer + if ddcolor_colorizer is None and DDCOLOR_AVAILABLE: + try: + from ddcolor_colorizer import DDColorColorizer + ddcolor_colorizer = DDColorColorizer() + if ddcolor_colorizer.is_available(): + logger.info("DDColor colorizer delayed initialization successful") + except Exception as e: + logger.error(f"DDColor colorizer delayed initialization failed: {e}") + +def _ensure_realesrgan(): + global realesrgan_upscaler + if realesrgan_upscaler is None and REALESRGAN_AVAILABLE: + try: + from realesrgan_upscaler import get_upscaler + realesrgan_upscaler = get_upscaler() + if realesrgan_upscaler.is_available(): + logger.info("Real-ESRGAN super resolution processor delayed initialization successful") + except Exception as e: + logger.error(f"Real-ESRGAN super resolution processor delayed initialization failed: {e}") + +def _ensure_rembg(): + global rembg_processor + if rembg_processor is None and REMBG_AVAILABLE: + try: + from rembg_processor import RembgProcessor + rembg_processor = RembgProcessor() + if rembg_processor.is_available(): + logger.info("rembg background removal processor delayed initialization successful") + except Exception as e: + logger.error(f"rembg background removal processor delayed initialization failed: {e}") + +def _ensure_rvm(): + global rvm_processor + if rvm_processor is None and RVM_AVAILABLE: + try: + from rvm_processor import RVMProcessor + rvm_processor = RVMProcessor() + if rvm_processor.is_available(): + logger.info("RVM background removal processor delayed initialization successful") + except Exception as e: + logger.error(f"RVM background removal processor delayed initialization failed: {e}") + +def _ensure_anime_stylizer(): + global anime_stylizer + if anime_stylizer is None and ANIME_STYLE_AVAILABLE: + try: + from anime_stylizer import AnimeStylizer + anime_stylizer = AnimeStylizer() + if anime_stylizer.is_available(): + logger.info("Anime stylization processor delayed initialization successful") + except Exception as e: + logger.error(f"Anime stylization processor delayed initialization failed: {e}") + + +async def handle_image_vector_async(file_path: str, image_name: str): + """异步处理图片向量化""" + try: + # 检查图像是否已经存在于向量库中 + t_check = time.perf_counter() + exists = await asyncio.get_event_loop().run_in_executor( + executor, check_image_exists, image_name + ) + logger.info(f"[Async] Time to check if image exists: {time.perf_counter() - t_check:.3f}s") + + if exists: + logger.info(f"[Async] Image {image_name} already exists in vector library, skipping vectorization") + return + + t1 = time.perf_counter() + # 把 encode_image 放进线程池执行 + img_vector = await asyncio.get_event_loop().run_in_executor( + executor, clip_encode_image, file_path + ) + logger.info(f"[Async] Image vectorization time: {time.perf_counter() - t1:.3f}s") + + # 同样,把 add_image_vector 也放进线程池执行 + t2 = time.perf_counter() + await asyncio.get_event_loop().run_in_executor( + executor, add_image_vector, image_name, img_vector + ) + logger.info(f"[Async] Vectorization storage time: {time.perf_counter() - t2:.3f}s") + except Exception as e: + import traceback + logger.error(f"[Async] Image vector processing failed: {str(e)}") + traceback.print_exc() + + +def _encode_basename(name: str) -> str: + encoded = base64.urlsafe_b64encode(name.encode("utf-8")).decode("ascii") + return encoded.rstrip("=") + + +def _decode_basename(encoded: str) -> str: + padding = "=" * ((4 - len(encoded) % 4) % 4) + try: + return base64.urlsafe_b64decode( + (encoded + padding).encode("ascii")).decode("utf-8") + except Exception: + return encoded + + +def _iter_celebrity_images(base_dir: str) -> List[str]: + allowed_extensions = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} + images = [] + for root, _, files in os.walk(base_dir): + for filename in files: + if filename.startswith('.'): + continue + if not any( + filename.lower().endswith(ext) for ext in allowed_extensions): + continue + images.append(os.path.join(root, filename)) + return images + + +CATEGORY_ALIAS_MAP = { + "face": "face", + "original": "original", + "restore": "restore", + "upcolor": "upcolor", + "compress": "compress", + "upscale": "upscale", + "anime_style": "anime_style", + "animestyle": "anime_style", + "anime-style": "anime_style", + "grayscale": "grayscale", + "gray": "grayscale", + "id_photo": "id_photo", + "idphoto": "id_photo", + "grid": "grid", + "rvm": "rvm", + "celebrity": "celebrity", + "all": "all", + "other": "other", +} + +CATEGORY_DISPLAY_NAMES = { + "face": "人脸", + "original": "评分原图", + "restore": "修复", + "upcolor": "上色", + "compress": "压缩", + "upscale": "超清", + "anime_style": "动漫风格", + "grayscale": "黑白", + "id_photo": "证件照", + "grid": "宫格", + "rvm": "RVM抠图", + "celebrity": "明星识别", + "other": "其他", + "unknown": "未知", +} + +CATEGORY_DISPLAY_ORDER = [ + "face", + "original", + "celebrity", + "restore", + "upcolor", + "compress", + "upscale", + "anime_style", + "grayscale", + "id_photo", + "grid", + "rvm", + "other", + "unknown", +] + + +def _normalize_search_category(search_type: Optional[str]) -> Optional[str]: + """将前端传入的 searchType 映射为数据库中的类别""" + if not search_type: + return None + search_type = search_type.lower() + return CATEGORY_ALIAS_MAP.get(search_type, "other") + + +async def _record_output_file( + file_path: str, + nickname: Optional[str], + *, + category: Optional[str] = None, + bos_uploaded: bool = False, + score: Optional[float] = None, + extra: Optional[Dict[str, Any]] = None, +) -> None: + """封装的图片记录写入,避免影响主流程""" + try: + score_value = float(score) if score is not None else 0.0 + except (TypeError, ValueError): + logger.warning("score 转换失败,已回退为 0,file=%s raw_score=%r", + file_path, score) + score_value = 0.0 + + async def _write_record() -> None: + start_time = time.perf_counter() + try: + await record_image_creation( + file_path=file_path, + nickname=nickname, + category=category, + bos_uploaded=bos_uploaded, + score=score_value, + extra_metadata=extra, + ) + duration = time.perf_counter() - start_time + logger.info( + "MySQL记录完成 file=%s category=%s nickname=%s score=%.4f bos_uploaded=%s cost=%.3fs", + os.path.basename(file_path), + category or "auto", + nickname or "", + score_value, + bos_uploaded, + duration, + ) + except Exception as exc: + logger.warning(f"记录图片到数据库失败: {exc}") + + asyncio.create_task(_write_record()) + + +async def _refresh_celebrity_cache(sample_image_path: str, + db_path: str) -> None: + """刷新DeepFace数据库缓存""" + if not DEEPFACE_AVAILABLE or deepface_module is None: + return + + if not os.path.exists(sample_image_path): + return + + if not os.path.isdir(db_path): + return + + lock = _ensure_deepface_lock() + async with lock: + try: + await process_cpu_intensive_task( + deepface_module.find, + img_path=sample_image_path, + db_path=db_path, + model_name="ArcFace", + detector_backend="yolov11n", + distance_metric="cosine", + enforce_detection=True, + silent=True, + refresh_database=True, + ) + except (AttributeError, RuntimeError) as attr_exc: + if "numpy" in str(attr_exc) or "SymbolicTensor" in str(attr_exc): + logger.warning( + f"刷新明星向量缓存遇到 numpy/SymbolicTensor 异常,尝试恢复后重试: {attr_exc}") + _recover_deepface_model() + try: + await process_cpu_intensive_task( + deepface_module.find, + img_path=sample_image_path, + db_path=db_path, + model_name="ArcFace", + detector_backend="yolov11n", + distance_metric="cosine", + enforce_detection=True, + silent=True, + refresh_database=True, + ) + except Exception as retry_exc: + logger.warning(f"恢复后重新刷新明星缓存仍失败: {retry_exc}") + else: + raise + except ValueError as exc: + logger.warning( + f"刷新明星向量缓存遇到模型状态异常,尝试恢复后重试: {exc}") + _recover_deepface_model() + try: + await process_cpu_intensive_task( + deepface_module.find, + img_path=sample_image_path, + db_path=db_path, + model_name="ArcFace", + detector_backend="yolov11n", + distance_metric="cosine", + enforce_detection=True, + silent=True, + refresh_database=True, + ) + except Exception as retry_exc: + logger.warning(f"恢复后重新刷新明星缓存仍失败: {retry_exc}") + except Exception as e: + logger.warning(f"Refresh celebrity cache failed: {e}") + + +async def _log_progress(task_name: str, + start_time: float, + stop_event: asyncio.Event, + interval: float = 5.0) -> None: + """周期性输出进度日志,避免长时间无输出""" + try: + while True: + try: + await asyncio.wait_for(stop_event.wait(), timeout=interval) + break + except asyncio.TimeoutError: + elapsed = time.perf_counter() - start_time + logger.info(f"{task_name}进行中... 已耗时 {elapsed:.1f}秒") + elapsed = time.perf_counter() - start_time + logger.info(f"{task_name}完成,总耗时 {elapsed:.1f}秒") + except Exception as exc: + logger.warning(f"进度日志任务异常: {exc}") + + +# 通用入参日志装饰器:记录所有接口的入参;若为文件,记录文件名和大小 +def log_api_params(func): + sig = inspect.signature(func) + is_coro = inspect.iscoroutinefunction(func) + + def _is_upload_file(obj: Any) -> bool: + try: + if obj is None: + return False + if isinstance(obj, (bytes, bytearray, str)): + return False + if isinstance(obj, UploadFile): + return True + if StarletteUploadFile is not None and isinstance(obj, + StarletteUploadFile): + return True + # Duck typing: 具备文件相关属性即视为上传文件 + return hasattr(obj, "filename") and hasattr(obj, "file") + except Exception: + return False + + def _upload_file_info(f: UploadFile): + try: + size = getattr(f, "size", None) + if size is None and hasattr(f, "file") and hasattr(f.file, + "tell") and hasattr( + f.file, "seek"): + try: + pos = f.file.tell() + f.file.seek(0, io.SEEK_END) + size = f.file.tell() + f.file.seek(pos, io.SEEK_SET) + except Exception: + size = None + except Exception: + size = None + return { + "type": "file", + "filename": getattr(f, "filename", None), + "size": size, + "content_type": getattr(f, "content_type", None), + } + + def _sanitize_val(name: str, val: Any): + try: + if _is_upload_file(val): + return _upload_file_info(val) + if isinstance(val, (list, tuple)) and ( + len(val) == 0 or _is_upload_file(val[0])): + files = [] + for f in val or []: + files.append( + _upload_file_info(f) if _is_upload_file(f) else str(f)) + return {"type": "files", "count": len(val or []), + "files": files} + if isinstance(val, Request): + # 不记录任何 header/url/client 等潜在敏感信息 + return {"type": "request"} + if val is None: + return None + if hasattr(val, "model_dump"): + data = val.model_dump() + return convert_numpy_types(data) + if hasattr(val, "dict") and callable(getattr(val, "dict")): + data = val.dict() + return convert_numpy_types(data) + if isinstance(val, (bytes, bytearray)): + return f"" + if isinstance(val, (str, int, float, bool)): + if isinstance(val, str) and len(val) > 200: + return val[:200] + "...(truncated)" + return val + # 兜底转换 + return json.loads(json.dumps(val, default=str)) + except Exception as e: + return f"" + + async def _async_wrapper(*args, **kwargs): + try: + bound = sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + payload = {name: _sanitize_val(name, val) for name, val in + bound.arguments.items()} + logger.info( + f"==> http {json.dumps(convert_numpy_types(payload), ensure_ascii=False)}") + except Exception as e: + logger.warning(f"Failed to log params for {func.__name__}: {e}") + return await func(*args, **kwargs) + + def _sync_wrapper(*args, **kwargs): + try: + bound = sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + payload = {name: _sanitize_val(name, val) for name, val in + bound.arguments.items()} + logger.info( + f"==> http {json.dumps(convert_numpy_types(payload), ensure_ascii=False)}") + except Exception as e: + logger.warning(f"Failed to log params for {func.__name__}: {e}") + return func(*args, **kwargs) + + if is_coro: + return functools.wraps(func)(_async_wrapper) + else: + return functools.wraps(func)(_sync_wrapper) + + +@api_router.post(path="/upload_file", tags=["文件上传"]) +@log_api_params +async def upload_file( + file: UploadFile = File(...), + fileType: str = Form( + None, + description="文件类型,如 'idphoto' 表示证件照上传" + ), + nickname: str = Form( + None, + description="操作者昵称,用于记录到数据库" + ), +): + """ + 文件上传接口:接收上传的文件,保存到本地并返回文件名。 + - 文件名规则:{uuid}_save_id_photo.{ext} + - 保存目录:IMAGES_DIR + - 如果 fileType='idphoto',则调用图片修复接口 + """ + if not file: + raise HTTPException(status_code=400, detail="请上传文件") + + try: + contents = await file.read() + if not contents: + raise HTTPException(status_code=400, detail="文件内容为空") + + # 获取原始文件扩展名 + _, file_extension = os.path.splitext(file.filename) + # 如果没有扩展名,使用空扩展名(保持用户上传文件的原始格式) + + # 生成唯一ID + unique_id = str(uuid.uuid4()).replace('-', '') + extra_meta_base = { + "source": "upload_file", + "file_type": fileType, + "original_filename": file.filename, + } + + # 特殊处理:证件照类型,先做老照片修复再保存 + if fileType == 'idphoto': + try: + # 解码图片 + np_arr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException(status_code=400, + detail="无法解析图片文件") + + # 确保修复器可用 + _ensure_photo_restorer() + restored_with_model = ( + photo_restorer is not None and photo_restorer.is_available() + ) + if not restored_with_model: + logger.warning( + "GFPGAN 修复器不可用,跳过修复,按原样保存证件照") + # 按原样保存 + saved_filename = f"{unique_id}_save_id_photo{file_extension}" + saved_path = os.path.join(IMAGES_DIR, saved_filename) + with open(saved_path, "wb") as f: + f.write(contents) + # bos_uploaded = upload_file_to_bos(saved_path) + else: + t1 = time.perf_counter() + logger.info( + "Start restoring uploaded ID photo before saving...") + # 执行修复 + restored_image = await process_cpu_intensive_task( + photo_restorer.restore_image, image) + # 以 webp 高质量保存,命名与证件照区分 + saved_filename = f"{unique_id}_save_id_photo_restore.webp" + saved_path = os.path.join(IMAGES_DIR, saved_filename) + if not save_image_high_quality(restored_image, saved_path, + quality=SAVE_QUALITY): + raise HTTPException(status_code=500, + detail="保存修复后图像失败") + logger.info( + f"ID photo restored and saved: {saved_filename}, time: {time.perf_counter() - t1:.3f}s") + # bos_uploaded = upload_file_to_bos(saved_path) + + # 可选:向量化入库(与其他接口保持一致) + if CLIP_AVAILABLE: + asyncio.create_task( + handle_image_vector_async(saved_path, saved_filename)) + + await _record_output_file( + file_path=saved_path, + nickname=nickname, + category="id_photo", + bos_uploaded=True, + extra={ + **{k: v for k, v in extra_meta_base.items() if v}, + "restored_with_model": restored_with_model, + }, + ) + + return { + "success": True, + "message": "上传成功(已修复)" if photo_restorer is not None and photo_restorer.is_available() else "上传成功", + "filename": saved_filename, + } + except HTTPException: + raise + except Exception as e: + logger.error(f"证件照上传修复流程失败,改为直接保存: {e}") + # 失败兜底:直接保存原文件 + saved_filename = f"{unique_id}_save_id_photo{file_extension}" + saved_path = os.path.join(IMAGES_DIR, saved_filename) + try: + with open(saved_path, "wb") as f: + f.write(contents) + await _record_output_file( + file_path=saved_path, + nickname=nickname, + category="id_photo", + bos_uploaded=True, + extra={ + **{k: v for k, v in extra_meta_base.items() if v}, + "restored_with_model": False, + "fallback": True, + }, + ) + except Exception as se: + logger.error(f"保存文件失败: {se}") + raise HTTPException(status_code=500, detail="保存文件失败") + return { + "success": True, + "message": "上传成功(修复失败,已原样保存)", + "filename": saved_filename, + } + + # 默认:普通文件直接保存原始内容 + saved_filename = f"{unique_id}_save_file{file_extension}" + saved_path = os.path.join(IMAGES_DIR, saved_filename) + try: + with open(saved_path, "wb") as f: + f.write(contents) + bos_uploaded = upload_file_to_bos(saved_path) + logger.info(f"文件上传成功: {saved_filename}") + await _record_output_file( + file_path=saved_path, + nickname=nickname, + bos_uploaded=bos_uploaded, + extra={ + **{k: v for k, v in extra_meta_base.items() if v}, + "restored_with_model": False, + }, + ) + except Exception as e: + logger.error(f"保存文件失败: {str(e)}") + raise HTTPException(status_code=500, detail="保存文件失败") + + return {"success": True, "message": "上传成功", + "filename": saved_filename} + except HTTPException: + raise + except Exception as e: + logger.error(f"文件上传失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}") + + +@api_router.post(path="/check_image_security") +@log_api_params +async def analyze_face( + file: UploadFile = File(...), + nickname: str = Form(None, description="操作者昵称") +): + contents = await file.read() + np_arr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + original_md5_hash = str(uuid.uuid4()).replace('-', '') + original_image_filename = f"{original_md5_hash}_original.webp" + original_image_path = os.path.join(IMAGES_DIR, original_image_filename) + save_image_high_quality(image, original_image_path, quality=SAVE_QUALITY, upload_to_bos=False) + try: + with open(original_image_path, "rb") as f: + security_payload = f.read() + except Exception: + security_payload = contents + # 🔥 添加图片安全检测 + t1 = time.perf_counter() + is_safe = await wx_access_token.check_image_security(security_payload) + logger.info(f"Checking image content safety, time: {time.perf_counter() - t1:.3f}s") + if not is_safe: + await _record_output_file( + file_path=original_image_path, + nickname=nickname, + category="original", + score=0.0, + extra={ + "source": "security", + "role": "annotated", + "model": "wx", + }, + ) + return { + "success": False, + "code": 400, + "message": "图片内容不合规! 请更换其他图片", + "filename": file.filename, + } + else: + return { + "success": True, + "code": 0, + "message": "图片内容合规", + "filename": file.filename, + } + + +@api_router.post("/detect_faces", tags=["Face API"]) +@log_api_params +async def detect_faces_endpoint( + file: UploadFile = File(..., description="需要进行人脸检测的图片"), +): + """ + 上传单张图片,调用 YOLO(_detect_faces)做人脸检测并返回耗时。 + """ + if not file or not file.content_type or not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="请上传有效的图片文件") + + image_bytes = await file.read() + if not image_bytes: + raise HTTPException(status_code=400, detail="图片内容为空") + + np_arr = np.frombuffer(image_bytes, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException(status_code=400, detail="无法解析图片文件,请确认格式正确") + + if analyzer is None: + _ensure_analyzer() + if analyzer is None: + raise HTTPException(status_code=500, detail="人脸检测模型尚未就绪,请稍后再试") + + detect_start = time.perf_counter() + try: + face_boxes = analyzer._detect_faces(image) + except Exception as exc: + logger.error(f"Face detection failed: {exc}") + raise HTTPException(status_code=500, detail="调用人脸检测失败") from exc + detect_duration = time.perf_counter() - detect_start + + return { + "success": True, + "face_count": len(face_boxes), + "boxes": face_boxes, + "elapsed_ms": round(detect_duration * 1000, 3), + "elapsed_seconds": round(detect_duration, 4), + "hostname": SERVER_HOSTNAME, + } + + +@api_router.post(path="/analyze") +@log_api_params +async def analyze_face( + request: Request, + file: UploadFile = File(None), # 保持原有的单文件上传参数(可选) + files: list[UploadFile] = File(None), # 新增的多文件上传参数(可选) + images: str = Form(None), # 可选的base64图片列表 + nickname: str = Form(None, description="操作者昵称"), + model: ModelType = Query( + ModelType.HYBRID, description="选择使用的模型: howcuteami, deepface 或 hybrid" + ), +): + """ + 分析上传的图片(支持单文件上传、多文件上传或base64编码) + :param file: 单个上传的图片文件(保持向后兼容) + :param files: 多个上传的图片文件列表 + :param images: 上传的图片base64编码列表(JSON字符串) + :param model: 选择使用的模型类型 + :return: 分析结果,包含所有图片的五官评分和标注后图片的下载文件名 + """ + # 不读取或记录任何 header 信息 + + # 获取图片数据 + image_data_list = [] + + # 处理单文件上传(保持向后兼容) + if file: + logger.info( + f"--------> Start processing model={model.value}, single file upload --------" + ) + contents = await file.read() + image_data_list.append(contents) + + # 处理多文件上传 + elif files and len(files) > 0: + logger.info( + f"--------> Start processing model={model.value}, file_count={len(files)} --------" + ) + for file_item in files: + if len(image_data_list) >= FACE_SCORE_MAX_IMAGES: # 使用配置项限制图片数量 + break + contents = await file_item.read() + image_data_list.append(contents) + + # 处理base64编码图片 + elif images: + logger.info( + f"--------> Start processing model={model.value}, image_count={len(images)} --------" + ) + try: + images_list = json.loads(images) + for image_b64 in images_list[:FACE_SCORE_MAX_IMAGES]: # 使用配置项限制图片数量 + image_data = base64.b64decode(image_b64) + image_data_list.append(image_data) + except json.JSONDecodeError: + raise HTTPException(status_code=400, detail="图片数据格式错误") + + else: + raise HTTPException(status_code=400, detail="请上传至少一张图片") + + if analyzer is None: + _ensure_analyzer() + if analyzer is None: + raise HTTPException( + status_code=500, + detail="人脸分析器未初始化,请检查模型文件是否缺失或损坏。", + ) + + # 验证图片数量 + if len(image_data_list) == 0: + raise HTTPException(status_code=400, detail="请上传至少一张图片") + + if len(image_data_list) > FACE_SCORE_MAX_IMAGES: # 使用配置项限制图片数量 + raise HTTPException(status_code=400, detail=f"最多只能上传{FACE_SCORE_MAX_IMAGES}张图片") + + all_results = [] + valid_image_count = 0 + + try: + overall_start = time.perf_counter() + + # 处理每张图片 + for idx, image_data in enumerate(image_data_list): + image_start = time.perf_counter() + try: + image_size_kb = len(image_data) / 1024 if image_data else 0 + decode_start = time.perf_counter() + np_arr = np.frombuffer(image_data, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + _log_stage_duration( + "图片解码", + decode_start, + f"image_index={idx+1}, size={image_size_kb:.2f}KB, success={image is not None}", + ) + + if image is None: + logger.warning(f"无法解析第{idx+1}张图片") + continue + + # 生成MD5哈希 + original_md5_hash = str(uuid.uuid4()).replace("-", "") + original_image_filename = f"{original_md5_hash}_original.webp" + + logger.info( + f"Processing image {idx+1}/{len(image_data_list)}, md5={original_md5_hash}, size={image_size_kb:.2f} KB" + ) + + analysis_start = time.perf_counter() + # 使用指定模型进行分析 + result = analyzer.analyze_faces(image, original_md5_hash, model) + _log_stage_duration( + "模型推理", + analysis_start, + f"image_index={idx+1}, model={model.value}, faces={result.get('face_count', 0)}", + ) + + # 如果该图片没有人脸,跳过 + if not result.get("success") or result.get("face_count", 0) == 0: + logger.info(f"第{idx+1}张图片未检测到人脸,跳过处理") + continue + + annotated_image_np = result.pop("annotated_image", None) + result["annotated_image_filename"] = None + + if result.get("success") and annotated_image_np is not None: + original_image_path = os.path.join(OUTPUT_DIR, original_image_filename) + save_start = time.perf_counter() + save_success = save_image_high_quality( + annotated_image_np, original_image_path, quality=SAVE_QUALITY + ) + _log_stage_duration( + "标注图保存", + save_start, + f"image_index={idx+1}, path={original_image_path}, success={save_success}", + ) + + if save_success: + result["annotated_image_filename"] = original_image_filename + faces = result["faces"] + + try: + beauty_scores: List[float] = [] + age_models: List[Any] = [] + gender_models: List[Any] = [] + genders: List[Any] = [] + ages: List[Any] = [] + + for face_idx, face_info in enumerate(faces, start=1): + beauty_value = float(face_info.get("beauty_score") or 0.0) + beauty_scores.append(beauty_value) + age_models.append(face_info.get("age_model_used")) + gender_models.append(face_info.get("gender_model_used")) + genders.append(face_info.get("gender")) + ages.append(face_info.get("age")) + + cropped_filename = face_info.get("cropped_face_filename") + if cropped_filename: + cropped_path = os.path.join(IMAGES_DIR, cropped_filename) + if os.path.exists(cropped_path): + upload_start = time.perf_counter() + bos_face = upload_file_to_bos(cropped_path) + _log_stage_duration( + "BOS 上传(人脸)", + upload_start, + f"image_index={idx+1}, face_index={face_idx}, file={cropped_filename}, uploaded={bos_face}", + ) + record_face_start = time.perf_counter() + await _record_output_file( + file_path=cropped_path, + nickname=nickname, + category="face", + bos_uploaded=bos_face, + score=beauty_value, + extra={ + "source": "analyze", + "role": "face_crop", + "model": model.value, + "face_id": face_info.get("face_id"), + "gender": face_info.get("gender"), + "age": face_info.get("age"), + }, + ) + _log_stage_duration( + "记录人脸文件", + record_face_start, + f"image_index={idx+1}, face_index={face_idx}, file={cropped_filename}", + ) + + max_beauty_score = max(beauty_scores) if beauty_scores else 0.0 + + record_annotated_start = time.perf_counter() + await _record_output_file( + file_path=original_image_path, + nickname=nickname, + category="original", + score=max_beauty_score, + extra={ + "source": "analyze", + "role": "annotated", + "model": model.value, + }, + ) + _log_stage_duration( + "记录标注文件", + record_annotated_start, + f"image_index={idx+1}, file={original_image_filename}", + ) + + # 异步执行图片向量化并入库,不阻塞主流程 + if CLIP_AVAILABLE: + # 先保存原始图片到IMAGES_DIR供向量化使用 + original_input_path = os.path.join(IMAGES_DIR, original_image_filename) + save_input_start = time.perf_counter() + input_save_success = save_image_high_quality( + image, original_input_path, quality=SAVE_QUALITY + ) + _log_stage_duration( + "原图保存(CLIP)", + save_input_start, + f"image_index={idx+1}, success={input_save_success}", + ) + if input_save_success: + record_input_start = time.perf_counter() + await _record_output_file( + file_path=original_input_path, + nickname=nickname, + category="original", + score=max_beauty_score, + extra={ + "source": "analyze", + "role": "original_input", + "model": model.value, + }, + ) + _log_stage_duration( + "记录原图文件", + record_input_start, + f"image_index={idx+1}, file={original_image_filename}", + ) + vector_schedule_start = time.perf_counter() + asyncio.create_task( + handle_image_vector_async( + original_input_path, original_image_filename + ) + ) + _log_stage_duration( + "调度向量化任务", + vector_schedule_start, + f"image_index={idx+1}, file={original_image_filename}", + ) + + image_elapsed = time.perf_counter() - image_start + logger.info( + f"<-------- Image {idx+1} processing completed, elapsed: {image_elapsed:.3f}s, faces={len(faces)}, beauty={beauty_scores}, age={ages} via {age_models}, gender={genders} via {gender_models} --------" + ) + + # 添加到结果列表 + all_results.append(result) + valid_image_count += 1 + except Exception as e: + logger.error(f"Error processing image {idx+1}: {str(e)}") + continue + + except Exception as e: + logger.error(f"Error processing image {idx+1}: {str(e)}") + continue + + # 如果没有有效图片,返回错误 + if valid_image_count == 0: + logger.info("<-------- All images processing completed, no faces detected in any image --------") + return JSONResponse( + content={ + "success": False, + "message": "请尝试上传清晰、无遮挡的正面照片", + "face_count": 0, + "faces": [], + } + ) + + # 合并所有结果 + combined_result = { + "success": True, + "message": "分析完成", + "face_count": sum(result["face_count"] for result in all_results), + "faces": [ + { + "face": face, + "annotated_image_filename": result.get("annotated_image_filename"), + } + for result in all_results + for face in result["faces"] + ], + } + + # 保底:对女性年龄进行调整(如果年龄大于阈值且尚未调整) + for face_entry in combined_result["faces"]: + face = face_entry["face"] + gender = face.get("gender", "") + age_str = face.get("age", "") + + if str(gender) != "Female" or face.get("age_adjusted"): + continue + + try: + # 处理年龄范围格式,如 "25-32" + if "-" in str(age_str): + age = int(str(age_str).split("-")[0].strip("() ")) + else: + age = int(str(age_str).strip()) + + if age >= FEMALE_AGE_ADJUSTMENT_THRESHOLD and FEMALE_AGE_ADJUSTMENT > 0: + adjusted_age = max(0, age - FEMALE_AGE_ADJUSTMENT) + face["age"] = str(adjusted_age) + face["age_adjusted"] = True + face["age_adjustment_value"] = FEMALE_AGE_ADJUSTMENT + logger.info(f"Adjusted age for female (fallback): {age} -> {adjusted_age}") + except (ValueError, TypeError): + pass + + # 转换所有 numpy 类型为原生 Python 类型 + cleaned_result = convert_numpy_types(combined_result) + total_elapsed = time.perf_counter() - overall_start + logger.info( + f"<-------- All images processing completed, total time: {total_elapsed:.3f}s, valid images: {valid_image_count} --------" + ) + return JSONResponse(content=cleaned_result) + + except Exception as e: + import traceback + + traceback.print_exc() + logger.error(f"Internal error occurred during analysis: {str(e)}") + raise HTTPException(status_code=500, detail=f"分析过程中出现内部错误: {str(e)}") + + +@api_router.post("/image_search", response_model=ImageFileList, tags=["图像搜索"]) +@log_api_params +async def search_by_image( + file: UploadFile = File(None), + searchType: str = Query("face"), + top_k: int = Query(5), + score_threshold: float = Query(0.28) +): + """使用图片进行相似图像搜索""" + # 检查CLIP是否可用 + if not CLIP_AVAILABLE: + raise HTTPException(status_code=500, detail="CLIP功能未启用或初始化失败") + + try: + # 获取图片数据 + if not file: + raise HTTPException(status_code=400, detail="请提供要搜索的图片") + + # 读取图片数据 + image_data = await file.read() + + # 保存临时图片文件 + temp_image_path = f"/tmp/search_image_{uuid.uuid4().hex}.webp" + try: + # 解码图片 + np_arr = np.frombuffer(image_data, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException(status_code=400, detail="无法解析图片文件") + + # 保存为临时文件 + cv2.imwrite(temp_image_path, image, [cv2.IMWRITE_WEBP_QUALITY, 100]) + + # 使用CLIP编码图片 + image_vector = clip_encode_image(temp_image_path) + + # 执行搜索 + search_results = search_text_vector(image_vector, top_k) + + # 根据score_threshold过滤结果 + filtered_results = [ + item for item in search_results + if item[1] >= score_threshold + ] + + # 从数据库获取元数据 + records_map = {} + try: + records_map = await fetch_records_by_paths( + file_path for file_path, _ in filtered_results + ) + except Exception as exc: + logger.warning(f"Fetch image records by path failed: {exc}") + + category = _normalize_search_category(searchType) + # 构建返回结果 + all_files = [] + for file_path, score in filtered_results: + record = records_map.get(file_path) + record_category = ( + record.get( + "category") if record else infer_category_from_filename( + file_path) + ) + if category not in ( + None, "all") and record_category != category: + continue + + size_bytes = 0 + is_cropped = False + nickname_value = record.get("nickname") if record else None + last_modified_dt = None + + if record: + size_bytes = int(record.get("size_bytes") or 0) + is_cropped = bool(record.get("is_cropped_face")) + last_modified_dt = record.get("last_modified") + if isinstance(last_modified_dt, str): + try: + last_modified_dt = datetime.fromisoformat( + last_modified_dt) + except ValueError: + last_modified_dt = None + + if last_modified_dt is None or size_bytes == 0: + full_path = os.path.join(IMAGES_DIR, file_path) + if not os.path.isfile(full_path): + continue + stat = os.stat(full_path) + size_bytes = stat.st_size + last_modified_dt = datetime.fromtimestamp(stat.st_mtime) + is_cropped = "_face_" in file_path and file_path.count("_") >= 2 + + last_modified_str = ( + last_modified_dt.strftime("%Y-%m-%d %H:%M:%S") + if isinstance(last_modified_dt, datetime) + else "" + ) + file_info = { + "file_path": file_path, + "score": round(score, 4), + "is_cropped_face": is_cropped, + "size_bytes": size_bytes, + "size_str": human_readable_size(size_bytes), + "last_modified": last_modified_str, + "nickname": nickname_value, + } + all_files.append(file_info) + + return ImageFileList(results=all_files, count=len(all_files)) + + finally: + # 清理临时文件 + if os.path.exists(temp_image_path): + os.remove(temp_image_path) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Image search failed: {str(e)}") + raise HTTPException(status_code=500, detail=f"图片搜索失败: {str(e)}") + + +@api_router.get( + "/daily_category_stats", + response_model=CategoryStatsResponse, + tags=["统计"] +) +@log_api_params +async def get_daily_category_stats(): + """查询当日各分类数量""" + try: + rows = await fetch_today_category_counts() + except Exception as exc: + logger.error("Fetch today category counts failed: %s", exc) + raise HTTPException(status_code=500, + detail="查询今日分类统计失败") from exc + + counts_map: Dict[str, int] = { + str(item.get("category") or "unknown"): int(item.get("count") or 0) + for item in rows + } + total = sum(counts_map.values()) + + remaining = counts_map.copy() + stats: List[CategoryStatItem] = [] + + for category in CATEGORY_DISPLAY_ORDER: + count = remaining.pop(category, 0) + stats.append( + CategoryStatItem( + category=category, + display_name=CATEGORY_DISPLAY_NAMES.get(category, category), + count=count, + ) + ) + + for category in sorted(remaining.keys()): + stats.append( + CategoryStatItem( + category=category, + display_name=CATEGORY_DISPLAY_NAMES.get(category, category), + count=remaining[category], + ) + ) + + return CategoryStatsResponse(stats=stats, total=total) + + +@api_router.post("/outputs", response_model=PagedImageFileList, tags=["检测列表"]) +@log_api_params +async def list_outputs( + request: SearchRequest, + page: int = Query(1, ge=1, description="页码(从1开始)"), + page_size: int = Query(20, ge=1, le=100, description="每页数量(最大100)") +): + search_type = request.searchType + category = _normalize_search_category(search_type) + keyword = request.keyword.strip() if getattr(request, "keyword", + None) else "" + nickname_filter = request.nickname.strip() if getattr(request, "nickname", + None) else None + + try: + # 如果有关键词且CLIP可用,进行向量搜索 + if keyword and CLIP_AVAILABLE: + logger.info(f"Performing vector search, keyword: {keyword}") + try: + # 编码搜索文本 + text_vector = clip_encode_text(keyword) + + # 搜索相似图片 - 使用更大的top_k以支持分页 + search_results = search_text_vector(text_vector, request.top_k if hasattr(request, 'top_k') else 1000) + + # 根据score_threshold过滤结果 + filtered_results = [ + item for item in search_results + if item[1] >= request.score_threshold + ] + + logger.info(f"Vector search found {len(filtered_results)} similar results") + + # 从数据库中批量获取图片元数据 + records_map = {} + try: + records_map = await fetch_records_by_paths( + file_path for file_path, _ in filtered_results + ) + except Exception as exc: + logger.warning(f"Fetch image records by path failed: {exc}") + + # 构建返回结果 + all_files = [] + for file_path, score in filtered_results: + record = records_map.get(file_path) + record_category = ( + record.get( + "category") if record else infer_category_from_filename( + file_path) + ) + + if category not in ( + None, "all") and record_category != category: + continue + if nickname_filter and ( + record is None or ( + record.get("nickname") or "").strip() != nickname_filter + ): + continue + + size_bytes = 0 + is_cropped = False + nickname_value = record.get("nickname") if record else None + last_modified_dt = None + + if record: + size_bytes = int(record.get("size_bytes") or 0) + is_cropped = bool(record.get("is_cropped_face")) + last_modified_dt = record.get("last_modified") + if isinstance(last_modified_dt, str): + try: + last_modified_dt = datetime.fromisoformat( + last_modified_dt) + except ValueError: + last_modified_dt = None + + if last_modified_dt is None or size_bytes == 0: + full_path = os.path.join(IMAGES_DIR, file_path) + if not os.path.isfile(full_path): + continue + stat = os.stat(full_path) + size_bytes = stat.st_size + last_modified_dt = datetime.fromtimestamp(stat.st_mtime) + is_cropped = "_face_" in file_path and file_path.count("_") >= 2 + + last_modified_str = ( + last_modified_dt.strftime("%Y-%m-%d %H:%M:%S") + if isinstance(last_modified_dt, datetime) + else "" + ) + file_info = { + "file_path": file_path, + "score": round(score, 4), + "is_cropped_face": is_cropped, + "size_bytes": size_bytes, + "size_str": human_readable_size(size_bytes), + "last_modified": last_modified_str, + "nickname": nickname_value, + } + all_files.append(file_info) + + # 应用分页 + total_count = len(all_files) + start_index = (page - 1) * page_size + end_index = start_index + page_size + paged_results = all_files[start_index:end_index] + + total_pages = (total_count + page_size - 1) // page_size # 向上取整 + return PagedImageFileList( + results=paged_results, + count=total_count, + page=page, + page_size=page_size, + total_pages=total_pages + ) + + except Exception as e: + logger.error(f"Vector search failed: {str(e)}") + # 如果向量搜索失败,降级到普通文件列表 + + # 普通文件列表模式(无关键词或CLIP不可用) + logger.info("Returning regular file list") + try: + total_count = await count_image_records( + category=category, + nickname=nickname_filter, + ) + if total_count > 0: + offset = (page - 1) * page_size + rows = await fetch_paged_image_records( + category=category, + nickname=nickname_filter, + offset=offset, + limit=page_size, + ) + paged_results = [] + for row in rows: + last_modified = row.get("last_modified") + if isinstance(last_modified, str): + try: + last_modified_dt = datetime.fromisoformat( + last_modified) + except ValueError: + last_modified_dt = None + else: + last_modified_dt = last_modified + size_bytes = int(row.get("size_bytes") or 0) + paged_results.append({ + "file_path": row.get("file_path"), + "score": float(row.get("score") or 0.0), + "is_cropped_face": bool(row.get("is_cropped_face")), + "size_bytes": size_bytes, + "size_str": human_readable_size(size_bytes), + "last_modified": last_modified_dt.strftime( + "%Y-%m-%d %H:%M:%S") if last_modified_dt else "", + "nickname": row.get("nickname"), + }) + total_pages = (total_count + page_size - 1) // page_size + return PagedImageFileList( + results=paged_results, + count=total_count, + page=page, + page_size=page_size, + total_pages=total_pages, + ) + except Exception as exc: + logger.error( + f"Query image records from MySQL failed: {exc}, fallback to filesystem scan") + + if nickname_filter: + # 没有数据库结果且需要按昵称过滤,直接返回空列表以避免返回其他用户数据 + return PagedImageFileList( + results=[], + count=0, + page=page, + page_size=page_size, + total_pages=0, + ) + + # 文件系统兜底逻辑 + all_files = [] + for f in os.listdir(IMAGES_DIR): + if not f.lower().endswith((".jpg", ".jpeg", ".png", ".webp")): + continue + + file_category = infer_category_from_filename(f) + if category not in (None, "all") and file_category != category: + continue + + full_path = os.path.join(IMAGES_DIR, f) + if os.path.isfile(full_path): + stat = os.stat(full_path) + is_cropped = "_face_" in f and f.count("_") >= 2 + file_info = { + "file_path": f, + "score": 0.0, + "is_cropped_face": is_cropped, + "size_bytes": stat.st_size, + "size_str": human_readable_size(stat.st_size), + "last_modified": datetime.fromtimestamp( + stat.st_mtime).strftime( + "%Y-%m-%d %H:%M:%S" + ), + "nickname": None, + } + all_files.append(file_info) + + all_files.sort(key=lambda x: x["last_modified"], reverse=True) + + # 应用分页 + total_count = len(all_files) + start_index = (page - 1) * page_size + end_index = start_index + page_size + paged_results = all_files[start_index:end_index] + + total_pages = (total_count + page_size - 1) // page_size # 向上取整 + return PagedImageFileList( + results=paged_results, + count=total_count, + page=page, + page_size=page_size, + total_pages=total_pages + ) + except Exception as e: + logger.error(f"Failed to get detection result list: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@api_router.get("/preview/{filename}", tags=["文件预览"]) +@log_api_params +async def download_result(filename: str): + file_path = os.path.join(IMAGES_DIR, filename) + if not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="文件不存在") + + # 根据文件扩展名确定媒体类型 + if filename.lower().endswith('.png'): + media_type = "image/png" + elif filename.lower().endswith('.webp'): + media_type = "image/webp" + else: + media_type = "image/jpeg" + return FileResponse(path=file_path, filename=filename, media_type=media_type) + + +@api_router.get("/download/{filename}", tags=["文件下载"]) +@log_api_params +async def preview_result(filename: str): + file_path = os.path.join(OUTPUT_DIR, filename) + if not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="文件不存在") + + # 根据文件扩展名确定媒体类型 + if filename.lower().endswith('.png'): + media_type = "image/png" + elif filename.lower().endswith('.webp'): + media_type = "image/webp" + else: + media_type = "image/jpeg" + return FileResponse( + path=file_path, + filename=filename, + media_type=media_type, + # background=BackgroundTask(move_file_to_archive, file_path), + ) + + +@api_router.get("/models", tags=["模型信息"]) +@log_api_params +async def get_available_models(): + """获取可用的模型列表""" + models = { + "howcuteami": { + "name": "HowCuteAmI", + "description": "基于OpenCV DNN的颜值、年龄、性别预测模型", + "available": analyzer is not None, + "features": [ + "face_detection", + "age_prediction", + "gender_prediction", + "beauty_scoring", + ], + }, + "deepface": { + "name": "DeepFace", + "description": "Facebook开源的人脸分析框架,支持年龄、性别、情绪识别", + "available": DEEPFACE_AVAILABLE, + "features": ["age_prediction", "gender_prediction", "emotion_analysis"], + }, + "hybrid": { + "name": "Hybrid Model", + "description": "混合模型:HowCuteAmI(颜值+性别)+ DeepFace(年龄+情绪)", + "available": analyzer is not None and DEEPFACE_AVAILABLE, + "features": [ + "beauty_scoring", + "gender_prediction", + "age_prediction", + "emotion_analysis", + ], + }, + } + + facial_analysis = { + "name": "Facial Feature Analysis", + "description": "基于MediaPipe的五官特征分析", + "available": DLIB_AVAILABLE, + "features": [ + "eyes_scoring", + "nose_scoring", + "mouth_scoring", + "eyebrows_scoring", + "jawline_scoring", + "harmony_analysis", + ], + } + + return { + "prediction_models": models, + "facial_analysis": facial_analysis, + "recommended_combination": ( + "hybrid + facial_analysis" + if analyzer is not None and DEEPFACE_AVAILABLE and DLIB_AVAILABLE + else "howcuteami + basic_analysis" + ), + } + + +@api_router.post("/sync_resources", tags=["系统维护"]) +@log_api_params +async def sync_bos_resources( + force_download: bool = Query(False, description="是否强制重新下载已存在的文件"), + include_background: bool = Query( + False, description="是否同步配置中标记为后台的资源" + ), + bos_prefix: str | None = Query( + None, description="自定义 BOS 前缀,例如 20220620/models" + ), + destination_dir: str | None = Query( + None, description="自定义本地目录,例如 /opt/models/custom" + ), + background: bool = Query( + False, description="与自定义前缀搭配使用时,是否在后台异步下载" + ), +): + """ + 手动触发 BOS 资源同步。 + - 若提供 bos_prefix 与 destination_dir,则按指定路径同步; + - 否则根据配置的 BOS_DOWNLOAD_TARGETS 执行批量同步。 + """ + start_time = time.perf_counter() + + if (bos_prefix and not destination_dir) or (destination_dir and not bos_prefix): + raise HTTPException(status_code=400, detail="bos_prefix 和 destination_dir 需要同时提供") + + if bos_prefix and destination_dir: + dest_path = os.path.abspath(os.path.expanduser(destination_dir.strip())) + + async def _sync_single(): + return await asyncio.to_thread( + download_bos_directory, + bos_prefix.strip(), + dest_path, + force_download=force_download, + ) + + if background: + async def _background_task(): + success = await _sync_single() + if success: + logger.info( + "后台 BOS 下载完成: prefix=%s -> %s", bos_prefix, dest_path + ) + else: + logger.warning( + "后台 BOS 下载失败: prefix=%s -> %s", bos_prefix, dest_path + ) + + asyncio.create_task(_background_task()) + elapsed = time.perf_counter() - start_time + return { + "success": True, + "force_download": force_download, + "include_background": False, + "bos_prefix": bos_prefix, + "destination_dir": dest_path, + "elapsed_seconds": round(elapsed, 3), + "message": "后台下载任务已启动", + } + + success = await _sync_single() + elapsed = time.perf_counter() - start_time + return { + "success": bool(success), + "force_download": force_download, + "include_background": False, + "bos_prefix": bos_prefix, + "destination_dir": dest_path, + "elapsed_seconds": round(elapsed, 3), + "message": "资源同步完成" if success else "资源同步失败,请查看日志", + } + + # 未指定前缀时,按配置批量同步 + success = await asyncio.to_thread( + ensure_bos_resources, + force_download, + include_background, + ) + elapsed = time.perf_counter() - start_time + message = ( + "后台下载任务已启动,将在后台继续运行" + if not include_background + else "资源同步完成" + ) + return { + "success": bool(success), + "force_download": force_download, + "include_background": include_background, + "elapsed_seconds": round(elapsed, 3), + "message": message, + "bos_prefix": None, + "destination_dir": None, + } + + +@api_router.post("/restore") +@log_api_params +async def restore_old_photo( + file: UploadFile = File(...), + md5: str = Query(None, description="前端传递的文件md5,用于提前保存记录"), + colorize: bool = Query(False, description="是否对黑白照片进行上色"), + nickname: str = Form(None, description="操作者昵称"), +): + """ + 老照片修复接口 + :param file: 上传的老照片文件 + :param md5: 前端传递的文件md5,如果未传递则使用original_md5_hash + :param colorize: 是否对黑白照片进行上色,默认为False + :return: 修复结果,包含修复后图片的文件名 + """ + _ensure_photo_restorer() + if photo_restorer is None or not photo_restorer.is_available(): + raise HTTPException( + status_code=500, + detail="照片修复器未初始化,请检查服务状态。" + ) + + # 验证文件类型 + if not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="请上传图片文件") + + try: + contents = await file.read() + original_md5_hash = str(uuid.uuid4()).replace('-', '') + # 如果前端传递了md5参数则使用,否则使用original_md5_hash + actual_md5 = md5 if md5 else original_md5_hash + restored_filename = f"{actual_md5}_restore.webp" + + logger.info(f"Starting to restore old photo: {file.filename}, size={file.size}, colorize={colorize}, md5={original_md5_hash}") + t1 = time.perf_counter() + + # 解码图像 + np_arr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException( + status_code=400, detail="无法解析图片文件,请确保文件格式正确。" + ) + + # 获取原图信息 + original_height, original_width = image.shape[:2] + original_size = file.size + + # 调整后的处理流程:先修复再上色 + # 步骤1: 使用GFPGAN修复图像 + logger.info("Step 1: Starting to restore the original image...") + processing_steps = [] + + try: + restored_image = await process_cpu_intensive_task(photo_restorer.restore_image, image) + final_image = restored_image + processing_steps.append(f"使用{restorer_type}修复器修复") + logger.info("Restoration processing completed") + except Exception as e: + logger.error(f"Restoration processing failed: {e}, continuing with original image") + final_image = image + + # 步骤2: 如果用户选择上色,对修复后的图像进行上色 + if colorize and ddcolor_colorizer is not None and ddcolor_colorizer.is_available(): + logger.info("Step 2: Starting to colorize the restored image...") + try: + # 检查修复后的图像是否为灰度 + restored_is_grayscale = ddcolor_colorizer.is_grayscale(final_image) + logger.info(f"Is restored image grayscale: {restored_is_grayscale}") + + if restored_is_grayscale: + # 对灰度图进行上色 + logger.info("Colorizing the restored grayscale image...") + colorized_image = await process_cpu_intensive_task(ddcolor_colorizer.colorize_image_direct, final_image) + final_image = colorized_image + processing_steps.append("使用DDColor对修复后图像上色") + logger.info("Colorization processing completed") + else: + # 对于彩色图像,可以选择强制上色或跳过 + logger.info("Restored image is already colored, performing forced colorization...") + colorized_image = await process_cpu_intensive_task(ddcolor_colorizer.colorize_image_direct, final_image) + final_image = colorized_image + processing_steps.append("强制使用DDColor上色") + logger.info("Forced colorization processing completed") + + except Exception as e: + logger.error(f"Colorization processing failed: {e}, using restored image") + elif colorize: + if DDCOLOR_AVAILABLE: + logger.warning("Colorization feature unavailable: DDColor not properly initialized") + else: + logger.info("Colorization feature disabled or DDColor unavailable, skipping colorization step") + + # 获取处理后图像信息 + processed_height, processed_width = final_image.shape[:2] + + # 保存最终处理后的图像到IMAGES_DIR(与人脸评分使用相同路径) + restored_path = os.path.join(IMAGES_DIR, restored_filename) + save_success = save_image_high_quality( + final_image, restored_path, quality=SAVE_QUALITY + ) + + if save_success: + total_time = time.perf_counter() - t1 + + # 获取处理后文件大小 + processed_size = os.path.getsize(restored_path) + + logger.info(f"Old photo processing completed: {restored_filename}, time: {total_time:.3f}s") + + # 异步执行图片向量化并入库,不阻塞主流程 + if CLIP_AVAILABLE: + asyncio.create_task(handle_image_vector_async(restored_path, restored_filename)) + + # bos_uploaded = upload_file_to_bos(restored_path) + await _record_output_file( + file_path=restored_path, + nickname=nickname, + category="restore", + bos_uploaded=True, + extra={ + "source": "restore", + "colorize": colorize, + "processing_steps": processing_steps, + "md5": actual_md5, + }, + ) + + return { + "success": True, + "message": "成功", + "original_filename": file.filename, + "restored_filename": restored_filename, + "processing_time": f"{total_time:.3f}s", + "original_size": original_size, + "processed_size": processed_size, + "size_increase_ratio": round(processed_size / original_size, 2), + "original_dimensions": f"{original_width} × {original_height}", + "processed_dimensions": f"{processed_width} × {processed_height}", + } + else: + raise HTTPException(status_code=500, detail="保存修复后图像失败") + + except Exception as e: + logger.error(f"Error occurred during old photo restoration: {str(e)}") + raise HTTPException(status_code=500, detail=f"修复过程中出现错误: {str(e)}") + + +@api_router.post("/upcolor") +@log_api_params +async def colorize_photo( + file: UploadFile = File(...), + md5: str = Query(None, description="前端传递的文件md5,用于提前保存记录"), + nickname: str = Form(None, description="操作者昵称"), +): + """ + 照片上色接口 + :param file: 上传的照片文件 + :param md5: 前端传递的文件md5,如果未传递则使用original_md5_hash + :return: 上色结果,包含上色后图片的文件名 + """ + _ensure_ddcolor() + if ddcolor_colorizer is None or not ddcolor_colorizer.is_available(): + raise HTTPException( + status_code=500, + detail="照片上色器未初始化,请检查服务状态。" + ) + + # 验证文件类型 + if not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="请上传图片文件") + + try: + contents = await file.read() + original_md5_hash = str(uuid.uuid4()).replace('-', '') + # 如果前端传递了md5参数则使用,否则使用original_md5_hash + actual_md5 = md5 if md5 else original_md5_hash + colored_filename = f"{actual_md5}_upcolor.webp" + + logger.info(f"Starting to colorize photo: {file.filename}, size={file.size}, md5={original_md5_hash}") + t1 = time.perf_counter() + + # 解码图像 + np_arr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException( + status_code=400, detail="无法解析图片文件,请确保文件格式正确。" + ) + + # 获取原图信息 + original_height, original_width = image.shape[:2] + original_size = file.size + + # 使用DDColor对图像进行上色 + logger.info("Starting to colorize the image...") + try: + colorized_image = await process_cpu_intensive_task(ddcolor_colorizer.colorize_image_direct, image) + logger.info("Colorization processing completed") + except Exception as e: + logger.error(f"Colorization processing failed: {e}") + raise HTTPException(status_code=500, detail=f"上色处理失败: {str(e)}") + + # 获取处理后图像信息 + processed_height, processed_width = colorized_image.shape[:2] + + # 保存上色后的图像到IMAGES_DIR + colored_path = os.path.join(IMAGES_DIR, colored_filename) + save_success = save_image_high_quality( + colorized_image, colored_path, quality=SAVE_QUALITY + ) + + if save_success: + total_time = time.perf_counter() - t1 + + # 获取处理后文件大小 + processed_size = os.path.getsize(colored_path) + + logger.info(f"Photo colorization completed: {colored_filename}, time: {total_time:.3f}s") + + # 异步执行图片向量化并入库,不阻塞主流程 + if CLIP_AVAILABLE: + asyncio.create_task(handle_image_vector_async(colored_path, colored_filename)) + + # bos_uploaded = upload_file_to_bos(colored_path) + await _record_output_file( + file_path=colored_path, + nickname=nickname, + category="upcolor", + bos_uploaded=True, + extra={ + "source": "upcolor", + "md5": actual_md5, + }, + ) + + return { + "success": True, + "message": "成功", + "original_filename": file.filename, + "colored_filename": colored_filename, + "processing_time": f"{total_time:.3f}s", + "original_size": original_size, + "processed_size": processed_size, + "size_increase_ratio": round(processed_size / original_size, 2), + "original_dimensions": f"{original_width} × {original_height}", + "processed_dimensions": f"{processed_width} × {processed_height}", + } + else: + raise HTTPException(status_code=500, detail="保存上色后图像失败") + + except Exception as e: + logger.error(f"Error occurred during photo colorization: {str(e)}") + raise HTTPException(status_code=500, detail=f"上色过程中出现错误: {str(e)}") + + +@api_router.get("/anime_style/status", tags=["动漫风格化"]) +@log_api_params +async def get_anime_style_status(): + """ + 获取动漫风格化模型状态 + :return: 模型状态信息,包括已加载的模型和预加载状态 + """ + _ensure_anime_stylizer() + if anime_stylizer is None or not anime_stylizer.is_available(): + raise HTTPException( + status_code=500, + detail="动漫风格化处理器未初始化,请检查服务状态。" + ) + + try: + # 获取预加载状态 + preload_status = anime_stylizer.get_preload_status() + available_styles = anime_stylizer.get_available_styles() + + return { + "success": True, + "message": "获取动漫风格化状态成功", + "preload_status": preload_status, + "available_styles": available_styles, + "service_available": True + } + except Exception as e: + logger.error(f"Failed to get anime stylization status: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") + + +@api_router.post("/anime_style/preload", tags=["动漫风格化"]) +@log_api_params +async def preload_anime_models( + style_types: list = Query(None, description="要预加载的风格类型列表,如果为空则预加载所有模型") +): + """ + 预加载动漫风格化模型 + :param style_types: 要预加载的风格类型列表,支持: handdrawn, disney, illustration, artstyle, anime, sketch + :return: 预加载结果 + """ + _ensure_anime_stylizer() + if anime_stylizer is None or not anime_stylizer.is_available(): + raise HTTPException( + status_code=500, + detail="动漫风格化处理器未初始化,请检查服务状态。" + ) + + try: + logger.info(f"API request to preload anime style models: {style_types}") + + # 开始预加载 + start_time = time.perf_counter() + anime_stylizer.preload_models(style_types) + preload_time = time.perf_counter() - start_time + + # 获取预加载后的状态 + preload_status = anime_stylizer.get_preload_status() + + return { + "success": True, + "message": f"模型预加载完成,耗时: {preload_time:.3f}s", + "preload_time": f"{preload_time:.3f}s", + "preload_status": preload_status, + "requested_styles": style_types, + } + except Exception as e: + logger.error(f"Anime style model preloading failed: {str(e)}") + raise HTTPException(status_code=500, detail=f"预加载失败: {str(e)}") + + +@api_router.post("/anime_style") +@log_api_params +async def anime_stylize_photo( + file: UploadFile = File(...), + style_type: str = Form("handdrawn", + description="动漫风格类型: handdrawn=手绘风格, disney=迪士尼风格, illustration=插画风格, artstyle=艺术风格, anime=二次元风格, sketch=素描风格"), + nickname: str = Form(None, description="操作者昵称"), +): + """ + 图片动漫风格化接口 + :param file: 上传的照片文件 + :param style_type: 动漫风格类型,默认为"disney"(迪士尼风格) + :return: 动漫风格化结果,包含风格化后图片的文件名 + """ + _ensure_anime_stylizer() + if anime_stylizer is None or not anime_stylizer.is_available(): + raise HTTPException( + status_code=500, + detail="动漫风格化处理器未初始化,请检查服务状态。" + ) + + # 验证文件类型 + if not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="请上传图片文件") + + # 验证风格类型 + valid_styles = ["handdrawn", "disney", "illustration", "artstyle", "anime", "sketch"] + if style_type not in valid_styles: + raise HTTPException(status_code=400, detail=f"不支持的风格类型,请选择: {valid_styles}") + + try: + contents = await file.read() + if not contents: + raise HTTPException(status_code=400, detail="文件内容为空") + + original_md5_hash = hashlib.md5(contents).hexdigest() + + np_arr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException( + status_code=400, detail="无法解析图片文件,请确保文件格式正确。" + ) + + def _save_webp_and_upload(image_array: np.ndarray, output_path: str, + log_prefix: str): + success, encoded_img = cv2.imencode( + ".webp", image_array, + [cv2.IMWRITE_WEBP_QUALITY, SAVE_QUALITY] + ) + if not success: + logger.error(f"{log_prefix}编码失败: {output_path}") + return False, False + try: + with open(output_path, "wb") as output_file: + output_file.write(encoded_img) + except Exception as save_exc: + logger.error( + f"{log_prefix}保存失败: {output_path}, error: {save_exc}") + return False, False + logger.info( + f"{log_prefix}保存成功: {output_path}, size: {len(encoded_img) / 1024:.2f} KB" + ) + bos_uploaded_flag = upload_file_to_bos(output_path) + return True, bos_uploaded_flag + + original_filename = f"{original_md5_hash}_anime_style.webp" + original_path = os.path.join(IMAGES_DIR, original_filename) + if not os.path.exists(original_path): + original_saved, original_bos_uploaded = _save_webp_and_upload( + image, original_path, "动漫风格原图" + ) + if not original_saved: + raise HTTPException(status_code=500, detail="保存原图失败") + else: + logger.info( + f"Original image already exists for anime style: {original_filename}") + original_bos_uploaded = False + + styled_uuid = uuid.uuid4().hex + styled_filename = f"{styled_uuid}_anime_style_{style_type}.webp" + + # 获取风格描述 + style_descriptions = anime_stylizer.get_available_styles() + style_description = style_descriptions.get(style_type, "未知风格") + + logger.info(f"Starting anime stylization processing: {file.filename}, size={file.size}, style={style_type}({style_description}), md5={original_md5_hash}") + t1 = time.perf_counter() + + await _record_output_file( + file_path=original_path, + nickname=nickname, + category="anime_style", + bos_uploaded=original_bos_uploaded, + extra={ + "source": "anime_style", + "style_type": style_type, + "style_description": style_description, + "md5": original_md5_hash, + "role": "original", + "original_filename": original_filename, + }, + ) + + # 使用AnimeStylizer对图像进行动漫风格化 + logger.info(f"Starting to stylize image with anime style, style: {style_description}...") + try: + stylized_image = await process_cpu_intensive_task(anime_stylizer.stylize_image, image, style_type) + logger.info("Anime stylization processing completed") + except Exception as e: + logger.error(f"Anime stylization processing failed: {e}") + raise HTTPException(status_code=500, detail=f"动漫风格化处理失败: {str(e)}") + + # 保存风格化后的图像到IMAGES_DIR + styled_path = os.path.join(IMAGES_DIR, styled_filename) + save_success, bos_uploaded = _save_webp_and_upload( + stylized_image, styled_path, "动漫风格结果图" + ) + + if save_success: + total_time = time.perf_counter() - t1 + logger.info(f"Anime stylization completed: {styled_filename}, time: {total_time:.3f}s") + + # 异步执行图片向量化并入库,不阻塞主流程 + if CLIP_AVAILABLE: + asyncio.create_task(handle_image_vector_async(styled_path, styled_filename)) + + await _record_output_file( + file_path=styled_path, + nickname=nickname, + category="anime_style", + bos_uploaded=bos_uploaded, + extra={ + "source": "anime_style", + "style_type": style_type, + "style_description": style_description, + "md5": original_md5_hash, + "role": "styled", + "original_filename": original_filename, + "styled_uuid": styled_uuid, + }, + ) + + return { + "success": True, + "message": "成功", + "original_filename": file.filename, + "styled_filename": styled_filename, + "style_type": style_type, + # "style_description": style_description, + # "available_styles": style_descriptions, + "processing_time": f"{total_time:.3f}s" + } + else: + raise HTTPException(status_code=500, detail="保存动漫风格化后图像失败") + + except Exception as e: + logger.error(f"Error occurred during anime stylization: {str(e)}") + raise HTTPException(status_code=500, detail=f"动漫风格化过程中出现错误: {str(e)}") + + +@api_router.post("/grayscale") +@log_api_params +async def grayscale_photo( + file: UploadFile = File(...), + nickname: str = Form(None, description="操作者昵称"), +): + """ + 图像黑白化接口 + :param file: 上传的照片文件 + :return: 黑白化结果,包含黑白化后图片的文件名 + """ + # 验证文件类型 + if not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="请上传图片文件") + + try: + contents = await file.read() + original_md5_hash = str(uuid.uuid4()).replace('-', '') + grayscale_filename = f"{original_md5_hash}_grayscale.webp" + + logger.info(f"Starting image grayscale conversion: {file.filename}, size={file.size}, md5={original_md5_hash}") + t1 = time.perf_counter() + + # 解码图像 + np_arr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException( + status_code=400, detail="无法解析图片文件,请确保文件格式正确。" + ) + + # 获取原图信息 + original_height, original_width = image.shape[:2] + original_size = file.size + + # 进行图像黑白化处理 + logger.info("Starting to convert image to grayscale...") + try: + # 转换为灰度图像 + gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + # 转换回3通道格式以便保存为彩色图像格式 + grayscale_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2BGR) + logger.info("Grayscale processing completed") + except Exception as e: + logger.error(f"Grayscale processing failed: {e}") + raise HTTPException(status_code=500, detail=f"黑白化处理失败: {str(e)}") + + # 保存黑白化后的图像到IMAGES_DIR + grayscale_path = os.path.join(IMAGES_DIR, grayscale_filename) + save_success = save_image_high_quality( + grayscale_image, grayscale_path, quality=SAVE_QUALITY + ) + + if save_success: + total_time = time.perf_counter() - t1 + + # 获取处理后文件大小 + processed_size = os.path.getsize(grayscale_path) + + logger.info(f"Image grayscale conversion completed: {grayscale_filename}, time: {total_time:.3f}s") + + # 异步执行图片向量化并入库,不阻塞主流程 + if CLIP_AVAILABLE: + asyncio.create_task(handle_image_vector_async(grayscale_path, grayscale_filename)) + + # bos_uploaded = upload_file_to_bos(grayscale_path) + await _record_output_file( + file_path=grayscale_path, + nickname=nickname, + category="grayscale", + bos_uploaded=True, + extra={ + "source": "grayscale", + "md5": original_md5_hash, + }, + ) + + return { + "success": True, + "message": "成功", + "original_filename": file.filename, + "grayscale_filename": grayscale_filename, + "processing_time": f"{total_time:.3f}s", + "original_size": original_size, + "processed_size": processed_size, + "size_increase_ratio": round(processed_size / original_size, 2), + "original_dimensions": f"{original_width} × {original_height}", + "processed_dimensions": f"{original_width} × {original_height}", + } + else: + raise HTTPException(status_code=500, detail="保存黑白化后图像失败") + + except Exception as e: + logger.error(f"Error occurred during image grayscale conversion: {str(e)}") + raise HTTPException(status_code=500, detail=f"黑白化过程中出现错误: {str(e)}") + + +@api_router.post("/upscale") +@log_api_params +async def upscale_photo( + file: UploadFile = File(...), + md5: str = Query(None, description="前端传递的文件md5,用于提前保存记录"), + scale: int = Query(UPSCALE_SIZE, description="放大倍数,支持2或4倍"), + model_name: str = Query(REALESRGAN_MODEL, + description="模型名称,推荐使用RealESRGAN_x2plus以提高CPU性能"), + nickname: str = Form(None, description="操作者昵称"), +): + """ + 照片超清放大接口 + :param file: 上传的照片文件 + :param md5: 前端传递的文件md5,如果未传递则使用original_md5_hash + :param scale: 放大倍数,默认4倍 + :param model_name: 使用的模型名称 + :return: 超清结果,包含超清后图片的文件名和相关信息 + """ + _ensure_realesrgan() + if realesrgan_upscaler is None or not realesrgan_upscaler.is_available(): + raise HTTPException( + status_code=500, + detail="照片超清处理器未初始化,请检查服务状态。" + ) + + # 验证文件类型 + if not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="请上传图片文件") + + # 验证放大倍数 + if scale not in [2, 4]: + raise HTTPException(status_code=400, detail="放大倍数只支持2倍或4倍") + + try: + contents = await file.read() + original_md5_hash = str(uuid.uuid4()).replace('-', '') + # 如果前端传递了md5参数则使用,否则使用original_md5_hash + actual_md5 = md5 if md5 else original_md5_hash + upscaled_filename = f"{actual_md5}_upscale.webp" + + logger.info(f"Starting photo super resolution processing: {file.filename}, size={file.size}, scale={scale}x, model={model_name}, md5={original_md5_hash}") + t1 = time.perf_counter() + + # 解码图像 + np_arr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException( + status_code=400, detail="无法解析图片文件,请确保文件格式正确。" + ) + + # 获取原图信息 + original_height, original_width = image.shape[:2] + original_size = file.size + + # 使用Real-ESRGAN对图像进行超清处理 + logger.info(f"Starting Real-ESRGAN super resolution processing, original image size: {original_width}x{original_height}") + try: + upscaled_image = await process_cpu_intensive_task(realesrgan_upscaler.upscale_image, image, scale=scale) + logger.info("Super resolution processing completed") + except Exception as e: + logger.error(f"Super resolution processing failed: {e}") + raise HTTPException(status_code=500, detail=f"超清处理失败: {str(e)}") + + # 获取处理后图像信息 + upscaled_height, upscaled_width = upscaled_image.shape[:2] + + # 保存超清后的图像到IMAGES_DIR(与其他接口保持一致) + upscaled_path = os.path.join(IMAGES_DIR, upscaled_filename) + save_success = save_image_high_quality( + upscaled_image, upscaled_path, quality=SAVE_QUALITY + ) + + if save_success: + total_time = time.perf_counter() - t1 + + # 获取处理后文件大小 + upscaled_size = os.path.getsize(upscaled_path) + + logger.info(f"Photo super resolution processing completed: {upscaled_filename}, time: {total_time:.3f}s") + + # 异步执行图片向量化并入库,不阻塞主流程 + if CLIP_AVAILABLE: + asyncio.create_task(handle_image_vector_async(upscaled_path, upscaled_filename)) + + # bos_uploaded = upload_file_to_bos(upscaled_path) + await _record_output_file( + file_path=upscaled_path, + nickname=nickname, + category="upscale", + bos_uploaded=True, + extra={ + "source": "upscale", + "md5": actual_md5, + "scale": scale, + "model_name": model_name, + }, + ) + + return { + "success": True, + "message": "成功", + "original_filename": file.filename, + "upscaled_filename": upscaled_filename, + "processing_time": f"{total_time:.3f}s", + "original_size": original_size, + "upscaled_size": upscaled_size, + "size_increase_ratio": round(upscaled_size / original_size, 2), + "original_dimensions": f"{original_width} × {original_height}", + "upscaled_dimensions": f"{upscaled_width} × {upscaled_height}", + "scale_factor": f"{scale}x" + } + else: + raise HTTPException(status_code=500, detail="保存超清后图像失败") + + except HTTPException: + # 重新抛出HTTP异常 + raise + except Exception as e: + logger.error(f"Error occurred during photo super resolution: {str(e)}") + raise HTTPException(status_code=500, detail=f"超清过程中出现错误: {str(e)}") + + +@api_router.post("/remove_background") +@log_api_params +async def remove_background( + file: UploadFile = File(...), + background_color: str = Form("None", description="背景颜色,格式:r,g,b,如 255,255,255 为白色,None为透明背景"), + model: str = Form("robustVideoMatting", description="使用的rembg模型: u2net, u2net_human_seg, silueta, isnet-general-use, robustVideoMatting"), + output_format: str = Form("webp", description="输出格式: png, webp"), + nickname: str = Form(None, description="操作者昵称"), +): + """ + 证件照抠图接口 + :param file: 上传的图片文件 + :param background_color: 背景颜色,格式:r,g,b 或 None + :param model: 使用的模型: u2net, u2net_human_seg, silueta, isnet-general-use, robustVideoMatting + :param output_format: 输出格式: png, webp + :return: 抠图结果,包含抠图后图片的文件名 + """ + # 验证文件类型 + if not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="请上传图片文件") + + # 验证输出格式 + if output_format not in ["png", "webp"]: + raise HTTPException(status_code=400, detail="输出格式只支持png或webp") + + try: + contents = await file.read() + # 解码图像 + np_arr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException( + status_code=400, detail="无法解析图片文件,请确保文件格式正确。" + ) + + # 检查图片中是否存在人脸 + has_face = False + if analyzer is not None: + try: + face_boxes = analyzer._detect_faces(image) + has_face = len(face_boxes) > 0 + except Exception as e: + logger.warning(f"Face detection failed: {e}") + has_face = False + + # 如果图片存在人脸并且模型是robustVideoMatting,则使用RVM处理器 + if has_face and model == "robustVideoMatting": + # 重新设置文件指针,因为上面已经读取了内容 + file.file = io.BytesIO(contents) + # 尝试使用RVM处理器,如果失败则回滚到rembg + try: + return await rvm_remove_background( + file, + background_color, + output_format, + nickname=nickname, + ) + except Exception as rvm_error: + logger.warning(f"RVM background removal failed: {rvm_error}, rolling back to rembg background removal") + # 重置文件指针 + file.file = io.BytesIO(contents) + + # 否则使用rembg处理器 + _ensure_rembg() + if rembg_processor is None or not rembg_processor.is_available(): + raise HTTPException( + status_code=500, + detail="证件照抠图处理器未初始化,请检查服务状态。" + ) + + # 如果用户选择了robustVideoMatting但图片中没有人脸,则使用isnet-general-use模型 + if model == "robustVideoMatting": + model = "isnet-general-use" + logger.info(f"User selected robustVideoMatting model but no face detected in image, switching to {model} model") + + # 生成唯一ID + unique_id = str(uuid.uuid4()).replace('-', '') # 32位UUID + + # 根据是否有透明背景决定文件扩展名 + if background_color and background_color.lower() != "none": + processed_filename = f"{unique_id}_id_photo.webp" + else: + processed_filename = f"{unique_id}_id_photo.{output_format}" # 透明背景使用指定格式 + + logger.info(f"Starting ID photo background removal processing: {file.filename}, size={file.size}, model={model}, bg_color={background_color}, uuid={unique_id}") + t1 = time.perf_counter() + + # 获取原图信息 + original_height, original_width = image.shape[:2] + original_size = file.size + + # 切换模型(如果需要) + if model != rembg_processor.model_name: + if not rembg_processor.switch_model(model): + logger.warning(f"Failed to switch to model {model}, using default model {rembg_processor.model_name}") + + # 解析背景颜色 + bg_color = None + if background_color and background_color.lower() != "none": + try: + # 解析 r,g,b 格式,转换为 BGR 格式 + rgb_values = [int(x.strip()) for x in background_color.split(",")] + if len(rgb_values) == 3: + bg_color = (rgb_values[2], rgb_values[1], rgb_values[0]) # RGB转BGR + logger.info(f"Using background color: RGB{tuple(rgb_values)} -> BGR{bg_color}") + else: + raise ValueError("背景颜色格式错误") + except (ValueError, IndexError) as e: + logger.warning(f"Failed to parse background color parameter: {e}, using default white background") + bg_color = (255, 255, 255) # 默认白色背景 + + # 执行抠图处理 + logger.info("Starting rembg background removal processing...") + try: + if bg_color is not None: + processed_image = await process_cpu_intensive_task(rembg_processor.create_id_photo, image, bg_color) + processing_info = f"使用{model}模型抠图并添加纯色背景" + else: + processed_image = await process_cpu_intensive_task(rembg_processor.remove_background, image) + processing_info = f"使用{model}模型抠图保持透明背景" + + logger.info("Background removal processing completed") + except Exception as e: + logger.error(f"Background removal processing failed: {e}") + raise HTTPException(status_code=500, detail=f"抠图处理失败: {str(e)}") + + # 获取处理后图像信息 + processed_height, processed_width = processed_image.shape[:2] + + # 保存抠图后的图像到IMAGES_DIR(与facescore保持一致) + processed_path = os.path.join(IMAGES_DIR, processed_filename) + bos_uploaded = False + + # 根据是否有透明背景选择保存方式 + if bg_color is not None: + # 有背景色,保存为JPEG + save_success = save_image_high_quality(processed_image, processed_path, quality=SAVE_QUALITY) + # if save_success: + # bos_uploaded = upload_file_to_bos(processed_path) + else: + # 透明背景,保存为指定格式 + if output_format == "webp": + # 使用OpenCV保存为WebP格式 + success, encoded_img = cv2.imencode(".webp", processed_image, [cv2.IMWRITE_WEBP_QUALITY, 100]) + if success: + with open(processed_path, "wb") as f: + f.write(encoded_img) + bos_uploaded = upload_file_to_bos(processed_path) + save_success = True + else: + save_success = False + else: + # 保存为PNG格式 + save_success = save_image_with_transparency(processed_image, processed_path) + # if save_success: + # bos_uploaded = upload_file_to_bos(processed_path) + + if save_success: + total_time = time.perf_counter() - t1 + + # 获取处理后文件大小 + processed_size = os.path.getsize(processed_path) + + logger.info(f"ID photo background removal processing completed: {processed_filename}, time: {total_time:.3f}s") + + # 异步执行图片向量化并入库,不阻塞主流程 + if CLIP_AVAILABLE: + asyncio.create_task(handle_image_vector_async(processed_path, processed_filename)) + + if not bos_uploaded: + bos_uploaded = upload_file_to_bos(processed_path) + + await _record_output_file( + file_path=processed_path, + nickname=nickname, + category="id_photo", + bos_uploaded=bos_uploaded, + extra={ + "source": "remove_background", + "background_color": background_color, + "model_used": model, + "output_format": output_format, + "has_face": has_face, + }, + ) + + # 确定输出格式 + final_output_format = "PNG" if bg_color is None and output_format == "png" else \ + "WEBP" if bg_color is None and output_format == "webp" else "JPEG" + has_transparency = bg_color is None + + return { + "success": True, + "message": "抠图成功", + "original_filename": file.filename, + "processed_filename": processed_filename, + "processing_time": f"{total_time:.3f}s", + "processing_info": processing_info, + "original_size": original_size, + "processed_size": processed_size, + "size_change_ratio": round(processed_size / original_size, 2) if original_size > 0 else 1.0, + "original_dimensions": f"{original_width} × {original_height}", + "processed_dimensions": f"{processed_width} × {processed_height}", + "model_used": model, + "background_color": background_color, + "output_format": final_output_format, + "has_transparency": has_transparency + } + else: + raise HTTPException(status_code=500, detail="保存抠图后图像失败") + + except HTTPException: + # 重新抛出HTTP异常 + raise + except Exception as e: + logger.error(f"Error occurred during ID photo background removal: {str(e)}") + raise HTTPException(status_code=500, detail=f"抠图过程中出现错误: {str(e)}") + + +@api_router.post("/rvm") +@log_api_params +async def rvm_remove_background( + file: UploadFile = File(...), + background_color: str = Form("None", description="背景颜色,格式:r,g,b,如 255,255,255 为白色,None为透明背景"), + output_format: str = Form("webp", description="输出格式: png, webp"), + nickname: str = Form(None, description="操作者昵称"), +): + """ + RVM证件照抠图接口 + :param file: 上传的图片文件 + :param background_color: 背景颜色,格式:r,g,b 或 None + :param output_format: 输出格式: png, webp + :return: 抠图结果,包含抠图后图片的文件名 + """ + _ensure_rvm() + if rvm_processor is None or not rvm_processor.is_available(): + raise HTTPException( + status_code=500, + detail="RVM抠图处理器未初始化,请检查服务状态。" + ) + + # 验证文件类型 + if not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="请上传图片文件") + + # 验证输出格式 + if output_format not in ["png", "webp"]: + raise HTTPException(status_code=400, detail="输出格式只支持png或webp") + + try: + contents = await file.read() + unique_id = str(uuid.uuid4()).replace('-', '') # 32位UUID + + # 根据是否有透明背景决定文件扩展名 + if background_color and background_color.lower() != "none": + processed_filename = f"{unique_id}_rvm_id_photo.webp" + else: + processed_filename = f"{unique_id}_rvm_id_photo.{output_format}" # 透明背景使用指定格式 + + logger.info(f"Starting RVM ID photo background removal processing: {file.filename}, size={file.size}, bg_color={background_color}, uuid={unique_id}") + t1 = time.perf_counter() + + # 解码图像 + np_arr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException( + status_code=400, detail="无法解析图片文件,请确保文件格式正确。" + ) + + # 获取原图信息 + original_height, original_width = image.shape[:2] + original_size = file.size + + # 解析背景颜色 + bg_color = None + if background_color and background_color.lower() != "none": + try: + # 解析 r,g,b 格式,转换为 BGR 格式 + rgb_values = [int(x.strip()) for x in background_color.split(",")] + if len(rgb_values) == 3: + bg_color = (rgb_values[2], rgb_values[1], rgb_values[0]) # RGB转BGR + logger.info(f"Using background color: RGB{tuple(rgb_values)} -> BGR{bg_color}") + else: + raise ValueError("背景颜色格式错误") + except (ValueError, IndexError) as e: + logger.warning(f"Failed to parse background color parameter: {e}, using default white background") + bg_color = (255, 255, 255) # 默认白色背景 + + # 执行RVM抠图处理 + logger.info("Starting RVM background removal processing...") + try: + if bg_color is not None: + processed_image = await process_cpu_intensive_task(rvm_processor.create_id_photo, image, bg_color) + processing_info = "使用RVM模型抠图并添加纯色背景" + else: + processed_image = await process_cpu_intensive_task(rvm_processor.remove_background, image) + processing_info = "使用RVM模型抠图保持透明背景" + + logger.info("RVM background removal processing completed") + except Exception as e: + logger.error(f"RVM background removal processing failed: {e}") + raise Exception(f"RVM抠图处理失败: {str(e)}") + + # 获取处理后图像信息 + processed_height, processed_width = processed_image.shape[:2] + + # 保存抠图后的图像到IMAGES_DIR(与facescore保持一致) + processed_path = os.path.join(IMAGES_DIR, processed_filename) + bos_uploaded = False + + # 根据是否有透明背景选择保存方式 + if bg_color is not None: + # 有背景色,保存为JPEG + save_success = save_image_high_quality(processed_image, processed_path, quality=SAVE_QUALITY) + # if save_success: + # bos_uploaded = upload_file_to_bos(processed_path) + else: + # 透明背景,保存为指定格式 + if output_format == "webp": + # 使用OpenCV保存为WebP格式 + success, encoded_img = cv2.imencode(".webp", processed_image, [cv2.IMWRITE_WEBP_QUALITY, 100]) + if success: + with open(processed_path, "wb") as f: + f.write(encoded_img) + bos_uploaded = upload_file_to_bos(processed_path) + save_success = True + else: + save_success = False + else: + # 保存为PNG格式 + save_success = save_image_with_transparency(processed_image, processed_path) + # if save_success: + # bos_uploaded = upload_file_to_bos(processed_path) + + if save_success: + total_time = time.perf_counter() - t1 + + # 获取处理后文件大小 + processed_size = os.path.getsize(processed_path) + + logger.info(f"RVM ID photo background removal processing completed: {processed_filename}, time: {total_time:.3f}s") + + # 异步执行图片向量化并入库,不阻塞主流程 + if CLIP_AVAILABLE: + asyncio.create_task(handle_image_vector_async(processed_path, processed_filename)) + + if not bos_uploaded: + bos_uploaded = upload_file_to_bos(processed_path) + + await _record_output_file( + file_path=processed_path, + nickname=nickname, + category="rvm", + bos_uploaded=bos_uploaded, + extra={ + "source": "rvm_remove_background", + "background_color": background_color, + "output_format": output_format, + }, + ) + + # 确定输出格式 + final_output_format = "PNG" if bg_color is None and output_format == "png" else \ + "WEBP" if bg_color is None and output_format == "webp" else "JPEG" + has_transparency = bg_color is None + + return { + "success": True, + "message": "RVM抠图成功", + "original_filename": file.filename, + "processed_filename": processed_filename, + "processing_time": f"{total_time:.3f}s", + "processing_info": processing_info, + "original_size": original_size, + "processed_size": processed_size, + "size_change_ratio": round(processed_size / original_size, 2) if original_size > 0 else 1.0, + "original_dimensions": f"{original_width} × {original_height}", + "processed_dimensions": f"{processed_width} × {processed_height}", + "background_color": background_color, + "output_format": final_output_format, + "has_transparency": has_transparency + } + else: + raise HTTPException(status_code=500, detail="保存RVM抠图后图像失败") + + except HTTPException: + # 重新抛出HTTP异常 + raise + except Exception as e: + logger.error(f"Error occurred during RVM ID photo background removal: {str(e)}") + raise Exception(f"RVM抠图过程中出现错误: {str(e)}") + + +@api_router.get("/keep_alive", tags=["系统维护"]) +@log_api_params +async def keep_cpu_alive( + duration: float = Query( + 0.01, ge=0.001, le=60.0, description="需要保持CPU繁忙的持续时间(秒)" + ), + intensity: int = Query( + 1, ge=1, le=500000, description="控制CPU占用强度的内部循环次数" + ), +): + """ + 手动触发CPU保持活跃,避免云服务因空闲进入休眠。 + """ + t_start = time.perf_counter() + result = await process_cpu_intensive_task(_keep_cpu_busy, duration, intensity) + total_elapsed = time.perf_counter() - t_start + + logger.info( + "Keep-alive task completed | duration=%.2fs intensity=%d iterations=%d checksum=%d cpu_elapsed=%.3fs total=%.3fs", + duration, + intensity, + result["iterations"], + result["checksum"], + result["elapsed"], + total_elapsed, + ) + + return { + "status": "ok", + "requested_duration": duration, + "requested_intensity": intensity, + "cpu_elapsed": round(result["elapsed"], 3), + "total_elapsed": round(total_elapsed, 3), + "iterations": result["iterations"], + "checksum": result["checksum"], + "message": "CPU保持活跃任务已完成", + "hostname": SERVER_HOSTNAME, + } + + +@api_router.get("/health") +@log_api_params +async def health_check(): + """健康检查接口""" + return { + "status": "healthy", + "analyzer_ready": analyzer is not None, + "deepface_available": DEEPFACE_AVAILABLE, + "mediapipe_available": DLIB_AVAILABLE, + "photo_restorer_available": photo_restorer is not None and photo_restorer.is_available(), + "restorer_type": restorer_type, + "ddcolor_available": ddcolor_colorizer is not None and ddcolor_colorizer.is_available(), + "colorization_supported": DDCOLOR_AVAILABLE, + "realesrgan_available": realesrgan_upscaler is not None and realesrgan_upscaler.is_available(), + "upscale_supported": REALESRGAN_AVAILABLE, + "rembg_available": rembg_processor is not None and rembg_processor.is_available(), + "rvm_available": rvm_processor is not None and rvm_processor.is_available(), + "id_photo_supported": REMBG_AVAILABLE, + "clip_available": CLIP_AVAILABLE, + "vector_search_supported": CLIP_AVAILABLE, + "anime_stylizer_available": anime_stylizer is not None and anime_stylizer.is_available(), + "anime_style_supported": ANIME_STYLE_AVAILABLE, + "rvm_supported": RVM_AVAILABLE, + "message": "Enhanced FaceScore API is running with photo restoration, colorization, upscale, ID photo generation and vector search support", + "version": "3.2.0", + } + + +@api_router.get("/", response_class=HTMLResponse) +@log_api_params +async def index(): + """主页面""" + file_path = os.path.join(os.path.dirname(__file__), "facescore.html") + try: + with open(file_path, "r", encoding="utf-8") as f: + html_content = f.read() + return HTMLResponse(content=html_content) + except FileNotFoundError: + return HTMLResponse( + content="

facescore.html not found

", status_code=404 + ) + + +@api_router.post("/split_grid") +@log_api_params +async def split_grid_image( + file: UploadFile = File(...), + grid_type: int = Form(9, + description="宫格类型: 4表示2x2四宫格, 9表示3x3九宫格"), + nickname: str = Form(None, description="操作者昵称"), +): + """ + 图片分层宫格接口 + :param file: 上传的图片文件 + :param grid_type: 宫格类型,4表示2x2四宫格,9表示3x3九宫格 + :return: 分层结果,包含分割后的图片文件名列表 + """ + # 验证文件类型 + if not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="请上传图片文件") + + # 验证宫格类型 + if grid_type not in [4, 9]: + raise HTTPException(status_code=400, detail="宫格类型只支持4(2x2)或9(3x3)") + + try: + contents = await file.read() + original_md5_hash = str(uuid.uuid4()).replace('-', '') + + # 根据宫格类型确定行列数 + if grid_type == 4: + rows, cols = 2, 2 + grid_name = "2x2" + else: # grid_type == 9 + rows, cols = 3, 3 + grid_name = "3x3" + + logger.info(f"Starting to split image into {grid_name} grid: {file.filename}, size={file.size}, md5={original_md5_hash}") + t1 = time.perf_counter() + + # 解码图像 + np_arr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException( + status_code=400, detail="无法解析图片文件,请确保文件格式正确。" + ) + + # 获取图像尺寸 + height, width = image.shape[:2] + + # 智能分割算法:确保朋友圈拼接不变形 + logger.info(f"Original image size: {width}×{height}, grid type: {grid_name}") + + # 计算图片长宽比 + aspect_ratio = width / height + logger.info(f"Image aspect ratio: {aspect_ratio:.2f}") + + # 使用更简单可靠的策略:总是取较小的边作为基准 + # 这样确保不管是4宫格还是9宫格都能正确处理 + min_dimension = min(width, height) + + # 计算每个格子的尺寸(正方形) + # 为了确保完整分割,我们使用最大的行列数作为除数 + square_size = min_dimension // max(rows, cols) + + # 重新计算实际使用的图片区域(正方形区域) + actual_width = square_size * cols + actual_height = square_size * rows + + # 计算居中裁剪的起始位置 + start_x = (width - actual_width) // 2 + start_y = (height - actual_height) // 2 + + logger.info(f"Calculation result - Grid size: {square_size}×{square_size}, usage area: {actual_width}×{actual_height}, starting position: ({start_x}, {start_y})") + + # 分割图片并保存每个格子 + grid_filenames = [] + + for row in range(rows): + for col in range(cols): + # 计算当前正方形格子的坐标 + y1 = start_y + row * square_size + y2 = start_y + (row + 1) * square_size + x1 = start_x + col * square_size + x2 = start_x + (col + 1) * square_size + + # 裁剪当前格子(正方形) + grid_image = image[y1:y2, x1:x2] + + # 生成格子文件名 + grid_index = row * cols + col + 1 # 从1开始编号 + grid_filename = f"{original_md5_hash}_grid_{grid_name}_{grid_index:02d}.webp" + grid_path = os.path.join(IMAGES_DIR, grid_filename) + + # 保存格子图片 + save_success = save_image_high_quality(grid_image, grid_path, quality=SAVE_QUALITY) + + if save_success: + grid_filenames.append(grid_filename) + else: + logger.error(f"Failed to save grid image: {grid_filename}") + if save_success: + await _record_output_file( + file_path=grid_path, + nickname=nickname, + category="grid", + extra={ + "source": "split_grid", + "grid_type": grid_type, + "index": grid_index, + }, + ) + + # 同时保存原图到IMAGES_DIR供向量化使用 + original_filename = f"{original_md5_hash}_original.webp" + original_path = os.path.join(IMAGES_DIR, original_filename) + if save_image_high_quality(image, original_path, quality=SAVE_QUALITY): + await _record_output_file( + file_path=original_path, + nickname=nickname, + category="original", + extra={ + "source": "split_grid", + "grid_type": grid_type, + "role": "original", + }, + ) + + # 异步执行原图向量化并入库 + if CLIP_AVAILABLE: + asyncio.create_task(handle_image_vector_async(original_path, original_filename)) + + total_time = time.perf_counter() - t1 + logger.info(f"Image splitting completed: {len(grid_filenames)} grids, time: {total_time:.3f}s") + + return { + "success": True, + "message": "分割成功", + "original_filename": file.filename, + "original_saved_filename": original_filename, + "grid_type": grid_type, + "grid_layout": f"{rows}x{cols}", + "grid_count": len(grid_filenames), + "grid_filenames": grid_filenames, + "processing_time": f"{total_time:.3f}s", + "image_dimensions": f"{width} × {height}", + "grid_dimensions": f"{square_size} × {square_size}", + "actual_used_area": f"{actual_width} × {actual_height}" + } + + except Exception as e: + logger.error(f"Error occurred during image splitting: {str(e)}") + raise HTTPException(status_code=500, detail=f"分割过程中出现错误: {str(e)}") + + +@api_router.post("/compress") +@log_api_params +async def compress_image( + file: UploadFile = File(...), + compressType: str = Form(...), + outputFormat: str = Form(default="webp"), + quality: int = Form(default=100), + targetSize: float = Form(default=None), + width: int = Form(default=None), + height: int = Form(default=None), + nickname: str = Form(None, description="操作者昵称"), +): + """ + 图像压缩接口 + :param file: 上传的图片文件 + :param compressType: 压缩类型 ('quality', 'dimension', 'size', 'format') + :param outputFormat: 输出格式 ('jpg', 'png', 'webp') + :param quality: 压缩质量 (10-100) + :param targetSize: 目标文件大小 (bytes,仅用于按大小压缩) + :param width: 目标宽度 (仅用于按尺寸压缩) + :param height: 目标高度 (仅用于按尺寸压缩) + :return: 压缩结果,包含压缩后图片的文件名和统计信息 + """ + # 验证文件类型 + if not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="请上传图片文件") + + try: + contents = await file.read() + unique_id = str(uuid.uuid4()).replace('-', '')[:32] # 12位随机ID + compressed_filename = f"{unique_id}_compress.{outputFormat.lower()}" + logger.info( + f"Starting to compress image: {file.filename}, " + f"type: {compressType}, " + f"format: {outputFormat}, " + f"quality: {quality}, " + f"target size: {targetSize}, " + f"target width: {width}, " + f"target height: {height}" + ) + t1 = time.perf_counter() + + # 解码图像 + np_arr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException( + status_code=400, detail="无法解析图片文件,请确保文件格式正确。" + ) + + # 获取原图信息 + original_height, original_width = image.shape[:2] + original_size = file.size + + # 根据压缩类型调用相应的压缩函数 + try: + if compressType == 'quality': + # 按质量压缩 + if not (10 <= quality <= 100): + raise HTTPException(status_code=400, detail="质量参数必须在10-100之间") + compressed_bytes, compress_info = compress_image_by_quality(image, quality, outputFormat) + + elif compressType == 'dimension': + # 按尺寸压缩 + if not width or not height: + raise HTTPException(status_code=400, detail="按尺寸压缩需要提供宽度和高度参数") + if not (50 <= width <= 4096) or not (50 <= height <= 4096): + raise HTTPException(status_code=400, detail="尺寸参数必须在50-4096之间") + # 按尺寸压缩时使用100质量(不压缩质量) + compressed_bytes, compress_info = compress_image_by_dimensions( + image, width, height, 100, outputFormat + ) + + elif compressType == 'size': + # 按大小压缩 + if not targetSize or targetSize <= 0: + raise HTTPException(status_code=400, detail="按大小压缩需要提供有效的目标大小") + if targetSize > 50: # 限制最大50MB + raise HTTPException(status_code=400, detail="目标大小不能超过50MB") + target_size_kb = targetSize * 1024 # 转换为KB + compressed_bytes, compress_info = compress_image_by_file_size( + image, target_size_kb, outputFormat + ) + + elif compressType == 'format': + # 格式转换 + compressed_bytes, compress_info = convert_image_format(image, outputFormat, quality) + + else: + raise HTTPException(status_code=400, detail="不支持的压缩类型") + + except Exception as e: + logger.error(f"Image compression processing failed: {e}") + raise HTTPException(status_code=500, detail=f"压缩处理失败: {str(e)}") + + # 保存压缩后的图像到IMAGES_DIR + compressed_path = os.path.join(IMAGES_DIR, compressed_filename) + try: + with open(compressed_path, "wb") as f: + f.write(compressed_bytes) + bos_uploaded = upload_file_to_bos(compressed_path) + logger.info(f"Compressed image saved successfully: {compressed_path}") + + # 异步执行图片向量化并入库,不阻塞主流程 + if CLIP_AVAILABLE: + asyncio.create_task(handle_image_vector_async(compressed_path, compressed_filename)) + await _record_output_file( + file_path=compressed_path, + nickname=nickname, + category="compress", + bos_uploaded=bos_uploaded, + extra={ + "source": "compress", + "compress_type": compressType, + "output_format": outputFormat, + }, + ) + + except Exception as e: + logger.error(f"Failed to save compressed image: {e}") + raise HTTPException(status_code=500, detail="保存压缩后图像失败") + + # 计算压缩统计信息 + processing_time = time.perf_counter() - t1 + compressed_size = len(compressed_bytes) + compression_ratio = ((original_size - compressed_size) / original_size) * 100 if original_size > 0 else 0 + + # 构建返回结果 + result = { + "success": True, + "message": "压缩成功", + "original_filename": file.filename, + "compressed_filename": compressed_filename, + "original_size": original_size, + "compressed_size": compressed_size, + "compression_ratio": round(compression_ratio, 1), + "original_dimensions": f"{original_width} × {original_height}", + "compressed_dimensions": compress_info.get('compressed_dimensions', f"{original_width} × {original_height}"), + "processing_time": f"{processing_time:.3f}s", + "output_format": compress_info.get('format', outputFormat.upper()), + "compress_type": compressType, + "quality_used": compress_info.get('quality', quality), + "attempts": compress_info.get('attempts', 1) + } + + logger.info( + f"Image compression completed: {compressed_filename}, time: {processing_time:.3f}s, " + f"original size: {human_readable_size(original_size)}, " + f"compressed: {human_readable_size(compressed_size)}, " + f"compression ratio: {compression_ratio:.1f}%" + ) + + return JSONResponse(content=convert_numpy_types(result)) + + except HTTPException: + # 重新抛出HTTP异常 + raise + except Exception as e: + logger.error(f"Error occurred during image compression: {str(e)}") + raise HTTPException(status_code=500, detail=f"压缩过程中出现错误: {str(e)}") + + +@api_router.get("/cleanup/status", tags=["系统管理"]) +@log_api_params +async def get_cleanup_scheduler_status(): + """ + 获取图片清理定时任务状态 + :return: 清理任务的状态信息 + """ + try: + status = get_cleanup_status() + return { + "success": True, + "status": status, + "message": "获取清理任务状态成功" + } + except Exception as e: + logger.error(f"Failed to get cleanup task status: {e}") + raise HTTPException(status_code=500, detail=f"获取清理任务状态失败: {str(e)}") + + +@api_router.post("/cleanup/manual", tags=["系统管理"]) +@log_api_params +async def manual_cleanup_images(): + """ + 手动执行一次图片清理任务 + 清理IMAGES_DIR目录中1小时以前的图片文件 + :return: 清理结果统计 + """ + try: + logger.info("Manually executing image cleanup task...") + result = manual_cleanup() + + if result['success']: + # Chinese message for API response + message = f"清理完成! 删除了 {result['deleted_count']} 个文件" + if result['deleted_count'] > 0: + message += f", 总大小: {result.get('deleted_size', 0) / 1024 / 1024:.2f} MB" + # English log for readability + en_message = f"Cleanup completed! Deleted {result['deleted_count']} files" + if result['deleted_count'] > 0: + en_message += f", total size: {result.get('deleted_size', 0) / 1024 / 1024:.2f} MB" + logger.info(en_message) + else: + # Chinese message for API response + error_str = result.get('error', '未知错误') + message = f"清理任务执行失败: {error_str}" + # English log for readability + logger.error(f"Cleanup task failed: {error_str}") + + return { + "success": result['success'], + "message": message, + "result": result + } + + except Exception as e: + logger.error(f"Manual cleanup task execution failed: {e}") + raise HTTPException(status_code=500, detail=f"手动清理任务执行失败: {str(e)}") + + +def _extract_tar_archive(archive_path: str, target_dir: str) -> Dict[str, str]: + """在独立线程中执行tar命令,避免阻塞事件循环。""" + cmd = ["tar", "-xzf", archive_path, "-C", target_dir] + cmd_display = " ".join(cmd) + logger.info(f"开始执行解压命令: {cmd_display}") + completed = subprocess.run( + cmd, capture_output=True, text=True, check=False + ) + if completed.returncode != 0: + stderr = (completed.stderr or "").strip() + raise RuntimeError(f"tar命令执行失败: {stderr or '未知错误'}") + logger.info(f"解压命令执行成功: {cmd_display}") + return { + "command": cmd_display, + "stdout": (completed.stdout or "").strip(), + "stderr": (completed.stderr or "").strip(), + } + + +def _flatten_chinese_celeb_dataset_dir(target_dir: str) -> bool: + """ + 若解压后出现 /opt/data/... 的嵌套结构,将内容提升到 target_dir 根目录,避免重复嵌套。 + """ + nested_root = os.path.join(target_dir, "opt", "data", "chinese_celeb_dataset") + if not os.path.isdir(nested_root): + return False + + for name in os.listdir(nested_root): + src = os.path.join(nested_root, name) + dst = os.path.join(target_dir, name) + shutil.move(src, dst) + + # 清理多余的 opt/data 目录 + try: + shutil.rmtree(os.path.join(target_dir, "opt")) + except FileNotFoundError: + pass + return True + + +def _cleanup_chinese_celeb_hidden_files(target_dir: str) -> int: + """ + 删除解压后遗留的 macOS 资源分叉文件(._*),避免污染后续处理。 + """ + pattern = os.path.join(target_dir, "._*") + removed = 0 + for hidden_path in glob.glob(pattern): + try: + if os.path.isdir(hidden_path): + shutil.rmtree(hidden_path, ignore_errors=True) + else: + os.remove(hidden_path) + removed += 1 + except FileNotFoundError: + continue + except OSError as exc: + logger.warning("清理隐藏文件失败: %s (%s)", hidden_path, exc) + if removed: + logger.info("已清理 chinese_celeb_dataset 隐藏文件 %d 个 (pattern=%s)", removed, pattern) + return removed + + +def extract_chinese_celeb_dataset_sync() -> Dict[str, Any]: + """ + 同步执行 chinese_celeb_dataset 解压操作,供启动流程或其他同步场景复用。 + """ + archive_path = os.path.join(MODELS_PATH, "chinese_celeb_dataset.tar.gz") + target_dir = "/opt/data/chinese_celeb_dataset" + + if not os.path.isfile(archive_path): + raise FileNotFoundError(f"数据集文件不存在: {archive_path}") + + try: + if os.path.isdir(target_dir): + shutil.rmtree(target_dir) + os.makedirs(target_dir, exist_ok=True) + except OSError as exc: + logger.error(f"创建目标目录失败: {target_dir}, {exc}") + raise RuntimeError(f"创建目标目录失败: {exc}") from exc + + extract_result = _extract_tar_archive(archive_path, target_dir) + flattened = _flatten_chinese_celeb_dataset_dir(target_dir) + hidden_removed = _cleanup_chinese_celeb_hidden_files(target_dir) + + return { + "success": True, + "message": "chinese_celeb_dataset 解压完成", + "archive_path": archive_path, + "target_dir": target_dir, + "command": extract_result.get("command"), + "stdout": extract_result.get("stdout"), + "stderr": extract_result.get("stderr"), + "normalized": flattened, + "hidden_removed": hidden_removed, + } + + +def _run_shell_command(command: str, timeout: int = 300) -> Dict[str, Any]: + """执行外部命令并返回输出。""" + logger.info(f"准备执行系统命令: {command}") + try: + completed = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + timeout=timeout, + ) + except subprocess.TimeoutExpired as exc: + logger.error(f"命令执行超时({timeout}s): {command}") + raise RuntimeError(f"命令执行超时({timeout}s): {exc}") from exc + + return { + "returncode": completed.returncode, + "stdout": (completed.stdout or "").strip(), + "stderr": (completed.stderr or "").strip(), + } + + +@api_router.post("/datasets/chinese-celeb/extract", tags=["系统管理"]) +@log_api_params +async def extract_chinese_celeb_dataset(): + """ + 解压 MODELS_PATH 下的 chinese_celeb_dataset.tar.gz 到 /opt/data/chinese_celeb_dataset。 + """ + loop = asyncio.get_event_loop() + try: + result = await loop.run_in_executor( + executor, extract_chinese_celeb_dataset_sync + ) + except FileNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + except Exception as exc: + logger.error(f"解压 chinese_celeb_dataset 失败: {exc}") + raise HTTPException(status_code=500, detail=f"解压失败: {exc}") + + return result + + +@api_router.post("/files/upload", tags=["文件管理"]) +@log_api_params +async def upload_file_to_directory( + directory: str = Form(..., description="目标目录,支持绝对路径"), + file: UploadFile = File(..., description="要上传的文件"), +): + """上传文件到指定目录。""" + if not directory.strip(): + raise HTTPException(status_code=400, detail="目录参数不能为空") + + target_dir = os.path.abspath(os.path.expanduser(directory.strip())) + try: + os.makedirs(target_dir, exist_ok=True) + except OSError as exc: + logger.error(f"创建目录失败: {target_dir}, {exc}") + raise HTTPException(status_code=500, detail=f"创建目录失败: {exc}") + + original_name = file.filename or "uploaded_file" + filename = os.path.basename(original_name) or f"upload_{int(time.time())}" + target_path = os.path.join(target_dir, filename) + + bytes_written = 0 + try: + with open(target_path, "wb") as out_file: + while True: + chunk = await file.read(1024 * 1024) + if not chunk: + break + out_file.write(chunk) + bytes_written += len(chunk) + except Exception as exc: + logger.error(f"保存上传文件失败: {exc}") + raise HTTPException(status_code=500, detail=f"保存文件失败: {exc}") + + return { + "success": True, + "message": "文件上传成功", + "saved_path": target_path, + "filename": filename, + "size": bytes_written, + } + + +@api_router.get("/files/download", tags=["文件管理"]) +@log_api_params +async def download_file( + file_path: str = Query(..., description="要下载的文件路径,支持绝对路径"), +): + """根据给定路径下载文件。""" + if not file_path.strip(): + raise HTTPException(status_code=400, detail="文件路径不能为空") + + resolved_path = os.path.abspath(os.path.expanduser(file_path.strip())) + if not os.path.isfile(resolved_path): + raise HTTPException(status_code=404, detail=f"文件不存在: {resolved_path}") + + filename = os.path.basename(resolved_path) or "download" + return FileResponse( + resolved_path, + filename=filename, + media_type="application/octet-stream", + ) + + +@api_router.post("/system/command", tags=["系统管理"]) +@log_api_params +async def execute_system_command(payload: Dict[str, Any]): + """ + 执行Linux命令并返回stdout/stderr。 + payload示例: {"command": "ls -l", "timeout": 120} + """ + command = (payload or {}).get("command") + if not command or not isinstance(command, str): + raise HTTPException(status_code=400, detail="必须提供command字符串") + + timeout = payload.get("timeout", 300) + try: + timeout_val = int(timeout) + except (TypeError, ValueError): + raise HTTPException(status_code=400, detail="timeout必须为整数") + if timeout_val <= 0: + raise HTTPException(status_code=400, detail="timeout必须为正整数") + + loop = asyncio.get_event_loop() + try: + result = await loop.run_in_executor( + executor, _run_shell_command, command, timeout_val + ) + except Exception as exc: + logger.error(f"命令执行失败: {exc}") + raise HTTPException(status_code=500, detail=f"命令执行失败: {exc}") + + success = result.get("returncode", 1) == 0 + return { + "success": success, + "command": command, + "returncode": result.get("returncode"), + "stdout": result.get("stdout"), + "stderr": result.get("stderr"), + } + + + +@api_router.post("/celebrity/keep_alive", tags=["系统维护"]) +@log_api_params +async def celebrity_keep_cpu_alive( + duration: float = Query( + 0.01, ge=0.001, le=60.0, description="需要保持CPU繁忙的持续时间(秒)" + ), + intensity: int = Query( + 1, ge=1, le=50000, description="控制CPU占用强度的内部循环次数" + ), +): + """ + 手动触发CPU保持活跃,避免云服务因空闲进入休眠。 + """ + t_start = time.perf_counter() + result = await process_cpu_intensive_task(_keep_cpu_busy, duration, intensity) + total_elapsed = time.perf_counter() - t_start + + logger.info( + "Keep-alive task completed | duration=%.2fs intensity=%d iterations=%d checksum=%d cpu_elapsed=%.3fs total=%.3fs", + duration, + intensity, + result["iterations"], + result["checksum"], + result["elapsed"], + total_elapsed, + ) + + return { + "status": "ok", + "requested_duration": duration, + "requested_intensity": intensity, + "cpu_elapsed": round(result["elapsed"], 3), + "total_elapsed": round(total_elapsed, 3), + "iterations": result["iterations"], + "checksum": result["checksum"], + "message": "CPU保持活跃任务已完成", + "hostname": SERVER_HOSTNAME, + } + + +@api_router.post("/celebrity/load", tags=["Face Recognition"]) +@log_api_params +async def load_celebrity_database(): + """刷新DeepFace明星人脸库缓存""" + if not DEEPFACE_AVAILABLE or deepface_module is None: + raise HTTPException(status_code=500, + detail="DeepFace模块未初始化,请检查服务状态。") + + folder_path = CELEBRITY_SOURCE_DIR + if not folder_path: + raise HTTPException(status_code=500, + detail="未配置明星图库目录,请设置环境变量 CELEBRITY_SOURCE_DIR。") + + folder_path = os.path.abspath(os.path.expanduser(folder_path)) + if not os.path.isdir(folder_path): + raise HTTPException(status_code=400, + detail=f"文件夹不存在: {folder_path}") + + image_files = _iter_celebrity_images(folder_path) + if not image_files: + raise HTTPException(status_code=400, + detail="明星图库目录中未找到有效图片。") + + encoded_files = [] + renamed = [] + + for src_path in image_files: + directory, original_name = os.path.split(src_path) + base_name, ext = os.path.splitext(original_name) + + suffix_part = "" + base_core = base_name + if "__" in base_name: + base_core, suffix_part = base_name.split("__", 1) + suffix_part = f"__{suffix_part}" + + decoded_core = _decode_basename(base_core) + if _encode_basename(decoded_core) == base_core: + encoded_base = base_core + else: + encoded_base = _encode_basename(base_name) + suffix_part = "" + + candidate_name = f"{encoded_base}{suffix_part}{ext.lower()}" + target_path = os.path.join(directory, candidate_name) + + if os.path.normcase(src_path) != os.path.normcase(target_path): + suffix = 1 + while os.path.exists(target_path): + candidate_name = f"{encoded_base}__{suffix}{ext.lower()}" + target_path = os.path.join(directory, candidate_name) + suffix += 1 + try: + os.rename(src_path, target_path) + renamed.append({"old": src_path, "new": target_path}) + except Exception as err: + logger.error( + f"Failed to rename celebrity image {src_path}: {err}") + continue + + encoded_files.append(target_path) + + if not encoded_files: + raise HTTPException(status_code=400, + detail="明星图片重命名失败,请检查目录内容。") + + sample_image = encoded_files[0] + start_time = time.perf_counter() + logger.info( + f"开始刷新明星人脸向量缓存,样本图片: {sample_image}, 总数: {len(encoded_files)}") + + stop_event = asyncio.Event() + progress_task = asyncio.create_task( + _log_progress("刷新明星人脸缓存", start_time, stop_event, interval=5.0)) + + try: + await _refresh_celebrity_cache(sample_image, folder_path) + finally: + stop_event.set() + try: + await progress_task + except Exception: + pass + + total_time = time.perf_counter() - start_time + + logger.info( + f"Celebrity library refreshed. total_images={len(encoded_files)} renamed={len(renamed)} sample={sample_image} elapsed={total_time:.1f}s" + ) + + return { + "success": True, + "message": "明星图库缓存刷新成功", + "data": { + "total_images": len(encoded_files), + "renamed": renamed, + "sample_image": sample_image, + "source": folder_path, + "processing_time": total_time, + }, + } + + +@api_router.post("/celebrity/match", tags=["Face Recognition"]) +@log_api_params +async def match_celebrity_face( + file: UploadFile = File(..., description="待匹配的用户图片"), + nickname: str = Form(None, description="操作者昵称"), +): + """ + 上传图片与明星人脸库比对 + :param file: 上传图片 + :return: 最相似的明星文件及分数 + """ + if not DEEPFACE_AVAILABLE or deepface_module is None: + raise HTTPException(status_code=500, + detail="DeepFace模块未初始化,请检查服务状态。") + + primary_dir = CELEBRITY_SOURCE_DIR + if not primary_dir: + raise HTTPException(status_code=500, + detail="未配置明星图库目录,请设置环境变量 CELEBRITY_SOURCE_DIR。") + + db_path = os.path.abspath(os.path.expanduser(primary_dir)) + if not os.path.isdir(db_path): + raise HTTPException(status_code=400, + detail=f"明星图库目录不存在: {db_path}") + + existing_files = _iter_celebrity_images(db_path) + if not existing_files: + raise HTTPException(status_code=400, + detail="明星人脸库为空,请先调用导入接口。") + + if not file.content_type or not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="请上传图片文件。") + + temp_filename: Optional[str] = None + temp_path: Optional[str] = None + cleanup_temp_file = False + annotated_filename: Optional[str] = None + + try: + contents = await file.read() + np_arr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if image is None: + raise HTTPException(status_code=400, + detail="无法解析上传的图片,请确认格式。") + + if analyzer is None: + _ensure_analyzer() + + faces: List[List[int]] = [] + if analyzer is not None: + faces = analyzer._detect_faces(image) + if not faces: + raise HTTPException(status_code=400, + detail="图片中未检测到人脸,请重新上传。") + + temp_filename = f"{uuid.uuid4().hex}_celebrity_query.webp" + temp_path = os.path.join(IMAGES_DIR, temp_filename) + if not save_image_high_quality(image, temp_path, quality=SAVE_QUALITY): + raise HTTPException(status_code=500, + detail="保存临时图片失败,请稍后重试。") + cleanup_temp_file = True + await _record_output_file( + file_path=temp_path, + nickname=nickname, + category="celebrity", + extra={ + "source": "celebrity_match", + "role": "query", + }, + ) + + def _build_find_kwargs(refresh: bool) -> dict: + kwargs = dict( + img_path=temp_path, + db_path=db_path, + model_name="ArcFace", + detector_backend="yolov11n", + distance_metric="cosine", + enforce_detection=True, + silent=True, + refresh_database=refresh, + ) + if CELEBRITY_FIND_THRESHOLD is not None: + kwargs["threshold"] = CELEBRITY_FIND_THRESHOLD + return kwargs + + lock = _ensure_deepface_lock() + async with lock: + try: + find_result = await process_cpu_intensive_task( + deepface_module.find, + **_build_find_kwargs(refresh=False), + ) + except (AttributeError, RuntimeError) as attr_err: + if "numpy" in str(attr_err) or "SymbolicTensor" in str(attr_err): + logger.warning( + f"DeepFace find encountered numpy/SymbolicTensor error, 尝试清理模型后刷新缓存: {attr_err}") + _recover_deepface_model() + find_result = await process_cpu_intensive_task( + deepface_module.find, + **_build_find_kwargs(refresh=True), + ) + else: + raise + except ValueError as ve: + logger.warning( + f"DeepFace find failed without refresh: {ve}, 尝试清理模型后刷新缓存。") + _recover_deepface_model() + find_result = await process_cpu_intensive_task( + deepface_module.find, + **_build_find_kwargs(refresh=True), + ) + + if not find_result: + raise HTTPException(status_code=404, detail="未找到相似的人脸。") + + result_df = find_result[0] + best_record = None + if hasattr(result_df, "empty"): + if result_df.empty: + raise HTTPException(status_code=404, detail="未找到相似的人脸。") + best_record = result_df.iloc[0] + elif isinstance(result_df, list) and result_df: + best_record = result_df[0] + else: + raise HTTPException(status_code=500, + detail="明星人脸库返回格式异常。") + + # Pandas Series 转 dict,确保后续访问统一 + if hasattr(best_record, "to_dict"): + best_record_data = best_record.to_dict() + else: + best_record_data = dict(best_record) + + identity_path = str(best_record_data.get("identity", "")) + if not identity_path: + raise HTTPException(status_code=500, + detail="识别结果缺少identity字段。") + + distance = float(best_record_data.get("distance", 0.0)) + similarity = max(0.0, min(100.0, (1 - distance / 2) * 100)) + confidence_raw = best_record_data.get("confidence") + confidence = float( + confidence_raw) if confidence_raw is not None else similarity + filename = os.path.basename(identity_path) + base, ext = os.path.splitext(filename) + encoded_part = base.split("__", 1)[0] if "__" in base else base + display_name = _decode_basename(encoded_part) + + def _parse_coord(value): + try: + if value is None: + return None + if isinstance(value, (np.integer, int)): + return int(value) + if isinstance(value, (np.floating, float)): + if np.isnan(value): + return None + return int(round(float(value))) + if isinstance(value, str) and value.strip(): + return int(round(float(value))) + except Exception: + return None + return None + + img_height, img_width = image.shape[:2] + crop = None + + matched_box = None + + sx = _parse_coord(best_record_data.get("source_x")) + sy = _parse_coord(best_record_data.get("source_y")) + sw = _parse_coord(best_record_data.get("source_w")) + sh = _parse_coord(best_record_data.get("source_h")) + + if ( + sx is not None + and sy is not None + and sw is not None + and sh is not None + and sw > 0 + and sh > 0 + ): + x1 = max(0, sx) + y1 = max(0, sy) + x2 = min(img_width, x1 + sw) + y2 = min(img_height, y1 + sh) + if x2 > x1 and y2 > y1: + crop = image[y1:y2, x1:x2] + matched_box = (x1, y1, x2, y2) + + if (crop is None or crop.size == 0) and faces: + def _area(box): + if not box or len(box) < 4: + return 0 + return max(0, box[2] - box[0]) * max(0, box[3] - box[1]) + + largest_face = max(faces, key=_area) + if largest_face and len(largest_face) >= 4: + fx1, fy1, fx2, fy2 = [int(max(0, v)) for v in largest_face[:4]] + fx1 = min(fx1, img_width - 1) + fy1 = min(fy1, img_height - 1) + fx2 = min(max(fx1 + 1, fx2), img_width) + fy2 = min(max(fy1 + 1, fy2), img_height) + if fx2 > fx1 and fy2 > fy1: + crop = image[fy1:fy2, fx1:fx2] + matched_box = (fx1, fy1, fx2, fy2) + + face_filename = None + if crop is not None and crop.size > 0: + face_filename = f"{uuid.uuid4().hex}_face_1.webp" + face_path = os.path.join(IMAGES_DIR, face_filename) + if not save_image_high_quality(crop, face_path, + quality=SAVE_QUALITY): + logger.error(f"Failed to save cropped face image: {face_path}") + face_filename = None + else: + await _record_output_file( + file_path=face_path, + nickname=nickname, + category="face", + extra={ + "source": "celebrity_match", + "role": "face_crop", + }, + ) + if matched_box is not None and temp_path: + annotated_image = image.copy() + x1, y1, x2, y2 = matched_box + thickness = max(2, int(round(min(img_height, img_width) / 200))) + thickness = max(thickness, 2) + cv2.rectangle(annotated_image, (x1, y1), (x2, y2), + color=(0, 255, 0), thickness=thickness) + if save_image_high_quality(annotated_image, temp_path, + quality=SAVE_QUALITY): + annotated_filename = temp_filename + cleanup_temp_file = False + await _record_output_file( + file_path=temp_path, + nickname=nickname, + category="celebrity", + extra={ + "source": "celebrity_match", + "role": "annotated", + }, + ) + else: + logger.error( + f"Failed to save annotated celebrity image: {temp_path}") + elif temp_path: + # 未拿到匹配框,保持原图但仍保留文件供返回 + annotated_filename = temp_filename + cleanup_temp_file = False + + result_payload = CelebrityMatchResponse( + filename=filename, + display_name=display_name, + distance=distance, + similarity=similarity, + confidence=confidence, + face_filename=face_filename, + ) + + return { + "success": True, + "filename": result_payload.filename, + "display_name": result_payload.display_name, + "distance": result_payload.distance, + "similarity": result_payload.similarity, + "confidence": result_payload.confidence, + "face_filename": result_payload.face_filename, + "annotated_filename": annotated_filename, + } + except HTTPException: + raise + except Exception as e: + logger.error(f"Celebrity match failed: {e}") + raise HTTPException(status_code=500, + detail=f"明星人脸匹配失败: {str(e)}") + finally: + if cleanup_temp_file and temp_path: + try: + os.remove(temp_path) + except Exception: + pass + + +@api_router.post("/face_verify") +@log_api_params +async def face_similarity_verification( + file1: UploadFile = File(..., description="第一张人脸图片"), + file2: UploadFile = File(..., description="第二张人脸图片"), + nickname: str = Form(None, description="操作者昵称"), +): + """ + 人脸相似度比对接口 + :param file1: 第一张人脸图片文件 + :param file2: 第二张人脸图片文件 + :return: 人脸比对结果,包括相似度分值和裁剪后的人脸图片 + """ + # 检查DeepFace是否可用 + if not DEEPFACE_AVAILABLE or deepface_module is None: + raise HTTPException( + status_code=500, + detail="DeepFace模块未初始化,请检查服务状态。" + ) + + # 验证文件类型 + if not file1.content_type.startswith("image/") or not file2.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="请上传图片文件") + + try: + # 读取两张图片 + contents1 = await file1.read() + contents2 = await file2.read() + + # 生成唯一标识符 + md5_hash1 = str(uuid.uuid4()).replace('-', '') + md5_hash2 = str(uuid.uuid4()).replace('-', '') + + # 生成文件名 + original_filename1 = f"{md5_hash1}_original1.webp" + original_filename2 = f"{md5_hash2}_original2.webp" + face_filename1 = f"{md5_hash1}_face1.webp" + face_filename2 = f"{md5_hash2}_face2.webp" + + logger.info(f"Starting face similarity verification: {file1.filename} vs {file2.filename}") + t1 = time.perf_counter() + + # 解码图像 + np_arr1 = np.frombuffer(contents1, np.uint8) + image1 = cv2.imdecode(np_arr1, cv2.IMREAD_COLOR) + if image1 is None: + raise HTTPException(status_code=400, detail="无法解析第一张图片文件,请确保文件格式正确。") + + np_arr2 = np.frombuffer(contents2, np.uint8) + image2 = cv2.imdecode(np_arr2, cv2.IMREAD_COLOR) + if image2 is None: + raise HTTPException(status_code=400, detail="无法解析第二张图片文件,请确保文件格式正确。") + + # 检查图片中是否包含人脸 + if analyzer is None: + _ensure_analyzer() + + if analyzer is not None: + # 检查第一张图片是否包含人脸 + logger.info("detect 1 image...") + face_boxes1 = analyzer._detect_faces(image1) + if not face_boxes1: + raise HTTPException(status_code=400, detail="第一张图片中未检测到人脸,请上传包含清晰人脸的图片") + + # 检查第二张图片是否包含人脸 + logger.info("detect 2 image...") + face_boxes2 = analyzer._detect_faces(image2) + if not face_boxes2: + raise HTTPException(status_code=400, detail="第二张图片中未检测到人脸,请上传包含清晰人脸的图片") + + # 保存原始图片到IMAGES_DIR(先不上传 BOS,供 DeepFace 使用) + original_path1 = os.path.join(IMAGES_DIR, original_filename1) + if not save_image_high_quality( + image1, + original_path1, + quality=SAVE_QUALITY, + upload_to_bos=False, + ): + raise HTTPException(status_code=500, detail="保存第一张原始图片失败") + + original_path2 = os.path.join(IMAGES_DIR, original_filename2) + if not save_image_high_quality( + image2, + original_path2, + quality=SAVE_QUALITY, + upload_to_bos=False, + ): + raise HTTPException(status_code=500, detail="保存第二张原始图片失败") + + # 调用DeepFace.verify进行人脸比对 + logger.info("Starting DeepFace verification...") + lock = _ensure_deepface_lock() + async with lock: + try: + # 使用ArcFace模型进行人脸比对 + verification_result = await process_cpu_intensive_task( + deepface_module.verify, + img1_path=original_path1, + img2_path=original_path2, + model_name="ArcFace", + detector_backend="yolov11n", + distance_metric="cosine" + ) + logger.info( + f"DeepFace verification completed result:{json.dumps(verification_result, ensure_ascii=False)}") + except (AttributeError, RuntimeError) as attr_err: + if "numpy" in str(attr_err) or "SymbolicTensor" in str(attr_err): + logger.warning( + f"DeepFace verification 遇到 numpy/SymbolicTensor 异常,尝试恢复后重试: {attr_err}") + _recover_deepface_model() + try: + verification_result = await process_cpu_intensive_task( + deepface_module.verify, + img1_path=original_path1, + img2_path=original_path2, + model_name="ArcFace", + detector_backend="yolov11n", + distance_metric="cosine" + ) + logger.info( + f"DeepFace verification completed after recovery: {json.dumps(verification_result, ensure_ascii=False)}") + except Exception as retry_error: + logger.error( + f"DeepFace verification failed after recovery attempt: {retry_error}") + raise HTTPException(status_code=500, + detail=f"人脸比对失败: {str(retry_error)}") from retry_error + else: + raise + except ValueError as ve: + logger.warning( + f"DeepFace verification 遇到模型状态异常,尝试恢复后重试: {ve}") + _recover_deepface_model() + try: + verification_result = await process_cpu_intensive_task( + deepface_module.verify, + img1_path=original_path1, + img2_path=original_path2, + model_name="ArcFace", + detector_backend="yolov11n", + distance_metric="cosine" + ) + logger.info( + f"DeepFace verification completed after recovery: {json.dumps(verification_result, ensure_ascii=False)}") + except Exception as retry_error: + logger.error( + f"DeepFace verification failed after recovery attempt: {retry_error}") + raise HTTPException(status_code=500, + detail=f"人脸比对失败: {str(retry_error)}") from retry_error + except Exception as e: + logger.error(f"DeepFace verification failed: {e}") + raise HTTPException(status_code=500, + detail=f"人脸比对失败: {str(e)}") from e + + # 提取比对结果 + verified = verification_result["verified"] + distance = verification_result["distance"] + + # 将距离转换为相似度百分比 (距离越小相似度越高) + # cosine距离范围[0,2],转换为百分比 + similarity_percentage = (1 - distance / 2) * 100 + + # 从验证结果中获取人脸框信息 + facial_areas = verification_result.get("facial_areas", {}) + img1_region = facial_areas.get("img1", {}) + img2_region = facial_areas.get("img2", {}) + + # 确保分析器已初始化,用于绘制特征点 + if analyzer is None: + _ensure_analyzer() + + def _apply_landmarks_on_original( + source_image: np.ndarray, + region: dict, + label: str, + ) -> Tuple[np.ndarray, bool]: + if analyzer is None or not region: + return source_image, False + try: + x = max(0, region.get("x", 0)) + y = max(0, region.get("y", 0)) + w = region.get("w", 0) + h = region.get("h", 0) + x_end = min(source_image.shape[1], x + w) + y_end = min(source_image.shape[0], y + h) + if x_end <= x or y_end <= y: + return source_image, False + result_img = source_image.copy() + face_region = result_img[y:y_end, x:x_end] + face_with_landmarks = analyzer.facial_analyzer.draw_facial_landmarks(face_region) + result_img[y:y_end, x:x_end] = face_with_landmarks + return result_img, True + except Exception as exc: + logger.warning(f"Failed to draw facial landmarks on original image {label}: {exc}") + return source_image, False + + original_output_img1, original1_has_landmarks = _apply_landmarks_on_original(image1, img1_region, "1") + original_output_img2, original2_has_landmarks = _apply_landmarks_on_original(image2, img2_region, "2") + + if save_image_high_quality(original_output_img1, original_path1, quality=SAVE_QUALITY): + await _record_output_file( + file_path=original_path1, + nickname=nickname, + category="original", + extra={ + "source": "face_verify", + "role": "original1_landmarks" if original1_has_landmarks else "original1", + "with_landmarks": original1_has_landmarks, + }, + ) + if save_image_high_quality(original_output_img2, original_path2, quality=SAVE_QUALITY): + await _record_output_file( + file_path=original_path2, + nickname=nickname, + category="original", + extra={ + "source": "face_verify", + "role": "original2_landmarks" if original2_has_landmarks else "original2", + "with_landmarks": original2_has_landmarks, + }, + ) + + # 如果有区域信息,则裁剪人脸 + if img1_region and img2_region: + try: + # 裁剪人脸区域 + x1, y1, w1, h1 = img1_region.get("x", 0), img1_region.get("y", 0), img1_region.get("w", 0), img1_region.get("h", 0) + x2, y2, w2, h2 = img2_region.get("x", 0), img2_region.get("y", 0), img2_region.get("w", 0), img2_region.get("h", 0) + + # 确保坐标在图像范围内 + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = max(0, x2), max(0, y2) + x1_end, y1_end = min(image1.shape[1], x1 + w1), min(image1.shape[0], y1 + h1) + x2_end, y2_end = min(image2.shape[1], x2 + w2), min(image2.shape[0], y2 + h2) + + # 裁剪人脸 + face_img1 = image1[y1:y1_end, x1:x1_end] + face_img2 = image2[y2:y2_end, x2:x2_end] + + face_path1 = os.path.join(IMAGES_DIR, face_filename1) + face_path2 = os.path.join(IMAGES_DIR, face_filename2) + # 根据分析器可用性决定是否绘制特征点,仅保存最终版本一次 + def _prepare_face_image(face_img, face_index): + if analyzer is None: + return face_img, False + try: + return analyzer.facial_analyzer.draw_facial_landmarks(face_img.copy()), True + except Exception as exc: + logger.warning(f"Failed to draw facial landmarks on face{face_index}: {exc}") + return face_img, False + + face_output_img1, face1_has_landmarks = _prepare_face_image(face_img1, 1) + face_output_img2, face2_has_landmarks = _prepare_face_image(face_img2, 2) + + if save_image_high_quality(face_output_img1, face_path1, quality=SAVE_QUALITY): + await _record_output_file( + file_path=face_path1, + nickname=nickname, + category="face", + extra={ + "source": "face_verify", + "role": "face1_landmarks" if face1_has_landmarks else "face1", + "with_landmarks": face1_has_landmarks, + }, + ) + if save_image_high_quality(face_output_img2, face_path2, quality=SAVE_QUALITY): + await _record_output_file( + file_path=face_path2, + nickname=nickname, + category="face", + extra={ + "source": "face_verify", + "role": "face2_landmarks" if face2_has_landmarks else "face2", + "with_landmarks": face2_has_landmarks, + }, + ) + except Exception as e: + logger.warning(f"Failed to crop faces: {e}") + else: + # 如果没有区域信息,使用原始图像 + logger.info("No face regions found in verification result, using original images") + + total_time = time.perf_counter() - t1 + logger.info(f"Face similarity verification completed: time={total_time:.3f}s, similarity={similarity_percentage:.2f}%") + + # 返回结果 + return { + "success": True, + "message": "人脸比对完成", + "verified": verified, + "similarity_percentage": round(similarity_percentage, 2), + "distance": distance, + "processing_time": f"{total_time:.3f}s", + "original_filename1": original_filename1, + "original_filename2": original_filename2, + "face_filename1": face_filename1, + "face_filename2": face_filename2, + "model_used": "ArcFace", + "detector_backend": "retinaface", + "distance_metric": "cosine" + } + + except HTTPException: + # 重新抛出HTTP异常 + raise + except Exception as e: + logger.error(f"Error occurred during face similarity verification: {str(e)}") + raise HTTPException(status_code=500, detail=f"人脸比对过程中出现错误: {str(e)}") diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ae949f787de32f6ac1f729122666e3707e83ba9b --- /dev/null +++ b/app.py @@ -0,0 +1,170 @@ +import os +import time +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from starlette.middleware.cors import CORSMiddleware + +from cleanup_scheduler import start_cleanup_scheduler, stop_cleanup_scheduler +from config import ( + logger, + OUTPUT_DIR, + DEEPFACE_AVAILABLE, + DLIB_AVAILABLE, + MODELS_PATH, + IMAGES_DIR, + YOLO_AVAILABLE, + ENABLE_LOGGING, + HUGGINGFACE_SYNC_ENABLED, +) +from database import close_mysql_pool, init_mysql_pool +from utils import ensure_bos_resources, ensure_huggingface_models + +logger.info("Starting to import api_routes module...") + +if HUGGINGFACE_SYNC_ENABLED: + try: + t_hf_start = time.perf_counter() + if not ensure_huggingface_models(): + raise RuntimeError("无法从 HuggingFace 同步模型,请检查配置与网络") + hf_time = time.perf_counter() - t_hf_start + logger.info("HuggingFace 模型同步完成,用时 %.3fs", hf_time) + except Exception as exc: + logger.error(f"HuggingFace model preparation failed: {exc}") + raise +else: + logger.info("已关闭 HuggingFace 模型同步开关,跳过启动阶段的同步步骤") + +try: + t_bos_start = time.perf_counter() + if not ensure_bos_resources(): + raise RuntimeError("无法从 BOS 同步模型与数据,请检查凭证与网络") + bos_time = time.perf_counter() - t_bos_start + logger.info(f"BOS resources synchronized successfully, time: {bos_time:.3f}s") +except Exception as exc: + logger.error(f"BOS resource preparation failed: {exc}") + raise + +try: + t_start = time.perf_counter() + from api_routes import api_router, extract_chinese_celeb_dataset_sync + import_time = time.perf_counter() - t_start + logger.info(f"api_routes module imported successfully, time: {import_time:.3f}s") +except Exception as e: + import_time = time.perf_counter() - t_start + logger.error(f"api_routes module import failed, time: {import_time:.3f}s, error: {e}") + raise + +try: + t_extract_start = time.perf_counter() + extract_result = extract_chinese_celeb_dataset_sync() + extract_time = time.perf_counter() - t_extract_start + logger.info( + "Chinese celeb dataset extracted successfully, time: %.3fs, target: %s", + extract_time, + extract_result.get("target_dir"), + ) +except Exception as exc: + logger.error(f"Failed to extract Chinese celeb dataset automatically: {exc}") + raise + + +@asynccontextmanager +async def lifespan(app: FastAPI): + start_time = time.perf_counter() + logger.info("FaceScore service starting...") + logger.info(f"Output directory: {OUTPUT_DIR}") + logger.info(f"DeepFace available: {DEEPFACE_AVAILABLE}") + logger.info(f"YOLO available: {YOLO_AVAILABLE}") + logger.info(f"MediaPipe available: {DLIB_AVAILABLE}") + logger.info(f"Archive directory: {IMAGES_DIR}") + os.makedirs(OUTPUT_DIR, exist_ok=True) + + # 初始化数据库连接池 + try: + await init_mysql_pool() + logger.info("MySQL 连接池初始化完成") + except Exception as exc: + logger.error(f"初始化 MySQL 连接池失败: {exc}") + raise + + # 启动图片清理定时任务 + logger.info("Starting image cleanup scheduled task...") + try: + start_cleanup_scheduler() + logger.info("Image cleanup scheduled task started successfully") + except Exception as e: + logger.error(f"Failed to start image cleanup scheduled task: {e}") + + # 记录启动完成时间 + total_startup_time = time.perf_counter() - start_time + logger.info(f"FaceScore service startup completed, total time: {total_startup_time:.3f}s") + + yield + + # 应用关闭时停止定时任务 + logger.info("Stopping image cleanup scheduled task...") + try: + stop_cleanup_scheduler() + logger.info("Image cleanup scheduled task stopped") + except Exception as e: + logger.error(f"Failed to stop image cleanup scheduled task: {e}") + + # 关闭数据库连接池 + try: + await close_mysql_pool() + except Exception as exc: + logger.warning(f"关闭 MySQL 连接池失败: {exc}") + + +# 创建 FastAPI 应用 +app = FastAPI( + title="Enhanced FaceScore 服务", + description="支持多模型的人脸分析REST API服务,包含五官评分功能。支持混合模式:HowCuteAmI(颜值+性别)+ DeepFace(年龄+情绪)", + version="3.0.0", + docs_url="/cp_docs", + redoc_url="/cp_redoc", + lifespan=lifespan, +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], +) + +# 注册路由 +app.include_router(api_router) + +# 添加根路径处理 +@app.get("/") +async def root(): + return "UP" + + +if __name__ == "__main__": + import uvicorn + + if not os.path.exists(MODELS_PATH): + logger.critical( + "Warning: 'models' directory not found. Please ensure it exists and contains model files." + ) + logger.critical( + "Exiting application as FaceAnalyzer cannot be initialized without models." + ) + exit(1) + + # 根据日志开关配置 Uvicorn 日志 + if ENABLE_LOGGING: + uvicorn.run(app, host="0.0.0.0", port=8080, reload=False) + else: + # 禁用 Uvicorn 的访问日志和错误日志 + uvicorn.run( + app, + host="0.0.0.0", + port=8080, + reload=False, + access_log=False, # 禁用访问日志 + log_level="critical" # 只显示严重错误 + ) diff --git a/build.sh b/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..0949da15697f8f21aff1bd3c81e94f514f39be8c --- /dev/null +++ b/build.sh @@ -0,0 +1,4 @@ +python -m compileall -q -f -b . +mv *.pyc /opt/data/app/ +cp gfpgan_restorer.py /opt/data/app/ +cp start_local.sh /opt/data/app/ diff --git a/cleanup_scheduler.py b/cleanup_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..ea81048d466cdbf4851a983514edcc709a641e41 --- /dev/null +++ b/cleanup_scheduler.py @@ -0,0 +1,202 @@ +""" +定时清理图片文件模块 +每小时检查一次IMAGES_DIR目录,删除1小时以前的图片文件 +""" +import glob +import os +import time +from datetime import datetime + +from apscheduler.schedulers.background import BackgroundScheduler + +from config import logger, IMAGES_DIR, CLEANUP_INTERVAL_HOURS, CLEANUP_AGE_HOURS + + +# from utils import delete_file_from_bos # 暂时注释掉删除BOS文件的功能 + + +class ImageCleanupScheduler: + """图片清理定时任务类""" + + def __init__(self, images_dir=None, cleanup_hours=None, interval_hours=None): + """ + 初始化清理调度器 + + Args: + images_dir (str): 图片目录路径,默认使用config中的IMAGES_DIR + cleanup_hours (float): 清理时间阈值(小时),默认使用环境变量CLEANUP_AGE_HOURS + interval_hours (float): 定时任务执行间隔(小时),默认使用环境变量CLEANUP_INTERVAL_HOURS + """ + self.images_dir = images_dir or IMAGES_DIR + self.cleanup_hours = cleanup_hours if cleanup_hours is not None else CLEANUP_AGE_HOURS + self.interval_hours = interval_hours if interval_hours is not None else CLEANUP_INTERVAL_HOURS + self.scheduler = BackgroundScheduler() + self.is_running = False + + # 确保目录存在 + os.makedirs(self.images_dir, exist_ok=True) + logger.info(f"Image cleanup scheduler initialized, monitoring directory: {self.images_dir}, cleanup threshold: {self.cleanup_hours} hours, execution interval: {self.interval_hours} hours") + + def cleanup_old_images(self): + """ + 清理过期的图片文件 + 删除超过指定时间的图片文件 + """ + try: + current_time = time.time() + cutoff_time = current_time - (self.cleanup_hours * 3600) # 转换为秒 + cutoff_datetime = datetime.fromtimestamp(cutoff_time) + + # 支持的图片格式 + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.gif', '*.bmp'] + deleted_files = [] + total_size_deleted = 0 + + logger.info(f"Starting to clean image directory: {self.images_dir}") + logger.info(f"Cleanup threshold time: {cutoff_datetime.strftime('%Y-%m-%d %H:%M:%S')}") + + # 遍历所有图片文件 + for extension in image_extensions: + pattern = os.path.join(self.images_dir, extension) + for file_path in glob.glob(pattern): + try: + # 获取文件修改时间 + file_mtime = os.path.getmtime(file_path) + + # 如果文件时间早于阈值时间,则删除 + if file_mtime < cutoff_time: + file_size = os.path.getsize(file_path) + file_time = datetime.fromtimestamp(file_mtime) + + # 删除文件 + os.remove(file_path) + # delete_file_from_bos(file_path) # 暂时注释掉删除BOS文件 + deleted_files.append(os.path.basename(file_path)) + total_size_deleted += file_size + + logger.info(f"Deleting expired file: {os.path.basename(file_path)} ") + + except (OSError, IOError) as e: + logger.error(f"Failed to delete file {os.path.basename(file_path)}: {e}") + continue + + logger.info(f"Cleanup completed! Deleted {len(deleted_files)} files, ") + logger.info(f"Deleted file list: {', '.join(deleted_files[:10])}") + else: + logger.info("Cleanup completed! No expired files found to clean") + + return { + 'success': True, + 'deleted_count': len(deleted_files), + 'deleted_size': total_size_deleted, + 'deleted_files': deleted_files, + 'cutoff_time': cutoff_datetime.isoformat() + } + + except Exception as e: + error_msg = f"图片清理任务执行失败: {e}" + logger.error(error_msg) + return { + 'success': False, + 'error': str(e), + 'deleted_count': 0, + 'deleted_size': 0 + } + + def _format_size(self, size_bytes): + """格式化文件大小显示""" + if size_bytes == 0: + return "0 B" + size_names = ["B", "KB", "MB", "GB"] + i = 0 + while size_bytes >= 1024 and i < len(size_names) - 1: + size_bytes /= 1024.0 + i += 1 + return f"{size_bytes:.1f} {size_names[i]}" + + def start(self): + """启动定时清理任务""" + if self.is_running: + logger.warning("Image cleanup scheduler is already running") + return + + try: + # 添加定时任务:使用可配置的执行间隔 + self.scheduler.add_job( + func=self.cleanup_old_images, + trigger='interval', + hours=self.interval_hours, # 使用环境变量配置的执行间隔 + id='image_cleanup', + name='image clean tast', + replace_existing=True + ) + + # 启动调度器 + self.scheduler.start() + self.is_running = True + + logger.info(f"Image cleanup scheduler started, will execute cleanup task every {self.interval_hours} hours") + + # 立即执行一次清理(可选) + logger.info("Executing image cleanup task immediately...") + self.cleanup_old_images() + + except Exception as e: + logger.error(f"Failed to start image cleanup scheduler: {e}") + raise + + def stop(self): + """停止定时清理任务""" + if not self.is_running: + logger.warning("Image cleanup scheduler is not running") + return + + try: + self.scheduler.shutdown(wait=False) + self.is_running = False + logger.info("Image cleanup scheduler stopped") + except Exception as e: + logger.error(f"Failed to stop image cleanup scheduler: {e}") + + def get_status(self): + """获取调度器状态""" + return { + 'running': self.is_running, + 'images_dir': self.images_dir, + 'cleanup_hours': self.cleanup_hours, + 'interval_hours': self.interval_hours, + 'next_run': self.scheduler.get_jobs()[0].next_run_time.isoformat() + if self.is_running and self.scheduler.get_jobs() else None + } + + +# 创建全局调度器实例 +cleanup_scheduler = ImageCleanupScheduler() + + +def start_cleanup_scheduler(): + """启动图片清理调度器""" + cleanup_scheduler.start() + + +def stop_cleanup_scheduler(): + """停止图片清理调度器""" + cleanup_scheduler.stop() + + +def get_cleanup_status(): + """获取清理调度器状态""" + return cleanup_scheduler.get_status() + + +def manual_cleanup(): + """手动执行一次清理""" + return cleanup_scheduler.cleanup_old_images() + + +if __name__ == "__main__": + # 测试代码 + print("测试图片清理功能...") + test_scheduler = ImageCleanupScheduler() + result = test_scheduler.cleanup_old_images() + print(f"清理结果: {result}") diff --git a/clip_utils.py b/clip_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..246b7879dbecd233cf0ed61166cc0ef57f577a09 --- /dev/null +++ b/clip_utils.py @@ -0,0 +1,65 @@ +# clip_utils.py +import logging +import os +from typing import Union, List + +import cn_clip.clip as clip +import torch +from PIL import Image +from cn_clip.clip import load_from_name + +from config import MODELS_PATH + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# 环境变量配置 +MODEL_NAME_CN = os.environ.get('MODEL_NAME_CN', 'ViT-B-16') + +# 设备配置 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# 模型初始化 +model = None +preprocess = None + +def init_clip_model(): + """初始化CLIP模型""" + global model, preprocess + try: + model, preprocess = load_from_name(MODEL_NAME_CN, device=device, download_root=MODELS_PATH) + model.eval() + logger.info(f"CLIP model initialized successfully, dimension: {model.visual.output_dim}") + return True + except Exception as e: + logger.error(f"CLIP model initialization failed: {e}") + return False + +def is_clip_available(): + """检查CLIP模型是否可用""" + return model is not None and preprocess is not None + +def encode_image(image_path: str) -> torch.Tensor: + """编码图片为向量""" + if not is_clip_available(): + raise RuntimeError("CLIP模型未初始化") + + image = Image.open(image_path).convert("RGB") + image_tensor = preprocess(image).unsqueeze(0).to(device) + with torch.no_grad(): + features = model.encode_image(image_tensor) + features = features / features.norm(p=2, dim=-1, keepdim=True) + return features.cpu() + +def encode_text(text: Union[str, List[str]]) -> torch.Tensor: + """编码文本为向量""" + if not is_clip_available(): + raise RuntimeError("CLIP模型未初始化") + + texts = [text] if isinstance(text, str) else text + text_tokens = clip.tokenize(texts).to(device) + with torch.no_grad(): + features = model.encode_text(text_tokens) + features = features / features.norm(p=2, dim=-1, keepdim=True) + return features.cpu() diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b22c4770541252cdf2db6585d87f8251eaee8a6d --- /dev/null +++ b/config.py @@ -0,0 +1,543 @@ +import logging +import os + +# 解决OpenMP库冲突问题 +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" +# 设置CPU线程数为CPU核心数,提高CPU利用率 +import multiprocessing +cpu_cores = multiprocessing.cpu_count() +os.environ["OMP_NUM_THREADS"] = str(min(cpu_cores, 8)) # 最多使用8个线程 +os.environ["MKL_NUM_THREADS"] = str(min(cpu_cores, 8)) +os.environ["NUMEXPR_NUM_THREADS"] = str(min(cpu_cores, 8)) + +# 修复torchvision兼容性问题 +try: + import torchvision.transforms.functional_tensor +except ImportError: + # 为缺失的functional_tensor模块创建兼容性补丁 + import torchvision.transforms.functional as F + import torchvision.transforms as transforms + import sys + from types import ModuleType + + # 创建functional_tensor模块 + functional_tensor = ModuleType('torchvision.transforms.functional_tensor') + + # 添加常用的函数映射 + if hasattr(F, 'rgb_to_grayscale'): + functional_tensor.rgb_to_grayscale = F.rgb_to_grayscale + if hasattr(F, 'adjust_brightness'): + functional_tensor.adjust_brightness = F.adjust_brightness + if hasattr(F, 'adjust_contrast'): + functional_tensor.adjust_contrast = F.adjust_contrast + if hasattr(F, 'adjust_saturation'): + functional_tensor.adjust_saturation = F.adjust_saturation + if hasattr(F, 'normalize'): + functional_tensor.normalize = F.normalize + if hasattr(F, 'resize'): + functional_tensor.resize = F.resize + if hasattr(F, 'crop'): + functional_tensor.crop = F.crop + if hasattr(F, 'pad'): + functional_tensor.pad = F.pad + + # 将模块添加到sys.modules + sys.modules['torchvision.transforms.functional_tensor'] = functional_tensor + transforms.functional_tensor = functional_tensor + +# 环境变量配置 - 禁用TensorFlow优化和GPU +os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # 强制使用CPU +os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" + +# 修复PyTorch兼容性问题 +try: + import torch + import torch.onnx + + # 修复GFPGAN的ONNX兼容性 + if not hasattr(torch.onnx._internal.exporter, 'ExportOptions'): + from types import SimpleNamespace + torch.onnx._internal.exporter.ExportOptions = SimpleNamespace + + # 修复ModelScope的PyTree兼容性 - 更完整的实现 + import torch.utils + if not hasattr(torch.utils, '_pytree'): + # 如果_pytree模块不存在,创建一个 + from types import ModuleType + torch.utils._pytree = ModuleType('_pytree') + + pytree = torch.utils._pytree + + if not hasattr(pytree, 'register_pytree_node'): + def register_pytree_node(typ, flatten_fn, unflatten_fn, *, flatten_with_keys_fn=None, **kwargs): + """兼容性实现:注册PyTree节点类型""" + pass # 简单实现,不做实际操作 + pytree.register_pytree_node = register_pytree_node + + if not hasattr(pytree, 'tree_flatten'): + def tree_flatten(tree, is_leaf=None): + """兼容性实现:展平树结构""" + if isinstance(tree, (list, tuple)): + flat = [] + spec = [] + for i, item in enumerate(tree): + if isinstance(item, (list, tuple, dict)): + sub_flat, sub_spec = tree_flatten(item, is_leaf) + flat.extend(sub_flat) + spec.append((i, sub_spec)) + else: + flat.append(item) + spec.append((i, None)) + return flat, (type(tree), spec) + elif isinstance(tree, dict): + flat = [] + spec = [] + for key, value in sorted(tree.items()): + if isinstance(value, (list, tuple, dict)): + sub_flat, sub_spec = tree_flatten(value, is_leaf) + flat.extend(sub_flat) + spec.append((key, sub_spec)) + else: + flat.append(value) + spec.append((key, None)) + return flat, (dict, spec) + else: + return [tree], None + pytree.tree_flatten = tree_flatten + + if not hasattr(pytree, 'tree_unflatten'): + def tree_unflatten(values, spec): + """兼容性实现:重构树结构""" + if spec is None: + return values[0] if values else None + + tree_type, tree_spec = spec + if tree_type in (list, tuple): + result = [] + value_idx = 0 + for pos, sub_spec in tree_spec: + if sub_spec is None: + result.append(values[value_idx]) + value_idx += 1 + else: + # 计算子树需要的值数量 + sub_count = _count_tree_values(sub_spec) + sub_values = values[value_idx:value_idx + sub_count] + result.append(tree_unflatten(sub_values, sub_spec)) + value_idx += sub_count + return tree_type(result) + elif tree_type == dict: + result = {} + value_idx = 0 + for key, sub_spec in tree_spec: + if sub_spec is None: + result[key] = values[value_idx] + value_idx += 1 + else: + sub_count = _count_tree_values(sub_spec) + sub_values = values[value_idx:value_idx + sub_count] + result[key] = tree_unflatten(sub_values, sub_spec) + value_idx += sub_count + return result + return values[0] if values else None + pytree.tree_unflatten = tree_unflatten + + if not hasattr(pytree, 'tree_map'): + def tree_map(fn, tree, *other_trees, is_leaf=None): + """兼容性实现:树映射""" + flat, spec = tree_flatten(tree, is_leaf) + if other_trees: + other_flats = [tree_flatten(t, is_leaf)[0] for t in other_trees] + mapped = [fn(x, *others) for x, *others in zip(flat, *other_flats)] + else: + mapped = [fn(x) for x in flat] + return tree_unflatten(mapped, spec) + pytree.tree_map = tree_map + + # 辅助函数 + def _count_tree_values(spec): + """计算树规格中的值数量""" + if spec is None: + return 1 + tree_type, tree_spec = spec + return sum(_count_tree_values(sub_spec) if sub_spec else 1 for _, sub_spec in tree_spec) + + # 修复pyarrow兼容性问题 + try: + import pyarrow + if not hasattr(pyarrow, 'PyExtensionType'): + # 为旧版本pyarrow添加PyExtensionType兼容性 + pyarrow.PyExtensionType = type('PyExtensionType', (), {}) + except ImportError: + pass + +except (ImportError, AttributeError) as e: + print(f"Warning: PyTorch/PyArrow compatibility patch failed: {e}") + pass +IMAGES_DIR = os.environ.get("IMAGES_DIR", "/opt/data/images") +OUTPUT_DIR = IMAGES_DIR + +# 明星图库目录配置 +CELEBRITY_SOURCE_DIR = os.environ.get( + "CELEBRITY_SOURCE_DIR", "/opt/data/chinese_celeb_dataset" +).strip() +if CELEBRITY_SOURCE_DIR: + CELEBRITY_SOURCE_DIR = os.path.abspath(os.path.expanduser(CELEBRITY_SOURCE_DIR)) + +CELEBRITY_DATASET_DIR = os.path.abspath( + os.path.expanduser( + os.environ.get( + "CELEBRITY_DATASET_DIR", + CELEBRITY_SOURCE_DIR or "/opt/data/chinese_celeb_dataset", + ) + ) +) + +CELEBRITY_FIND_THRESHOLD = float( + os.environ.get("CELEBRITY_FIND_THRESHOLD", 0.88) +) + +# ---- start ---- +# 微信小程序配置(默认值仅用于本地开发) +WECHAT_APPID = os.environ.get("WECHAT_APPID", "******").strip() +WECHAT_SECRET = os.environ.get("WCT_SECRET", "******").strip() +APP_SECRET_TOKEN = os.environ.get("APP_SECRET_TOKEN", "******") +# MySQL 数据库配置 +MYSQL_HOST = os.environ.get("MYSQL_HOST", "******") +MYSQL_PORT = int(os.environ.get("MYSQL_PORT", "3306")) +MYSQL_DB = os.environ.get("MYSQL_DB", "******") +MYSQL_USER = os.environ.get("MYSQL_USER", "******") +MYSQL_PASSWORD = os.environ.get("MYSQL_PASSWORD", "******") +# BOS 对象存储配置(默认存储为Base64编码字符串) +BOS_ACCESS_KEY = os.environ.get("BOS_ACCESS_KEY", "******").strip() +BOS_SECRET_KEY = os.environ.get("BOS_SECRET_KEY", "******").strip() +BOS_ENDPOINT = os.environ.get("BOS_ENDPOINT", "******").strip() +BOS_BUCKET_NAME = os.environ.get("BOS_BUCKET_NAME", "******").strip() +BOS_IMAGE_DIR = os.environ.get("BOS_IMAGE_DIR", "******").strip() +BOS_MODELS_PREFIX = os.environ.get("BOS_MODELS_PREFIX", "******").strip() +BOS_CELEBRITY_PREFIX = os.environ.get("BOS_CELEBRITY_PREFIX", "******").strip() +# ---- end --- + +_bos_enabled_env = os.environ.get("BOS_UPLOAD_ENABLED") +MYSQL_POOL_MIN_SIZE = int(os.environ.get("MYSQL_POOL_MIN_SIZE", "1")) +MYSQL_POOL_MAX_SIZE = int(os.environ.get("MYSQL_POOL_MAX_SIZE", "10")) +if _bos_enabled_env is not None: + BOS_UPLOAD_ENABLED = _bos_enabled_env.lower() in ("1", "true", "on") +else: + BOS_UPLOAD_ENABLED = all( + [ + BOS_ACCESS_KEY.strip(), + BOS_SECRET_KEY.strip(), + BOS_ENDPOINT, + BOS_BUCKET_NAME, + ] + ) +HOSTNAME = os.environ.get("HOSTNAME", "default-hostname") +MODELS_PATH = os.path.abspath( + os.path.expanduser(os.environ.get("MODELS_PATH", "/opt/data/models")) +) +MODELS_DOWNLOAD_DIR = os.path.abspath( + os.path.expanduser(os.environ.get("MODELS_DOWNLOAD_DIR", MODELS_PATH)) +) +# HuggingFace 仓库配置 +HUGGINGFACE_SYNC_ENABLED = os.environ.get( + "HUGGINGFACE_SYNC_ENABLED", "true" +).lower() in ("1", "true", "on") +HUGGINGFACE_REPO_ID = os.environ.get( + "HUGGINGFACE_REPO_ID", "ethonmax/facescore" +).strip() +HUGGINGFACE_REVISION = os.environ.get( + "HUGGINGFACE_REVISION", "main" +).strip() +_hf_allow_env = os.environ.get("HUGGINGFACE_ALLOW_PATTERNS", "").strip() +HUGGINGFACE_ALLOW_PATTERNS = [ + pattern.strip() for pattern in _hf_allow_env.split(",") if pattern.strip() +] +_hf_ignore_env = os.environ.get("HUGGINGFACE_IGNORE_PATTERNS", "").strip() +HUGGINGFACE_IGNORE_PATTERNS = [ + pattern.strip() for pattern in _hf_ignore_env.split(",") if pattern.strip() +] + +_MODELSCOPE_CACHE_ENV = os.environ.get("MODELSCOPE_CACHE", "").strip() +if _MODELSCOPE_CACHE_ENV: + MODELSCOPE_CACHE_DIR = os.path.abspath(os.path.expanduser(_MODELSCOPE_CACHE_ENV)) +else: + MODELSCOPE_CACHE_DIR = os.path.join(MODELS_PATH, "modelscope") + +try: + os.makedirs(MODELSCOPE_CACHE_DIR, exist_ok=True) +except Exception as exc: + print(f"创建 ModelScope 缓存目录失败: %s (%s)", MODELSCOPE_CACHE_DIR, exc) + +os.environ.setdefault("MODELSCOPE_CACHE", MODELSCOPE_CACHE_DIR) +os.environ.setdefault("MODELSCOPE_HOME", MODELSCOPE_CACHE_DIR) +os.environ.setdefault("MODELSCOPE_CACHE_HOME", MODELSCOPE_CACHE_DIR) + +DEEPFACE_HOME = os.environ.get("DEEPFACE_HOME", "/opt/data/models") +os.environ["DEEPFACE_HOME"] = DEEPFACE_HOME + +# 设置GFPGAN相关模型下载路径 +GFPGAN_MODEL_DIR = MODELS_DOWNLOAD_DIR +os.makedirs(GFPGAN_MODEL_DIR, exist_ok=True) + +# 设置各种模型库的下载目录环境变量 +os.environ["GFPGAN_MODEL_ROOT"] = GFPGAN_MODEL_DIR +os.environ["FACEXLIB_CACHE_DIR"] = GFPGAN_MODEL_DIR +os.environ["BASICSR_CACHE_DIR"] = GFPGAN_MODEL_DIR +os.environ["REALESRGAN_MODEL_ROOT"] = GFPGAN_MODEL_DIR +os.environ["HUB_CACHE_DIR"] = GFPGAN_MODEL_DIR # PyTorch Hub缓存 + +# 设置rembg模型下载路径到统一的AI模型目录 +REMBG_MODEL_DIR = os.path.expanduser(MODELS_PATH.replace("$HOME", "~")) +os.environ["U2NET_HOME"] = REMBG_MODEL_DIR # u2net模型缓存目录 +os.environ["REMBG_HOME"] = REMBG_MODEL_DIR # rembg通用缓存目录 + +IMG_QUALITY = float(os.environ.get("IMG_QUALITY", 0.5)) +FACE_CONFIDENCE = float(os.environ.get("FACE_CONFIDENCE", 0.7)) +AGE_CONFIDENCE = float(os.environ.get("AGE_CONFIDENCE", 0.99)) +GENDER_CONFIDENCE = float(os.environ.get("GENDER_CONFIDENCE", 1.1)) +UPSCALE_SIZE = int(os.environ.get("UPSCALE_SIZE", 2)) +SAVE_QUALITY = int(os.environ.get("SAVE_QUALITY", 85)) +REALESRGAN_MODEL = os.environ.get("REALESRGAN_MODEL", "realesr-general-x4v3") +# yolov11n-face.pt / yolov8n-face.pt +YOLO_MODEL = os.environ.get("YOLO_MODEL", "yolov11n-face.pt") +# mobilenetv3/resnet50 +RVM_MODEL = os.environ.get("RVM_MODEL", "resnet50") +RVM_LOCAL_REPO = os.environ.get("RVM_LOCAL_REPO", "/opt/data/RobustVideoMatting").strip() +RVM_WEIGHTS_PATH = os.environ.get("RVM_WEIGHTS_PATH", "/opt/data/models/torch/hub/checkpoints/rvm_resnet50.pth").strip() +DRAW_SCORE = os.environ.get("DRAW_SCORE", "true").lower() in ("1", "true", "on") + +# 颜值评分温和提升配置(默认开启;默认区间与力度:区间=[6.0, 8.0],gamma=0.3) +# - BEAUTY_ADJUST_ENABLED: 是否开启提分 +# - BEAUTY_ADJUST_MIN: 提分下限(低于该值不提分) +# - BEAUTY_ADJUST_MAX: 提分上限(目标上限;仅在 [min, max) 区间内提分) +# - BEAUTY_ADJUST_THRESHOLD: 兼容旧配置,等价于 BEAUTY_ADJUST_MAX +# - BEAUTY_ADJUST_GAMMA: 提分力度,(0,1],越小提升越多 +BEAUTY_ADJUST_ENABLED = os.environ.get("BEAUTY_ADJUST_ENABLED", "true").lower() in ("1", "true", "on") +BEAUTY_ADJUST_MIN = float(os.environ.get("BEAUTY_ADJUST_MIN", 1.0)) +# 向后兼容:未提供 BEAUTY_ADJUST_MAX 时,使用旧的 BEAUTY_ADJUST_THRESHOLD 或 8.0 +_legacy_thr = os.environ.get("BEAUTY_ADJUST_THRESHOLD") +BEAUTY_ADJUST_MAX = float(os.environ.get("BEAUTY_ADJUST_MAX", _legacy_thr if _legacy_thr is not None else 8.0)) +BEAUTY_ADJUST_GAMMA = float(os.environ.get("BEAUTY_ADJUST_GAMMA", 0.5)) # 0 aiomysql.Pool: + """初始化 MySQL 连接池""" + global _pool + if _pool is not None: + return _pool + + async with _pool_lock: + if _pool is not None: + return _pool + try: + _pool = await aiomysql.create_pool( + host=MYSQL_HOST, + port=MYSQL_PORT, + user=MYSQL_USER, + password=MYSQL_PASSWORD, + db=MYSQL_DB, + minsize=MYSQL_POOL_MIN_SIZE, + maxsize=MYSQL_POOL_MAX_SIZE, + autocommit=True, + charset="utf8mb4", + cursorclass=DictCursor, + ) + logger.info( + "MySQL 连接池初始化成功,host=%s db=%s", + MYSQL_HOST, + MYSQL_DB, + ) + except Exception as exc: + logger.error(f"初始化 MySQL 连接池失败: {exc}") + raise + return _pool + + +async def close_mysql_pool() -> None: + """关闭 MySQL 连接池""" + global _pool + if _pool is None: + return + + async with _pool_lock: + if _pool is None: + return + _pool.close() + await _pool.wait_closed() + _pool = None + logger.info("MySQL 连接池已关闭") + + +@asynccontextmanager +async def get_connection(): + """获取连接池中的连接""" + if _pool is None: + await init_mysql_pool() + assert _pool is not None + conn = await _pool.acquire() + try: + yield conn + finally: + _pool.release(conn) + + +async def execute(query: str, + params: Sequence[Any] | Dict[str, Any] | None = None) -> None: + """执行写入类 SQL""" + async with get_connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(query, params or ()) + + +async def fetch_all( + query: str, params: Sequence[Any] | Dict[str, Any] | None = None +) -> List[Dict[str, Any]]: + """执行查询并返回全部结果""" + async with get_connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(query, params or ()) + rows = await cursor.fetchall() + return list(rows) + + +def _serialize_extra(extra: Optional[Dict[str, Any]]) -> Optional[str]: + if extra is None: + return None + try: + return json.dumps(extra, ensure_ascii=False) + except Exception: + logger.warning("无法序列化 extra_metadata,已忽略") + return None + + +async def upsert_image_record( + *, + file_path: str, + category: str, + nickname: Optional[str], + score: float, + is_cropped_face: bool, + size_bytes: int, + last_modified: datetime, + bos_uploaded: bool, + hostname: Optional[str] = None, + extra_metadata: Optional[Dict[str, Any]] = None, +) -> None: + """写入或更新图片记录""" + query = """ + INSERT INTO tpl_app_processed_images ( + file_path, + category, + nickname, + score, + is_cropped_face, + size_bytes, + last_modified, + bos_uploaded, + hostname, + extra_metadata + ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON DUPLICATE KEY UPDATE + category = VALUES(category), + nickname = VALUES(nickname), + score = VALUES(score), + is_cropped_face = VALUES(is_cropped_face), + size_bytes = VALUES(size_bytes), + last_modified = VALUES(last_modified), + bos_uploaded = VALUES(bos_uploaded), + hostname = VALUES(hostname), + extra_metadata = VALUES(extra_metadata), + updated_at = CURRENT_TIMESTAMP + """ + extra_value = _serialize_extra(extra_metadata) + await execute( + query, + ( + file_path, + category, + nickname, + score, + 1 if is_cropped_face else 0, + size_bytes, + last_modified, + 1 if bos_uploaded else 0, + hostname, + extra_value, + ), + ) + + +async def fetch_paged_image_records( + *, + category: Optional[str], + nickname: Optional[str], + offset: int, + limit: int, +) -> List[Dict[str, Any]]: + """按条件分页查询图片记录""" + where_clauses: List[str] = [] + params: List[Any] = [] + if category and category != "all": + where_clauses.append("category = %s") + params.append(category) + if nickname: + where_clauses.append("nickname = %s") + params.append(nickname) + where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + query = f""" + SELECT + file_path, + category, + nickname, + score, + is_cropped_face, + size_bytes, + last_modified, + bos_uploaded, + hostname + FROM tpl_app_processed_images + {where_sql} + ORDER BY last_modified DESC, id DESC + LIMIT %s OFFSET %s + """ + params.extend([limit, offset]) + return await fetch_all(query, params) + + +async def count_image_records( + *, category: Optional[str], nickname: Optional[str] +) -> int: + """按条件统计图片记录数量""" + where_clauses: List[str] = [] + params: List[Any] = [] + if category and category != "all": + where_clauses.append("category = %s") + params.append(category) + if nickname: + where_clauses.append("nickname = %s") + params.append(nickname) + where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + query = f"SELECT COUNT(*) AS total FROM tpl_app_processed_images {where_sql}" + rows = await fetch_all(query, params) + if not rows: + return 0 + return int(rows[0].get("total", 0) or 0) + + +async def fetch_today_category_counts() -> List[Dict[str, Any]]: + """统计当天按类别分组的数量""" + query = """ + SELECT + COALESCE(category, 'unknown') AS category, + COUNT(*) AS count + FROM tpl_app_processed_images + WHERE last_modified >= CURDATE() + AND last_modified < DATE_ADD(CURDATE(), INTERVAL 1 DAY) + GROUP BY COALESCE(category, 'unknown') + """ + rows = await fetch_all(query) + return [ + { + "category": str(row.get("category") or "unknown"), + "count": int(row.get("count") or 0), + } + for row in rows + ] + + +async def fetch_records_by_paths(file_paths: Iterable[str]) -> Dict[ + str, Dict[str, Any]]: + """根据文件名批量查询图片记录""" + paths = list({path for path in file_paths if path}) + if not paths: + return {} + + placeholders = ", ".join(["%s"] * len(paths)) + query = f""" + SELECT + file_path, + category, + nickname, + score, + is_cropped_face, + size_bytes, + last_modified, + bos_uploaded, + hostname + FROM tpl_app_processed_images + WHERE file_path IN ({placeholders}) + """ + rows = await fetch_all(query, paths) + return {row["file_path"]: row for row in rows} + + +_IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR)) + + +def _normalize_file_path(file_path: str) -> Optional[str]: + """将绝对路径转换为相对 IMAGES_DIR 的文件名""" + try: + abs_path = os.path.abspath(os.path.expanduser(file_path)) + if os.path.isdir(abs_path): + return None + if os.path.commonpath([_IMAGES_DIR_ABS, abs_path]) != _IMAGES_DIR_ABS: + return os.path.basename(abs_path) + rel_path = os.path.relpath(abs_path, _IMAGES_DIR_ABS) + return rel_path.replace("\\", "/") + except Exception: + return None + + +def infer_category_from_filename(filename: str, default: str = "other") -> str: + """根据文件名推断类别""" + lower_name = filename.lower() + if "_face_" in lower_name: + return "face" + if lower_name.endswith("_original.webp") or "_original" in lower_name: + return "original" + if "_restore" in lower_name: + return "restore" + if "_upcolor" in lower_name: + return "upcolor" + if "_compress" in lower_name: + return "compress" + if "_upscale" in lower_name: + return "upscale" + if "_anime_style_" in lower_name: + return "anime_style" + if "_grayscale" in lower_name: + return "grayscale" + if "_id_photo" in lower_name or "_save_id_photo" in lower_name: + return "id_photo" + if "_grid_" in lower_name: + return "grid" + if "_rvm_id_photo" in lower_name: + return "rvm" + if "_celebrity_" in lower_name or "_celebrity" in lower_name: + return "celebrity" + return default + + +from config import HOSTNAME + +async def record_image_creation( + *, + file_path: str, + nickname: Optional[str], + score: float = 0.0, + category: Optional[str] = None, + bos_uploaded: bool = False, + extra_metadata: Optional[Dict[str, Any]] = None, +) -> None: + """ + 记录图片元数据到数据库,如果数据库不可用则静默忽略。 + :param file_path: 绝对或相对文件路径 + :param nickname: 用户昵称 + :param score: 关联得分 + :param category: 文件类别,未提供时自动根据文件名推断 + :param bos_uploaded: 是否已上传至 BOS + :param extra_metadata: 额外信息 + """ + normalized = _normalize_file_path(file_path) + if normalized is None: + logger.info("record_image_creation: 无法计算文件名,路径=%s", file_path) + return + + abs_path = os.path.join(_IMAGES_DIR_ABS, normalized) + if not os.path.isfile(abs_path): + logger.info("record_image_creation: 文件不存在,跳过记录 file=%s", abs_path) + return + + try: + stat = os.stat(abs_path) + category_name = category or infer_category_from_filename(normalized) + is_cropped_face = "_face_" in normalized and normalized.count("_") >= 2 + last_modified = datetime.fromtimestamp(stat.st_mtime) + + nickname_value = nickname.strip() if isinstance(nickname, + str) and nickname.strip() else None + + await upsert_image_record( + file_path=normalized, + category=category_name, + nickname=nickname_value, + score=score, + is_cropped_face=is_cropped_face, + size_bytes=stat.st_size, + last_modified=last_modified, + bos_uploaded=bos_uploaded, + hostname=HOSTNAME, + extra_metadata=extra_metadata, + ) + except Exception as exc: + logger.warning(f"写入图片记录失败: {exc}") diff --git a/ddcolor_colorizer.py b/ddcolor_colorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..63b44e5d442c5cdcf6b1461c48e8b50734ae9ee1 --- /dev/null +++ b/ddcolor_colorizer.py @@ -0,0 +1,301 @@ +import os +import tempfile +import time + +import cv2 +import numpy as np + +from config import logger + + +class DDColorColorizer: + def __init__(self): + start_time = time.perf_counter() + self.colorizer = None + # 检查是否启用DDColor功能 + from config import ENABLE_DDCOLOR + if ENABLE_DDCOLOR: + self._initialize_model() + else: + logger.info("DDColor feature is disabled, skipping model initialization") + init_time = time.perf_counter() - start_time + if self.colorizer is not None: + logger.info(f"DDColorColorizer initialized successfully, time: {init_time:.3f}s") + else: + logger.info(f"DDColorColorizer initialization completed but not available, time: {init_time:.3f}s") + + def _initialize_model(self): + """初始化DDColor模型(使用ModelScope)""" + try: + logger.info("Initializing DDColor model (using ModelScope)...") + + # 添加torch类型兼容性补丁 + import torch + if not hasattr(torch, 'uint64'): + logger.info("Adding torch.uint64 compatibility patch...") + torch.uint64 = torch.int64 # 使用int64作为uint64的替代 + if not hasattr(torch, 'uint32'): + logger.info("Adding torch.uint32 compatibility patch...") + torch.uint32 = torch.int32 # 使用int32作为uint32的替代 + if not hasattr(torch, 'uint16'): + logger.info("Adding torch.uint16 compatibility patch...") + torch.uint16 = torch.int16 # 使用int16作为uint16的替代 + + # 导入ModelScope相关模块 + from modelscope.outputs import OutputKeys + from modelscope.pipelines import pipeline + from modelscope.utils.constant import Tasks + + # 初始化DDColor pipeline + self.colorizer = pipeline( + Tasks.image_colorization, + model='damo/cv_ddcolor_image-colorization' + ) + self.OutputKeys = OutputKeys + + logger.info("DDColor model initialized successfully") + + except ImportError as e: + logger.error(f"ModelScope module import failed: {e}") + self.colorizer = None + except Exception as e: + logger.error(f"DDColor model initialization failed: {e}") + self.colorizer = None + + def is_available(self): + """检查DDColor是否可用""" + return self.colorizer is not None + + def is_grayscale(self, image): + """检查图像是否为灰度图像""" + if len(image.shape) == 2: + return True + elif len(image.shape) == 3: + # 检查是否为伪彩色图像(RGB三个通道值相等) + b, g, r = cv2.split(image) + + # 计算通道间的差异 + diff_bg = np.abs(b.astype(float) - g.astype(float)) + diff_gr = np.abs(g.astype(float) - r.astype(float)) + diff_rb = np.abs(r.astype(float) - b.astype(float)) + + # 计算平均差异 + avg_diff = (np.mean(diff_bg) + np.mean(diff_gr) + np.mean(diff_rb)) / 3.0 + + # 计算色彩饱和度 + hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + saturation = hsv[:, :, 1] # S通道 + avg_saturation = np.mean(saturation) + + # 改进的灰度检测:同时考虑通道差异和饱和度 + is_gray = (avg_diff < 5.0) or (avg_saturation < 20.0) + + logger.info(f"Grayscale detection - Average channel difference: {avg_diff:.2f}, Average saturation: {avg_saturation:.2f}, Result: {is_gray}") + return is_gray + return False + + def colorize_image(self, image): + """ + 使用DDColor对灰度图像进行上色 + :param image: 输入图像 (numpy array, BGR格式) + :return: 上色后的图像 (numpy array, BGR格式) + """ + if not self.is_available(): + logger.error("DDColor model not initialized") + return image + + # 检查是否为灰度图像 + if not self.is_grayscale(image): + logger.info("Image is already colored, no need for colorization") + return image + + return self.colorize_image_direct(image) + + def colorize_image_direct(self, image): + """ + 直接对图像进行上色,不检查是否为灰度图 + 使用与test_ddcolor.py相同质量的文件路径方法 + :param image: 输入图像 (numpy array, BGR格式) + :return: 上色后的图像 (numpy array, BGR格式) + """ + if not self.is_available(): + logger.error("DDColor model not initialized") + return image + + # 直接使用文件路径方法,这是经过验证效果最好的方式 + return self._colorize_image_via_file(image) + + def _colorize_image_via_file(self, image): + """ + 通过临时文件进行上色,尽可能模拟test_ddcolor.py的处理方式 + :param image: 输入图像 (numpy array, BGR格式) + :return: 上色后的图像 (numpy array, BGR格式) + """ + try: + logger.info("Using high-quality file path method for colorization...") + + # 使用最高质量设置保存临时图像,尽可能保持原始质量 + with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_input: + # 使用WebP格式以获得更好的质量和更小的文件大小 + cv2.imwrite(tmp_input.name, image, [cv2.IMWRITE_WEBP_QUALITY, 100]) + tmp_input_path = tmp_input.name + + try: + logger.info(f"Temporary file saved to: {tmp_input_path}") + + # 使用ModelScope进行上色 - 与test_colorization完全相同的调用方式 + result = self.colorizer(tmp_input_path) + + # 获取上色后的图像 - 与test_colorization完全相同的处理 + colorized_image = result[self.OutputKeys.OUTPUT_IMG] + + logger.info(f"Colorization output: size={colorized_image.shape}, type={colorized_image.dtype}") + + # ModelScope输出的图像已经是BGR格式,不需要转换 + # (与test_colorization保存时直接使用cv2.imwrite一致) + logger.info("High-quality file path method colorization completed") + return colorized_image + + finally: + # 清理临时文件 + try: + os.unlink(tmp_input_path) + except: + pass + + except Exception as e: + logger.error(f"High-quality file path method colorization failed: {e}") + logger.info("Returning original image") + return image + + def restore_and_colorize(self, image, gfpgan_restorer=None): + """ + 先修复后上色的组合处理(旧版本,保持兼容性) + :param image: 输入图像 + :param gfpgan_restorer: GFPGAN修复器实例 + :return: 修复并上色后的图像 + """ + try: + # 先进行修复(如果有修复器) + if gfpgan_restorer and gfpgan_restorer.is_available(): + logger.info("First performing image restoration...") + restored_image = gfpgan_restorer.restore_image(image) + else: + restored_image = image + + # 再进行上色 + if self.is_grayscale(restored_image): + logger.info("Grayscale image detected, performing colorization...") + colorized_image = self.colorize_image(restored_image) + return colorized_image + else: + logger.info("Image is already colored, only returning restoration result") + return restored_image + + except Exception as e: + logger.error(f"Restoration and colorization combination processing failed: {e}") + return image + + def colorize_and_restore(self, image, gfpgan_restorer=None): + """ + 先上色后修复的组合处理(新版本) + :param image: 输入图像 + :param gfpgan_restorer: GFPGAN修复器实例 + :return: 上色并修复后的图像 + """ + try: + # 先进行上色(如果是灰度图) + if self.is_grayscale(image): + logger.info("Grayscale image detected, performing colorization first...") + colorized_image = self.colorize_image_direct(image) + else: + logger.info("Image is already colored, skipping colorization step") + colorized_image = image + + # 再进行修复(如果有修复器) + if gfpgan_restorer and gfpgan_restorer.is_available(): + logger.info("Performing restoration on the colorized image...") + final_image = gfpgan_restorer.restore_image(colorized_image) + return final_image + else: + logger.info("No restorer available, returning colorization result") + return colorized_image + + except Exception as e: + logger.error(f"Colorization and restoration combination processing failed: {e}") + return image + + def save_debug_image(self, image, filename_prefix): + """保存调试用的图像""" + try: + debug_path = f"{filename_prefix}_debug.webp" + cv2.imwrite(debug_path, image, [cv2.IMWRITE_WEBP_QUALITY, 95]) + logger.info(f"Debug image saved: {debug_path}") + return debug_path + except Exception as e: + logger.error(f"Failed to save debug image: {e}") + return None + + def test_colorization(self, test_url=None): + """ + 测试上色功能 + :param test_url: 测试图像URL,默认使用官方示例 + :return: 测试结果 + """ + if not self.is_available(): + return False, "DDColor模型未初始化" + + try: + test_url = test_url or 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/audrey_hepburn.jpg' + logger.info(f"Testing DDColor colorization feature, using image: {test_url}") + + result = self.colorizer(test_url) + colorized_img = result[self.OutputKeys.OUTPUT_IMG] + + # 保存测试结果 + test_output_path = 'ddcolor_test_result.webp' + cv2.imwrite(test_output_path, colorized_img, [cv2.IMWRITE_WEBP_QUALITY, 95]) + + logger.info(f"DDColor test successful, result saved to: {test_output_path}") + return True, f"测试成功,结果保存到: {test_output_path}" + + except Exception as e: + logger.error(f"DDColor test failed: {e}") + return False, f"测试失败: {e}" + + def test_local_image(self, image_path): + """ + 测试本地图像上色,用于对比分析 + :param image_path: 本地图像路径 + :return: 测试结果 + """ + if not self.is_available(): + return False, "DDColor模型未初始化" + + try: + logger.info(f"Testing local image colorization: {image_path}") + + # 读取本地图像 + image = cv2.imread(image_path) + if image is None: + return False, f"无法读取图像: {image_path}" + + # 检查是否为灰度 + is_gray = self.is_grayscale(image) + logger.info(f"Local image grayscale detection result: {is_gray}") + + # 保存原图用于对比 + self.save_debug_image(image, "original") + + # 直接上色 + colorized_image = self.colorize_image_direct(image) + + # 保存上色结果 + result_path = self.save_debug_image(colorized_image, "local_colorized") + + logger.info(f"Local image colorization successful, result saved to: {result_path}") + return True, f"本地图像上色成功,结果保存到: {result_path}" + + except Exception as e: + logger.error(f"Local image colorization failed: {e}") + return False, f"本地图像上色失败: {e}" diff --git a/debug_colorize.py b/debug_colorize.py new file mode 100644 index 0000000000000000000000000000000000000000..270e46c00f3395dc8b41455f2970838ceb112390 --- /dev/null +++ b/debug_colorize.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +""" +调试上色效果差异的脚本 +""" + +import sys +import os +import cv2 +import numpy as np + +# 添加当前目录到路径 +sys.path.insert(0, os.path.dirname(__file__)) + +from ddcolor_colorizer import DDColorColorizer +from gfpgan_restorer import GFPGANRestorer +import logging + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s') + +def simulate_api_processing(image_path): + """ + 模拟API接口的完整处理流程 + """ + print("\n=== 模拟API接口处理流程 ===") + + # 初始化组件 + print("初始化GFPGAN修复器...") + try: + gfpgan_restorer = GFPGANRestorer() + if not gfpgan_restorer.is_available(): + print("❌ GFPGAN不可用") + return None + print("✅ GFPGAN初始化成功") + except Exception as e: + print(f"❌ GFPGAN初始化失败: {e}") + return None + + print("初始化DDColor上色器...") + try: + ddcolor_colorizer = DDColorColorizer() + if not ddcolor_colorizer.is_available(): + print("❌ DDColor不可用") + return None + print("✅ DDColor初始化成功") + except Exception as e: + print(f"❌ DDColor初始化失败: {e}") + return None + + # 读取图像 + print(f"读取图像: {image_path}") + image = cv2.imread(image_path) + if image is None: + print(f"❌ 无法读取图像: {image_path}") + return None + + print(f"原图尺寸: {image.shape}") + + # 保存原图 + ddcolor_colorizer.save_debug_image(image, "api_original") + + # 检查原图灰度状态 + original_is_grayscale = ddcolor_colorizer.is_grayscale(image) + print(f"原图灰度检测: {original_is_grayscale}") + + # 新的处理流程:先上色再修复 + # 步骤1: 上色处理 + print("\n步骤1: 上色处理...") + if original_is_grayscale: + print("策略: 对原图进行上色") + colorized_image = ddcolor_colorizer.colorize_image_direct(image) + ddcolor_colorizer.save_debug_image(colorized_image, "api_colorized") + strategy = "先上色" + current_image = colorized_image + else: + print("策略: 图像已经是彩色的,跳过上色") + strategy = "跳过上色" + current_image = image + + # 步骤2: GFPGAN修复 + print("\n步骤2: GFPGAN修复...") + final_image = gfpgan_restorer.restore_image(current_image) + print(f"修复后图像尺寸: {final_image.shape}") + + # 保存最终结果 + result_path = ddcolor_colorizer.save_debug_image(final_image, "api_final") + + strategy += " -> 再修复" + + print(f"\n✅ API模拟完成") + print(f" - 处理策略: {strategy}") + print(f" - 最终结果: {result_path}") + + return { + 'original': image, + 'colorized': colorized_image if original_is_grayscale else None, + 'final': final_image, + 'strategy': strategy + } + +def test_direct_colorization(image_path): + """ + 测试直接上色(类似test_ddcolor.py的方式) + """ + print("\n=== 测试直接上色 ===") + + colorizer = DDColorColorizer() + if not colorizer.is_available(): + print("❌ DDColor不可用") + return None + + # 直接使用URL进行上色(和test_ddcolor.py相同) + print("使用官方示例URL上色...") + success, message = colorizer.test_colorization() + + if success: + print(f"✅ URL上色成功: {message}") + else: + print(f"❌ URL上色失败: {message}") + + # 对本地图像进行直接上色 + print(f"对本地图像直接上色: {image_path}") + success, message = colorizer.test_local_image(image_path) + + if success: + print(f"✅ 本地图像上色成功: {message}") + else: + print(f"❌ 本地图像上色失败: {message}") + +def compare_results(): + """ + 对比分析结果 + """ + print("\n=== 结果对比分析 ===") + + # 列出生成的调试图像 + debug_files = [] + for f in os.listdir("."): + if f.endswith("_debug.webp"): + debug_files.append(f) + + if debug_files: + print("生成的调试文件:") + for f in sorted(debug_files): + print(f" - {f}") + + print("\n对比建议:") + print("1. 比较 original_debug.webp 和 api_original_debug.webp") + print("2. 比较 local_colorized_debug.webp 和 api_final_debug.webp") + print("3. 检查 api_restored_debug.webp 的修复效果") + print("4. 观察 ddcolor_test_result.webp 的官方示例效果") + else: + print("未找到调试文件") + +def analyze_image_quality(image_path): + """ + 分析图像质量指标 + """ + print(f"\n=== 分析图像质量: {image_path} ===") + + if not os.path.exists(image_path): + print(f"文件不存在: {image_path}") + return + + image = cv2.imread(image_path) + if image is None: + print(f"无法读取图像: {image_path}") + return + + # 基本信息 + h, w, c = image.shape + print(f"尺寸: {w}x{h}, 通道数: {c}") + + # 亮度分析 + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + mean_brightness = np.mean(gray) + print(f"平均亮度: {mean_brightness:.2f}") + + # 对比度分析 + contrast = np.std(gray) + print(f"对比度(标准差): {contrast:.2f}") + + # 色彩分析 + hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + mean_saturation = np.mean(hsv[:, :, 1]) + print(f"平均饱和度: {mean_saturation:.2f}") + + # 锐度分析(拉普拉斯算子) + laplacian = cv2.Laplacian(gray, cv2.CV_64F) + sharpness = np.var(laplacian) + print(f"锐度: {sharpness:.2f}") + +def main(): + """主函数""" + print("DDColor 上色效果调试工具") + print("=" * 60) + + # 可以指定测试图像路径,或使用默认路径 + test_image_path = "/path/to/your/test/image.jpg" # 替换为实际路径 + + if len(sys.argv) > 1: + test_image_path = sys.argv[1] + + print(f"测试图像路径: {test_image_path}") + + if not os.path.exists(test_image_path): + print("⚠️ 测试图像不存在,将只运行URL测试") + + # 只测试直接上色 + test_direct_colorization(None) + + else: + # 分析原图质量 + analyze_image_quality(test_image_path) + + # 测试直接上色 + test_direct_colorization(test_image_path) + + # 模拟API处理 + api_result = simulate_api_processing(test_image_path) + + # 分析结果图像质量 + if os.path.exists("api_final_debug.webp"): + print("\n--- API处理结果质量分析 ---") + analyze_image_quality("api_final_debug.webp") + + if os.path.exists("local_colorized_debug.webp"): + print("\n--- 直接上色结果质量分析 ---") + analyze_image_quality("local_colorized_debug.webp") + + # 对比分析 + compare_results() + + print("\n调试完成!") + print("请检查生成的调试图像来识别问题所在。") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/face_analyzer.py b/face_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..3ec722d220becc6dc10cd12dea160bd6cf8d402b --- /dev/null +++ b/face_analyzer.py @@ -0,0 +1,1099 @@ +import os +import random +import time +from typing import List, Dict, Any + +import cv2 +import numpy as np + +import config +from config import logger, MODELS_PATH, OUTPUT_DIR, DEEPFACE_AVAILABLE, \ + YOLO_AVAILABLE +from facial_analyzer import FacialFeatureAnalyzer +from models import ModelType +from utils import save_image_high_quality + +if DEEPFACE_AVAILABLE: + from deepface import DeepFace + +# 可选导入 YOLO +if YOLO_AVAILABLE: + try: + from ultralytics import YOLO + + YOLO_AVAILABLE = True + except ImportError: + YOLO_AVAILABLE = False + YOLO = None + print("Warning: ENABLE_YOLO=true but ultralytics not available") + + +class EnhancedFaceAnalyzer: + """增强版人脸分析器 - 支持混合模型""" + + def __init__(self, models_dir: str = MODELS_PATH): + """ + 初始化人脸分析器 + :param models_dir: 模型文件目录 + """ + start_time = time.perf_counter() + self.models_dir = models_dir + self.MODEL_MEAN_VALUES = (104, 117, 123) + self.age_list = [ + "(0-2)", + "(4-6)", + "(8-12)", + "(15-20)", + "(25-32)", + "(38-43)", + "(48-53)", + "(60-100)", + ] + self.gender_list = ["Male", "Female"] + # 性别对应的颜色 (BGR格式) + self.gender_colors = { + "Male": (255, 165, 0), # 橙色 Orange + "Female": (255, 0, 255), # 洋红 Magenta / Fuchsia + } + + # 初始化五官分析器 + self.facial_analyzer = FacialFeatureAnalyzer() + # 加载HowCuteAmI模型 + self._load_howcuteami_models() + # 加载YOLOv人脸检测模型 + self._load_yolo_model() + + # 预热模型(可选,通过配置开关) + if getattr(config, "ENABLE_WARMUP", False): + self._warmup_models() + + init_time = time.perf_counter() - start_time + logger.info(f"EnhancedFaceAnalyzer initialized successfully, time: {init_time:.3f}s") + + def _cap_conf(self, value: float) -> float: + """将置信度限制在 [0, 0.9999] 并保留4位小数。""" + try: + v = float(value if value is not None else 0.0) + except Exception: + v = 0.0 + if v >= 1.0: + v = 0.9999 + if v < 0.0: + v = 0.0 + return round(v, 4) + + def _adjust_beauty_score(self, score: float) -> float: + try: + if not config.BEAUTY_ADJUST_ENABLED: + return score + # 读取提分区间与力度 + low = float(getattr(config, "BEAUTY_ADJUST_MIN", 6.0)) + high = float(getattr(config, "BEAUTY_ADJUST_MAX", getattr(config, "BEAUTY_ADJUST_THRESHOLD", 8.0))) + gamma = float(getattr(config, "BEAUTY_ADJUST_GAMMA", 0.3)) + gamma = max(0.0001, min(1.0, gamma)) + + # 区间有效性保护 + if not (0.0 <= low < high <= 10.0): + return score + + # 低于下限不提分,区间内提向上限,高于上限不变 + if score < low: + return score + if score < high: + # 向上限 high 进行温和靠拢:adjusted = high - gamma * (high - score) + adjusted = high - gamma * (high - score) + adjusted = round(min(10.0, max(0.0, adjusted)), 1) + try: + logger.info( + f"beauty_score adjusted: original={score:.1f} -> adjusted={adjusted:.1f} " + f"(range=[{low:.1f},{high:.1f}], gamma={gamma:.3f})" + ) + except Exception: + pass + return adjusted + return score + except Exception: + return score + + def _load_yolo_model(self): + """加载YOLOv人脸检测模型""" + self.yolo_model = None + if config.YOLO_AVAILABLE: + try: + # 尝试加载本地YOLOv人脸模型 + yolo_face_path = os.path.join(self.models_dir, config.YOLO_MODEL) + + if os.path.exists(yolo_face_path): + self.yolo_model = YOLO(yolo_face_path) + logger.info(f"Local YOLO face model loaded successfully: {yolo_face_path}") + else: + # 如果本地没有,尝试在线下载(第一次使用时) + logger.info("Local YOLO face model does not exist, attempting to download...") + try: + # 检查是否是yolov8,使用相应的模型 + model_name = "yolov11n-face.pt" # 默认使用yolov8n + self.yolo_model = YOLO(model_name) + logger.info( + f"YOLOv8 general model loaded successfully (detecting 'person' class as face regions)" + ) + except Exception as e: + logger.warning(f"YOLOv model download failed: {e}") + + except Exception as e: + logger.error(f"YOLOv model loading failed: {e}") + else: + logger.warning("ultralytics not installed, cannot use YOLOv") + + def _load_howcuteami_models(self): + """加载HowCuteAmI深度学习模型""" + try: + # 人脸检测模型 + face_proto = os.path.join(self.models_dir, "opencv_face_detector.pbtxt") + face_model = os.path.join(self.models_dir, "opencv_face_detector_uint8.pb") + self.face_net = cv2.dnn.readNet(face_model, face_proto) + + # 年龄预测模型 + age_proto = os.path.join(self.models_dir, "age_googlenet.prototxt") + age_model = os.path.join(self.models_dir, "age_googlenet.caffemodel") + self.age_net = cv2.dnn.readNet(age_model, age_proto) + + # 性别预测模型 + gender_proto = os.path.join(self.models_dir, "gender_googlenet.prototxt") + gender_model = os.path.join(self.models_dir, "gender_googlenet.caffemodel") + self.gender_net = cv2.dnn.readNet(gender_model, gender_proto) + + # 颜值预测模型 + beauty_proto = os.path.join(self.models_dir, "beauty_resnet.prototxt") + beauty_model = os.path.join(self.models_dir, "beauty_resnet.caffemodel") + self.beauty_net = cv2.dnn.readNet(beauty_model, beauty_proto) + + logger.info("HowCuteAmI model loaded successfully!") + + except Exception as e: + logger.error(f"HowCuteAmI model loading failed: {e}") + raise e + + # 人脸检测方法 + def _detect_faces( + self, frame: np.ndarray, conf_threshold: float = config.FACE_CONFIDENCE + ) -> List[List[int]]: + """ + 使用YOLO进行人脸检测,如果失败则回退到OpenCV DNN + """ + # 优先使用YOLO + face_boxes = [] + if self.yolo_model is not None: + try: + results = self.yolo_model(frame, conf=conf_threshold, verbose=False) + for result in results: + boxes = result.boxes + if boxes is not None: + for box in boxes: + # 检查类别ID (如果是专门的人脸模型,通常是0;如果是通用模型,person类别通常是0) + class_id = int(box.cls[0]) + # 获取边界框坐标 (xyxy格式) + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) + confidence = float(box.conf[0]) + logger.info( + f"detect class_id={class_id}, confidence={confidence}" + ) + # 基本边界检查 + frame_height, frame_width = frame.shape[:2] + x1 = max(0, int(x1)) + y1 = max(0, int(y1)) + x2 = min(frame_width, int(x2)) + y2 = min(frame_height, int(y2)) + + # 过滤太小的检测框 + width, height = x2 - x1, y2 - y1 + if ( + width > 30 and height > 30 + ): # YOLO通常检测精度更高,可以稍微提高最小尺寸 + # 如果使用通用模型检测person,需要进一步过滤头部区域 + if self._is_likely_face_region(x1, y1, x2, y2, frame): + face_boxes.append(self._scale_box([x1, y1, x2, y2])) + logger.info( + f"YOLO detected {len(face_boxes)} faces, conf_threshold={conf_threshold}" + ) + if face_boxes: # 如果YOLO检测到了人脸,直接返回 + return face_boxes + + except Exception as e: + logger.warning(f"YOLO detection failed, falling back to OpenCV DNN: {e}") + return self._detect_faces_opencv_fallback(frame, conf_threshold) + + return face_boxes + + def _is_likely_face_region( + self, x1: int, y1: int, x2: int, y2: int, frame: np.ndarray + ) -> bool: + """ + 判断检测区域是否可能是人脸区域(当使用通用YOLO模型时) + """ + width, height = x2 - x1, y2 - y1 + + # 长宽比检查 - 人脸/头部通常接近正方形 + aspect_ratio = width / height + if not (0.6 <= aspect_ratio <= 1.6): + return False + + # 位置检查 - 人脸通常在图像上半部分(简单启发式) + frame_height = frame.shape[0] + center_y = (y1 + y2) / 2 + if center_y > frame_height * 0.8: # 如果中心点在图像下方80%以下,可能不是人脸 + return False + + # 尺寸检查 - 不应该占据整个图像 + frame_width, frame_height = frame.shape[1], frame.shape[0] + if width > frame_width * 0.8 or height > frame_height * 0.8: + return False + + return True + + def _detect_faces_opencv_fallback( + self, frame: np.ndarray, conf_threshold: float = 0.5 + ) -> List[List[int]]: + """ + 优化版人脸检测 - 支持多尺度检测和小人脸识别 + """ + frame_height, frame_width = frame.shape[:2] + all_boxes = [] + + # 多尺度检测配置 - 从小到大,更好地检测不同大小的人脸 + detection_configs = [ + {"size": (300, 300), "threshold": conf_threshold}, + { + "size": (416, 416), + "threshold": max(0.3, conf_threshold - 0.2), + }, # 对大尺度降低阈值 + { + "size": (512, 512), + "threshold": max(0.25, conf_threshold - 0.25), + }, # 进一步降低阈值检测小脸 + ] + logger.info(f"Detecting faces using opencv, conf_threshold={conf_threshold}") + for config in detection_configs: + try: + # 图像预处理 - 增强对比度有助于小人脸检测 + processed_frame = cv2.convertScaleAbs(frame, alpha=1.1, beta=10) + + blob = cv2.dnn.blobFromImage( + processed_frame, 1.0, config["size"], [104, 117, 123], True, False + ) + self.face_net.setInput(blob) + detections = self.face_net.forward() + + # 提取检测结果 + for i in range(detections.shape[2]): + confidence = detections[0, 0, i, 2] + if confidence > config["threshold"]: + x1 = int(detections[0, 0, i, 3] * frame_width) + y1 = int(detections[0, 0, i, 4] * frame_height) + x2 = int(detections[0, 0, i, 5] * frame_width) + y2 = int(detections[0, 0, i, 6] * frame_height) + + # 基本边界检查 + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(frame_width, x2), min(frame_height, y2) + + # 过滤太小或不合理的检测框 + width, height = x2 - x1, y2 - y1 + if ( + width > 20 + and height > 20 + and width < frame_width * 0.8 + and height < frame_height * 0.8 + ): + # 长宽比检查 - 人脸通常接近正方形 + aspect_ratio = width / height + if 0.6 <= aspect_ratio <= 1.8: # 允许一定的椭圆形变 + all_boxes.append( + { + "box": [x1, y1, x2, y2], + "confidence": confidence, + "area": width * height, + } + ) + except Exception as e: + logger.warning(f"Scale {config['size']} detection failed: {e}") + continue + + # 如果没有检测到任何人脸,尝试更宽松的条件 + if not all_boxes: + logger.info("No faces detected, trying more relaxed detection conditions...") + try: + # 最后一次尝试:最低阈值 + 图像增强 + enhanced_frame = cv2.equalizeHist( + cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + ) + enhanced_frame = cv2.cvtColor(enhanced_frame, cv2.COLOR_GRAY2BGR) + + blob = cv2.dnn.blobFromImage( + enhanced_frame, 1.0, (300, 300), [104, 117, 123], True, False + ) + self.face_net.setInput(blob) + detections = self.face_net.forward() + + for i in range(detections.shape[2]): + confidence = detections[0, 0, i, 2] + if confidence > 0.15: # 非常低的阈值 + x1 = int(detections[0, 0, i, 3] * frame_width) + y1 = int(detections[0, 0, i, 4] * frame_height) + x2 = int(detections[0, 0, i, 5] * frame_width) + y2 = int(detections[0, 0, i, 6] * frame_height) + + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(frame_width, x2), min(frame_height, y2) + + width, height = x2 - x1, y2 - y1 + if width > 15 and height > 15: # 更小的最小尺寸 + aspect_ratio = width / height + if 0.5 <= aspect_ratio <= 2.0: # 更宽松的长宽比 + all_boxes.append( + { + "box": [x1, y1, x2, y2], + "confidence": confidence, + "area": width * height, + } + ) + except Exception as e: + logger.warning(f"Relaxed condition detection also failed: {e}") + + # NMS (非极大值抑制) 去除重复检测 + if all_boxes: + final_boxes = self._apply_nms(all_boxes, overlap_threshold=0.4) + return [self._scale_box(box["box"]) for box in final_boxes] + + return [] + + def _apply_nms( + self, detections: List[Dict], overlap_threshold: float = 0.4 + ) -> List[Dict]: + """ + 非极大值抑制,去除重复的检测框 + """ + if not detections: + return [] + + # 按置信度排序 + detections.sort(key=lambda x: x["confidence"], reverse=True) + + keep = [] + while detections: + # 保留置信度最高的 + best = detections.pop(0) + keep.append(best) + + # 移除与最佳检测重叠度高的其他检测 + remaining = [] + for det in detections: + if self._calculate_iou(best["box"], det["box"]) < overlap_threshold: + remaining.append(det) + detections = remaining + + return keep + + def _calculate_iou(self, box1: List[int], box2: List[int]) -> float: + """ + 计算两个边界框的IoU (交并比) + """ + x1_1, y1_1, x2_1, y2_1 = box1 + x1_2, y1_2, x2_2, y2_2 = box2 + + # 计算交集 + x1_i = max(x1_1, x1_2) + y1_i = max(y1_1, y1_2) + x2_i = min(x2_1, x2_2) + y2_i = min(y2_1, y2_2) + + if x2_i <= x1_i or y2_i <= y1_i: + return 0.0 + + intersection = (x2_i - x1_i) * (y2_i - y1_i) + + # 计算并集 + area1 = (x2_1 - x1_1) * (y2_1 - y1_1) + area2 = (x2_2 - x1_2) * (y2_2 - y1_2) + union = area1 + area2 - intersection + + return intersection / union if union > 0 else 0.0 + + def _scale_box(self, box: List[int]) -> List[int]: + """将矩形框缩放为正方形""" + width = box[2] - box[0] + height = box[3] - box[1] + maximum = max(width, height) + dx = int((maximum - width) / 2) + dy = int((maximum - height) / 2) + + return [box[0] - dx, box[1] - dy, box[2] + dx, box[3] + dy] + + def _crop_face(self, image: np.ndarray, box: List[int]) -> np.ndarray: + """裁剪人脸区域""" + x1, y1, x2, y2 = box + h, w = image.shape[:2] + x1 = max(0, x1) + y1 = max(0, y1) + x2 = min(w, x2) + y2 = min(h, y2) + return image[y1:y2, x1:x2] + + def _predict_beauty_gender_with_howcuteami( + self, face: np.ndarray + ) -> Dict[str, Any]: + """使用HowCuteAmI模型预测颜值和性别""" + try: + blob = cv2.dnn.blobFromImage( + face, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False + ) + + # 性别预测 + self.gender_net.setInput(blob) + gender_preds = self.gender_net.forward() + gender = self.gender_list[gender_preds[0].argmax()] + gender_confidence = float(np.max(gender_preds[0])) + gender_confidence = self._cap_conf(gender_confidence) + # 年龄预测 + self.age_net.setInput(blob) + age_preds = self.age_net.forward() + age = self.age_list[age_preds[0].argmax()] + age_confidence = float(np.max(age_preds[0])) + # 颜值预测 + blob_beauty = cv2.dnn.blobFromImage( + face, 1.0 / 255, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False + ) + self.beauty_net.setInput(blob_beauty) + beauty_preds = self.beauty_net.forward() + beauty_score = round(float(2.0 * np.sum(beauty_preds[0])), 1) + beauty_score = min(10.0, max(0.0, beauty_score)) + beauty_score = self._adjust_beauty_score(beauty_score) + raw_score = float(np.sum(beauty_preds[0])) + + return { + "age": age, + "age_confidence": round(age_confidence, 4), + "gender": gender, + "gender_confidence": gender_confidence, + "beauty_score": beauty_score, + "beauty_raw_score": round(raw_score, 4), + "age_model_used": "HowCuteAmI", + "gender_model_used": "HowCuteAmI", + "beauty_model_used": "HowCuteAmI", + } + except Exception as e: + logger.error(f"HowCuteAmI beauty gender prediction failed: {e}") + raise e + + def _predict_age_emotion_with_deepface( + self, face_image: np.ndarray + ) -> Dict[str, Any]: + """使用DeepFace预测年龄、情绪(并返回可用的性别信息用于回退)""" + if not DEEPFACE_AVAILABLE: + # 如果DeepFace不可用,使用HowCuteAmI的年龄预测作为回退 + return self._predict_age_with_howcuteami_fallback(face_image) + + if face_image is None or face_image.size == 0: + raise ValueError("无效的人脸图像") + + try: + # DeepFace分析 - 禁用进度条和详细输出 + result = DeepFace.analyze( + img_path=face_image, + actions=["age", "emotion", "gender"], + enforce_detection=False, + detector_backend="skip", + silent=True # 禁用进度条输出 + ) + + # 处理结果 (DeepFace返回的结果格式可能是list或dict) + if isinstance(result, list): + result = result[0] + + # 提取信息 + age = result.get("age", 25) + emotion = result.get("dominant_emotion", "neutral") + emotion_scores = result.get("emotion", {}) + # 性别信息(用于在HowCuteAmI置信度低时回退) + deep_gender = result.get("dominant_gender", "Woman") + deep_gender_conf = result.get("gender", {}).get(deep_gender, 50.0) / 100.0 + deep_gender_conf = self._cap_conf(deep_gender_conf) + if str(deep_gender).lower() in ["woman", "female"]: + deep_gender = "Female" + else: + deep_gender = "Male" + + age_conf = round(random.uniform(0.7613, 0.9599), 4) + return { + "age": str(int(age)), + "age_confidence": age_conf, + "emotion": emotion, + "emotion_analysis": emotion_scores, + "gender": deep_gender, + "gender_confidence": deep_gender_conf, + } + except Exception as e: + logger.error(f"DeepFace age emotion prediction failed, falling back to HowCuteAmI: {e}") + return self._predict_age_with_howcuteami_fallback(face_image) + + def _predict_age_with_howcuteami_fallback( + self, face_image: np.ndarray + ) -> Dict[str, Any]: + """HowCuteAmI年龄预测回退方案""" + try: + if face_image is None or face_image.size == 0: + raise ValueError("无法读取人脸图像") + + face_resized = cv2.resize(face_image, (224, 224)) + blob = cv2.dnn.blobFromImage( + face_resized, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False + ) + + # 年龄预测 + self.age_net.setInput(blob) + age_preds = self.age_net.forward() + age = self.age_list[age_preds[0].argmax()] + age_confidence = float(np.max(age_preds[0])) + + return { + "age": age[1:-1], # 去掉括号 + "age_confidence": round(age_confidence, 4), + "emotion": "neutral", # 默认情绪 + "emotion_analysis": {"neutral": 100.0}, # 默认情绪分析 + } + except Exception as e: + logger.error(f"HowCuteAmI age prediction fallback failed: {e}") + return { + "age": "25-32", + "age_confidence": 0.5, + "emotion": "neutral", + "emotion_analysis": {"neutral": 100.0}, + } + + def _predict_with_hybrid_model( + self, face: np.ndarray, face_image: np.ndarray + ) -> Dict[str, Any]: + """混合模型预测:HowCuteAmI(颜值+性别)+ DeepFace(年龄+情绪,年龄置信度低时优先使用)""" + # 使用HowCuteAmI预测颜值和性别 + beauty_gender_result = self._predict_beauty_gender_with_howcuteami(face) + + # 首先获取HowCuteAmI的年龄/性别预测置信度 + howcuteami_age_confidence = beauty_gender_result.get("age_confidence", 0) + gender_confidence = beauty_gender_result.get("gender_confidence", 0) + if gender_confidence >= 1: + gender_confidence = 0.9999 + age = beauty_gender_result["age"] + + # 如果HowCuteAmI的年龄置信度低于阈值,则使用DeepFace的年龄 + agec = config.AGE_CONFIDENCE + if howcuteami_age_confidence < agec: + # 使用DeepFace获取年龄/情绪(以及可选的性别回退信息) + age_emotion_result = self._predict_age_emotion_with_deepface( + face_image + ) + deep_age = age_emotion_result["age"] + logger.info( + f"HowCuteAmI age confidence ({howcuteami_age_confidence}) below {agec}, value=({age}); using DeepFace for age prediction, value={deep_age}" + ) + # 合并结果,使用DeepFace的年龄预测 + result = { + "gender": beauty_gender_result["gender"], # 先用HowCuteAmI,后面可能回退 + "gender_confidence": self._cap_conf(gender_confidence), + "beauty_score": beauty_gender_result["beauty_score"], + "beauty_raw_score": beauty_gender_result["beauty_raw_score"], + "age": deep_age, + "age_confidence": age_emotion_result["age_confidence"], + "emotion": age_emotion_result["emotion"], + "emotion_analysis": age_emotion_result["emotion_analysis"], + "model_used": "hybrid_deepface_age", + "age_model_used": "DeepFace", + "gender_model_used": "HowCuteAmI", + } + else: + # HowCuteAmI年龄置信度足够高,使用原有逻辑 + logger.info( + f"HowCuteAmI age confidence ({howcuteami_age_confidence}) is high enough, value={age}; using HowCuteAmI for age prediction" + ) + # 合并结果,保留HowCuteAmI的年龄预测 + result = { + "gender": beauty_gender_result["gender"], # 先用HowCuteAmI,后面可能回退 + "gender_confidence": self._cap_conf(gender_confidence), + "beauty_score": beauty_gender_result["beauty_score"], + "beauty_raw_score": beauty_gender_result["beauty_raw_score"], + "age": beauty_gender_result["age"], + "age_confidence": beauty_gender_result["age_confidence"], + "emotion": None, + "emotion_analysis": None, + "model_used": "hybrid", + "age_model_used": "HowCuteAmI", + "gender_model_used": "HowCuteAmI", + } + + # 统一性别判定规则:任一模型判为Female则Female;两者都为Male才Male + try: + how_gender = beauty_gender_result.get("gender") + how_conf = float(beauty_gender_result.get("gender_confidence", 0) or 0) + deep_gender = age_emotion_result.get("gender") + deep_conf = float(age_emotion_result.get("gender_confidence", 0) or 0) + + final_gender = result.get("gender") + final_conf = float(result.get("gender_confidence", 0) or 0) + # 规则判断 + if (str(how_gender) == "Female") or (str(deep_gender) == "Female"): + final_gender = "Female" + final_conf = max(how_conf if how_gender == "Female" else 0, + deep_conf if deep_gender == "Female" else 0) + result["gender_model_used"] = "Combined(H+DF)" + elif (str(how_gender) == "Male") and (str(deep_gender) == "Male"): + final_gender = "Male" + final_conf = max(how_conf if how_gender == "Male" else 0, + deep_conf if deep_gender == "Male" else 0) + result["gender_model_used"] = "Combined(H+DF)" + # 否则保持原判定 + + result["gender"] = final_gender + result["gender_confidence"] = self._cap_conf(final_conf) + except Exception: + pass + + return result + + def _predict_with_howcuteami(self, face: np.ndarray) -> Dict[str, Any]: + """使用HowCuteAmI模型进行完整预测""" + try: + # 性别预测 + blob = cv2.dnn.blobFromImage( + face, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False + ) + self.gender_net.setInput(blob) + gender_preds = self.gender_net.forward() + gender = self.gender_list[gender_preds[0].argmax()] + gender_confidence = float(np.max(gender_preds[0])) + gender_confidence = self._cap_conf(gender_confidence) + + # 年龄预测 + self.age_net.setInput(blob) + age_preds = self.age_net.forward() + age = self.age_list[age_preds[0].argmax()] + age_confidence = float(np.max(age_preds[0])) + + # 颜值预测 + blob_beauty = cv2.dnn.blobFromImage( + face, 1.0 / 255, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False + ) + self.beauty_net.setInput(blob_beauty) + beauty_preds = self.beauty_net.forward() + beauty_score = round(float(2.0 * np.sum(beauty_preds[0])), 1) + beauty_score = min(10.0, max(0.0, beauty_score)) + beauty_score = self._adjust_beauty_score(beauty_score) + raw_score = float(np.sum(beauty_preds[0])) + + return { + "gender": gender, + "gender_confidence": gender_confidence, + "age": age[1:-1], # 去掉括号 + "age_confidence": round(age_confidence, 4), + "beauty_score": beauty_score, + "beauty_raw_score": round(raw_score, 4), + "model_used": "HowCuteAmI", + "emotion": "neutral", # HowCuteAmI不支持情绪分析 + "emotion_analysis": {"neutral": 100.0}, + "age_model_used": "HowCuteAmI", + "gender_model_used": "HowCuteAmI", + "beauty_model_used": "HowCuteAmI", + } + except Exception as e: + logger.error(f"HowCuteAmI prediction failed: {e}") + raise e + + def _predict_with_deepface(self, face_image: np.ndarray) -> Dict[str, Any]: + """使用DeepFace进行预测""" + if not DEEPFACE_AVAILABLE: + raise ValueError("DeepFace未安装") + + if face_image is None or face_image.size == 0: + raise ValueError("无效的人脸图像") + + try: + # DeepFace分析 - 禁用进度条和详细输出 + result = DeepFace.analyze( + img_path=face_image, + actions=["age", "gender", "emotion"], + enforce_detection=False, + detector_backend="skip", + silent=True # 禁用进度条输出 + ) + + # 处理结果 (DeepFace返回的结果格式可能是list或dict) + if isinstance(result, list): + result = result[0] + + # 提取信息 + age = result.get("age", 25) + gender = result.get("dominant_gender", "Woman") + gender_confidence = result.get("gender", {}).get(gender, 0.5) / 100 + gender_confidence = self._cap_conf(gender_confidence) + + # 统一性别标签 + if gender.lower() in ["woman", "female"]: + gender = "Female" + else: + gender = "Male" + + # DeepFace没有内置颜值评分,这里使用简单的启发式方法 + emotion = result.get("dominant_emotion", "neutral") + emotion_scores = result.get("emotion", {}) + + # 基于情绪和年龄的简单颜值估算 + happiness_score = emotion_scores.get("happy", 0) / 100 + neutral_score = emotion_scores.get("neutral", 0) / 100 + + # 简单的颜值算法 (可以改进) + base_beauty = 6.0 # 基础分 + emotion_bonus = happiness_score * 2 + neutral_score * 1 + age_factor = max(0.5, 1 - abs(age - 25) / 50) # 25岁为最佳年龄 + + beauty_score = round(min(10.0, base_beauty + emotion_bonus + age_factor), 2) + + age_conf = round(random.uniform(0.7613, 0.9599), 4) + return { + "gender": gender, + "gender_confidence": gender_confidence, + "age": str(int(age)), + "age_confidence": age_conf, # DeepFace年龄置信度(随机范围) + "beauty_score": beauty_score, + "beauty_raw_score": round(beauty_score / 10, 4), + "model_used": "DeepFace", + "emotion": emotion, + "emotion_analysis": emotion_scores, + "age_model_used": "DeepFace", + "gender_model_used": "DeepFace", + "beauty_model_used": "Heuristic", + } + except Exception as e: + logger.error(f"DeepFace prediction failed: {e}") + raise e + + def analyze_faces( + self, + image: np.ndarray, + original_image_hash: str, + model_type: ModelType = ModelType.HYBRID, + ) -> Dict[str, Any]: + """ + 分析图片中的人脸 + :param image: 输入图像 + :param original_image_hash: 原始图片的MD5哈希值 + :param model_type: 使用的模型类型 + :return: 分析结果 + """ + if image is None: + raise ValueError("无效的图像输入") + + # 检测人脸 + face_boxes = self._detect_faces(image) + + if not face_boxes: + return { + "success": False, + "message": "请尝试上传清晰、无遮挡的正面照片", + "face_count": 0, + "faces": [], + "annotated_image": None, + "model_used": model_type.value, + } + + results = { + "success": True, + "message": f"成功检测到 {len(face_boxes)} 张人脸", + "face_count": len(face_boxes), + "faces": [], + "model_used": model_type.value, + } + + # 复制原图用于绘制 + annotated_image = image.copy() + logger.info( + f"Input annotated_image shape: {annotated_image.shape}, dtype: {annotated_image.dtype}, ndim: {annotated_image.ndim}" + ) + # 分析每张人脸 + for i, face_box in enumerate(face_boxes): + # 裁剪人脸 + face_cropped = self._crop_face(image, face_box) + if face_cropped.size == 0: + logger.warning(f"Cropped face {i + 1} is empty, skipping.") + continue + + face_resized = cv2.resize(face_cropped, (224, 224)) + face_for_deepface = face_cropped.copy() + + # 根据模型类型进行预测 + try: + if model_type == ModelType.HYBRID: + # 混合模式:颜值性别用HowCuteAmI,年龄情绪用DeepFace + prediction_result = self._predict_with_hybrid_model( + face_resized, face_for_deepface + ) + elif model_type == ModelType.HOWCUTEAMI: + prediction_result = self._predict_with_howcuteami(face_resized) + # 非混合模式也进行性别合并:引入DeepFace性别 + try: + age_emotion_result = self._predict_age_emotion_with_deepface( + face_for_deepface + ) + how_gender = prediction_result.get("gender") + how_conf = float(prediction_result.get("gender_confidence", 0) or 0) + deep_gender = age_emotion_result.get("gender") + deep_conf = float(age_emotion_result.get("gender_confidence", 0) or 0) + final_gender = prediction_result.get("gender") + final_conf = float(prediction_result.get("gender_confidence", 0) or 0) + if (str(how_gender) == "Female") or (str(deep_gender) == "Female"): + final_gender = "Female" + final_conf = max(how_conf if how_gender == "Female" else 0, + deep_conf if deep_gender == "Female" else 0) + prediction_result["gender_model_used"] = "Combined(H+DF)" + elif (str(how_gender) == "Male") and (str(deep_gender) == "Male"): + final_gender = "Male" + final_conf = max(how_conf if how_gender == "Male" else 0, + deep_conf if deep_gender == "Male" else 0) + prediction_result["gender_model_used"] = "Combined(H+DF)" + prediction_result["gender"] = final_gender + prediction_result["gender_confidence"] = round(float(final_conf), 4) + except Exception: + pass + elif model_type == ModelType.DEEPFACE and DEEPFACE_AVAILABLE: + prediction_result = self._predict_with_deepface(face_for_deepface) + # 非混合模式也进行性别合并:引入HowCuteAmI性别 + try: + beauty_gender_result = self._predict_beauty_gender_with_howcuteami( + face_resized + ) + deep_gender = prediction_result.get("gender") + deep_conf = float(prediction_result.get("gender_confidence", 0) or 0) + how_gender = beauty_gender_result.get("gender") + how_conf = float(beauty_gender_result.get("gender_confidence", 0) or 0) + final_gender = prediction_result.get("gender") + final_conf = float(prediction_result.get("gender_confidence", 0) or 0) + if (str(how_gender) == "Female") or (str(deep_gender) == "Female"): + final_gender = "Female" + final_conf = max(how_conf if how_gender == "Female" else 0, + deep_conf if deep_gender == "Female" else 0) + prediction_result["gender_model_used"] = "Combined(H+DF)" + elif (str(how_gender) == "Male") and (str(deep_gender) == "Male"): + final_gender = "Male" + final_conf = max(how_conf if how_gender == "Male" else 0, + deep_conf if deep_gender == "Male" else 0) + prediction_result["gender_model_used"] = "Combined(H+DF)" + prediction_result["gender"] = final_gender + prediction_result["gender_confidence"] = round(float(final_conf), 4) + except Exception: + pass + else: + # 回退到混合模式 + prediction_result = self._predict_with_hybrid_model( + face_resized, face_for_deepface + ) + logger.warning(f"Model {model_type.value} is not available, using hybrid mode") + + except Exception as e: + logger.error(f"Prediction failed, using default values: {e}") + prediction_result = { + "gender": "Unknown", + "gender_confidence": 0.5, + "age": "25-32", + "age_confidence": 0.5, + "beauty_score": 5.0, + "beauty_raw_score": 0.5, + "emotion": "neutral", + "emotion_analysis": {"neutral": 100.0}, + "model_used": "fallback", + } + + # 五官分析 + # facial_features = self.facial_analyzer.analyze_facial_features( + # face_cropped, face_box + # ) + + # 颜色设置与年龄显示统一(应用女性年龄调整) + gender = prediction_result.get("gender", "Unknown") + color_bgr = self.gender_colors.get(gender, (128, 128, 128)) + color_hex = f"#{color_bgr[2]:02x}{color_bgr[1]:02x}{color_bgr[0]:02x}" + + # 年龄文本与调整 + raw_age_str = prediction_result.get("age", "Unknown") + display_age_str = str(raw_age_str) + age_adjusted_flag = False + age_adjustment_value = int(getattr(config, "FEMALE_AGE_ADJUSTMENT", 0) or 0) + age_adjustment_threshold = int(getattr(config, "FEMALE_AGE_ADJUSTMENT_THRESHOLD", 999) or 999) + + # 仅对女性且年龄达到阈值时进行调整 + try: + # 支持 "25-32" 或 "25" 格式 + if "-" in str(raw_age_str): + age_num = int(str(raw_age_str).split("-")[0].strip("() ")) + else: + age_num = int(str(raw_age_str).strip()) + + if str(gender) == "Female" and age_num >= age_adjustment_threshold and age_adjustment_value > 0: + adjusted_age = max(0, age_num - age_adjustment_value) + display_age_str = str(adjusted_age) + age_adjusted_flag = True + try: + logger.info(f"Adjusted age for female (draw+data): {age_num} -> {adjusted_age}") + except Exception: + pass + except Exception: + # 无法解析年龄时,保持原样 + pass + + # 保存裁剪的人脸 + cropped_face_filename = f"{original_image_hash}_face_{i + 1}.webp" + cropped_face_path = os.path.join(OUTPUT_DIR, cropped_face_filename) + try: + save_image_high_quality(face_cropped, cropped_face_path) + logger.info(f"cropped face: {cropped_face_path}") + except Exception as e: + logger.error(f"Failed to save cropped face {cropped_face_path}: {e}") + cropped_face_filename = None + + # 在图片上绘制标注 + if config.DRAW_SCORE: + cv2.rectangle( + annotated_image, + (face_box[0], face_box[1]), + (face_box[2], face_box[3]), + color_bgr, + int(round(image.shape[0] / 400)), + 8, + ) + + # 标签文本 + beauty_score = prediction_result.get("beauty_score", 0) + label = f"{gender}, {display_age_str}, {beauty_score}" + + font_scale = max( + 0.3, min(0.7, image.shape[0] / 800) + ) # 从500改为800,范围从0.5-1.0改为0.3-0.7 + font_thickness = 2 + font = cv2.FONT_HERSHEY_SIMPLEX + # 绘制文本 + text_x = face_box[0] + text_y = face_box[1] - 10 if face_box[1] - 10 > 20 else face_box[1] + 30 + + # 计算文字大小(宽高) + (text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, font_thickness) + + # 画黑色矩形背景,稍微比文字框大一点,增加边距 + background_tl = (text_x, text_y - text_height - baseline) # 矩形左上角 + background_br = (text_x + text_width, text_y + baseline) # 矩形右下角 + + if config.DRAW_SCORE: + cv2.rectangle( + annotated_image, + background_tl, + background_br, + color_bgr, # 黑色背景 + thickness=-1 # 填充 + ) + cv2.putText( + annotated_image, + label, + (text_x, text_y), + font, + font_scale, + (255, 255, 255), + font_thickness, + cv2.LINE_AA, + ) + + # 构建人脸结果 + face_result = { + "face_id": i + 1, + "gender": gender, + "gender_confidence": prediction_result.get("gender_confidence", 0), + "gender_model_used": prediction_result.get("gender_model_used", prediction_result.get("model_used", model_type.value)), + "age": display_age_str, + "age_confidence": prediction_result.get("age_confidence", 0), + "age_model_used": prediction_result.get("age_model_used", prediction_result.get("model_used", model_type.value)), + "beauty_score": prediction_result.get("beauty_score", 0), + "beauty_raw_score": prediction_result.get("beauty_raw_score", 0), + "emotion": prediction_result.get("emotion", "neutral"), + "emotion_analysis": prediction_result.get("emotion_analysis", {}), + # "facial_features": facial_features, # 五官分析 + "bounding_box": { + "x1": int(face_box[0]), + "y1": int(face_box[1]), + "x2": int(face_box[2]), + "y2": int(face_box[3]), + }, + "color": { + "bgr": [int(color_bgr[0]), int(color_bgr[1]), int(color_bgr[2])], + "hex": color_hex, + }, + "cropped_face_filename": cropped_face_filename, + "model_used": prediction_result.get("model_used", model_type.value), + } + + if age_adjusted_flag: + face_result["age_adjusted"] = True + face_result["age_adjustment_value"] = int(age_adjustment_value) + + results["faces"].append(face_result) + + results["annotated_image"] = annotated_image + return results + + def _warmup_models(self): + """预热模型,减少首次调用延迟""" + try: + logger.info("Starting to warm up models...") + + # 创建一个小的测试图像 (64x64) + test_image = np.ones((64, 64, 3), dtype=np.uint8) * 128 + + # 预热DeepFace模型(如果可用) + if DEEPFACE_AVAILABLE: + try: + import tempfile + with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_file: + cv2.imwrite(tmp_file.name, test_image, [cv2.IMWRITE_WEBP_QUALITY, 95]) + # 预热DeepFace - 使用最小的actions集合 + DeepFace.analyze( + img_path=tmp_file.name, + actions=["age", "emotion", "gender"], + detector_backend="yolov8", + enforce_detection=False, + silent=True + ) + os.unlink(tmp_file.name) + logger.info("DeepFace model warm-up completed") + except Exception as e: + logger.warning(f"DeepFace model warm-up failed: {e}") + + # 预热OpenCV DNN模型 + try: + # 预热人脸检测模型 + blob = cv2.dnn.blobFromImage(test_image, 1.0, (300, 300), (104, 117, 123)) + self.face_net.setInput(blob) + self.face_net.forward() + + # 预热年龄预测模型 + test_face = cv2.resize(test_image, (224, 224)) + blob = cv2.dnn.blobFromImage(test_face, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False) + self.age_net.setInput(blob) + self.age_net.forward() + + # 预热性别预测模型 + self.gender_net.setInput(blob) + self.gender_net.forward() + + # 预热颜值评分模型 + self.beauty_net.setInput(blob) + self.beauty_net.forward() + + logger.info("OpenCV DNN model warm-up completed") + except Exception as e: + logger.warning(f"OpenCV DNN model warm-up failed: {e}") + + logger.info("Model warm-up completed") + except Exception as e: + logger.warning(f"Error occurred during model warm-up: {e}") diff --git a/facial_analyzer.py b/facial_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..9b8e998ab7e6210a13e51bbada676cd93d05f84f --- /dev/null +++ b/facial_analyzer.py @@ -0,0 +1,912 @@ +import traceback +from typing import List, Dict, Any + +import cv2 +import numpy as np + +import config +from config import logger, DLIB_AVAILABLE + +if DLIB_AVAILABLE: + import mediapipe as mp + + +class FacialFeatureAnalyzer: + """五官分析器""" + + def __init__(self): + self.face_mesh = None + if DLIB_AVAILABLE: + try: + # 初始化MediaPipe Face Mesh + mp_face_mesh = mp.solutions.face_mesh + self.face_mesh = mp_face_mesh.FaceMesh( + static_image_mode=True, + max_num_faces=1, + refine_landmarks=True, + min_detection_confidence=0.5, + min_tracking_confidence=0.5 + ) + logger.info("MediaPipe face landmark detector loaded successfully") + except Exception as e: + logger.error(f"Failed to load MediaPipe model: {e}") + + def analyze_facial_features( + self, face_image: np.ndarray, face_box: List[int] + ) -> Dict[str, Any]: + """ + 分析五官特征 + :param face_image: 人脸图像 + :param face_box: 人脸边界框 [x1, y1, x2, y2] + :return: 五官分析结果 + """ + if not DLIB_AVAILABLE or self.face_mesh is None: + return self._basic_facial_analysis(face_image) + + try: + # MediaPipe需要RGB图像 + rgb_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB) + + # 检测关键点 + results = self.face_mesh.process(rgb_image) + + if not results.multi_face_landmarks: + logger.warning("No facial landmarks detected") + return self._basic_facial_analysis(face_image) + + # 获取第一个面部的关键点 + face_landmarks = results.multi_face_landmarks[0] + + # 将MediaPipe的468个关键点转换为类似dlib 68点的格式 + points = self._convert_mediapipe_to_dlib_format(face_landmarks, face_image.shape) + + return self._analyze_features_from_landmarks(points, face_image.shape) + + except Exception as e: + logger.error(f"Facial feature analysis failed: {e}") + traceback.print_exc() # ← 打印完整堆栈,包括确切行号 + return self._basic_facial_analysis(face_image) + + def _convert_mediapipe_to_dlib_format(self, face_landmarks, image_shape): + """ + 将MediaPipe的468个关键点转换为类似dlib 68点的格式 + MediaPipe到dlib的关键点映射 + """ + h, w = image_shape[:2] + + # MediaPipe关键点索引到dlib 68点的映射 + # 这个映射基于MediaPipe Face Mesh的标准索引 + mediapipe_to_dlib_map = { + # 面部轮廓 (0-16) + 0: 234, # 下巴最低点 + 1: 132, # 右脸颊下 + 2: 172, # 右脸颊 + 3: 136, # 右脸颊上 + 4: 150, # 右颧骨 + 5: 149, # 右太阳穴 + 6: 176, # 右额头边缘 + 7: 148, # 右额头 + 8: 152, # 额头中央 + 9: 377, # 左额头 + 10: 400, # 左额头边缘 + 11: 378, # 左太阳穴 + 12: 379, # 左颧骨 + 13: 365, # 左脸颊上 + 14: 397, # 左脸颊 + 15: 361, # 左脸颊下 + 16: 454, # 下巴左侧 + + # 右眉毛 (17-21) + 17: 70, # 右眉毛外端 + 18: 63, # 右眉毛 + 19: 105, # 右眉毛 + 20: 66, # 右眉毛 + 21: 107, # 右眉毛内端 + + # 左眉毛 (22-26) + 22: 336, # 左眉毛内端 + 23: 296, # 左眉毛 + 24: 334, # 左眉毛 + 25: 293, # 左眉毛 + 26: 300, # 左眉毛外端 + + # 鼻梁 (27-30) + 27: 168, # 鼻梁顶 + 28: 8, # 鼻梁 + 29: 9, # 鼻梁 + 30: 10, # 鼻梁底 + + # 鼻翼 (31-35) + 31: 151, # 右鼻翼 + 32: 134, # 右鼻孔 + 33: 2, # 鼻尖 + 34: 363, # 左鼻孔 + 35: 378, # 左鼻翼 + + # 右眼 (36-41) + 36: 33, # 右眼外角 + 37: 7, # 右眼上眼睑 + 38: 163, # 右眼上眼睑 + 39: 144, # 右眼内角 + 40: 145, # 右眼下眼睑 + 41: 153, # 右眼下眼睑 + + # 左眼 (42-47) + 42: 362, # 左眼内角 + 43: 382, # 左眼上眼睑 + 44: 381, # 左眼上眼睑 + 45: 380, # 左眼外角 + 46: 374, # 左眼下眼睑 + 47: 373, # 左眼下眼睑 + + # 嘴部轮廓 (48-67) + 48: 78, # 右嘴角 + 49: 95, # 右上唇 + 50: 88, # 上唇右侧 + 51: 178, # 上唇中央右 + 52: 87, # 上唇中央 + 53: 14, # 上唇中央左 + 54: 317, # 上唇左侧 + 55: 318, # 左上唇 + 56: 308, # 左嘴角 + 57: 324, # 左下唇 + 58: 318, # 下唇左侧 + 59: 16, # 下唇中央左 + 60: 17, # 下唇中央 + 61: 18, # 下唇中央右 + 62: 200, # 下唇右侧 + 63: 199, # 右下唇 + 64: 175, # 右嘴角内 + 65: 84, # 上唇内右 + 66: 17, # 下唇内中央 + 67: 314, # 上唇内左 + } + + # 转换关键点 + points = [] + for i in range(68): + if i in mediapipe_to_dlib_map: + mp_idx = mediapipe_to_dlib_map[i] + if mp_idx < len(face_landmarks.landmark): + landmark = face_landmarks.landmark[mp_idx] + x = int(landmark.x * w) + y = int(landmark.y * h) + points.append((x, y)) + else: + # 如果索引超出范围,使用默认位置 + points.append((w//2, h//2)) + else: + # 如果没有映射,使用默认位置 + points.append((w//2, h//2)) + + return points + + def _analyze_features_from_landmarks( + self, landmarks: List[tuple], image_shape: tuple + ) -> Dict[str, Any]: + """基于68个关键点分析五官""" + try: + # 定义各部位的关键点索引 + jawline = landmarks[0:17] # 下颌线 + left_eyebrow = landmarks[17:22] # 左眉毛 + right_eyebrow = landmarks[22:27] # 右眉毛 + nose = landmarks[27:36] # 鼻子 + left_eye = landmarks[36:42] # 左眼 + right_eye = landmarks[42:48] # 右眼 + mouth = landmarks[48:68] # 嘴巴 + + # 计算各部位得分 (简化版,实际应用需要更复杂的算法) + scores = { + "eyes": self._score_eyes(left_eye, right_eye, image_shape), + "nose": self._score_nose(nose, image_shape), + "mouth": self._score_mouth(mouth, image_shape), + "eyebrows": self._score_eyebrows( + left_eyebrow, right_eyebrow, image_shape + ), + "jawline": self._score_jawline(jawline, image_shape), + } + + # 计算总体协调性 + harmony_score = self._calculate_harmony_new(landmarks, image_shape) + # 温和上调整体协调性分数(与颜值类似的拉升策略) + harmony_score = self._adjust_harmony_score(harmony_score) + + return { + "facial_features": scores, + "harmony_score": round(harmony_score, 2), + "overall_facial_score": round(sum(scores.values()) / len(scores), 2), + "analysis_method": "mediapipe_landmarks", + } + + except Exception as e: + logger.error(f"Landmark analysis failed: {e}") + return self._basic_facial_analysis(None) + + def _adjust_harmony_score(self, score: float) -> float: + """整体协调性分值温和拉升:当低于阈值时往阈值靠拢一点。""" + try: + if not getattr(config, "HARMONY_ADJUST_ENABLED", False): + return round(float(score), 2) + thr = float(getattr(config, "HARMONY_ADJUST_THRESHOLD", 8.0)) + gamma = float(getattr(config, "HARMONY_ADJUST_GAMMA", 0.5)) + gamma = max(0.0001, min(1.0, gamma)) + s = float(score) + if s < thr: + s = thr - gamma * (thr - s) + return round(min(10.0, max(0.0, s)), 2) + except Exception: + try: + return round(float(score), 2) + except Exception: + return 6.21 + + def _score_eyes( + self, left_eye: List[tuple], right_eye: List[tuple], image_shape: tuple + ) -> float: + """眼部评分""" + try: + # 计算眼部对称性和大小 + left_width = abs(left_eye[3][0] - left_eye[0][0]) + right_width = abs(right_eye[3][0] - right_eye[0][0]) + + # 计算眼部高度 + left_height = abs(left_eye[1][1] - left_eye[5][1]) + right_height = abs(right_eye[1][1] - right_eye[5][1]) + + # 对称性评分 - 宽度对称性 + width_symmetry = 1 - min( + abs(left_width - right_width) / max(left_width, right_width), 0.5 + ) + + # 高度对称性 + height_symmetry = 1 - min( + abs(left_height - right_height) / max(left_height, right_height), 0.5 + ) + + # 大小适中性评分 (相对于脸部宽度) - 调整理想比例 + avg_eye_width = (left_width + right_width) / 2 + face_width = image_shape[1] + ideal_ratio = 0.08 # 调整理想比例,原来0.15太大 + size_score = max( + 0, 1 - abs(avg_eye_width / face_width - ideal_ratio) / ideal_ratio + ) + + # 眼部长宽比评分 + avg_eye_height = (left_height + right_height) / 2 + aspect_ratio = avg_eye_width / max(avg_eye_height, 1) # 避免除零 + ideal_aspect = 3.0 # 理想长宽比 + aspect_score = max(0, 1 - abs(aspect_ratio - ideal_aspect) / ideal_aspect) + + final_score = ( + width_symmetry * 0.3 + + height_symmetry * 0.3 + + size_score * 0.25 + + aspect_score * 0.15 + ) * 10 + return round(max(0, min(10, final_score)), 2) + except: + return 6.21 + + def _score_nose(self, nose: List[tuple], image_shape: tuple) -> float: + """鼻部评分""" + try: + # 鼻子关键点 + nose_tip = nose[3] # 鼻尖 + nose_bridge_top = nose[0] # 鼻梁顶部 + left_nostril = nose[1] + right_nostril = nose[5] + + # 计算鼻子的直线度 (鼻梁是否挺直) + straightness = 1 - min( + abs(nose_tip[0] - nose_bridge_top[0]) / (image_shape[1] * 0.1), 1.0 + ) + + # 鼻宽评分 - 使用鼻翼宽度 + nose_width = abs(right_nostril[0] - left_nostril[0]) + face_width = image_shape[1] + ideal_nose_ratio = 0.06 # 调整理想比例 + width_score = max( + 0, + 1 - abs(nose_width / face_width - ideal_nose_ratio) / ideal_nose_ratio, + ) + + # 鼻子长度评分 + nose_length = abs(nose_tip[1] - nose_bridge_top[1]) + face_height = image_shape[0] + ideal_length_ratio = 0.08 + length_score = max( + 0, + 1 + - abs(nose_length / face_height - ideal_length_ratio) + / ideal_length_ratio, + ) + + final_score = ( + straightness * 0.4 + width_score * 0.35 + length_score * 0.25 + ) * 10 + return round(max(0, min(10, final_score)), 2) + except: + return 6.21 + + def _score_mouth(self, mouth: List[tuple], image_shape: tuple) -> float: + """嘴部评分 - 大幅优化,更宽松的评分标准""" + try: + # 嘴角点 + left_corner = mouth[0] # 左嘴角 + right_corner = mouth[6] # 右嘴角 + + # 上唇和下唇中心点 + upper_lip_center = mouth[3] # 上唇中心 + lower_lip_center = mouth[9] # 下唇中心 + + # 基础分数,避免过低 + base_score = 6.0 + + # 1. 嘴宽评分 - 更宽松的标准 + mouth_width = abs(right_corner[0] - left_corner[0]) + face_width = image_shape[1] + mouth_ratio = mouth_width / face_width + + # 设置更宽的合理范围 (0.04-0.15) + if 0.04 <= mouth_ratio <= 0.15: + width_score = 1.0 # 在合理范围内就给满分 + elif mouth_ratio < 0.04: + width_score = max(0.3, mouth_ratio / 0.04) # 太小时渐减 + else: + width_score = max(0.3, 0.15 / mouth_ratio) # 太大时渐减 + + # 2. 唇厚度评分 - 简化并放宽标准 + lip_thickness = abs(lower_lip_center[1] - upper_lip_center[1]) + # 只要厚度不是极端值就给高分 + if lip_thickness > 3: # 像素值,有一定厚度 + thickness_score = min(1.0, lip_thickness / 25) # 25像素为满分 + else: + thickness_score = 0.5 # 太薄给中等分数 + + # 3. 嘴部对称性评分 - 更宽松 + mouth_center_x = (left_corner[0] + right_corner[0]) / 2 + face_center_x = image_shape[1] / 2 + center_deviation = abs(mouth_center_x - face_center_x) / face_width + + if center_deviation < 0.02: # 偏差小于2% + symmetry_score = 1.0 + elif center_deviation < 0.05: # 偏差小于5% + symmetry_score = 0.8 + else: + symmetry_score = max(0.5, 1 - center_deviation * 10) # 最低0.5分 + + # 4. 嘴唇形状评分 - 简化 + # 检查嘴角是否在合理位置 + corner_height_diff = abs(left_corner[1] - right_corner[1]) + if corner_height_diff < face_width * 0.02: # 嘴角高度差异小 + shape_score = 1.0 + else: + shape_score = max(0.6, 1 - corner_height_diff / (face_width * 0.02)) + + # 5. 综合评分 - 调整权重,给基础分更大权重 + feature_score = ( + width_score * 0.3 + + thickness_score * 0.25 + + symmetry_score * 0.25 + + shape_score * 0.2 + ) + + # 最终分数 = 基础分 + 特征分奖励 + final_score = base_score + feature_score * 4 # 最高10分 + + return round(max(4.0, min(10, final_score)), 2) # 最低4分,最高10分 + except Exception as e: + return 6.21 + + def _score_eyebrows( + self, left_brow: List[tuple], right_brow: List[tuple], image_shape: tuple + ) -> float: + """眉毛评分 - 改进算法""" + try: + # 计算眉毛长度 + left_length = abs(left_brow[-1][0] - left_brow[0][0]) + right_length = abs(right_brow[-1][0] - right_brow[0][0]) + + # 长度对称性 + length_symmetry = 1 - min( + abs(left_length - right_length) / max(left_length, right_length), 0.5 + ) + + # 计算眉毛拱形 - 改进方法 + left_peak_y = min([p[1] for p in left_brow]) # 眉峰(y坐标最小) + left_ends_y = (left_brow[0][1] + left_brow[-1][1]) / 2 # 眉毛两端平均高度 + left_arch = max(0, left_ends_y - left_peak_y) # 拱形高度 + + right_peak_y = min([p[1] for p in right_brow]) + right_ends_y = (right_brow[0][1] + right_brow[-1][1]) / 2 + right_arch = max(0, right_ends_y - right_peak_y) + + # 拱形对称性 + arch_symmetry = 1 - min( + abs(left_arch - right_arch) / max(left_arch, right_arch, 1), 0.5 + ) + + # 眉形适中性评分 + avg_arch = (left_arch + right_arch) / 2 + face_height = image_shape[0] + ideal_arch_ratio = 0.015 # 理想拱形比例 + arch_ratio = avg_arch / face_height + arch_score = max( + 0, 1 - abs(arch_ratio - ideal_arch_ratio) / ideal_arch_ratio + ) + + # 眉毛浓密度(通过点的密集程度估算) + density_score = min(1.0, (len(left_brow) + len(right_brow)) / 10) + + final_score = ( + length_symmetry * 0.3 + + arch_symmetry * 0.3 + + arch_score * 0.25 + + density_score * 0.15 + ) * 10 + return round(max(0, min(10, final_score)), 2) + except: + return 6.21 + + def _score_jawline(self, jawline: List[tuple], image_shape: tuple) -> float: + """下颌线评分 - 改进算法""" + try: + jaw_points = [(p[0], p[1]) for p in jawline] + + # 关键点 + left_jaw = jaw_points[2] # 左下颌角 + jaw_tip = jaw_points[8] # 下巴尖 + right_jaw = jaw_points[14] # 右下颌角 + + # 对称性评分 - 改进计算 + left_dist = ( + (left_jaw[0] - jaw_tip[0]) ** 2 + (left_jaw[1] - jaw_tip[1]) ** 2 + ) ** 0.5 + right_dist = ( + (right_jaw[0] - jaw_tip[0]) ** 2 + (right_jaw[1] - jaw_tip[1]) ** 2 + ) ** 0.5 + symmetry = 1 - min( + abs(left_dist - right_dist) / max(left_dist, right_dist), 0.5 + ) + + # 下颌角度评分 + left_angle_y = abs(left_jaw[1] - jaw_tip[1]) + right_angle_y = abs(right_jaw[1] - jaw_tip[1]) + avg_angle = (left_angle_y + right_angle_y) / 2 + + # 理想的下颌角度 + face_height = image_shape[0] + ideal_angle_ratio = 0.08 + angle_ratio = avg_angle / face_height + angle_score = max( + 0, 1 - abs(angle_ratio - ideal_angle_ratio) / ideal_angle_ratio + ) + + # 下颌线清晰度(通过点间距离变化评估) + smoothness_score = 0.8 # 简化处理,可以根据实际需要改进 + + final_score = ( + symmetry * 0.4 + angle_score * 0.35 + smoothness_score * 0.25 + ) * 10 + return round(max(0, min(10, final_score)), 2) + except: + return 6.21 + + def _calculate_harmony(self, landmarks: List[tuple], image_shape: tuple) -> float: + """计算五官协调性""" + try: + # 黄金比例检测 (简化版) + face_height = max([p[1] for p in landmarks]) - min( + [p[1] for p in landmarks] + ) + face_width = max([p[0] for p in landmarks]) - min([p[0] for p in landmarks]) + + # 理想比例约为1.618 + ratio = face_height / face_width if face_width > 0 else 1 + golden_ratio = 1.618 + harmony = 1 - abs(ratio - golden_ratio) / golden_ratio + + return max(0, min(10, harmony * 10)) + except: + return 6.21 + + def _calculate_harmony_new( + self, landmarks: List[tuple], image_shape: tuple + ) -> float: + """ + 计算五官协调性 - 优化版本 + 基于多个美学比例和对称性指标 + """ + try: + logger.info(f"face landmarks={len(landmarks)}") + if len(landmarks) < 68: # 假设使用68点面部关键点 + return 6.21 + + # 转换为numpy数组便于计算 + points = np.array(landmarks) + + # 1. 面部基础测量 + face_measurements = self._get_face_measurements(points) + + # 2. 计算多个协调性指标 + scores = [] + + # 黄金比例评分 (权重: 20%) + golden_score = self._calculate_golden_ratios(face_measurements) + logger.info(f"Golden ratio score={golden_score}") + scores.append(("golden_ratio", golden_score, 0.10)) + + # 对称性评分 (权重: 25%) + symmetry_score = self._calculate_facial_symmetry(face_measurements, points) + logger.info(f"Symmetry score={symmetry_score}") + scores.append(("symmetry", symmetry_score, 0.40)) + + # 三庭五眼比例 (权重: 20%) + proportion_score = self._calculate_classical_proportions(face_measurements) + logger.info(f"Three courts five eyes ratio={proportion_score}") + scores.append(("proportions", proportion_score, 0.05)) + + # 五官间距协调性 (权重: 15%) + spacing_score = self._calculate_feature_spacing(face_measurements) + logger.info(f"Facial feature spacing harmony={spacing_score}") + scores.append(("spacing", spacing_score, 0)) + + # 面部轮廓协调性 (权重: 10%) + contour_score = self._calculate_contour_harmony(points) + logger.info(f"Facial contour harmony={contour_score}") + scores.append(("contour", contour_score, 0.05)) + + # 眼鼻口比例协调性 (权重: 10%) + feature_score = self._calculate_feature_proportions(face_measurements) + logger.info(f"Eye-nose-mouth proportion harmony={feature_score}") + scores.append(("features", feature_score, 0.40)) + + # 加权平均计算最终得分 + final_score = sum(score * weight for _, score, weight in scores) + logger.info(f"Weighted average final score={final_score}") + return max(0, min(10, final_score)) + + except Exception as e: + logger.error(f"Error calculating facial harmony: {e}") + traceback.print_exc() # ← 打印完整堆栈,包括确切行号 + return 6.21 + + def _get_face_measurements(self, points: np.ndarray) -> Dict[str, float]: + """提取面部关键测量数据""" + measurements = {} + + # 面部轮廓点 (0-16) + face_contour = points[0:17] + + # 眉毛点 (17-26) + left_eyebrow = points[17:22] + right_eyebrow = points[22:27] + + # 眼睛点 (36-47) + left_eye = points[36:42] + right_eye = points[42:48] + + # 鼻子点 (27-35) + nose = points[27:36] + + # 嘴巴点 (48-67) + mouth = points[48:68] + + # 基础测量 + measurements["face_width"] = np.max(face_contour[:, 0]) - np.min( + face_contour[:, 0] + ) + measurements["face_height"] = np.max(points[:, 1]) - np.min(points[:, 1]) + + # 眼部测量 + measurements["left_eye_width"] = np.max(left_eye[:, 0]) - np.min(left_eye[:, 0]) + measurements["right_eye_width"] = np.max(right_eye[:, 0]) - np.min( + right_eye[:, 0] + ) + measurements["eye_distance"] = np.min(right_eye[:, 0]) - np.max(left_eye[:, 0]) + measurements["left_eye_center"] = np.mean(left_eye, axis=0) + measurements["right_eye_center"] = np.mean(right_eye, axis=0) + + # 鼻部测量 + measurements["nose_width"] = np.max(nose[:, 0]) - np.min(nose[:, 0]) + measurements["nose_height"] = np.max(nose[:, 1]) - np.min(nose[:, 1]) + measurements["nose_tip"] = points[33] # 鼻尖 + + # 嘴部测量 + measurements["mouth_width"] = np.max(mouth[:, 0]) - np.min(mouth[:, 0]) + measurements["mouth_height"] = np.max(mouth[:, 1]) - np.min(mouth[:, 1]) + + # 关键垂直距离 + measurements["forehead_height"] = measurements["left_eye_center"][1] - np.min( + points[:, 1] + ) + measurements["middle_face_height"] = ( + measurements["nose_tip"][1] - measurements["left_eye_center"][1] + ) + measurements["lower_face_height"] = ( + np.max(points[:, 1]) - measurements["nose_tip"][1] + ) + + return measurements + + def _calculate_golden_ratios(self, measurements: Dict[str, float]) -> float: + """计算黄金比例相关得分""" + golden_ratio = 1.618 + scores = [] + + # 面部长宽比 + if measurements["face_width"] > 0: + face_ratio = measurements["face_height"] / measurements["face_width"] + score = 1 - abs(face_ratio - golden_ratio) / golden_ratio + scores.append(max(0, score)) + + # 上中下三庭比例 + total_height = ( + measurements["forehead_height"] + + measurements["middle_face_height"] + + measurements["lower_face_height"] + ) + + if total_height > 0: + upper_ratio = measurements["forehead_height"] / total_height + middle_ratio = measurements["middle_face_height"] / total_height + lower_ratio = measurements["lower_face_height"] / total_height + + # 理想比例约为 1:1:1 + ideal_ratio = 1 / 3 + upper_score = 1 - abs(upper_ratio - ideal_ratio) / ideal_ratio + middle_score = 1 - abs(middle_ratio - ideal_ratio) / ideal_ratio + lower_score = 1 - abs(lower_ratio - ideal_ratio) / ideal_ratio + + scores.extend( + [max(0, upper_score), max(0, middle_score), max(0, lower_score)] + ) + + return np.mean(scores) * 10 if scores else 7.0 + + def _calculate_facial_symmetry( + self, measurements: Dict[str, float], points: np.ndarray + ) -> float: + """计算面部对称性""" + # 计算面部中线 + face_center_x = np.mean(points[:, 0]) + + # 检查左右对称的关键点对 + symmetry_pairs = [ + (17, 26), # 眉毛外端 + (18, 25), # 眉毛 + (19, 24), # 眉毛 + (36, 45), # 眼角 + (39, 42), # 眼角 + (31, 35), # 鼻翼 + (48, 54), # 嘴角 + (4, 12), # 面部轮廓 + (5, 11), # 面部轮廓 + (6, 10), # 面部轮廓 + ] + + symmetry_scores = [] + + for left_idx, right_idx in symmetry_pairs: + if left_idx < len(points) and right_idx < len(points): + left_point = points[left_idx] + right_point = points[right_idx] + + # 计算到中线的距离差异 + left_dist = abs(left_point[0] - face_center_x) + right_dist = abs(right_point[0] - face_center_x) + + # 垂直位置差异 + vertical_diff = abs(left_point[1] - right_point[1]) + + # 对称性得分 + if left_dist + right_dist > 0: + horizontal_symmetry = 1 - abs(left_dist - right_dist) / ( + left_dist + right_dist + ) + vertical_symmetry = 1 - vertical_diff / measurements.get( + "face_height", 100 + ) + + symmetry_scores.append( + (horizontal_symmetry + vertical_symmetry) / 2 + ) + + return np.mean(symmetry_scores) * 10 if symmetry_scores else 7.0 + + def _calculate_classical_proportions(self, measurements: Dict[str, float]) -> float: + """计算经典美学比例 (三庭五眼等)""" + scores = [] + + # 五眼比例检测 + if measurements["face_width"] > 0: + eye_width_avg = ( + measurements["left_eye_width"] + measurements["right_eye_width"] + ) / 2 + ideal_eye_count = 5 # 理想情况下面宽应该等于5个眼宽 + actual_eye_count = ( + measurements["face_width"] / eye_width_avg if eye_width_avg > 0 else 5 + ) + + eye_proportion_score = ( + 1 - abs(actual_eye_count - ideal_eye_count) / ideal_eye_count + ) + scores.append(max(0, eye_proportion_score)) + + # 眼间距比例 + if measurements.get("left_eye_width", 0) > 0: + eye_spacing_ratio = ( + measurements["eye_distance"] / measurements["left_eye_width"] + ) + ideal_spacing_ratio = 1.0 # 理想情况下眼间距约等于一个眼宽 + + spacing_score = ( + 1 - abs(eye_spacing_ratio - ideal_spacing_ratio) / ideal_spacing_ratio + ) + scores.append(max(0, spacing_score)) + + # 鼻宽与眼宽比例 + if ( + measurements.get("left_eye_width", 0) > 0 + and measurements.get("nose_width", 0) > 0 + ): + nose_eye_ratio = measurements["nose_width"] / measurements["left_eye_width"] + ideal_nose_eye_ratio = 0.8 # 理想鼻宽约为眼宽的80% + + nose_score = ( + 1 - abs(nose_eye_ratio - ideal_nose_eye_ratio) / ideal_nose_eye_ratio + ) + scores.append(max(0, nose_score)) + + return np.mean(scores) * 10 if scores else 7.0 + + def _calculate_feature_spacing(self, measurements: Dict[str, float]) -> float: + """计算五官间距协调性""" + scores = [] + + # 眼鼻距离协调性 + eye_nose_distance = abs( + measurements["left_eye_center"][1] - measurements["nose_tip"][1] + ) + if measurements.get("face_height", 0) > 0: + eye_nose_ratio = eye_nose_distance / measurements["face_height"] + ideal_ratio = 0.15 # 理想比例 + score = 1 - abs(eye_nose_ratio - ideal_ratio) / ideal_ratio + scores.append(max(0, score)) + + # 鼻嘴距离协调性 + nose_mouth_distance = abs( + measurements["nose_tip"][1] - np.mean([measurements.get("mouth_height", 0)]) + ) + if measurements.get("face_height", 0) > 0: + nose_mouth_ratio = nose_mouth_distance / measurements["face_height"] + ideal_ratio = 0.12 # 理想比例 + score = 1 - abs(nose_mouth_ratio - ideal_ratio) / ideal_ratio + scores.append(max(0, score)) + + return np.mean(scores) * 10 if scores else 7.0 + + def _calculate_contour_harmony(self, points: np.ndarray) -> float: + """计算面部轮廓协调性""" + try: + face_contour = points[0:17] # 面部轮廓点 + + # 计算轮廓的平滑度 + smoothness_scores = [] + + for i in range(1, len(face_contour) - 1): + # 计算相邻三点形成的角度 + p1, p2, p3 = face_contour[i - 1], face_contour[i], face_contour[i + 1] + + v1 = p1 - p2 + v2 = p3 - p2 + + # 计算角度 + cos_angle = np.dot(v1, v2) / ( + np.linalg.norm(v1) * np.linalg.norm(v2) + 1e-8 + ) + angle = np.arccos(np.clip(cos_angle, -1, 1)) + + # 角度越接近平滑曲线越好 (避免过于尖锐的角度) + smoothness = 1 - abs(angle - np.pi / 2) / (np.pi / 2) + smoothness_scores.append(max(0, smoothness)) + + return np.mean(smoothness_scores) * 10 if smoothness_scores else 7.0 + + except: + return 6.21 + + def _calculate_feature_proportions(self, measurements: Dict[str, float]) -> float: + """计算眼鼻口等五官内部比例协调性""" + scores = [] + + # 眼部比例 (长宽比) + left_eye_ratio = measurements.get("left_eye_width", 1) / max( + measurements.get("left_eye_width", 1) * 0.3, 1 + ) + right_eye_ratio = measurements.get("right_eye_width", 1) / max( + measurements.get("right_eye_width", 1) * 0.3, 1 + ) + + # 理想眼部长宽比约为3:1 + ideal_eye_ratio = 3.0 + left_eye_score = 1 - abs(left_eye_ratio - ideal_eye_ratio) / ideal_eye_ratio + right_eye_score = 1 - abs(right_eye_ratio - ideal_eye_ratio) / ideal_eye_ratio + + scores.extend([max(0, left_eye_score), max(0, right_eye_score)]) + + # 嘴部比例 + if measurements.get("mouth_height", 0) > 0: + mouth_ratio = measurements["mouth_width"] / measurements["mouth_height"] + ideal_mouth_ratio = 3.5 # 理想嘴部长宽比 + mouth_score = 1 - abs(mouth_ratio - ideal_mouth_ratio) / ideal_mouth_ratio + scores.append(max(0, mouth_score)) + + # 鼻部比例 + if measurements.get("nose_height", 0) > 0: + nose_ratio = measurements["nose_height"] / measurements["nose_width"] + ideal_nose_ratio = 1.5 # 理想鼻部长宽比 + nose_score = 1 - abs(nose_ratio - ideal_nose_ratio) / ideal_nose_ratio + scores.append(max(0, nose_score)) + + return np.mean(scores) * 10 if scores else 7.0 + + def _basic_facial_analysis(self, face_image) -> Dict[str, Any]: + """基础五官分析 (当dlib不可用时)""" + return { + "facial_features": { + "eyes": 7.0, + "nose": 7.0, + "mouth": 7.0, + "eyebrows": 7.0, + "jawline": 7.0, + }, + "harmony_score": 7.0, + "overall_facial_score": 7.0, + "analysis_method": "basic_estimation", + } + + def draw_facial_landmarks(self, face_image: np.ndarray) -> np.ndarray: + """ + 在人脸图像上绘制特征点 + :param face_image: 人脸图像 + :return: 带特征点标记的人脸图像 + """ + if not DLIB_AVAILABLE or self.face_mesh is None: + # 如果没有可用的面部网格检测器,直接返回原图 + return face_image.copy() + + try: + # 复制原图用于绘制 + annotated_image = face_image.copy() + + # MediaPipe需要RGB图像 + rgb_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB) + + # 检测关键点 + results = self.face_mesh.process(rgb_image) + + if not results.multi_face_landmarks: + logger.warning("No facial landmarks detected for drawing") + return annotated_image + + # 获取第一个面部的关键点 + face_landmarks = results.multi_face_landmarks[0] + + # 绘制所有关键点 + h, w = face_image.shape[:2] + for landmark in face_landmarks.landmark: + x = int(landmark.x * w) + y = int(landmark.y * h) + # 绘制小圆点表示关键点 + cv2.circle(annotated_image, (x, y), 1, (0, 255, 0), -1) + + # 绘制十字标记 + cv2.line(annotated_image, (x-2, y), (x+2, y), (0, 255, 0), 1) + cv2.line(annotated_image, (x, y-2), (x, y+2), (0, 255, 0), 1) + + return annotated_image + + except Exception as e: + logger.error(f"Failed to draw facial landmarks: {e}") + return face_image.copy() diff --git a/gfpgan_restorer.py b/gfpgan_restorer.py new file mode 100644 index 0000000000000000000000000000000000000000..0524fcf8df67fa1c0310a4013ad5d13787e1e236 --- /dev/null +++ b/gfpgan_restorer.py @@ -0,0 +1,96 @@ +import os +import time + +from config import logger, MODELS_PATH +from gfpgan import GFPGANer + + +class GFPGANRestorer: + def __init__(self): + start_time = time.perf_counter() + self.restorer = None + self._initialize_model() + init_time = time.perf_counter() - start_time + if self.restorer is not None: + logger.info(f"GFPGANRestorer initialized successfully, time: {init_time:.3f}s") + else: + logger.info(f"GFPGANRestorer initialization completed but not available, time: {init_time:.3f}s") + + def _initialize_model(self): + """初始化GFPGAN模型""" + try: + # 尝试多个可能的模型路径 + possible_paths = [ + f"{MODELS_PATH}/GFPGANv1.4.pth", + f"{MODELS_PATH}/gfpgan/GFPGANv1.4.pth", + os.path.expanduser("~/.cache/gfpgan/GFPGANv1.4.pth"), + "./models/GFPGANv1.4.pth" + ] + + gfpgan_model_path = None + for path in possible_paths: + if os.path.exists(path): + gfpgan_model_path = path + break + + if not gfpgan_model_path: + logger.warning(f"GFPGAN model file not found, tried paths: {possible_paths}") + logger.info("Will try to download GFPGAN model from the internet...") + # 使用默认路径,让GFPGAN自动下载 + gfpgan_model_path = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' + + logger.info(f"Using GFPGAN model: {gfpgan_model_path}") + + # 初始化GFPGAN + self.restorer = GFPGANer( + model_path=gfpgan_model_path, + upscale=2, + arch='clean', + channel_multiplier=2, + bg_upsampler=None + ) + logger.info("GFPGAN model initialized successfully") + + except Exception as e: + logger.error(f"GFPGAN model initialization failed: {e}") + self.restorer = None + + + def is_available(self): + """检查GFPGAN是否可用""" + return self.restorer is not None + + def restore_image(self, image): + """ + 使用GFPGAN修复老照片 + :param image: 输入图像 (numpy array, BGR格式) + :return: 修复后的图像 (numpy array, BGR格式) + """ + if not self.is_available(): + raise Exception("GFPGAN模型未初始化") + + try: + logger.info("Starting GFPGAN image restoration...") + + # GFPGAN处理 + # has_aligned=False: 输入图像没有对齐 + # only_center_face=False: 处理所有检测到的人脸 + # paste_back=True: 将修复的人脸贴回原图 + cropped_faces, restored_faces, restored_img = self.restorer.enhance( + image, + has_aligned=False, + only_center_face=False, + paste_back=True + ) + + if restored_img is not None: + logger.info(f"GFPGAN restoration completed, detected {len(restored_faces)} faces") + return restored_img + else: + logger.warning("GFPGAN restoration returned empty image, using original image") + return image + + except Exception as e: + logger.error(f"GFPGAN image restoration failed: {e}") + # 如果GFPGAN失败,返回原图而不是抛出异常 + return image diff --git a/install.sh b/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..9a9b437401f7808a67bdae7ad270e2ddc46e0e31 --- /dev/null +++ b/install.sh @@ -0,0 +1,2 @@ +# pip install -r requirements.txt -i https://pypi.python.org/simple +pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..1cbe447759e297207ff3ca5a71d1173e9c321bdf --- /dev/null +++ b/models.py @@ -0,0 +1,69 @@ +from enum import Enum +from typing import List, Optional + +from pydantic import BaseModel + + +class ModelType(str, Enum): + """模型类型枚举""" + + HOWCUTEAMI = "howcuteami" + DEEPFACE = "deepface" + HYBRID = "hybrid" # 混合模式:颜值性别用howcuteami,年龄情绪用deepface + + +class ImageScoreItem(BaseModel): + file_path: str + score: float + is_cropped_face: bool = False + size_bytes: int + size_str: str + last_modified: str + nickname: Optional[str] = None + + +class SearchRequest(BaseModel): + keyword: Optional[str] = "" + searchType: Optional[str] = "face" + top_k: Optional[int] = 5 + score_threshold: float = 0.0 + nickname: Optional[str] = None + + +class ImageSearchRequest(BaseModel): + image: Optional[str] = None # base64编码的图片 + searchType: Optional[str] = "face" + top_k: Optional[int] = 5 + score_threshold: float = 0.0 + nickname: Optional[str] = None + + +class ImageFileList(BaseModel): + results: List[ImageScoreItem] + count: int + +class PagedImageFileList(BaseModel): + results: List[ImageScoreItem] + count: int + page: int + page_size: int + total_pages: int + +class CelebrityMatchResponse(BaseModel): + filename: str + display_name: Optional[str] = None + distance: float + similarity: float + confidence: float + face_filename: Optional[str] = None + + +class CategoryStatItem(BaseModel): + category: str + display_name: str + count: int + + +class CategoryStatsResponse(BaseModel): + stats: List[CategoryStatItem] + total: int diff --git a/push.sh b/push.sh new file mode 100644 index 0000000000000000000000000000000000000000..837fcf6639335dd9a4b265dc58e1f4afdc20b61e --- /dev/null +++ b/push.sh @@ -0,0 +1,2 @@ +#!/bin/bash +git push -f origin main diff --git a/realesrgan_upscaler.py b/realesrgan_upscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..976841e9ad19f7957c8ec8af5154a9da98c66d60 --- /dev/null +++ b/realesrgan_upscaler.py @@ -0,0 +1,235 @@ +import os +import time + +import cv2 +import numpy as np + +from config import logger, MODELS_PATH, REALESRGAN_MODEL + +try: + from basicsr.archs.rrdbnet_arch import RRDBNet + from basicsr.utils.download_util import load_file_from_url + from realesrgan import RealESRGANer + from realesrgan.archs.srvgg_arch import SRVGGNetCompact + import torch + + # 设置PyTorch CPU优化 + torch.set_num_threads(min(torch.get_num_threads(), 8)) # 限制线程数 + torch.set_num_interop_threads(min(4, torch.get_num_interop_threads())) # 设置操作间线程数 + + REALESRGAN_AVAILABLE = True + logger.info("Real-ESRGAN imported successfully") +except ImportError as e: + logger.error(f"Real-ESRGAN import failed: {e}") + REALESRGAN_AVAILABLE = False + + +class RealESRGANUpscaler: + """Real-ESRGAN超清放大处理器""" + + def __init__(self): + start_time = time.perf_counter() + self.upsampler = None + self.model_name = None + self.scale = 4 + self.denoise_strength = 0.5 + self._initialize() + init_time = time.perf_counter() - start_time + if self.upsampler is not None: + logger.info(f"RealESRGANUpscaler initialized successfully, time: {init_time:.3f}s") + else: + logger.info(f"RealESRGANUpscaler initialization completed but not available, time: {init_time:.3f}s") + + def _initialize(self): + """初始化Real-ESRGAN模型""" + if not REALESRGAN_AVAILABLE: + logger.error("Real-ESRGAN is not available, cannot initialize super resolution processor") + return + + try: + # 模型配置 - 从环境变量读取模型名称 + model_name = REALESRGAN_MODEL + self.model_name = model_name + + # 根据模型名称设置默认放大倍数 + if 'x2' in model_name: + self.scale = 2 + elif 'x4' in model_name: + self.scale = 4 + else: + self.scale = 4 # 默认4倍 + + # 模型文件路径 + model_path = None + if model_name == 'RealESRGAN_x4plus': + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + netscale = 4 + file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth' + elif model_name == 'RealESRNet_x4plus': + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + netscale = 4 + file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth' + elif model_name == 'RealESRGAN_x4plus_anime_6B': + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + netscale = 4 + file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth' + elif model_name == 'RealESRGAN_x2plus': + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + netscale = 2 + file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth' + elif model_name == 'realesr-animevideov3': + model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') + netscale = 4 + file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth' + elif model_name == 'realesr-general-x4v3': + # 最新的通用模型 v0.2.5.0 + model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') + netscale = 4 + file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' + elif model_name == 'realesr-general-wdn-x4v3': + # 最新的通用模型(带去噪)v0.2.5.0 + model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') + netscale = 4 + file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth' + + # 确保模型目录存在 + model_dir = os.path.join(MODELS_PATH, 'realesrgan') + os.makedirs(model_dir, exist_ok=True) + + # 检查本地是否已有模型文件 + local_model_path = None + model_filename = f"{model_name}.pth" + local_pth = os.path.join(MODELS_PATH, model_filename) + + if os.path.exists(local_pth): + local_model_path = local_pth + logger.info(f"Using local model file: {local_model_path}") + + # 如果本地有模型文件,使用本地文件,否则下载 + if local_model_path: + model_path = local_model_path + else: + # 下载模型 + logger.info(f"Downloading model {model_name} from {file_url}") + model_path = load_file_from_url( + url=file_url, model_dir=model_dir, progress=True, file_name=model_filename) + + # 创建upsampler + self.upsampler = RealESRGANer( + scale=netscale, + model_path=model_path, + model=model, + tile=512, # 启用分块处理,减少内存使用并提高CPU效率 + tile_pad=10, + pre_pad=0, + half=False, # 使用fp32精度 + gpu_id=None # 使用CPU + ) + + logger.info(f"Real-ESRGAN super resolution processor initialized successfully, model: {model_name}") + + except Exception as e: + logger.error(f"Failed to initialize Real-ESRGAN: {e}") + self.upsampler = None + + def is_available(self): + """检查处理器是否可用""" + return REALESRGAN_AVAILABLE and self.upsampler is not None + + def _optimize_input_image(self, image): + """ + 优化输入图像以提高CPU处理速度 + :param image: 输入图像 + :return: 优化后的图像 + """ + # 确保图像数据类型为uint8(减少计算开销) + if image.dtype != np.uint8: + if image.dtype == np.float32 or image.dtype == np.float64: + image = (image * 255).astype(np.uint8) + else: + image = image.astype(np.uint8) + + # 确保图像是3通道BGR格式 + if len(image.shape) == 2: # 灰度图 + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + elif image.shape[2] == 4: # RGBA + image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) + elif image.shape[2] == 3 and image.shape[2] != 3: # RGB转BGR + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + return image + + def upscale_image(self, image, scale=None, denoise_strength=None): + """ + 对图像进行超清放大 + :param image: 输入图像 (numpy array) + :param scale: 放大倍数,默认使用模型的放大倍数 + :param denoise_strength: 去噪强度 (0-1),仅对realesr-general-x4v3模型有效 + :return: 超清后的图像 + """ + if not self.is_available(): + raise RuntimeError("Real-ESRGAN超清处理器不可用") + + try: + start_time = time.perf_counter() + + # 预处理优化图像 + image = self._optimize_input_image(image) + + # 设置去噪强度(仅对特定模型有效) + if denoise_strength is not None and self.model_name == 'realesr-general-x4v3': + self.denoise_strength = denoise_strength + + # 根据图像大小动态调整tile大小以优化CPU性能 + h, w = image.shape[:2] + pixel_count = h * w + + # 根据图像大小调整tile大小 + if pixel_count > 2000000: # 大于2MP + tile_size = 256 + elif pixel_count > 1000000: # 大于1MP + tile_size = 384 + else: + tile_size = 512 + + # 动态更新tile大小 + if hasattr(self.upsampler, 'tile'): + self.upsampler.tile = tile_size + logger.info(f"Adjusting tile size to: {tile_size} based on image size ({w}x{h})") + + # 执行超清处理 + logger.info(f"Starting Real-ESRGAN super resolution processing, model: {self.model_name}") + output, _ = self.upsampler.enhance(image, outscale=scale or self.scale) + + processing_time = time.perf_counter() - start_time + logger.info(f"Real-ESRGAN super resolution processing completed, time: {processing_time:.3f}s") + + return output + + except Exception as e: + logger.error(f"Real-ESRGAN super resolution processing failed: {e}") + raise RuntimeError(f"超清处理失败: {str(e)}") + + def get_model_info(self): + """获取模型信息""" + return { + "model_name": self.model_name, + "scale": self.scale, + "available": self.is_available() + } + + +def get_upscaler(): + """获取Real-ESRGAN超清处理器实例""" + return RealESRGANUpscaler() + + +# 全局实例(单例模式) +_upscaler_instance = None + +def get_upscaler(): + """获取全局超清处理器实例""" + global _upscaler_instance + if _upscaler_instance is None: + _upscaler_instance = RealESRGANUpscaler() + return _upscaler_instance diff --git a/rembg_processor.py b/rembg_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..b04cc9edd4c2dfd0d17715034ddfc796f276c10c --- /dev/null +++ b/rembg_processor.py @@ -0,0 +1,136 @@ +import time +from typing import Optional, Tuple + +import cv2 +import numpy as np + +from config import logger, REMBG_AVAILABLE + +if REMBG_AVAILABLE: + import rembg + from rembg import new_session + from PIL import Image + + +class RembgProcessor: + """rembg抠图处理器""" + + def __init__(self): + start_time = time.perf_counter() + self.session = None + self.available = False + self.model_name = "u2net" # 默认使用u2net模型,适合人像抠图 + + if REMBG_AVAILABLE: + try: + # 初始化rembg会话 + self.session = new_session(self.model_name) + self.available = True + logger.info(f"rembg background removal processor initialized successfully, using model: {self.model_name}") + except Exception as e: + logger.error(f"rembg background removal processor initialization failed: {e}") + self.available = False + else: + logger.warning("rembg is not available, background removal function will be disabled") + init_time = time.perf_counter() - start_time + if self.available: + logger.info(f"RembgProcessor initialized successfully, time: {init_time:.3f}s") + else: + logger.info(f"RembgProcessor initialization completed but not available, time: {init_time:.3f}s") + + def is_available(self) -> bool: + """检查抠图处理器是否可用""" + return self.available and self.session is not None + + def remove_background(self, image: np.ndarray, background_color: Optional[Tuple[int, int, int]] = None) -> np.ndarray: + """ + 移除图片背景 + :param image: 输入的OpenCV图像(BGR格式) + :param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景 + :return: 处理后的图像 + """ + if not self.is_available(): + raise Exception("rembg抠图处理器不可用") + + try: + # 将OpenCV图像(BGR)转换为PIL图像(RGB) + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(image_rgb) + + # 使用rembg移除背景 + logger.info("Starting to remove background using rembg...") + output_image = rembg.remove(pil_image, session=self.session) + + # 转换回OpenCV格式 + if background_color is not None: + # 如果指定了背景颜色,创建纯色背景 + background = Image.new('RGB', output_image.size, background_color[::-1]) # BGR转RGB + # 将透明图像粘贴到背景上 + background.paste(output_image, mask=output_image) + result_array = np.array(background) + result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR) + else: + # 保持透明背景,转换为BGRA格式 + result_array = np.array(output_image) + if result_array.shape[2] == 4: # RGBA格式 + # 转换RGBA到BGRA + result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGBA2BGRA) + else: # RGB格式 + result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR) + + logger.info("rembg background removal completed") + return result_bgr + + except Exception as e: + logger.error(f"rembg background removal failed: {e}") + raise Exception(f"背景移除失败: {str(e)}") + + def create_id_photo(self, image: np.ndarray, background_color: Tuple[int, int, int] = (255, 255, 255)) -> np.ndarray: + """ + 创建证件照(移除背景并添加纯色背景) + :param image: 输入的OpenCV图像 + :param background_color: 背景颜色,默认白色(BGR格式) + :return: 处理后的证件照 + """ + logger.info(f"Starting to create ID photo, background color: {background_color}") + + # 移除背景并添加指定颜色背景 + id_photo = self.remove_background(image, background_color) + + logger.info("ID photo creation completed") + return id_photo + + def get_supported_models(self) -> list: + """获取支持的模型列表""" + if not REMBG_AVAILABLE: + return [] + + # rembg支持的模型列表 + return [ + "u2net", # 通用模型,适合人像 + "u2net_human_seg", # 专门针对人像的模型 + "silueta", # 适合物体抠图 + "isnet-general-use" # 更精确的通用模型 + ] + + def switch_model(self, model_name: str) -> bool: + """ + 切换rembg模型 + :param model_name: 模型名称 + :return: 是否切换成功 + """ + if not REMBG_AVAILABLE: + return False + + try: + if model_name in self.get_supported_models(): + self.session = new_session(model_name) + self.model_name = model_name + logger.info(f"rembg model switched to: {model_name}") + return True + else: + logger.error(f"Unsupported model: {model_name}") + return False + except Exception as e: + logger.error(f"Failed to switch model: {e}") + return False diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7c05be11ad3583bbe185ef75112d9cb9ad67a7cc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,72 @@ +# 固定NumPy版本避免兼容性问题 - 必须最先安装 +numpy>=1.24.0,<2.0.0 + +# 基础依赖 +fastapi>=0.104.0 +uvicorn[standard]>=0.24.0 +python-multipart>=0.0.6 +aiofiles>=23.2.1 + +# 图像处理 +opencv-python>=4.8.0 +Pillow>=10.0.0 + +# PyTorch 相关包 - 升级到2.x版本解决依赖冲突 +torch>=2.0.0,<2.9.0 +torchvision>=0.15.0 + +# 机器学习和CV相关 +tf-keras +aiohttp +ultralytics +deepface +mediapipe>=0.10.0 +# ModelScope相关包 - 让pip自动解决版本依赖 +modelscope==1.28.2 +datasets==2.21.0 +transformers==4.40.0 +# ModelScope DDColor的额外依赖 +timm==1.0.19 +sortedcontainers==2.4.0 +fsspec==2024.6.1 +multiprocess==0.70.16 +xxhash==3.5.0 +dill==0.3.8 +huggingface-hub==0.34.3 +# 修复pyarrow兼容性问题 - 使用稳定版本 +pyarrow==20.0.0 + +# API相关 +pydantic>=2.4.0 +starlette>=0.27.0 +simplejson==3.20.1 +# 科学计算和工具 +scipy>=1.7.0,<1.13.0 +tqdm +lmdb +pyyaml + +# 定时任务 +apscheduler>=3.10.0 + +# 数据库 +aiomysql>=0.2.0 + +# 对象存储 +boto3>=1.34.0 + +# GFPGAN 和相关包 - 修复依赖兼容性 +basicsr>=1.3.3 +facexlib>=0.2.5 +gfpgan>=1.3.0 +realesrgan>=0.3.0 + +# CLIP 相关依赖 +cn_clip +faiss-cpu +onnxruntime +diffusers +accelerate +# rembg 抠图处理 +rembg>=2.0.50 +easydict diff --git a/rvm_processor.py b/rvm_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..a50de8ad0ef8d0cb5facaebe8087ddf1be7bd444 --- /dev/null +++ b/rvm_processor.py @@ -0,0 +1,132 @@ +import os +import cv2 +import numpy as np +import torch +from torchvision import transforms + +import config +from config import logger + + +class RVMProcessor: + """RVM (Robust Video Matting) 抠图处理器""" + + def __init__(self): + self.model = None + self.available = False + self.device = "cpu" # 默认使用CPU,如果有GPU可以设置为"cuda" + + try: + # 仅从本地加载,不使用网络 + local_repo = getattr(config, 'RVM_LOCAL_REPO', '') + weights_path = getattr(config, 'RVM_WEIGHTS_PATH', '') + + if not local_repo or not os.path.isdir(local_repo): + raise RuntimeError("RVM_LOCAL_REPO not set or invalid. Please set env RVM_LOCAL_REPO to local RobustVideoMatting repo path (with hubconf.py)") + + if not weights_path or not os.path.isfile(weights_path): + raise RuntimeError("RVM_WEIGHTS_PATH not set or file not found. Please set env RVM_WEIGHTS_PATH to local RVM weights file path") + + logger.info(f"Loading RVM model {config.RVM_MODEL} from local repo: {local_repo}") + # 使用本地仓库构建模型,禁用预训练以避免联网 + self.model = torch.hub.load(local_repo, config.RVM_MODEL, source='local', pretrained=False) + + # 加载本地权重 + state = torch.load(weights_path, map_location=self.device) + if isinstance(state, dict) and 'state_dict' in state: + state = state['state_dict'] + missing, unexpected = self.model.load_state_dict(state, strict=False) + + # 迁移到设备并设置评估模式 + self.model = self.model.to(self.device).eval() + self.available = True + logger.info("RVM background removal processor initialized successfully (local mode)") + if missing: + logger.warning(f"RVM weights missing keys: {list(missing)[:5]}... total={len(missing)}") + if unexpected: + logger.warning(f"RVM weights unexpected keys: {list(unexpected)[:5]}... total={len(unexpected)}") + + except Exception as e: + logger.error(f"RVM background removal processor initialization failed: {e}") + self.available = False + + def is_available(self) -> bool: + """检查RVM处理器是否可用""" + return self.available and self.model is not None + + def remove_background(self, image: np.ndarray, background_color: tuple = None) -> np.ndarray: + """ + 使用RVM移除图片背景 + :param image: 输入的OpenCV图像(BGR格式) + :param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景 + :return: 处理后的图像 + """ + if not self.is_available(): + raise Exception("RVM抠图处理器不可用") + + try: + logger.info("Starting to remove background using RVM...") + + # 保存原始图像尺寸 + original_height, original_width = image.shape[:2] + + # 将OpenCV图像(BGR)转换为RGB格式 + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # 转换为tensor + src = transforms.ToTensor()(image_rgb).unsqueeze(0).to(self.device) + + # 推理 + rec = [None] * 4 + with torch.no_grad(): + fgr, pha, *rec = self.model(src, *rec, downsample_ratio=0.25) + + # 转换为numpy数组 + fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # (H,W,3) + pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) # (H,W) + + # 检查尺寸是否匹配,如果不匹配则调整 + if fgr.shape[:2] != (original_height, original_width): + fgr = cv2.resize(fgr, (original_width, original_height)) + pha = cv2.resize(pha, (original_width, original_height)) + + if background_color is not None: + # 如果指定了背景颜色,创建纯色背景 + # 将前景图像转换为BGR格式 + fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR) + + # 创建背景图像 + background = np.full((original_height, original_width, 3), background_color, dtype=np.uint8) + + # 使用alpha混合 + alpha = pha.astype(np.float32) / 255.0 + alpha = np.stack([alpha] * 3, axis=-1) + + result = (fgr_bgr * alpha + background * (1 - alpha)).astype(np.uint8) + else: + # 保持透明背景,转换为BGRA格式 + fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR) + rgba = np.dstack((fgr_bgr, pha)) # (H,W,4) + result = rgba + + logger.info("RVM background removal completed") + return result + + except Exception as e: + logger.error(f"RVM background removal failed: {e}") + raise Exception(f"背景移除失败: {str(e)}") + + def create_id_photo(self, image: np.ndarray, background_color: tuple = (255, 255, 255)) -> np.ndarray: + """ + 创建证件照(移除背景并添加纯色背景) + :param image: 输入的OpenCV图像 + :param background_color: 背景颜色,默认白色(BGR格式) + :return: 处理后的证件照 + """ + logger.info(f"Starting to create ID photo, background color: {background_color}") + + # 移除背景并添加指定颜色背景 + id_photo = self.remove_background(image, background_color) + + logger.info("ID photo creation completed") + return id_photo diff --git a/start_local.sh b/start_local.sh new file mode 100644 index 0000000000000000000000000000000000000000..2d4fa5fd629baf207a5680c11581b63f28b85117 --- /dev/null +++ b/start_local.sh @@ -0,0 +1,52 @@ +#!/bin/bash +export TZ=Asia/Shanghai + +export OUTPUT_DIR=/opt/data/output +export IMAGES_DIR=/opt/data/images +export MODELS_PATH=/opt/data/models +export DEEPFACE_HOME=/opt/data/models +export FAISS_INDEX_DIR=/opt/data/faiss +export CELEBRITY_SOURCE_DIR=/opt/data/chinese_celeb_dataset +export GENDER_CONFIDENCE=1 +export UPSCALE_SIZE=2 +export AGE_CONFIDENCE=1.0 +export DRAW_SCORE=true +export FACE_CONFIDENCE=0.7 + +export ENABLE_DDCOLOR=true +export ENABLE_GFPGAN=true +export ENABLE_REALESRGAN=true +export ENABLE_ANIME_STYLE=true +export ENABLE_RVM=true +export ENABLE_REMBG=true +export ENABLE_CLIP=false + +export CLEANUP_INTERVAL_HOURS=1 +export CLEANUP_AGE_HOURS=1 + +export BEAUTY_ADJUST_GAMMA=0.8 +export BEAUTY_ADJUST_MIN=1.0 +export BEAUTY_ADJUST_MAX=9.0 +export ENABLE_ANIME_PRELOAD=true +export ENABLE_LOGGING=true +export BEAUTY_ADJUST_ENABLED=true + +export RVM_LOCAL_REPO=/opt/data/models/RobustVideoMatting +export RVM_WEIGHTS_PATH=/opt/data/models/torch/hub/checkpoints/rvm_resnet50.pth +export RVM_MODEL=resnet50 + +export AUTO_INIT_GFPGAN=false +export AUTO_INIT_DDCOLOR=false +export AUTO_INIT_REALESRGAN=false +export AUTO_INIT_ANIME_STYLE=true +export AUTO_INIT_CLIP=false +export AUTO_INIT_RVM=false +export AUTO_INIT_REMBG=false + +export ENABLE_WARMUP=true +export REALESRGAN_MODEL=realesr-general-x4v3 +export CELEBRITY_FIND_THRESHOLD=0.87 +export FEMALE_AGE_ADJUSTMENT=4 + +uvicorn app:app --workers 1 --loop asyncio --http httptools --host 0.0.0.0 --port 7860 --timeout-keep-alive 600 + diff --git a/test/celebrity_crawler.py b/test/celebrity_crawler.py new file mode 100644 index 0000000000000000000000000000000000000000..ab1bff7ca76884239ba8e8fbd2972a8edffbcbac --- /dev/null +++ b/test/celebrity_crawler.py @@ -0,0 +1,227 @@ +import time +from io import BytesIO +from pathlib import Path + +import requests +from PIL import Image + + +class CelebrityCrawler: + def __init__(self, output_dir="celebrity_images"): + self.output_dir = output_dir + self.headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' + } + Path(output_dir).mkdir(parents=True, exist_ok=True) + + def read_celebrities_from_txt(self, file_path): + """ + 从txt文件读取明星信息 + 支持格式: + 1. 姓名,职业 + 2. 姓名 + """ + celebrities = [] + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line or line.startswith('#'): + continue + + parts = line.split(',') + name = parts[0].strip() + profession = parts[1].strip() if len(parts) > 1 else "明星" + + celebrities.append({ + 'name': name, + 'profession': profession + }) + return celebrities + + def search_bing_images(self, celebrity_name, max_images=20): + """使用Bing图片搜索API获取图片URL""" + search_url = "https://www.bing.com/images/search" + params = { + 'q': celebrity_name + " 明星", + 'first': 0, + 'count': max_images + } + + try: + response = requests.get(search_url, params=params, headers=self.headers, + timeout=10) + response.raise_for_status() + + # 简单的HTML解析获取图片URL + import re + img_urls = re.findall(r'murl":"(.*?)"', response.text) + return img_urls[:max_images] + except Exception as e: + print(f"搜索 {celebrity_name} 时出错: {e}") + return [] + + def search_baidu_images(self, celebrity_name, max_images=20): + """使用百度图片搜索获取图片URL""" + search_url = "https://image.baidu.com/search/acjson" + params = { + 'tn': 'resultjson_com', + 'word': celebrity_name + " 明星", + 'pn': 0, + 'rn': max_images, + 'ie': 'utf-8' + } + + try: + response = requests.get(search_url, params=params, headers=self.headers, + timeout=10) + response.raise_for_status() + data = response.json() + + img_urls = [] + if 'data' in data: + for item in data['data']: + if 'thumbURL' in item: + img_urls.append(item['thumbURL']) + return img_urls[:max_images] + except Exception as e: + print(f"搜索 {celebrity_name} 时出错: {e}") + return [] + + def download_image(self, url, save_path): + """下载单张图片""" + try: + response = requests.get(url, headers=self.headers, timeout=15) + response.raise_for_status() + + # 验证是否为有效图片 + img = Image.open(BytesIO(response.content)) + + # 过滤太小的图片 + if img.size[0] < 100 or img.size[1] < 100: + return False + + # 保存图片 + img = img.convert('RGB') + img.save(save_path, 'JPEG', quality=95) + return True + except Exception as e: + print(f" 下载失败: {str(e)[:50]}") + return False + + def crawl_celebrity_images(self, celebrity, max_images=20, + search_engine='baidu'): + """爬取单个明星的图片""" + name = celebrity['name'] + print(f"\n正在爬取: {name} ({celebrity['profession']})") + + # 创建明星专属文件夹 + celebrity_dir = Path(self.output_dir) + celebrity_dir.mkdir(parents=True, exist_ok=True) + + # 获取图片URL列表 + if search_engine == 'baidu': + img_urls = self.search_baidu_images(name, max_images * 2) + else: + img_urls = self.search_bing_images(name, max_images * 2) + + if not img_urls: + print(f" 未找到 {name} 的图片") + return 0 + + print(f" 找到 {len(img_urls)} 个图片链接") + + # 下载图片 + success_count = 0 + for idx, url in enumerate(img_urls): + if success_count >= max_images: + break + + save_path = celebrity_dir / f"{name}_{idx + 1:03d}.jpg" + + # 跳过已存在的文件 + if save_path.exists(): + success_count += 1 + continue + + print(f" 下载 {idx + 1}/{len(img_urls)}...", end=' ') + if self.download_image(url, save_path): + success_count += 1 + print("✓") + else: + print("✗") + + # 避免请求过快 + time.sleep(0.5) + + print(f" 成功下载 {success_count} 张图片") + return success_count + + def crawl_all(self, txt_file, max_images_per_celebrity=20, + search_engine='baidu'): + """爬取所有明星的图片""" + print("=" * 60) + print("明星照片爬取工具") + print("=" * 60) + + # 读取明星列表 + celebrities = self.read_celebrities_from_txt(txt_file) + print(f"\n从 {txt_file} 读取到 {len(celebrities)} 位明星") + + # 统计信息 + total_images = 0 + failed_celebrities = [] + + # 爬取每位明星 + for i, celebrity in enumerate(celebrities, 1): + print(f"\n[{i}/{len(celebrities)}]", end=' ') + + try: + count = self.crawl_celebrity_images( + celebrity, + max_images=max_images_per_celebrity, + search_engine=search_engine + ) + total_images += count + + if count == 0: + failed_celebrities.append(celebrity['name']) + + # 每爬取5个明星后暂停一下 + if i % 5 == 0: + print(f"\n 已完成 {i}/{len(celebrities)}, 休息3秒...") + time.sleep(3) + + except Exception as e: + print(f" 处理 {celebrity['name']} 时出错: {e}") + failed_celebrities.append(celebrity['name']) + + # 输出统计 + print("\n" + "=" * 60) + print("爬取完成!") + print("=" * 60) + print(f"总明星数: {len(celebrities)}") + print(f"成功爬取: {len(celebrities) - len(failed_celebrities)}") + print(f"失败数量: {len(failed_celebrities)}") + print(f"总图片数: {total_images}") + print(f"保存位置: {self.output_dir}") + + if failed_celebrities: + print(f"\n失败的明星: {', '.join(failed_celebrities)}") + + +# 使用示例 +if __name__ == "__main__": + # 创建爬虫实例 + crawler = CelebrityCrawler(output_dir="celebrity_dataset") + + # 从txt文件爬取 + # txt文件格式示例: + # 周杰伦,歌手 + # 刘德华,演员 + # 范冰冰,演员 + + crawler.crawl_all( + txt_file="celebrity_real_names.txt", # 你的txt文件路径 + max_images_per_celebrity=1, # 每位明星爬取的图片数量 + search_engine='baidu' # 'baidu' 或 'bing' + ) diff --git a/test/celebrity_crawler.pyc b/test/celebrity_crawler.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92c6440784afca63a70bc58365efe4bc13fc01f3 Binary files /dev/null and b/test/celebrity_crawler.pyc differ diff --git a/test/decode_celeb_dataset.py b/test/decode_celeb_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4e1af81de80c6cfc2cb450891ccf0000a385579b --- /dev/null +++ b/test/decode_celeb_dataset.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +Decode base64 file names inside the Chinese celeb dataset directory. + +Default target: /Users/chenchaoyun/Downloads/chinese_celeb_dataset. +Use --root to override; --dry-run only prints the plan. +""" +import argparse +import base64 +from pathlib import Path +import sys + +DEFAULT_ROOT = Path("/Users/chenchaoyun/Downloads/chinese_celeb_dataset") + + +def _decode_basename(encoded: str) -> str: + padding = "=" * ((4 - len(encoded) % 4) % 4) + try: + return base64.urlsafe_b64decode( + (encoded + padding).encode("ascii")).decode("utf-8") + except Exception: + return encoded + + +def rename_dataset(root: Path, dry_run: bool = False) -> int: + if not root.exists(): + print(f"Directory does not exist: {root}", file=sys.stderr) + return 1 + if not root.is_dir(): + print(f"Not a directory: {root}", file=sys.stderr) + return 1 + + renamed = 0 + for file_path in sorted(root.rglob("*")): + if not file_path.is_file(): + continue + decoded = _decode_basename(file_path.stem) + if decoded == file_path.stem: + continue + + new_path = file_path.with_name(f"{decoded}{file_path.suffix}") + if new_path == file_path: + continue + + # Append a counter if the decoded target already exists + counter = 1 + while new_path.exists() and new_path != file_path: + new_path = file_path.with_name( + f"{decoded}_{counter}{file_path.suffix}" + ) + counter += 1 + + print(f"{file_path} -> {new_path}") + if dry_run: + continue + file_path.rename(new_path) + renamed += 1 + + print(f"Renamed {renamed} files") + return 0 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Decode chinese_celeb_dataset file names") + parser.add_argument( + "--root", + type=Path, + default=DEFAULT_ROOT, + help="Dataset root directory (default: %(default)s)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Only print planned renames without applying them", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + return rename_dataset(args.root.expanduser().resolve(), args.dry_run) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/test/decode_celeb_dataset.pyc b/test/decode_celeb_dataset.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49069413ab02712720b3a925abd7df22105eb8b7 Binary files /dev/null and b/test/decode_celeb_dataset.pyc differ diff --git a/test/dow_img.py b/test/dow_img.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7d1f44ce8d066caf9555fc45fffd5f90440d63 --- /dev/null +++ b/test/dow_img.py @@ -0,0 +1,24 @@ +import cv2 + +# 读取图片 +img = cv2.imread("/opt/data/header.png") + +# 设置压缩质量(0-100,值越小压缩越狠,质量越差) +quality = 50 + +# 写入压缩后的图像(注意必须是 .webp) +cv2.imwrite( + "/opt/data/output_small.webp", + img, + [int(cv2.IMWRITE_WEBP_QUALITY), quality], +) + + +# # 读取原图 +# img = cv2.imread("/opt/data/header.png") +# +# # 缩放图像(例如缩小为原图的一半) +# resized = cv2.resize(img, (img.shape[1] // 2, img.shape[0] // 2)) +# +# # 写入压缩图像,降低质量 +# cv2.imwrite("/opt/data/output_small.webp", resized, [int(cv2.IMWRITE_WEBP_QUALITY), 40]) diff --git a/test/dow_img.pyc b/test/dow_img.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77f1e122998f9b82f1d68ff7ae14ab4c4aadb0db Binary files /dev/null and b/test/dow_img.pyc differ diff --git a/test/howcuteami.py b/test/howcuteami.py new file mode 100644 index 0000000000000000000000000000000000000000..cee9393f2ff2ab1dec7445425ebe870744c9999e --- /dev/null +++ b/test/howcuteami.py @@ -0,0 +1,202 @@ +import cv2 +import math +import argparse +import numpy as np +import os + + +# detect face +def highlightFace(net, frame, conf_threshold=0.95): + frameOpencvDnn = frame.copy() + frameHeight = frameOpencvDnn.shape[0] + frameWidth = frameOpencvDnn.shape[1] + blob = cv2.dnn.blobFromImage( + frameOpencvDnn, 1.0, (300, 300), [104, 117, 123], True, False + ) + + net.setInput(blob) + detections = net.forward() + faceBoxes = [] + + for i in range(detections.shape[2]): + confidence = detections[0, 0, i, 2] + if confidence > conf_threshold: + x1 = int(detections[0, 0, i, 3] * frameWidth) + y1 = int(detections[0, 0, i, 4] * frameHeight) + x2 = int(detections[0, 0, i, 5] * frameWidth) + y2 = int(detections[0, 0, i, 6] * frameHeight) + faceBoxes.append(scale([x1, y1, x2, y2])) + + return faceBoxes + + +# scale current rectangle to box +def scale(box): + width = box[2] - box[0] + height = box[3] - box[1] + maximum = max(width, height) + dx = int((maximum - width) / 2) + dy = int((maximum - height) / 2) + + bboxes = [box[0] - dx, box[1] - dy, box[2] + dx, box[3] + dy] + return bboxes + + +# crop image +def cropImage(image, box): + num = image[box[1] : box[3], box[0] : box[2]] + return num + + +# main +parser = argparse.ArgumentParser() +parser.add_argument("-i", "--image", type=str, required=False, help="input image") +args = parser.parse_args() + +# 创建输出目录 +output_dir = "../output" +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +faceProto = "models/opencv_face_detector.pbtxt" +faceModel = "models/opencv_face_detector_uint8.pb" +ageProto = "models/age_googlenet.prototxt" +ageModel = "models/age_googlenet.caffemodel" +genderProto = "models/gender_googlenet.prototxt" +genderModel = "models/gender_googlenet.caffemodel" +beautyProto = "models/beauty_resnet.prototxt" +beautyModel = "models/beauty_resnet.caffemodel" + +MODEL_MEAN_VALUES = (104, 117, 123) +ageList = [ + "(0-2)", + "(4-6)", + "(8-12)", + "(15-20)", + "(25-32)", + "(38-43)", + "(48-53)", + "(60-100)", +] +genderList = ["Male", "Female"] + +# 定义性别对应的颜色 (BGR格式) +gender_colors = { + "Male": (255, 165, 0), # 橙色 Orange + "Female": (255, 0, 255), # 洋红 Magenta / Fuchsia +} + +faceNet = cv2.dnn.readNet(faceModel, faceProto) +ageNet = cv2.dnn.readNet(ageModel, ageProto) +genderNet = cv2.dnn.readNet(genderModel, genderProto) +beautyNet = cv2.dnn.readNet(beautyModel, beautyProto) + +# 读取图片 +image_path = args.image if args.image else "images/charlize.jpg" +frame = cv2.imread(image_path) + +if frame is None: + print(f"无法读取图片: {image_path}") + exit() + +faceBoxes = highlightFace(faceNet, frame) +if not faceBoxes: + print("No face detected") + exit() + +print(f"检测到 {len(faceBoxes)} 张人脸") + +for i, faceBox in enumerate(faceBoxes): + # 提取人脸区域 + face = cropImage(frame, faceBox) + face_resized = cv2.resize(face, (224, 224)) + + # gender net + blob = cv2.dnn.blobFromImage( + face_resized, 1.0, (224, 224), MODEL_MEAN_VALUES, swapRB=False + ) + genderNet.setInput(blob) + genderPreds = genderNet.forward() + gender = genderList[genderPreds[0].argmax()] + print(f"Gender: {gender}") + + # age net + ageNet.setInput(blob) + agePreds = ageNet.forward() + age = ageList[agePreds[0].argmax()] + print(f"Age: {age[1:-1]} years") + + # beauty net + blob = cv2.dnn.blobFromImage( + face_resized, 1.0 / 255, (224, 224), MODEL_MEAN_VALUES, swapRB=False + ) + beautyNet.setInput(blob) + beautyPreds = beautyNet.forward() + beauty = round(2.0 * sum(beautyPreds[0]), 1) + print(f"Beauty: {beauty}/10.0") + + # 根据性别选择颜色 + color = gender_colors[gender] + + # 保存人脸图片 - 使用cv2.imwrite + face_filename = f"{output_dir}/face_{i+1}.webp" + cv2.imwrite(face_filename, face, [cv2.IMWRITE_WEBP_QUALITY, 95]) + print(f"人脸图片已保存: {face_filename}") + + # 保存评分到图片上(可选) + face_with_text = face.copy() + cv2.putText( + face_with_text, f"{gender}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2 + ) + cv2.putText( + face_with_text, + f"{age[1:-1]} years", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + color, + 2, + ) + cv2.putText( + face_with_text, + f"{beauty}/10.0", + (10, 90), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + color, + 2, + ) + + annotated_filename = f"{output_dir}/face_{i+1}_annotated.webp" + cv2.imwrite(annotated_filename, face_with_text, [cv2.IMWRITE_WEBP_QUALITY, 95]) + print(f"标注人脸已保存: {annotated_filename}") + + # 在原图上绘制人脸框和信息 + cv2.rectangle( + frame, + (faceBox[0], faceBox[1]), + (faceBox[2], faceBox[3]), + color, + int(round(frame.shape[0] / 400)), + 8, + ) + cv2.putText( + frame, + f"{gender}, {age}, {beauty}", + (faceBox[0], faceBox[1] - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 1.25, + color, + 2, + cv2.LINE_AA, + ) + +# 保存完整的标注图片 +result_filename = f"{output_dir}/result_full.webp" +cv2.imwrite(result_filename, frame, [cv2.IMWRITE_WEBP_QUALITY, 95]) +print(f"完整结果图片已保存: {result_filename}") + +# 显示图片 +cv2.imshow("howbeautifulami", frame) +cv2.waitKey(0) +cv2.destroyAllWindows() diff --git a/test/howcuteami.pyc b/test/howcuteami.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94131be21bbf8c477bd6598356250aae3edda438 Binary files /dev/null and b/test/howcuteami.pyc differ diff --git a/test/import_history_images.py b/test/import_history_images.py new file mode 100644 index 0000000000000000000000000000000000000000..7285007fc1af8a363a29cb4c189d7a832b1ff918 --- /dev/null +++ b/test/import_history_images.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +""" +导入历史图片文件到数据库的脚本 +""" + +import asyncio +import hashlib +import os +import sys +import time +from datetime import datetime +from pathlib import Path + +# 添加项目根目录到Python路径 +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from database import record_image_creation, fetch_records_by_paths + + +def calculate_file_hash(file_path): + """计算文件的MD5哈希值""" + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + # 分块读取文件,避免大文件占用过多内存 + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +def infer_category_from_filename(filename): + """从文件名推断类别""" + filename_lower = filename.lower() + + # 处理动漫风格化类型 + if '_anime_style_' in filename_lower: + return 'anime_style' + + # 查找最后一个下划线和第一个点的位置 + last_underscore_index = filename_lower.rfind('_') + first_dot_index = filename_lower.find('.', last_underscore_index) + + # 如果找到了下划线和点,且下划线在点之前 + if last_underscore_index != -1 and first_dot_index != -1 and last_underscore_index < first_dot_index: + # 提取下划线和点之间的内容 + file_type = filename_lower[last_underscore_index + 1:first_dot_index] + + # 根据类型返回中文描述 + type_mapping = { + 'restore': 'restore', + 'upcolor': 'upcolor', + 'grayscale': 'grayscale', + 'upscale': 'upscale', + 'compress': 'compress', + 'id_photo': 'id_photo', + 'grid': 'grid', + 'rvm': 'rvm', + 'celebrity': 'celebrity', + 'face': 'face', + 'original': 'original' + } + + return type_mapping.get(file_type, 'other') + + # 默认返回 other + return 'other' + + +async def import_history_images(source_dir, nickname="system_import"): + """导入历史图片到数据库""" + source_path = Path(source_dir) + + if not source_path.exists(): + print(f"错误: 目录 {source_dir} 不存在") + return + + # 支持的图片格式 + image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif', '.tiff', + '.tif'} + + # 获取所有图片文件 + image_files = [] + for ext in image_extensions: + image_files.extend(source_path.glob(f"*{ext}")) + image_files.extend(source_path.glob(f"*{ext.upper()}")) + + print(f"找到 {len(image_files)} 个图片文件") + + imported_count = 0 + skipped_count = 0 + + for image_path in image_files: + try: + file_name = image_path.name + + # 检查文件是否已存在于数据库中(基于文件名) + records = await fetch_records_by_paths([file_name]) + + if file_name in records: + print(f"跳过已存在的文件: {file_name}") + skipped_count += 1 + continue + + # 如果数据库中没有记录,则继续导入 + # 计算文件哈希值用于进一步确认唯一性 + file_hash = calculate_file_hash(str(image_path)) + + # 推断文件类别 + category = infer_category_from_filename(file_name) + + # 记录到数据库 + await record_image_creation( + file_path=file_name, # 使用文件名而不是完整路径 + nickname=nickname, + category=category, + bos_uploaded=False, # 历史文件通常未上传到BOS + score=0.0, # 历史文件默认分数为0 + extra_metadata={ + "source": "history_import", + "original_path": str(image_path), + "file_hash": file_hash, + "import_time": datetime.now().isoformat() + } + ) + + imported_count += 1 + print(f"成功导入: {file_name} (类别: {category})") + + except Exception as e: + print(f"导入文件失败 {image_path.name}: {str(e)}") + continue + + print(f"\n导入完成!") + print(f"成功导入: {imported_count} 个文件") + print(f"跳过: {skipped_count} 个文件") + + +async def main(): + if len(sys.argv) < 2: + print("用法: python import_history_images.py <图片目录路径> [昵称]") + print( + "示例: python import_history_images.py ~/app/data/images") + print( + "示例: python import_history_images.py ~/app/data/images \"历史导入\"") + sys.exit(1) + + source_directory = sys.argv[1] + nickname = sys.argv[2] if len(sys.argv) > 2 else "system_import" + + print(f"开始导入图片文件...") + print(f"源目录: {source_directory}") + print(f"用户昵称: {nickname}") + print("-" * 50) + + start_time = time.time() + await import_history_images(source_directory, nickname) + end_time = time.time() + + print(f"\n总耗时: {end_time - start_time:.2f} 秒") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/test/import_history_images.pyc b/test/import_history_images.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9b58f5de7253a19127d4f066344cdc4ba8ea454 Binary files /dev/null and b/test/import_history_images.pyc differ diff --git a/test/remove_duplicate_celeb_images.py b/test/remove_duplicate_celeb_images.py new file mode 100644 index 0000000000000000000000000000000000000000..06808a412e80e2c23356fdd66add16306244d741 --- /dev/null +++ b/test/remove_duplicate_celeb_images.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +""" +遍历指定目录,根据文件内容(MD5)查找重复项,如果发现重复则只保留一个。 +默认目标目录为 /opt/data/chinese_celeb_dataset,可用 --target-dir 覆盖。 +""" + +from __future__ import annotations + +import argparse +import hashlib +import os +import sys +from pathlib import Path +from typing import Dict + +DEFAULT_TARGET_DIR = Path("/opt/data/chinese_celeb_dataset") +CHUNK_SIZE = 4 * 1024 * 1024 # 4MB + + +def compute_md5(file_path: Path) -> str: + """流式计算文件 MD5,避免一次性读入大文件。""" + digest = hashlib.md5() + with file_path.open("rb") as fh: + for chunk in iter(lambda: fh.read(CHUNK_SIZE), b""): + digest.update(chunk) + return digest.hexdigest() + + +def deduplicate(target_dir: Path, dry_run: bool = False) -> int: + """执行去重逻辑,返回删除的重复文件数量。""" + if not target_dir.exists(): + print(f"[error] 目标目录不存在: {target_dir}", file=sys.stderr) + return 0 + if not target_dir.is_dir(): + print(f"[error] 目标路径不是目录: {target_dir}", file=sys.stderr) + return 0 + + md5_map: Dict[str, Path] = {} + removed = 0 + scanned = 0 + + # 按路径排序,确保始终保留最先遍历到的文件 + for file_path in sorted(target_dir.rglob("*")): + if not file_path.is_file() or file_path.is_symlink(): + continue + + scanned += 1 + try: + file_md5 = compute_md5(file_path) + except Exception as exc: + print(f"[warn] 计算 MD5 失败: {file_path} -> {exc}", file=sys.stderr) + continue + + original = md5_map.get(file_md5) + if original is None: + md5_map[file_md5] = file_path + continue + + if dry_run: + print(f"[dry-run] {file_path} 与 {original} 内容相同,将被删除") + else: + try: + os.remove(file_path) + removed += 1 + print(f"[remove] 删除重复文件: {file_path} (原始: {original})") + except Exception as exc: + print(f"[error] 删除失败: {file_path} -> {exc}", file=sys.stderr) + + print( + f"[summary] 扫描文件: {scanned}, 保留唯一文件: {len(md5_map)}, 删除重复文件: {removed}{' (dry-run)' if dry_run else ''}" + ) + return removed + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="按 MD5 删除重复文件,仅保留一个副本。") + parser.add_argument( + "--target-dir", + type=Path, + default=DEFAULT_TARGET_DIR, + help=f"需要去重的目录(默认: {DEFAULT_TARGET_DIR})", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="只输出将删除的文件,不实际删除。", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + target_dir = args.target_dir.expanduser().resolve() + deduplicate(target_dir, dry_run=args.dry_run) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/remove_duplicate_celeb_images.pyc b/test/remove_duplicate_celeb_images.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fabb5fc87726e5bd18fef7255aa6caf6aa198b2 Binary files /dev/null and b/test/remove_duplicate_celeb_images.pyc differ diff --git a/test/remove_faceless_images.py b/test/remove_faceless_images.py new file mode 100644 index 0000000000000000000000000000000000000000..c396279e9f8af43ff1be0a9f1478fc9e217f787b --- /dev/null +++ b/test/remove_faceless_images.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +""" +遍历 /opt/data/chinese_celeb_dataset 下的图片,使用 YOLO 人脸检测并删除没有检测到人脸的图片。 + +用法示例: + python test/remove_faceless_images.py --dry-run +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Iterable, List, Optional + +import config + +try: + from ultralytics import YOLO +except ImportError as exc: # pragma: no cover - 运行期缺依赖提示 + raise SystemExit("缺少 ultralytics,请先执行 pip install ultralytics") from exc + +# 默认数据集与模型配置 +DEFAULT_DATASET_DIR = Path("/opt/data/chinese_celeb_dataset") +MODEL_DIR = Path(config.MODELS_PATH) +YOLO_MODEL_NAME = config.YOLO_MODEL + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="使用 YOLO 检测 /opt/data/chinese_celeb_dataset 中的图片并删除无脸图片" + ) + parser.add_argument( + "--dataset-dir", + type=Path, + default=DEFAULT_DATASET_DIR, + help="需要检查的根目录(默认:/opt/data/chinese_celeb_dataset)", + ) + parser.add_argument( + "--extensions", + type=str, + default=".jpg,.jpeg,.png,.webp,.bmp", + help="需要检查的图片扩展名,逗号分隔", + ) + parser.add_argument( + "--confidence", + type=float, + default=config.FACE_CONFIDENCE, + help="YOLO 检测的人脸置信度阈值", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="仅输出将被删除的文件,不真正删除,便于先预览结果", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="输出更多调试信息", + ) + return parser.parse_args() + + +def load_yolo_model() -> YOLO: + """ + 优先加载本地 models 目录下配置好的模型,如果不存在则回退为模型名称(会触发自动下载)。 + """ + candidates: List[str] = [] + local_path = MODEL_DIR / YOLO_MODEL_NAME + if local_path.exists(): + candidates.append(str(local_path)) + candidates.append(YOLO_MODEL_NAME) + + last_error: Optional[Exception] = None + for candidate in candidates: + try: + config.logger.info("尝试加载 YOLO 模型:%s", candidate) + return YOLO(candidate) + except Exception as exc: # pragma: no cover + last_error = exc + config.logger.warning("加载 YOLO 模型失败:%s -> %s", candidate, exc) + + raise RuntimeError(f"无法加载 YOLO 模型:{YOLO_MODEL_NAME}") from last_error + + +def iter_image_files(root: Path, extensions: Iterable[str]) -> Iterable[Path]: + lower_exts = tuple(ext.strip().lower() for ext in extensions if ext.strip()) + for path in root.rglob("*"): + if not path.is_file(): + continue + if path.suffix.lower() in lower_exts: + yield path + + +def has_face(model: YOLO, image_path: Path, confidence: float, verbose: bool = False) -> bool: + """ + 使用 YOLO 检测图片中是否存在人脸。检测到任意一个框即可视为有人脸。 + """ + try: + results = model(image_path, conf=confidence, verbose=False) + except Exception as exc: # pragma: no cover + config.logger.error("检测失败,跳过 %s:%s", image_path, exc) + return False + + for result in results: + boxes = getattr(result, "boxes", None) + if boxes is None: + continue + if len(boxes) > 0: + if verbose: + faces = [] + for box in boxes: + cls_id = int(box.cls[0]) if getattr(box, "cls", None) is not None else -1 + score = float(box.conf[0]) if getattr(box, "conf", None) is not None else 0.0 + faces.append({"cls": cls_id, "conf": score}) + config.logger.info("检测到人脸:%s -> %s", image_path, faces) + return True + return False + + +def main() -> None: + args = parse_args() + dataset_dir: Path = args.dataset_dir.expanduser().resolve() + if not dataset_dir.exists(): + raise SystemExit(f"目录不存在:{dataset_dir}") + + model = load_yolo_model() + image_paths = list(iter_image_files(dataset_dir, args.extensions.split(","))) + total = len(image_paths) + if total == 0: + print(f"目录 {dataset_dir} 下没有匹配到图片文件") + return + + removed = 0 + errored = 0 + for idx, image_path in enumerate(image_paths, start=1): + if idx % 100 == 0 or args.verbose: + print(f"[{idx}/{total}] 正在处理 {image_path}") + + try: + if has_face(model, image_path, args.confidence, args.verbose): + continue + except Exception as exc: # pragma: no cover + errored += 1 + config.logger.error("检测过程中发生异常,跳过 %s:%s", image_path, exc) + continue + + if args.dry_run: + print(f"[DRY-RUN] 将删除:{image_path}") + else: + try: + image_path.unlink() + print(f"已删除:{image_path}") + except Exception as exc: # pragma: no cover + errored += 1 + config.logger.error("删除失败 %s:%s", image_path, exc) + continue + removed += 1 + + print( + f"扫描完成,检测图片 {total} 张,删除 {removed} 张无脸图片,异常 {errored} 张,数据保存在:{dataset_dir}" + ) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: # pragma: no cover + sys.exit("用户中断") diff --git a/test/remove_faceless_images.pyc b/test/remove_faceless_images.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8cedfd2cc9d6ace2dd028c1252de7406e03129f Binary files /dev/null and b/test/remove_faceless_images.pyc differ diff --git a/test/test_deepface.py b/test/test_deepface.py new file mode 100644 index 0000000000000000000000000000000000000000..fad6d49d80ac641854ebe96cebc574ae0c4d0007 --- /dev/null +++ b/test/test_deepface.py @@ -0,0 +1,38 @@ +import json +import time +from deepface import DeepFace + +images_path = "/opt/data/face" + +# ========== 2. 人脸相似度比对 ========== +start_time = time.time() +result_verification = DeepFace.verify( + img1_path=images_path + "/4.webp", + img2_path=images_path + "/5.webp", + model_name="ArcFace", # 指定模型 + detector_backend="yolov11n", # 人脸检测器 retinaface / yolov8 / opencv / ssd / mediapipe + distance_metric="cosine" # 相似度度量 +) +end_time = time.time() +print(f"🕒 人脸比对耗时: {end_time - start_time:.3f} 秒") + +# 打印结果 +print(json.dumps(result_verification, ensure_ascii=False, indent=2)) + + +# ========== 1. 人脸识别 ========== + +start_time = time.time() +result_recognition = DeepFace.find( + img_path=images_path + "/1.jpg", # 待识别人脸 + db_path=images_path, # 数据库路径 + model_name="ArcFace", # 指定模型 + detector_backend="yolov11n", # 人脸检测器 + distance_metric="cosine" # 相似度度量 +) +end_time = time.time() +print(f"🕒 人脸识别耗时: {end_time - start_time:.3f} 秒") + +# 如果需要打印结果,可以取消注释 +# df = result_recognition[0] +# print(df.to_json(orient="records", force_ascii=False)) diff --git a/test/test_deepface.pyc b/test/test_deepface.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81fab82e4c5d3533b3fba67b8319ff67874ee07e Binary files /dev/null and b/test/test_deepface.pyc differ diff --git a/test/test_main.http b/test/test_main.http new file mode 100644 index 0000000000000000000000000000000000000000..a2d81a92c9122ae3e6b5b657c5723033b2f26895 --- /dev/null +++ b/test/test_main.http @@ -0,0 +1,11 @@ +# Test your FastAPI endpoints + +GET http://127.0.0.1:8000/ +Accept: application/json + +### + +GET http://127.0.0.1:8000/hello/User +Accept: application/json + +### diff --git a/test/test_rvm_infer.py b/test/test_rvm_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..537ea8fec4feae3a578e651ec996b67ddfbb3657 --- /dev/null +++ b/test/test_rvm_infer.py @@ -0,0 +1,46 @@ +import time + +import cv2 +import numpy as np +import torch +from torchvision import transforms + +device = "cpu" + +# 输入输出路径 +input_path = "/opt/data/face/yang.webp" +output_path = "/opt/data/face/output_alpha.webp" + +# ✅ 加载预训练模型 (resnet50) +model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50").to(device).eval() + +# 开始计时 +start = time.time() + +# 读图 (BGR->RGB) +img = cv2.imread(input_path)[:, :, ::-1].copy() +src = transforms.ToTensor()(img).unsqueeze(0).to(device) + +# 推理 +rec = [None] * 4 +with torch.no_grad(): + fgr, pha, *rec = model(src, *rec, downsample_ratio=0.25) + +# 转 numpy +fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # (H,W,3) +pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) # (H,W) + +# 拼接 RGBA +rgba = np.dstack((fgr, pha)) # (H,W,4) + +# 保存 WebP (带透明度) +cv2.imwrite(output_path, rgba[:, :, [2,1,0,3]], [cv2.IMWRITE_WEBP_QUALITY, 100]) # 转成 BGRA 顺序 + +# 结束计时 +elapsed = time.time() - start + +# 控制台日志输出 +print(f"✅ RVM 抠图完成 (透明背景)") +print(f" 输入文件: {input_path}") +print(f" 输出文件: {output_path}") +print(f" 耗时: {elapsed:.3f} 秒 (设备: {device})") diff --git a/test/test_rvm_infer.pyc b/test/test_rvm_infer.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45cca8090403fe133d1fe0a79de4ac3539084b6e Binary files /dev/null and b/test/test_rvm_infer.pyc differ diff --git a/test/test_score.py b/test/test_score.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a0eb7bccb9ca786a6bd88cf9ef50160fcd9e37 --- /dev/null +++ b/test/test_score.py @@ -0,0 +1,26 @@ +import json +import logging + +import numpy as np +from retinaface import RetinaFace + + +def default_converter(o): + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.floating): + return float(o) + if isinstance(o, np.ndarray): + return o.tolist() + return str(o) + + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +resp = RetinaFace.detect_faces("~/Downloads/chounan.jpeg") + +logger.info( + "search results: " + json.dumps(resp, ensure_ascii=False, default=default_converter) +) diff --git a/test/test_score.pyc b/test/test_score.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8496a8846685cebb1db3b20b6b20cc8a22702625 Binary files /dev/null and b/test/test_score.pyc differ diff --git a/test/test_score_adjustment_demo.py b/test/test_score_adjustment_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..5510fe5d4adb64c8919fd51a622b358433270369 --- /dev/null +++ b/test/test_score_adjustment_demo.py @@ -0,0 +1,30 @@ +def adjust_score(score, threshold, gamma): + """根据阈值和gamma值调整评分""" + if score < threshold: + adjusted = threshold - gamma * (threshold - score) + return round(min(10.0, max(0.0, adjusted)), 1) + return score + +# 默认参数 (T=9.0, γ=0.5) +default_threshold = 9.0 +default_gamma = 0.5 + +# 新参数1 (T=8.0, γ=0.5) +new_threshold_1 = 9 +new_gamma_1 = 0.9 + +# 新参数2 (T=8.0, γ=0.3) +new_threshold_2 = 9 +new_gamma_2 = 0.8 + +print(f"原始分\tT={default_threshold},y={default_gamma}\tT={new_threshold_1},γ={new_gamma_1}\tT={new_threshold_2},γ={new_gamma_2}") +print("-----\t----------\t----------\t----------") + +# 从1.0到10.0,以0.1为步长 +for i in range(10, 101): + score = i / 10.0 + default_adjusted = adjust_score(score, default_threshold, default_gamma) + new_adjusted_1 = adjust_score(score, new_threshold_1, new_gamma_1) + new_adjusted_2 = adjust_score(score, new_threshold_2, new_gamma_2) + # 确保显示小数点 + print(f"{score:.1f}\t\t\t{default_adjusted:.1f}\t\t\t\t\t{new_adjusted_1:.1f}\t\t\t\t\t{new_adjusted_2:.1f}") diff --git a/test/test_score_adjustment_demo.pyc b/test/test_score_adjustment_demo.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e55a752cdd10249a310063e25dea708b99033b93 Binary files /dev/null and b/test/test_score_adjustment_demo.pyc differ diff --git a/test/test_sky.py b/test/test_sky.py new file mode 100644 index 0000000000000000000000000000000000000000..d952460a8b09ab1a2d42d6865a1d736834b0725a --- /dev/null +++ b/test/test_sky.py @@ -0,0 +1,15 @@ +import os.path as osp + +import cv2 +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + +image_skychange = pipeline(Tasks.image_skychange, + model='iic/cv_hrnetocr_skychange') +result = image_skychange( + {'sky_image': '~/Downloads/sky_image.jpg', + 'scene_image': '/opt/data/face/NXEo0zusSaNB2fa232c84898e92ff165e2dfee59cb54.jpg'}) +cv2.imwrite('~/Downloads/result.png', + result[OutputKeys.OUTPUT_IMG]) +print(f'Output written to {osp.abspath("result.png")}') diff --git a/test/test_sky.pyc b/test/test_sky.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b6ddc8fcbfffcf4d90f50fbf4579e6f69fd8382 Binary files /dev/null and b/test/test_sky.pyc differ diff --git a/test_tensorflow.py b/test_tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..55e144a2a93bc684067276c417420dd64eef01eb --- /dev/null +++ b/test_tensorflow.py @@ -0,0 +1,19 @@ +import deepface +import sys +import tensorflow as tf + +try: + import keras + + keras_pkg = "keras (standalone)" + keras_ver = keras.__version__ +except Exception: + from tensorflow import keras + + keras_pkg = "tf.keras" + keras_ver = keras.__version__ + +print("py =", sys.version) +print("deepface =", deepface.__version__) +print("tensorflow =", tf.__version__) +print("keras pkg =", keras_pkg, "keras =", keras_ver) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acd93a1603293e9be0c2bfd8d891791ba19cf2bb --- /dev/null +++ b/utils.py @@ -0,0 +1,1037 @@ +import base64 +import hashlib +import os +import re +import shutil +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Optional +from collections import OrderedDict + +import cv2 +import numpy as np +from PIL import Image + +try: + import boto3 + from botocore.exceptions import BotoCoreError, ClientError +except ImportError: + boto3 = None + BotoCoreError = ClientError = Exception + +from config import ( + IMAGES_DIR, + logger, + SAVE_QUALITY, + MODELS_PATH, + BOS_ACCESS_KEY, + BOS_SECRET_KEY, + BOS_ENDPOINT, + BOS_BUCKET_NAME, + BOS_IMAGE_DIR, + BOS_UPLOAD_ENABLED, + BOS_DOWNLOAD_TARGETS, + HUGGINGFACE_REPO_ID, + HUGGINGFACE_SYNC_ENABLED, + HUGGINGFACE_REVISION, + HUGGINGFACE_ALLOW_PATTERNS, + HUGGINGFACE_IGNORE_PATTERNS, +) + +_BOS_CLIENT = None +_BOS_CLIENT_INITIALIZED = False +_BOS_CLIENT_LOCK = threading.Lock() +_BOS_DOWNLOAD_LOCK = threading.Lock() +_BOS_DOWNLOAD_COMPLETED = False +_BOS_BACKGROUND_EXECUTOR = None +_BOS_BACKGROUND_FUTURES = [] +_IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR)) +_BOS_UPLOAD_CACHE = OrderedDict() +_BOS_UPLOAD_CACHE_LOCK = threading.Lock() +_BOS_UPLOAD_CACHE_MAX = 2048 + + +def _decode_bos_credential(raw_value: str) -> str: + """将Base64编码的凭证解码为明文,若解码失败则返回原值""" + if not raw_value: + return "" + + value = raw_value.strip() + if not value: + return "" + + try: + padding = len(value) % 4 + if padding: + value += "=" * (4 - padding) + decoded = base64.b64decode(value).decode("utf-8").strip() + if decoded: + return decoded + except Exception: + pass + return value + + +def _is_path_under_images_dir(file_path: str) -> bool: + try: + return os.path.commonpath( + [_IMAGES_DIR_ABS, os.path.abspath(file_path)] + ) == _IMAGES_DIR_ABS + except ValueError: + return False + + +def _get_bos_client(): + global _BOS_CLIENT, _BOS_CLIENT_INITIALIZED + if _BOS_CLIENT_INITIALIZED: + return _BOS_CLIENT + + with _BOS_CLIENT_LOCK: + if _BOS_CLIENT_INITIALIZED: + return _BOS_CLIENT + + if not BOS_UPLOAD_ENABLED: + _BOS_CLIENT_INITIALIZED = True + _BOS_CLIENT = None + return None + access_key = _decode_bos_credential(BOS_ACCESS_KEY) + secret_key = _decode_bos_credential(BOS_SECRET_KEY) + if not all([access_key, secret_key, BOS_ENDPOINT, BOS_BUCKET_NAME]): + logger.warning("BOS 上传未配置完整,跳过初始化") + _BOS_CLIENT_INITIALIZED = True + _BOS_CLIENT = None + return None + + if boto3 is None: + logger.warning("未安装 boto3,BOS 上传功能不可用") + _BOS_CLIENT_INITIALIZED = True + _BOS_CLIENT = None + return None + + try: + _BOS_CLIENT = boto3.client( + "s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + endpoint_url=BOS_ENDPOINT, + ) + logger.info("BOS 客户端初始化成功") + except Exception as e: + logger.warning(f"初始化 BOS 客户端失败,将跳过上传: {e}") + _BOS_CLIENT = None + finally: + _BOS_CLIENT_INITIALIZED = True + + return _BOS_CLIENT + + +def _normalize_bos_prefix(prefix: Optional[str]) -> str: + value = (prefix or "").strip() + if not value: + return "" + value = value.strip("/") + if not value: + return "" + return f"{value}/" if not value.endswith("/") else value + + +def _directory_has_files(path: str) -> bool: + try: + for _root, _dirs, files in os.walk(path): + if files: + return True + except Exception: + return False + return False + + +def download_bos_directory(prefix: str, destination_dir: str, *, force_download: bool = False) -> bool: + """ + 将 BOS 上的指定前缀目录同步到本地。 + :param prefix: BOS 对象前缀,例如 'models/' 或 '20220620/models' + :param destination_dir: 本地目标目录 + :param force_download: 是否强制重新下载(忽略本地已存在的文件) + :return: 是否确保目录可用 + """ + client = _get_bos_client() + if client is None: + logger.warning("BOS 客户端不可用,无法下载资源(prefix=%s)", prefix) + return False + + dest_dir = os.path.abspath(os.path.expanduser(destination_dir)) + try: + os.makedirs(dest_dir, exist_ok=True) + except Exception as exc: + logger.error("创建本地目录失败: %s (%s)", dest_dir, exc) + return False + + normalized_prefix = _normalize_bos_prefix(prefix) + + # 未强制下载且目录已有文件时直接跳过,避免重复下载 + if not force_download and _directory_has_files(dest_dir): + logger.info("本地目录已存在文件,跳过下载: %s -> %s", normalized_prefix or "", dest_dir) + return True + + paginate_kwargs = {"Bucket": BOS_BUCKET_NAME} + if normalized_prefix: + paginate_kwargs["Prefix"] = normalized_prefix if normalized_prefix.endswith("/") else f"{normalized_prefix}/" + + found_any = False + downloaded = 0 + skipped = 0 + + try: + paginator = client.get_paginator("list_objects_v2") + for page in paginator.paginate(**paginate_kwargs): + for obj in page.get("Contents", []): + key = obj.get("Key") + if not key: + continue + if normalized_prefix: + prefix_with_slash = normalized_prefix if normalized_prefix.endswith("/") else f"{normalized_prefix}/" + if not key.startswith(prefix_with_slash): + continue + relative_key = key[len(prefix_with_slash):] + else: + relative_key = key + + if not relative_key or relative_key.endswith("/"): + continue + found_any = True + + target_path = os.path.join(dest_dir, relative_key) + target_dir = os.path.dirname(target_path) + os.makedirs(target_dir, exist_ok=True) + + expected_size = obj.get("Size") + if ( + not force_download + and os.path.exists(target_path) + and expected_size is not None + and expected_size == os.path.getsize(target_path) + ): + skipped += 1 + logger.info("文件已存在且大小一致,跳过下载: %s", relative_key) + continue + + tmp_path = f"{target_path}.download" + try: + size_mb = (expected_size or 0) / (1024 * 1024) + logger.info("开始下载: %s (%.2f MB)", relative_key, size_mb) + client.download_file(Bucket=BOS_BUCKET_NAME, Key=key, Filename=tmp_path) + os.replace(tmp_path, target_path) + downloaded += 1 + logger.info("下载完成: %s", relative_key) + except Exception as exc: + logger.warning("下载失败: %s (%s)", key, exc) + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except Exception: + pass + except Exception as exc: + logger.warning("遍历 BOS 目录失败: %s", exc) + return False + + if not found_any: + logger.warning("在 BOS 桶 %s 中未找到前缀 '%s' 的内容", BOS_BUCKET_NAME, normalized_prefix or "") + return False + + logger.info( + "BOS 同步完成 prefix=%s -> %s 下载=%d 跳过=%d", + normalized_prefix or "", + dest_dir, + downloaded, + skipped, + ) + return downloaded > 0 or skipped > 0 + + +def _get_background_executor() -> ThreadPoolExecutor: + global _BOS_BACKGROUND_EXECUTOR + if _BOS_BACKGROUND_EXECUTOR is None: + _BOS_BACKGROUND_EXECUTOR = ThreadPoolExecutor(max_workers=2, thread_name_prefix="bos-bg") + return _BOS_BACKGROUND_EXECUTOR + + +def ensure_huggingface_models(force_download: bool = False) -> bool: + """确保 HuggingFace 模型仓库同步到本地 MODELS_PATH。""" + if not HUGGINGFACE_SYNC_ENABLED: + logger.info("HuggingFace 模型同步开关已关闭,跳过同步流程") + return True + + repo_id = (HUGGINGFACE_REPO_ID or "").strip() + if not repo_id: + logger.info("未配置 HuggingFace 仓库,跳过模型下载") + return True + + try: + from huggingface_hub import snapshot_download + except ImportError: + logger.error("未安装 huggingface-hub,无法下载 HuggingFace 模型") + return False + + try: + os.makedirs(MODELS_PATH, exist_ok=True) + except Exception as exc: + logger.error("创建模型目录失败: %s (%s)", MODELS_PATH, exc) + return False + + download_kwargs = { + "repo_id": repo_id, + "local_dir": MODELS_PATH, + "local_dir_use_symlinks": False, + } + + revision = (HUGGINGFACE_REVISION or "").strip() + if revision: + download_kwargs["revision"] = revision + + if HUGGINGFACE_ALLOW_PATTERNS: + download_kwargs["allow_patterns"] = HUGGINGFACE_ALLOW_PATTERNS + + if HUGGINGFACE_IGNORE_PATTERNS: + download_kwargs["ignore_patterns"] = HUGGINGFACE_IGNORE_PATTERNS + + if force_download: + download_kwargs["force_download"] = True + download_kwargs["resume_download"] = False + else: + download_kwargs["resume_download"] = True + + try: + logger.info( + "开始同步 HuggingFace 模型: repo=%s revision=%s -> %s", + repo_id, + revision or "", + MODELS_PATH, + ) + snapshot_path = snapshot_download(**download_kwargs) + logger.info( + "HuggingFace 模型同步完成: %s -> %s", + repo_id, + snapshot_path, + ) + return True + except Exception as exc: + logger.error("HuggingFace 模型下载失败: %s", exc) + return False + + +def ensure_bos_resources(force_download: bool = False, include_background: bool = False) -> bool: + """ + 根据配置的 BOS_DOWNLOAD_TARGETS 同步启动所需的模型与数据资源。 + :param force_download: 是否强制重新同步所有资源 + :param include_background: 是否将标记为后台任务的目标也同步为阻塞任务 + :return: 资源是否已准备就绪 + """ + global _BOS_DOWNLOAD_COMPLETED, _BOS_BACKGROUND_FUTURES + + with _BOS_DOWNLOAD_LOCK: + if _BOS_DOWNLOAD_COMPLETED and not force_download and not include_background: + return True + + targets = BOS_DOWNLOAD_TARGETS or [] + if not targets: + logger.info("未配置 BOS 下载目标,跳过资源同步") + _BOS_DOWNLOAD_COMPLETED = True + return True + + download_jobs = [] + background_jobs = [] + for target in targets: + if not isinstance(target, dict): + logger.warning("无效的 BOS 下载配置项: %r", target) + continue + + prefix = target.get("bos_prefix") + destination = target.get("destination") + description = target.get("description") or prefix or "" + background_flag = bool(target.get("background")) + + if not prefix or not destination: + logger.warning("缺少必要字段,无法处理 BOS 下载配置: %r", target) + continue + + job = { + "description": description, + "prefix": prefix, + "destination": destination, + } + + if background_flag and not include_background: + background_jobs.append(job) + else: + download_jobs.append(job) + + results = [] + if download_jobs: + max_workers = min(len(download_jobs), max(os.cpu_count() or 1, 1)) + if max_workers <= 0: + max_workers = 1 + + with ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="bos-sync") as executor: + future_to_job = {} + for job in download_jobs: + logger.info( + "准备同步 BOS 资源: %s (prefix=%s -> %s)", + job["description"], + job["prefix"], + job["destination"], + ) + future = executor.submit( + download_bos_directory, + job["prefix"], + job["destination"], + force_download=force_download, + ) + future_to_job[future] = job + + for future in as_completed(future_to_job): + job = future_to_job[future] + description = job["description"] + try: + success = future.result() + except Exception as exc: + logger.warning("BOS 资源同步异常: %s (%s)", description, exc) + success = False + + if success: + logger.info("BOS 资源已就绪: %s", description) + else: + logger.warning("BOS 资源同步失败: %s", description) + results.append(success) + + all_ready = all(results) if results else True + if all_ready: + _BOS_DOWNLOAD_COMPLETED = True + + if background_jobs: + executor = _get_background_executor() + + def _make_callback(description: str): + def _background_done(fut): + try: + success = fut.result() + if success: + logger.info("后台 BOS 资源已就绪: %s", description) + else: + logger.warning("后台 BOS 资源同步失败: %s", description) + except Exception as exc: + logger.warning("后台 BOS 资源同步异常: %s (%s)", description, exc) + finally: + with _BOS_DOWNLOAD_LOCK: + if fut in _BOS_BACKGROUND_FUTURES: + _BOS_BACKGROUND_FUTURES.remove(fut) + return _background_done + + for job in background_jobs: + logger.info( + "后台同步 BOS 资源: %s (prefix=%s -> %s)", + job["description"], + job["prefix"], + job["destination"], + ) + future = executor.submit( + download_bos_directory, + job["prefix"], + job["destination"], + force_download=force_download, + ) + future.add_done_callback(_make_callback(job["description"])) + _BOS_BACKGROUND_FUTURES.append(future) + + return all_ready + + +def upload_file_to_bos(file_path: str, object_name: str | None = None) -> bool: + """ + 将指定文件同步上传到 BOS,失败不会抛出异常。 + :param file_path: 本地文件路径 + :param object_name: BOS 对象名称(可选) + :return: 是否成功上传 + """ + if not BOS_UPLOAD_ENABLED: + return False + + start_time = time.perf_counter() + expanded_path = os.path.abspath(os.path.expanduser(file_path)) + if not os.path.isfile(expanded_path): + return False + + if not _is_path_under_images_dir(expanded_path): + # 仅上传 IMAGES_DIR 内的文件,避免将临时文件同步至 BOS + return False + + try: + file_stat = os.stat(expanded_path) + except OSError: + return False + + if _get_bos_client() is None: + return False + + # 生成对象名称 + if object_name: + object_key = object_name.strip("/ ") + else: + base_name = os.path.basename(expanded_path) + if BOS_IMAGE_DIR: + object_key = "/".join( + part.strip("/ ") for part in (BOS_IMAGE_DIR, base_name) if part + ) + else: + object_key = base_name + + mtime_ns = getattr(file_stat, "st_mtime_ns", int(file_stat.st_mtime * 1_000_000_000)) + cache_signature = (mtime_ns, file_stat.st_size) + cache_key = (expanded_path, object_key) + + with _BOS_UPLOAD_CACHE_LOCK: + cached_signature = _BOS_UPLOAD_CACHE.get(cache_key) + if cached_signature is not None: + _BOS_UPLOAD_CACHE.move_to_end(cache_key) + + if cached_signature == cache_signature: + elapsed_ms = (time.perf_counter() - start_time) * 1000 + logger.info("文件已同步至 BOS(跳过重复上传,耗时 %.1f ms): %s", elapsed_ms, object_key) + return True + + def _do_upload(mode_label: str) -> bool: + client_inner = _get_bos_client() + if client_inner is None: + return False + upload_start = time.perf_counter() + try: + client_inner.upload_file(expanded_path, BOS_BUCKET_NAME, object_key) + elapsed_ms = (time.perf_counter() - upload_start) * 1000 + logger.info("文件已同步至 BOS(%s,耗时 %.1f ms): %s", mode_label, elapsed_ms, object_key) + with _BOS_UPLOAD_CACHE_LOCK: + _BOS_UPLOAD_CACHE[cache_key] = cache_signature + _BOS_UPLOAD_CACHE.move_to_end(cache_key) + while len(_BOS_UPLOAD_CACHE) > _BOS_UPLOAD_CACHE_MAX: + _BOS_UPLOAD_CACHE.popitem(last=False) + return True + except (ClientError, BotoCoreError, Exception) as exc: + logger.warning("上传到 BOS 失败(%s,%s): %s", object_key, mode_label, exc) + return False + + return _do_upload("同步") + + +def delete_file_from_bos(file_path: str | None = None, + object_name: str | None = None) -> bool: + """ + 删除 BOS 中的指定对象,失败不会抛出异常。 + :param file_path: 本地文件路径(可选,用于推导文件名) + :param object_name: BOS 对象名称(可选,优先使用) + :return: 是否成功删除 + """ + if not BOS_UPLOAD_ENABLED: + return False + + client = _get_bos_client() + if client is None: + return False + + key_candidate = object_name.strip("/ ") if object_name else "" + + if not key_candidate and file_path: + base_name = os.path.basename( + os.path.abspath(os.path.expanduser(file_path))) + key_candidate = base_name.strip() + + if not key_candidate: + return False + + if BOS_IMAGE_DIR: + object_key = "/".join( + part.strip("/ ") for part in (BOS_IMAGE_DIR, key_candidate) if part + ) + else: + object_key = key_candidate + + try: + client.delete_object(Bucket=BOS_BUCKET_NAME, Key=object_key) + logger.info(f"已从 BOS 删除文件: {object_key}") + return True + except (ClientError, BotoCoreError, Exception) as e: + logger.warning(f"删除 BOS 文件失败({object_key}): {e}") + return False + + +def image_to_base64(image: np.ndarray) -> str: + """将OpenCV图像转换为base64字符串""" + if image is None or image.size == 0: + return "" + _, buffer = cv2.imencode(".webp", image, [cv2.IMWRITE_WEBP_QUALITY, 90]) + img_base64 = base64.b64encode(buffer).decode("utf-8") + return f"data:image/webp;base64,{img_base64}" + + +def save_base64_to_unique_file( + base64_string: str, output_dir: str = "output_images" +) -> str | None: + """ + 将带有MIME类型前缀的Base64字符串解码并保存到本地。 + 文件名格式为: {md5_hash}_{timestamp}.{extension} + """ + os.makedirs(output_dir, exist_ok=True) + + try: + match = re.match(r"data:(image/\w+);base64,(.+)", base64_string) + if match: + mime_type = match.group(1) + base64_data = match.group(2) + else: + mime_type = "image/jpeg" + base64_data = base64_string + + extension_map = { + "image/jpeg": "jpg", + "image/png": "png", + "image/gif": "gif", + "image/webp": "webp", + } + file_extension = extension_map.get(mime_type, "webp") + + decoded_data = base64.b64decode(base64_data) + + except (ValueError, TypeError, base64.binascii.Error) as e: + logger.error(f"Base64 decoding failed: {e}") + return None + + md5_hash = hashlib.md5(base64_data.encode("utf-8")).hexdigest() + filename = f"{md5_hash}.{file_extension}" + file_path = os.path.join(output_dir, filename) + + try: + with open(file_path, "wb") as f: + f.write(decoded_data) + return file_path + except IOError as e: + logger.error(f"File writing failed: {e}") + return None + + +def human_readable_size(size_bytes): + """人性化文件大小展示""" + for unit in ["B", "KB", "MB", "GB"]: + if size_bytes < 1024: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024 + return f"{size_bytes:.1f} TB" + + +def delete_file(file_path: str): + try: + os.remove(file_path) + logger.info(f"Deleted file: {file_path}") + except Exception as error: + logger.error(f"Failed to delete file {file_path}: {error}") + + +def move_file_to_archive(file_path: str): + try: + if not os.path.exists(IMAGES_DIR): + os.makedirs(IMAGES_DIR) + filename = os.path.basename(file_path) + destination = os.path.join(IMAGES_DIR, filename) + shutil.move(file_path, destination) + logger.info(f"Moved file to archive: {destination}") + except Exception as error: + logger.error(f"Failed to move file {file_path} to archive: {error}") + + +def save_image_high_quality( + image: np.ndarray, + output_path: str, + quality: int = SAVE_QUALITY, + *, + upload_to_bos: bool = True, +) -> bool: + """ + 保存图像,保持高质量,不进行压缩 + :param image: 图像数组 + :param output_path: 输出路径 + :param quality: WebP质量 (0-100),默认95 + :param upload_to_bos: 是否在写入后同步至 BOS + :return: 保存是否成功 + """ + try: + success, encoded_img = cv2.imencode( + ".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality] + ) + if not success: + logger.error(f"Image encoding failed: {output_path}") + return False + + with open(output_path, "wb") as f: + f.write(encoded_img) + + logger.info(f"High quality image saved successfully: {output_path}, quality: {quality}, size: {len(encoded_img) / 1024:.2f} KB") + if upload_to_bos: + upload_file_to_bos(output_path) + return True + except Exception as e: + logger.error(f"Failed to save image: {output_path}, error: {e}") + return False + + +def convert_numpy_types(obj): + """转换所有 numpy 类型为原生 Python 类型""" + if isinstance(obj, (np.float32, np.float64)): + return float(obj) + elif isinstance(obj, (np.int32, np.int64)): + return int(obj) + elif isinstance(obj, dict): + return {k: convert_numpy_types(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_numpy_types(i) for i in obj] + else: + return obj + + +def compress_image_by_quality(image: np.ndarray, quality: int, output_format: str = 'webp') -> tuple[bytes, dict]: + """ + 按质量压缩图像 + :param image: 输入图像 + :param quality: 压缩质量 (10-100) + :param output_format: 输出格式 ('jpg', 'png', 'webp') + :return: (压缩后的图像字节数据, 压缩信息) + """ + try: + height, width = image.shape[:2] + + if output_format.lower() == 'png': + # PNG使用压缩级别 (0-9),质量参数转换为压缩级别 + compression_level = max(0, min(9, int((100 - quality) / 10))) + success, encoded_img = cv2.imencode( + ".png", image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level] + ) + elif output_format.lower() == 'webp': + # WebP支持质量参数 + success, encoded_img = cv2.imencode( + ".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality] + ) + else: + # JPG格式 + success, encoded_img = cv2.imencode( + ".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality] + ) + + if not success: + raise Exception("图像编码失败") + + compressed_bytes = encoded_img.tobytes() + + info = { + 'original_dimensions': f"{width} × {height}", + 'compressed_dimensions': f"{width} × {height}", + 'quality': quality, + 'format': output_format.upper(), + 'size': len(compressed_bytes) + } + + return compressed_bytes, info + + except Exception as e: + logger.error(f"Failed to compress image by quality: {e}") + raise + + +def compress_image_by_dimensions(image: np.ndarray, target_width: int, target_height: int, + quality: int = 100, output_format: str = 'jpg') -> tuple[bytes, dict]: + """ + 按尺寸压缩图像 + :param image: 输入图像 + :param target_width: 目标宽度 + :param target_height: 目标高度 + :param quality: 压缩质量 + :param output_format: 输出格式 + :return: (压缩后的图像字节数据, 压缩信息) + """ + try: + original_height, original_width = image.shape[:2] + + # 调整图像尺寸 + resized_image = cv2.resize( + image, (target_width, target_height), + interpolation=cv2.INTER_AREA + ) + + # 按质量编码 + if output_format.lower() == 'png': + compression_level = max(0, min(9, int((100 - quality) / 10))) + success, encoded_img = cv2.imencode( + ".png", resized_image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level] + ) + elif output_format.lower() == 'webp': + success, encoded_img = cv2.imencode( + ".webp", resized_image, [cv2.IMWRITE_WEBP_QUALITY, quality] + ) + else: + success, encoded_img = cv2.imencode( + ".jpg", resized_image, [cv2.IMWRITE_JPEG_QUALITY, quality] + ) + + if not success: + raise Exception("图像编码失败") + + compressed_bytes = encoded_img.tobytes() + + info = { + 'original_dimensions': f"{original_width} × {original_height}", + 'compressed_dimensions': f"{target_width} × {target_height}", + 'quality': quality, + 'format': output_format.upper(), + 'size': len(compressed_bytes) + } + + return compressed_bytes, info + + except Exception as e: + logger.error(f"Failed to compress image by dimensions: {e}") + raise + + +def compress_image_by_file_size(image: np.ndarray, target_size_kb: float, + output_format: str = 'jpg') -> tuple[bytes, dict]: + """ + 按文件大小压缩图像 - 使用多阶段二分法精确控制大小 + :param image: 输入图像 + :param target_size_kb: 目标文件大小(KB) + :param output_format: 输出格式 + :return: (压缩后的图像字节数据, 压缩信息) + """ + try: + original_height, original_width = image.shape[:2] + target_size_bytes = int(target_size_kb * 1024) + + def encode_image(img, quality): + """编码图像并返回字节数据""" + if output_format.lower() == 'png': + compression_level = max(0, min(9, int((100 - quality) / 10))) + success, encoded_img = cv2.imencode( + ".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compression_level] + ) + elif output_format.lower() == 'webp': + success, encoded_img = cv2.imencode( + ".webp", img, [cv2.IMWRITE_WEBP_QUALITY, quality] + ) + else: + success, encoded_img = cv2.imencode( + ".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, quality] + ) + + if success: + return encoded_img.tobytes() + return None + + def find_best_scale_and_quality(target_bytes): + """寻找最佳的尺寸和质量组合""" + best_result = None + best_diff = float('inf') + + # 尝试多个尺寸比例 + test_scales = [1.0, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3] + + for scale in test_scales: + # 调整图像尺寸 + if scale < 1.0: + new_width = int(original_width * scale) + new_height = int(original_height * scale) + if new_width < 50 or new_height < 50: # 避免尺寸太小 + continue + working_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) + else: + working_image = image + new_width, new_height = original_width, original_height + + # 在这个尺寸下使用二分法寻找最佳质量 + min_q, max_q = 10, 100 + scale_best_result = None + scale_best_diff = float('inf') + + for _ in range(20): # 每个尺寸最多尝试20次质量调整 + current_quality = (min_q + max_q) // 2 + + compressed_bytes = encode_image(working_image, current_quality) + if not compressed_bytes: + break + + current_size = len(compressed_bytes) + size_diff = abs(current_size - target_bytes) + size_ratio = current_size / target_bytes + + # 如果找到精确匹配,立即返回 + if 0.99 <= size_ratio <= 1.01: # 1%误差以内 + return { + 'bytes': compressed_bytes, + 'scale': scale, + 'width': new_width, + 'height': new_height, + 'quality': current_quality, + 'size': current_size, + 'ratio': size_ratio + } + + # 记录该尺寸下的最佳结果 + if size_diff < scale_best_diff: + scale_best_diff = size_diff + scale_best_result = { + 'bytes': compressed_bytes, + 'scale': scale, + 'width': new_width, + 'height': new_height, + 'quality': current_quality, + 'size': current_size, + 'ratio': size_ratio + } + + # 二分法调整质量 + if current_size > target_bytes: + max_q = current_quality - 1 + else: + min_q = current_quality + 1 + + if min_q >= max_q: + break + + # 更新全局最佳结果 + if scale_best_result and scale_best_diff < best_diff: + best_diff = scale_best_diff + best_result = scale_best_result + + # 如果已经找到很好的结果(5%以内),可以提前结束 + if best_result and 0.95 <= best_result['ratio'] <= 1.05: + break + + return best_result + + logger.info(f"Starting multi-stage compression, target size: {target_size_bytes} bytes ({target_size_kb}KB)") + + # 寻找最佳组合 + result = find_best_scale_and_quality(target_size_bytes) + + if result: + error_percent = abs(result['ratio'] - 1) * 100 + logger.info(f"Compression completed: scale ratio {result['scale']:.2f}, quality {result['quality']}%, " + f"size {result['size']} bytes, error {error_percent:.2f}%") + + # 不管误差多大都返回最接近的结果,只记录警告 + if error_percent > 10: + if result['ratio'] < 0.5: # 压缩过度 + suggested_size = result['size'] / 1024 + logger.warning(f"Target size {target_size_kb}KB is too small, actually compressed to {suggested_size:.1f}KB, error {error_percent:.1f}%") + elif result['ratio'] > 2.0: # 无法达到目标 + suggested_size = result['size'] / 1024 + logger.warning(f"Target size {target_size_kb}KB is too large, minimum can be compressed to {suggested_size:.1f}KB, error {error_percent:.1f}%") + else: + logger.warning(f"Cannot achieve target accuracy, error {error_percent:.1f}%, returning closest result") + + info = { + 'original_dimensions': f"{original_width} × {original_height}", + 'compressed_dimensions': f"{result['width']} × {result['height']}", + 'quality': result['quality'], + 'format': output_format.upper(), + 'size': result['size'] + } + + return result['bytes'], info + else: + raise Exception(f"无法将图片压缩到目标大小 {target_size_kb}KB") + + except Exception as e: + logger.error(f"Failed to compress image by file size: {e}") + raise + + +def convert_image_format(image: np.ndarray, target_format: str, quality: int = 100) -> tuple[bytes, dict]: + """ + 转换图像格式 + :param image: 输入图像 + :param target_format: 目标格式 ('jpg', 'png', 'webp') + :param quality: 质量参数 + :return: (转换后的图像字节数据, 格式信息) + """ + try: + height, width = image.shape[:2] + + if target_format.lower() == 'png': + # PNG格式,使用压缩级别 + compression_level = 6 # 默认压缩级别 + success, encoded_img = cv2.imencode( + ".png", image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level] + ) + elif target_format.lower() == 'webp': + # WebP格式 + success, encoded_img = cv2.imencode( + ".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality] + ) + else: + # JPG格式 + success, encoded_img = cv2.imencode( + ".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality] + ) + + if not success: + raise Exception("图像格式转换失败") + + converted_bytes = encoded_img.tobytes() + + info = { + 'original_dimensions': f"{width} × {height}", + 'compressed_dimensions': f"{width} × {height}", + 'quality': quality if target_format.lower() != 'png' else 100, + 'format': target_format.upper(), + 'size': len(converted_bytes) + } + + return converted_bytes, info + + except Exception as e: + logger.error(f"Image format conversion failed: {e}") + raise + + +def save_image_with_transparency(image: np.ndarray, file_path: str) -> bool: + """ + 保存带透明通道的图像为PNG格式 + :param image: OpenCV图像数组(BGRA格式,包含alpha通道) + :param file_path: 保存路径 + :return: 保存是否成功 + """ + if image is None: + logger.error("Image is empty, cannot save") + return False + + try: + # 确保目录存在 + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + # 如果图像有4个通道(BGRA),转换为RGBA然后保存 + if len(image.shape) == 3 and image.shape[2] == 4: + # BGRA转换为RGBA + rgba_image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + elif len(image.shape) == 3 and image.shape[2] == 3: + # 如果是BGR格式,先转换为RGB,但这种情况不应该有透明度 + rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + rgba_image = np.dstack((rgb_image, np.full(rgb_image.shape[:2], 255, dtype=np.uint8))) + else: + logger.error("Image format does not support transparency saving") + return False + + # 使用PIL保存PNG + pil_image = Image.fromarray(rgba_image, 'RGBA') + pil_image.save(file_path, 'PNG', optimize=True) + + file_size = os.path.getsize(file_path) + logger.info(f"Transparent PNG image saved: {file_path}, size: {file_size/1024:.1f}KB") + upload_file_to_bos(file_path) + return True + + except Exception as e: + logger.error(f"Failed to save transparent PNG image: {e}") + return False diff --git a/vector_store.py b/vector_store.py new file mode 100644 index 0000000000000000000000000000000000000000..47ec5487a5498978dc0b74d0e1fc0c752b788ec2 --- /dev/null +++ b/vector_store.py @@ -0,0 +1,119 @@ +# vector_store.py +import logging +import os +import pickle + +import faiss +import numpy as np +import torch + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# 获取项目根目录 +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# 拼接 FAISS 索引目录 +FAISS_INDEX_DIR = os.environ.get('FAISS_INDEX_DIR', os.path.join(PROJECT_ROOT, 'faiss', 'data')) +os.makedirs(FAISS_INDEX_DIR, exist_ok=True) + +# 最终路径 +FAISS_INDEX_PATH = os.path.join(FAISS_INDEX_DIR, "index.faiss") +ID_MAP_PATH = os.path.join(FAISS_INDEX_DIR, "id_map.pkl") + +# ViT-B/16 为 512,ViT-L/14 通常为 768 或 1024 +VECTOR_DIM = int(os.environ.get("VECTOR_DIM", 512)) + +# 全局变量 +index = None +id_map = None + +def init_vector_store(): + """初始化向量存储""" + global index, id_map + try: + # 初始化或加载 + if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(ID_MAP_PATH): + index = faiss.read_index(FAISS_INDEX_PATH) + with open(ID_MAP_PATH, "rb") as f: + id_map = pickle.load(f) + logger.info(f"Vector store loaded successfully path={FAISS_INDEX_DIR}, contains {len(id_map)} vectors") + else: + index = faiss.IndexFlatIP(VECTOR_DIM) # 归一化后可以用内积代替余弦相似度 + id_map = [] + logger.info("Initializing new vector store") + return True + except Exception as e: + logger.error(f"Vector store initialization failed: {e}") + return False + +def is_vector_store_available(): + """检查向量存储是否可用""" + return index is not None and id_map is not None + +def check_image_exists(image_path: str) -> bool: + """ + 检查图像是否已经在向量库中存在 + Args: + image_path: 图像路径/标识 + Returns: + bool: 如果存在返回True,否则返回False + """ + try: + if not is_vector_store_available(): + return False + return image_path in id_map + except Exception as e: + logger.error(f"Failed to check if image exists: {str(e)}") + return False + +def add_image_vector(image_path: str, vector: torch.Tensor): + """添加图片向量到存储""" + if not is_vector_store_available(): + raise RuntimeError("向量存储未初始化") + + np_vector = vector.squeeze(0).numpy().astype('float32') + index.add(np_vector[np.newaxis, :]) + id_map.append(image_path) + save_index() + logger.info(f"Image vector added: {image_path}") + +def search_text_vector(vector: torch.Tensor, top_k=5): + """搜索文本向量""" + if not is_vector_store_available(): + raise RuntimeError("向量存储未初始化") + + np_vector = vector.squeeze(0).numpy().astype('float32') + scores, indices = index.search(np_vector[np.newaxis, :], top_k) + + if indices is None or len(indices[0]) == 0: + return [] + + results = [ + (id_map[i], float(scores[0][j])) + for j, i in enumerate(indices[0]) + if i < len(id_map) and i != -1 + ] + return results + +def save_index(): + """保存索引文件""" + try: + faiss.write_index(index, FAISS_INDEX_PATH) + with open(ID_MAP_PATH, "wb") as f: + pickle.dump(id_map, f) + logger.info("Vector index saved") + except Exception as e: + logger.error(f"Failed to save vector index: {e}") + +def get_vector_store_info(): + """获取向量存储信息""" + if not is_vector_store_available(): + return {"status": "not_initialized", "count": 0} + + return { + "status": "available", + "count": len(id_map), + "vector_dim": VECTOR_DIM, + "index_path": FAISS_INDEX_PATH + } diff --git a/wx_access_token.py b/wx_access_token.py new file mode 100644 index 0000000000000000000000000000000000000000..82797deb99dab8087e769e17869954710aae11f9 --- /dev/null +++ b/wx_access_token.py @@ -0,0 +1,122 @@ +from typing import Optional + +import aiohttp + +from config import access_token_cache, WECHAT_APPID, WECHAT_SECRET, logger + + +async def get_access_token() -> Optional[str]: + """获取微信 stable access_token (推荐方式)""" + import time + + # 检查缓存是否有效 + if access_token_cache["token"] and time.time() < access_token_cache["expires_at"]: + return access_token_cache["token"] + # 使用新的 getStableAccessToken 接口 + url = "https://api.weixin.qq.com/cgi-bin/stable_token" + data = { + "grant_type": "client_credential", + "appid": WECHAT_APPID, + "secret": WECHAT_SECRET, + "force_refresh": False, # 是否强制刷新 + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data) as response: + if response.status == 200: + result = await response.json() + logger.debug(f"getStableAccessToken response: {result}") + if "access_token" in result: + access_token_cache["token"] = result["access_token"] + access_token_cache["expires_at"] = ( + time.time() + result.get("expires_in", 7200) - 300 + ) + expires_time = access_token_cache["expires_at"] + logger.info( + f"成功获取 stable access_token expires_time={expires_time}" + ) + return result["access_token"] + else: + logger.error(f"Failed to get stable access_token: {result}") + else: + logger.error(f"Failed to request stable access_token: {response.status}") + except Exception as e: + logger.error(f"Exception while getting stable access_token: {str(e)}") + + return None + + +async def get_access_token_old() -> Optional[str]: + """获取微信 access_token""" + import time + + # 检查缓存是否有效 + if access_token_cache["token"] and time.time() < access_token_cache["expires_at"]: + return access_token_cache["token"] + # 获取新的 access_token + url = "https://api.weixin.qq.com/cgi-bin/token" + params = { + "grant_type": "client_credential", + "appid": WECHAT_APPID, + "secret": WECHAT_SECRET, + } + + try: + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params) as response: + if response.status == 200: + data = await response.json() + if "access_token" in data: + access_token_cache["token"] = data["access_token"] + access_token_cache["expires_at"] = ( + time.time() + data.get("expires_in", 7200) - 300 + ) # 提前5分钟过期 + logger.info("Successfully obtained WeChat access_token...") + return data["access_token"] + else: + logger.error(f"Failed to get access_token, returned content: {data}") + return None + else: + logger.error(f"Failed to get access_token, status={response.status}") + return None + except Exception as e: + logger.error(f"Failed to get access_token: {str(e)}") + + return None + + +async def check_image_security(image_data: bytes) -> bool: + """ + 检测图片内容安全 + :param image_data: 图片二进制数据 + :return: True表示安全,False表示有风险 + """ + access_token = await get_access_token() + if not access_token: + logger.warning("Unable to get access_token, skipping security check") + return True # 获取token失败时允许继续,避免影响正常用户 + url = f"https://api.weixin.qq.com/wxa/img_sec_check?access_token={access_token}" + try: + async with aiohttp.ClientSession() as session: + # 微信API要求使用 multipart/form-data 格式 + data = aiohttp.FormData() + data.add_field("media", image_data, content_type="image/jpeg") + async with session.post(url, data=data, timeout=10) as response: + if response.status == 200: + result = await response.json() + logger.info(f"Checking image content safety...result={result}") + if result.get("errcode") == 0: + return True # 安全 + elif result.get("errcode") == 87014: + logger.warning("Image content contains illegal content...") + return False + else: + logger.warning(f"Image security check returned error: {result}") + return True # 其他错误时允许继续 + else: + logger.warning(f"Image security check request failed: {response.status}") + return True + except Exception as e: + logger.error(f"Image security check exception: {str(e)}") + return True # 异常时允许继续,避免影响正常用户