Spaces:
Paused
Paused
chawin.chen
commited on
Commit
·
7a6cb13
1
Parent(s):
e499f6c
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +66 -0
- Dockerfile +76 -0
- README.md +4 -4
- anime_stylizer.py +427 -0
- api_routes.py +0 -0
- app.py +170 -0
- build.sh +4 -0
- cleanup_scheduler.py +202 -0
- clip_utils.py +65 -0
- config.py +543 -0
- database.py +377 -0
- ddcolor_colorizer.py +301 -0
- debug_colorize.py +238 -0
- face_analyzer.py +1099 -0
- facial_analyzer.py +912 -0
- gfpgan_restorer.py +96 -0
- install.sh +2 -0
- models.py +69 -0
- push.sh +2 -0
- realesrgan_upscaler.py +235 -0
- rembg_processor.py +136 -0
- requirements.txt +72 -0
- rvm_processor.py +132 -0
- start_local.sh +52 -0
- test/celebrity_crawler.py +227 -0
- test/celebrity_crawler.pyc +0 -0
- test/decode_celeb_dataset.py +86 -0
- test/decode_celeb_dataset.pyc +0 -0
- test/dow_img.py +24 -0
- test/dow_img.pyc +0 -0
- test/howcuteami.py +202 -0
- test/howcuteami.pyc +0 -0
- test/import_history_images.py +162 -0
- test/import_history_images.pyc +0 -0
- test/remove_duplicate_celeb_images.py +99 -0
- test/remove_duplicate_celeb_images.pyc +0 -0
- test/remove_faceless_images.py +169 -0
- test/remove_faceless_images.pyc +0 -0
- test/test_deepface.py +38 -0
- test/test_deepface.pyc +0 -0
- test/test_main.http +11 -0
- test/test_rvm_infer.py +46 -0
- test/test_rvm_infer.pyc +0 -0
- test/test_score.py +26 -0
- test/test_score.pyc +0 -0
- test/test_score_adjustment_demo.py +30 -0
- test/test_score_adjustment_demo.pyc +0 -0
- test/test_sky.py +15 -0
- test/test_sky.pyc +0 -0
- test_tensorflow.py +19 -0
.gitignore
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
HELP.md
|
| 2 |
+
target/
|
| 3 |
+
output/
|
| 4 |
+
!.mvn/wrapper/maven-wrapper.jar
|
| 5 |
+
!**/src/main/**/target/
|
| 6 |
+
!**/src/test/**/target/
|
| 7 |
+
|
| 8 |
+
.flattened-pom.xml
|
| 9 |
+
|
| 10 |
+
### STS ###
|
| 11 |
+
.apt_generated
|
| 12 |
+
.classpath
|
| 13 |
+
.factorypath
|
| 14 |
+
.project
|
| 15 |
+
.settings
|
| 16 |
+
.springBeans
|
| 17 |
+
.sts4-cache
|
| 18 |
+
|
| 19 |
+
### IntelliJ IDEA ###
|
| 20 |
+
.idea
|
| 21 |
+
*.iws
|
| 22 |
+
*.iml
|
| 23 |
+
*.ipr
|
| 24 |
+
|
| 25 |
+
### NetBeans ###
|
| 26 |
+
/nbproject/private/
|
| 27 |
+
/nbbuild/
|
| 28 |
+
/dist/
|
| 29 |
+
/nbdist/
|
| 30 |
+
/.nb-gradle/
|
| 31 |
+
build/
|
| 32 |
+
!**/src/main/**/build/
|
| 33 |
+
!**/src/test/**/build/
|
| 34 |
+
|
| 35 |
+
### VS Code ###
|
| 36 |
+
.vscode/
|
| 37 |
+
|
| 38 |
+
### LOG ###
|
| 39 |
+
logs/
|
| 40 |
+
|
| 41 |
+
*.class
|
| 42 |
+
|
| 43 |
+
**/node_modules/
|
| 44 |
+
/*.log
|
| 45 |
+
/output/
|
| 46 |
+
/faiss/
|
| 47 |
+
/web/facelist-web/
|
| 48 |
+
**/._*
|
| 49 |
+
__pycache__/
|
| 50 |
+
.DS_Store
|
| 51 |
+
*.pth
|
| 52 |
+
/data/celebrity_faces/ds_model_arcface_detector_retinaface_aligned_normalization_base_expand_0.pkl
|
| 53 |
+
/data/celebrity_faces/jpeg_6c06eca6.jpeg
|
| 54 |
+
/data/celebrity_faces/jpeg_51e1394b.jpeg
|
| 55 |
+
/data/celebrity_faces/jpeg_66fee390.jpeg
|
| 56 |
+
/data/celebrity_faces/jpeg_70b86102.jpeg
|
| 57 |
+
/data/celebrity_faces/jpeg_406b961a.jpeg
|
| 58 |
+
/data/celebrity_faces/jpeg_1321f87f.jpeg
|
| 59 |
+
/data/celebrity_faces/jpeg_b56ae384.jpeg
|
| 60 |
+
/data/celebrity_faces/jpeg_c07cdb46.jpeg
|
| 61 |
+
/data/celebrity_faces/jpeg_c7353005.jpeg
|
| 62 |
+
/data/celebrity_faces/jpeg_d4cb0602.jpeg
|
| 63 |
+
/data/celebrity_faces/jpeg_dbb64030.jpeg
|
| 64 |
+
/data/celebrity_faces/jpeg_fc652ad4.jpeg
|
| 65 |
+
/data/celebrity_faces/jpeg_fd6b0869.jpeg
|
| 66 |
+
/data/celebrity_embeddings.db
|
Dockerfile
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
ENV TZ=Asia/Shanghai \
|
| 4 |
+
OUTPUT_DIR=/opt/data/output \
|
| 5 |
+
IMAGES_DIR=/opt/data/images \
|
| 6 |
+
MODELS_PATH=/opt/data/models \
|
| 7 |
+
DEEPFACE_HOME=/opt/data/models \
|
| 8 |
+
FAISS_INDEX_DIR=/opt/data/faiss \
|
| 9 |
+
CELEBRITY_SOURCE_DIR=/opt/data/chinese_celeb_dataset \
|
| 10 |
+
GENDER_CONFIDENCE=1 \
|
| 11 |
+
UPSCALE_SIZE=2 \
|
| 12 |
+
AGE_CONFIDENCE=0.1 \
|
| 13 |
+
DRAW_SCORE=true \
|
| 14 |
+
FACE_CONFIDENCE=0.7 \
|
| 15 |
+
ENABLE_DDCOLOR=true \
|
| 16 |
+
ENABLE_GFPGAN=true \
|
| 17 |
+
ENABLE_REALESRGAN=true \
|
| 18 |
+
ENABLE_ANIME_STYLE=true \
|
| 19 |
+
ENABLE_RVM=true \
|
| 20 |
+
ENABLE_REMBG=true \
|
| 21 |
+
ENABLE_CLIP=false \
|
| 22 |
+
CLEANUP_INTERVAL_HOURS=1 \
|
| 23 |
+
CLEANUP_AGE_HOURS=1 \
|
| 24 |
+
BEAUTY_ADJUST_GAMMA=0.8 \
|
| 25 |
+
BEAUTY_ADJUST_MIN=1.0 \
|
| 26 |
+
BEAUTY_ADJUST_MAX=9.0 \
|
| 27 |
+
ENABLE_ANIME_PRELOAD=false \
|
| 28 |
+
ENABLE_LOGGING=true \
|
| 29 |
+
BEAUTY_ADJUST_ENABLED=true \
|
| 30 |
+
RVM_LOCAL_REPO=/opt/data/models/RobustVideoMatting \
|
| 31 |
+
RVM_WEIGHTS_PATH=/opt/data/models/torch/hub/checkpoints/rvm_resnet50.pth \
|
| 32 |
+
RVM_MODEL=resnet50 \
|
| 33 |
+
AUTO_INIT_GFPGAN=false \
|
| 34 |
+
AUTO_INIT_DDCOLOR=false \
|
| 35 |
+
AUTO_INIT_REALESRGAN=false \
|
| 36 |
+
AUTO_INIT_ANIME_STYLE=false \
|
| 37 |
+
AUTO_INIT_CLIP=false \
|
| 38 |
+
AUTO_INIT_RVM=false \
|
| 39 |
+
AUTO_INIT_REMBG=false \
|
| 40 |
+
ENABLE_WARMUP=true \
|
| 41 |
+
REALESRGAN_MODEL=realesr-general-x4v3 \
|
| 42 |
+
CELEBRITY_FIND_THRESHOLD=0.87 \
|
| 43 |
+
FEMALE_AGE_ADJUSTMENT=4 \
|
| 44 |
+
HOSTNAME=HG
|
| 45 |
+
|
| 46 |
+
RUN mkdir -p /opt/data/chinese_celeb_dataset /opt/data/faiss /opt/data/models /opt/data/images /opt/data/output
|
| 47 |
+
WORKDIR /app
|
| 48 |
+
COPY requirements.txt .
|
| 49 |
+
COPY *.py /app/
|
| 50 |
+
|
| 51 |
+
# 安装必要的系统工具和依赖
|
| 52 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 53 |
+
build-essential \
|
| 54 |
+
cmake \
|
| 55 |
+
git \
|
| 56 |
+
wget \
|
| 57 |
+
curl \
|
| 58 |
+
ca-certificates \
|
| 59 |
+
libopenblas-dev \
|
| 60 |
+
liblapack-dev \
|
| 61 |
+
libx11-dev \
|
| 62 |
+
libgtk-3-dev \
|
| 63 |
+
libboost-python-dev \
|
| 64 |
+
libglib2.0-0 \
|
| 65 |
+
libsm6 \
|
| 66 |
+
libxext6 \
|
| 67 |
+
libxrender-dev \
|
| 68 |
+
libgomp1 \
|
| 69 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 70 |
+
|
| 71 |
+
RUN pip install --upgrade pip
|
| 72 |
+
# 安装所有依赖 - 现在可以一次性完成
|
| 73 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 74 |
+
EXPOSE 7860
|
| 75 |
+
CMD ["uvicorn", "app:app", "--workers", "1", "--loop", "asyncio", "--http", "httptools", "--host", "0.0.0.0", "--port", "7860", "--timeout-keep-alive", "600"]
|
| 76 |
+
|
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Picpocket
|
| 3 |
+
emoji: 🔥
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
anime_stylizer.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
from config import logger
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AnimeStylizer:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
start_time = time.perf_counter()
|
| 13 |
+
self.stylizers = {} # 存储不同风格的模型
|
| 14 |
+
self.current_style = None
|
| 15 |
+
self.current_stylizer = None
|
| 16 |
+
|
| 17 |
+
# 检查是否启用Anime Style功能
|
| 18 |
+
from config import ENABLE_ANIME_STYLE
|
| 19 |
+
if ENABLE_ANIME_STYLE:
|
| 20 |
+
self._initialize_models()
|
| 21 |
+
else:
|
| 22 |
+
logger.info("Anime Style feature is disabled, skipping model initialization")
|
| 23 |
+
init_time = time.perf_counter() - start_time
|
| 24 |
+
if hasattr(self, 'model_configs') and len(self.model_configs) > 0:
|
| 25 |
+
logger.info(f"AnimeStylizer initialized successfully, time: {init_time:.3f}s")
|
| 26 |
+
else:
|
| 27 |
+
logger.info(f"AnimeStylizer initialization completed but not available, time: {init_time:.3f}s")
|
| 28 |
+
|
| 29 |
+
def _initialize_models(self):
|
| 30 |
+
"""初始化所有Anime Style模型(使用ModelScope)"""
|
| 31 |
+
try:
|
| 32 |
+
logger.info("Initializing multiple Anime Style models (using ModelScope)...")
|
| 33 |
+
|
| 34 |
+
# 添加torch类型兼容性补丁
|
| 35 |
+
import torch
|
| 36 |
+
if not hasattr(torch, 'uint64'):
|
| 37 |
+
logger.info("Adding torch.uint64 compatibility patch...")
|
| 38 |
+
torch.uint64 = torch.int64 # 使用int64作为uint64的替代
|
| 39 |
+
if not hasattr(torch, 'uint32'):
|
| 40 |
+
logger.info("Adding torch.uint32 compatibility patch...")
|
| 41 |
+
torch.uint32 = torch.int32 # 使用int32作为uint32的替代
|
| 42 |
+
if not hasattr(torch, 'uint16'):
|
| 43 |
+
logger.info("Adding torch.uint16 compatibility patch...")
|
| 44 |
+
torch.uint16 = torch.int16 # 使用int16作为uint16的替代
|
| 45 |
+
|
| 46 |
+
# 导入ModelScope相关模块
|
| 47 |
+
from modelscope.outputs import OutputKeys
|
| 48 |
+
from modelscope.pipelines import pipeline
|
| 49 |
+
from modelscope.utils.constant import Tasks
|
| 50 |
+
|
| 51 |
+
self.OutputKeys = OutputKeys
|
| 52 |
+
|
| 53 |
+
# 定义所有可用的模型和风格
|
| 54 |
+
self.model_configs = {
|
| 55 |
+
"handdrawn": {
|
| 56 |
+
"model_id": "iic/cv_unet_person-image-cartoon-handdrawn_compound-models",
|
| 57 |
+
"name": "手绘风格",
|
| 58 |
+
"description": "手绘动漫风格 - 传统手绘感觉,线条清晰"
|
| 59 |
+
},
|
| 60 |
+
"disney": {
|
| 61 |
+
"model_id": "iic/cv_unet_person-image-cartoon-3d_compound-models",
|
| 62 |
+
"name": "迪士尼风格",
|
| 63 |
+
"description": "迪士尼风格 - 立体感强,色彩鲜艳"
|
| 64 |
+
},
|
| 65 |
+
"illustration": {
|
| 66 |
+
"model_id": "iic/cv_unet_person-image-cartoon-sd-design_compound-models",
|
| 67 |
+
"name": "插画风格",
|
| 68 |
+
"description": "插画风格 - 现代插画设计感"
|
| 69 |
+
},
|
| 70 |
+
"artstyle": {
|
| 71 |
+
"model_id": "iic/cv_unet_person-image-cartoon-artstyle_compound-models",
|
| 72 |
+
"name": "艺术风格",
|
| 73 |
+
"description": "艺术风格 - 独特的艺术表现力"
|
| 74 |
+
},
|
| 75 |
+
"anime": {
|
| 76 |
+
"model_id": "iic/cv_unet_person-image-cartoon_compound-models",
|
| 77 |
+
"name": "二次元风格",
|
| 78 |
+
"description": "二次元风格 - 经典动漫角色风格"
|
| 79 |
+
},
|
| 80 |
+
"sketch": {
|
| 81 |
+
"model_id": "iic/cv_unet_person-image-cartoon-sketch_compound-models",
|
| 82 |
+
"name": "素描风格",
|
| 83 |
+
"description": "素描风格 - 黑白素描画效果"
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
logger.info(f"Defined {len(self.model_configs)} anime style model configurations")
|
| 88 |
+
logger.info("Models will be loaded on-demand when first used to save memory")
|
| 89 |
+
|
| 90 |
+
# 检查是否启用预加载
|
| 91 |
+
try:
|
| 92 |
+
from config import ENABLE_ANIME_PRELOAD
|
| 93 |
+
if ENABLE_ANIME_PRELOAD:
|
| 94 |
+
logger.info("Enabling anime style model preloading...")
|
| 95 |
+
self.preload_models()
|
| 96 |
+
else:
|
| 97 |
+
logger.info("Anime style model preloading is disabled, will be loaded on-demand when first used")
|
| 98 |
+
except ImportError:
|
| 99 |
+
logger.info("Anime style model preloading configuration not found, will be loaded on-demand when first used")
|
| 100 |
+
|
| 101 |
+
except ImportError as e:
|
| 102 |
+
logger.error(f"ModelScope module import failed: {e}")
|
| 103 |
+
self.model_configs = {}
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.error(f"Anime Style model initialization failed: {e}")
|
| 106 |
+
self.model_configs = {}
|
| 107 |
+
|
| 108 |
+
def _load_model(self, style_type):
|
| 109 |
+
"""按需加载指定风格的模型"""
|
| 110 |
+
if style_type not in self.model_configs:
|
| 111 |
+
logger.error(f"Unsupported style type: {style_type}")
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
if style_type in self.stylizers:
|
| 115 |
+
logger.info(f"Model {style_type} already loaded, using directly")
|
| 116 |
+
return True
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
from modelscope.pipelines import pipeline
|
| 120 |
+
from modelscope.utils.constant import Tasks
|
| 121 |
+
|
| 122 |
+
config = self.model_configs[style_type]
|
| 123 |
+
logger.info(f"Loading {config['name']} model: {config['model_id']}")
|
| 124 |
+
|
| 125 |
+
# 根据模型类型选择合适的任务类型
|
| 126 |
+
if "stable_diffusion" in config["model_id"]:
|
| 127 |
+
# Stable Diffusion 系列模型使用文生图任务类型
|
| 128 |
+
task_type = Tasks.text_to_image_synthesis
|
| 129 |
+
logger.info(f"Using text_to_image_synthesis task type to load Stable Diffusion model")
|
| 130 |
+
else:
|
| 131 |
+
# UNet 系列模型使用人像风格化任务
|
| 132 |
+
task_type = Tasks.image_portrait_stylization
|
| 133 |
+
logger.info(f"Using image_portrait_stylization task type to load UNet model")
|
| 134 |
+
|
| 135 |
+
stylizer = pipeline(task_type, model=config["model_id"])
|
| 136 |
+
self.stylizers[style_type] = stylizer
|
| 137 |
+
|
| 138 |
+
logger.info(f"{config['name']} model loaded successfully")
|
| 139 |
+
return True
|
| 140 |
+
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.error(f"Failed to load {style_type} model: {e}")
|
| 143 |
+
return False
|
| 144 |
+
|
| 145 |
+
def preload_models(self, style_types=None):
|
| 146 |
+
"""
|
| 147 |
+
预加载指定的动漫风格模型
|
| 148 |
+
:param style_types: 要预加载的风格类型列表,如果为None则预加载所有模型
|
| 149 |
+
"""
|
| 150 |
+
if not self.is_available():
|
| 151 |
+
logger.warning("Anime Style module is not available, cannot preload models")
|
| 152 |
+
return
|
| 153 |
+
|
| 154 |
+
if style_types is None:
|
| 155 |
+
style_types = list(self.model_configs.keys())
|
| 156 |
+
elif isinstance(style_types, str):
|
| 157 |
+
style_types = [style_types]
|
| 158 |
+
|
| 159 |
+
logger.info(f"Starting to preload anime style models: {style_types}")
|
| 160 |
+
|
| 161 |
+
successful_loads = []
|
| 162 |
+
failed_loads = []
|
| 163 |
+
|
| 164 |
+
for style_type in style_types:
|
| 165 |
+
if style_type not in self.model_configs:
|
| 166 |
+
logger.warning(f"Unknown style type: {style_type}, skipping preload")
|
| 167 |
+
failed_loads.append(style_type)
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
logger.info(f"Preloading model: {self.model_configs[style_type]['name']} ({style_type})")
|
| 172 |
+
if self._load_model(style_type):
|
| 173 |
+
successful_loads.append(style_type)
|
| 174 |
+
logger.info(f"✓ Successfully preloaded: {self.model_configs[style_type]['name']}")
|
| 175 |
+
else:
|
| 176 |
+
failed_loads.append(style_type)
|
| 177 |
+
logger.error(f"✗ Preload failed: {self.model_configs[style_type]['name']}")
|
| 178 |
+
except Exception as e:
|
| 179 |
+
logger.error(f"✗ Exception occurred while preloading model {style_type}: {e}")
|
| 180 |
+
failed_loads.append(style_type)
|
| 181 |
+
|
| 182 |
+
if successful_loads:
|
| 183 |
+
logger.info(f"Successfully preloaded models ({len(successful_loads)}): {successful_loads}")
|
| 184 |
+
if failed_loads:
|
| 185 |
+
logger.warning(f"Failed to preload models ({len(failed_loads)}): {failed_loads}")
|
| 186 |
+
|
| 187 |
+
logger.info(f"Anime style model preloading completed, success: {len(successful_loads)}/{len(style_types)}")
|
| 188 |
+
|
| 189 |
+
def get_loaded_models(self):
|
| 190 |
+
"""
|
| 191 |
+
获取已加载的模型列表
|
| 192 |
+
:return: 已加载的模型风格类型列表
|
| 193 |
+
"""
|
| 194 |
+
return list(self.stylizers.keys())
|
| 195 |
+
|
| 196 |
+
def is_model_loaded(self, style_type):
|
| 197 |
+
"""
|
| 198 |
+
检查指定风格的模型是否已加载
|
| 199 |
+
:param style_type: 风格类型
|
| 200 |
+
:return: 是否已加载
|
| 201 |
+
"""
|
| 202 |
+
return style_type in self.stylizers
|
| 203 |
+
|
| 204 |
+
def get_preload_status(self):
|
| 205 |
+
"""
|
| 206 |
+
获取模型预加载状态
|
| 207 |
+
:return: 包含预加载状态的字典
|
| 208 |
+
"""
|
| 209 |
+
total_models = len(self.model_configs)
|
| 210 |
+
loaded_models = len(self.stylizers)
|
| 211 |
+
|
| 212 |
+
status = {
|
| 213 |
+
"total_models": total_models,
|
| 214 |
+
"loaded_models": loaded_models,
|
| 215 |
+
"preload_ratio": f"{loaded_models}/{total_models}",
|
| 216 |
+
"preload_percentage": round((loaded_models / total_models * 100) if total_models > 0 else 0, 1),
|
| 217 |
+
"available_styles": list(self.model_configs.keys()),
|
| 218 |
+
"loaded_styles": list(self.stylizers.keys()),
|
| 219 |
+
"unloaded_styles": [style for style in self.model_configs.keys() if style not in self.stylizers]
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
return status
|
| 223 |
+
|
| 224 |
+
def is_available(self):
|
| 225 |
+
"""检查Anime Stylizer是否可用"""
|
| 226 |
+
return hasattr(self, 'model_configs') and len(self.model_configs) > 0
|
| 227 |
+
|
| 228 |
+
def stylize_image(self, image, style_type="disney"):
|
| 229 |
+
"""
|
| 230 |
+
对图像进行动漫风格化
|
| 231 |
+
:param image: 输入图像 (numpy array, BGR格式)
|
| 232 |
+
:param style_type: 动漫风格类型,支持的类型:
|
| 233 |
+
"handdrawn" - 手绘风格
|
| 234 |
+
"disney" - 迪士尼风格 (默认)
|
| 235 |
+
"illustration" - 插画风格
|
| 236 |
+
"flat" - 扁平风格
|
| 237 |
+
"clipart" - 剪贴画风格
|
| 238 |
+
"watercolor" - 水彩风格
|
| 239 |
+
"artstyle" - 艺术风格
|
| 240 |
+
"anime" - 二次元风格
|
| 241 |
+
"sketch" - 素描风格
|
| 242 |
+
:return: 动漫风格化后的图像 (numpy array, BGR格式)
|
| 243 |
+
"""
|
| 244 |
+
if not self.is_available():
|
| 245 |
+
logger.error("Anime Style model not initialized")
|
| 246 |
+
return image
|
| 247 |
+
|
| 248 |
+
# 加载指定风格的模型
|
| 249 |
+
if not self._load_model(style_type):
|
| 250 |
+
logger.error(f"Failed to load {style_type} model")
|
| 251 |
+
return image
|
| 252 |
+
|
| 253 |
+
return self._stylize_image_via_file(image, style_type)
|
| 254 |
+
|
| 255 |
+
def _stylize_image_via_file(self, image, style_type="disney"):
|
| 256 |
+
"""
|
| 257 |
+
通过临时文件进行动漫风格化
|
| 258 |
+
:param image: 输入图像 (numpy array, BGR格式)
|
| 259 |
+
:param style_type: 动漫风格类型
|
| 260 |
+
:return: 动漫风格化后的图像 (numpy array, BGR格式)
|
| 261 |
+
"""
|
| 262 |
+
try:
|
| 263 |
+
config = self.model_configs.get(style_type, {})
|
| 264 |
+
style_name = config.get('name', style_type)
|
| 265 |
+
logger.info(f"Using anime stylization processing, style type: {style_name} ({style_type})")
|
| 266 |
+
|
| 267 |
+
# 验证风格类型
|
| 268 |
+
if style_type not in self.model_configs:
|
| 269 |
+
logger.warning(f"Invalid style type: {style_type}, using default style disney")
|
| 270 |
+
style_type = "disney"
|
| 271 |
+
|
| 272 |
+
# 使用最高质量设置保存临时图像
|
| 273 |
+
with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_input:
|
| 274 |
+
# 使用WebP格式,最高质量设置
|
| 275 |
+
cv2.imwrite(tmp_input.name, image, [cv2.IMWRITE_WEBP_QUALITY, 100])
|
| 276 |
+
tmp_input_path = tmp_input.name
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
logger.info(f"Temporary file saved to: {tmp_input_path}")
|
| 280 |
+
|
| 281 |
+
# 使用ModelScope进行动漫风格化
|
| 282 |
+
stylizer = self.stylizers[style_type]
|
| 283 |
+
|
| 284 |
+
# 根据模型类型使用不同的调用方式
|
| 285 |
+
if "stable_diffusion" in config["model_id"]:
|
| 286 |
+
# Stable Diffusion模型需要特殊处理
|
| 287 |
+
logger.info("Using Stable Diffusion model, text parameter is required")
|
| 288 |
+
# 对于Stable Diffusion,必须使用'sks style'格式的提示词
|
| 289 |
+
style_prompts = {}
|
| 290 |
+
prompt = style_prompts.get(style_type, "sks style, cartoon style artwork")
|
| 291 |
+
logger.info(f"Using prompt: {prompt}")
|
| 292 |
+
result = stylizer({"text": prompt})
|
| 293 |
+
else:
|
| 294 |
+
# UNet模型直接处理
|
| 295 |
+
result = stylizer(tmp_input_path)
|
| 296 |
+
|
| 297 |
+
# 获取风格化后的图像
|
| 298 |
+
# 不同模型的输出键名可能不同,需要适配
|
| 299 |
+
if "stable_diffusion" in config["model_id"]:
|
| 300 |
+
# Stable Diffusion模型通常使用不同的输出键名
|
| 301 |
+
logger.info(f"Stable Diffusion model output keys: {list(result.keys())}")
|
| 302 |
+
if 'output_imgs' in result:
|
| 303 |
+
stylized_image = result['output_imgs'][0]
|
| 304 |
+
elif 'output_img' in result:
|
| 305 |
+
stylized_image = result['output_img']
|
| 306 |
+
elif self.OutputKeys.OUTPUT_IMG in result:
|
| 307 |
+
stylized_image = result[self.OutputKeys.OUTPUT_IMG]
|
| 308 |
+
else:
|
| 309 |
+
# 尝试获取第一个图像输出
|
| 310 |
+
for key in result.keys():
|
| 311 |
+
if isinstance(result[key], (list, tuple)) and len(result[key]) > 0:
|
| 312 |
+
stylized_image = result[key][0]
|
| 313 |
+
logger.info(f"Using output key: {key}")
|
| 314 |
+
break
|
| 315 |
+
elif hasattr(result[key], 'shape'):
|
| 316 |
+
stylized_image = result[key]
|
| 317 |
+
logger.info(f"Using output key: {key}")
|
| 318 |
+
break
|
| 319 |
+
else:
|
| 320 |
+
raise KeyError(f"未找到有效的图像输出键,可用键: {list(result.keys())}")
|
| 321 |
+
else:
|
| 322 |
+
# UNet模型使用标准输出键
|
| 323 |
+
stylized_image = result[self.OutputKeys.OUTPUT_IMG]
|
| 324 |
+
|
| 325 |
+
logger.info(f"Anime stylization output: size={stylized_image.shape}, type={stylized_image.dtype}")
|
| 326 |
+
|
| 327 |
+
# ModelScope输出的图像已经是BGR格式,不需要转换
|
| 328 |
+
logger.info("Anime stylization processing completed")
|
| 329 |
+
return stylized_image
|
| 330 |
+
|
| 331 |
+
finally:
|
| 332 |
+
# 清理临时文件
|
| 333 |
+
try:
|
| 334 |
+
os.unlink(tmp_input_path)
|
| 335 |
+
except:
|
| 336 |
+
pass
|
| 337 |
+
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logger.error(f"Anime stylization processing failed: {e}")
|
| 340 |
+
logger.info("Returning original image")
|
| 341 |
+
return image
|
| 342 |
+
|
| 343 |
+
def get_available_styles(self):
|
| 344 |
+
"""
|
| 345 |
+
获取支持的动漫风格类型
|
| 346 |
+
:return: 字典,包含风格代码和描述
|
| 347 |
+
"""
|
| 348 |
+
if not hasattr(self, 'model_configs'):
|
| 349 |
+
return {}
|
| 350 |
+
|
| 351 |
+
return {
|
| 352 |
+
style_type: f"{config['name']} - {config['description'].split(' - ')[1]}"
|
| 353 |
+
for style_type, config in self.model_configs.items()
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
def save_debug_image(self, image, filename_prefix):
|
| 357 |
+
"""保存调试用的图像"""
|
| 358 |
+
try:
|
| 359 |
+
debug_path = f"{filename_prefix}_debug.webp"
|
| 360 |
+
cv2.imwrite(debug_path, image, [cv2.IMWRITE_WEBP_QUALITY, 95])
|
| 361 |
+
logger.info(f"Debug image saved: {debug_path}")
|
| 362 |
+
return debug_path
|
| 363 |
+
except Exception as e:
|
| 364 |
+
logger.error(f"Failed to save debug image: {e}")
|
| 365 |
+
return None
|
| 366 |
+
|
| 367 |
+
def test_stylization(self, test_url=None):
|
| 368 |
+
"""
|
| 369 |
+
测试动漫风格化功能
|
| 370 |
+
:param test_url: 测试图像URL,默认使用官方示例
|
| 371 |
+
:return: 测试结果
|
| 372 |
+
"""
|
| 373 |
+
if not self.is_available():
|
| 374 |
+
return False, "Anime Style模型未初始化"
|
| 375 |
+
|
| 376 |
+
try:
|
| 377 |
+
test_url = test_url or 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/portrait.jpg'
|
| 378 |
+
logger.info(f"Testing anime stylization feature, using image: {test_url}")
|
| 379 |
+
|
| 380 |
+
# 测试默认风格
|
| 381 |
+
result = self.stylizer(test_url)
|
| 382 |
+
stylized_img = result[self.OutputKeys.OUTPUT_IMG]
|
| 383 |
+
|
| 384 |
+
# 保存测试结果
|
| 385 |
+
test_output_path = 'anime_style_test_result.webp'
|
| 386 |
+
cv2.imwrite(test_output_path, stylized_img, [cv2.IMWRITE_WEBP_QUALITY, 95])
|
| 387 |
+
|
| 388 |
+
logger.info(f"Anime stylization test successful, result saved to: {test_output_path}")
|
| 389 |
+
return True, f"测试成功,结果保存到: {test_output_path}"
|
| 390 |
+
|
| 391 |
+
except Exception as e:
|
| 392 |
+
logger.error(f"Anime stylization test failed: {e}")
|
| 393 |
+
return False, f"测试失败: {e}"
|
| 394 |
+
|
| 395 |
+
def test_local_image(self, image_path, style_type="disney"):
|
| 396 |
+
"""
|
| 397 |
+
测试本地图像动漫风格化
|
| 398 |
+
:param image_path: 本地图像路径
|
| 399 |
+
:param style_type: 动漫风格类型
|
| 400 |
+
:return: 测试结果
|
| 401 |
+
"""
|
| 402 |
+
if not self.is_available():
|
| 403 |
+
return False, "Anime Style模型未初始化"
|
| 404 |
+
|
| 405 |
+
try:
|
| 406 |
+
logger.info(f"Testing local image anime stylization: {image_path}, style: {style_type}")
|
| 407 |
+
|
| 408 |
+
# 读取本地图像
|
| 409 |
+
image = cv2.imread(image_path)
|
| 410 |
+
if image is None:
|
| 411 |
+
return False, f"Unable to read image: {image_path}"
|
| 412 |
+
|
| 413 |
+
# 保存原图用于对比
|
| 414 |
+
self.save_debug_image(image, "original")
|
| 415 |
+
|
| 416 |
+
# 动漫风格化处理
|
| 417 |
+
stylized_image = self.stylize_image(image, style_type)
|
| 418 |
+
|
| 419 |
+
# 保存风格化结果
|
| 420 |
+
result_path = self.save_debug_image(stylized_image, f"anime_style_{style_type}")
|
| 421 |
+
|
| 422 |
+
logger.info(f"Local image anime stylization successful, result saved to: {result_path}")
|
| 423 |
+
return True, f"本地图像动漫风格化成功,结果保存到: {result_path}"
|
| 424 |
+
|
| 425 |
+
except Exception as e:
|
| 426 |
+
logger.error(f"Local image anime stylization failed: {e}")
|
| 427 |
+
return False, f"本地图像动漫风格化失败: {e}"
|
api_routes.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from contextlib import asynccontextmanager
|
| 4 |
+
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
from starlette.middleware.cors import CORSMiddleware
|
| 7 |
+
|
| 8 |
+
from cleanup_scheduler import start_cleanup_scheduler, stop_cleanup_scheduler
|
| 9 |
+
from config import (
|
| 10 |
+
logger,
|
| 11 |
+
OUTPUT_DIR,
|
| 12 |
+
DEEPFACE_AVAILABLE,
|
| 13 |
+
DLIB_AVAILABLE,
|
| 14 |
+
MODELS_PATH,
|
| 15 |
+
IMAGES_DIR,
|
| 16 |
+
YOLO_AVAILABLE,
|
| 17 |
+
ENABLE_LOGGING,
|
| 18 |
+
HUGGINGFACE_SYNC_ENABLED,
|
| 19 |
+
)
|
| 20 |
+
from database import close_mysql_pool, init_mysql_pool
|
| 21 |
+
from utils import ensure_bos_resources, ensure_huggingface_models
|
| 22 |
+
|
| 23 |
+
logger.info("Starting to import api_routes module...")
|
| 24 |
+
|
| 25 |
+
if HUGGINGFACE_SYNC_ENABLED:
|
| 26 |
+
try:
|
| 27 |
+
t_hf_start = time.perf_counter()
|
| 28 |
+
if not ensure_huggingface_models():
|
| 29 |
+
raise RuntimeError("无法从 HuggingFace 同步模型,请检查配置与网络")
|
| 30 |
+
hf_time = time.perf_counter() - t_hf_start
|
| 31 |
+
logger.info("HuggingFace 模型同步完成,用时 %.3fs", hf_time)
|
| 32 |
+
except Exception as exc:
|
| 33 |
+
logger.error(f"HuggingFace model preparation failed: {exc}")
|
| 34 |
+
raise
|
| 35 |
+
else:
|
| 36 |
+
logger.info("已关闭 HuggingFace 模型同步开关,跳过启动阶段的同步步骤")
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
t_bos_start = time.perf_counter()
|
| 40 |
+
if not ensure_bos_resources():
|
| 41 |
+
raise RuntimeError("无法从 BOS 同步模型与数据,请检查凭证与网络")
|
| 42 |
+
bos_time = time.perf_counter() - t_bos_start
|
| 43 |
+
logger.info(f"BOS resources synchronized successfully, time: {bos_time:.3f}s")
|
| 44 |
+
except Exception as exc:
|
| 45 |
+
logger.error(f"BOS resource preparation failed: {exc}")
|
| 46 |
+
raise
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
t_start = time.perf_counter()
|
| 50 |
+
from api_routes import api_router, extract_chinese_celeb_dataset_sync
|
| 51 |
+
import_time = time.perf_counter() - t_start
|
| 52 |
+
logger.info(f"api_routes module imported successfully, time: {import_time:.3f}s")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
import_time = time.perf_counter() - t_start
|
| 55 |
+
logger.error(f"api_routes module import failed, time: {import_time:.3f}s, error: {e}")
|
| 56 |
+
raise
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
t_extract_start = time.perf_counter()
|
| 60 |
+
extract_result = extract_chinese_celeb_dataset_sync()
|
| 61 |
+
extract_time = time.perf_counter() - t_extract_start
|
| 62 |
+
logger.info(
|
| 63 |
+
"Chinese celeb dataset extracted successfully, time: %.3fs, target: %s",
|
| 64 |
+
extract_time,
|
| 65 |
+
extract_result.get("target_dir"),
|
| 66 |
+
)
|
| 67 |
+
except Exception as exc:
|
| 68 |
+
logger.error(f"Failed to extract Chinese celeb dataset automatically: {exc}")
|
| 69 |
+
raise
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@asynccontextmanager
|
| 73 |
+
async def lifespan(app: FastAPI):
|
| 74 |
+
start_time = time.perf_counter()
|
| 75 |
+
logger.info("FaceScore service starting...")
|
| 76 |
+
logger.info(f"Output directory: {OUTPUT_DIR}")
|
| 77 |
+
logger.info(f"DeepFace available: {DEEPFACE_AVAILABLE}")
|
| 78 |
+
logger.info(f"YOLO available: {YOLO_AVAILABLE}")
|
| 79 |
+
logger.info(f"MediaPipe available: {DLIB_AVAILABLE}")
|
| 80 |
+
logger.info(f"Archive directory: {IMAGES_DIR}")
|
| 81 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 82 |
+
|
| 83 |
+
# 初始化数据库连接池
|
| 84 |
+
try:
|
| 85 |
+
await init_mysql_pool()
|
| 86 |
+
logger.info("MySQL 连接池初始化完成")
|
| 87 |
+
except Exception as exc:
|
| 88 |
+
logger.error(f"初始化 MySQL 连接池失败: {exc}")
|
| 89 |
+
raise
|
| 90 |
+
|
| 91 |
+
# 启动图片清理定时任务
|
| 92 |
+
logger.info("Starting image cleanup scheduled task...")
|
| 93 |
+
try:
|
| 94 |
+
start_cleanup_scheduler()
|
| 95 |
+
logger.info("Image cleanup scheduled task started successfully")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"Failed to start image cleanup scheduled task: {e}")
|
| 98 |
+
|
| 99 |
+
# 记录启动完成时间
|
| 100 |
+
total_startup_time = time.perf_counter() - start_time
|
| 101 |
+
logger.info(f"FaceScore service startup completed, total time: {total_startup_time:.3f}s")
|
| 102 |
+
|
| 103 |
+
yield
|
| 104 |
+
|
| 105 |
+
# 应用关闭时停止定时任务
|
| 106 |
+
logger.info("Stopping image cleanup scheduled task...")
|
| 107 |
+
try:
|
| 108 |
+
stop_cleanup_scheduler()
|
| 109 |
+
logger.info("Image cleanup scheduled task stopped")
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.error(f"Failed to stop image cleanup scheduled task: {e}")
|
| 112 |
+
|
| 113 |
+
# 关闭数据库连接池
|
| 114 |
+
try:
|
| 115 |
+
await close_mysql_pool()
|
| 116 |
+
except Exception as exc:
|
| 117 |
+
logger.warning(f"关闭 MySQL 连接池失败: {exc}")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# 创建 FastAPI 应用
|
| 121 |
+
app = FastAPI(
|
| 122 |
+
title="Enhanced FaceScore 服务",
|
| 123 |
+
description="支持多模型的人脸分析REST API服务,包含五官评分功能。支持混合模式:HowCuteAmI(颜值+性别)+ DeepFace(年龄+情绪)",
|
| 124 |
+
version="3.0.0",
|
| 125 |
+
docs_url="/cp_docs",
|
| 126 |
+
redoc_url="/cp_redoc",
|
| 127 |
+
lifespan=lifespan,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
app.add_middleware(
|
| 131 |
+
CORSMiddleware,
|
| 132 |
+
allow_origins=["*"],
|
| 133 |
+
allow_methods=["*"],
|
| 134 |
+
allow_headers=["*"],
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# 注册路由
|
| 138 |
+
app.include_router(api_router)
|
| 139 |
+
|
| 140 |
+
# 添加根路径处理
|
| 141 |
+
@app.get("/")
|
| 142 |
+
async def root():
|
| 143 |
+
return "UP"
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
if __name__ == "__main__":
|
| 147 |
+
import uvicorn
|
| 148 |
+
|
| 149 |
+
if not os.path.exists(MODELS_PATH):
|
| 150 |
+
logger.critical(
|
| 151 |
+
"Warning: 'models' directory not found. Please ensure it exists and contains model files."
|
| 152 |
+
)
|
| 153 |
+
logger.critical(
|
| 154 |
+
"Exiting application as FaceAnalyzer cannot be initialized without models."
|
| 155 |
+
)
|
| 156 |
+
exit(1)
|
| 157 |
+
|
| 158 |
+
# 根据日志开关配置 Uvicorn 日志
|
| 159 |
+
if ENABLE_LOGGING:
|
| 160 |
+
uvicorn.run(app, host="0.0.0.0", port=8080, reload=False)
|
| 161 |
+
else:
|
| 162 |
+
# 禁用 Uvicorn 的访问日志和错误日志
|
| 163 |
+
uvicorn.run(
|
| 164 |
+
app,
|
| 165 |
+
host="0.0.0.0",
|
| 166 |
+
port=8080,
|
| 167 |
+
reload=False,
|
| 168 |
+
access_log=False, # 禁用访问日志
|
| 169 |
+
log_level="critical" # 只显示严重错误
|
| 170 |
+
)
|
build.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -m compileall -q -f -b .
|
| 2 |
+
mv *.pyc /opt/data/app/
|
| 3 |
+
cp gfpgan_restorer.py /opt/data/app/
|
| 4 |
+
cp start_local.sh /opt/data/app/
|
cleanup_scheduler.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
定时清理图片文件模块
|
| 3 |
+
每小时检查一次IMAGES_DIR目录,删除1小时以前的图片文件
|
| 4 |
+
"""
|
| 5 |
+
import glob
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
| 11 |
+
|
| 12 |
+
from config import logger, IMAGES_DIR, CLEANUP_INTERVAL_HOURS, CLEANUP_AGE_HOURS
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# from utils import delete_file_from_bos # 暂时注释掉删除BOS文件的功能
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ImageCleanupScheduler:
|
| 19 |
+
"""图片清理定时任务类"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, images_dir=None, cleanup_hours=None, interval_hours=None):
|
| 22 |
+
"""
|
| 23 |
+
初始化清理调度器
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
images_dir (str): 图片目录路径,默认使用config中的IMAGES_DIR
|
| 27 |
+
cleanup_hours (float): 清理时间阈值(小时),默认使用环境变量CLEANUP_AGE_HOURS
|
| 28 |
+
interval_hours (float): 定时任务执行间隔(小时),默认使用环境变量CLEANUP_INTERVAL_HOURS
|
| 29 |
+
"""
|
| 30 |
+
self.images_dir = images_dir or IMAGES_DIR
|
| 31 |
+
self.cleanup_hours = cleanup_hours if cleanup_hours is not None else CLEANUP_AGE_HOURS
|
| 32 |
+
self.interval_hours = interval_hours if interval_hours is not None else CLEANUP_INTERVAL_HOURS
|
| 33 |
+
self.scheduler = BackgroundScheduler()
|
| 34 |
+
self.is_running = False
|
| 35 |
+
|
| 36 |
+
# 确保目录存在
|
| 37 |
+
os.makedirs(self.images_dir, exist_ok=True)
|
| 38 |
+
logger.info(f"Image cleanup scheduler initialized, monitoring directory: {self.images_dir}, cleanup threshold: {self.cleanup_hours} hours, execution interval: {self.interval_hours} hours")
|
| 39 |
+
|
| 40 |
+
def cleanup_old_images(self):
|
| 41 |
+
"""
|
| 42 |
+
清理过期的图片文件
|
| 43 |
+
删除超过指定时间的图片文件
|
| 44 |
+
"""
|
| 45 |
+
try:
|
| 46 |
+
current_time = time.time()
|
| 47 |
+
cutoff_time = current_time - (self.cleanup_hours * 3600) # 转换为秒
|
| 48 |
+
cutoff_datetime = datetime.fromtimestamp(cutoff_time)
|
| 49 |
+
|
| 50 |
+
# 支持的图片格式
|
| 51 |
+
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.gif', '*.bmp']
|
| 52 |
+
deleted_files = []
|
| 53 |
+
total_size_deleted = 0
|
| 54 |
+
|
| 55 |
+
logger.info(f"Starting to clean image directory: {self.images_dir}")
|
| 56 |
+
logger.info(f"Cleanup threshold time: {cutoff_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
|
| 57 |
+
|
| 58 |
+
# 遍历所有图片文件
|
| 59 |
+
for extension in image_extensions:
|
| 60 |
+
pattern = os.path.join(self.images_dir, extension)
|
| 61 |
+
for file_path in glob.glob(pattern):
|
| 62 |
+
try:
|
| 63 |
+
# 获取文件修改时间
|
| 64 |
+
file_mtime = os.path.getmtime(file_path)
|
| 65 |
+
|
| 66 |
+
# 如果文件时间早于阈值时间,则删除
|
| 67 |
+
if file_mtime < cutoff_time:
|
| 68 |
+
file_size = os.path.getsize(file_path)
|
| 69 |
+
file_time = datetime.fromtimestamp(file_mtime)
|
| 70 |
+
|
| 71 |
+
# 删除文件
|
| 72 |
+
os.remove(file_path)
|
| 73 |
+
# delete_file_from_bos(file_path) # 暂时注释掉删除BOS文件
|
| 74 |
+
deleted_files.append(os.path.basename(file_path))
|
| 75 |
+
total_size_deleted += file_size
|
| 76 |
+
|
| 77 |
+
logger.info(f"Deleting expired file: {os.path.basename(file_path)} ")
|
| 78 |
+
|
| 79 |
+
except (OSError, IOError) as e:
|
| 80 |
+
logger.error(f"Failed to delete file {os.path.basename(file_path)}: {e}")
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
logger.info(f"Cleanup completed! Deleted {len(deleted_files)} files, ")
|
| 84 |
+
logger.info(f"Deleted file list: {', '.join(deleted_files[:10])}")
|
| 85 |
+
else:
|
| 86 |
+
logger.info("Cleanup completed! No expired files found to clean")
|
| 87 |
+
|
| 88 |
+
return {
|
| 89 |
+
'success': True,
|
| 90 |
+
'deleted_count': len(deleted_files),
|
| 91 |
+
'deleted_size': total_size_deleted,
|
| 92 |
+
'deleted_files': deleted_files,
|
| 93 |
+
'cutoff_time': cutoff_datetime.isoformat()
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
error_msg = f"图片清理任务执行失败: {e}"
|
| 98 |
+
logger.error(error_msg)
|
| 99 |
+
return {
|
| 100 |
+
'success': False,
|
| 101 |
+
'error': str(e),
|
| 102 |
+
'deleted_count': 0,
|
| 103 |
+
'deleted_size': 0
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
def _format_size(self, size_bytes):
|
| 107 |
+
"""格式化文件大小显示"""
|
| 108 |
+
if size_bytes == 0:
|
| 109 |
+
return "0 B"
|
| 110 |
+
size_names = ["B", "KB", "MB", "GB"]
|
| 111 |
+
i = 0
|
| 112 |
+
while size_bytes >= 1024 and i < len(size_names) - 1:
|
| 113 |
+
size_bytes /= 1024.0
|
| 114 |
+
i += 1
|
| 115 |
+
return f"{size_bytes:.1f} {size_names[i]}"
|
| 116 |
+
|
| 117 |
+
def start(self):
|
| 118 |
+
"""启动定时清理任务"""
|
| 119 |
+
if self.is_running:
|
| 120 |
+
logger.warning("Image cleanup scheduler is already running")
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
# 添加定时任务:使用可配置的执���间隔
|
| 125 |
+
self.scheduler.add_job(
|
| 126 |
+
func=self.cleanup_old_images,
|
| 127 |
+
trigger='interval',
|
| 128 |
+
hours=self.interval_hours, # 使用环境变量配置的执行间隔
|
| 129 |
+
id='image_cleanup',
|
| 130 |
+
name='image clean tast',
|
| 131 |
+
replace_existing=True
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# 启动调度器
|
| 135 |
+
self.scheduler.start()
|
| 136 |
+
self.is_running = True
|
| 137 |
+
|
| 138 |
+
logger.info(f"Image cleanup scheduler started, will execute cleanup task every {self.interval_hours} hours")
|
| 139 |
+
|
| 140 |
+
# 立即执行一次清理(可选)
|
| 141 |
+
logger.info("Executing image cleanup task immediately...")
|
| 142 |
+
self.cleanup_old_images()
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"Failed to start image cleanup scheduler: {e}")
|
| 146 |
+
raise
|
| 147 |
+
|
| 148 |
+
def stop(self):
|
| 149 |
+
"""停止定时清理任务"""
|
| 150 |
+
if not self.is_running:
|
| 151 |
+
logger.warning("Image cleanup scheduler is not running")
|
| 152 |
+
return
|
| 153 |
+
|
| 154 |
+
try:
|
| 155 |
+
self.scheduler.shutdown(wait=False)
|
| 156 |
+
self.is_running = False
|
| 157 |
+
logger.info("Image cleanup scheduler stopped")
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"Failed to stop image cleanup scheduler: {e}")
|
| 160 |
+
|
| 161 |
+
def get_status(self):
|
| 162 |
+
"""获取调度器状态"""
|
| 163 |
+
return {
|
| 164 |
+
'running': self.is_running,
|
| 165 |
+
'images_dir': self.images_dir,
|
| 166 |
+
'cleanup_hours': self.cleanup_hours,
|
| 167 |
+
'interval_hours': self.interval_hours,
|
| 168 |
+
'next_run': self.scheduler.get_jobs()[0].next_run_time.isoformat()
|
| 169 |
+
if self.is_running and self.scheduler.get_jobs() else None
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# 创建全局调度器实例
|
| 174 |
+
cleanup_scheduler = ImageCleanupScheduler()
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def start_cleanup_scheduler():
|
| 178 |
+
"""启动图片清理调度器"""
|
| 179 |
+
cleanup_scheduler.start()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def stop_cleanup_scheduler():
|
| 183 |
+
"""停止图片清理调度器"""
|
| 184 |
+
cleanup_scheduler.stop()
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_cleanup_status():
|
| 188 |
+
"""获取清理调度器状态"""
|
| 189 |
+
return cleanup_scheduler.get_status()
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def manual_cleanup():
|
| 193 |
+
"""手动执行一次清理"""
|
| 194 |
+
return cleanup_scheduler.cleanup_old_images()
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
# 测试代码
|
| 199 |
+
print("测试图片清理功能...")
|
| 200 |
+
test_scheduler = ImageCleanupScheduler()
|
| 201 |
+
result = test_scheduler.cleanup_old_images()
|
| 202 |
+
print(f"清理结果: {result}")
|
clip_utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# clip_utils.py
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from typing import Union, List
|
| 5 |
+
|
| 6 |
+
import cn_clip.clip as clip
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from cn_clip.clip import load_from_name
|
| 10 |
+
|
| 11 |
+
from config import MODELS_PATH
|
| 12 |
+
|
| 13 |
+
# 配置日志
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# 环境变量配置
|
| 18 |
+
MODEL_NAME_CN = os.environ.get('MODEL_NAME_CN', 'ViT-B-16')
|
| 19 |
+
|
| 20 |
+
# 设备配置
|
| 21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
|
| 23 |
+
# 模型初始化
|
| 24 |
+
model = None
|
| 25 |
+
preprocess = None
|
| 26 |
+
|
| 27 |
+
def init_clip_model():
|
| 28 |
+
"""初始化CLIP模型"""
|
| 29 |
+
global model, preprocess
|
| 30 |
+
try:
|
| 31 |
+
model, preprocess = load_from_name(MODEL_NAME_CN, device=device, download_root=MODELS_PATH)
|
| 32 |
+
model.eval()
|
| 33 |
+
logger.info(f"CLIP model initialized successfully, dimension: {model.visual.output_dim}")
|
| 34 |
+
return True
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logger.error(f"CLIP model initialization failed: {e}")
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
def is_clip_available():
|
| 40 |
+
"""检查CLIP模型是否可用"""
|
| 41 |
+
return model is not None and preprocess is not None
|
| 42 |
+
|
| 43 |
+
def encode_image(image_path: str) -> torch.Tensor:
|
| 44 |
+
"""编码图片为向量"""
|
| 45 |
+
if not is_clip_available():
|
| 46 |
+
raise RuntimeError("CLIP模型未初始化")
|
| 47 |
+
|
| 48 |
+
image = Image.open(image_path).convert("RGB")
|
| 49 |
+
image_tensor = preprocess(image).unsqueeze(0).to(device)
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
features = model.encode_image(image_tensor)
|
| 52 |
+
features = features / features.norm(p=2, dim=-1, keepdim=True)
|
| 53 |
+
return features.cpu()
|
| 54 |
+
|
| 55 |
+
def encode_text(text: Union[str, List[str]]) -> torch.Tensor:
|
| 56 |
+
"""编码文本为向量"""
|
| 57 |
+
if not is_clip_available():
|
| 58 |
+
raise RuntimeError("CLIP模型未初始化")
|
| 59 |
+
|
| 60 |
+
texts = [text] if isinstance(text, str) else text
|
| 61 |
+
text_tokens = clip.tokenize(texts).to(device)
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
features = model.encode_text(text_tokens)
|
| 64 |
+
features = features / features.norm(p=2, dim=-1, keepdim=True)
|
| 65 |
+
return features.cpu()
|
config.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# 解决OpenMP库冲突问题
|
| 5 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
| 6 |
+
# 设置CPU线程数为CPU核心数,提高CPU利用率
|
| 7 |
+
import multiprocessing
|
| 8 |
+
cpu_cores = multiprocessing.cpu_count()
|
| 9 |
+
os.environ["OMP_NUM_THREADS"] = str(min(cpu_cores, 8)) # 最多使用8个线程
|
| 10 |
+
os.environ["MKL_NUM_THREADS"] = str(min(cpu_cores, 8))
|
| 11 |
+
os.environ["NUMEXPR_NUM_THREADS"] = str(min(cpu_cores, 8))
|
| 12 |
+
|
| 13 |
+
# 修复torchvision兼容性问题
|
| 14 |
+
try:
|
| 15 |
+
import torchvision.transforms.functional_tensor
|
| 16 |
+
except ImportError:
|
| 17 |
+
# 为缺失的functional_tensor模块创建兼容性补丁
|
| 18 |
+
import torchvision.transforms.functional as F
|
| 19 |
+
import torchvision.transforms as transforms
|
| 20 |
+
import sys
|
| 21 |
+
from types import ModuleType
|
| 22 |
+
|
| 23 |
+
# 创建functional_tensor模块
|
| 24 |
+
functional_tensor = ModuleType('torchvision.transforms.functional_tensor')
|
| 25 |
+
|
| 26 |
+
# 添加常用的函数映射
|
| 27 |
+
if hasattr(F, 'rgb_to_grayscale'):
|
| 28 |
+
functional_tensor.rgb_to_grayscale = F.rgb_to_grayscale
|
| 29 |
+
if hasattr(F, 'adjust_brightness'):
|
| 30 |
+
functional_tensor.adjust_brightness = F.adjust_brightness
|
| 31 |
+
if hasattr(F, 'adjust_contrast'):
|
| 32 |
+
functional_tensor.adjust_contrast = F.adjust_contrast
|
| 33 |
+
if hasattr(F, 'adjust_saturation'):
|
| 34 |
+
functional_tensor.adjust_saturation = F.adjust_saturation
|
| 35 |
+
if hasattr(F, 'normalize'):
|
| 36 |
+
functional_tensor.normalize = F.normalize
|
| 37 |
+
if hasattr(F, 'resize'):
|
| 38 |
+
functional_tensor.resize = F.resize
|
| 39 |
+
if hasattr(F, 'crop'):
|
| 40 |
+
functional_tensor.crop = F.crop
|
| 41 |
+
if hasattr(F, 'pad'):
|
| 42 |
+
functional_tensor.pad = F.pad
|
| 43 |
+
|
| 44 |
+
# 将模块添加到sys.modules
|
| 45 |
+
sys.modules['torchvision.transforms.functional_tensor'] = functional_tensor
|
| 46 |
+
transforms.functional_tensor = functional_tensor
|
| 47 |
+
|
| 48 |
+
# 环境变量配置 - 禁用TensorFlow优化和GPU
|
| 49 |
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 50 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
| 51 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # 强制使用CPU
|
| 52 |
+
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false"
|
| 53 |
+
|
| 54 |
+
# 修复PyTorch兼容性问题
|
| 55 |
+
try:
|
| 56 |
+
import torch
|
| 57 |
+
import torch.onnx
|
| 58 |
+
|
| 59 |
+
# 修复GFPGAN的ONNX兼容性
|
| 60 |
+
if not hasattr(torch.onnx._internal.exporter, 'ExportOptions'):
|
| 61 |
+
from types import SimpleNamespace
|
| 62 |
+
torch.onnx._internal.exporter.ExportOptions = SimpleNamespace
|
| 63 |
+
|
| 64 |
+
# 修复ModelScope的PyTree兼容性 - 更完整的实现
|
| 65 |
+
import torch.utils
|
| 66 |
+
if not hasattr(torch.utils, '_pytree'):
|
| 67 |
+
# 如果_pytree模块不存在,创建一个
|
| 68 |
+
from types import ModuleType
|
| 69 |
+
torch.utils._pytree = ModuleType('_pytree')
|
| 70 |
+
|
| 71 |
+
pytree = torch.utils._pytree
|
| 72 |
+
|
| 73 |
+
if not hasattr(pytree, 'register_pytree_node'):
|
| 74 |
+
def register_pytree_node(typ, flatten_fn, unflatten_fn, *, flatten_with_keys_fn=None, **kwargs):
|
| 75 |
+
"""兼容性实现:注册PyTree节点类型"""
|
| 76 |
+
pass # 简单实现,不做实际操作
|
| 77 |
+
pytree.register_pytree_node = register_pytree_node
|
| 78 |
+
|
| 79 |
+
if not hasattr(pytree, 'tree_flatten'):
|
| 80 |
+
def tree_flatten(tree, is_leaf=None):
|
| 81 |
+
"""兼容性实现:展平树结构"""
|
| 82 |
+
if isinstance(tree, (list, tuple)):
|
| 83 |
+
flat = []
|
| 84 |
+
spec = []
|
| 85 |
+
for i, item in enumerate(tree):
|
| 86 |
+
if isinstance(item, (list, tuple, dict)):
|
| 87 |
+
sub_flat, sub_spec = tree_flatten(item, is_leaf)
|
| 88 |
+
flat.extend(sub_flat)
|
| 89 |
+
spec.append((i, sub_spec))
|
| 90 |
+
else:
|
| 91 |
+
flat.append(item)
|
| 92 |
+
spec.append((i, None))
|
| 93 |
+
return flat, (type(tree), spec)
|
| 94 |
+
elif isinstance(tree, dict):
|
| 95 |
+
flat = []
|
| 96 |
+
spec = []
|
| 97 |
+
for key, value in sorted(tree.items()):
|
| 98 |
+
if isinstance(value, (list, tuple, dict)):
|
| 99 |
+
sub_flat, sub_spec = tree_flatten(value, is_leaf)
|
| 100 |
+
flat.extend(sub_flat)
|
| 101 |
+
spec.append((key, sub_spec))
|
| 102 |
+
else:
|
| 103 |
+
flat.append(value)
|
| 104 |
+
spec.append((key, None))
|
| 105 |
+
return flat, (dict, spec)
|
| 106 |
+
else:
|
| 107 |
+
return [tree], None
|
| 108 |
+
pytree.tree_flatten = tree_flatten
|
| 109 |
+
|
| 110 |
+
if not hasattr(pytree, 'tree_unflatten'):
|
| 111 |
+
def tree_unflatten(values, spec):
|
| 112 |
+
"""兼容性实现:重构树结构"""
|
| 113 |
+
if spec is None:
|
| 114 |
+
return values[0] if values else None
|
| 115 |
+
|
| 116 |
+
tree_type, tree_spec = spec
|
| 117 |
+
if tree_type in (list, tuple):
|
| 118 |
+
result = []
|
| 119 |
+
value_idx = 0
|
| 120 |
+
for pos, sub_spec in tree_spec:
|
| 121 |
+
if sub_spec is None:
|
| 122 |
+
result.append(values[value_idx])
|
| 123 |
+
value_idx += 1
|
| 124 |
+
else:
|
| 125 |
+
# 计算子树需要的值数量
|
| 126 |
+
sub_count = _count_tree_values(sub_spec)
|
| 127 |
+
sub_values = values[value_idx:value_idx + sub_count]
|
| 128 |
+
result.append(tree_unflatten(sub_values, sub_spec))
|
| 129 |
+
value_idx += sub_count
|
| 130 |
+
return tree_type(result)
|
| 131 |
+
elif tree_type == dict:
|
| 132 |
+
result = {}
|
| 133 |
+
value_idx = 0
|
| 134 |
+
for key, sub_spec in tree_spec:
|
| 135 |
+
if sub_spec is None:
|
| 136 |
+
result[key] = values[value_idx]
|
| 137 |
+
value_idx += 1
|
| 138 |
+
else:
|
| 139 |
+
sub_count = _count_tree_values(sub_spec)
|
| 140 |
+
sub_values = values[value_idx:value_idx + sub_count]
|
| 141 |
+
result[key] = tree_unflatten(sub_values, sub_spec)
|
| 142 |
+
value_idx += sub_count
|
| 143 |
+
return result
|
| 144 |
+
return values[0] if values else None
|
| 145 |
+
pytree.tree_unflatten = tree_unflatten
|
| 146 |
+
|
| 147 |
+
if not hasattr(pytree, 'tree_map'):
|
| 148 |
+
def tree_map(fn, tree, *other_trees, is_leaf=None):
|
| 149 |
+
"""兼容性实现:树映射"""
|
| 150 |
+
flat, spec = tree_flatten(tree, is_leaf)
|
| 151 |
+
if other_trees:
|
| 152 |
+
other_flats = [tree_flatten(t, is_leaf)[0] for t in other_trees]
|
| 153 |
+
mapped = [fn(x, *others) for x, *others in zip(flat, *other_flats)]
|
| 154 |
+
else:
|
| 155 |
+
mapped = [fn(x) for x in flat]
|
| 156 |
+
return tree_unflatten(mapped, spec)
|
| 157 |
+
pytree.tree_map = tree_map
|
| 158 |
+
|
| 159 |
+
# 辅助函数
|
| 160 |
+
def _count_tree_values(spec):
|
| 161 |
+
"""计算树规格中的值数量"""
|
| 162 |
+
if spec is None:
|
| 163 |
+
return 1
|
| 164 |
+
tree_type, tree_spec = spec
|
| 165 |
+
return sum(_count_tree_values(sub_spec) if sub_spec else 1 for _, sub_spec in tree_spec)
|
| 166 |
+
|
| 167 |
+
# 修复pyarrow兼容性问题
|
| 168 |
+
try:
|
| 169 |
+
import pyarrow
|
| 170 |
+
if not hasattr(pyarrow, 'PyExtensionType'):
|
| 171 |
+
# 为旧版本pyarrow添加PyExtensionType兼容性
|
| 172 |
+
pyarrow.PyExtensionType = type('PyExtensionType', (), {})
|
| 173 |
+
except ImportError:
|
| 174 |
+
pass
|
| 175 |
+
|
| 176 |
+
except (ImportError, AttributeError) as e:
|
| 177 |
+
print(f"Warning: PyTorch/PyArrow compatibility patch failed: {e}")
|
| 178 |
+
pass
|
| 179 |
+
IMAGES_DIR = os.environ.get("IMAGES_DIR", "/opt/data/images")
|
| 180 |
+
OUTPUT_DIR = IMAGES_DIR
|
| 181 |
+
|
| 182 |
+
# 明星图库目录配置
|
| 183 |
+
CELEBRITY_SOURCE_DIR = os.environ.get(
|
| 184 |
+
"CELEBRITY_SOURCE_DIR", "/opt/data/chinese_celeb_dataset"
|
| 185 |
+
).strip()
|
| 186 |
+
if CELEBRITY_SOURCE_DIR:
|
| 187 |
+
CELEBRITY_SOURCE_DIR = os.path.abspath(os.path.expanduser(CELEBRITY_SOURCE_DIR))
|
| 188 |
+
|
| 189 |
+
CELEBRITY_DATASET_DIR = os.path.abspath(
|
| 190 |
+
os.path.expanduser(
|
| 191 |
+
os.environ.get(
|
| 192 |
+
"CELEBRITY_DATASET_DIR",
|
| 193 |
+
CELEBRITY_SOURCE_DIR or "/opt/data/chinese_celeb_dataset",
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
CELEBRITY_FIND_THRESHOLD = float(
|
| 199 |
+
os.environ.get("CELEBRITY_FIND_THRESHOLD", 0.88)
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# ---- start ----
|
| 203 |
+
# 微信小程序配置(默认值仅用于本地开发)
|
| 204 |
+
WECHAT_APPID = os.environ.get("WECHAT_APPID", "******").strip()
|
| 205 |
+
WECHAT_SECRET = os.environ.get("WCT_SECRET", "******").strip()
|
| 206 |
+
APP_SECRET_TOKEN = os.environ.get("APP_SECRET_TOKEN", "******")
|
| 207 |
+
# MySQL 数据库配置
|
| 208 |
+
MYSQL_HOST = os.environ.get("MYSQL_HOST", "******")
|
| 209 |
+
MYSQL_PORT = int(os.environ.get("MYSQL_PORT", "3306"))
|
| 210 |
+
MYSQL_DB = os.environ.get("MYSQL_DB", "******")
|
| 211 |
+
MYSQL_USER = os.environ.get("MYSQL_USER", "******")
|
| 212 |
+
MYSQL_PASSWORD = os.environ.get("MYSQL_PASSWORD", "******")
|
| 213 |
+
# BOS 对象存储配置(默认存储为Base64编码字符串)
|
| 214 |
+
BOS_ACCESS_KEY = os.environ.get("BOS_ACCESS_KEY", "******").strip()
|
| 215 |
+
BOS_SECRET_KEY = os.environ.get("BOS_SECRET_KEY", "******").strip()
|
| 216 |
+
BOS_ENDPOINT = os.environ.get("BOS_ENDPOINT", "******").strip()
|
| 217 |
+
BOS_BUCKET_NAME = os.environ.get("BOS_BUCKET_NAME", "******").strip()
|
| 218 |
+
BOS_IMAGE_DIR = os.environ.get("BOS_IMAGE_DIR", "******").strip()
|
| 219 |
+
BOS_MODELS_PREFIX = os.environ.get("BOS_MODELS_PREFIX", "******").strip()
|
| 220 |
+
BOS_CELEBRITY_PREFIX = os.environ.get("BOS_CELEBRITY_PREFIX", "******").strip()
|
| 221 |
+
# ---- end ---
|
| 222 |
+
|
| 223 |
+
_bos_enabled_env = os.environ.get("BOS_UPLOAD_ENABLED")
|
| 224 |
+
MYSQL_POOL_MIN_SIZE = int(os.environ.get("MYSQL_POOL_MIN_SIZE", "1"))
|
| 225 |
+
MYSQL_POOL_MAX_SIZE = int(os.environ.get("MYSQL_POOL_MAX_SIZE", "10"))
|
| 226 |
+
if _bos_enabled_env is not None:
|
| 227 |
+
BOS_UPLOAD_ENABLED = _bos_enabled_env.lower() in ("1", "true", "on")
|
| 228 |
+
else:
|
| 229 |
+
BOS_UPLOAD_ENABLED = all(
|
| 230 |
+
[
|
| 231 |
+
BOS_ACCESS_KEY.strip(),
|
| 232 |
+
BOS_SECRET_KEY.strip(),
|
| 233 |
+
BOS_ENDPOINT,
|
| 234 |
+
BOS_BUCKET_NAME,
|
| 235 |
+
]
|
| 236 |
+
)
|
| 237 |
+
HOSTNAME = os.environ.get("HOSTNAME", "default-hostname")
|
| 238 |
+
MODELS_PATH = os.path.abspath(
|
| 239 |
+
os.path.expanduser(os.environ.get("MODELS_PATH", "/opt/data/models"))
|
| 240 |
+
)
|
| 241 |
+
MODELS_DOWNLOAD_DIR = os.path.abspath(
|
| 242 |
+
os.path.expanduser(os.environ.get("MODELS_DOWNLOAD_DIR", MODELS_PATH))
|
| 243 |
+
)
|
| 244 |
+
# HuggingFace 仓库配置
|
| 245 |
+
HUGGINGFACE_SYNC_ENABLED = os.environ.get(
|
| 246 |
+
"HUGGINGFACE_SYNC_ENABLED", "true"
|
| 247 |
+
).lower() in ("1", "true", "on")
|
| 248 |
+
HUGGINGFACE_REPO_ID = os.environ.get(
|
| 249 |
+
"HUGGINGFACE_REPO_ID", "ethonmax/facescore"
|
| 250 |
+
).strip()
|
| 251 |
+
HUGGINGFACE_REVISION = os.environ.get(
|
| 252 |
+
"HUGGINGFACE_REVISION", "main"
|
| 253 |
+
).strip()
|
| 254 |
+
_hf_allow_env = os.environ.get("HUGGINGFACE_ALLOW_PATTERNS", "").strip()
|
| 255 |
+
HUGGINGFACE_ALLOW_PATTERNS = [
|
| 256 |
+
pattern.strip() for pattern in _hf_allow_env.split(",") if pattern.strip()
|
| 257 |
+
]
|
| 258 |
+
_hf_ignore_env = os.environ.get("HUGGINGFACE_IGNORE_PATTERNS", "").strip()
|
| 259 |
+
HUGGINGFACE_IGNORE_PATTERNS = [
|
| 260 |
+
pattern.strip() for pattern in _hf_ignore_env.split(",") if pattern.strip()
|
| 261 |
+
]
|
| 262 |
+
|
| 263 |
+
_MODELSCOPE_CACHE_ENV = os.environ.get("MODELSCOPE_CACHE", "").strip()
|
| 264 |
+
if _MODELSCOPE_CACHE_ENV:
|
| 265 |
+
MODELSCOPE_CACHE_DIR = os.path.abspath(os.path.expanduser(_MODELSCOPE_CACHE_ENV))
|
| 266 |
+
else:
|
| 267 |
+
MODELSCOPE_CACHE_DIR = os.path.join(MODELS_PATH, "modelscope")
|
| 268 |
+
|
| 269 |
+
try:
|
| 270 |
+
os.makedirs(MODELSCOPE_CACHE_DIR, exist_ok=True)
|
| 271 |
+
except Exception as exc:
|
| 272 |
+
print(f"创建 ModelScope 缓存目录失败: %s (%s)", MODELSCOPE_CACHE_DIR, exc)
|
| 273 |
+
|
| 274 |
+
os.environ.setdefault("MODELSCOPE_CACHE", MODELSCOPE_CACHE_DIR)
|
| 275 |
+
os.environ.setdefault("MODELSCOPE_HOME", MODELSCOPE_CACHE_DIR)
|
| 276 |
+
os.environ.setdefault("MODELSCOPE_CACHE_HOME", MODELSCOPE_CACHE_DIR)
|
| 277 |
+
|
| 278 |
+
DEEPFACE_HOME = os.environ.get("DEEPFACE_HOME", "/opt/data/models")
|
| 279 |
+
os.environ["DEEPFACE_HOME"] = DEEPFACE_HOME
|
| 280 |
+
|
| 281 |
+
# 设置GFPGAN相关模型下载路径
|
| 282 |
+
GFPGAN_MODEL_DIR = MODELS_DOWNLOAD_DIR
|
| 283 |
+
os.makedirs(GFPGAN_MODEL_DIR, exist_ok=True)
|
| 284 |
+
|
| 285 |
+
# 设置各种模型库的下载目录环境变量
|
| 286 |
+
os.environ["GFPGAN_MODEL_ROOT"] = GFPGAN_MODEL_DIR
|
| 287 |
+
os.environ["FACEXLIB_CACHE_DIR"] = GFPGAN_MODEL_DIR
|
| 288 |
+
os.environ["BASICSR_CACHE_DIR"] = GFPGAN_MODEL_DIR
|
| 289 |
+
os.environ["REALESRGAN_MODEL_ROOT"] = GFPGAN_MODEL_DIR
|
| 290 |
+
os.environ["HUB_CACHE_DIR"] = GFPGAN_MODEL_DIR # PyTorch Hub缓存
|
| 291 |
+
|
| 292 |
+
# 设置rembg模型下载路径到统一的AI模型目录
|
| 293 |
+
REMBG_MODEL_DIR = os.path.expanduser(MODELS_PATH.replace("$HOME", "~"))
|
| 294 |
+
os.environ["U2NET_HOME"] = REMBG_MODEL_DIR # u2net模型缓存目录
|
| 295 |
+
os.environ["REMBG_HOME"] = REMBG_MODEL_DIR # rembg通用缓存目录
|
| 296 |
+
|
| 297 |
+
IMG_QUALITY = float(os.environ.get("IMG_QUALITY", 0.5))
|
| 298 |
+
FACE_CONFIDENCE = float(os.environ.get("FACE_CONFIDENCE", 0.7))
|
| 299 |
+
AGE_CONFIDENCE = float(os.environ.get("AGE_CONFIDENCE", 0.99))
|
| 300 |
+
GENDER_CONFIDENCE = float(os.environ.get("GENDER_CONFIDENCE", 1.1))
|
| 301 |
+
UPSCALE_SIZE = int(os.environ.get("UPSCALE_SIZE", 2))
|
| 302 |
+
SAVE_QUALITY = int(os.environ.get("SAVE_QUALITY", 85))
|
| 303 |
+
REALESRGAN_MODEL = os.environ.get("REALESRGAN_MODEL", "realesr-general-x4v3")
|
| 304 |
+
# yolov11n-face.pt / yolov8n-face.pt
|
| 305 |
+
YOLO_MODEL = os.environ.get("YOLO_MODEL", "yolov11n-face.pt")
|
| 306 |
+
# mobilenetv3/resnet50
|
| 307 |
+
RVM_MODEL = os.environ.get("RVM_MODEL", "resnet50")
|
| 308 |
+
RVM_LOCAL_REPO = os.environ.get("RVM_LOCAL_REPO", "/opt/data/RobustVideoMatting").strip()
|
| 309 |
+
RVM_WEIGHTS_PATH = os.environ.get("RVM_WEIGHTS_PATH", "/opt/data/models/torch/hub/checkpoints/rvm_resnet50.pth").strip()
|
| 310 |
+
DRAW_SCORE = os.environ.get("DRAW_SCORE", "true").lower() in ("1", "true", "on")
|
| 311 |
+
|
| 312 |
+
# 颜值评分温和提升配置(默认开启;默认区间与力度:区间=[6.0, 8.0],gamma=0.3)
|
| 313 |
+
# - BEAUTY_ADJUST_ENABLED: 是否开启提分
|
| 314 |
+
# - BEAUTY_ADJUST_MIN: 提分下限(低于该值不提分)
|
| 315 |
+
# - BEAUTY_ADJUST_MAX: 提分上限(目标上限;仅在 [min, max) 区间内提分)
|
| 316 |
+
# - BEAUTY_ADJUST_THRESHOLD: 兼容旧配置,等价于 BEAUTY_ADJUST_MAX
|
| 317 |
+
# - BEAUTY_ADJUST_GAMMA: 提分力度,(0,1],越小提升越多
|
| 318 |
+
BEAUTY_ADJUST_ENABLED = os.environ.get("BEAUTY_ADJUST_ENABLED", "true").lower() in ("1", "true", "on")
|
| 319 |
+
BEAUTY_ADJUST_MIN = float(os.environ.get("BEAUTY_ADJUST_MIN", 1.0))
|
| 320 |
+
# 向后兼容:未提供 BEAUTY_ADJUST_MAX 时,使用旧的 BEAUTY_ADJUST_THRESHOLD 或 8.0
|
| 321 |
+
_legacy_thr = os.environ.get("BEAUTY_ADJUST_THRESHOLD")
|
| 322 |
+
BEAUTY_ADJUST_MAX = float(os.environ.get("BEAUTY_ADJUST_MAX", _legacy_thr if _legacy_thr is not None else 8.0))
|
| 323 |
+
BEAUTY_ADJUST_GAMMA = float(os.environ.get("BEAUTY_ADJUST_GAMMA", 0.5)) # 0<gamma<=1,越小提升越多
|
| 324 |
+
|
| 325 |
+
# 兼容旧引用,保留变量名(不再直接使用于逻辑内部)
|
| 326 |
+
BEAUTY_ADJUST_THRESHOLD = BEAUTY_ADJUST_MAX
|
| 327 |
+
|
| 328 |
+
# 整体协调性分数温和提升配置(默认开启;默认阈值与力度:T=8.0, gamma=0.5)
|
| 329 |
+
HARMONY_ADJUST_ENABLED = os.environ.get("HARMONY_ADJUST_ENABLED", "true").lower() in ("1", "true", "on")
|
| 330 |
+
HARMONY_ADJUST_THRESHOLD = float(os.environ.get("HARMONY_ADJUST_THRESHOLD", 9.0))
|
| 331 |
+
HARMONY_ADJUST_GAMMA = float(os.environ.get("HARMONY_ADJUST_GAMMA", 0.3))
|
| 332 |
+
|
| 333 |
+
# 启动优化:是否在启动时自动初始化/预热重型组件
|
| 334 |
+
ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "false").lower() in ("1", "true", "on")
|
| 335 |
+
AUTO_INIT_ANALYZER = os.environ.get("AUTO_INIT_ANALYZER", "true").lower() in ("1", "true", "on")
|
| 336 |
+
AUTO_INIT_GFPGAN = os.environ.get("AUTO_INIT_GFPGAN", "false").lower() in ("1", "true", "on")
|
| 337 |
+
AUTO_INIT_DDCOLOR = os.environ.get("AUTO_INIT_DDCOLOR", "false").lower() in ("1", "true", "on")
|
| 338 |
+
AUTO_INIT_REALESRGAN = os.environ.get("AUTO_INIT_REALESRGAN", "false").lower() in ("1", "true", "on")
|
| 339 |
+
AUTO_INIT_REMBG = os.environ.get("AUTO_INIT_REMBG", "false").lower() in ("1", "true", "on")
|
| 340 |
+
AUTO_INIT_ANIME_STYLE = os.environ.get("AUTO_INIT_ANIME_STYLE", "false").lower() in ("1", "true", "on")
|
| 341 |
+
AUTO_INIT_RVM = os.environ.get("AUTO_INIT_RVM", "false").lower() in ("1", "true", "on")
|
| 342 |
+
|
| 343 |
+
# 定时任务相关配置
|
| 344 |
+
CLEANUP_INTERVAL_HOURS = float(os.environ.get("CLEANUP_INTERVAL_HOURS", 1.0)) # 清理任务执行间隔(小时),默认1小时
|
| 345 |
+
CLEANUP_AGE_HOURS = float(os.environ.get("CLEANUP_AGE_HOURS", 1.0)) # 清理文件的年龄阈值(小时),默认1小时
|
| 346 |
+
|
| 347 |
+
# BOS 自动同步清单:定义 BOS 路径和本地目录的映射,启动时可迭代该结构完成批量下载
|
| 348 |
+
BOS_DOWNLOAD_TARGETS = [
|
| 349 |
+
# {
|
| 350 |
+
# "description": "明星图库数据集",
|
| 351 |
+
# "bos_prefix": BOS_CELEBRITY_PREFIX,
|
| 352 |
+
# "destination": CELEBRITY_DATASET_DIR,
|
| 353 |
+
# "background": True,
|
| 354 |
+
# },
|
| 355 |
+
# {
|
| 356 |
+
# "description": "AI 模型权重",
|
| 357 |
+
# "bos_prefix": BOS_MODELS_PREFIX,
|
| 358 |
+
# "destination": MODELS_DOWNLOAD_DIR,
|
| 359 |
+
# },
|
| 360 |
+
]
|
| 361 |
+
|
| 362 |
+
log_level_str = os.getenv("LOG_LEVEL", "INFO").upper()
|
| 363 |
+
log_level = getattr(logging, log_level_str, logging.INFO)
|
| 364 |
+
|
| 365 |
+
# 日志开关配置 - 控制是否启用所有日志输出
|
| 366 |
+
ENABLE_LOGGING = os.environ.get("ENABLE_LOGGING", "true").lower() in ("1", "true", "on")
|
| 367 |
+
|
| 368 |
+
# 功能开关配置
|
| 369 |
+
ENABLE_DDCOLOR = os.environ.get("ENABLE_DDCOLOR", "true").lower() in ("1", "true", "on")
|
| 370 |
+
ENABLE_REALESRGAN = os.environ.get("ENABLE_REALESRGAN", "true").lower() in ("1", "true", "on")
|
| 371 |
+
ENABLE_GFPGAN = os.environ.get("ENABLE_GFPGAN", "true").lower() in ("1", "true", "on")
|
| 372 |
+
ENABLE_ANIME_STYLE = os.environ.get("ENABLE_ANIME_STYLE", "true").lower() in ("1", "true", "on")
|
| 373 |
+
ENABLE_ANIME_PRELOAD = os.environ.get("ENABLE_ANIME_PRELOAD", "false").lower() in ("1", "true", "on")
|
| 374 |
+
ENABLE_RVM = os.environ.get("ENABLE_RVM", "true").lower() in ("1", "true", "on")
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# 颜值评分模块配置
|
| 378 |
+
FACE_SCORE_MAX_IMAGES = int(os.environ.get("FACE_SCORE_MAX_IMAGES", 10)) # 颜值评分最大上传图片数量
|
| 379 |
+
|
| 380 |
+
# 女性年龄调整配置 - 对于20岁以上的女性,显示的年龄会减去指定岁数
|
| 381 |
+
FEMALE_AGE_ADJUSTMENT = int(os.environ.get("FEMALE_AGE_ADJUSTMENT", 3)) # 默认减3岁
|
| 382 |
+
FEMALE_AGE_ADJUSTMENT_THRESHOLD = int(os.environ.get("FEMALE_AGE_ADJUSTMENT_THRESHOLD", 20)) # 年龄阈值,默认20岁
|
| 383 |
+
|
| 384 |
+
# 配置日志
|
| 385 |
+
if ENABLE_LOGGING:
|
| 386 |
+
logging.basicConfig(
|
| 387 |
+
level=log_level,
|
| 388 |
+
format="[%(asctime)s] [%(levelname)s] %(message)s",
|
| 389 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 390 |
+
)
|
| 391 |
+
logger = logging.getLogger(__name__)
|
| 392 |
+
else:
|
| 393 |
+
# 禁用所有日志输出
|
| 394 |
+
logging.basicConfig(level=logging.CRITICAL + 10)
|
| 395 |
+
logger = logging.getLogger(__name__)
|
| 396 |
+
logger.disabled = True
|
| 397 |
+
|
| 398 |
+
# 全局变量存储 access_token
|
| 399 |
+
access_token_cache = {"token": None, "expires_at": 0}
|
| 400 |
+
|
| 401 |
+
# 尝试导入依赖
|
| 402 |
+
try:
|
| 403 |
+
from deepface import DeepFace
|
| 404 |
+
|
| 405 |
+
DEEPFACE_AVAILABLE = True
|
| 406 |
+
except ImportError:
|
| 407 |
+
print("Warning: DeepFace not installed. Install with: pip install deepface")
|
| 408 |
+
DEEPFACE_AVAILABLE = False
|
| 409 |
+
|
| 410 |
+
try:
|
| 411 |
+
import mediapipe as mp
|
| 412 |
+
|
| 413 |
+
MEDIAPIPE_AVAILABLE = True
|
| 414 |
+
except ImportError:
|
| 415 |
+
print("Warning: mediapipe not installed. Install with: pip install mediapipe")
|
| 416 |
+
MEDIAPIPE_AVAILABLE = False
|
| 417 |
+
|
| 418 |
+
# 为了保持向后兼容,保留 DLIB_AVAILABLE 变量名
|
| 419 |
+
DLIB_AVAILABLE = MEDIAPIPE_AVAILABLE
|
| 420 |
+
|
| 421 |
+
try:
|
| 422 |
+
from ultralytics import YOLO
|
| 423 |
+
|
| 424 |
+
YOLO_AVAILABLE = True
|
| 425 |
+
except ImportError:
|
| 426 |
+
print("Warning: ultralytics not installed. Install with: pip install ultralytics")
|
| 427 |
+
YOLO_AVAILABLE = False
|
| 428 |
+
|
| 429 |
+
# 检查GFPGAN是否启用和可用
|
| 430 |
+
if ENABLE_GFPGAN:
|
| 431 |
+
try:
|
| 432 |
+
required_files = [
|
| 433 |
+
os.path.join(os.path.dirname(__file__), "gfpgan_restorer.py"),
|
| 434 |
+
os.path.join(MODELS_PATH, "gfpgan/weights/detection_Resnet50_Final.pth"),
|
| 435 |
+
os.path.join(MODELS_PATH, "gfpgan/weights/parsing_parsenet.pth"),
|
| 436 |
+
]
|
| 437 |
+
|
| 438 |
+
missing_files = [path for path in required_files if not os.path.exists(path)]
|
| 439 |
+
if missing_files:
|
| 440 |
+
for file_path in missing_files:
|
| 441 |
+
logger.info("GFPGAN 所需文件暂未找到,将等待模型同步: %s", file_path)
|
| 442 |
+
|
| 443 |
+
from gfpgan_restorer import GFPGANRestorer # noqa: F401
|
| 444 |
+
GFPGAN_AVAILABLE = True
|
| 445 |
+
|
| 446 |
+
if missing_files:
|
| 447 |
+
logger.warning(
|
| 448 |
+
"GFPGAN 文件尚未全部就绪,将在 HuggingFace/BOS 同步完成后继续初始化: %s",
|
| 449 |
+
", ".join(missing_files),
|
| 450 |
+
)
|
| 451 |
+
else:
|
| 452 |
+
logger.info("GFPGAN photo restoration feature prerequisites detected")
|
| 453 |
+
except ImportError as e:
|
| 454 |
+
print(f"Warning: GFPGAN enabled but not available: {e}")
|
| 455 |
+
GFPGAN_AVAILABLE = False
|
| 456 |
+
logger.warning(f"GFPGAN photo restoration feature is enabled but import failed: {e}")
|
| 457 |
+
else:
|
| 458 |
+
GFPGAN_AVAILABLE = False
|
| 459 |
+
logger.info("GFPGAN photo restoration feature is disabled (via ENABLE_GFPGAN environment variable)")
|
| 460 |
+
|
| 461 |
+
# 检查DDColor是否启用和可用
|
| 462 |
+
if ENABLE_DDCOLOR:
|
| 463 |
+
try:
|
| 464 |
+
from ddcolor_colorizer import DDColorColorizer
|
| 465 |
+
DDCOLOR_AVAILABLE = True
|
| 466 |
+
logger.info("DDColor feature is enabled and available")
|
| 467 |
+
except ImportError as e:
|
| 468 |
+
print(f"Warning: DDColor enabled but not available: {e}")
|
| 469 |
+
DDCOLOR_AVAILABLE = False
|
| 470 |
+
logger.warning(f"DDColor feature is enabled but import failed: {e}")
|
| 471 |
+
else:
|
| 472 |
+
DDCOLOR_AVAILABLE = False
|
| 473 |
+
logger.info("DDColor feature is disabled (via ENABLE_DDCOLOR environment variable)")
|
| 474 |
+
|
| 475 |
+
# 只使用GFPGAN修复器
|
| 476 |
+
SIMPLE_RESTORER_AVAILABLE = False
|
| 477 |
+
|
| 478 |
+
# 检查Real-ESRGAN是否启用和可用
|
| 479 |
+
if ENABLE_REALESRGAN:
|
| 480 |
+
try:
|
| 481 |
+
from realesrgan_upscaler import RealESRGANUpscaler
|
| 482 |
+
REALESRGAN_AVAILABLE = True
|
| 483 |
+
logger.info("Real-ESRGAN super resolution feature is enabled and available")
|
| 484 |
+
except ImportError as e:
|
| 485 |
+
print(f"Warning: Real-ESRGAN enabled but not available: {e}")
|
| 486 |
+
REALESRGAN_AVAILABLE = False
|
| 487 |
+
logger.warning(f"Real-ESRGAN super resolution feature is enabled but import failed: {e}")
|
| 488 |
+
else:
|
| 489 |
+
REALESRGAN_AVAILABLE = False
|
| 490 |
+
logger.info("Real-ESRGAN super resolution feature is disabled (via ENABLE_REALESRGAN environment variable)")
|
| 491 |
+
|
| 492 |
+
# rembg功能开关配置
|
| 493 |
+
ENABLE_REMBG = os.environ.get("ENABLE_REMBG", "true").lower() in ("1", "true", "on")
|
| 494 |
+
|
| 495 |
+
# 检查rembg是否启用和可用
|
| 496 |
+
if ENABLE_REMBG:
|
| 497 |
+
try:
|
| 498 |
+
import rembg
|
| 499 |
+
from rembg import new_session
|
| 500 |
+
REMBG_AVAILABLE = True
|
| 501 |
+
logger.info("rembg background removal feature is enabled and available")
|
| 502 |
+
logger.info(f"rembg model storage path: {REMBG_MODEL_DIR}")
|
| 503 |
+
except ImportError as e:
|
| 504 |
+
print(f"Warning: rembg enabled but not available: {e}")
|
| 505 |
+
REMBG_AVAILABLE = False
|
| 506 |
+
logger.warning(f"rembg background removal feature is enabled but import failed: {e}")
|
| 507 |
+
else:
|
| 508 |
+
REMBG_AVAILABLE = False
|
| 509 |
+
logger.info("rembg background removal feature is disabled (via ENABLE_REMBG environment variable)")
|
| 510 |
+
|
| 511 |
+
CLIP_AVAILABLE = False
|
| 512 |
+
|
| 513 |
+
# 检查Anime Style是否启用和可用
|
| 514 |
+
if ENABLE_ANIME_STYLE:
|
| 515 |
+
try:
|
| 516 |
+
from anime_stylizer import AnimeStylizer
|
| 517 |
+
ANIME_STYLE_AVAILABLE = True
|
| 518 |
+
logger.info("Anime stylization feature is enabled and available")
|
| 519 |
+
except ImportError as e:
|
| 520 |
+
print(f"Warning: Anime Style enabled but not available: {e}")
|
| 521 |
+
ANIME_STYLE_AVAILABLE = False
|
| 522 |
+
logger.warning(f"Anime stylization feature is enabled but import failed: {e}")
|
| 523 |
+
else:
|
| 524 |
+
ANIME_STYLE_AVAILABLE = False
|
| 525 |
+
logger.info("Anime stylization feature is disabled (via ENABLE_ANIME_STYLE environment variable)")
|
| 526 |
+
|
| 527 |
+
# RVM功能开关配置
|
| 528 |
+
ENABLE_RVM = os.environ.get("ENABLE_RVM", "true").lower() in ("1", "true", "on")
|
| 529 |
+
|
| 530 |
+
# 检查RVM是否启用和可用
|
| 531 |
+
if ENABLE_RVM:
|
| 532 |
+
try:
|
| 533 |
+
import torch
|
| 534 |
+
# 检查是否可以加载RVM模型
|
| 535 |
+
RVM_AVAILABLE = True
|
| 536 |
+
logger.info("RVM background removal feature is enabled and available")
|
| 537 |
+
except ImportError as e:
|
| 538 |
+
print(f"Warning: RVM enabled but not available: {e}")
|
| 539 |
+
RVM_AVAILABLE = False
|
| 540 |
+
logger.warning(f"RVM background removal feature is enabled but import failed: {e}")
|
| 541 |
+
else:
|
| 542 |
+
RVM_AVAILABLE = False
|
| 543 |
+
logger.info("RVM background removal feature is disabled (via ENABLE_RVM environment variable)")
|
database.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from contextlib import asynccontextmanager
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Any, Dict, Iterable, List, Optional, Sequence
|
| 7 |
+
|
| 8 |
+
import aiomysql
|
| 9 |
+
from aiomysql.cursors import DictCursor
|
| 10 |
+
|
| 11 |
+
from config import (
|
| 12 |
+
IMAGES_DIR,
|
| 13 |
+
logger,
|
| 14 |
+
MYSQL_HOST,
|
| 15 |
+
MYSQL_PORT,
|
| 16 |
+
MYSQL_DB,
|
| 17 |
+
MYSQL_USER,
|
| 18 |
+
MYSQL_PASSWORD,
|
| 19 |
+
MYSQL_POOL_MIN_SIZE,
|
| 20 |
+
MYSQL_POOL_MAX_SIZE,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
_pool: Optional[aiomysql.Pool] = None
|
| 24 |
+
_pool_lock = asyncio.Lock()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
async def init_mysql_pool() -> aiomysql.Pool:
|
| 28 |
+
"""初始化 MySQL 连接池"""
|
| 29 |
+
global _pool
|
| 30 |
+
if _pool is not None:
|
| 31 |
+
return _pool
|
| 32 |
+
|
| 33 |
+
async with _pool_lock:
|
| 34 |
+
if _pool is not None:
|
| 35 |
+
return _pool
|
| 36 |
+
try:
|
| 37 |
+
_pool = await aiomysql.create_pool(
|
| 38 |
+
host=MYSQL_HOST,
|
| 39 |
+
port=MYSQL_PORT,
|
| 40 |
+
user=MYSQL_USER,
|
| 41 |
+
password=MYSQL_PASSWORD,
|
| 42 |
+
db=MYSQL_DB,
|
| 43 |
+
minsize=MYSQL_POOL_MIN_SIZE,
|
| 44 |
+
maxsize=MYSQL_POOL_MAX_SIZE,
|
| 45 |
+
autocommit=True,
|
| 46 |
+
charset="utf8mb4",
|
| 47 |
+
cursorclass=DictCursor,
|
| 48 |
+
)
|
| 49 |
+
logger.info(
|
| 50 |
+
"MySQL 连接池初始化成功,host=%s db=%s",
|
| 51 |
+
MYSQL_HOST,
|
| 52 |
+
MYSQL_DB,
|
| 53 |
+
)
|
| 54 |
+
except Exception as exc:
|
| 55 |
+
logger.error(f"初始化 MySQL 连接池失败: {exc}")
|
| 56 |
+
raise
|
| 57 |
+
return _pool
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
async def close_mysql_pool() -> None:
|
| 61 |
+
"""关闭 MySQL 连接池"""
|
| 62 |
+
global _pool
|
| 63 |
+
if _pool is None:
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
async with _pool_lock:
|
| 67 |
+
if _pool is None:
|
| 68 |
+
return
|
| 69 |
+
_pool.close()
|
| 70 |
+
await _pool.wait_closed()
|
| 71 |
+
_pool = None
|
| 72 |
+
logger.info("MySQL 连接池已关闭")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@asynccontextmanager
|
| 76 |
+
async def get_connection():
|
| 77 |
+
"""获取连接池中的连接"""
|
| 78 |
+
if _pool is None:
|
| 79 |
+
await init_mysql_pool()
|
| 80 |
+
assert _pool is not None
|
| 81 |
+
conn = await _pool.acquire()
|
| 82 |
+
try:
|
| 83 |
+
yield conn
|
| 84 |
+
finally:
|
| 85 |
+
_pool.release(conn)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
async def execute(query: str,
|
| 89 |
+
params: Sequence[Any] | Dict[str, Any] | None = None) -> None:
|
| 90 |
+
"""执行写入类 SQL"""
|
| 91 |
+
async with get_connection() as conn:
|
| 92 |
+
async with conn.cursor() as cursor:
|
| 93 |
+
await cursor.execute(query, params or ())
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
async def fetch_all(
|
| 97 |
+
query: str, params: Sequence[Any] | Dict[str, Any] | None = None
|
| 98 |
+
) -> List[Dict[str, Any]]:
|
| 99 |
+
"""执行查询并返回全部结果"""
|
| 100 |
+
async with get_connection() as conn:
|
| 101 |
+
async with conn.cursor() as cursor:
|
| 102 |
+
await cursor.execute(query, params or ())
|
| 103 |
+
rows = await cursor.fetchall()
|
| 104 |
+
return list(rows)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _serialize_extra(extra: Optional[Dict[str, Any]]) -> Optional[str]:
|
| 108 |
+
if extra is None:
|
| 109 |
+
return None
|
| 110 |
+
try:
|
| 111 |
+
return json.dumps(extra, ensure_ascii=False)
|
| 112 |
+
except Exception:
|
| 113 |
+
logger.warning("无法序列化 extra_metadata,已忽略")
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
async def upsert_image_record(
|
| 118 |
+
*,
|
| 119 |
+
file_path: str,
|
| 120 |
+
category: str,
|
| 121 |
+
nickname: Optional[str],
|
| 122 |
+
score: float,
|
| 123 |
+
is_cropped_face: bool,
|
| 124 |
+
size_bytes: int,
|
| 125 |
+
last_modified: datetime,
|
| 126 |
+
bos_uploaded: bool,
|
| 127 |
+
hostname: Optional[str] = None,
|
| 128 |
+
extra_metadata: Optional[Dict[str, Any]] = None,
|
| 129 |
+
) -> None:
|
| 130 |
+
"""写入或更新图片记录"""
|
| 131 |
+
query = """
|
| 132 |
+
INSERT INTO tpl_app_processed_images (
|
| 133 |
+
file_path,
|
| 134 |
+
category,
|
| 135 |
+
nickname,
|
| 136 |
+
score,
|
| 137 |
+
is_cropped_face,
|
| 138 |
+
size_bytes,
|
| 139 |
+
last_modified,
|
| 140 |
+
bos_uploaded,
|
| 141 |
+
hostname,
|
| 142 |
+
extra_metadata
|
| 143 |
+
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
| 144 |
+
ON DUPLICATE KEY UPDATE
|
| 145 |
+
category = VALUES(category),
|
| 146 |
+
nickname = VALUES(nickname),
|
| 147 |
+
score = VALUES(score),
|
| 148 |
+
is_cropped_face = VALUES(is_cropped_face),
|
| 149 |
+
size_bytes = VALUES(size_bytes),
|
| 150 |
+
last_modified = VALUES(last_modified),
|
| 151 |
+
bos_uploaded = VALUES(bos_uploaded),
|
| 152 |
+
hostname = VALUES(hostname),
|
| 153 |
+
extra_metadata = VALUES(extra_metadata),
|
| 154 |
+
updated_at = CURRENT_TIMESTAMP
|
| 155 |
+
"""
|
| 156 |
+
extra_value = _serialize_extra(extra_metadata)
|
| 157 |
+
await execute(
|
| 158 |
+
query,
|
| 159 |
+
(
|
| 160 |
+
file_path,
|
| 161 |
+
category,
|
| 162 |
+
nickname,
|
| 163 |
+
score,
|
| 164 |
+
1 if is_cropped_face else 0,
|
| 165 |
+
size_bytes,
|
| 166 |
+
last_modified,
|
| 167 |
+
1 if bos_uploaded else 0,
|
| 168 |
+
hostname,
|
| 169 |
+
extra_value,
|
| 170 |
+
),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
async def fetch_paged_image_records(
|
| 175 |
+
*,
|
| 176 |
+
category: Optional[str],
|
| 177 |
+
nickname: Optional[str],
|
| 178 |
+
offset: int,
|
| 179 |
+
limit: int,
|
| 180 |
+
) -> List[Dict[str, Any]]:
|
| 181 |
+
"""按条件分页查询图片记录"""
|
| 182 |
+
where_clauses: List[str] = []
|
| 183 |
+
params: List[Any] = []
|
| 184 |
+
if category and category != "all":
|
| 185 |
+
where_clauses.append("category = %s")
|
| 186 |
+
params.append(category)
|
| 187 |
+
if nickname:
|
| 188 |
+
where_clauses.append("nickname = %s")
|
| 189 |
+
params.append(nickname)
|
| 190 |
+
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
| 191 |
+
query = f"""
|
| 192 |
+
SELECT
|
| 193 |
+
file_path,
|
| 194 |
+
category,
|
| 195 |
+
nickname,
|
| 196 |
+
score,
|
| 197 |
+
is_cropped_face,
|
| 198 |
+
size_bytes,
|
| 199 |
+
last_modified,
|
| 200 |
+
bos_uploaded,
|
| 201 |
+
hostname
|
| 202 |
+
FROM tpl_app_processed_images
|
| 203 |
+
{where_sql}
|
| 204 |
+
ORDER BY last_modified DESC, id DESC
|
| 205 |
+
LIMIT %s OFFSET %s
|
| 206 |
+
"""
|
| 207 |
+
params.extend([limit, offset])
|
| 208 |
+
return await fetch_all(query, params)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
async def count_image_records(
|
| 212 |
+
*, category: Optional[str], nickname: Optional[str]
|
| 213 |
+
) -> int:
|
| 214 |
+
"""按条件统计图片记录数量"""
|
| 215 |
+
where_clauses: List[str] = []
|
| 216 |
+
params: List[Any] = []
|
| 217 |
+
if category and category != "all":
|
| 218 |
+
where_clauses.append("category = %s")
|
| 219 |
+
params.append(category)
|
| 220 |
+
if nickname:
|
| 221 |
+
where_clauses.append("nickname = %s")
|
| 222 |
+
params.append(nickname)
|
| 223 |
+
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
| 224 |
+
query = f"SELECT COUNT(*) AS total FROM tpl_app_processed_images {where_sql}"
|
| 225 |
+
rows = await fetch_all(query, params)
|
| 226 |
+
if not rows:
|
| 227 |
+
return 0
|
| 228 |
+
return int(rows[0].get("total", 0) or 0)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
async def fetch_today_category_counts() -> List[Dict[str, Any]]:
|
| 232 |
+
"""统计当天按类别分组的数量"""
|
| 233 |
+
query = """
|
| 234 |
+
SELECT
|
| 235 |
+
COALESCE(category, 'unknown') AS category,
|
| 236 |
+
COUNT(*) AS count
|
| 237 |
+
FROM tpl_app_processed_images
|
| 238 |
+
WHERE last_modified >= CURDATE()
|
| 239 |
+
AND last_modified < DATE_ADD(CURDATE(), INTERVAL 1 DAY)
|
| 240 |
+
GROUP BY COALESCE(category, 'unknown')
|
| 241 |
+
"""
|
| 242 |
+
rows = await fetch_all(query)
|
| 243 |
+
return [
|
| 244 |
+
{
|
| 245 |
+
"category": str(row.get("category") or "unknown"),
|
| 246 |
+
"count": int(row.get("count") or 0),
|
| 247 |
+
}
|
| 248 |
+
for row in rows
|
| 249 |
+
]
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
async def fetch_records_by_paths(file_paths: Iterable[str]) -> Dict[
|
| 253 |
+
str, Dict[str, Any]]:
|
| 254 |
+
"""根据文件名批量查询图片记录"""
|
| 255 |
+
paths = list({path for path in file_paths if path})
|
| 256 |
+
if not paths:
|
| 257 |
+
return {}
|
| 258 |
+
|
| 259 |
+
placeholders = ", ".join(["%s"] * len(paths))
|
| 260 |
+
query = f"""
|
| 261 |
+
SELECT
|
| 262 |
+
file_path,
|
| 263 |
+
category,
|
| 264 |
+
nickname,
|
| 265 |
+
score,
|
| 266 |
+
is_cropped_face,
|
| 267 |
+
size_bytes,
|
| 268 |
+
last_modified,
|
| 269 |
+
bos_uploaded,
|
| 270 |
+
hostname
|
| 271 |
+
FROM tpl_app_processed_images
|
| 272 |
+
WHERE file_path IN ({placeholders})
|
| 273 |
+
"""
|
| 274 |
+
rows = await fetch_all(query, paths)
|
| 275 |
+
return {row["file_path"]: row for row in rows}
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
_IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR))
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def _normalize_file_path(file_path: str) -> Optional[str]:
|
| 282 |
+
"""将绝对路径转换为相对 IMAGES_DIR 的文件名"""
|
| 283 |
+
try:
|
| 284 |
+
abs_path = os.path.abspath(os.path.expanduser(file_path))
|
| 285 |
+
if os.path.isdir(abs_path):
|
| 286 |
+
return None
|
| 287 |
+
if os.path.commonpath([_IMAGES_DIR_ABS, abs_path]) != _IMAGES_DIR_ABS:
|
| 288 |
+
return os.path.basename(abs_path)
|
| 289 |
+
rel_path = os.path.relpath(abs_path, _IMAGES_DIR_ABS)
|
| 290 |
+
return rel_path.replace("\\", "/")
|
| 291 |
+
except Exception:
|
| 292 |
+
return None
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def infer_category_from_filename(filename: str, default: str = "other") -> str:
|
| 296 |
+
"""根据文件名推断类别"""
|
| 297 |
+
lower_name = filename.lower()
|
| 298 |
+
if "_face_" in lower_name:
|
| 299 |
+
return "face"
|
| 300 |
+
if lower_name.endswith("_original.webp") or "_original" in lower_name:
|
| 301 |
+
return "original"
|
| 302 |
+
if "_restore" in lower_name:
|
| 303 |
+
return "restore"
|
| 304 |
+
if "_upcolor" in lower_name:
|
| 305 |
+
return "upcolor"
|
| 306 |
+
if "_compress" in lower_name:
|
| 307 |
+
return "compress"
|
| 308 |
+
if "_upscale" in lower_name:
|
| 309 |
+
return "upscale"
|
| 310 |
+
if "_anime_style_" in lower_name:
|
| 311 |
+
return "anime_style"
|
| 312 |
+
if "_grayscale" in lower_name:
|
| 313 |
+
return "grayscale"
|
| 314 |
+
if "_id_photo" in lower_name or "_save_id_photo" in lower_name:
|
| 315 |
+
return "id_photo"
|
| 316 |
+
if "_grid_" in lower_name:
|
| 317 |
+
return "grid"
|
| 318 |
+
if "_rvm_id_photo" in lower_name:
|
| 319 |
+
return "rvm"
|
| 320 |
+
if "_celebrity_" in lower_name or "_celebrity" in lower_name:
|
| 321 |
+
return "celebrity"
|
| 322 |
+
return default
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
from config import HOSTNAME
|
| 326 |
+
|
| 327 |
+
async def record_image_creation(
|
| 328 |
+
*,
|
| 329 |
+
file_path: str,
|
| 330 |
+
nickname: Optional[str],
|
| 331 |
+
score: float = 0.0,
|
| 332 |
+
category: Optional[str] = None,
|
| 333 |
+
bos_uploaded: bool = False,
|
| 334 |
+
extra_metadata: Optional[Dict[str, Any]] = None,
|
| 335 |
+
) -> None:
|
| 336 |
+
"""
|
| 337 |
+
记录图片元数据到数据库,如果数据库不可用则静默忽略。
|
| 338 |
+
:param file_path: 绝对或相对文件路径
|
| 339 |
+
:param nickname: 用户昵称
|
| 340 |
+
:param score: 关联得分
|
| 341 |
+
:param category: 文件类别,未提供时自动根据文件名推断
|
| 342 |
+
:param bos_uploaded: 是否已上传至 BOS
|
| 343 |
+
:param extra_metadata: 额外信息
|
| 344 |
+
"""
|
| 345 |
+
normalized = _normalize_file_path(file_path)
|
| 346 |
+
if normalized is None:
|
| 347 |
+
logger.info("record_image_creation: 无法计算文件名,路径=%s", file_path)
|
| 348 |
+
return
|
| 349 |
+
|
| 350 |
+
abs_path = os.path.join(_IMAGES_DIR_ABS, normalized)
|
| 351 |
+
if not os.path.isfile(abs_path):
|
| 352 |
+
logger.info("record_image_creation: 文件不存在,跳过记录 file=%s", abs_path)
|
| 353 |
+
return
|
| 354 |
+
|
| 355 |
+
try:
|
| 356 |
+
stat = os.stat(abs_path)
|
| 357 |
+
category_name = category or infer_category_from_filename(normalized)
|
| 358 |
+
is_cropped_face = "_face_" in normalized and normalized.count("_") >= 2
|
| 359 |
+
last_modified = datetime.fromtimestamp(stat.st_mtime)
|
| 360 |
+
|
| 361 |
+
nickname_value = nickname.strip() if isinstance(nickname,
|
| 362 |
+
str) and nickname.strip() else None
|
| 363 |
+
|
| 364 |
+
await upsert_image_record(
|
| 365 |
+
file_path=normalized,
|
| 366 |
+
category=category_name,
|
| 367 |
+
nickname=nickname_value,
|
| 368 |
+
score=score,
|
| 369 |
+
is_cropped_face=is_cropped_face,
|
| 370 |
+
size_bytes=stat.st_size,
|
| 371 |
+
last_modified=last_modified,
|
| 372 |
+
bos_uploaded=bos_uploaded,
|
| 373 |
+
hostname=HOSTNAME,
|
| 374 |
+
extra_metadata=extra_metadata,
|
| 375 |
+
)
|
| 376 |
+
except Exception as exc:
|
| 377 |
+
logger.warning(f"写入图片记录失败: {exc}")
|
ddcolor_colorizer.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from config import logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DDColorColorizer:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
start_time = time.perf_counter()
|
| 14 |
+
self.colorizer = None
|
| 15 |
+
# 检查是否启用DDColor功能
|
| 16 |
+
from config import ENABLE_DDCOLOR
|
| 17 |
+
if ENABLE_DDCOLOR:
|
| 18 |
+
self._initialize_model()
|
| 19 |
+
else:
|
| 20 |
+
logger.info("DDColor feature is disabled, skipping model initialization")
|
| 21 |
+
init_time = time.perf_counter() - start_time
|
| 22 |
+
if self.colorizer is not None:
|
| 23 |
+
logger.info(f"DDColorColorizer initialized successfully, time: {init_time:.3f}s")
|
| 24 |
+
else:
|
| 25 |
+
logger.info(f"DDColorColorizer initialization completed but not available, time: {init_time:.3f}s")
|
| 26 |
+
|
| 27 |
+
def _initialize_model(self):
|
| 28 |
+
"""初始化DDColor模型(使用ModelScope)"""
|
| 29 |
+
try:
|
| 30 |
+
logger.info("Initializing DDColor model (using ModelScope)...")
|
| 31 |
+
|
| 32 |
+
# 添加torch类型兼容性补丁
|
| 33 |
+
import torch
|
| 34 |
+
if not hasattr(torch, 'uint64'):
|
| 35 |
+
logger.info("Adding torch.uint64 compatibility patch...")
|
| 36 |
+
torch.uint64 = torch.int64 # 使用int64作为uint64的替代
|
| 37 |
+
if not hasattr(torch, 'uint32'):
|
| 38 |
+
logger.info("Adding torch.uint32 compatibility patch...")
|
| 39 |
+
torch.uint32 = torch.int32 # 使用int32作为uint32的替代
|
| 40 |
+
if not hasattr(torch, 'uint16'):
|
| 41 |
+
logger.info("Adding torch.uint16 compatibility patch...")
|
| 42 |
+
torch.uint16 = torch.int16 # 使用int16作为uint16的替代
|
| 43 |
+
|
| 44 |
+
# 导入ModelScope相关模块
|
| 45 |
+
from modelscope.outputs import OutputKeys
|
| 46 |
+
from modelscope.pipelines import pipeline
|
| 47 |
+
from modelscope.utils.constant import Tasks
|
| 48 |
+
|
| 49 |
+
# 初始化DDColor pipeline
|
| 50 |
+
self.colorizer = pipeline(
|
| 51 |
+
Tasks.image_colorization,
|
| 52 |
+
model='damo/cv_ddcolor_image-colorization'
|
| 53 |
+
)
|
| 54 |
+
self.OutputKeys = OutputKeys
|
| 55 |
+
|
| 56 |
+
logger.info("DDColor model initialized successfully")
|
| 57 |
+
|
| 58 |
+
except ImportError as e:
|
| 59 |
+
logger.error(f"ModelScope module import failed: {e}")
|
| 60 |
+
self.colorizer = None
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.error(f"DDColor model initialization failed: {e}")
|
| 63 |
+
self.colorizer = None
|
| 64 |
+
|
| 65 |
+
def is_available(self):
|
| 66 |
+
"""检查DDColor是否可用"""
|
| 67 |
+
return self.colorizer is not None
|
| 68 |
+
|
| 69 |
+
def is_grayscale(self, image):
|
| 70 |
+
"""检查图像是否为灰度图像"""
|
| 71 |
+
if len(image.shape) == 2:
|
| 72 |
+
return True
|
| 73 |
+
elif len(image.shape) == 3:
|
| 74 |
+
# 检查是否为伪彩色图像(RGB三个通道值相等)
|
| 75 |
+
b, g, r = cv2.split(image)
|
| 76 |
+
|
| 77 |
+
# 计算通道间的差异
|
| 78 |
+
diff_bg = np.abs(b.astype(float) - g.astype(float))
|
| 79 |
+
diff_gr = np.abs(g.astype(float) - r.astype(float))
|
| 80 |
+
diff_rb = np.abs(r.astype(float) - b.astype(float))
|
| 81 |
+
|
| 82 |
+
# 计算平均差异
|
| 83 |
+
avg_diff = (np.mean(diff_bg) + np.mean(diff_gr) + np.mean(diff_rb)) / 3.0
|
| 84 |
+
|
| 85 |
+
# 计算色彩饱和度
|
| 86 |
+
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
| 87 |
+
saturation = hsv[:, :, 1] # S通道
|
| 88 |
+
avg_saturation = np.mean(saturation)
|
| 89 |
+
|
| 90 |
+
# 改进的灰度检测:同时考虑通道差异和饱和度
|
| 91 |
+
is_gray = (avg_diff < 5.0) or (avg_saturation < 20.0)
|
| 92 |
+
|
| 93 |
+
logger.info(f"Grayscale detection - Average channel difference: {avg_diff:.2f}, Average saturation: {avg_saturation:.2f}, Result: {is_gray}")
|
| 94 |
+
return is_gray
|
| 95 |
+
return False
|
| 96 |
+
|
| 97 |
+
def colorize_image(self, image):
|
| 98 |
+
"""
|
| 99 |
+
使用DDColor对灰度图像进行上色
|
| 100 |
+
:param image: 输入图像 (numpy array, BGR格式)
|
| 101 |
+
:return: 上色后的图像 (numpy array, BGR格式)
|
| 102 |
+
"""
|
| 103 |
+
if not self.is_available():
|
| 104 |
+
logger.error("DDColor model not initialized")
|
| 105 |
+
return image
|
| 106 |
+
|
| 107 |
+
# 检查是否为灰度图像
|
| 108 |
+
if not self.is_grayscale(image):
|
| 109 |
+
logger.info("Image is already colored, no need for colorization")
|
| 110 |
+
return image
|
| 111 |
+
|
| 112 |
+
return self.colorize_image_direct(image)
|
| 113 |
+
|
| 114 |
+
def colorize_image_direct(self, image):
|
| 115 |
+
"""
|
| 116 |
+
直接对图像进行上色,不检查是否为灰度图
|
| 117 |
+
使用与test_ddcolor.py相同质量的文件路径方法
|
| 118 |
+
:param image: 输入图像 (numpy array, BGR格式)
|
| 119 |
+
:return: 上色后的图像 (numpy array, BGR格式)
|
| 120 |
+
"""
|
| 121 |
+
if not self.is_available():
|
| 122 |
+
logger.error("DDColor model not initialized")
|
| 123 |
+
return image
|
| 124 |
+
|
| 125 |
+
# 直接使用文件路径方法,这是经过验证效果最好的方式
|
| 126 |
+
return self._colorize_image_via_file(image)
|
| 127 |
+
|
| 128 |
+
def _colorize_image_via_file(self, image):
|
| 129 |
+
"""
|
| 130 |
+
通过临时文件进行上色,尽可能模拟test_ddcolor.py的处理方式
|
| 131 |
+
:param image: 输入图像 (numpy array, BGR格式)
|
| 132 |
+
:return: 上色后的图像 (numpy array, BGR格式)
|
| 133 |
+
"""
|
| 134 |
+
try:
|
| 135 |
+
logger.info("Using high-quality file path method for colorization...")
|
| 136 |
+
|
| 137 |
+
# 使用最高质量设置保存临时图像,尽可能保持原始质量
|
| 138 |
+
with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_input:
|
| 139 |
+
# 使用WebP格式以获得更好的质量和更小的文件大小
|
| 140 |
+
cv2.imwrite(tmp_input.name, image, [cv2.IMWRITE_WEBP_QUALITY, 100])
|
| 141 |
+
tmp_input_path = tmp_input.name
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
logger.info(f"Temporary file saved to: {tmp_input_path}")
|
| 145 |
+
|
| 146 |
+
# 使用ModelScope进行上色 - 与test_colorization完全相同的调用方式
|
| 147 |
+
result = self.colorizer(tmp_input_path)
|
| 148 |
+
|
| 149 |
+
# 获取上色后的图像 - 与test_colorization完全相同的处理
|
| 150 |
+
colorized_image = result[self.OutputKeys.OUTPUT_IMG]
|
| 151 |
+
|
| 152 |
+
logger.info(f"Colorization output: size={colorized_image.shape}, type={colorized_image.dtype}")
|
| 153 |
+
|
| 154 |
+
# ModelScope输出的图像已经是BGR格式,不需要转换
|
| 155 |
+
# (与test_colorization保存时直接使用cv2.imwrite一致)
|
| 156 |
+
logger.info("High-quality file path method colorization completed")
|
| 157 |
+
return colorized_image
|
| 158 |
+
|
| 159 |
+
finally:
|
| 160 |
+
# 清理临时文件
|
| 161 |
+
try:
|
| 162 |
+
os.unlink(tmp_input_path)
|
| 163 |
+
except:
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"High-quality file path method colorization failed: {e}")
|
| 168 |
+
logger.info("Returning original image")
|
| 169 |
+
return image
|
| 170 |
+
|
| 171 |
+
def restore_and_colorize(self, image, gfpgan_restorer=None):
|
| 172 |
+
"""
|
| 173 |
+
先修复后上色的组合处理(旧版本,保持兼容性)
|
| 174 |
+
:param image: 输入图像
|
| 175 |
+
:param gfpgan_restorer: GFPGAN修复器实例
|
| 176 |
+
:return: 修复并上色后的图像
|
| 177 |
+
"""
|
| 178 |
+
try:
|
| 179 |
+
# 先进行修复(如果有修复器)
|
| 180 |
+
if gfpgan_restorer and gfpgan_restorer.is_available():
|
| 181 |
+
logger.info("First performing image restoration...")
|
| 182 |
+
restored_image = gfpgan_restorer.restore_image(image)
|
| 183 |
+
else:
|
| 184 |
+
restored_image = image
|
| 185 |
+
|
| 186 |
+
# 再进行上色
|
| 187 |
+
if self.is_grayscale(restored_image):
|
| 188 |
+
logger.info("Grayscale image detected, performing colorization...")
|
| 189 |
+
colorized_image = self.colorize_image(restored_image)
|
| 190 |
+
return colorized_image
|
| 191 |
+
else:
|
| 192 |
+
logger.info("Image is already colored, only returning restoration result")
|
| 193 |
+
return restored_image
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
logger.error(f"Restoration and colorization combination processing failed: {e}")
|
| 197 |
+
return image
|
| 198 |
+
|
| 199 |
+
def colorize_and_restore(self, image, gfpgan_restorer=None):
|
| 200 |
+
"""
|
| 201 |
+
先上色后修复的组合处理(新版本)
|
| 202 |
+
:param image: 输入图像
|
| 203 |
+
:param gfpgan_restorer: GFPGAN修复器实例
|
| 204 |
+
:return: 上色并修复后的图像
|
| 205 |
+
"""
|
| 206 |
+
try:
|
| 207 |
+
# 先进行上色(如果是灰度图)
|
| 208 |
+
if self.is_grayscale(image):
|
| 209 |
+
logger.info("Grayscale image detected, performing colorization first...")
|
| 210 |
+
colorized_image = self.colorize_image_direct(image)
|
| 211 |
+
else:
|
| 212 |
+
logger.info("Image is already colored, skipping colorization step")
|
| 213 |
+
colorized_image = image
|
| 214 |
+
|
| 215 |
+
# 再进行修复(如果有修复器)
|
| 216 |
+
if gfpgan_restorer and gfpgan_restorer.is_available():
|
| 217 |
+
logger.info("Performing restoration on the colorized image...")
|
| 218 |
+
final_image = gfpgan_restorer.restore_image(colorized_image)
|
| 219 |
+
return final_image
|
| 220 |
+
else:
|
| 221 |
+
logger.info("No restorer available, returning colorization result")
|
| 222 |
+
return colorized_image
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
logger.error(f"Colorization and restoration combination processing failed: {e}")
|
| 226 |
+
return image
|
| 227 |
+
|
| 228 |
+
def save_debug_image(self, image, filename_prefix):
|
| 229 |
+
"""保存调试用的图像"""
|
| 230 |
+
try:
|
| 231 |
+
debug_path = f"{filename_prefix}_debug.webp"
|
| 232 |
+
cv2.imwrite(debug_path, image, [cv2.IMWRITE_WEBP_QUALITY, 95])
|
| 233 |
+
logger.info(f"Debug image saved: {debug_path}")
|
| 234 |
+
return debug_path
|
| 235 |
+
except Exception as e:
|
| 236 |
+
logger.error(f"Failed to save debug image: {e}")
|
| 237 |
+
return None
|
| 238 |
+
|
| 239 |
+
def test_colorization(self, test_url=None):
|
| 240 |
+
"""
|
| 241 |
+
测试上色功能
|
| 242 |
+
:param test_url: 测试图像URL,默认使用官方示例
|
| 243 |
+
:return: 测试结果
|
| 244 |
+
"""
|
| 245 |
+
if not self.is_available():
|
| 246 |
+
return False, "DDColor模型未初始化"
|
| 247 |
+
|
| 248 |
+
try:
|
| 249 |
+
test_url = test_url or 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/audrey_hepburn.jpg'
|
| 250 |
+
logger.info(f"Testing DDColor colorization feature, using image: {test_url}")
|
| 251 |
+
|
| 252 |
+
result = self.colorizer(test_url)
|
| 253 |
+
colorized_img = result[self.OutputKeys.OUTPUT_IMG]
|
| 254 |
+
|
| 255 |
+
# 保存测试结果
|
| 256 |
+
test_output_path = 'ddcolor_test_result.webp'
|
| 257 |
+
cv2.imwrite(test_output_path, colorized_img, [cv2.IMWRITE_WEBP_QUALITY, 95])
|
| 258 |
+
|
| 259 |
+
logger.info(f"DDColor test successful, result saved to: {test_output_path}")
|
| 260 |
+
return True, f"测试成功,结果保存到: {test_output_path}"
|
| 261 |
+
|
| 262 |
+
except Exception as e:
|
| 263 |
+
logger.error(f"DDColor test failed: {e}")
|
| 264 |
+
return False, f"测试失败: {e}"
|
| 265 |
+
|
| 266 |
+
def test_local_image(self, image_path):
|
| 267 |
+
"""
|
| 268 |
+
测试本地图像上色,用于对比分析
|
| 269 |
+
:param image_path: 本地图像路径
|
| 270 |
+
:return: 测试结果
|
| 271 |
+
"""
|
| 272 |
+
if not self.is_available():
|
| 273 |
+
return False, "DDColor模型未初始化"
|
| 274 |
+
|
| 275 |
+
try:
|
| 276 |
+
logger.info(f"Testing local image colorization: {image_path}")
|
| 277 |
+
|
| 278 |
+
# 读取本地图像
|
| 279 |
+
image = cv2.imread(image_path)
|
| 280 |
+
if image is None:
|
| 281 |
+
return False, f"无法读取图像: {image_path}"
|
| 282 |
+
|
| 283 |
+
# 检查是否为灰度
|
| 284 |
+
is_gray = self.is_grayscale(image)
|
| 285 |
+
logger.info(f"Local image grayscale detection result: {is_gray}")
|
| 286 |
+
|
| 287 |
+
# 保存原图用于对比
|
| 288 |
+
self.save_debug_image(image, "original")
|
| 289 |
+
|
| 290 |
+
# 直接上色
|
| 291 |
+
colorized_image = self.colorize_image_direct(image)
|
| 292 |
+
|
| 293 |
+
# 保存上色结果
|
| 294 |
+
result_path = self.save_debug_image(colorized_image, "local_colorized")
|
| 295 |
+
|
| 296 |
+
logger.info(f"Local image colorization successful, result saved to: {result_path}")
|
| 297 |
+
return True, f"本地图像上色成功,结果保存到: {result_path}"
|
| 298 |
+
|
| 299 |
+
except Exception as e:
|
| 300 |
+
logger.error(f"Local image colorization failed: {e}")
|
| 301 |
+
return False, f"本地图像上色失败: {e}"
|
debug_colorize.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
调试上色效果差异的脚本
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
# 添加当前目录到路径
|
| 12 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 13 |
+
|
| 14 |
+
from ddcolor_colorizer import DDColorColorizer
|
| 15 |
+
from gfpgan_restorer import GFPGANRestorer
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
# 设置日志
|
| 19 |
+
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
|
| 20 |
+
|
| 21 |
+
def simulate_api_processing(image_path):
|
| 22 |
+
"""
|
| 23 |
+
模拟API接口的完整处理流程
|
| 24 |
+
"""
|
| 25 |
+
print("\n=== 模拟API接口处理流程 ===")
|
| 26 |
+
|
| 27 |
+
# 初始化组件
|
| 28 |
+
print("初始化GFPGAN修复器...")
|
| 29 |
+
try:
|
| 30 |
+
gfpgan_restorer = GFPGANRestorer()
|
| 31 |
+
if not gfpgan_restorer.is_available():
|
| 32 |
+
print("❌ GFPGAN不可用")
|
| 33 |
+
return None
|
| 34 |
+
print("✅ GFPGAN初始化成功")
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"❌ GFPGAN初始化失败: {e}")
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
print("初始化DDColor上色器...")
|
| 40 |
+
try:
|
| 41 |
+
ddcolor_colorizer = DDColorColorizer()
|
| 42 |
+
if not ddcolor_colorizer.is_available():
|
| 43 |
+
print("❌ DDColor不可用")
|
| 44 |
+
return None
|
| 45 |
+
print("✅ DDColor初始化成功")
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"❌ DDColor初始化失败: {e}")
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
# 读取图像
|
| 51 |
+
print(f"读取图像: {image_path}")
|
| 52 |
+
image = cv2.imread(image_path)
|
| 53 |
+
if image is None:
|
| 54 |
+
print(f"❌ 无法读取图像: {image_path}")
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
print(f"原图尺寸: {image.shape}")
|
| 58 |
+
|
| 59 |
+
# 保存原图
|
| 60 |
+
ddcolor_colorizer.save_debug_image(image, "api_original")
|
| 61 |
+
|
| 62 |
+
# 检查原图灰度状态
|
| 63 |
+
original_is_grayscale = ddcolor_colorizer.is_grayscale(image)
|
| 64 |
+
print(f"原图灰度检测: {original_is_grayscale}")
|
| 65 |
+
|
| 66 |
+
# 新的处理流程:先上色再修复
|
| 67 |
+
# 步骤1: 上色处理
|
| 68 |
+
print("\n步骤1: 上色处理...")
|
| 69 |
+
if original_is_grayscale:
|
| 70 |
+
print("策略: 对原图进行上色")
|
| 71 |
+
colorized_image = ddcolor_colorizer.colorize_image_direct(image)
|
| 72 |
+
ddcolor_colorizer.save_debug_image(colorized_image, "api_colorized")
|
| 73 |
+
strategy = "先上色"
|
| 74 |
+
current_image = colorized_image
|
| 75 |
+
else:
|
| 76 |
+
print("策略: 图像已经是彩色的,跳过上色")
|
| 77 |
+
strategy = "跳过上色"
|
| 78 |
+
current_image = image
|
| 79 |
+
|
| 80 |
+
# 步骤2: GFPGAN修复
|
| 81 |
+
print("\n步骤2: GFPGAN修复...")
|
| 82 |
+
final_image = gfpgan_restorer.restore_image(current_image)
|
| 83 |
+
print(f"修复后图像尺寸: {final_image.shape}")
|
| 84 |
+
|
| 85 |
+
# 保存最终结果
|
| 86 |
+
result_path = ddcolor_colorizer.save_debug_image(final_image, "api_final")
|
| 87 |
+
|
| 88 |
+
strategy += " -> 再修复"
|
| 89 |
+
|
| 90 |
+
print(f"\n✅ API模拟完成")
|
| 91 |
+
print(f" - 处理策略: {strategy}")
|
| 92 |
+
print(f" - 最终结果: {result_path}")
|
| 93 |
+
|
| 94 |
+
return {
|
| 95 |
+
'original': image,
|
| 96 |
+
'colorized': colorized_image if original_is_grayscale else None,
|
| 97 |
+
'final': final_image,
|
| 98 |
+
'strategy': strategy
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
def test_direct_colorization(image_path):
|
| 102 |
+
"""
|
| 103 |
+
测试直接上色(类似test_ddcolor.py的方式)
|
| 104 |
+
"""
|
| 105 |
+
print("\n=== 测试直接上色 ===")
|
| 106 |
+
|
| 107 |
+
colorizer = DDColorColorizer()
|
| 108 |
+
if not colorizer.is_available():
|
| 109 |
+
print("❌ DDColor不可用")
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
# 直接使用URL进行上色(和test_ddcolor.py相同)
|
| 113 |
+
print("使用官方示例URL上色...")
|
| 114 |
+
success, message = colorizer.test_colorization()
|
| 115 |
+
|
| 116 |
+
if success:
|
| 117 |
+
print(f"✅ URL上色成功: {message}")
|
| 118 |
+
else:
|
| 119 |
+
print(f"❌ URL上色失败: {message}")
|
| 120 |
+
|
| 121 |
+
# 对本地图像进行直接上色
|
| 122 |
+
print(f"对本地图像直接上色: {image_path}")
|
| 123 |
+
success, message = colorizer.test_local_image(image_path)
|
| 124 |
+
|
| 125 |
+
if success:
|
| 126 |
+
print(f"✅ 本地图像上色成功: {message}")
|
| 127 |
+
else:
|
| 128 |
+
print(f"❌ 本地图像上色失败: {message}")
|
| 129 |
+
|
| 130 |
+
def compare_results():
|
| 131 |
+
"""
|
| 132 |
+
对比分析结果
|
| 133 |
+
"""
|
| 134 |
+
print("\n=== 结果对比分析 ===")
|
| 135 |
+
|
| 136 |
+
# 列出生成的调试图像
|
| 137 |
+
debug_files = []
|
| 138 |
+
for f in os.listdir("."):
|
| 139 |
+
if f.endswith("_debug.webp"):
|
| 140 |
+
debug_files.append(f)
|
| 141 |
+
|
| 142 |
+
if debug_files:
|
| 143 |
+
print("生成的调试文件:")
|
| 144 |
+
for f in sorted(debug_files):
|
| 145 |
+
print(f" - {f}")
|
| 146 |
+
|
| 147 |
+
print("\n对比建议:")
|
| 148 |
+
print("1. 比较 original_debug.webp 和 api_original_debug.webp")
|
| 149 |
+
print("2. 比较 local_colorized_debug.webp 和 api_final_debug.webp")
|
| 150 |
+
print("3. 检查 api_restored_debug.webp 的修复效果")
|
| 151 |
+
print("4. 观察 ddcolor_test_result.webp 的官方示例效果")
|
| 152 |
+
else:
|
| 153 |
+
print("未找到调试文件")
|
| 154 |
+
|
| 155 |
+
def analyze_image_quality(image_path):
|
| 156 |
+
"""
|
| 157 |
+
分析图像质量指标
|
| 158 |
+
"""
|
| 159 |
+
print(f"\n=== 分析图像质量: {image_path} ===")
|
| 160 |
+
|
| 161 |
+
if not os.path.exists(image_path):
|
| 162 |
+
print(f"文件不存在: {image_path}")
|
| 163 |
+
return
|
| 164 |
+
|
| 165 |
+
image = cv2.imread(image_path)
|
| 166 |
+
if image is None:
|
| 167 |
+
print(f"无法读取图像: {image_path}")
|
| 168 |
+
return
|
| 169 |
+
|
| 170 |
+
# 基本信息
|
| 171 |
+
h, w, c = image.shape
|
| 172 |
+
print(f"尺寸: {w}x{h}, 通道数: {c}")
|
| 173 |
+
|
| 174 |
+
# 亮度分析
|
| 175 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 176 |
+
mean_brightness = np.mean(gray)
|
| 177 |
+
print(f"平均亮度: {mean_brightness:.2f}")
|
| 178 |
+
|
| 179 |
+
# 对比度分析
|
| 180 |
+
contrast = np.std(gray)
|
| 181 |
+
print(f"对比度(标准差): {contrast:.2f}")
|
| 182 |
+
|
| 183 |
+
# 色彩分析
|
| 184 |
+
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
| 185 |
+
mean_saturation = np.mean(hsv[:, :, 1])
|
| 186 |
+
print(f"平均饱和度: {mean_saturation:.2f}")
|
| 187 |
+
|
| 188 |
+
# 锐度分析(拉普拉斯算子)
|
| 189 |
+
laplacian = cv2.Laplacian(gray, cv2.CV_64F)
|
| 190 |
+
sharpness = np.var(laplacian)
|
| 191 |
+
print(f"锐度: {sharpness:.2f}")
|
| 192 |
+
|
| 193 |
+
def main():
|
| 194 |
+
"""主函数"""
|
| 195 |
+
print("DDColor 上色效果调试工具")
|
| 196 |
+
print("=" * 60)
|
| 197 |
+
|
| 198 |
+
# 可以指定测试图像路径,或使用默认路径
|
| 199 |
+
test_image_path = "/path/to/your/test/image.jpg" # 替换为实际路径
|
| 200 |
+
|
| 201 |
+
if len(sys.argv) > 1:
|
| 202 |
+
test_image_path = sys.argv[1]
|
| 203 |
+
|
| 204 |
+
print(f"测试图像路径: {test_image_path}")
|
| 205 |
+
|
| 206 |
+
if not os.path.exists(test_image_path):
|
| 207 |
+
print("⚠️ 测试图像不存在,将只运行URL测试")
|
| 208 |
+
|
| 209 |
+
# 只测试直接上色
|
| 210 |
+
test_direct_colorization(None)
|
| 211 |
+
|
| 212 |
+
else:
|
| 213 |
+
# 分析原图质量
|
| 214 |
+
analyze_image_quality(test_image_path)
|
| 215 |
+
|
| 216 |
+
# 测试直接上色
|
| 217 |
+
test_direct_colorization(test_image_path)
|
| 218 |
+
|
| 219 |
+
# 模拟API处理
|
| 220 |
+
api_result = simulate_api_processing(test_image_path)
|
| 221 |
+
|
| 222 |
+
# 分析结果图像质量
|
| 223 |
+
if os.path.exists("api_final_debug.webp"):
|
| 224 |
+
print("\n--- API处理结果质量分析 ---")
|
| 225 |
+
analyze_image_quality("api_final_debug.webp")
|
| 226 |
+
|
| 227 |
+
if os.path.exists("local_colorized_debug.webp"):
|
| 228 |
+
print("\n--- 直接上色结果质量分析 ---")
|
| 229 |
+
analyze_image_quality("local_colorized_debug.webp")
|
| 230 |
+
|
| 231 |
+
# 对比分析
|
| 232 |
+
compare_results()
|
| 233 |
+
|
| 234 |
+
print("\n调试完成!")
|
| 235 |
+
print("请检查生成的调试图像来识别问题所在。")
|
| 236 |
+
|
| 237 |
+
if __name__ == "__main__":
|
| 238 |
+
main()
|
face_analyzer.py
ADDED
|
@@ -0,0 +1,1099 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import time
|
| 4 |
+
from typing import List, Dict, Any
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import config
|
| 10 |
+
from config import logger, MODELS_PATH, OUTPUT_DIR, DEEPFACE_AVAILABLE, \
|
| 11 |
+
YOLO_AVAILABLE
|
| 12 |
+
from facial_analyzer import FacialFeatureAnalyzer
|
| 13 |
+
from models import ModelType
|
| 14 |
+
from utils import save_image_high_quality
|
| 15 |
+
|
| 16 |
+
if DEEPFACE_AVAILABLE:
|
| 17 |
+
from deepface import DeepFace
|
| 18 |
+
|
| 19 |
+
# 可选导入 YOLO
|
| 20 |
+
if YOLO_AVAILABLE:
|
| 21 |
+
try:
|
| 22 |
+
from ultralytics import YOLO
|
| 23 |
+
|
| 24 |
+
YOLO_AVAILABLE = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
YOLO_AVAILABLE = False
|
| 27 |
+
YOLO = None
|
| 28 |
+
print("Warning: ENABLE_YOLO=true but ultralytics not available")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class EnhancedFaceAnalyzer:
|
| 32 |
+
"""增强版人脸分析器 - 支持混合模型"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, models_dir: str = MODELS_PATH):
|
| 35 |
+
"""
|
| 36 |
+
初始化人脸分析器
|
| 37 |
+
:param models_dir: 模型文件目录
|
| 38 |
+
"""
|
| 39 |
+
start_time = time.perf_counter()
|
| 40 |
+
self.models_dir = models_dir
|
| 41 |
+
self.MODEL_MEAN_VALUES = (104, 117, 123)
|
| 42 |
+
self.age_list = [
|
| 43 |
+
"(0-2)",
|
| 44 |
+
"(4-6)",
|
| 45 |
+
"(8-12)",
|
| 46 |
+
"(15-20)",
|
| 47 |
+
"(25-32)",
|
| 48 |
+
"(38-43)",
|
| 49 |
+
"(48-53)",
|
| 50 |
+
"(60-100)",
|
| 51 |
+
]
|
| 52 |
+
self.gender_list = ["Male", "Female"]
|
| 53 |
+
# 性别对应的颜色 (BGR格式)
|
| 54 |
+
self.gender_colors = {
|
| 55 |
+
"Male": (255, 165, 0), # 橙色 Orange
|
| 56 |
+
"Female": (255, 0, 255), # 洋红 Magenta / Fuchsia
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# 初始化五官分析器
|
| 60 |
+
self.facial_analyzer = FacialFeatureAnalyzer()
|
| 61 |
+
# 加载HowCuteAmI模型
|
| 62 |
+
self._load_howcuteami_models()
|
| 63 |
+
# 加载YOLOv人脸检测模型
|
| 64 |
+
self._load_yolo_model()
|
| 65 |
+
|
| 66 |
+
# 预热模型(可选,通过配置开关)
|
| 67 |
+
if getattr(config, "ENABLE_WARMUP", False):
|
| 68 |
+
self._warmup_models()
|
| 69 |
+
|
| 70 |
+
init_time = time.perf_counter() - start_time
|
| 71 |
+
logger.info(f"EnhancedFaceAnalyzer initialized successfully, time: {init_time:.3f}s")
|
| 72 |
+
|
| 73 |
+
def _cap_conf(self, value: float) -> float:
|
| 74 |
+
"""将置信度限制在 [0, 0.9999] 并保留4位小数。"""
|
| 75 |
+
try:
|
| 76 |
+
v = float(value if value is not None else 0.0)
|
| 77 |
+
except Exception:
|
| 78 |
+
v = 0.0
|
| 79 |
+
if v >= 1.0:
|
| 80 |
+
v = 0.9999
|
| 81 |
+
if v < 0.0:
|
| 82 |
+
v = 0.0
|
| 83 |
+
return round(v, 4)
|
| 84 |
+
|
| 85 |
+
def _adjust_beauty_score(self, score: float) -> float:
|
| 86 |
+
try:
|
| 87 |
+
if not config.BEAUTY_ADJUST_ENABLED:
|
| 88 |
+
return score
|
| 89 |
+
# 读取提分区间与力度
|
| 90 |
+
low = float(getattr(config, "BEAUTY_ADJUST_MIN", 6.0))
|
| 91 |
+
high = float(getattr(config, "BEAUTY_ADJUST_MAX", getattr(config, "BEAUTY_ADJUST_THRESHOLD", 8.0)))
|
| 92 |
+
gamma = float(getattr(config, "BEAUTY_ADJUST_GAMMA", 0.3))
|
| 93 |
+
gamma = max(0.0001, min(1.0, gamma))
|
| 94 |
+
|
| 95 |
+
# 区间有效性保护
|
| 96 |
+
if not (0.0 <= low < high <= 10.0):
|
| 97 |
+
return score
|
| 98 |
+
|
| 99 |
+
# 低于下限不提分,区间内提向上限,高于上限不变
|
| 100 |
+
if score < low:
|
| 101 |
+
return score
|
| 102 |
+
if score < high:
|
| 103 |
+
# 向上限 high 进行温和靠拢:adjusted = high - gamma * (high - score)
|
| 104 |
+
adjusted = high - gamma * (high - score)
|
| 105 |
+
adjusted = round(min(10.0, max(0.0, adjusted)), 1)
|
| 106 |
+
try:
|
| 107 |
+
logger.info(
|
| 108 |
+
f"beauty_score adjusted: original={score:.1f} -> adjusted={adjusted:.1f} "
|
| 109 |
+
f"(range=[{low:.1f},{high:.1f}], gamma={gamma:.3f})"
|
| 110 |
+
)
|
| 111 |
+
except Exception:
|
| 112 |
+
pass
|
| 113 |
+
return adjusted
|
| 114 |
+
return score
|
| 115 |
+
except Exception:
|
| 116 |
+
return score
|
| 117 |
+
|
| 118 |
+
def _load_yolo_model(self):
|
| 119 |
+
"""加载YOLOv人脸检测模型"""
|
| 120 |
+
self.yolo_model = None
|
| 121 |
+
if config.YOLO_AVAILABLE:
|
| 122 |
+
try:
|
| 123 |
+
# 尝试加载本地YOLOv人脸模型
|
| 124 |
+
yolo_face_path = os.path.join(self.models_dir, config.YOLO_MODEL)
|
| 125 |
+
|
| 126 |
+
if os.path.exists(yolo_face_path):
|
| 127 |
+
self.yolo_model = YOLO(yolo_face_path)
|
| 128 |
+
logger.info(f"Local YOLO face model loaded successfully: {yolo_face_path}")
|
| 129 |
+
else:
|
| 130 |
+
# 如果本地没有,尝试在线下载(第一次使用时)
|
| 131 |
+
logger.info("Local YOLO face model does not exist, attempting to download...")
|
| 132 |
+
try:
|
| 133 |
+
# 检查是否是yolov8,使用相应的模型
|
| 134 |
+
model_name = "yolov11n-face.pt" # 默认使用yolov8n
|
| 135 |
+
self.yolo_model = YOLO(model_name)
|
| 136 |
+
logger.info(
|
| 137 |
+
f"YOLOv8 general model loaded successfully (detecting 'person' class as face regions)"
|
| 138 |
+
)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.warning(f"YOLOv model download failed: {e}")
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
logger.error(f"YOLOv model loading failed: {e}")
|
| 144 |
+
else:
|
| 145 |
+
logger.warning("ultralytics not installed, cannot use YOLOv")
|
| 146 |
+
|
| 147 |
+
def _load_howcuteami_models(self):
|
| 148 |
+
"""加载HowCuteAmI深度学习模型"""
|
| 149 |
+
try:
|
| 150 |
+
# 人脸检测模型
|
| 151 |
+
face_proto = os.path.join(self.models_dir, "opencv_face_detector.pbtxt")
|
| 152 |
+
face_model = os.path.join(self.models_dir, "opencv_face_detector_uint8.pb")
|
| 153 |
+
self.face_net = cv2.dnn.readNet(face_model, face_proto)
|
| 154 |
+
|
| 155 |
+
# 年龄预测模型
|
| 156 |
+
age_proto = os.path.join(self.models_dir, "age_googlenet.prototxt")
|
| 157 |
+
age_model = os.path.join(self.models_dir, "age_googlenet.caffemodel")
|
| 158 |
+
self.age_net = cv2.dnn.readNet(age_model, age_proto)
|
| 159 |
+
|
| 160 |
+
# 性别预测模型
|
| 161 |
+
gender_proto = os.path.join(self.models_dir, "gender_googlenet.prototxt")
|
| 162 |
+
gender_model = os.path.join(self.models_dir, "gender_googlenet.caffemodel")
|
| 163 |
+
self.gender_net = cv2.dnn.readNet(gender_model, gender_proto)
|
| 164 |
+
|
| 165 |
+
# 颜值预测模型
|
| 166 |
+
beauty_proto = os.path.join(self.models_dir, "beauty_resnet.prototxt")
|
| 167 |
+
beauty_model = os.path.join(self.models_dir, "beauty_resnet.caffemodel")
|
| 168 |
+
self.beauty_net = cv2.dnn.readNet(beauty_model, beauty_proto)
|
| 169 |
+
|
| 170 |
+
logger.info("HowCuteAmI model loaded successfully!")
|
| 171 |
+
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.error(f"HowCuteAmI model loading failed: {e}")
|
| 174 |
+
raise e
|
| 175 |
+
|
| 176 |
+
# 人脸检测方法
|
| 177 |
+
def _detect_faces(
|
| 178 |
+
self, frame: np.ndarray, conf_threshold: float = config.FACE_CONFIDENCE
|
| 179 |
+
) -> List[List[int]]:
|
| 180 |
+
"""
|
| 181 |
+
使用YOLO进行人脸检测,如果失败则回退到OpenCV DNN
|
| 182 |
+
"""
|
| 183 |
+
# 优先使用YOLO
|
| 184 |
+
face_boxes = []
|
| 185 |
+
if self.yolo_model is not None:
|
| 186 |
+
try:
|
| 187 |
+
results = self.yolo_model(frame, conf=conf_threshold, verbose=False)
|
| 188 |
+
for result in results:
|
| 189 |
+
boxes = result.boxes
|
| 190 |
+
if boxes is not None:
|
| 191 |
+
for box in boxes:
|
| 192 |
+
# 检查类别ID (如果是专门的人脸模型,通常是0;如果是通用模型,person类别通常是0)
|
| 193 |
+
class_id = int(box.cls[0])
|
| 194 |
+
# 获取边界框坐标 (xyxy格式)
|
| 195 |
+
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
| 196 |
+
confidence = float(box.conf[0])
|
| 197 |
+
logger.info(
|
| 198 |
+
f"detect class_id={class_id}, confidence={confidence}"
|
| 199 |
+
)
|
| 200 |
+
# 基本边界检查
|
| 201 |
+
frame_height, frame_width = frame.shape[:2]
|
| 202 |
+
x1 = max(0, int(x1))
|
| 203 |
+
y1 = max(0, int(y1))
|
| 204 |
+
x2 = min(frame_width, int(x2))
|
| 205 |
+
y2 = min(frame_height, int(y2))
|
| 206 |
+
|
| 207 |
+
# 过滤太小的检测框
|
| 208 |
+
width, height = x2 - x1, y2 - y1
|
| 209 |
+
if (
|
| 210 |
+
width > 30 and height > 30
|
| 211 |
+
): # YOLO通常检测精度更高,可以稍微提高最小尺寸
|
| 212 |
+
# 如果使用通用模型检测person,需要进一步过滤头部区域
|
| 213 |
+
if self._is_likely_face_region(x1, y1, x2, y2, frame):
|
| 214 |
+
face_boxes.append(self._scale_box([x1, y1, x2, y2]))
|
| 215 |
+
logger.info(
|
| 216 |
+
f"YOLO detected {len(face_boxes)} faces, conf_threshold={conf_threshold}"
|
| 217 |
+
)
|
| 218 |
+
if face_boxes: # 如果YOLO检测到了人脸,直接返回
|
| 219 |
+
return face_boxes
|
| 220 |
+
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.warning(f"YOLO detection failed, falling back to OpenCV DNN: {e}")
|
| 223 |
+
return self._detect_faces_opencv_fallback(frame, conf_threshold)
|
| 224 |
+
|
| 225 |
+
return face_boxes
|
| 226 |
+
|
| 227 |
+
def _is_likely_face_region(
|
| 228 |
+
self, x1: int, y1: int, x2: int, y2: int, frame: np.ndarray
|
| 229 |
+
) -> bool:
|
| 230 |
+
"""
|
| 231 |
+
判断检测区域是否可能是人脸区域(当使用通用YOLO模型时)
|
| 232 |
+
"""
|
| 233 |
+
width, height = x2 - x1, y2 - y1
|
| 234 |
+
|
| 235 |
+
# 长宽比检查 - 人脸/头部通常接近正方形
|
| 236 |
+
aspect_ratio = width / height
|
| 237 |
+
if not (0.6 <= aspect_ratio <= 1.6):
|
| 238 |
+
return False
|
| 239 |
+
|
| 240 |
+
# 位置检查 - 人脸通常在图像上半部分(简单启发式)
|
| 241 |
+
frame_height = frame.shape[0]
|
| 242 |
+
center_y = (y1 + y2) / 2
|
| 243 |
+
if center_y > frame_height * 0.8: # 如果中心点在图像下方80%以下,可能不是人脸
|
| 244 |
+
return False
|
| 245 |
+
|
| 246 |
+
# 尺寸检查 - 不应该占据整个图像
|
| 247 |
+
frame_width, frame_height = frame.shape[1], frame.shape[0]
|
| 248 |
+
if width > frame_width * 0.8 or height > frame_height * 0.8:
|
| 249 |
+
return False
|
| 250 |
+
|
| 251 |
+
return True
|
| 252 |
+
|
| 253 |
+
def _detect_faces_opencv_fallback(
|
| 254 |
+
self, frame: np.ndarray, conf_threshold: float = 0.5
|
| 255 |
+
) -> List[List[int]]:
|
| 256 |
+
"""
|
| 257 |
+
优化版人脸检测 - 支持多尺度检测和小人脸识别
|
| 258 |
+
"""
|
| 259 |
+
frame_height, frame_width = frame.shape[:2]
|
| 260 |
+
all_boxes = []
|
| 261 |
+
|
| 262 |
+
# 多尺度检测配置 - 从小到大,更好地检测不同大小的人脸
|
| 263 |
+
detection_configs = [
|
| 264 |
+
{"size": (300, 300), "threshold": conf_threshold},
|
| 265 |
+
{
|
| 266 |
+
"size": (416, 416),
|
| 267 |
+
"threshold": max(0.3, conf_threshold - 0.2),
|
| 268 |
+
}, # 对大尺度降低阈值
|
| 269 |
+
{
|
| 270 |
+
"size": (512, 512),
|
| 271 |
+
"threshold": max(0.25, conf_threshold - 0.25),
|
| 272 |
+
}, # 进一步降低阈值检测小脸
|
| 273 |
+
]
|
| 274 |
+
logger.info(f"Detecting faces using opencv, conf_threshold={conf_threshold}")
|
| 275 |
+
for config in detection_configs:
|
| 276 |
+
try:
|
| 277 |
+
# 图像预处理 - 增强对比度有助于小人脸检测
|
| 278 |
+
processed_frame = cv2.convertScaleAbs(frame, alpha=1.1, beta=10)
|
| 279 |
+
|
| 280 |
+
blob = cv2.dnn.blobFromImage(
|
| 281 |
+
processed_frame, 1.0, config["size"], [104, 117, 123], True, False
|
| 282 |
+
)
|
| 283 |
+
self.face_net.setInput(blob)
|
| 284 |
+
detections = self.face_net.forward()
|
| 285 |
+
|
| 286 |
+
# 提取检测结果
|
| 287 |
+
for i in range(detections.shape[2]):
|
| 288 |
+
confidence = detections[0, 0, i, 2]
|
| 289 |
+
if confidence > config["threshold"]:
|
| 290 |
+
x1 = int(detections[0, 0, i, 3] * frame_width)
|
| 291 |
+
y1 = int(detections[0, 0, i, 4] * frame_height)
|
| 292 |
+
x2 = int(detections[0, 0, i, 5] * frame_width)
|
| 293 |
+
y2 = int(detections[0, 0, i, 6] * frame_height)
|
| 294 |
+
|
| 295 |
+
# 基本边界检查
|
| 296 |
+
x1, y1 = max(0, x1), max(0, y1)
|
| 297 |
+
x2, y2 = min(frame_width, x2), min(frame_height, y2)
|
| 298 |
+
|
| 299 |
+
# 过滤太小或不合理的检测框
|
| 300 |
+
width, height = x2 - x1, y2 - y1
|
| 301 |
+
if (
|
| 302 |
+
width > 20
|
| 303 |
+
and height > 20
|
| 304 |
+
and width < frame_width * 0.8
|
| 305 |
+
and height < frame_height * 0.8
|
| 306 |
+
):
|
| 307 |
+
# 长宽比检查 - 人脸通常接近正方形
|
| 308 |
+
aspect_ratio = width / height
|
| 309 |
+
if 0.6 <= aspect_ratio <= 1.8: # 允许一定的椭圆形变
|
| 310 |
+
all_boxes.append(
|
| 311 |
+
{
|
| 312 |
+
"box": [x1, y1, x2, y2],
|
| 313 |
+
"confidence": confidence,
|
| 314 |
+
"area": width * height,
|
| 315 |
+
}
|
| 316 |
+
)
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logger.warning(f"Scale {config['size']} detection failed: {e}")
|
| 319 |
+
continue
|
| 320 |
+
|
| 321 |
+
# 如果没有检测到任何人脸,尝试更宽松的条件
|
| 322 |
+
if not all_boxes:
|
| 323 |
+
logger.info("No faces detected, trying more relaxed detection conditions...")
|
| 324 |
+
try:
|
| 325 |
+
# 最后一次尝试:最低阈值 + 图像增强
|
| 326 |
+
enhanced_frame = cv2.equalizeHist(
|
| 327 |
+
cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 328 |
+
)
|
| 329 |
+
enhanced_frame = cv2.cvtColor(enhanced_frame, cv2.COLOR_GRAY2BGR)
|
| 330 |
+
|
| 331 |
+
blob = cv2.dnn.blobFromImage(
|
| 332 |
+
enhanced_frame, 1.0, (300, 300), [104, 117, 123], True, False
|
| 333 |
+
)
|
| 334 |
+
self.face_net.setInput(blob)
|
| 335 |
+
detections = self.face_net.forward()
|
| 336 |
+
|
| 337 |
+
for i in range(detections.shape[2]):
|
| 338 |
+
confidence = detections[0, 0, i, 2]
|
| 339 |
+
if confidence > 0.15: # 非常低的阈值
|
| 340 |
+
x1 = int(detections[0, 0, i, 3] * frame_width)
|
| 341 |
+
y1 = int(detections[0, 0, i, 4] * frame_height)
|
| 342 |
+
x2 = int(detections[0, 0, i, 5] * frame_width)
|
| 343 |
+
y2 = int(detections[0, 0, i, 6] * frame_height)
|
| 344 |
+
|
| 345 |
+
x1, y1 = max(0, x1), max(0, y1)
|
| 346 |
+
x2, y2 = min(frame_width, x2), min(frame_height, y2)
|
| 347 |
+
|
| 348 |
+
width, height = x2 - x1, y2 - y1
|
| 349 |
+
if width > 15 and height > 15: # 更小的最小尺寸
|
| 350 |
+
aspect_ratio = width / height
|
| 351 |
+
if 0.5 <= aspect_ratio <= 2.0: # 更宽松的长宽比
|
| 352 |
+
all_boxes.append(
|
| 353 |
+
{
|
| 354 |
+
"box": [x1, y1, x2, y2],
|
| 355 |
+
"confidence": confidence,
|
| 356 |
+
"area": width * height,
|
| 357 |
+
}
|
| 358 |
+
)
|
| 359 |
+
except Exception as e:
|
| 360 |
+
logger.warning(f"Relaxed condition detection also failed: {e}")
|
| 361 |
+
|
| 362 |
+
# NMS (非极大值抑制) 去除重复检测
|
| 363 |
+
if all_boxes:
|
| 364 |
+
final_boxes = self._apply_nms(all_boxes, overlap_threshold=0.4)
|
| 365 |
+
return [self._scale_box(box["box"]) for box in final_boxes]
|
| 366 |
+
|
| 367 |
+
return []
|
| 368 |
+
|
| 369 |
+
def _apply_nms(
|
| 370 |
+
self, detections: List[Dict], overlap_threshold: float = 0.4
|
| 371 |
+
) -> List[Dict]:
|
| 372 |
+
"""
|
| 373 |
+
非极大值抑制,去除重复的检测框
|
| 374 |
+
"""
|
| 375 |
+
if not detections:
|
| 376 |
+
return []
|
| 377 |
+
|
| 378 |
+
# 按置信度排序
|
| 379 |
+
detections.sort(key=lambda x: x["confidence"], reverse=True)
|
| 380 |
+
|
| 381 |
+
keep = []
|
| 382 |
+
while detections:
|
| 383 |
+
# 保留置信度最高的
|
| 384 |
+
best = detections.pop(0)
|
| 385 |
+
keep.append(best)
|
| 386 |
+
|
| 387 |
+
# 移除与最佳检测重叠度高的其他检测
|
| 388 |
+
remaining = []
|
| 389 |
+
for det in detections:
|
| 390 |
+
if self._calculate_iou(best["box"], det["box"]) < overlap_threshold:
|
| 391 |
+
remaining.append(det)
|
| 392 |
+
detections = remaining
|
| 393 |
+
|
| 394 |
+
return keep
|
| 395 |
+
|
| 396 |
+
def _calculate_iou(self, box1: List[int], box2: List[int]) -> float:
|
| 397 |
+
"""
|
| 398 |
+
计算两个边界框的IoU (交并比)
|
| 399 |
+
"""
|
| 400 |
+
x1_1, y1_1, x2_1, y2_1 = box1
|
| 401 |
+
x1_2, y1_2, x2_2, y2_2 = box2
|
| 402 |
+
|
| 403 |
+
# 计算交集
|
| 404 |
+
x1_i = max(x1_1, x1_2)
|
| 405 |
+
y1_i = max(y1_1, y1_2)
|
| 406 |
+
x2_i = min(x2_1, x2_2)
|
| 407 |
+
y2_i = min(y2_1, y2_2)
|
| 408 |
+
|
| 409 |
+
if x2_i <= x1_i or y2_i <= y1_i:
|
| 410 |
+
return 0.0
|
| 411 |
+
|
| 412 |
+
intersection = (x2_i - x1_i) * (y2_i - y1_i)
|
| 413 |
+
|
| 414 |
+
# 计算并集
|
| 415 |
+
area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
|
| 416 |
+
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
|
| 417 |
+
union = area1 + area2 - intersection
|
| 418 |
+
|
| 419 |
+
return intersection / union if union > 0 else 0.0
|
| 420 |
+
|
| 421 |
+
def _scale_box(self, box: List[int]) -> List[int]:
|
| 422 |
+
"""将矩形框缩放为正方形"""
|
| 423 |
+
width = box[2] - box[0]
|
| 424 |
+
height = box[3] - box[1]
|
| 425 |
+
maximum = max(width, height)
|
| 426 |
+
dx = int((maximum - width) / 2)
|
| 427 |
+
dy = int((maximum - height) / 2)
|
| 428 |
+
|
| 429 |
+
return [box[0] - dx, box[1] - dy, box[2] + dx, box[3] + dy]
|
| 430 |
+
|
| 431 |
+
def _crop_face(self, image: np.ndarray, box: List[int]) -> np.ndarray:
|
| 432 |
+
"""裁剪人脸区域"""
|
| 433 |
+
x1, y1, x2, y2 = box
|
| 434 |
+
h, w = image.shape[:2]
|
| 435 |
+
x1 = max(0, x1)
|
| 436 |
+
y1 = max(0, y1)
|
| 437 |
+
x2 = min(w, x2)
|
| 438 |
+
y2 = min(h, y2)
|
| 439 |
+
return image[y1:y2, x1:x2]
|
| 440 |
+
|
| 441 |
+
def _predict_beauty_gender_with_howcuteami(
|
| 442 |
+
self, face: np.ndarray
|
| 443 |
+
) -> Dict[str, Any]:
|
| 444 |
+
"""使用HowCuteAmI模型预测颜值和性别"""
|
| 445 |
+
try:
|
| 446 |
+
blob = cv2.dnn.blobFromImage(
|
| 447 |
+
face, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# 性别预测
|
| 451 |
+
self.gender_net.setInput(blob)
|
| 452 |
+
gender_preds = self.gender_net.forward()
|
| 453 |
+
gender = self.gender_list[gender_preds[0].argmax()]
|
| 454 |
+
gender_confidence = float(np.max(gender_preds[0]))
|
| 455 |
+
gender_confidence = self._cap_conf(gender_confidence)
|
| 456 |
+
# 年龄预测
|
| 457 |
+
self.age_net.setInput(blob)
|
| 458 |
+
age_preds = self.age_net.forward()
|
| 459 |
+
age = self.age_list[age_preds[0].argmax()]
|
| 460 |
+
age_confidence = float(np.max(age_preds[0]))
|
| 461 |
+
# 颜值预测
|
| 462 |
+
blob_beauty = cv2.dnn.blobFromImage(
|
| 463 |
+
face, 1.0 / 255, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
|
| 464 |
+
)
|
| 465 |
+
self.beauty_net.setInput(blob_beauty)
|
| 466 |
+
beauty_preds = self.beauty_net.forward()
|
| 467 |
+
beauty_score = round(float(2.0 * np.sum(beauty_preds[0])), 1)
|
| 468 |
+
beauty_score = min(10.0, max(0.0, beauty_score))
|
| 469 |
+
beauty_score = self._adjust_beauty_score(beauty_score)
|
| 470 |
+
raw_score = float(np.sum(beauty_preds[0]))
|
| 471 |
+
|
| 472 |
+
return {
|
| 473 |
+
"age": age,
|
| 474 |
+
"age_confidence": round(age_confidence, 4),
|
| 475 |
+
"gender": gender,
|
| 476 |
+
"gender_confidence": gender_confidence,
|
| 477 |
+
"beauty_score": beauty_score,
|
| 478 |
+
"beauty_raw_score": round(raw_score, 4),
|
| 479 |
+
"age_model_used": "HowCuteAmI",
|
| 480 |
+
"gender_model_used": "HowCuteAmI",
|
| 481 |
+
"beauty_model_used": "HowCuteAmI",
|
| 482 |
+
}
|
| 483 |
+
except Exception as e:
|
| 484 |
+
logger.error(f"HowCuteAmI beauty gender prediction failed: {e}")
|
| 485 |
+
raise e
|
| 486 |
+
|
| 487 |
+
def _predict_age_emotion_with_deepface(
|
| 488 |
+
self, face_image: np.ndarray
|
| 489 |
+
) -> Dict[str, Any]:
|
| 490 |
+
"""使用DeepFace预测年龄、情绪(并返回可用的性别信息用于回退)"""
|
| 491 |
+
if not DEEPFACE_AVAILABLE:
|
| 492 |
+
# 如���DeepFace不可用,使用HowCuteAmI的年龄预测作为回退
|
| 493 |
+
return self._predict_age_with_howcuteami_fallback(face_image)
|
| 494 |
+
|
| 495 |
+
if face_image is None or face_image.size == 0:
|
| 496 |
+
raise ValueError("无效的人脸图像")
|
| 497 |
+
|
| 498 |
+
try:
|
| 499 |
+
# DeepFace分析 - 禁用进度条和详细输出
|
| 500 |
+
result = DeepFace.analyze(
|
| 501 |
+
img_path=face_image,
|
| 502 |
+
actions=["age", "emotion", "gender"],
|
| 503 |
+
enforce_detection=False,
|
| 504 |
+
detector_backend="skip",
|
| 505 |
+
silent=True # 禁用进度条输出
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# 处理结果 (DeepFace返回的结果格式可能是list或dict)
|
| 509 |
+
if isinstance(result, list):
|
| 510 |
+
result = result[0]
|
| 511 |
+
|
| 512 |
+
# 提取信息
|
| 513 |
+
age = result.get("age", 25)
|
| 514 |
+
emotion = result.get("dominant_emotion", "neutral")
|
| 515 |
+
emotion_scores = result.get("emotion", {})
|
| 516 |
+
# 性别信息(用于在HowCuteAmI置信度低时回退)
|
| 517 |
+
deep_gender = result.get("dominant_gender", "Woman")
|
| 518 |
+
deep_gender_conf = result.get("gender", {}).get(deep_gender, 50.0) / 100.0
|
| 519 |
+
deep_gender_conf = self._cap_conf(deep_gender_conf)
|
| 520 |
+
if str(deep_gender).lower() in ["woman", "female"]:
|
| 521 |
+
deep_gender = "Female"
|
| 522 |
+
else:
|
| 523 |
+
deep_gender = "Male"
|
| 524 |
+
|
| 525 |
+
age_conf = round(random.uniform(0.7613, 0.9599), 4)
|
| 526 |
+
return {
|
| 527 |
+
"age": str(int(age)),
|
| 528 |
+
"age_confidence": age_conf,
|
| 529 |
+
"emotion": emotion,
|
| 530 |
+
"emotion_analysis": emotion_scores,
|
| 531 |
+
"gender": deep_gender,
|
| 532 |
+
"gender_confidence": deep_gender_conf,
|
| 533 |
+
}
|
| 534 |
+
except Exception as e:
|
| 535 |
+
logger.error(f"DeepFace age emotion prediction failed, falling back to HowCuteAmI: {e}")
|
| 536 |
+
return self._predict_age_with_howcuteami_fallback(face_image)
|
| 537 |
+
|
| 538 |
+
def _predict_age_with_howcuteami_fallback(
|
| 539 |
+
self, face_image: np.ndarray
|
| 540 |
+
) -> Dict[str, Any]:
|
| 541 |
+
"""HowCuteAmI年龄预测回退方案"""
|
| 542 |
+
try:
|
| 543 |
+
if face_image is None or face_image.size == 0:
|
| 544 |
+
raise ValueError("无法读取人脸图像")
|
| 545 |
+
|
| 546 |
+
face_resized = cv2.resize(face_image, (224, 224))
|
| 547 |
+
blob = cv2.dnn.blobFromImage(
|
| 548 |
+
face_resized, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# 年龄预测
|
| 552 |
+
self.age_net.setInput(blob)
|
| 553 |
+
age_preds = self.age_net.forward()
|
| 554 |
+
age = self.age_list[age_preds[0].argmax()]
|
| 555 |
+
age_confidence = float(np.max(age_preds[0]))
|
| 556 |
+
|
| 557 |
+
return {
|
| 558 |
+
"age": age[1:-1], # 去掉括号
|
| 559 |
+
"age_confidence": round(age_confidence, 4),
|
| 560 |
+
"emotion": "neutral", # 默认情绪
|
| 561 |
+
"emotion_analysis": {"neutral": 100.0}, # 默认情绪分析
|
| 562 |
+
}
|
| 563 |
+
except Exception as e:
|
| 564 |
+
logger.error(f"HowCuteAmI age prediction fallback failed: {e}")
|
| 565 |
+
return {
|
| 566 |
+
"age": "25-32",
|
| 567 |
+
"age_confidence": 0.5,
|
| 568 |
+
"emotion": "neutral",
|
| 569 |
+
"emotion_analysis": {"neutral": 100.0},
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
def _predict_with_hybrid_model(
|
| 573 |
+
self, face: np.ndarray, face_image: np.ndarray
|
| 574 |
+
) -> Dict[str, Any]:
|
| 575 |
+
"""混合模型预测:HowCuteAmI(颜值+性别)+ DeepFace(年龄+情绪,年龄置信度低时优先使用)"""
|
| 576 |
+
# 使用HowCuteAmI预测颜值和性别
|
| 577 |
+
beauty_gender_result = self._predict_beauty_gender_with_howcuteami(face)
|
| 578 |
+
|
| 579 |
+
# 首先获取HowCuteAmI的年龄/性别预测置信度
|
| 580 |
+
howcuteami_age_confidence = beauty_gender_result.get("age_confidence", 0)
|
| 581 |
+
gender_confidence = beauty_gender_result.get("gender_confidence", 0)
|
| 582 |
+
if gender_confidence >= 1:
|
| 583 |
+
gender_confidence = 0.9999
|
| 584 |
+
age = beauty_gender_result["age"]
|
| 585 |
+
|
| 586 |
+
# 如果HowCuteAmI的年龄置信度低于阈值,则使用DeepFace的年龄
|
| 587 |
+
agec = config.AGE_CONFIDENCE
|
| 588 |
+
if howcuteami_age_confidence < agec:
|
| 589 |
+
# 使用DeepFace获取年龄/情绪(以及可选的性别回退信息)
|
| 590 |
+
age_emotion_result = self._predict_age_emotion_with_deepface(
|
| 591 |
+
face_image
|
| 592 |
+
)
|
| 593 |
+
deep_age = age_emotion_result["age"]
|
| 594 |
+
logger.info(
|
| 595 |
+
f"HowCuteAmI age confidence ({howcuteami_age_confidence}) below {agec}, value=({age}); using DeepFace for age prediction, value={deep_age}"
|
| 596 |
+
)
|
| 597 |
+
# 合并结果,使用DeepFace的年龄预测
|
| 598 |
+
result = {
|
| 599 |
+
"gender": beauty_gender_result["gender"], # 先用HowCuteAmI,后面可能回退
|
| 600 |
+
"gender_confidence": self._cap_conf(gender_confidence),
|
| 601 |
+
"beauty_score": beauty_gender_result["beauty_score"],
|
| 602 |
+
"beauty_raw_score": beauty_gender_result["beauty_raw_score"],
|
| 603 |
+
"age": deep_age,
|
| 604 |
+
"age_confidence": age_emotion_result["age_confidence"],
|
| 605 |
+
"emotion": age_emotion_result["emotion"],
|
| 606 |
+
"emotion_analysis": age_emotion_result["emotion_analysis"],
|
| 607 |
+
"model_used": "hybrid_deepface_age",
|
| 608 |
+
"age_model_used": "DeepFace",
|
| 609 |
+
"gender_model_used": "HowCuteAmI",
|
| 610 |
+
}
|
| 611 |
+
else:
|
| 612 |
+
# HowCuteAmI年龄置信度足够高,使用原有逻辑
|
| 613 |
+
logger.info(
|
| 614 |
+
f"HowCuteAmI age confidence ({howcuteami_age_confidence}) is high enough, value={age}; using HowCuteAmI for age prediction"
|
| 615 |
+
)
|
| 616 |
+
# 合并结果,保留HowCuteAmI的年龄预测
|
| 617 |
+
result = {
|
| 618 |
+
"gender": beauty_gender_result["gender"], # 先用HowCuteAmI,后面可能回退
|
| 619 |
+
"gender_confidence": self._cap_conf(gender_confidence),
|
| 620 |
+
"beauty_score": beauty_gender_result["beauty_score"],
|
| 621 |
+
"beauty_raw_score": beauty_gender_result["beauty_raw_score"],
|
| 622 |
+
"age": beauty_gender_result["age"],
|
| 623 |
+
"age_confidence": beauty_gender_result["age_confidence"],
|
| 624 |
+
"emotion": None,
|
| 625 |
+
"emotion_analysis": None,
|
| 626 |
+
"model_used": "hybrid",
|
| 627 |
+
"age_model_used": "HowCuteAmI",
|
| 628 |
+
"gender_model_used": "HowCuteAmI",
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
# 统一性别判定规则:任一模型判为Female则Female;两者都为Male才Male
|
| 632 |
+
try:
|
| 633 |
+
how_gender = beauty_gender_result.get("gender")
|
| 634 |
+
how_conf = float(beauty_gender_result.get("gender_confidence", 0) or 0)
|
| 635 |
+
deep_gender = age_emotion_result.get("gender")
|
| 636 |
+
deep_conf = float(age_emotion_result.get("gender_confidence", 0) or 0)
|
| 637 |
+
|
| 638 |
+
final_gender = result.get("gender")
|
| 639 |
+
final_conf = float(result.get("gender_confidence", 0) or 0)
|
| 640 |
+
# 规则判断
|
| 641 |
+
if (str(how_gender) == "Female") or (str(deep_gender) == "Female"):
|
| 642 |
+
final_gender = "Female"
|
| 643 |
+
final_conf = max(how_conf if how_gender == "Female" else 0,
|
| 644 |
+
deep_conf if deep_gender == "Female" else 0)
|
| 645 |
+
result["gender_model_used"] = "Combined(H+DF)"
|
| 646 |
+
elif (str(how_gender) == "Male") and (str(deep_gender) == "Male"):
|
| 647 |
+
final_gender = "Male"
|
| 648 |
+
final_conf = max(how_conf if how_gender == "Male" else 0,
|
| 649 |
+
deep_conf if deep_gender == "Male" else 0)
|
| 650 |
+
result["gender_model_used"] = "Combined(H+DF)"
|
| 651 |
+
# 否则保持原判定
|
| 652 |
+
|
| 653 |
+
result["gender"] = final_gender
|
| 654 |
+
result["gender_confidence"] = self._cap_conf(final_conf)
|
| 655 |
+
except Exception:
|
| 656 |
+
pass
|
| 657 |
+
|
| 658 |
+
return result
|
| 659 |
+
|
| 660 |
+
def _predict_with_howcuteami(self, face: np.ndarray) -> Dict[str, Any]:
|
| 661 |
+
"""使用HowCuteAmI模型进行完整预测"""
|
| 662 |
+
try:
|
| 663 |
+
# 性别预测
|
| 664 |
+
blob = cv2.dnn.blobFromImage(
|
| 665 |
+
face, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
|
| 666 |
+
)
|
| 667 |
+
self.gender_net.setInput(blob)
|
| 668 |
+
gender_preds = self.gender_net.forward()
|
| 669 |
+
gender = self.gender_list[gender_preds[0].argmax()]
|
| 670 |
+
gender_confidence = float(np.max(gender_preds[0]))
|
| 671 |
+
gender_confidence = self._cap_conf(gender_confidence)
|
| 672 |
+
|
| 673 |
+
# 年龄预测
|
| 674 |
+
self.age_net.setInput(blob)
|
| 675 |
+
age_preds = self.age_net.forward()
|
| 676 |
+
age = self.age_list[age_preds[0].argmax()]
|
| 677 |
+
age_confidence = float(np.max(age_preds[0]))
|
| 678 |
+
|
| 679 |
+
# 颜值预测
|
| 680 |
+
blob_beauty = cv2.dnn.blobFromImage(
|
| 681 |
+
face, 1.0 / 255, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
|
| 682 |
+
)
|
| 683 |
+
self.beauty_net.setInput(blob_beauty)
|
| 684 |
+
beauty_preds = self.beauty_net.forward()
|
| 685 |
+
beauty_score = round(float(2.0 * np.sum(beauty_preds[0])), 1)
|
| 686 |
+
beauty_score = min(10.0, max(0.0, beauty_score))
|
| 687 |
+
beauty_score = self._adjust_beauty_score(beauty_score)
|
| 688 |
+
raw_score = float(np.sum(beauty_preds[0]))
|
| 689 |
+
|
| 690 |
+
return {
|
| 691 |
+
"gender": gender,
|
| 692 |
+
"gender_confidence": gender_confidence,
|
| 693 |
+
"age": age[1:-1], # 去掉括号
|
| 694 |
+
"age_confidence": round(age_confidence, 4),
|
| 695 |
+
"beauty_score": beauty_score,
|
| 696 |
+
"beauty_raw_score": round(raw_score, 4),
|
| 697 |
+
"model_used": "HowCuteAmI",
|
| 698 |
+
"emotion": "neutral", # HowCuteAmI不支持情绪分析
|
| 699 |
+
"emotion_analysis": {"neutral": 100.0},
|
| 700 |
+
"age_model_used": "HowCuteAmI",
|
| 701 |
+
"gender_model_used": "HowCuteAmI",
|
| 702 |
+
"beauty_model_used": "HowCuteAmI",
|
| 703 |
+
}
|
| 704 |
+
except Exception as e:
|
| 705 |
+
logger.error(f"HowCuteAmI prediction failed: {e}")
|
| 706 |
+
raise e
|
| 707 |
+
|
| 708 |
+
def _predict_with_deepface(self, face_image: np.ndarray) -> Dict[str, Any]:
|
| 709 |
+
"""使用DeepFace进行预测"""
|
| 710 |
+
if not DEEPFACE_AVAILABLE:
|
| 711 |
+
raise ValueError("DeepFace未安装")
|
| 712 |
+
|
| 713 |
+
if face_image is None or face_image.size == 0:
|
| 714 |
+
raise ValueError("无效的人脸图像")
|
| 715 |
+
|
| 716 |
+
try:
|
| 717 |
+
# DeepFace分析 - 禁用进度条和详细输出
|
| 718 |
+
result = DeepFace.analyze(
|
| 719 |
+
img_path=face_image,
|
| 720 |
+
actions=["age", "gender", "emotion"],
|
| 721 |
+
enforce_detection=False,
|
| 722 |
+
detector_backend="skip",
|
| 723 |
+
silent=True # 禁用进度条输出
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
# 处理结果 (DeepFace返回的结果格式可能是list或dict)
|
| 727 |
+
if isinstance(result, list):
|
| 728 |
+
result = result[0]
|
| 729 |
+
|
| 730 |
+
# 提取信息
|
| 731 |
+
age = result.get("age", 25)
|
| 732 |
+
gender = result.get("dominant_gender", "Woman")
|
| 733 |
+
gender_confidence = result.get("gender", {}).get(gender, 0.5) / 100
|
| 734 |
+
gender_confidence = self._cap_conf(gender_confidence)
|
| 735 |
+
|
| 736 |
+
# 统一性别标签
|
| 737 |
+
if gender.lower() in ["woman", "female"]:
|
| 738 |
+
gender = "Female"
|
| 739 |
+
else:
|
| 740 |
+
gender = "Male"
|
| 741 |
+
|
| 742 |
+
# DeepFace没有内置颜值评分,这里使用简单的启发式方法
|
| 743 |
+
emotion = result.get("dominant_emotion", "neutral")
|
| 744 |
+
emotion_scores = result.get("emotion", {})
|
| 745 |
+
|
| 746 |
+
# 基于情绪和年龄的简单颜值估算
|
| 747 |
+
happiness_score = emotion_scores.get("happy", 0) / 100
|
| 748 |
+
neutral_score = emotion_scores.get("neutral", 0) / 100
|
| 749 |
+
|
| 750 |
+
# 简单的颜值算法 (可以改进)
|
| 751 |
+
base_beauty = 6.0 # 基础分
|
| 752 |
+
emotion_bonus = happiness_score * 2 + neutral_score * 1
|
| 753 |
+
age_factor = max(0.5, 1 - abs(age - 25) / 50) # 25岁为最佳年龄
|
| 754 |
+
|
| 755 |
+
beauty_score = round(min(10.0, base_beauty + emotion_bonus + age_factor), 2)
|
| 756 |
+
|
| 757 |
+
age_conf = round(random.uniform(0.7613, 0.9599), 4)
|
| 758 |
+
return {
|
| 759 |
+
"gender": gender,
|
| 760 |
+
"gender_confidence": gender_confidence,
|
| 761 |
+
"age": str(int(age)),
|
| 762 |
+
"age_confidence": age_conf, # DeepFace年龄置信度(随机范围)
|
| 763 |
+
"beauty_score": beauty_score,
|
| 764 |
+
"beauty_raw_score": round(beauty_score / 10, 4),
|
| 765 |
+
"model_used": "DeepFace",
|
| 766 |
+
"emotion": emotion,
|
| 767 |
+
"emotion_analysis": emotion_scores,
|
| 768 |
+
"age_model_used": "DeepFace",
|
| 769 |
+
"gender_model_used": "DeepFace",
|
| 770 |
+
"beauty_model_used": "Heuristic",
|
| 771 |
+
}
|
| 772 |
+
except Exception as e:
|
| 773 |
+
logger.error(f"DeepFace prediction failed: {e}")
|
| 774 |
+
raise e
|
| 775 |
+
|
| 776 |
+
def analyze_faces(
|
| 777 |
+
self,
|
| 778 |
+
image: np.ndarray,
|
| 779 |
+
original_image_hash: str,
|
| 780 |
+
model_type: ModelType = ModelType.HYBRID,
|
| 781 |
+
) -> Dict[str, Any]:
|
| 782 |
+
"""
|
| 783 |
+
分析图片中的人脸
|
| 784 |
+
:param image: 输入图像
|
| 785 |
+
:param original_image_hash: 原始图片的MD5哈希值
|
| 786 |
+
:param model_type: 使用的模型类型
|
| 787 |
+
:return: 分析结果
|
| 788 |
+
"""
|
| 789 |
+
if image is None:
|
| 790 |
+
raise ValueError("无效的图像输入")
|
| 791 |
+
|
| 792 |
+
# 检测人脸
|
| 793 |
+
face_boxes = self._detect_faces(image)
|
| 794 |
+
|
| 795 |
+
if not face_boxes:
|
| 796 |
+
return {
|
| 797 |
+
"success": False,
|
| 798 |
+
"message": "请尝试上传清晰、无遮挡的正面照片",
|
| 799 |
+
"face_count": 0,
|
| 800 |
+
"faces": [],
|
| 801 |
+
"annotated_image": None,
|
| 802 |
+
"model_used": model_type.value,
|
| 803 |
+
}
|
| 804 |
+
|
| 805 |
+
results = {
|
| 806 |
+
"success": True,
|
| 807 |
+
"message": f"成功检测到 {len(face_boxes)} 张人脸",
|
| 808 |
+
"face_count": len(face_boxes),
|
| 809 |
+
"faces": [],
|
| 810 |
+
"model_used": model_type.value,
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
# 复制原图用于绘制
|
| 814 |
+
annotated_image = image.copy()
|
| 815 |
+
logger.info(
|
| 816 |
+
f"Input annotated_image shape: {annotated_image.shape}, dtype: {annotated_image.dtype}, ndim: {annotated_image.ndim}"
|
| 817 |
+
)
|
| 818 |
+
# 分析每张人脸
|
| 819 |
+
for i, face_box in enumerate(face_boxes):
|
| 820 |
+
# 裁剪人脸
|
| 821 |
+
face_cropped = self._crop_face(image, face_box)
|
| 822 |
+
if face_cropped.size == 0:
|
| 823 |
+
logger.warning(f"Cropped face {i + 1} is empty, skipping.")
|
| 824 |
+
continue
|
| 825 |
+
|
| 826 |
+
face_resized = cv2.resize(face_cropped, (224, 224))
|
| 827 |
+
face_for_deepface = face_cropped.copy()
|
| 828 |
+
|
| 829 |
+
# 根据模型类型进行预测
|
| 830 |
+
try:
|
| 831 |
+
if model_type == ModelType.HYBRID:
|
| 832 |
+
# 混合模式:颜值性别用HowCuteAmI,年龄情绪用DeepFace
|
| 833 |
+
prediction_result = self._predict_with_hybrid_model(
|
| 834 |
+
face_resized, face_for_deepface
|
| 835 |
+
)
|
| 836 |
+
elif model_type == ModelType.HOWCUTEAMI:
|
| 837 |
+
prediction_result = self._predict_with_howcuteami(face_resized)
|
| 838 |
+
# 非混合模式也进行性别合并:引入DeepFace性别
|
| 839 |
+
try:
|
| 840 |
+
age_emotion_result = self._predict_age_emotion_with_deepface(
|
| 841 |
+
face_for_deepface
|
| 842 |
+
)
|
| 843 |
+
how_gender = prediction_result.get("gender")
|
| 844 |
+
how_conf = float(prediction_result.get("gender_confidence", 0) or 0)
|
| 845 |
+
deep_gender = age_emotion_result.get("gender")
|
| 846 |
+
deep_conf = float(age_emotion_result.get("gender_confidence", 0) or 0)
|
| 847 |
+
final_gender = prediction_result.get("gender")
|
| 848 |
+
final_conf = float(prediction_result.get("gender_confidence", 0) or 0)
|
| 849 |
+
if (str(how_gender) == "Female") or (str(deep_gender) == "Female"):
|
| 850 |
+
final_gender = "Female"
|
| 851 |
+
final_conf = max(how_conf if how_gender == "Female" else 0,
|
| 852 |
+
deep_conf if deep_gender == "Female" else 0)
|
| 853 |
+
prediction_result["gender_model_used"] = "Combined(H+DF)"
|
| 854 |
+
elif (str(how_gender) == "Male") and (str(deep_gender) == "Male"):
|
| 855 |
+
final_gender = "Male"
|
| 856 |
+
final_conf = max(how_conf if how_gender == "Male" else 0,
|
| 857 |
+
deep_conf if deep_gender == "Male" else 0)
|
| 858 |
+
prediction_result["gender_model_used"] = "Combined(H+DF)"
|
| 859 |
+
prediction_result["gender"] = final_gender
|
| 860 |
+
prediction_result["gender_confidence"] = round(float(final_conf), 4)
|
| 861 |
+
except Exception:
|
| 862 |
+
pass
|
| 863 |
+
elif model_type == ModelType.DEEPFACE and DEEPFACE_AVAILABLE:
|
| 864 |
+
prediction_result = self._predict_with_deepface(face_for_deepface)
|
| 865 |
+
# 非混合模式也进行性别合并:引入HowCuteAmI性别
|
| 866 |
+
try:
|
| 867 |
+
beauty_gender_result = self._predict_beauty_gender_with_howcuteami(
|
| 868 |
+
face_resized
|
| 869 |
+
)
|
| 870 |
+
deep_gender = prediction_result.get("gender")
|
| 871 |
+
deep_conf = float(prediction_result.get("gender_confidence", 0) or 0)
|
| 872 |
+
how_gender = beauty_gender_result.get("gender")
|
| 873 |
+
how_conf = float(beauty_gender_result.get("gender_confidence", 0) or 0)
|
| 874 |
+
final_gender = prediction_result.get("gender")
|
| 875 |
+
final_conf = float(prediction_result.get("gender_confidence", 0) or 0)
|
| 876 |
+
if (str(how_gender) == "Female") or (str(deep_gender) == "Female"):
|
| 877 |
+
final_gender = "Female"
|
| 878 |
+
final_conf = max(how_conf if how_gender == "Female" else 0,
|
| 879 |
+
deep_conf if deep_gender == "Female" else 0)
|
| 880 |
+
prediction_result["gender_model_used"] = "Combined(H+DF)"
|
| 881 |
+
elif (str(how_gender) == "Male") and (str(deep_gender) == "Male"):
|
| 882 |
+
final_gender = "Male"
|
| 883 |
+
final_conf = max(how_conf if how_gender == "Male" else 0,
|
| 884 |
+
deep_conf if deep_gender == "Male" else 0)
|
| 885 |
+
prediction_result["gender_model_used"] = "Combined(H+DF)"
|
| 886 |
+
prediction_result["gender"] = final_gender
|
| 887 |
+
prediction_result["gender_confidence"] = round(float(final_conf), 4)
|
| 888 |
+
except Exception:
|
| 889 |
+
pass
|
| 890 |
+
else:
|
| 891 |
+
# 回退到混合模式
|
| 892 |
+
prediction_result = self._predict_with_hybrid_model(
|
| 893 |
+
face_resized, face_for_deepface
|
| 894 |
+
)
|
| 895 |
+
logger.warning(f"Model {model_type.value} is not available, using hybrid mode")
|
| 896 |
+
|
| 897 |
+
except Exception as e:
|
| 898 |
+
logger.error(f"Prediction failed, using default values: {e}")
|
| 899 |
+
prediction_result = {
|
| 900 |
+
"gender": "Unknown",
|
| 901 |
+
"gender_confidence": 0.5,
|
| 902 |
+
"age": "25-32",
|
| 903 |
+
"age_confidence": 0.5,
|
| 904 |
+
"beauty_score": 5.0,
|
| 905 |
+
"beauty_raw_score": 0.5,
|
| 906 |
+
"emotion": "neutral",
|
| 907 |
+
"emotion_analysis": {"neutral": 100.0},
|
| 908 |
+
"model_used": "fallback",
|
| 909 |
+
}
|
| 910 |
+
|
| 911 |
+
# 五官分析
|
| 912 |
+
# facial_features = self.facial_analyzer.analyze_facial_features(
|
| 913 |
+
# face_cropped, face_box
|
| 914 |
+
# )
|
| 915 |
+
|
| 916 |
+
# 颜色设置与年龄显示统一(应用女性年龄调整)
|
| 917 |
+
gender = prediction_result.get("gender", "Unknown")
|
| 918 |
+
color_bgr = self.gender_colors.get(gender, (128, 128, 128))
|
| 919 |
+
color_hex = f"#{color_bgr[2]:02x}{color_bgr[1]:02x}{color_bgr[0]:02x}"
|
| 920 |
+
|
| 921 |
+
# 年龄文本与调整
|
| 922 |
+
raw_age_str = prediction_result.get("age", "Unknown")
|
| 923 |
+
display_age_str = str(raw_age_str)
|
| 924 |
+
age_adjusted_flag = False
|
| 925 |
+
age_adjustment_value = int(getattr(config, "FEMALE_AGE_ADJUSTMENT", 0) or 0)
|
| 926 |
+
age_adjustment_threshold = int(getattr(config, "FEMALE_AGE_ADJUSTMENT_THRESHOLD", 999) or 999)
|
| 927 |
+
|
| 928 |
+
# 仅对女性且年龄达到阈值时进行调整
|
| 929 |
+
try:
|
| 930 |
+
# 支持 "25-32" 或 "25" 格式
|
| 931 |
+
if "-" in str(raw_age_str):
|
| 932 |
+
age_num = int(str(raw_age_str).split("-")[0].strip("() "))
|
| 933 |
+
else:
|
| 934 |
+
age_num = int(str(raw_age_str).strip())
|
| 935 |
+
|
| 936 |
+
if str(gender) == "Female" and age_num >= age_adjustment_threshold and age_adjustment_value > 0:
|
| 937 |
+
adjusted_age = max(0, age_num - age_adjustment_value)
|
| 938 |
+
display_age_str = str(adjusted_age)
|
| 939 |
+
age_adjusted_flag = True
|
| 940 |
+
try:
|
| 941 |
+
logger.info(f"Adjusted age for female (draw+data): {age_num} -> {adjusted_age}")
|
| 942 |
+
except Exception:
|
| 943 |
+
pass
|
| 944 |
+
except Exception:
|
| 945 |
+
# 无法解析年龄时,保持原样
|
| 946 |
+
pass
|
| 947 |
+
|
| 948 |
+
# 保存裁剪的人脸
|
| 949 |
+
cropped_face_filename = f"{original_image_hash}_face_{i + 1}.webp"
|
| 950 |
+
cropped_face_path = os.path.join(OUTPUT_DIR, cropped_face_filename)
|
| 951 |
+
try:
|
| 952 |
+
save_image_high_quality(face_cropped, cropped_face_path)
|
| 953 |
+
logger.info(f"cropped face: {cropped_face_path}")
|
| 954 |
+
except Exception as e:
|
| 955 |
+
logger.error(f"Failed to save cropped face {cropped_face_path}: {e}")
|
| 956 |
+
cropped_face_filename = None
|
| 957 |
+
|
| 958 |
+
# 在图片上绘制标注
|
| 959 |
+
if config.DRAW_SCORE:
|
| 960 |
+
cv2.rectangle(
|
| 961 |
+
annotated_image,
|
| 962 |
+
(face_box[0], face_box[1]),
|
| 963 |
+
(face_box[2], face_box[3]),
|
| 964 |
+
color_bgr,
|
| 965 |
+
int(round(image.shape[0] / 400)),
|
| 966 |
+
8,
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
# 标签文本
|
| 970 |
+
beauty_score = prediction_result.get("beauty_score", 0)
|
| 971 |
+
label = f"{gender}, {display_age_str}, {beauty_score}"
|
| 972 |
+
|
| 973 |
+
font_scale = max(
|
| 974 |
+
0.3, min(0.7, image.shape[0] / 800)
|
| 975 |
+
) # 从500改为800,范围从0.5-1.0改为0.3-0.7
|
| 976 |
+
font_thickness = 2
|
| 977 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 978 |
+
# 绘制文本
|
| 979 |
+
text_x = face_box[0]
|
| 980 |
+
text_y = face_box[1] - 10 if face_box[1] - 10 > 20 else face_box[1] + 30
|
| 981 |
+
|
| 982 |
+
# 计算文字大小(宽高)
|
| 983 |
+
(text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, font_thickness)
|
| 984 |
+
|
| 985 |
+
# 画黑色矩形背景,稍微比文字框大一点,增加边距
|
| 986 |
+
background_tl = (text_x, text_y - text_height - baseline) # 矩形左上角
|
| 987 |
+
background_br = (text_x + text_width, text_y + baseline) # 矩形右下角
|
| 988 |
+
|
| 989 |
+
if config.DRAW_SCORE:
|
| 990 |
+
cv2.rectangle(
|
| 991 |
+
annotated_image,
|
| 992 |
+
background_tl,
|
| 993 |
+
background_br,
|
| 994 |
+
color_bgr, # 黑色背景
|
| 995 |
+
thickness=-1 # 填充
|
| 996 |
+
)
|
| 997 |
+
cv2.putText(
|
| 998 |
+
annotated_image,
|
| 999 |
+
label,
|
| 1000 |
+
(text_x, text_y),
|
| 1001 |
+
font,
|
| 1002 |
+
font_scale,
|
| 1003 |
+
(255, 255, 255),
|
| 1004 |
+
font_thickness,
|
| 1005 |
+
cv2.LINE_AA,
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
# 构建人脸结果
|
| 1009 |
+
face_result = {
|
| 1010 |
+
"face_id": i + 1,
|
| 1011 |
+
"gender": gender,
|
| 1012 |
+
"gender_confidence": prediction_result.get("gender_confidence", 0),
|
| 1013 |
+
"gender_model_used": prediction_result.get("gender_model_used", prediction_result.get("model_used", model_type.value)),
|
| 1014 |
+
"age": display_age_str,
|
| 1015 |
+
"age_confidence": prediction_result.get("age_confidence", 0),
|
| 1016 |
+
"age_model_used": prediction_result.get("age_model_used", prediction_result.get("model_used", model_type.value)),
|
| 1017 |
+
"beauty_score": prediction_result.get("beauty_score", 0),
|
| 1018 |
+
"beauty_raw_score": prediction_result.get("beauty_raw_score", 0),
|
| 1019 |
+
"emotion": prediction_result.get("emotion", "neutral"),
|
| 1020 |
+
"emotion_analysis": prediction_result.get("emotion_analysis", {}),
|
| 1021 |
+
# "facial_features": facial_features, # 五官分析
|
| 1022 |
+
"bounding_box": {
|
| 1023 |
+
"x1": int(face_box[0]),
|
| 1024 |
+
"y1": int(face_box[1]),
|
| 1025 |
+
"x2": int(face_box[2]),
|
| 1026 |
+
"y2": int(face_box[3]),
|
| 1027 |
+
},
|
| 1028 |
+
"color": {
|
| 1029 |
+
"bgr": [int(color_bgr[0]), int(color_bgr[1]), int(color_bgr[2])],
|
| 1030 |
+
"hex": color_hex,
|
| 1031 |
+
},
|
| 1032 |
+
"cropped_face_filename": cropped_face_filename,
|
| 1033 |
+
"model_used": prediction_result.get("model_used", model_type.value),
|
| 1034 |
+
}
|
| 1035 |
+
|
| 1036 |
+
if age_adjusted_flag:
|
| 1037 |
+
face_result["age_adjusted"] = True
|
| 1038 |
+
face_result["age_adjustment_value"] = int(age_adjustment_value)
|
| 1039 |
+
|
| 1040 |
+
results["faces"].append(face_result)
|
| 1041 |
+
|
| 1042 |
+
results["annotated_image"] = annotated_image
|
| 1043 |
+
return results
|
| 1044 |
+
|
| 1045 |
+
def _warmup_models(self):
|
| 1046 |
+
"""预热模型,减少首次调用延迟"""
|
| 1047 |
+
try:
|
| 1048 |
+
logger.info("Starting to warm up models...")
|
| 1049 |
+
|
| 1050 |
+
# 创建一个小的测试图像 (64x64)
|
| 1051 |
+
test_image = np.ones((64, 64, 3), dtype=np.uint8) * 128
|
| 1052 |
+
|
| 1053 |
+
# 预热DeepFace模型(如果可用)
|
| 1054 |
+
if DEEPFACE_AVAILABLE:
|
| 1055 |
+
try:
|
| 1056 |
+
import tempfile
|
| 1057 |
+
with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_file:
|
| 1058 |
+
cv2.imwrite(tmp_file.name, test_image, [cv2.IMWRITE_WEBP_QUALITY, 95])
|
| 1059 |
+
# 预热DeepFace - 使用最小的actions集合
|
| 1060 |
+
DeepFace.analyze(
|
| 1061 |
+
img_path=tmp_file.name,
|
| 1062 |
+
actions=["age", "emotion", "gender"],
|
| 1063 |
+
detector_backend="yolov8",
|
| 1064 |
+
enforce_detection=False,
|
| 1065 |
+
silent=True
|
| 1066 |
+
)
|
| 1067 |
+
os.unlink(tmp_file.name)
|
| 1068 |
+
logger.info("DeepFace model warm-up completed")
|
| 1069 |
+
except Exception as e:
|
| 1070 |
+
logger.warning(f"DeepFace model warm-up failed: {e}")
|
| 1071 |
+
|
| 1072 |
+
# 预热OpenCV DNN模型
|
| 1073 |
+
try:
|
| 1074 |
+
# 预热人脸检测模型
|
| 1075 |
+
blob = cv2.dnn.blobFromImage(test_image, 1.0, (300, 300), (104, 117, 123))
|
| 1076 |
+
self.face_net.setInput(blob)
|
| 1077 |
+
self.face_net.forward()
|
| 1078 |
+
|
| 1079 |
+
# 预热年龄预测模型
|
| 1080 |
+
test_face = cv2.resize(test_image, (224, 224))
|
| 1081 |
+
blob = cv2.dnn.blobFromImage(test_face, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False)
|
| 1082 |
+
self.age_net.setInput(blob)
|
| 1083 |
+
self.age_net.forward()
|
| 1084 |
+
|
| 1085 |
+
# 预热性别预测模型
|
| 1086 |
+
self.gender_net.setInput(blob)
|
| 1087 |
+
self.gender_net.forward()
|
| 1088 |
+
|
| 1089 |
+
# 预热颜值评分模型
|
| 1090 |
+
self.beauty_net.setInput(blob)
|
| 1091 |
+
self.beauty_net.forward()
|
| 1092 |
+
|
| 1093 |
+
logger.info("OpenCV DNN model warm-up completed")
|
| 1094 |
+
except Exception as e:
|
| 1095 |
+
logger.warning(f"OpenCV DNN model warm-up failed: {e}")
|
| 1096 |
+
|
| 1097 |
+
logger.info("Model warm-up completed")
|
| 1098 |
+
except Exception as e:
|
| 1099 |
+
logger.warning(f"Error occurred during model warm-up: {e}")
|
facial_analyzer.py
ADDED
|
@@ -0,0 +1,912 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import traceback
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
import config
|
| 8 |
+
from config import logger, DLIB_AVAILABLE
|
| 9 |
+
|
| 10 |
+
if DLIB_AVAILABLE:
|
| 11 |
+
import mediapipe as mp
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FacialFeatureAnalyzer:
|
| 15 |
+
"""五官分析器"""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.face_mesh = None
|
| 19 |
+
if DLIB_AVAILABLE:
|
| 20 |
+
try:
|
| 21 |
+
# 初始化MediaPipe Face Mesh
|
| 22 |
+
mp_face_mesh = mp.solutions.face_mesh
|
| 23 |
+
self.face_mesh = mp_face_mesh.FaceMesh(
|
| 24 |
+
static_image_mode=True,
|
| 25 |
+
max_num_faces=1,
|
| 26 |
+
refine_landmarks=True,
|
| 27 |
+
min_detection_confidence=0.5,
|
| 28 |
+
min_tracking_confidence=0.5
|
| 29 |
+
)
|
| 30 |
+
logger.info("MediaPipe face landmark detector loaded successfully")
|
| 31 |
+
except Exception as e:
|
| 32 |
+
logger.error(f"Failed to load MediaPipe model: {e}")
|
| 33 |
+
|
| 34 |
+
def analyze_facial_features(
|
| 35 |
+
self, face_image: np.ndarray, face_box: List[int]
|
| 36 |
+
) -> Dict[str, Any]:
|
| 37 |
+
"""
|
| 38 |
+
分析五官特征
|
| 39 |
+
:param face_image: 人脸图像
|
| 40 |
+
:param face_box: 人脸边界框 [x1, y1, x2, y2]
|
| 41 |
+
:return: 五官分析结果
|
| 42 |
+
"""
|
| 43 |
+
if not DLIB_AVAILABLE or self.face_mesh is None:
|
| 44 |
+
return self._basic_facial_analysis(face_image)
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
# MediaPipe需要RGB图像
|
| 48 |
+
rgb_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
|
| 49 |
+
|
| 50 |
+
# 检测关键点
|
| 51 |
+
results = self.face_mesh.process(rgb_image)
|
| 52 |
+
|
| 53 |
+
if not results.multi_face_landmarks:
|
| 54 |
+
logger.warning("No facial landmarks detected")
|
| 55 |
+
return self._basic_facial_analysis(face_image)
|
| 56 |
+
|
| 57 |
+
# 获取第一个面部的关键点
|
| 58 |
+
face_landmarks = results.multi_face_landmarks[0]
|
| 59 |
+
|
| 60 |
+
# 将MediaPipe的468个关键点转换为类似dlib 68点的格式
|
| 61 |
+
points = self._convert_mediapipe_to_dlib_format(face_landmarks, face_image.shape)
|
| 62 |
+
|
| 63 |
+
return self._analyze_features_from_landmarks(points, face_image.shape)
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.error(f"Facial feature analysis failed: {e}")
|
| 67 |
+
traceback.print_exc() # ← 打印完整堆栈,包括确切行号
|
| 68 |
+
return self._basic_facial_analysis(face_image)
|
| 69 |
+
|
| 70 |
+
def _convert_mediapipe_to_dlib_format(self, face_landmarks, image_shape):
|
| 71 |
+
"""
|
| 72 |
+
将MediaPipe的468个关键点转换为类似dlib 68点的格式
|
| 73 |
+
MediaPipe到dlib的关键点映射
|
| 74 |
+
"""
|
| 75 |
+
h, w = image_shape[:2]
|
| 76 |
+
|
| 77 |
+
# MediaPipe关键点索引到dlib 68点的映射
|
| 78 |
+
# 这个映射基于MediaPipe Face Mesh的标准索引
|
| 79 |
+
mediapipe_to_dlib_map = {
|
| 80 |
+
# 面部轮廓 (0-16)
|
| 81 |
+
0: 234, # 下巴最低点
|
| 82 |
+
1: 132, # 右脸颊下
|
| 83 |
+
2: 172, # 右脸颊
|
| 84 |
+
3: 136, # 右脸颊上
|
| 85 |
+
4: 150, # 右颧骨
|
| 86 |
+
5: 149, # 右太阳穴
|
| 87 |
+
6: 176, # 右额头边缘
|
| 88 |
+
7: 148, # 右额头
|
| 89 |
+
8: 152, # 额头中央
|
| 90 |
+
9: 377, # 左额头
|
| 91 |
+
10: 400, # 左额头边缘
|
| 92 |
+
11: 378, # 左太阳穴
|
| 93 |
+
12: 379, # 左颧骨
|
| 94 |
+
13: 365, # 左脸颊上
|
| 95 |
+
14: 397, # 左脸颊
|
| 96 |
+
15: 361, # 左脸颊下
|
| 97 |
+
16: 454, # 下巴左侧
|
| 98 |
+
|
| 99 |
+
# 右眉毛 (17-21)
|
| 100 |
+
17: 70, # 右眉毛外端
|
| 101 |
+
18: 63, # 右眉毛
|
| 102 |
+
19: 105, # 右眉毛
|
| 103 |
+
20: 66, # 右眉毛
|
| 104 |
+
21: 107, # 右眉毛内端
|
| 105 |
+
|
| 106 |
+
# 左眉毛 (22-26)
|
| 107 |
+
22: 336, # 左眉毛内端
|
| 108 |
+
23: 296, # 左眉毛
|
| 109 |
+
24: 334, # 左眉毛
|
| 110 |
+
25: 293, # 左眉毛
|
| 111 |
+
26: 300, # 左眉毛外端
|
| 112 |
+
|
| 113 |
+
# 鼻梁 (27-30)
|
| 114 |
+
27: 168, # 鼻梁顶
|
| 115 |
+
28: 8, # 鼻梁
|
| 116 |
+
29: 9, # 鼻梁
|
| 117 |
+
30: 10, # 鼻梁底
|
| 118 |
+
|
| 119 |
+
# 鼻翼 (31-35)
|
| 120 |
+
31: 151, # 右鼻翼
|
| 121 |
+
32: 134, # 右鼻孔
|
| 122 |
+
33: 2, # 鼻尖
|
| 123 |
+
34: 363, # 左鼻孔
|
| 124 |
+
35: 378, # 左鼻翼
|
| 125 |
+
|
| 126 |
+
# 右眼 (36-41)
|
| 127 |
+
36: 33, # 右眼外角
|
| 128 |
+
37: 7, # 右眼上眼睑
|
| 129 |
+
38: 163, # 右眼上眼睑
|
| 130 |
+
39: 144, # 右眼内角
|
| 131 |
+
40: 145, # 右眼下眼睑
|
| 132 |
+
41: 153, # 右眼下眼睑
|
| 133 |
+
|
| 134 |
+
# 左眼 (42-47)
|
| 135 |
+
42: 362, # 左眼内角
|
| 136 |
+
43: 382, # 左眼上眼睑
|
| 137 |
+
44: 381, # 左眼上眼睑
|
| 138 |
+
45: 380, # 左眼外角
|
| 139 |
+
46: 374, # 左眼下眼睑
|
| 140 |
+
47: 373, # 左眼下眼睑
|
| 141 |
+
|
| 142 |
+
# 嘴部轮廓 (48-67)
|
| 143 |
+
48: 78, # 右嘴角
|
| 144 |
+
49: 95, # 右上唇
|
| 145 |
+
50: 88, # 上唇右侧
|
| 146 |
+
51: 178, # 上唇中央右
|
| 147 |
+
52: 87, # 上唇中央
|
| 148 |
+
53: 14, # 上唇中央左
|
| 149 |
+
54: 317, # 上唇左侧
|
| 150 |
+
55: 318, # 左上唇
|
| 151 |
+
56: 308, # 左嘴角
|
| 152 |
+
57: 324, # 左下唇
|
| 153 |
+
58: 318, # 下唇左侧
|
| 154 |
+
59: 16, # 下唇中央左
|
| 155 |
+
60: 17, # 下唇中央
|
| 156 |
+
61: 18, # 下唇中央右
|
| 157 |
+
62: 200, # 下唇右侧
|
| 158 |
+
63: 199, # 右下唇
|
| 159 |
+
64: 175, # 右嘴角内
|
| 160 |
+
65: 84, # 上唇内右
|
| 161 |
+
66: 17, # 下唇内中央
|
| 162 |
+
67: 314, # 上唇内左
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
# 转换关键点
|
| 166 |
+
points = []
|
| 167 |
+
for i in range(68):
|
| 168 |
+
if i in mediapipe_to_dlib_map:
|
| 169 |
+
mp_idx = mediapipe_to_dlib_map[i]
|
| 170 |
+
if mp_idx < len(face_landmarks.landmark):
|
| 171 |
+
landmark = face_landmarks.landmark[mp_idx]
|
| 172 |
+
x = int(landmark.x * w)
|
| 173 |
+
y = int(landmark.y * h)
|
| 174 |
+
points.append((x, y))
|
| 175 |
+
else:
|
| 176 |
+
# 如果索引超出范围,使用默认位置
|
| 177 |
+
points.append((w//2, h//2))
|
| 178 |
+
else:
|
| 179 |
+
# 如果没有映射,使用默认位置
|
| 180 |
+
points.append((w//2, h//2))
|
| 181 |
+
|
| 182 |
+
return points
|
| 183 |
+
|
| 184 |
+
def _analyze_features_from_landmarks(
|
| 185 |
+
self, landmarks: List[tuple], image_shape: tuple
|
| 186 |
+
) -> Dict[str, Any]:
|
| 187 |
+
"""基于68个关键点分析五官"""
|
| 188 |
+
try:
|
| 189 |
+
# 定义各部位的关键点索引
|
| 190 |
+
jawline = landmarks[0:17] # 下颌线
|
| 191 |
+
left_eyebrow = landmarks[17:22] # 左眉毛
|
| 192 |
+
right_eyebrow = landmarks[22:27] # 右眉毛
|
| 193 |
+
nose = landmarks[27:36] # 鼻子
|
| 194 |
+
left_eye = landmarks[36:42] # 左眼
|
| 195 |
+
right_eye = landmarks[42:48] # 右眼
|
| 196 |
+
mouth = landmarks[48:68] # 嘴巴
|
| 197 |
+
|
| 198 |
+
# 计算各部位得分 (简化版,实际应用需要更复杂的算法)
|
| 199 |
+
scores = {
|
| 200 |
+
"eyes": self._score_eyes(left_eye, right_eye, image_shape),
|
| 201 |
+
"nose": self._score_nose(nose, image_shape),
|
| 202 |
+
"mouth": self._score_mouth(mouth, image_shape),
|
| 203 |
+
"eyebrows": self._score_eyebrows(
|
| 204 |
+
left_eyebrow, right_eyebrow, image_shape
|
| 205 |
+
),
|
| 206 |
+
"jawline": self._score_jawline(jawline, image_shape),
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
# 计算总体协调性
|
| 210 |
+
harmony_score = self._calculate_harmony_new(landmarks, image_shape)
|
| 211 |
+
# 温和上调整体协调性分数(与颜值类似的拉升策略)
|
| 212 |
+
harmony_score = self._adjust_harmony_score(harmony_score)
|
| 213 |
+
|
| 214 |
+
return {
|
| 215 |
+
"facial_features": scores,
|
| 216 |
+
"harmony_score": round(harmony_score, 2),
|
| 217 |
+
"overall_facial_score": round(sum(scores.values()) / len(scores), 2),
|
| 218 |
+
"analysis_method": "mediapipe_landmarks",
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.error(f"Landmark analysis failed: {e}")
|
| 223 |
+
return self._basic_facial_analysis(None)
|
| 224 |
+
|
| 225 |
+
def _adjust_harmony_score(self, score: float) -> float:
|
| 226 |
+
"""整体协调性分值温和拉升:当低于阈值时往阈值靠拢一点。"""
|
| 227 |
+
try:
|
| 228 |
+
if not getattr(config, "HARMONY_ADJUST_ENABLED", False):
|
| 229 |
+
return round(float(score), 2)
|
| 230 |
+
thr = float(getattr(config, "HARMONY_ADJUST_THRESHOLD", 8.0))
|
| 231 |
+
gamma = float(getattr(config, "HARMONY_ADJUST_GAMMA", 0.5))
|
| 232 |
+
gamma = max(0.0001, min(1.0, gamma))
|
| 233 |
+
s = float(score)
|
| 234 |
+
if s < thr:
|
| 235 |
+
s = thr - gamma * (thr - s)
|
| 236 |
+
return round(min(10.0, max(0.0, s)), 2)
|
| 237 |
+
except Exception:
|
| 238 |
+
try:
|
| 239 |
+
return round(float(score), 2)
|
| 240 |
+
except Exception:
|
| 241 |
+
return 6.21
|
| 242 |
+
|
| 243 |
+
def _score_eyes(
|
| 244 |
+
self, left_eye: List[tuple], right_eye: List[tuple], image_shape: tuple
|
| 245 |
+
) -> float:
|
| 246 |
+
"""眼部评分"""
|
| 247 |
+
try:
|
| 248 |
+
# 计算眼部对称性和大小
|
| 249 |
+
left_width = abs(left_eye[3][0] - left_eye[0][0])
|
| 250 |
+
right_width = abs(right_eye[3][0] - right_eye[0][0])
|
| 251 |
+
|
| 252 |
+
# 计算眼部高度
|
| 253 |
+
left_height = abs(left_eye[1][1] - left_eye[5][1])
|
| 254 |
+
right_height = abs(right_eye[1][1] - right_eye[5][1])
|
| 255 |
+
|
| 256 |
+
# 对称性评分 - 宽度对称性
|
| 257 |
+
width_symmetry = 1 - min(
|
| 258 |
+
abs(left_width - right_width) / max(left_width, right_width), 0.5
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# 高度对称性
|
| 262 |
+
height_symmetry = 1 - min(
|
| 263 |
+
abs(left_height - right_height) / max(left_height, right_height), 0.5
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# 大小适中性评分 (相对于脸部宽度) - 调整理想比例
|
| 267 |
+
avg_eye_width = (left_width + right_width) / 2
|
| 268 |
+
face_width = image_shape[1]
|
| 269 |
+
ideal_ratio = 0.08 # 调整理想比例,原来0.15太大
|
| 270 |
+
size_score = max(
|
| 271 |
+
0, 1 - abs(avg_eye_width / face_width - ideal_ratio) / ideal_ratio
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# 眼部长宽比评分
|
| 275 |
+
avg_eye_height = (left_height + right_height) / 2
|
| 276 |
+
aspect_ratio = avg_eye_width / max(avg_eye_height, 1) # 避免除零
|
| 277 |
+
ideal_aspect = 3.0 # 理想长宽比
|
| 278 |
+
aspect_score = max(0, 1 - abs(aspect_ratio - ideal_aspect) / ideal_aspect)
|
| 279 |
+
|
| 280 |
+
final_score = (
|
| 281 |
+
width_symmetry * 0.3
|
| 282 |
+
+ height_symmetry * 0.3
|
| 283 |
+
+ size_score * 0.25
|
| 284 |
+
+ aspect_score * 0.15
|
| 285 |
+
) * 10
|
| 286 |
+
return round(max(0, min(10, final_score)), 2)
|
| 287 |
+
except:
|
| 288 |
+
return 6.21
|
| 289 |
+
|
| 290 |
+
def _score_nose(self, nose: List[tuple], image_shape: tuple) -> float:
|
| 291 |
+
"""鼻部评分"""
|
| 292 |
+
try:
|
| 293 |
+
# 鼻子关键点
|
| 294 |
+
nose_tip = nose[3] # 鼻尖
|
| 295 |
+
nose_bridge_top = nose[0] # 鼻梁顶部
|
| 296 |
+
left_nostril = nose[1]
|
| 297 |
+
right_nostril = nose[5]
|
| 298 |
+
|
| 299 |
+
# 计算鼻子的直线度 (鼻梁是否挺直)
|
| 300 |
+
straightness = 1 - min(
|
| 301 |
+
abs(nose_tip[0] - nose_bridge_top[0]) / (image_shape[1] * 0.1), 1.0
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# 鼻宽评分 - 使用鼻翼宽度
|
| 305 |
+
nose_width = abs(right_nostril[0] - left_nostril[0])
|
| 306 |
+
face_width = image_shape[1]
|
| 307 |
+
ideal_nose_ratio = 0.06 # 调整理想比例
|
| 308 |
+
width_score = max(
|
| 309 |
+
0,
|
| 310 |
+
1 - abs(nose_width / face_width - ideal_nose_ratio) / ideal_nose_ratio,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# 鼻子长度评分
|
| 314 |
+
nose_length = abs(nose_tip[1] - nose_bridge_top[1])
|
| 315 |
+
face_height = image_shape[0]
|
| 316 |
+
ideal_length_ratio = 0.08
|
| 317 |
+
length_score = max(
|
| 318 |
+
0,
|
| 319 |
+
1
|
| 320 |
+
- abs(nose_length / face_height - ideal_length_ratio)
|
| 321 |
+
/ ideal_length_ratio,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
final_score = (
|
| 325 |
+
straightness * 0.4 + width_score * 0.35 + length_score * 0.25
|
| 326 |
+
) * 10
|
| 327 |
+
return round(max(0, min(10, final_score)), 2)
|
| 328 |
+
except:
|
| 329 |
+
return 6.21
|
| 330 |
+
|
| 331 |
+
def _score_mouth(self, mouth: List[tuple], image_shape: tuple) -> float:
|
| 332 |
+
"""嘴部评分 - 大幅优化,更宽松的评分标准"""
|
| 333 |
+
try:
|
| 334 |
+
# 嘴角点
|
| 335 |
+
left_corner = mouth[0] # 左嘴角
|
| 336 |
+
right_corner = mouth[6] # 右嘴角
|
| 337 |
+
|
| 338 |
+
# 上唇和下唇中心点
|
| 339 |
+
upper_lip_center = mouth[3] # 上唇中心
|
| 340 |
+
lower_lip_center = mouth[9] # 下唇中心
|
| 341 |
+
|
| 342 |
+
# 基础分数,避免过低
|
| 343 |
+
base_score = 6.0
|
| 344 |
+
|
| 345 |
+
# 1. 嘴宽评分 - 更宽松的标准
|
| 346 |
+
mouth_width = abs(right_corner[0] - left_corner[0])
|
| 347 |
+
face_width = image_shape[1]
|
| 348 |
+
mouth_ratio = mouth_width / face_width
|
| 349 |
+
|
| 350 |
+
# 设置更宽的合理范围 (0.04-0.15)
|
| 351 |
+
if 0.04 <= mouth_ratio <= 0.15:
|
| 352 |
+
width_score = 1.0 # 在合理范围内就给满分
|
| 353 |
+
elif mouth_ratio < 0.04:
|
| 354 |
+
width_score = max(0.3, mouth_ratio / 0.04) # 太小时渐减
|
| 355 |
+
else:
|
| 356 |
+
width_score = max(0.3, 0.15 / mouth_ratio) # 太大时渐减
|
| 357 |
+
|
| 358 |
+
# 2. 唇厚度评分 - 简化并放宽标准
|
| 359 |
+
lip_thickness = abs(lower_lip_center[1] - upper_lip_center[1])
|
| 360 |
+
# 只要厚度不是极端值就给高分
|
| 361 |
+
if lip_thickness > 3: # 像素值,有一定厚度
|
| 362 |
+
thickness_score = min(1.0, lip_thickness / 25) # 25像素为满分
|
| 363 |
+
else:
|
| 364 |
+
thickness_score = 0.5 # 太薄给中等分数
|
| 365 |
+
|
| 366 |
+
# 3. 嘴部对称性评分 - 更宽松
|
| 367 |
+
mouth_center_x = (left_corner[0] + right_corner[0]) / 2
|
| 368 |
+
face_center_x = image_shape[1] / 2
|
| 369 |
+
center_deviation = abs(mouth_center_x - face_center_x) / face_width
|
| 370 |
+
|
| 371 |
+
if center_deviation < 0.02: # 偏差小于2%
|
| 372 |
+
symmetry_score = 1.0
|
| 373 |
+
elif center_deviation < 0.05: # 偏差小于5%
|
| 374 |
+
symmetry_score = 0.8
|
| 375 |
+
else:
|
| 376 |
+
symmetry_score = max(0.5, 1 - center_deviation * 10) # 最低0.5分
|
| 377 |
+
|
| 378 |
+
# 4. 嘴唇形状评分 - 简化
|
| 379 |
+
# 检查嘴角是否在合理位置
|
| 380 |
+
corner_height_diff = abs(left_corner[1] - right_corner[1])
|
| 381 |
+
if corner_height_diff < face_width * 0.02: # 嘴角高度差异小
|
| 382 |
+
shape_score = 1.0
|
| 383 |
+
else:
|
| 384 |
+
shape_score = max(0.6, 1 - corner_height_diff / (face_width * 0.02))
|
| 385 |
+
|
| 386 |
+
# 5. 综合评分 - 调整权重,给基础分更大权重
|
| 387 |
+
feature_score = (
|
| 388 |
+
width_score * 0.3
|
| 389 |
+
+ thickness_score * 0.25
|
| 390 |
+
+ symmetry_score * 0.25
|
| 391 |
+
+ shape_score * 0.2
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# 最终分数 = 基础分 + 特征分奖励
|
| 395 |
+
final_score = base_score + feature_score * 4 # 最高10分
|
| 396 |
+
|
| 397 |
+
return round(max(4.0, min(10, final_score)), 2) # 最低4分,最高10分
|
| 398 |
+
except Exception as e:
|
| 399 |
+
return 6.21
|
| 400 |
+
|
| 401 |
+
def _score_eyebrows(
|
| 402 |
+
self, left_brow: List[tuple], right_brow: List[tuple], image_shape: tuple
|
| 403 |
+
) -> float:
|
| 404 |
+
"""眉毛评分 - 改进算法"""
|
| 405 |
+
try:
|
| 406 |
+
# 计算眉毛长度
|
| 407 |
+
left_length = abs(left_brow[-1][0] - left_brow[0][0])
|
| 408 |
+
right_length = abs(right_brow[-1][0] - right_brow[0][0])
|
| 409 |
+
|
| 410 |
+
# 长度对称性
|
| 411 |
+
length_symmetry = 1 - min(
|
| 412 |
+
abs(left_length - right_length) / max(left_length, right_length), 0.5
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# 计算眉毛拱形 - 改进方法
|
| 416 |
+
left_peak_y = min([p[1] for p in left_brow]) # 眉峰(y坐标最小)
|
| 417 |
+
left_ends_y = (left_brow[0][1] + left_brow[-1][1]) / 2 # 眉毛两端平均高度
|
| 418 |
+
left_arch = max(0, left_ends_y - left_peak_y) # 拱形高度
|
| 419 |
+
|
| 420 |
+
right_peak_y = min([p[1] for p in right_brow])
|
| 421 |
+
right_ends_y = (right_brow[0][1] + right_brow[-1][1]) / 2
|
| 422 |
+
right_arch = max(0, right_ends_y - right_peak_y)
|
| 423 |
+
|
| 424 |
+
# 拱形对称性
|
| 425 |
+
arch_symmetry = 1 - min(
|
| 426 |
+
abs(left_arch - right_arch) / max(left_arch, right_arch, 1), 0.5
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# 眉形适中性评分
|
| 430 |
+
avg_arch = (left_arch + right_arch) / 2
|
| 431 |
+
face_height = image_shape[0]
|
| 432 |
+
ideal_arch_ratio = 0.015 # 理想拱形比例
|
| 433 |
+
arch_ratio = avg_arch / face_height
|
| 434 |
+
arch_score = max(
|
| 435 |
+
0, 1 - abs(arch_ratio - ideal_arch_ratio) / ideal_arch_ratio
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# 眉毛浓密度(通过点的密集程度估算)
|
| 439 |
+
density_score = min(1.0, (len(left_brow) + len(right_brow)) / 10)
|
| 440 |
+
|
| 441 |
+
final_score = (
|
| 442 |
+
length_symmetry * 0.3
|
| 443 |
+
+ arch_symmetry * 0.3
|
| 444 |
+
+ arch_score * 0.25
|
| 445 |
+
+ density_score * 0.15
|
| 446 |
+
) * 10
|
| 447 |
+
return round(max(0, min(10, final_score)), 2)
|
| 448 |
+
except:
|
| 449 |
+
return 6.21
|
| 450 |
+
|
| 451 |
+
def _score_jawline(self, jawline: List[tuple], image_shape: tuple) -> float:
|
| 452 |
+
"""下颌线评分 - 改进算法"""
|
| 453 |
+
try:
|
| 454 |
+
jaw_points = [(p[0], p[1]) for p in jawline]
|
| 455 |
+
|
| 456 |
+
# 关键点
|
| 457 |
+
left_jaw = jaw_points[2] # 左下颌角
|
| 458 |
+
jaw_tip = jaw_points[8] # 下巴尖
|
| 459 |
+
right_jaw = jaw_points[14] # 右下颌角
|
| 460 |
+
|
| 461 |
+
# 对称性评分 - 改进计算
|
| 462 |
+
left_dist = (
|
| 463 |
+
(left_jaw[0] - jaw_tip[0]) ** 2 + (left_jaw[1] - jaw_tip[1]) ** 2
|
| 464 |
+
) ** 0.5
|
| 465 |
+
right_dist = (
|
| 466 |
+
(right_jaw[0] - jaw_tip[0]) ** 2 + (right_jaw[1] - jaw_tip[1]) ** 2
|
| 467 |
+
) ** 0.5
|
| 468 |
+
symmetry = 1 - min(
|
| 469 |
+
abs(left_dist - right_dist) / max(left_dist, right_dist), 0.5
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
# 下颌角度评分
|
| 473 |
+
left_angle_y = abs(left_jaw[1] - jaw_tip[1])
|
| 474 |
+
right_angle_y = abs(right_jaw[1] - jaw_tip[1])
|
| 475 |
+
avg_angle = (left_angle_y + right_angle_y) / 2
|
| 476 |
+
|
| 477 |
+
# 理想的下颌角度
|
| 478 |
+
face_height = image_shape[0]
|
| 479 |
+
ideal_angle_ratio = 0.08
|
| 480 |
+
angle_ratio = avg_angle / face_height
|
| 481 |
+
angle_score = max(
|
| 482 |
+
0, 1 - abs(angle_ratio - ideal_angle_ratio) / ideal_angle_ratio
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# 下颌线清晰度(通过点间距离变化评估)
|
| 486 |
+
smoothness_score = 0.8 # 简化处理,可以根据实际需要改进
|
| 487 |
+
|
| 488 |
+
final_score = (
|
| 489 |
+
symmetry * 0.4 + angle_score * 0.35 + smoothness_score * 0.25
|
| 490 |
+
) * 10
|
| 491 |
+
return round(max(0, min(10, final_score)), 2)
|
| 492 |
+
except:
|
| 493 |
+
return 6.21
|
| 494 |
+
|
| 495 |
+
def _calculate_harmony(self, landmarks: List[tuple], image_shape: tuple) -> float:
|
| 496 |
+
"""计算五官协调性"""
|
| 497 |
+
try:
|
| 498 |
+
# 黄金比例检测 (简化版)
|
| 499 |
+
face_height = max([p[1] for p in landmarks]) - min(
|
| 500 |
+
[p[1] for p in landmarks]
|
| 501 |
+
)
|
| 502 |
+
face_width = max([p[0] for p in landmarks]) - min([p[0] for p in landmarks])
|
| 503 |
+
|
| 504 |
+
# 理想比例约为1.618
|
| 505 |
+
ratio = face_height / face_width if face_width > 0 else 1
|
| 506 |
+
golden_ratio = 1.618
|
| 507 |
+
harmony = 1 - abs(ratio - golden_ratio) / golden_ratio
|
| 508 |
+
|
| 509 |
+
return max(0, min(10, harmony * 10))
|
| 510 |
+
except:
|
| 511 |
+
return 6.21
|
| 512 |
+
|
| 513 |
+
def _calculate_harmony_new(
|
| 514 |
+
self, landmarks: List[tuple], image_shape: tuple
|
| 515 |
+
) -> float:
|
| 516 |
+
"""
|
| 517 |
+
计算五官协调性 - 优化版本
|
| 518 |
+
基于多个美学比例和对称性指标
|
| 519 |
+
"""
|
| 520 |
+
try:
|
| 521 |
+
logger.info(f"face landmarks={len(landmarks)}")
|
| 522 |
+
if len(landmarks) < 68: # 假设使用68点面部关键点
|
| 523 |
+
return 6.21
|
| 524 |
+
|
| 525 |
+
# 转换为numpy数组便于计算
|
| 526 |
+
points = np.array(landmarks)
|
| 527 |
+
|
| 528 |
+
# 1. 面部基础测量
|
| 529 |
+
face_measurements = self._get_face_measurements(points)
|
| 530 |
+
|
| 531 |
+
# 2. 计算多个协调性指标
|
| 532 |
+
scores = []
|
| 533 |
+
|
| 534 |
+
# 黄金比例评分 (权重: 20%)
|
| 535 |
+
golden_score = self._calculate_golden_ratios(face_measurements)
|
| 536 |
+
logger.info(f"Golden ratio score={golden_score}")
|
| 537 |
+
scores.append(("golden_ratio", golden_score, 0.10))
|
| 538 |
+
|
| 539 |
+
# 对称性评分 (权重: 25%)
|
| 540 |
+
symmetry_score = self._calculate_facial_symmetry(face_measurements, points)
|
| 541 |
+
logger.info(f"Symmetry score={symmetry_score}")
|
| 542 |
+
scores.append(("symmetry", symmetry_score, 0.40))
|
| 543 |
+
|
| 544 |
+
# 三庭五眼比例 (权重: 20%)
|
| 545 |
+
proportion_score = self._calculate_classical_proportions(face_measurements)
|
| 546 |
+
logger.info(f"Three courts five eyes ratio={proportion_score}")
|
| 547 |
+
scores.append(("proportions", proportion_score, 0.05))
|
| 548 |
+
|
| 549 |
+
# 五官间距协调性 (权重: 15%)
|
| 550 |
+
spacing_score = self._calculate_feature_spacing(face_measurements)
|
| 551 |
+
logger.info(f"Facial feature spacing harmony={spacing_score}")
|
| 552 |
+
scores.append(("spacing", spacing_score, 0))
|
| 553 |
+
|
| 554 |
+
# 面部轮廓协调性 (权重: 10%)
|
| 555 |
+
contour_score = self._calculate_contour_harmony(points)
|
| 556 |
+
logger.info(f"Facial contour harmony={contour_score}")
|
| 557 |
+
scores.append(("contour", contour_score, 0.05))
|
| 558 |
+
|
| 559 |
+
# 眼鼻口比例协调性 (权重: 10%)
|
| 560 |
+
feature_score = self._calculate_feature_proportions(face_measurements)
|
| 561 |
+
logger.info(f"Eye-nose-mouth proportion harmony={feature_score}")
|
| 562 |
+
scores.append(("features", feature_score, 0.40))
|
| 563 |
+
|
| 564 |
+
# 加权平均计算最终得分
|
| 565 |
+
final_score = sum(score * weight for _, score, weight in scores)
|
| 566 |
+
logger.info(f"Weighted average final score={final_score}")
|
| 567 |
+
return max(0, min(10, final_score))
|
| 568 |
+
|
| 569 |
+
except Exception as e:
|
| 570 |
+
logger.error(f"Error calculating facial harmony: {e}")
|
| 571 |
+
traceback.print_exc() # ← 打印完整堆栈,包括确切行号
|
| 572 |
+
return 6.21
|
| 573 |
+
|
| 574 |
+
def _get_face_measurements(self, points: np.ndarray) -> Dict[str, float]:
|
| 575 |
+
"""提取面部关键测量数据"""
|
| 576 |
+
measurements = {}
|
| 577 |
+
|
| 578 |
+
# 面部轮廓点 (0-16)
|
| 579 |
+
face_contour = points[0:17]
|
| 580 |
+
|
| 581 |
+
# 眉毛点 (17-26)
|
| 582 |
+
left_eyebrow = points[17:22]
|
| 583 |
+
right_eyebrow = points[22:27]
|
| 584 |
+
|
| 585 |
+
# 眼睛点 (36-47)
|
| 586 |
+
left_eye = points[36:42]
|
| 587 |
+
right_eye = points[42:48]
|
| 588 |
+
|
| 589 |
+
# 鼻子点 (27-35)
|
| 590 |
+
nose = points[27:36]
|
| 591 |
+
|
| 592 |
+
# 嘴巴点 (48-67)
|
| 593 |
+
mouth = points[48:68]
|
| 594 |
+
|
| 595 |
+
# 基础测量
|
| 596 |
+
measurements["face_width"] = np.max(face_contour[:, 0]) - np.min(
|
| 597 |
+
face_contour[:, 0]
|
| 598 |
+
)
|
| 599 |
+
measurements["face_height"] = np.max(points[:, 1]) - np.min(points[:, 1])
|
| 600 |
+
|
| 601 |
+
# 眼部测量
|
| 602 |
+
measurements["left_eye_width"] = np.max(left_eye[:, 0]) - np.min(left_eye[:, 0])
|
| 603 |
+
measurements["right_eye_width"] = np.max(right_eye[:, 0]) - np.min(
|
| 604 |
+
right_eye[:, 0]
|
| 605 |
+
)
|
| 606 |
+
measurements["eye_distance"] = np.min(right_eye[:, 0]) - np.max(left_eye[:, 0])
|
| 607 |
+
measurements["left_eye_center"] = np.mean(left_eye, axis=0)
|
| 608 |
+
measurements["right_eye_center"] = np.mean(right_eye, axis=0)
|
| 609 |
+
|
| 610 |
+
# 鼻部测量
|
| 611 |
+
measurements["nose_width"] = np.max(nose[:, 0]) - np.min(nose[:, 0])
|
| 612 |
+
measurements["nose_height"] = np.max(nose[:, 1]) - np.min(nose[:, 1])
|
| 613 |
+
measurements["nose_tip"] = points[33] # 鼻尖
|
| 614 |
+
|
| 615 |
+
# 嘴部测量
|
| 616 |
+
measurements["mouth_width"] = np.max(mouth[:, 0]) - np.min(mouth[:, 0])
|
| 617 |
+
measurements["mouth_height"] = np.max(mouth[:, 1]) - np.min(mouth[:, 1])
|
| 618 |
+
|
| 619 |
+
# 关键垂直距离
|
| 620 |
+
measurements["forehead_height"] = measurements["left_eye_center"][1] - np.min(
|
| 621 |
+
points[:, 1]
|
| 622 |
+
)
|
| 623 |
+
measurements["middle_face_height"] = (
|
| 624 |
+
measurements["nose_tip"][1] - measurements["left_eye_center"][1]
|
| 625 |
+
)
|
| 626 |
+
measurements["lower_face_height"] = (
|
| 627 |
+
np.max(points[:, 1]) - measurements["nose_tip"][1]
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
return measurements
|
| 631 |
+
|
| 632 |
+
def _calculate_golden_ratios(self, measurements: Dict[str, float]) -> float:
|
| 633 |
+
"""计算黄金比例相关得分"""
|
| 634 |
+
golden_ratio = 1.618
|
| 635 |
+
scores = []
|
| 636 |
+
|
| 637 |
+
# 面部长宽比
|
| 638 |
+
if measurements["face_width"] > 0:
|
| 639 |
+
face_ratio = measurements["face_height"] / measurements["face_width"]
|
| 640 |
+
score = 1 - abs(face_ratio - golden_ratio) / golden_ratio
|
| 641 |
+
scores.append(max(0, score))
|
| 642 |
+
|
| 643 |
+
# 上中下三庭比例
|
| 644 |
+
total_height = (
|
| 645 |
+
measurements["forehead_height"]
|
| 646 |
+
+ measurements["middle_face_height"]
|
| 647 |
+
+ measurements["lower_face_height"]
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
if total_height > 0:
|
| 651 |
+
upper_ratio = measurements["forehead_height"] / total_height
|
| 652 |
+
middle_ratio = measurements["middle_face_height"] / total_height
|
| 653 |
+
lower_ratio = measurements["lower_face_height"] / total_height
|
| 654 |
+
|
| 655 |
+
# 理想比例约为 1:1:1
|
| 656 |
+
ideal_ratio = 1 / 3
|
| 657 |
+
upper_score = 1 - abs(upper_ratio - ideal_ratio) / ideal_ratio
|
| 658 |
+
middle_score = 1 - abs(middle_ratio - ideal_ratio) / ideal_ratio
|
| 659 |
+
lower_score = 1 - abs(lower_ratio - ideal_ratio) / ideal_ratio
|
| 660 |
+
|
| 661 |
+
scores.extend(
|
| 662 |
+
[max(0, upper_score), max(0, middle_score), max(0, lower_score)]
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
return np.mean(scores) * 10 if scores else 7.0
|
| 666 |
+
|
| 667 |
+
def _calculate_facial_symmetry(
|
| 668 |
+
self, measurements: Dict[str, float], points: np.ndarray
|
| 669 |
+
) -> float:
|
| 670 |
+
"""计算面部对称性"""
|
| 671 |
+
# 计算面部中线
|
| 672 |
+
face_center_x = np.mean(points[:, 0])
|
| 673 |
+
|
| 674 |
+
# 检查左右对称的关键点对
|
| 675 |
+
symmetry_pairs = [
|
| 676 |
+
(17, 26), # 眉毛外端
|
| 677 |
+
(18, 25), # 眉毛
|
| 678 |
+
(19, 24), # 眉毛
|
| 679 |
+
(36, 45), # 眼角
|
| 680 |
+
(39, 42), # 眼角
|
| 681 |
+
(31, 35), # 鼻翼
|
| 682 |
+
(48, 54), # 嘴角
|
| 683 |
+
(4, 12), # 面部轮廓
|
| 684 |
+
(5, 11), # 面部轮廓
|
| 685 |
+
(6, 10), # 面部轮廓
|
| 686 |
+
]
|
| 687 |
+
|
| 688 |
+
symmetry_scores = []
|
| 689 |
+
|
| 690 |
+
for left_idx, right_idx in symmetry_pairs:
|
| 691 |
+
if left_idx < len(points) and right_idx < len(points):
|
| 692 |
+
left_point = points[left_idx]
|
| 693 |
+
right_point = points[right_idx]
|
| 694 |
+
|
| 695 |
+
# 计算到中线的距离差异
|
| 696 |
+
left_dist = abs(left_point[0] - face_center_x)
|
| 697 |
+
right_dist = abs(right_point[0] - face_center_x)
|
| 698 |
+
|
| 699 |
+
# 垂直位置差异
|
| 700 |
+
vertical_diff = abs(left_point[1] - right_point[1])
|
| 701 |
+
|
| 702 |
+
# 对称性得分
|
| 703 |
+
if left_dist + right_dist > 0:
|
| 704 |
+
horizontal_symmetry = 1 - abs(left_dist - right_dist) / (
|
| 705 |
+
left_dist + right_dist
|
| 706 |
+
)
|
| 707 |
+
vertical_symmetry = 1 - vertical_diff / measurements.get(
|
| 708 |
+
"face_height", 100
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
symmetry_scores.append(
|
| 712 |
+
(horizontal_symmetry + vertical_symmetry) / 2
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
return np.mean(symmetry_scores) * 10 if symmetry_scores else 7.0
|
| 716 |
+
|
| 717 |
+
def _calculate_classical_proportions(self, measurements: Dict[str, float]) -> float:
|
| 718 |
+
"""计算经典美学比例 (三庭五眼等)"""
|
| 719 |
+
scores = []
|
| 720 |
+
|
| 721 |
+
# 五眼比例检测
|
| 722 |
+
if measurements["face_width"] > 0:
|
| 723 |
+
eye_width_avg = (
|
| 724 |
+
measurements["left_eye_width"] + measurements["right_eye_width"]
|
| 725 |
+
) / 2
|
| 726 |
+
ideal_eye_count = 5 # 理想情况下面宽应该等于5个眼宽
|
| 727 |
+
actual_eye_count = (
|
| 728 |
+
measurements["face_width"] / eye_width_avg if eye_width_avg > 0 else 5
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
eye_proportion_score = (
|
| 732 |
+
1 - abs(actual_eye_count - ideal_eye_count) / ideal_eye_count
|
| 733 |
+
)
|
| 734 |
+
scores.append(max(0, eye_proportion_score))
|
| 735 |
+
|
| 736 |
+
# 眼间距比例
|
| 737 |
+
if measurements.get("left_eye_width", 0) > 0:
|
| 738 |
+
eye_spacing_ratio = (
|
| 739 |
+
measurements["eye_distance"] / measurements["left_eye_width"]
|
| 740 |
+
)
|
| 741 |
+
ideal_spacing_ratio = 1.0 # 理想情况下眼间距约等于一个眼宽
|
| 742 |
+
|
| 743 |
+
spacing_score = (
|
| 744 |
+
1 - abs(eye_spacing_ratio - ideal_spacing_ratio) / ideal_spacing_ratio
|
| 745 |
+
)
|
| 746 |
+
scores.append(max(0, spacing_score))
|
| 747 |
+
|
| 748 |
+
# 鼻宽与眼宽比例
|
| 749 |
+
if (
|
| 750 |
+
measurements.get("left_eye_width", 0) > 0
|
| 751 |
+
and measurements.get("nose_width", 0) > 0
|
| 752 |
+
):
|
| 753 |
+
nose_eye_ratio = measurements["nose_width"] / measurements["left_eye_width"]
|
| 754 |
+
ideal_nose_eye_ratio = 0.8 # 理想鼻宽约为眼宽的80%
|
| 755 |
+
|
| 756 |
+
nose_score = (
|
| 757 |
+
1 - abs(nose_eye_ratio - ideal_nose_eye_ratio) / ideal_nose_eye_ratio
|
| 758 |
+
)
|
| 759 |
+
scores.append(max(0, nose_score))
|
| 760 |
+
|
| 761 |
+
return np.mean(scores) * 10 if scores else 7.0
|
| 762 |
+
|
| 763 |
+
def _calculate_feature_spacing(self, measurements: Dict[str, float]) -> float:
|
| 764 |
+
"""计算五官间距协调性"""
|
| 765 |
+
scores = []
|
| 766 |
+
|
| 767 |
+
# 眼鼻距离协调性
|
| 768 |
+
eye_nose_distance = abs(
|
| 769 |
+
measurements["left_eye_center"][1] - measurements["nose_tip"][1]
|
| 770 |
+
)
|
| 771 |
+
if measurements.get("face_height", 0) > 0:
|
| 772 |
+
eye_nose_ratio = eye_nose_distance / measurements["face_height"]
|
| 773 |
+
ideal_ratio = 0.15 # 理想比例
|
| 774 |
+
score = 1 - abs(eye_nose_ratio - ideal_ratio) / ideal_ratio
|
| 775 |
+
scores.append(max(0, score))
|
| 776 |
+
|
| 777 |
+
# 鼻嘴距离协调性
|
| 778 |
+
nose_mouth_distance = abs(
|
| 779 |
+
measurements["nose_tip"][1] - np.mean([measurements.get("mouth_height", 0)])
|
| 780 |
+
)
|
| 781 |
+
if measurements.get("face_height", 0) > 0:
|
| 782 |
+
nose_mouth_ratio = nose_mouth_distance / measurements["face_height"]
|
| 783 |
+
ideal_ratio = 0.12 # 理想比例
|
| 784 |
+
score = 1 - abs(nose_mouth_ratio - ideal_ratio) / ideal_ratio
|
| 785 |
+
scores.append(max(0, score))
|
| 786 |
+
|
| 787 |
+
return np.mean(scores) * 10 if scores else 7.0
|
| 788 |
+
|
| 789 |
+
def _calculate_contour_harmony(self, points: np.ndarray) -> float:
|
| 790 |
+
"""计算面部轮廓协调性"""
|
| 791 |
+
try:
|
| 792 |
+
face_contour = points[0:17] # 面部轮廓点
|
| 793 |
+
|
| 794 |
+
# 计算轮廓的平滑度
|
| 795 |
+
smoothness_scores = []
|
| 796 |
+
|
| 797 |
+
for i in range(1, len(face_contour) - 1):
|
| 798 |
+
# 计算相邻三点形成的角度
|
| 799 |
+
p1, p2, p3 = face_contour[i - 1], face_contour[i], face_contour[i + 1]
|
| 800 |
+
|
| 801 |
+
v1 = p1 - p2
|
| 802 |
+
v2 = p3 - p2
|
| 803 |
+
|
| 804 |
+
# 计算角度
|
| 805 |
+
cos_angle = np.dot(v1, v2) / (
|
| 806 |
+
np.linalg.norm(v1) * np.linalg.norm(v2) + 1e-8
|
| 807 |
+
)
|
| 808 |
+
angle = np.arccos(np.clip(cos_angle, -1, 1))
|
| 809 |
+
|
| 810 |
+
# 角度越接近平滑曲线越好 (避免过于尖锐的角度)
|
| 811 |
+
smoothness = 1 - abs(angle - np.pi / 2) / (np.pi / 2)
|
| 812 |
+
smoothness_scores.append(max(0, smoothness))
|
| 813 |
+
|
| 814 |
+
return np.mean(smoothness_scores) * 10 if smoothness_scores else 7.0
|
| 815 |
+
|
| 816 |
+
except:
|
| 817 |
+
return 6.21
|
| 818 |
+
|
| 819 |
+
def _calculate_feature_proportions(self, measurements: Dict[str, float]) -> float:
|
| 820 |
+
"""计算眼鼻口等五官内部比例协调性"""
|
| 821 |
+
scores = []
|
| 822 |
+
|
| 823 |
+
# 眼部比例 (长宽比)
|
| 824 |
+
left_eye_ratio = measurements.get("left_eye_width", 1) / max(
|
| 825 |
+
measurements.get("left_eye_width", 1) * 0.3, 1
|
| 826 |
+
)
|
| 827 |
+
right_eye_ratio = measurements.get("right_eye_width", 1) / max(
|
| 828 |
+
measurements.get("right_eye_width", 1) * 0.3, 1
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
# 理想眼部长宽比约为3:1
|
| 832 |
+
ideal_eye_ratio = 3.0
|
| 833 |
+
left_eye_score = 1 - abs(left_eye_ratio - ideal_eye_ratio) / ideal_eye_ratio
|
| 834 |
+
right_eye_score = 1 - abs(right_eye_ratio - ideal_eye_ratio) / ideal_eye_ratio
|
| 835 |
+
|
| 836 |
+
scores.extend([max(0, left_eye_score), max(0, right_eye_score)])
|
| 837 |
+
|
| 838 |
+
# 嘴部比例
|
| 839 |
+
if measurements.get("mouth_height", 0) > 0:
|
| 840 |
+
mouth_ratio = measurements["mouth_width"] / measurements["mouth_height"]
|
| 841 |
+
ideal_mouth_ratio = 3.5 # 理想嘴部长宽比
|
| 842 |
+
mouth_score = 1 - abs(mouth_ratio - ideal_mouth_ratio) / ideal_mouth_ratio
|
| 843 |
+
scores.append(max(0, mouth_score))
|
| 844 |
+
|
| 845 |
+
# 鼻部比例
|
| 846 |
+
if measurements.get("nose_height", 0) > 0:
|
| 847 |
+
nose_ratio = measurements["nose_height"] / measurements["nose_width"]
|
| 848 |
+
ideal_nose_ratio = 1.5 # 理想鼻部长宽比
|
| 849 |
+
nose_score = 1 - abs(nose_ratio - ideal_nose_ratio) / ideal_nose_ratio
|
| 850 |
+
scores.append(max(0, nose_score))
|
| 851 |
+
|
| 852 |
+
return np.mean(scores) * 10 if scores else 7.0
|
| 853 |
+
|
| 854 |
+
def _basic_facial_analysis(self, face_image) -> Dict[str, Any]:
|
| 855 |
+
"""基础五官分析 (当dlib不可用时)"""
|
| 856 |
+
return {
|
| 857 |
+
"facial_features": {
|
| 858 |
+
"eyes": 7.0,
|
| 859 |
+
"nose": 7.0,
|
| 860 |
+
"mouth": 7.0,
|
| 861 |
+
"eyebrows": 7.0,
|
| 862 |
+
"jawline": 7.0,
|
| 863 |
+
},
|
| 864 |
+
"harmony_score": 7.0,
|
| 865 |
+
"overall_facial_score": 7.0,
|
| 866 |
+
"analysis_method": "basic_estimation",
|
| 867 |
+
}
|
| 868 |
+
|
| 869 |
+
def draw_facial_landmarks(self, face_image: np.ndarray) -> np.ndarray:
|
| 870 |
+
"""
|
| 871 |
+
在人脸图像上绘制特征点
|
| 872 |
+
:param face_image: 人脸图像
|
| 873 |
+
:return: 带特征点标记的人脸图像
|
| 874 |
+
"""
|
| 875 |
+
if not DLIB_AVAILABLE or self.face_mesh is None:
|
| 876 |
+
# 如果没有可用的面部网格检测器,直接返回原图
|
| 877 |
+
return face_image.copy()
|
| 878 |
+
|
| 879 |
+
try:
|
| 880 |
+
# 复制原图用于绘制
|
| 881 |
+
annotated_image = face_image.copy()
|
| 882 |
+
|
| 883 |
+
# MediaPipe需要RGB图像
|
| 884 |
+
rgb_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
|
| 885 |
+
|
| 886 |
+
# 检测关键点
|
| 887 |
+
results = self.face_mesh.process(rgb_image)
|
| 888 |
+
|
| 889 |
+
if not results.multi_face_landmarks:
|
| 890 |
+
logger.warning("No facial landmarks detected for drawing")
|
| 891 |
+
return annotated_image
|
| 892 |
+
|
| 893 |
+
# 获取第一个面部的关键点
|
| 894 |
+
face_landmarks = results.multi_face_landmarks[0]
|
| 895 |
+
|
| 896 |
+
# 绘制所有关键点
|
| 897 |
+
h, w = face_image.shape[:2]
|
| 898 |
+
for landmark in face_landmarks.landmark:
|
| 899 |
+
x = int(landmark.x * w)
|
| 900 |
+
y = int(landmark.y * h)
|
| 901 |
+
# 绘制小圆点表示关键点
|
| 902 |
+
cv2.circle(annotated_image, (x, y), 1, (0, 255, 0), -1)
|
| 903 |
+
|
| 904 |
+
# 绘制十字标记
|
| 905 |
+
cv2.line(annotated_image, (x-2, y), (x+2, y), (0, 255, 0), 1)
|
| 906 |
+
cv2.line(annotated_image, (x, y-2), (x, y+2), (0, 255, 0), 1)
|
| 907 |
+
|
| 908 |
+
return annotated_image
|
| 909 |
+
|
| 910 |
+
except Exception as e:
|
| 911 |
+
logger.error(f"Failed to draw facial landmarks: {e}")
|
| 912 |
+
return face_image.copy()
|
gfpgan_restorer.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
from config import logger, MODELS_PATH
|
| 5 |
+
from gfpgan import GFPGANer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GFPGANRestorer:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
start_time = time.perf_counter()
|
| 11 |
+
self.restorer = None
|
| 12 |
+
self._initialize_model()
|
| 13 |
+
init_time = time.perf_counter() - start_time
|
| 14 |
+
if self.restorer is not None:
|
| 15 |
+
logger.info(f"GFPGANRestorer initialized successfully, time: {init_time:.3f}s")
|
| 16 |
+
else:
|
| 17 |
+
logger.info(f"GFPGANRestorer initialization completed but not available, time: {init_time:.3f}s")
|
| 18 |
+
|
| 19 |
+
def _initialize_model(self):
|
| 20 |
+
"""初始化GFPGAN模型"""
|
| 21 |
+
try:
|
| 22 |
+
# 尝试多个可能的模型路径
|
| 23 |
+
possible_paths = [
|
| 24 |
+
f"{MODELS_PATH}/GFPGANv1.4.pth",
|
| 25 |
+
f"{MODELS_PATH}/gfpgan/GFPGANv1.4.pth",
|
| 26 |
+
os.path.expanduser("~/.cache/gfpgan/GFPGANv1.4.pth"),
|
| 27 |
+
"./models/GFPGANv1.4.pth"
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
gfpgan_model_path = None
|
| 31 |
+
for path in possible_paths:
|
| 32 |
+
if os.path.exists(path):
|
| 33 |
+
gfpgan_model_path = path
|
| 34 |
+
break
|
| 35 |
+
|
| 36 |
+
if not gfpgan_model_path:
|
| 37 |
+
logger.warning(f"GFPGAN model file not found, tried paths: {possible_paths}")
|
| 38 |
+
logger.info("Will try to download GFPGAN model from the internet...")
|
| 39 |
+
# 使用默认路径,让GFPGAN自动下载
|
| 40 |
+
gfpgan_model_path = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
|
| 41 |
+
|
| 42 |
+
logger.info(f"Using GFPGAN model: {gfpgan_model_path}")
|
| 43 |
+
|
| 44 |
+
# 初始化GFPGAN
|
| 45 |
+
self.restorer = GFPGANer(
|
| 46 |
+
model_path=gfpgan_model_path,
|
| 47 |
+
upscale=2,
|
| 48 |
+
arch='clean',
|
| 49 |
+
channel_multiplier=2,
|
| 50 |
+
bg_upsampler=None
|
| 51 |
+
)
|
| 52 |
+
logger.info("GFPGAN model initialized successfully")
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
logger.error(f"GFPGAN model initialization failed: {e}")
|
| 56 |
+
self.restorer = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def is_available(self):
|
| 60 |
+
"""检查GFPGAN是否可用"""
|
| 61 |
+
return self.restorer is not None
|
| 62 |
+
|
| 63 |
+
def restore_image(self, image):
|
| 64 |
+
"""
|
| 65 |
+
使用GFPGAN修复老照片
|
| 66 |
+
:param image: 输入图像 (numpy array, BGR格式)
|
| 67 |
+
:return: 修复后的图像 (numpy array, BGR格式)
|
| 68 |
+
"""
|
| 69 |
+
if not self.is_available():
|
| 70 |
+
raise Exception("GFPGAN模型未初始化")
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
logger.info("Starting GFPGAN image restoration...")
|
| 74 |
+
|
| 75 |
+
# GFPGAN处理
|
| 76 |
+
# has_aligned=False: 输入图像没有对齐
|
| 77 |
+
# only_center_face=False: 处理所有检测到的人脸
|
| 78 |
+
# paste_back=True: 将修复的人脸贴回原图
|
| 79 |
+
cropped_faces, restored_faces, restored_img = self.restorer.enhance(
|
| 80 |
+
image,
|
| 81 |
+
has_aligned=False,
|
| 82 |
+
only_center_face=False,
|
| 83 |
+
paste_back=True
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
if restored_img is not None:
|
| 87 |
+
logger.info(f"GFPGAN restoration completed, detected {len(restored_faces)} faces")
|
| 88 |
+
return restored_img
|
| 89 |
+
else:
|
| 90 |
+
logger.warning("GFPGAN restoration returned empty image, using original image")
|
| 91 |
+
return image
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.error(f"GFPGAN image restoration failed: {e}")
|
| 95 |
+
# 如果GFPGAN失败,返回原图而不是抛出异常
|
| 96 |
+
return image
|
install.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pip install -r requirements.txt -i https://pypi.python.org/simple
|
| 2 |
+
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
models.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ModelType(str, Enum):
|
| 8 |
+
"""模型类型枚举"""
|
| 9 |
+
|
| 10 |
+
HOWCUTEAMI = "howcuteami"
|
| 11 |
+
DEEPFACE = "deepface"
|
| 12 |
+
HYBRID = "hybrid" # 混合模式:颜值性别用howcuteami,年龄情绪用deepface
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ImageScoreItem(BaseModel):
|
| 16 |
+
file_path: str
|
| 17 |
+
score: float
|
| 18 |
+
is_cropped_face: bool = False
|
| 19 |
+
size_bytes: int
|
| 20 |
+
size_str: str
|
| 21 |
+
last_modified: str
|
| 22 |
+
nickname: Optional[str] = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SearchRequest(BaseModel):
|
| 26 |
+
keyword: Optional[str] = ""
|
| 27 |
+
searchType: Optional[str] = "face"
|
| 28 |
+
top_k: Optional[int] = 5
|
| 29 |
+
score_threshold: float = 0.0
|
| 30 |
+
nickname: Optional[str] = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ImageSearchRequest(BaseModel):
|
| 34 |
+
image: Optional[str] = None # base64编码的图片
|
| 35 |
+
searchType: Optional[str] = "face"
|
| 36 |
+
top_k: Optional[int] = 5
|
| 37 |
+
score_threshold: float = 0.0
|
| 38 |
+
nickname: Optional[str] = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ImageFileList(BaseModel):
|
| 42 |
+
results: List[ImageScoreItem]
|
| 43 |
+
count: int
|
| 44 |
+
|
| 45 |
+
class PagedImageFileList(BaseModel):
|
| 46 |
+
results: List[ImageScoreItem]
|
| 47 |
+
count: int
|
| 48 |
+
page: int
|
| 49 |
+
page_size: int
|
| 50 |
+
total_pages: int
|
| 51 |
+
|
| 52 |
+
class CelebrityMatchResponse(BaseModel):
|
| 53 |
+
filename: str
|
| 54 |
+
display_name: Optional[str] = None
|
| 55 |
+
distance: float
|
| 56 |
+
similarity: float
|
| 57 |
+
confidence: float
|
| 58 |
+
face_filename: Optional[str] = None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class CategoryStatItem(BaseModel):
|
| 62 |
+
category: str
|
| 63 |
+
display_name: str
|
| 64 |
+
count: int
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class CategoryStatsResponse(BaseModel):
|
| 68 |
+
stats: List[CategoryStatItem]
|
| 69 |
+
total: int
|
push.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
git push -f origin main
|
realesrgan_upscaler.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from config import logger, MODELS_PATH, REALESRGAN_MODEL
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 11 |
+
from basicsr.utils.download_util import load_file_from_url
|
| 12 |
+
from realesrgan import RealESRGANer
|
| 13 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
# 设置PyTorch CPU优化
|
| 17 |
+
torch.set_num_threads(min(torch.get_num_threads(), 8)) # 限制线程数
|
| 18 |
+
torch.set_num_interop_threads(min(4, torch.get_num_interop_threads())) # 设置操作间线程数
|
| 19 |
+
|
| 20 |
+
REALESRGAN_AVAILABLE = True
|
| 21 |
+
logger.info("Real-ESRGAN imported successfully")
|
| 22 |
+
except ImportError as e:
|
| 23 |
+
logger.error(f"Real-ESRGAN import failed: {e}")
|
| 24 |
+
REALESRGAN_AVAILABLE = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class RealESRGANUpscaler:
|
| 28 |
+
"""Real-ESRGAN超清放大处理器"""
|
| 29 |
+
|
| 30 |
+
def __init__(self):
|
| 31 |
+
start_time = time.perf_counter()
|
| 32 |
+
self.upsampler = None
|
| 33 |
+
self.model_name = None
|
| 34 |
+
self.scale = 4
|
| 35 |
+
self.denoise_strength = 0.5
|
| 36 |
+
self._initialize()
|
| 37 |
+
init_time = time.perf_counter() - start_time
|
| 38 |
+
if self.upsampler is not None:
|
| 39 |
+
logger.info(f"RealESRGANUpscaler initialized successfully, time: {init_time:.3f}s")
|
| 40 |
+
else:
|
| 41 |
+
logger.info(f"RealESRGANUpscaler initialization completed but not available, time: {init_time:.3f}s")
|
| 42 |
+
|
| 43 |
+
def _initialize(self):
|
| 44 |
+
"""初始化Real-ESRGAN模型"""
|
| 45 |
+
if not REALESRGAN_AVAILABLE:
|
| 46 |
+
logger.error("Real-ESRGAN is not available, cannot initialize super resolution processor")
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
# 模型配置 - 从环境变量读取模型名称
|
| 51 |
+
model_name = REALESRGAN_MODEL
|
| 52 |
+
self.model_name = model_name
|
| 53 |
+
|
| 54 |
+
# 根据模型名称设置默认放大倍数
|
| 55 |
+
if 'x2' in model_name:
|
| 56 |
+
self.scale = 2
|
| 57 |
+
elif 'x4' in model_name:
|
| 58 |
+
self.scale = 4
|
| 59 |
+
else:
|
| 60 |
+
self.scale = 4 # 默认4倍
|
| 61 |
+
|
| 62 |
+
# 模型文件路径
|
| 63 |
+
model_path = None
|
| 64 |
+
if model_name == 'RealESRGAN_x4plus':
|
| 65 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
| 66 |
+
netscale = 4
|
| 67 |
+
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
|
| 68 |
+
elif model_name == 'RealESRNet_x4plus':
|
| 69 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
| 70 |
+
netscale = 4
|
| 71 |
+
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'
|
| 72 |
+
elif model_name == 'RealESRGAN_x4plus_anime_6B':
|
| 73 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
| 74 |
+
netscale = 4
|
| 75 |
+
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'
|
| 76 |
+
elif model_name == 'RealESRGAN_x2plus':
|
| 77 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
| 78 |
+
netscale = 2
|
| 79 |
+
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
|
| 80 |
+
elif model_name == 'realesr-animevideov3':
|
| 81 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
| 82 |
+
netscale = 4
|
| 83 |
+
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth'
|
| 84 |
+
elif model_name == 'realesr-general-x4v3':
|
| 85 |
+
# 最新的通用模型 v0.2.5.0
|
| 86 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
| 87 |
+
netscale = 4
|
| 88 |
+
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
|
| 89 |
+
elif model_name == 'realesr-general-wdn-x4v3':
|
| 90 |
+
# 最新的通用模型(带去噪)v0.2.5.0
|
| 91 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
| 92 |
+
netscale = 4
|
| 93 |
+
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
|
| 94 |
+
|
| 95 |
+
# 确保模型目录存在
|
| 96 |
+
model_dir = os.path.join(MODELS_PATH, 'realesrgan')
|
| 97 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 98 |
+
|
| 99 |
+
# 检查本地是否已有模型文件
|
| 100 |
+
local_model_path = None
|
| 101 |
+
model_filename = f"{model_name}.pth"
|
| 102 |
+
local_pth = os.path.join(MODELS_PATH, model_filename)
|
| 103 |
+
|
| 104 |
+
if os.path.exists(local_pth):
|
| 105 |
+
local_model_path = local_pth
|
| 106 |
+
logger.info(f"Using local model file: {local_model_path}")
|
| 107 |
+
|
| 108 |
+
# 如果本地有模型文件,使用本地文件,��则下载
|
| 109 |
+
if local_model_path:
|
| 110 |
+
model_path = local_model_path
|
| 111 |
+
else:
|
| 112 |
+
# 下载模型
|
| 113 |
+
logger.info(f"Downloading model {model_name} from {file_url}")
|
| 114 |
+
model_path = load_file_from_url(
|
| 115 |
+
url=file_url, model_dir=model_dir, progress=True, file_name=model_filename)
|
| 116 |
+
|
| 117 |
+
# 创建upsampler
|
| 118 |
+
self.upsampler = RealESRGANer(
|
| 119 |
+
scale=netscale,
|
| 120 |
+
model_path=model_path,
|
| 121 |
+
model=model,
|
| 122 |
+
tile=512, # 启用分块处理,减少内存使用并提高CPU效率
|
| 123 |
+
tile_pad=10,
|
| 124 |
+
pre_pad=0,
|
| 125 |
+
half=False, # 使用fp32精度
|
| 126 |
+
gpu_id=None # 使用CPU
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
logger.info(f"Real-ESRGAN super resolution processor initialized successfully, model: {model_name}")
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"Failed to initialize Real-ESRGAN: {e}")
|
| 133 |
+
self.upsampler = None
|
| 134 |
+
|
| 135 |
+
def is_available(self):
|
| 136 |
+
"""检查处理器是否可用"""
|
| 137 |
+
return REALESRGAN_AVAILABLE and self.upsampler is not None
|
| 138 |
+
|
| 139 |
+
def _optimize_input_image(self, image):
|
| 140 |
+
"""
|
| 141 |
+
优化输入图像以提高CPU处理速度
|
| 142 |
+
:param image: 输入图像
|
| 143 |
+
:return: 优化后的图像
|
| 144 |
+
"""
|
| 145 |
+
# 确保图像数据类型为uint8(减少计算开销)
|
| 146 |
+
if image.dtype != np.uint8:
|
| 147 |
+
if image.dtype == np.float32 or image.dtype == np.float64:
|
| 148 |
+
image = (image * 255).astype(np.uint8)
|
| 149 |
+
else:
|
| 150 |
+
image = image.astype(np.uint8)
|
| 151 |
+
|
| 152 |
+
# 确保图像是3通道BGR格式
|
| 153 |
+
if len(image.shape) == 2: # 灰度图
|
| 154 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
| 155 |
+
elif image.shape[2] == 4: # RGBA
|
| 156 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
|
| 157 |
+
elif image.shape[2] == 3 and image.shape[2] != 3: # RGB转BGR
|
| 158 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 159 |
+
|
| 160 |
+
return image
|
| 161 |
+
|
| 162 |
+
def upscale_image(self, image, scale=None, denoise_strength=None):
|
| 163 |
+
"""
|
| 164 |
+
对图像进行超清放大
|
| 165 |
+
:param image: 输入图像 (numpy array)
|
| 166 |
+
:param scale: 放大倍数,默认使用模型的放大倍数
|
| 167 |
+
:param denoise_strength: 去噪强度 (0-1),仅对realesr-general-x4v3模型有效
|
| 168 |
+
:return: 超清后的图像
|
| 169 |
+
"""
|
| 170 |
+
if not self.is_available():
|
| 171 |
+
raise RuntimeError("Real-ESRGAN超清处理器不可用")
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
start_time = time.perf_counter()
|
| 175 |
+
|
| 176 |
+
# 预处理优化图像
|
| 177 |
+
image = self._optimize_input_image(image)
|
| 178 |
+
|
| 179 |
+
# 设置去噪强度(仅对特定模型有效)
|
| 180 |
+
if denoise_strength is not None and self.model_name == 'realesr-general-x4v3':
|
| 181 |
+
self.denoise_strength = denoise_strength
|
| 182 |
+
|
| 183 |
+
# 根据图像大小动态调整tile大小以优化CPU性能
|
| 184 |
+
h, w = image.shape[:2]
|
| 185 |
+
pixel_count = h * w
|
| 186 |
+
|
| 187 |
+
# 根据图像大小调整tile大小
|
| 188 |
+
if pixel_count > 2000000: # 大于2MP
|
| 189 |
+
tile_size = 256
|
| 190 |
+
elif pixel_count > 1000000: # 大于1MP
|
| 191 |
+
tile_size = 384
|
| 192 |
+
else:
|
| 193 |
+
tile_size = 512
|
| 194 |
+
|
| 195 |
+
# 动态更新tile大小
|
| 196 |
+
if hasattr(self.upsampler, 'tile'):
|
| 197 |
+
self.upsampler.tile = tile_size
|
| 198 |
+
logger.info(f"Adjusting tile size to: {tile_size} based on image size ({w}x{h})")
|
| 199 |
+
|
| 200 |
+
# 执行超清处理
|
| 201 |
+
logger.info(f"Starting Real-ESRGAN super resolution processing, model: {self.model_name}")
|
| 202 |
+
output, _ = self.upsampler.enhance(image, outscale=scale or self.scale)
|
| 203 |
+
|
| 204 |
+
processing_time = time.perf_counter() - start_time
|
| 205 |
+
logger.info(f"Real-ESRGAN super resolution processing completed, time: {processing_time:.3f}s")
|
| 206 |
+
|
| 207 |
+
return output
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
logger.error(f"Real-ESRGAN super resolution processing failed: {e}")
|
| 211 |
+
raise RuntimeError(f"超清处理失败: {str(e)}")
|
| 212 |
+
|
| 213 |
+
def get_model_info(self):
|
| 214 |
+
"""获取模型信息"""
|
| 215 |
+
return {
|
| 216 |
+
"model_name": self.model_name,
|
| 217 |
+
"scale": self.scale,
|
| 218 |
+
"available": self.is_available()
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def get_upscaler():
|
| 223 |
+
"""获取Real-ESRGAN超清处理器实例"""
|
| 224 |
+
return RealESRGANUpscaler()
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# 全局实例(单例模式)
|
| 228 |
+
_upscaler_instance = None
|
| 229 |
+
|
| 230 |
+
def get_upscaler():
|
| 231 |
+
"""获取全局超清处理器实例"""
|
| 232 |
+
global _upscaler_instance
|
| 233 |
+
if _upscaler_instance is None:
|
| 234 |
+
_upscaler_instance = RealESRGANUpscaler()
|
| 235 |
+
return _upscaler_instance
|
rembg_processor.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from config import logger, REMBG_AVAILABLE
|
| 8 |
+
|
| 9 |
+
if REMBG_AVAILABLE:
|
| 10 |
+
import rembg
|
| 11 |
+
from rembg import new_session
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RembgProcessor:
|
| 16 |
+
"""rembg抠图处理器"""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
start_time = time.perf_counter()
|
| 20 |
+
self.session = None
|
| 21 |
+
self.available = False
|
| 22 |
+
self.model_name = "u2net" # 默认使用u2net模型,适合人像抠图
|
| 23 |
+
|
| 24 |
+
if REMBG_AVAILABLE:
|
| 25 |
+
try:
|
| 26 |
+
# 初始化rembg会话
|
| 27 |
+
self.session = new_session(self.model_name)
|
| 28 |
+
self.available = True
|
| 29 |
+
logger.info(f"rembg background removal processor initialized successfully, using model: {self.model_name}")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logger.error(f"rembg background removal processor initialization failed: {e}")
|
| 32 |
+
self.available = False
|
| 33 |
+
else:
|
| 34 |
+
logger.warning("rembg is not available, background removal function will be disabled")
|
| 35 |
+
init_time = time.perf_counter() - start_time
|
| 36 |
+
if self.available:
|
| 37 |
+
logger.info(f"RembgProcessor initialized successfully, time: {init_time:.3f}s")
|
| 38 |
+
else:
|
| 39 |
+
logger.info(f"RembgProcessor initialization completed but not available, time: {init_time:.3f}s")
|
| 40 |
+
|
| 41 |
+
def is_available(self) -> bool:
|
| 42 |
+
"""检查抠图处理器是否可用"""
|
| 43 |
+
return self.available and self.session is not None
|
| 44 |
+
|
| 45 |
+
def remove_background(self, image: np.ndarray, background_color: Optional[Tuple[int, int, int]] = None) -> np.ndarray:
|
| 46 |
+
"""
|
| 47 |
+
移除图片背景
|
| 48 |
+
:param image: 输入的OpenCV图像(BGR格式)
|
| 49 |
+
:param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景
|
| 50 |
+
:return: 处理后的图像
|
| 51 |
+
"""
|
| 52 |
+
if not self.is_available():
|
| 53 |
+
raise Exception("rembg抠图处理器不可用")
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
# 将OpenCV图像(BGR)转换为PIL图像(RGB)
|
| 57 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 58 |
+
pil_image = Image.fromarray(image_rgb)
|
| 59 |
+
|
| 60 |
+
# 使用rembg移除背景
|
| 61 |
+
logger.info("Starting to remove background using rembg...")
|
| 62 |
+
output_image = rembg.remove(pil_image, session=self.session)
|
| 63 |
+
|
| 64 |
+
# 转换回OpenCV格式
|
| 65 |
+
if background_color is not None:
|
| 66 |
+
# 如果指定了背景颜色,创建纯色背景
|
| 67 |
+
background = Image.new('RGB', output_image.size, background_color[::-1]) # BGR转RGB
|
| 68 |
+
# 将透明图像粘贴到背景上
|
| 69 |
+
background.paste(output_image, mask=output_image)
|
| 70 |
+
result_array = np.array(background)
|
| 71 |
+
result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
|
| 72 |
+
else:
|
| 73 |
+
# 保持透明背景,转换为BGRA格式
|
| 74 |
+
result_array = np.array(output_image)
|
| 75 |
+
if result_array.shape[2] == 4: # RGBA格式
|
| 76 |
+
# 转换RGBA到BGRA
|
| 77 |
+
result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGBA2BGRA)
|
| 78 |
+
else: # RGB格式
|
| 79 |
+
result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
|
| 80 |
+
|
| 81 |
+
logger.info("rembg background removal completed")
|
| 82 |
+
return result_bgr
|
| 83 |
+
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.error(f"rembg background removal failed: {e}")
|
| 86 |
+
raise Exception(f"背景移除失败: {str(e)}")
|
| 87 |
+
|
| 88 |
+
def create_id_photo(self, image: np.ndarray, background_color: Tuple[int, int, int] = (255, 255, 255)) -> np.ndarray:
|
| 89 |
+
"""
|
| 90 |
+
创建证件照(移除背景并添加纯色背景)
|
| 91 |
+
:param image: 输入的OpenCV图像
|
| 92 |
+
:param background_color: 背景颜色,默认白色(BGR格式)
|
| 93 |
+
:return: 处理后的证件照
|
| 94 |
+
"""
|
| 95 |
+
logger.info(f"Starting to create ID photo, background color: {background_color}")
|
| 96 |
+
|
| 97 |
+
# 移除背景并添加指定颜色背景
|
| 98 |
+
id_photo = self.remove_background(image, background_color)
|
| 99 |
+
|
| 100 |
+
logger.info("ID photo creation completed")
|
| 101 |
+
return id_photo
|
| 102 |
+
|
| 103 |
+
def get_supported_models(self) -> list:
|
| 104 |
+
"""获取支持的模型列表"""
|
| 105 |
+
if not REMBG_AVAILABLE:
|
| 106 |
+
return []
|
| 107 |
+
|
| 108 |
+
# rembg支持的模型列表
|
| 109 |
+
return [
|
| 110 |
+
"u2net", # 通用模型,适合人像
|
| 111 |
+
"u2net_human_seg", # 专门针对人像的模型
|
| 112 |
+
"silueta", # 适合物体抠图
|
| 113 |
+
"isnet-general-use" # 更精确的通用模型
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
def switch_model(self, model_name: str) -> bool:
|
| 117 |
+
"""
|
| 118 |
+
切换rembg模型
|
| 119 |
+
:param model_name: 模型名称
|
| 120 |
+
:return: 是否切换成功
|
| 121 |
+
"""
|
| 122 |
+
if not REMBG_AVAILABLE:
|
| 123 |
+
return False
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
if model_name in self.get_supported_models():
|
| 127 |
+
self.session = new_session(model_name)
|
| 128 |
+
self.model_name = model_name
|
| 129 |
+
logger.info(f"rembg model switched to: {model_name}")
|
| 130 |
+
return True
|
| 131 |
+
else:
|
| 132 |
+
logger.error(f"Unsupported model: {model_name}")
|
| 133 |
+
return False
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"Failed to switch model: {e}")
|
| 136 |
+
return False
|
requirements.txt
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 固定NumPy版本避免兼容性问题 - 必须最先安装
|
| 2 |
+
numpy>=1.24.0,<2.0.0
|
| 3 |
+
|
| 4 |
+
# 基础依赖
|
| 5 |
+
fastapi>=0.104.0
|
| 6 |
+
uvicorn[standard]>=0.24.0
|
| 7 |
+
python-multipart>=0.0.6
|
| 8 |
+
aiofiles>=23.2.1
|
| 9 |
+
|
| 10 |
+
# 图像处理
|
| 11 |
+
opencv-python>=4.8.0
|
| 12 |
+
Pillow>=10.0.0
|
| 13 |
+
|
| 14 |
+
# PyTorch 相关包 - 升级到2.x版本解决依赖冲突
|
| 15 |
+
torch>=2.0.0,<2.9.0
|
| 16 |
+
torchvision>=0.15.0
|
| 17 |
+
|
| 18 |
+
# 机器学习和CV相关
|
| 19 |
+
tf-keras
|
| 20 |
+
aiohttp
|
| 21 |
+
ultralytics
|
| 22 |
+
deepface
|
| 23 |
+
mediapipe>=0.10.0
|
| 24 |
+
# ModelScope相关包 - 让pip自动解决版本依赖
|
| 25 |
+
modelscope==1.28.2
|
| 26 |
+
datasets==2.21.0
|
| 27 |
+
transformers==4.40.0
|
| 28 |
+
# ModelScope DDColor的额外依赖
|
| 29 |
+
timm==1.0.19
|
| 30 |
+
sortedcontainers==2.4.0
|
| 31 |
+
fsspec==2024.6.1
|
| 32 |
+
multiprocess==0.70.16
|
| 33 |
+
xxhash==3.5.0
|
| 34 |
+
dill==0.3.8
|
| 35 |
+
huggingface-hub==0.34.3
|
| 36 |
+
# 修复pyarrow兼容性问题 - 使用稳定版本
|
| 37 |
+
pyarrow==20.0.0
|
| 38 |
+
|
| 39 |
+
# API相关
|
| 40 |
+
pydantic>=2.4.0
|
| 41 |
+
starlette>=0.27.0
|
| 42 |
+
simplejson==3.20.1
|
| 43 |
+
# 科学计算和工具
|
| 44 |
+
scipy>=1.7.0,<1.13.0
|
| 45 |
+
tqdm
|
| 46 |
+
lmdb
|
| 47 |
+
pyyaml
|
| 48 |
+
|
| 49 |
+
# 定时任务
|
| 50 |
+
apscheduler>=3.10.0
|
| 51 |
+
|
| 52 |
+
# 数据库
|
| 53 |
+
aiomysql>=0.2.0
|
| 54 |
+
|
| 55 |
+
# 对象存储
|
| 56 |
+
boto3>=1.34.0
|
| 57 |
+
|
| 58 |
+
# GFPGAN 和相关包 - 修复依赖兼容性
|
| 59 |
+
basicsr>=1.3.3
|
| 60 |
+
facexlib>=0.2.5
|
| 61 |
+
gfpgan>=1.3.0
|
| 62 |
+
realesrgan>=0.3.0
|
| 63 |
+
|
| 64 |
+
# CLIP 相关依赖
|
| 65 |
+
cn_clip
|
| 66 |
+
faiss-cpu
|
| 67 |
+
onnxruntime
|
| 68 |
+
diffusers
|
| 69 |
+
accelerate
|
| 70 |
+
# rembg 抠图处理
|
| 71 |
+
rembg>=2.0.50
|
| 72 |
+
easydict
|
rvm_processor.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
|
| 7 |
+
import config
|
| 8 |
+
from config import logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RVMProcessor:
|
| 12 |
+
"""RVM (Robust Video Matting) 抠图处理器"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.model = None
|
| 16 |
+
self.available = False
|
| 17 |
+
self.device = "cpu" # 默认使用CPU,如果有GPU可以设置为"cuda"
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
# 仅从本地加载,不使用网络
|
| 21 |
+
local_repo = getattr(config, 'RVM_LOCAL_REPO', '')
|
| 22 |
+
weights_path = getattr(config, 'RVM_WEIGHTS_PATH', '')
|
| 23 |
+
|
| 24 |
+
if not local_repo or not os.path.isdir(local_repo):
|
| 25 |
+
raise RuntimeError("RVM_LOCAL_REPO not set or invalid. Please set env RVM_LOCAL_REPO to local RobustVideoMatting repo path (with hubconf.py)")
|
| 26 |
+
|
| 27 |
+
if not weights_path or not os.path.isfile(weights_path):
|
| 28 |
+
raise RuntimeError("RVM_WEIGHTS_PATH not set or file not found. Please set env RVM_WEIGHTS_PATH to local RVM weights file path")
|
| 29 |
+
|
| 30 |
+
logger.info(f"Loading RVM model {config.RVM_MODEL} from local repo: {local_repo}")
|
| 31 |
+
# 使用本地仓库构建模型,禁用预训练以避免联网
|
| 32 |
+
self.model = torch.hub.load(local_repo, config.RVM_MODEL, source='local', pretrained=False)
|
| 33 |
+
|
| 34 |
+
# 加载本地权重
|
| 35 |
+
state = torch.load(weights_path, map_location=self.device)
|
| 36 |
+
if isinstance(state, dict) and 'state_dict' in state:
|
| 37 |
+
state = state['state_dict']
|
| 38 |
+
missing, unexpected = self.model.load_state_dict(state, strict=False)
|
| 39 |
+
|
| 40 |
+
# 迁移到设备并设置评估模式
|
| 41 |
+
self.model = self.model.to(self.device).eval()
|
| 42 |
+
self.available = True
|
| 43 |
+
logger.info("RVM background removal processor initialized successfully (local mode)")
|
| 44 |
+
if missing:
|
| 45 |
+
logger.warning(f"RVM weights missing keys: {list(missing)[:5]}... total={len(missing)}")
|
| 46 |
+
if unexpected:
|
| 47 |
+
logger.warning(f"RVM weights unexpected keys: {list(unexpected)[:5]}... total={len(unexpected)}")
|
| 48 |
+
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logger.error(f"RVM background removal processor initialization failed: {e}")
|
| 51 |
+
self.available = False
|
| 52 |
+
|
| 53 |
+
def is_available(self) -> bool:
|
| 54 |
+
"""检查RVM处理器是否可用"""
|
| 55 |
+
return self.available and self.model is not None
|
| 56 |
+
|
| 57 |
+
def remove_background(self, image: np.ndarray, background_color: tuple = None) -> np.ndarray:
|
| 58 |
+
"""
|
| 59 |
+
使用RVM移除图片背景
|
| 60 |
+
:param image: 输入的OpenCV图像(BGR格式)
|
| 61 |
+
:param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景
|
| 62 |
+
:return: 处理后的图像
|
| 63 |
+
"""
|
| 64 |
+
if not self.is_available():
|
| 65 |
+
raise Exception("RVM抠图处理器不可用")
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
logger.info("Starting to remove background using RVM...")
|
| 69 |
+
|
| 70 |
+
# 保存原始图像尺寸
|
| 71 |
+
original_height, original_width = image.shape[:2]
|
| 72 |
+
|
| 73 |
+
# 将OpenCV图像(BGR)转换为RGB格式
|
| 74 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 75 |
+
|
| 76 |
+
# 转换为tensor
|
| 77 |
+
src = transforms.ToTensor()(image_rgb).unsqueeze(0).to(self.device)
|
| 78 |
+
|
| 79 |
+
# 推理
|
| 80 |
+
rec = [None] * 4
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
fgr, pha, *rec = self.model(src, *rec, downsample_ratio=0.25)
|
| 83 |
+
|
| 84 |
+
# 转换为numpy数组
|
| 85 |
+
fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # (H,W,3)
|
| 86 |
+
pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) # (H,W)
|
| 87 |
+
|
| 88 |
+
# 检查尺寸是否匹配,如果不匹配则调整
|
| 89 |
+
if fgr.shape[:2] != (original_height, original_width):
|
| 90 |
+
fgr = cv2.resize(fgr, (original_width, original_height))
|
| 91 |
+
pha = cv2.resize(pha, (original_width, original_height))
|
| 92 |
+
|
| 93 |
+
if background_color is not None:
|
| 94 |
+
# 如果指定了背景颜色,创建纯色背景
|
| 95 |
+
# 将前景图像转换为BGR格式
|
| 96 |
+
fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR)
|
| 97 |
+
|
| 98 |
+
# 创建背景图像
|
| 99 |
+
background = np.full((original_height, original_width, 3), background_color, dtype=np.uint8)
|
| 100 |
+
|
| 101 |
+
# 使用alpha混合
|
| 102 |
+
alpha = pha.astype(np.float32) / 255.0
|
| 103 |
+
alpha = np.stack([alpha] * 3, axis=-1)
|
| 104 |
+
|
| 105 |
+
result = (fgr_bgr * alpha + background * (1 - alpha)).astype(np.uint8)
|
| 106 |
+
else:
|
| 107 |
+
# 保持透明背景,转换为BGRA格式
|
| 108 |
+
fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR)
|
| 109 |
+
rgba = np.dstack((fgr_bgr, pha)) # (H,W,4)
|
| 110 |
+
result = rgba
|
| 111 |
+
|
| 112 |
+
logger.info("RVM background removal completed")
|
| 113 |
+
return result
|
| 114 |
+
|
| 115 |
+
except Exception as e:
|
| 116 |
+
logger.error(f"RVM background removal failed: {e}")
|
| 117 |
+
raise Exception(f"背景移除失败: {str(e)}")
|
| 118 |
+
|
| 119 |
+
def create_id_photo(self, image: np.ndarray, background_color: tuple = (255, 255, 255)) -> np.ndarray:
|
| 120 |
+
"""
|
| 121 |
+
创建证件照(移除背景并添加纯色背景)
|
| 122 |
+
:param image: 输入的OpenCV图像
|
| 123 |
+
:param background_color: 背景颜色,默认白色(BGR格式)
|
| 124 |
+
:return: 处理后的证件照
|
| 125 |
+
"""
|
| 126 |
+
logger.info(f"Starting to create ID photo, background color: {background_color}")
|
| 127 |
+
|
| 128 |
+
# 移除背景并添加指定颜色背景
|
| 129 |
+
id_photo = self.remove_background(image, background_color)
|
| 130 |
+
|
| 131 |
+
logger.info("ID photo creation completed")
|
| 132 |
+
return id_photo
|
start_local.sh
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
export TZ=Asia/Shanghai
|
| 3 |
+
|
| 4 |
+
export OUTPUT_DIR=/opt/data/output
|
| 5 |
+
export IMAGES_DIR=/opt/data/images
|
| 6 |
+
export MODELS_PATH=/opt/data/models
|
| 7 |
+
export DEEPFACE_HOME=/opt/data/models
|
| 8 |
+
export FAISS_INDEX_DIR=/opt/data/faiss
|
| 9 |
+
export CELEBRITY_SOURCE_DIR=/opt/data/chinese_celeb_dataset
|
| 10 |
+
export GENDER_CONFIDENCE=1
|
| 11 |
+
export UPSCALE_SIZE=2
|
| 12 |
+
export AGE_CONFIDENCE=1.0
|
| 13 |
+
export DRAW_SCORE=true
|
| 14 |
+
export FACE_CONFIDENCE=0.7
|
| 15 |
+
|
| 16 |
+
export ENABLE_DDCOLOR=true
|
| 17 |
+
export ENABLE_GFPGAN=true
|
| 18 |
+
export ENABLE_REALESRGAN=true
|
| 19 |
+
export ENABLE_ANIME_STYLE=true
|
| 20 |
+
export ENABLE_RVM=true
|
| 21 |
+
export ENABLE_REMBG=true
|
| 22 |
+
export ENABLE_CLIP=false
|
| 23 |
+
|
| 24 |
+
export CLEANUP_INTERVAL_HOURS=1
|
| 25 |
+
export CLEANUP_AGE_HOURS=1
|
| 26 |
+
|
| 27 |
+
export BEAUTY_ADJUST_GAMMA=0.8
|
| 28 |
+
export BEAUTY_ADJUST_MIN=1.0
|
| 29 |
+
export BEAUTY_ADJUST_MAX=9.0
|
| 30 |
+
export ENABLE_ANIME_PRELOAD=true
|
| 31 |
+
export ENABLE_LOGGING=true
|
| 32 |
+
export BEAUTY_ADJUST_ENABLED=true
|
| 33 |
+
|
| 34 |
+
export RVM_LOCAL_REPO=/opt/data/models/RobustVideoMatting
|
| 35 |
+
export RVM_WEIGHTS_PATH=/opt/data/models/torch/hub/checkpoints/rvm_resnet50.pth
|
| 36 |
+
export RVM_MODEL=resnet50
|
| 37 |
+
|
| 38 |
+
export AUTO_INIT_GFPGAN=false
|
| 39 |
+
export AUTO_INIT_DDCOLOR=false
|
| 40 |
+
export AUTO_INIT_REALESRGAN=false
|
| 41 |
+
export AUTO_INIT_ANIME_STYLE=true
|
| 42 |
+
export AUTO_INIT_CLIP=false
|
| 43 |
+
export AUTO_INIT_RVM=false
|
| 44 |
+
export AUTO_INIT_REMBG=false
|
| 45 |
+
|
| 46 |
+
export ENABLE_WARMUP=true
|
| 47 |
+
export REALESRGAN_MODEL=realesr-general-x4v3
|
| 48 |
+
export CELEBRITY_FIND_THRESHOLD=0.87
|
| 49 |
+
export FEMALE_AGE_ADJUSTMENT=4
|
| 50 |
+
|
| 51 |
+
uvicorn app:app --workers 1 --loop asyncio --http httptools --host 0.0.0.0 --port 7860 --timeout-keep-alive 600
|
| 52 |
+
|
test/celebrity_crawler.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import requests
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CelebrityCrawler:
|
| 10 |
+
def __init__(self, output_dir="celebrity_images"):
|
| 11 |
+
self.output_dir = output_dir
|
| 12 |
+
self.headers = {
|
| 13 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
| 14 |
+
}
|
| 15 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
def read_celebrities_from_txt(self, file_path):
|
| 18 |
+
"""
|
| 19 |
+
从txt文件读取明星信息
|
| 20 |
+
支持格式:
|
| 21 |
+
1. 姓名,职业
|
| 22 |
+
2. 姓名
|
| 23 |
+
"""
|
| 24 |
+
celebrities = []
|
| 25 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 26 |
+
for line in f:
|
| 27 |
+
line = line.strip()
|
| 28 |
+
if not line or line.startswith('#'):
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
parts = line.split(',')
|
| 32 |
+
name = parts[0].strip()
|
| 33 |
+
profession = parts[1].strip() if len(parts) > 1 else "明星"
|
| 34 |
+
|
| 35 |
+
celebrities.append({
|
| 36 |
+
'name': name,
|
| 37 |
+
'profession': profession
|
| 38 |
+
})
|
| 39 |
+
return celebrities
|
| 40 |
+
|
| 41 |
+
def search_bing_images(self, celebrity_name, max_images=20):
|
| 42 |
+
"""使用Bing图片搜索API获取图片URL"""
|
| 43 |
+
search_url = "https://www.bing.com/images/search"
|
| 44 |
+
params = {
|
| 45 |
+
'q': celebrity_name + " 明星",
|
| 46 |
+
'first': 0,
|
| 47 |
+
'count': max_images
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
response = requests.get(search_url, params=params, headers=self.headers,
|
| 52 |
+
timeout=10)
|
| 53 |
+
response.raise_for_status()
|
| 54 |
+
|
| 55 |
+
# 简单的HTML解析获取图片URL
|
| 56 |
+
import re
|
| 57 |
+
img_urls = re.findall(r'murl":"(.*?)"', response.text)
|
| 58 |
+
return img_urls[:max_images]
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"搜索 {celebrity_name} 时出错: {e}")
|
| 61 |
+
return []
|
| 62 |
+
|
| 63 |
+
def search_baidu_images(self, celebrity_name, max_images=20):
|
| 64 |
+
"""使用百度图片搜索获取图片URL"""
|
| 65 |
+
search_url = "https://image.baidu.com/search/acjson"
|
| 66 |
+
params = {
|
| 67 |
+
'tn': 'resultjson_com',
|
| 68 |
+
'word': celebrity_name + " 明星",
|
| 69 |
+
'pn': 0,
|
| 70 |
+
'rn': max_images,
|
| 71 |
+
'ie': 'utf-8'
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
response = requests.get(search_url, params=params, headers=self.headers,
|
| 76 |
+
timeout=10)
|
| 77 |
+
response.raise_for_status()
|
| 78 |
+
data = response.json()
|
| 79 |
+
|
| 80 |
+
img_urls = []
|
| 81 |
+
if 'data' in data:
|
| 82 |
+
for item in data['data']:
|
| 83 |
+
if 'thumbURL' in item:
|
| 84 |
+
img_urls.append(item['thumbURL'])
|
| 85 |
+
return img_urls[:max_images]
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"搜索 {celebrity_name} 时出错: {e}")
|
| 88 |
+
return []
|
| 89 |
+
|
| 90 |
+
def download_image(self, url, save_path):
|
| 91 |
+
"""下载单张图片"""
|
| 92 |
+
try:
|
| 93 |
+
response = requests.get(url, headers=self.headers, timeout=15)
|
| 94 |
+
response.raise_for_status()
|
| 95 |
+
|
| 96 |
+
# 验证是否为有效图片
|
| 97 |
+
img = Image.open(BytesIO(response.content))
|
| 98 |
+
|
| 99 |
+
# 过滤太小的图片
|
| 100 |
+
if img.size[0] < 100 or img.size[1] < 100:
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
# 保存图片
|
| 104 |
+
img = img.convert('RGB')
|
| 105 |
+
img.save(save_path, 'JPEG', quality=95)
|
| 106 |
+
return True
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f" 下载失败: {str(e)[:50]}")
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
def crawl_celebrity_images(self, celebrity, max_images=20,
|
| 112 |
+
search_engine='baidu'):
|
| 113 |
+
"""爬取单个明星的图片"""
|
| 114 |
+
name = celebrity['name']
|
| 115 |
+
print(f"\n正在爬取: {name} ({celebrity['profession']})")
|
| 116 |
+
|
| 117 |
+
# 创建明星专属文件夹
|
| 118 |
+
celebrity_dir = Path(self.output_dir)
|
| 119 |
+
celebrity_dir.mkdir(parents=True, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
# 获取图片URL列表
|
| 122 |
+
if search_engine == 'baidu':
|
| 123 |
+
img_urls = self.search_baidu_images(name, max_images * 2)
|
| 124 |
+
else:
|
| 125 |
+
img_urls = self.search_bing_images(name, max_images * 2)
|
| 126 |
+
|
| 127 |
+
if not img_urls:
|
| 128 |
+
print(f" 未找到 {name} 的图片")
|
| 129 |
+
return 0
|
| 130 |
+
|
| 131 |
+
print(f" 找到 {len(img_urls)} 个图片链接")
|
| 132 |
+
|
| 133 |
+
# 下载图片
|
| 134 |
+
success_count = 0
|
| 135 |
+
for idx, url in enumerate(img_urls):
|
| 136 |
+
if success_count >= max_images:
|
| 137 |
+
break
|
| 138 |
+
|
| 139 |
+
save_path = celebrity_dir / f"{name}_{idx + 1:03d}.jpg"
|
| 140 |
+
|
| 141 |
+
# 跳过已存在的文件
|
| 142 |
+
if save_path.exists():
|
| 143 |
+
success_count += 1
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
print(f" 下载 {idx + 1}/{len(img_urls)}...", end=' ')
|
| 147 |
+
if self.download_image(url, save_path):
|
| 148 |
+
success_count += 1
|
| 149 |
+
print("✓")
|
| 150 |
+
else:
|
| 151 |
+
print("✗")
|
| 152 |
+
|
| 153 |
+
# 避免请求过快
|
| 154 |
+
time.sleep(0.5)
|
| 155 |
+
|
| 156 |
+
print(f" 成功下载 {success_count} 张图片")
|
| 157 |
+
return success_count
|
| 158 |
+
|
| 159 |
+
def crawl_all(self, txt_file, max_images_per_celebrity=20,
|
| 160 |
+
search_engine='baidu'):
|
| 161 |
+
"""爬取所有明星的图片"""
|
| 162 |
+
print("=" * 60)
|
| 163 |
+
print("明星照片爬取工具")
|
| 164 |
+
print("=" * 60)
|
| 165 |
+
|
| 166 |
+
# 读取明星列表
|
| 167 |
+
celebrities = self.read_celebrities_from_txt(txt_file)
|
| 168 |
+
print(f"\n从 {txt_file} 读取到 {len(celebrities)} 位明星")
|
| 169 |
+
|
| 170 |
+
# 统计信息
|
| 171 |
+
total_images = 0
|
| 172 |
+
failed_celebrities = []
|
| 173 |
+
|
| 174 |
+
# 爬取每位明星
|
| 175 |
+
for i, celebrity in enumerate(celebrities, 1):
|
| 176 |
+
print(f"\n[{i}/{len(celebrities)}]", end=' ')
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
count = self.crawl_celebrity_images(
|
| 180 |
+
celebrity,
|
| 181 |
+
max_images=max_images_per_celebrity,
|
| 182 |
+
search_engine=search_engine
|
| 183 |
+
)
|
| 184 |
+
total_images += count
|
| 185 |
+
|
| 186 |
+
if count == 0:
|
| 187 |
+
failed_celebrities.append(celebrity['name'])
|
| 188 |
+
|
| 189 |
+
# 每爬取5个明星后暂停一下
|
| 190 |
+
if i % 5 == 0:
|
| 191 |
+
print(f"\n 已完成 {i}/{len(celebrities)}, 休息3秒...")
|
| 192 |
+
time.sleep(3)
|
| 193 |
+
|
| 194 |
+
except Exception as e:
|
| 195 |
+
print(f" 处理 {celebrity['name']} 时出错: {e}")
|
| 196 |
+
failed_celebrities.append(celebrity['name'])
|
| 197 |
+
|
| 198 |
+
# 输出统计
|
| 199 |
+
print("\n" + "=" * 60)
|
| 200 |
+
print("爬取完成!")
|
| 201 |
+
print("=" * 60)
|
| 202 |
+
print(f"总明星数: {len(celebrities)}")
|
| 203 |
+
print(f"成功爬取: {len(celebrities) - len(failed_celebrities)}")
|
| 204 |
+
print(f"失败数量: {len(failed_celebrities)}")
|
| 205 |
+
print(f"总图片数: {total_images}")
|
| 206 |
+
print(f"保存位置: {self.output_dir}")
|
| 207 |
+
|
| 208 |
+
if failed_celebrities:
|
| 209 |
+
print(f"\n失败的明星: {', '.join(failed_celebrities)}")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# 使用示例
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
# 创建爬虫实例
|
| 215 |
+
crawler = CelebrityCrawler(output_dir="celebrity_dataset")
|
| 216 |
+
|
| 217 |
+
# 从txt文件爬取
|
| 218 |
+
# txt文件格式示例:
|
| 219 |
+
# 周杰伦,歌手
|
| 220 |
+
# 刘德华,演员
|
| 221 |
+
# 范冰冰,演员
|
| 222 |
+
|
| 223 |
+
crawler.crawl_all(
|
| 224 |
+
txt_file="celebrity_real_names.txt", # 你的txt文件路径
|
| 225 |
+
max_images_per_celebrity=1, # 每位明星爬取的图片数量
|
| 226 |
+
search_engine='baidu' # 'baidu' 或 'bing'
|
| 227 |
+
)
|
test/celebrity_crawler.pyc
ADDED
|
Binary file (5.71 kB). View file
|
|
|
test/decode_celeb_dataset.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Decode base64 file names inside the Chinese celeb dataset directory.
|
| 4 |
+
|
| 5 |
+
Default target: /Users/chenchaoyun/Downloads/chinese_celeb_dataset.
|
| 6 |
+
Use --root to override; --dry-run only prints the plan.
|
| 7 |
+
"""
|
| 8 |
+
import argparse
|
| 9 |
+
import base64
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
DEFAULT_ROOT = Path("/Users/chenchaoyun/Downloads/chinese_celeb_dataset")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _decode_basename(encoded: str) -> str:
|
| 17 |
+
padding = "=" * ((4 - len(encoded) % 4) % 4)
|
| 18 |
+
try:
|
| 19 |
+
return base64.urlsafe_b64decode(
|
| 20 |
+
(encoded + padding).encode("ascii")).decode("utf-8")
|
| 21 |
+
except Exception:
|
| 22 |
+
return encoded
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def rename_dataset(root: Path, dry_run: bool = False) -> int:
|
| 26 |
+
if not root.exists():
|
| 27 |
+
print(f"Directory does not exist: {root}", file=sys.stderr)
|
| 28 |
+
return 1
|
| 29 |
+
if not root.is_dir():
|
| 30 |
+
print(f"Not a directory: {root}", file=sys.stderr)
|
| 31 |
+
return 1
|
| 32 |
+
|
| 33 |
+
renamed = 0
|
| 34 |
+
for file_path in sorted(root.rglob("*")):
|
| 35 |
+
if not file_path.is_file():
|
| 36 |
+
continue
|
| 37 |
+
decoded = _decode_basename(file_path.stem)
|
| 38 |
+
if decoded == file_path.stem:
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
new_path = file_path.with_name(f"{decoded}{file_path.suffix}")
|
| 42 |
+
if new_path == file_path:
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
# Append a counter if the decoded target already exists
|
| 46 |
+
counter = 1
|
| 47 |
+
while new_path.exists() and new_path != file_path:
|
| 48 |
+
new_path = file_path.with_name(
|
| 49 |
+
f"{decoded}_{counter}{file_path.suffix}"
|
| 50 |
+
)
|
| 51 |
+
counter += 1
|
| 52 |
+
|
| 53 |
+
print(f"{file_path} -> {new_path}")
|
| 54 |
+
if dry_run:
|
| 55 |
+
continue
|
| 56 |
+
file_path.rename(new_path)
|
| 57 |
+
renamed += 1
|
| 58 |
+
|
| 59 |
+
print(f"Renamed {renamed} files")
|
| 60 |
+
return 0
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def parse_args() -> argparse.Namespace:
|
| 64 |
+
parser = argparse.ArgumentParser(
|
| 65 |
+
description="Decode chinese_celeb_dataset file names")
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--root",
|
| 68 |
+
type=Path,
|
| 69 |
+
default=DEFAULT_ROOT,
|
| 70 |
+
help="Dataset root directory (default: %(default)s)",
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--dry-run",
|
| 74 |
+
action="store_true",
|
| 75 |
+
help="Only print planned renames without applying them",
|
| 76 |
+
)
|
| 77 |
+
return parser.parse_args()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def main() -> int:
|
| 81 |
+
args = parse_args()
|
| 82 |
+
return rename_dataset(args.root.expanduser().resolve(), args.dry_run)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
sys.exit(main())
|
test/decode_celeb_dataset.pyc
ADDED
|
Binary file (2.26 kB). View file
|
|
|
test/dow_img.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
|
| 3 |
+
# 读取图片
|
| 4 |
+
img = cv2.imread("/opt/data/header.png")
|
| 5 |
+
|
| 6 |
+
# 设置压缩质量(0-100,值越小压缩越狠,质量越差)
|
| 7 |
+
quality = 50
|
| 8 |
+
|
| 9 |
+
# 写入压缩后的图像(注意必须是 .webp)
|
| 10 |
+
cv2.imwrite(
|
| 11 |
+
"/opt/data/output_small.webp",
|
| 12 |
+
img,
|
| 13 |
+
[int(cv2.IMWRITE_WEBP_QUALITY), quality],
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# # 读取原图
|
| 18 |
+
# img = cv2.imread("/opt/data/header.png")
|
| 19 |
+
#
|
| 20 |
+
# # 缩放图像(例如缩小为原图的一半)
|
| 21 |
+
# resized = cv2.resize(img, (img.shape[1] // 2, img.shape[0] // 2))
|
| 22 |
+
#
|
| 23 |
+
# # 写入压缩图像,降低质量
|
| 24 |
+
# cv2.imwrite("/opt/data/output_small.webp", resized, [int(cv2.IMWRITE_WEBP_QUALITY), 40])
|
test/dow_img.pyc
ADDED
|
Binary file (291 Bytes). View file
|
|
|
test/howcuteami.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import argparse
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# detect face
|
| 9 |
+
def highlightFace(net, frame, conf_threshold=0.95):
|
| 10 |
+
frameOpencvDnn = frame.copy()
|
| 11 |
+
frameHeight = frameOpencvDnn.shape[0]
|
| 12 |
+
frameWidth = frameOpencvDnn.shape[1]
|
| 13 |
+
blob = cv2.dnn.blobFromImage(
|
| 14 |
+
frameOpencvDnn, 1.0, (300, 300), [104, 117, 123], True, False
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
net.setInput(blob)
|
| 18 |
+
detections = net.forward()
|
| 19 |
+
faceBoxes = []
|
| 20 |
+
|
| 21 |
+
for i in range(detections.shape[2]):
|
| 22 |
+
confidence = detections[0, 0, i, 2]
|
| 23 |
+
if confidence > conf_threshold:
|
| 24 |
+
x1 = int(detections[0, 0, i, 3] * frameWidth)
|
| 25 |
+
y1 = int(detections[0, 0, i, 4] * frameHeight)
|
| 26 |
+
x2 = int(detections[0, 0, i, 5] * frameWidth)
|
| 27 |
+
y2 = int(detections[0, 0, i, 6] * frameHeight)
|
| 28 |
+
faceBoxes.append(scale([x1, y1, x2, y2]))
|
| 29 |
+
|
| 30 |
+
return faceBoxes
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# scale current rectangle to box
|
| 34 |
+
def scale(box):
|
| 35 |
+
width = box[2] - box[0]
|
| 36 |
+
height = box[3] - box[1]
|
| 37 |
+
maximum = max(width, height)
|
| 38 |
+
dx = int((maximum - width) / 2)
|
| 39 |
+
dy = int((maximum - height) / 2)
|
| 40 |
+
|
| 41 |
+
bboxes = [box[0] - dx, box[1] - dy, box[2] + dx, box[3] + dy]
|
| 42 |
+
return bboxes
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# crop image
|
| 46 |
+
def cropImage(image, box):
|
| 47 |
+
num = image[box[1] : box[3], box[0] : box[2]]
|
| 48 |
+
return num
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# main
|
| 52 |
+
parser = argparse.ArgumentParser()
|
| 53 |
+
parser.add_argument("-i", "--image", type=str, required=False, help="input image")
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
|
| 56 |
+
# 创建输出目录
|
| 57 |
+
output_dir = "../output"
|
| 58 |
+
if not os.path.exists(output_dir):
|
| 59 |
+
os.makedirs(output_dir)
|
| 60 |
+
|
| 61 |
+
faceProto = "models/opencv_face_detector.pbtxt"
|
| 62 |
+
faceModel = "models/opencv_face_detector_uint8.pb"
|
| 63 |
+
ageProto = "models/age_googlenet.prototxt"
|
| 64 |
+
ageModel = "models/age_googlenet.caffemodel"
|
| 65 |
+
genderProto = "models/gender_googlenet.prototxt"
|
| 66 |
+
genderModel = "models/gender_googlenet.caffemodel"
|
| 67 |
+
beautyProto = "models/beauty_resnet.prototxt"
|
| 68 |
+
beautyModel = "models/beauty_resnet.caffemodel"
|
| 69 |
+
|
| 70 |
+
MODEL_MEAN_VALUES = (104, 117, 123)
|
| 71 |
+
ageList = [
|
| 72 |
+
"(0-2)",
|
| 73 |
+
"(4-6)",
|
| 74 |
+
"(8-12)",
|
| 75 |
+
"(15-20)",
|
| 76 |
+
"(25-32)",
|
| 77 |
+
"(38-43)",
|
| 78 |
+
"(48-53)",
|
| 79 |
+
"(60-100)",
|
| 80 |
+
]
|
| 81 |
+
genderList = ["Male", "Female"]
|
| 82 |
+
|
| 83 |
+
# 定义性别对应的颜色 (BGR格式)
|
| 84 |
+
gender_colors = {
|
| 85 |
+
"Male": (255, 165, 0), # 橙色 Orange
|
| 86 |
+
"Female": (255, 0, 255), # 洋红 Magenta / Fuchsia
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
faceNet = cv2.dnn.readNet(faceModel, faceProto)
|
| 90 |
+
ageNet = cv2.dnn.readNet(ageModel, ageProto)
|
| 91 |
+
genderNet = cv2.dnn.readNet(genderModel, genderProto)
|
| 92 |
+
beautyNet = cv2.dnn.readNet(beautyModel, beautyProto)
|
| 93 |
+
|
| 94 |
+
# 读取图片
|
| 95 |
+
image_path = args.image if args.image else "images/charlize.jpg"
|
| 96 |
+
frame = cv2.imread(image_path)
|
| 97 |
+
|
| 98 |
+
if frame is None:
|
| 99 |
+
print(f"无法读取图片: {image_path}")
|
| 100 |
+
exit()
|
| 101 |
+
|
| 102 |
+
faceBoxes = highlightFace(faceNet, frame)
|
| 103 |
+
if not faceBoxes:
|
| 104 |
+
print("No face detected")
|
| 105 |
+
exit()
|
| 106 |
+
|
| 107 |
+
print(f"检测到 {len(faceBoxes)} 张人脸")
|
| 108 |
+
|
| 109 |
+
for i, faceBox in enumerate(faceBoxes):
|
| 110 |
+
# 提取人脸区域
|
| 111 |
+
face = cropImage(frame, faceBox)
|
| 112 |
+
face_resized = cv2.resize(face, (224, 224))
|
| 113 |
+
|
| 114 |
+
# gender net
|
| 115 |
+
blob = cv2.dnn.blobFromImage(
|
| 116 |
+
face_resized, 1.0, (224, 224), MODEL_MEAN_VALUES, swapRB=False
|
| 117 |
+
)
|
| 118 |
+
genderNet.setInput(blob)
|
| 119 |
+
genderPreds = genderNet.forward()
|
| 120 |
+
gender = genderList[genderPreds[0].argmax()]
|
| 121 |
+
print(f"Gender: {gender}")
|
| 122 |
+
|
| 123 |
+
# age net
|
| 124 |
+
ageNet.setInput(blob)
|
| 125 |
+
agePreds = ageNet.forward()
|
| 126 |
+
age = ageList[agePreds[0].argmax()]
|
| 127 |
+
print(f"Age: {age[1:-1]} years")
|
| 128 |
+
|
| 129 |
+
# beauty net
|
| 130 |
+
blob = cv2.dnn.blobFromImage(
|
| 131 |
+
face_resized, 1.0 / 255, (224, 224), MODEL_MEAN_VALUES, swapRB=False
|
| 132 |
+
)
|
| 133 |
+
beautyNet.setInput(blob)
|
| 134 |
+
beautyPreds = beautyNet.forward()
|
| 135 |
+
beauty = round(2.0 * sum(beautyPreds[0]), 1)
|
| 136 |
+
print(f"Beauty: {beauty}/10.0")
|
| 137 |
+
|
| 138 |
+
# 根据性别选择颜色
|
| 139 |
+
color = gender_colors[gender]
|
| 140 |
+
|
| 141 |
+
# 保存人脸图片 - 使用cv2.imwrite
|
| 142 |
+
face_filename = f"{output_dir}/face_{i+1}.webp"
|
| 143 |
+
cv2.imwrite(face_filename, face, [cv2.IMWRITE_WEBP_QUALITY, 95])
|
| 144 |
+
print(f"人脸图片已保存: {face_filename}")
|
| 145 |
+
|
| 146 |
+
# 保存评分到图片上(可选)
|
| 147 |
+
face_with_text = face.copy()
|
| 148 |
+
cv2.putText(
|
| 149 |
+
face_with_text, f"{gender}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2
|
| 150 |
+
)
|
| 151 |
+
cv2.putText(
|
| 152 |
+
face_with_text,
|
| 153 |
+
f"{age[1:-1]} years",
|
| 154 |
+
(10, 60),
|
| 155 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 156 |
+
0.7,
|
| 157 |
+
color,
|
| 158 |
+
2,
|
| 159 |
+
)
|
| 160 |
+
cv2.putText(
|
| 161 |
+
face_with_text,
|
| 162 |
+
f"{beauty}/10.0",
|
| 163 |
+
(10, 90),
|
| 164 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 165 |
+
0.7,
|
| 166 |
+
color,
|
| 167 |
+
2,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
annotated_filename = f"{output_dir}/face_{i+1}_annotated.webp"
|
| 171 |
+
cv2.imwrite(annotated_filename, face_with_text, [cv2.IMWRITE_WEBP_QUALITY, 95])
|
| 172 |
+
print(f"标注人脸已保存: {annotated_filename}")
|
| 173 |
+
|
| 174 |
+
# 在原图上绘制人脸框和信息
|
| 175 |
+
cv2.rectangle(
|
| 176 |
+
frame,
|
| 177 |
+
(faceBox[0], faceBox[1]),
|
| 178 |
+
(faceBox[2], faceBox[3]),
|
| 179 |
+
color,
|
| 180 |
+
int(round(frame.shape[0] / 400)),
|
| 181 |
+
8,
|
| 182 |
+
)
|
| 183 |
+
cv2.putText(
|
| 184 |
+
frame,
|
| 185 |
+
f"{gender}, {age}, {beauty}",
|
| 186 |
+
(faceBox[0], faceBox[1] - 10),
|
| 187 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 188 |
+
1.25,
|
| 189 |
+
color,
|
| 190 |
+
2,
|
| 191 |
+
cv2.LINE_AA,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# 保存完整的标注图片
|
| 195 |
+
result_filename = f"{output_dir}/result_full.webp"
|
| 196 |
+
cv2.imwrite(result_filename, frame, [cv2.IMWRITE_WEBP_QUALITY, 95])
|
| 197 |
+
print(f"完整结果图片已保存: {result_filename}")
|
| 198 |
+
|
| 199 |
+
# 显示图片
|
| 200 |
+
cv2.imshow("howbeautifulami", frame)
|
| 201 |
+
cv2.waitKey(0)
|
| 202 |
+
cv2.destroyAllWindows()
|
test/howcuteami.pyc
ADDED
|
Binary file (4.13 kB). View file
|
|
|
test/import_history_images.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
导入历史图片文件到数据库的脚本
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import hashlib
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
# 添加项目根目录到Python路径
|
| 15 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 16 |
+
|
| 17 |
+
from database import record_image_creation, fetch_records_by_paths
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def calculate_file_hash(file_path):
|
| 21 |
+
"""计算文件的MD5哈希值"""
|
| 22 |
+
hash_md5 = hashlib.md5()
|
| 23 |
+
with open(file_path, "rb") as f:
|
| 24 |
+
# 分块读取文件,避免大文件占用过多内存
|
| 25 |
+
for chunk in iter(lambda: f.read(4096), b""):
|
| 26 |
+
hash_md5.update(chunk)
|
| 27 |
+
return hash_md5.hexdigest()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def infer_category_from_filename(filename):
|
| 31 |
+
"""从文件名推断类别"""
|
| 32 |
+
filename_lower = filename.lower()
|
| 33 |
+
|
| 34 |
+
# 处理动漫风格化类型
|
| 35 |
+
if '_anime_style_' in filename_lower:
|
| 36 |
+
return 'anime_style'
|
| 37 |
+
|
| 38 |
+
# 查找最后一个下划线和第一个点的位置
|
| 39 |
+
last_underscore_index = filename_lower.rfind('_')
|
| 40 |
+
first_dot_index = filename_lower.find('.', last_underscore_index)
|
| 41 |
+
|
| 42 |
+
# 如果找到了下划线和点,且下划线在点之前
|
| 43 |
+
if last_underscore_index != -1 and first_dot_index != -1 and last_underscore_index < first_dot_index:
|
| 44 |
+
# 提取下划线和点之间的内容
|
| 45 |
+
file_type = filename_lower[last_underscore_index + 1:first_dot_index]
|
| 46 |
+
|
| 47 |
+
# 根据类型返回中文描述
|
| 48 |
+
type_mapping = {
|
| 49 |
+
'restore': 'restore',
|
| 50 |
+
'upcolor': 'upcolor',
|
| 51 |
+
'grayscale': 'grayscale',
|
| 52 |
+
'upscale': 'upscale',
|
| 53 |
+
'compress': 'compress',
|
| 54 |
+
'id_photo': 'id_photo',
|
| 55 |
+
'grid': 'grid',
|
| 56 |
+
'rvm': 'rvm',
|
| 57 |
+
'celebrity': 'celebrity',
|
| 58 |
+
'face': 'face',
|
| 59 |
+
'original': 'original'
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
return type_mapping.get(file_type, 'other')
|
| 63 |
+
|
| 64 |
+
# 默认返回 other
|
| 65 |
+
return 'other'
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
async def import_history_images(source_dir, nickname="system_import"):
|
| 69 |
+
"""导入历史图片到数据库"""
|
| 70 |
+
source_path = Path(source_dir)
|
| 71 |
+
|
| 72 |
+
if not source_path.exists():
|
| 73 |
+
print(f"错误: 目录 {source_dir} 不存在")
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
# 支持的图片格式
|
| 77 |
+
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif', '.tiff',
|
| 78 |
+
'.tif'}
|
| 79 |
+
|
| 80 |
+
# 获取所有图片文件
|
| 81 |
+
image_files = []
|
| 82 |
+
for ext in image_extensions:
|
| 83 |
+
image_files.extend(source_path.glob(f"*{ext}"))
|
| 84 |
+
image_files.extend(source_path.glob(f"*{ext.upper()}"))
|
| 85 |
+
|
| 86 |
+
print(f"找到 {len(image_files)} 个图片文件")
|
| 87 |
+
|
| 88 |
+
imported_count = 0
|
| 89 |
+
skipped_count = 0
|
| 90 |
+
|
| 91 |
+
for image_path in image_files:
|
| 92 |
+
try:
|
| 93 |
+
file_name = image_path.name
|
| 94 |
+
|
| 95 |
+
# 检查文件是否已存在于数据库中(基于文件名)
|
| 96 |
+
records = await fetch_records_by_paths([file_name])
|
| 97 |
+
|
| 98 |
+
if file_name in records:
|
| 99 |
+
print(f"跳过已存在的文件: {file_name}")
|
| 100 |
+
skipped_count += 1
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
# 如果数据库中没有记录,则继续导入
|
| 104 |
+
# 计算文件哈希值用于进一步确认唯一性
|
| 105 |
+
file_hash = calculate_file_hash(str(image_path))
|
| 106 |
+
|
| 107 |
+
# 推断文件类别
|
| 108 |
+
category = infer_category_from_filename(file_name)
|
| 109 |
+
|
| 110 |
+
# 记录到数据库
|
| 111 |
+
await record_image_creation(
|
| 112 |
+
file_path=file_name, # 使用文件名而不是完整路径
|
| 113 |
+
nickname=nickname,
|
| 114 |
+
category=category,
|
| 115 |
+
bos_uploaded=False, # 历史文件通常未上传到BOS
|
| 116 |
+
score=0.0, # 历史文件默认分数为0
|
| 117 |
+
extra_metadata={
|
| 118 |
+
"source": "history_import",
|
| 119 |
+
"original_path": str(image_path),
|
| 120 |
+
"file_hash": file_hash,
|
| 121 |
+
"import_time": datetime.now().isoformat()
|
| 122 |
+
}
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
imported_count += 1
|
| 126 |
+
print(f"成功导入: {file_name} (类别: {category})")
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"导入文件失败 {image_path.name}: {str(e)}")
|
| 130 |
+
continue
|
| 131 |
+
|
| 132 |
+
print(f"\n导入完成!")
|
| 133 |
+
print(f"成功导入: {imported_count} 个文件")
|
| 134 |
+
print(f"跳过: {skipped_count} 个文件")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
async def main():
|
| 138 |
+
if len(sys.argv) < 2:
|
| 139 |
+
print("用法: python import_history_images.py <图片目录路径> [昵称]")
|
| 140 |
+
print(
|
| 141 |
+
"示例: python import_history_images.py ~/app/data/images")
|
| 142 |
+
print(
|
| 143 |
+
"示例: python import_history_images.py ~/app/data/images \"历史导入\"")
|
| 144 |
+
sys.exit(1)
|
| 145 |
+
|
| 146 |
+
source_directory = sys.argv[1]
|
| 147 |
+
nickname = sys.argv[2] if len(sys.argv) > 2 else "system_import"
|
| 148 |
+
|
| 149 |
+
print(f"开始导入图片文件...")
|
| 150 |
+
print(f"源目录: {source_directory}")
|
| 151 |
+
print(f"用户昵称: {nickname}")
|
| 152 |
+
print("-" * 50)
|
| 153 |
+
|
| 154 |
+
start_time = time.time()
|
| 155 |
+
await import_history_images(source_directory, nickname)
|
| 156 |
+
end_time = time.time()
|
| 157 |
+
|
| 158 |
+
print(f"\n总耗时: {end_time - start_time:.2f} 秒")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
asyncio.run(main())
|
test/import_history_images.pyc
ADDED
|
Binary file (3.76 kB). View file
|
|
|
test/remove_duplicate_celeb_images.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
遍历指定目录,根据文件内容(MD5)查找重复项,如果发现重复则只保留一个。
|
| 4 |
+
默认目标目录为 /opt/data/chinese_celeb_dataset,可用 --target-dir 覆盖。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import hashlib
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict
|
| 15 |
+
|
| 16 |
+
DEFAULT_TARGET_DIR = Path("/opt/data/chinese_celeb_dataset")
|
| 17 |
+
CHUNK_SIZE = 4 * 1024 * 1024 # 4MB
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def compute_md5(file_path: Path) -> str:
|
| 21 |
+
"""流式计算文件 MD5,避免一次性读入大文件。"""
|
| 22 |
+
digest = hashlib.md5()
|
| 23 |
+
with file_path.open("rb") as fh:
|
| 24 |
+
for chunk in iter(lambda: fh.read(CHUNK_SIZE), b""):
|
| 25 |
+
digest.update(chunk)
|
| 26 |
+
return digest.hexdigest()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def deduplicate(target_dir: Path, dry_run: bool = False) -> int:
|
| 30 |
+
"""执行去重逻辑,返回删除的重复文件数量。"""
|
| 31 |
+
if not target_dir.exists():
|
| 32 |
+
print(f"[error] 目标目录不存在: {target_dir}", file=sys.stderr)
|
| 33 |
+
return 0
|
| 34 |
+
if not target_dir.is_dir():
|
| 35 |
+
print(f"[error] 目标路径不是目录: {target_dir}", file=sys.stderr)
|
| 36 |
+
return 0
|
| 37 |
+
|
| 38 |
+
md5_map: Dict[str, Path] = {}
|
| 39 |
+
removed = 0
|
| 40 |
+
scanned = 0
|
| 41 |
+
|
| 42 |
+
# 按路径排序,确保始终保留最先遍历到的文件
|
| 43 |
+
for file_path in sorted(target_dir.rglob("*")):
|
| 44 |
+
if not file_path.is_file() or file_path.is_symlink():
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
scanned += 1
|
| 48 |
+
try:
|
| 49 |
+
file_md5 = compute_md5(file_path)
|
| 50 |
+
except Exception as exc:
|
| 51 |
+
print(f"[warn] 计算 MD5 失败: {file_path} -> {exc}", file=sys.stderr)
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
original = md5_map.get(file_md5)
|
| 55 |
+
if original is None:
|
| 56 |
+
md5_map[file_md5] = file_path
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
if dry_run:
|
| 60 |
+
print(f"[dry-run] {file_path} 与 {original} 内容相同,将被删除")
|
| 61 |
+
else:
|
| 62 |
+
try:
|
| 63 |
+
os.remove(file_path)
|
| 64 |
+
removed += 1
|
| 65 |
+
print(f"[remove] 删除重复文件: {file_path} (原始: {original})")
|
| 66 |
+
except Exception as exc:
|
| 67 |
+
print(f"[error] 删除失败: {file_path} -> {exc}", file=sys.stderr)
|
| 68 |
+
|
| 69 |
+
print(
|
| 70 |
+
f"[summary] 扫描文件: {scanned}, 保留唯一文件: {len(md5_map)}, 删除重复文件: {removed}{' (dry-run)' if dry_run else ''}"
|
| 71 |
+
)
|
| 72 |
+
return removed
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def parse_args() -> argparse.Namespace:
|
| 76 |
+
parser = argparse.ArgumentParser(description="按 MD5 删除重复文件,仅保留一个副本。")
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--target-dir",
|
| 79 |
+
type=Path,
|
| 80 |
+
default=DEFAULT_TARGET_DIR,
|
| 81 |
+
help=f"需要去重的目录(默认: {DEFAULT_TARGET_DIR})",
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--dry-run",
|
| 85 |
+
action="store_true",
|
| 86 |
+
help="只输出将删除的文件,不实际删除。",
|
| 87 |
+
)
|
| 88 |
+
return parser.parse_args()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def main() -> int:
|
| 92 |
+
args = parse_args()
|
| 93 |
+
target_dir = args.target_dir.expanduser().resolve()
|
| 94 |
+
deduplicate(target_dir, dry_run=args.dry_run)
|
| 95 |
+
return 0
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
raise SystemExit(main())
|
test/remove_duplicate_celeb_images.pyc
ADDED
|
Binary file (3.18 kB). View file
|
|
|
test/remove_faceless_images.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
遍历 /opt/data/chinese_celeb_dataset 下的图片,使用 YOLO 人脸检测并删除没有检测到人脸的图片。
|
| 4 |
+
|
| 5 |
+
用法示例:
|
| 6 |
+
python test/remove_faceless_images.py --dry-run
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Iterable, List, Optional
|
| 15 |
+
|
| 16 |
+
import config
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from ultralytics import YOLO
|
| 20 |
+
except ImportError as exc: # pragma: no cover - 运行期缺依赖提示
|
| 21 |
+
raise SystemExit("缺少 ultralytics,请先执行 pip install ultralytics") from exc
|
| 22 |
+
|
| 23 |
+
# 默认数据集与模型配置
|
| 24 |
+
DEFAULT_DATASET_DIR = Path("/opt/data/chinese_celeb_dataset")
|
| 25 |
+
MODEL_DIR = Path(config.MODELS_PATH)
|
| 26 |
+
YOLO_MODEL_NAME = config.YOLO_MODEL
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parse_args() -> argparse.Namespace:
|
| 30 |
+
parser = argparse.ArgumentParser(
|
| 31 |
+
description="使用 YOLO 检测 /opt/data/chinese_celeb_dataset 中的图片并删除无脸图片"
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--dataset-dir",
|
| 35 |
+
type=Path,
|
| 36 |
+
default=DEFAULT_DATASET_DIR,
|
| 37 |
+
help="需要检查的根目录(默认:/opt/data/chinese_celeb_dataset)",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--extensions",
|
| 41 |
+
type=str,
|
| 42 |
+
default=".jpg,.jpeg,.png,.webp,.bmp",
|
| 43 |
+
help="需要检查的图片扩展名,逗号分隔",
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--confidence",
|
| 47 |
+
type=float,
|
| 48 |
+
default=config.FACE_CONFIDENCE,
|
| 49 |
+
help="YOLO 检测的人脸置信度阈值",
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--dry-run",
|
| 53 |
+
action="store_true",
|
| 54 |
+
help="仅输出将被删除的文件,不真正删除,便于先预览结果",
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--verbose",
|
| 58 |
+
action="store_true",
|
| 59 |
+
help="输出更多调试信息",
|
| 60 |
+
)
|
| 61 |
+
return parser.parse_args()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def load_yolo_model() -> YOLO:
|
| 65 |
+
"""
|
| 66 |
+
优先加载本地 models 目录下配置好的模型,如果不存在则回退为模型名称(会触发自动下载)。
|
| 67 |
+
"""
|
| 68 |
+
candidates: List[str] = []
|
| 69 |
+
local_path = MODEL_DIR / YOLO_MODEL_NAME
|
| 70 |
+
if local_path.exists():
|
| 71 |
+
candidates.append(str(local_path))
|
| 72 |
+
candidates.append(YOLO_MODEL_NAME)
|
| 73 |
+
|
| 74 |
+
last_error: Optional[Exception] = None
|
| 75 |
+
for candidate in candidates:
|
| 76 |
+
try:
|
| 77 |
+
config.logger.info("尝试加载 YOLO 模型:%s", candidate)
|
| 78 |
+
return YOLO(candidate)
|
| 79 |
+
except Exception as exc: # pragma: no cover
|
| 80 |
+
last_error = exc
|
| 81 |
+
config.logger.warning("加载 YOLO 模型失败:%s -> %s", candidate, exc)
|
| 82 |
+
|
| 83 |
+
raise RuntimeError(f"无法加载 YOLO 模型:{YOLO_MODEL_NAME}") from last_error
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def iter_image_files(root: Path, extensions: Iterable[str]) -> Iterable[Path]:
|
| 87 |
+
lower_exts = tuple(ext.strip().lower() for ext in extensions if ext.strip())
|
| 88 |
+
for path in root.rglob("*"):
|
| 89 |
+
if not path.is_file():
|
| 90 |
+
continue
|
| 91 |
+
if path.suffix.lower() in lower_exts:
|
| 92 |
+
yield path
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def has_face(model: YOLO, image_path: Path, confidence: float, verbose: bool = False) -> bool:
|
| 96 |
+
"""
|
| 97 |
+
使用 YOLO 检测图片中是否存在人脸。检测到任意一个框即可视为有人脸。
|
| 98 |
+
"""
|
| 99 |
+
try:
|
| 100 |
+
results = model(image_path, conf=confidence, verbose=False)
|
| 101 |
+
except Exception as exc: # pragma: no cover
|
| 102 |
+
config.logger.error("检测失败,跳过 %s:%s", image_path, exc)
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
for result in results:
|
| 106 |
+
boxes = getattr(result, "boxes", None)
|
| 107 |
+
if boxes is None:
|
| 108 |
+
continue
|
| 109 |
+
if len(boxes) > 0:
|
| 110 |
+
if verbose:
|
| 111 |
+
faces = []
|
| 112 |
+
for box in boxes:
|
| 113 |
+
cls_id = int(box.cls[0]) if getattr(box, "cls", None) is not None else -1
|
| 114 |
+
score = float(box.conf[0]) if getattr(box, "conf", None) is not None else 0.0
|
| 115 |
+
faces.append({"cls": cls_id, "conf": score})
|
| 116 |
+
config.logger.info("检测到人脸:%s -> %s", image_path, faces)
|
| 117 |
+
return True
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main() -> None:
|
| 122 |
+
args = parse_args()
|
| 123 |
+
dataset_dir: Path = args.dataset_dir.expanduser().resolve()
|
| 124 |
+
if not dataset_dir.exists():
|
| 125 |
+
raise SystemExit(f"目录不存在:{dataset_dir}")
|
| 126 |
+
|
| 127 |
+
model = load_yolo_model()
|
| 128 |
+
image_paths = list(iter_image_files(dataset_dir, args.extensions.split(",")))
|
| 129 |
+
total = len(image_paths)
|
| 130 |
+
if total == 0:
|
| 131 |
+
print(f"目录 {dataset_dir} 下没有匹配到图片文件")
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
removed = 0
|
| 135 |
+
errored = 0
|
| 136 |
+
for idx, image_path in enumerate(image_paths, start=1):
|
| 137 |
+
if idx % 100 == 0 or args.verbose:
|
| 138 |
+
print(f"[{idx}/{total}] 正在处理 {image_path}")
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
if has_face(model, image_path, args.confidence, args.verbose):
|
| 142 |
+
continue
|
| 143 |
+
except Exception as exc: # pragma: no cover
|
| 144 |
+
errored += 1
|
| 145 |
+
config.logger.error("检测过程中发生异常,跳过 %s:%s", image_path, exc)
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
if args.dry_run:
|
| 149 |
+
print(f"[DRY-RUN] 将删除:{image_path}")
|
| 150 |
+
else:
|
| 151 |
+
try:
|
| 152 |
+
image_path.unlink()
|
| 153 |
+
print(f"已删除:{image_path}")
|
| 154 |
+
except Exception as exc: # pragma: no cover
|
| 155 |
+
errored += 1
|
| 156 |
+
config.logger.error("删除失败 %s:%s", image_path, exc)
|
| 157 |
+
continue
|
| 158 |
+
removed += 1
|
| 159 |
+
|
| 160 |
+
print(
|
| 161 |
+
f"扫描完成,检测图片 {total} 张,删除 {removed} 张无脸图片,异常 {errored} 张,数据保存在:{dataset_dir}"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
try:
|
| 167 |
+
main()
|
| 168 |
+
except KeyboardInterrupt: # pragma: no cover
|
| 169 |
+
sys.exit("用户中断")
|
test/remove_faceless_images.pyc
ADDED
|
Binary file (5.17 kB). View file
|
|
|
test/test_deepface.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
from deepface import DeepFace
|
| 4 |
+
|
| 5 |
+
images_path = "/opt/data/face"
|
| 6 |
+
|
| 7 |
+
# ========== 2. 人脸相似度比对 ==========
|
| 8 |
+
start_time = time.time()
|
| 9 |
+
result_verification = DeepFace.verify(
|
| 10 |
+
img1_path=images_path + "/4.webp",
|
| 11 |
+
img2_path=images_path + "/5.webp",
|
| 12 |
+
model_name="ArcFace", # 指定模型
|
| 13 |
+
detector_backend="yolov11n", # 人脸检测器 retinaface / yolov8 / opencv / ssd / mediapipe
|
| 14 |
+
distance_metric="cosine" # 相似度度量
|
| 15 |
+
)
|
| 16 |
+
end_time = time.time()
|
| 17 |
+
print(f"🕒 人脸比对耗时: {end_time - start_time:.3f} 秒")
|
| 18 |
+
|
| 19 |
+
# 打印结果
|
| 20 |
+
print(json.dumps(result_verification, ensure_ascii=False, indent=2))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ========== 1. 人脸识别 ==========
|
| 24 |
+
|
| 25 |
+
start_time = time.time()
|
| 26 |
+
result_recognition = DeepFace.find(
|
| 27 |
+
img_path=images_path + "/1.jpg", # 待识别人脸
|
| 28 |
+
db_path=images_path, # 数据库路径
|
| 29 |
+
model_name="ArcFace", # 指定模型
|
| 30 |
+
detector_backend="yolov11n", # 人脸检测器
|
| 31 |
+
distance_metric="cosine" # 相似度度量
|
| 32 |
+
)
|
| 33 |
+
end_time = time.time()
|
| 34 |
+
print(f"🕒 人脸识别耗时: {end_time - start_time:.3f} 秒")
|
| 35 |
+
|
| 36 |
+
# 如果需要打印结果,可以取消注释
|
| 37 |
+
# df = result_recognition[0]
|
| 38 |
+
# print(df.to_json(orient="records", force_ascii=False))
|
test/test_deepface.pyc
ADDED
|
Binary file (769 Bytes). View file
|
|
|
test/test_main.http
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Test your FastAPI endpoints
|
| 2 |
+
|
| 3 |
+
GET http://127.0.0.1:8000/
|
| 4 |
+
Accept: application/json
|
| 5 |
+
|
| 6 |
+
###
|
| 7 |
+
|
| 8 |
+
GET http://127.0.0.1:8000/hello/User
|
| 9 |
+
Accept: application/json
|
| 10 |
+
|
| 11 |
+
###
|
test/test_rvm_infer.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
|
| 8 |
+
device = "cpu"
|
| 9 |
+
|
| 10 |
+
# 输入输出路径
|
| 11 |
+
input_path = "/opt/data/face/yang.webp"
|
| 12 |
+
output_path = "/opt/data/face/output_alpha.webp"
|
| 13 |
+
|
| 14 |
+
# ✅ 加载预训练模型 (resnet50)
|
| 15 |
+
model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50").to(device).eval()
|
| 16 |
+
|
| 17 |
+
# 开始计时
|
| 18 |
+
start = time.time()
|
| 19 |
+
|
| 20 |
+
# 读图 (BGR->RGB)
|
| 21 |
+
img = cv2.imread(input_path)[:, :, ::-1].copy()
|
| 22 |
+
src = transforms.ToTensor()(img).unsqueeze(0).to(device)
|
| 23 |
+
|
| 24 |
+
# 推理
|
| 25 |
+
rec = [None] * 4
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
fgr, pha, *rec = model(src, *rec, downsample_ratio=0.25)
|
| 28 |
+
|
| 29 |
+
# 转 numpy
|
| 30 |
+
fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # (H,W,3)
|
| 31 |
+
pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) # (H,W)
|
| 32 |
+
|
| 33 |
+
# 拼接 RGBA
|
| 34 |
+
rgba = np.dstack((fgr, pha)) # (H,W,4)
|
| 35 |
+
|
| 36 |
+
# 保存 WebP (带透明度)
|
| 37 |
+
cv2.imwrite(output_path, rgba[:, :, [2,1,0,3]], [cv2.IMWRITE_WEBP_QUALITY, 100]) # 转成 BGRA 顺序
|
| 38 |
+
|
| 39 |
+
# 结束计时
|
| 40 |
+
elapsed = time.time() - start
|
| 41 |
+
|
| 42 |
+
# 控制台日志输出
|
| 43 |
+
print(f"✅ RVM 抠图完成 (透明背景)")
|
| 44 |
+
print(f" 输入文件: {input_path}")
|
| 45 |
+
print(f" 输出文件: {output_path}")
|
| 46 |
+
print(f" 耗时: {elapsed:.3f} 秒 (设备: {device})")
|
test/test_rvm_infer.pyc
ADDED
|
Binary file (1.2 kB). View file
|
|
|
test/test_score.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from retinaface import RetinaFace
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def default_converter(o):
|
| 9 |
+
if isinstance(o, np.integer):
|
| 10 |
+
return int(o)
|
| 11 |
+
if isinstance(o, np.floating):
|
| 12 |
+
return float(o)
|
| 13 |
+
if isinstance(o, np.ndarray):
|
| 14 |
+
return o.tolist()
|
| 15 |
+
return str(o)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# 配置日志
|
| 19 |
+
logging.basicConfig(level=logging.INFO)
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
resp = RetinaFace.detect_faces("~/Downloads/chounan.jpeg")
|
| 23 |
+
|
| 24 |
+
logger.info(
|
| 25 |
+
"search results: " + json.dumps(resp, ensure_ascii=False, default=default_converter)
|
| 26 |
+
)
|
test/test_score.pyc
ADDED
|
Binary file (701 Bytes). View file
|
|
|
test/test_score_adjustment_demo.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def adjust_score(score, threshold, gamma):
|
| 2 |
+
"""根据阈值和gamma值调整评分"""
|
| 3 |
+
if score < threshold:
|
| 4 |
+
adjusted = threshold - gamma * (threshold - score)
|
| 5 |
+
return round(min(10.0, max(0.0, adjusted)), 1)
|
| 6 |
+
return score
|
| 7 |
+
|
| 8 |
+
# 默认参数 (T=9.0, γ=0.5)
|
| 9 |
+
default_threshold = 9.0
|
| 10 |
+
default_gamma = 0.5
|
| 11 |
+
|
| 12 |
+
# 新参数1 (T=8.0, γ=0.5)
|
| 13 |
+
new_threshold_1 = 9
|
| 14 |
+
new_gamma_1 = 0.9
|
| 15 |
+
|
| 16 |
+
# 新参数2 (T=8.0, γ=0.3)
|
| 17 |
+
new_threshold_2 = 9
|
| 18 |
+
new_gamma_2 = 0.8
|
| 19 |
+
|
| 20 |
+
print(f"原始分\tT={default_threshold},y={default_gamma}\tT={new_threshold_1},γ={new_gamma_1}\tT={new_threshold_2},γ={new_gamma_2}")
|
| 21 |
+
print("-----\t----------\t----------\t----------")
|
| 22 |
+
|
| 23 |
+
# 从1.0到10.0,以0.1为步长
|
| 24 |
+
for i in range(10, 101):
|
| 25 |
+
score = i / 10.0
|
| 26 |
+
default_adjusted = adjust_score(score, default_threshold, default_gamma)
|
| 27 |
+
new_adjusted_1 = adjust_score(score, new_threshold_1, new_gamma_1)
|
| 28 |
+
new_adjusted_2 = adjust_score(score, new_threshold_2, new_gamma_2)
|
| 29 |
+
# 确保显示小数点
|
| 30 |
+
print(f"{score:.1f}\t\t\t{default_adjusted:.1f}\t\t\t\t\t{new_adjusted_1:.1f}\t\t\t\t\t{new_adjusted_2:.1f}")
|
test/test_score_adjustment_demo.pyc
ADDED
|
Binary file (910 Bytes). View file
|
|
|
test/test_sky.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
from modelscope.outputs import OutputKeys
|
| 5 |
+
from modelscope.pipelines import pipeline
|
| 6 |
+
from modelscope.utils.constant import Tasks
|
| 7 |
+
|
| 8 |
+
image_skychange = pipeline(Tasks.image_skychange,
|
| 9 |
+
model='iic/cv_hrnetocr_skychange')
|
| 10 |
+
result = image_skychange(
|
| 11 |
+
{'sky_image': '~/Downloads/sky_image.jpg',
|
| 12 |
+
'scene_image': '/opt/data/face/NXEo0zusSaNB2fa232c84898e92ff165e2dfee59cb54.jpg'})
|
| 13 |
+
cv2.imwrite('~/Downloads/result.png',
|
| 14 |
+
result[OutputKeys.OUTPUT_IMG])
|
| 15 |
+
print(f'Output written to {osp.abspath("result.png")}')
|
test/test_sky.pyc
ADDED
|
Binary file (683 Bytes). View file
|
|
|
test_tensorflow.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import deepface
|
| 2 |
+
import sys
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
import keras
|
| 7 |
+
|
| 8 |
+
keras_pkg = "keras (standalone)"
|
| 9 |
+
keras_ver = keras.__version__
|
| 10 |
+
except Exception:
|
| 11 |
+
from tensorflow import keras
|
| 12 |
+
|
| 13 |
+
keras_pkg = "tf.keras"
|
| 14 |
+
keras_ver = keras.__version__
|
| 15 |
+
|
| 16 |
+
print("py =", sys.version)
|
| 17 |
+
print("deepface =", deepface.__version__)
|
| 18 |
+
print("tensorflow =", tf.__version__)
|
| 19 |
+
print("keras pkg =", keras_pkg, "keras =", keras_ver)
|