chawin.chen commited on
Commit
7a6cb13
·
1 Parent(s): e499f6c
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitignore ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HELP.md
2
+ target/
3
+ output/
4
+ !.mvn/wrapper/maven-wrapper.jar
5
+ !**/src/main/**/target/
6
+ !**/src/test/**/target/
7
+
8
+ .flattened-pom.xml
9
+
10
+ ### STS ###
11
+ .apt_generated
12
+ .classpath
13
+ .factorypath
14
+ .project
15
+ .settings
16
+ .springBeans
17
+ .sts4-cache
18
+
19
+ ### IntelliJ IDEA ###
20
+ .idea
21
+ *.iws
22
+ *.iml
23
+ *.ipr
24
+
25
+ ### NetBeans ###
26
+ /nbproject/private/
27
+ /nbbuild/
28
+ /dist/
29
+ /nbdist/
30
+ /.nb-gradle/
31
+ build/
32
+ !**/src/main/**/build/
33
+ !**/src/test/**/build/
34
+
35
+ ### VS Code ###
36
+ .vscode/
37
+
38
+ ### LOG ###
39
+ logs/
40
+
41
+ *.class
42
+
43
+ **/node_modules/
44
+ /*.log
45
+ /output/
46
+ /faiss/
47
+ /web/facelist-web/
48
+ **/._*
49
+ __pycache__/
50
+ .DS_Store
51
+ *.pth
52
+ /data/celebrity_faces/ds_model_arcface_detector_retinaface_aligned_normalization_base_expand_0.pkl
53
+ /data/celebrity_faces/jpeg_6c06eca6.jpeg
54
+ /data/celebrity_faces/jpeg_51e1394b.jpeg
55
+ /data/celebrity_faces/jpeg_66fee390.jpeg
56
+ /data/celebrity_faces/jpeg_70b86102.jpeg
57
+ /data/celebrity_faces/jpeg_406b961a.jpeg
58
+ /data/celebrity_faces/jpeg_1321f87f.jpeg
59
+ /data/celebrity_faces/jpeg_b56ae384.jpeg
60
+ /data/celebrity_faces/jpeg_c07cdb46.jpeg
61
+ /data/celebrity_faces/jpeg_c7353005.jpeg
62
+ /data/celebrity_faces/jpeg_d4cb0602.jpeg
63
+ /data/celebrity_faces/jpeg_dbb64030.jpeg
64
+ /data/celebrity_faces/jpeg_fc652ad4.jpeg
65
+ /data/celebrity_faces/jpeg_fd6b0869.jpeg
66
+ /data/celebrity_embeddings.db
Dockerfile ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV TZ=Asia/Shanghai \
4
+ OUTPUT_DIR=/opt/data/output \
5
+ IMAGES_DIR=/opt/data/images \
6
+ MODELS_PATH=/opt/data/models \
7
+ DEEPFACE_HOME=/opt/data/models \
8
+ FAISS_INDEX_DIR=/opt/data/faiss \
9
+ CELEBRITY_SOURCE_DIR=/opt/data/chinese_celeb_dataset \
10
+ GENDER_CONFIDENCE=1 \
11
+ UPSCALE_SIZE=2 \
12
+ AGE_CONFIDENCE=0.1 \
13
+ DRAW_SCORE=true \
14
+ FACE_CONFIDENCE=0.7 \
15
+ ENABLE_DDCOLOR=true \
16
+ ENABLE_GFPGAN=true \
17
+ ENABLE_REALESRGAN=true \
18
+ ENABLE_ANIME_STYLE=true \
19
+ ENABLE_RVM=true \
20
+ ENABLE_REMBG=true \
21
+ ENABLE_CLIP=false \
22
+ CLEANUP_INTERVAL_HOURS=1 \
23
+ CLEANUP_AGE_HOURS=1 \
24
+ BEAUTY_ADJUST_GAMMA=0.8 \
25
+ BEAUTY_ADJUST_MIN=1.0 \
26
+ BEAUTY_ADJUST_MAX=9.0 \
27
+ ENABLE_ANIME_PRELOAD=false \
28
+ ENABLE_LOGGING=true \
29
+ BEAUTY_ADJUST_ENABLED=true \
30
+ RVM_LOCAL_REPO=/opt/data/models/RobustVideoMatting \
31
+ RVM_WEIGHTS_PATH=/opt/data/models/torch/hub/checkpoints/rvm_resnet50.pth \
32
+ RVM_MODEL=resnet50 \
33
+ AUTO_INIT_GFPGAN=false \
34
+ AUTO_INIT_DDCOLOR=false \
35
+ AUTO_INIT_REALESRGAN=false \
36
+ AUTO_INIT_ANIME_STYLE=false \
37
+ AUTO_INIT_CLIP=false \
38
+ AUTO_INIT_RVM=false \
39
+ AUTO_INIT_REMBG=false \
40
+ ENABLE_WARMUP=true \
41
+ REALESRGAN_MODEL=realesr-general-x4v3 \
42
+ CELEBRITY_FIND_THRESHOLD=0.87 \
43
+ FEMALE_AGE_ADJUSTMENT=4 \
44
+ HOSTNAME=HG
45
+
46
+ RUN mkdir -p /opt/data/chinese_celeb_dataset /opt/data/faiss /opt/data/models /opt/data/images /opt/data/output
47
+ WORKDIR /app
48
+ COPY requirements.txt .
49
+ COPY *.py /app/
50
+
51
+ # 安装必要的系统工具和依赖
52
+ RUN apt-get update && apt-get install -y --no-install-recommends \
53
+ build-essential \
54
+ cmake \
55
+ git \
56
+ wget \
57
+ curl \
58
+ ca-certificates \
59
+ libopenblas-dev \
60
+ liblapack-dev \
61
+ libx11-dev \
62
+ libgtk-3-dev \
63
+ libboost-python-dev \
64
+ libglib2.0-0 \
65
+ libsm6 \
66
+ libxext6 \
67
+ libxrender-dev \
68
+ libgomp1 \
69
+ && rm -rf /var/lib/apt/lists/*
70
+
71
+ RUN pip install --upgrade pip
72
+ # 安装所有依赖 - 现在可以一次性完成
73
+ RUN pip install --no-cache-dir -r requirements.txt
74
+ EXPOSE 7860
75
+ CMD ["uvicorn", "app:app", "--workers", "1", "--loop", "asyncio", "--http", "httptools", "--host", "0.0.0.0", "--port", "7860", "--timeout-keep-alive", "600"]
76
+
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Picpocket2
3
- emoji: 💻
4
- colorFrom: indigo
5
- colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  ---
 
1
  ---
2
+ title: Picpocket
3
+ emoji: 🔥
4
+ colorFrom: yellow
5
+ colorTo: red
6
  sdk: docker
7
  pinned: false
8
  ---
anime_stylizer.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+
5
+ import cv2
6
+
7
+ from config import logger
8
+
9
+
10
+ class AnimeStylizer:
11
+ def __init__(self):
12
+ start_time = time.perf_counter()
13
+ self.stylizers = {} # 存储不同风格的模型
14
+ self.current_style = None
15
+ self.current_stylizer = None
16
+
17
+ # 检查是否启用Anime Style功能
18
+ from config import ENABLE_ANIME_STYLE
19
+ if ENABLE_ANIME_STYLE:
20
+ self._initialize_models()
21
+ else:
22
+ logger.info("Anime Style feature is disabled, skipping model initialization")
23
+ init_time = time.perf_counter() - start_time
24
+ if hasattr(self, 'model_configs') and len(self.model_configs) > 0:
25
+ logger.info(f"AnimeStylizer initialized successfully, time: {init_time:.3f}s")
26
+ else:
27
+ logger.info(f"AnimeStylizer initialization completed but not available, time: {init_time:.3f}s")
28
+
29
+ def _initialize_models(self):
30
+ """初始化所有Anime Style模型(使用ModelScope)"""
31
+ try:
32
+ logger.info("Initializing multiple Anime Style models (using ModelScope)...")
33
+
34
+ # 添加torch类型兼容性补丁
35
+ import torch
36
+ if not hasattr(torch, 'uint64'):
37
+ logger.info("Adding torch.uint64 compatibility patch...")
38
+ torch.uint64 = torch.int64 # 使用int64作为uint64的替代
39
+ if not hasattr(torch, 'uint32'):
40
+ logger.info("Adding torch.uint32 compatibility patch...")
41
+ torch.uint32 = torch.int32 # 使用int32作为uint32的替代
42
+ if not hasattr(torch, 'uint16'):
43
+ logger.info("Adding torch.uint16 compatibility patch...")
44
+ torch.uint16 = torch.int16 # 使用int16作为uint16的替代
45
+
46
+ # 导入ModelScope相关模块
47
+ from modelscope.outputs import OutputKeys
48
+ from modelscope.pipelines import pipeline
49
+ from modelscope.utils.constant import Tasks
50
+
51
+ self.OutputKeys = OutputKeys
52
+
53
+ # 定义所有可用的模型和风格
54
+ self.model_configs = {
55
+ "handdrawn": {
56
+ "model_id": "iic/cv_unet_person-image-cartoon-handdrawn_compound-models",
57
+ "name": "手绘风格",
58
+ "description": "手绘动漫风格 - 传统手绘感觉,线条清晰"
59
+ },
60
+ "disney": {
61
+ "model_id": "iic/cv_unet_person-image-cartoon-3d_compound-models",
62
+ "name": "迪士尼风格",
63
+ "description": "迪士尼风格 - 立体感强,色彩鲜艳"
64
+ },
65
+ "illustration": {
66
+ "model_id": "iic/cv_unet_person-image-cartoon-sd-design_compound-models",
67
+ "name": "插画风格",
68
+ "description": "插画风格 - 现代插画设计感"
69
+ },
70
+ "artstyle": {
71
+ "model_id": "iic/cv_unet_person-image-cartoon-artstyle_compound-models",
72
+ "name": "艺术风格",
73
+ "description": "艺术风格 - 独特的艺术表现力"
74
+ },
75
+ "anime": {
76
+ "model_id": "iic/cv_unet_person-image-cartoon_compound-models",
77
+ "name": "二次元风格",
78
+ "description": "二次元风格 - 经典动漫角色风格"
79
+ },
80
+ "sketch": {
81
+ "model_id": "iic/cv_unet_person-image-cartoon-sketch_compound-models",
82
+ "name": "素描风格",
83
+ "description": "素描风格 - 黑白素描画效果"
84
+ }
85
+ }
86
+
87
+ logger.info(f"Defined {len(self.model_configs)} anime style model configurations")
88
+ logger.info("Models will be loaded on-demand when first used to save memory")
89
+
90
+ # 检查是否启用预加载
91
+ try:
92
+ from config import ENABLE_ANIME_PRELOAD
93
+ if ENABLE_ANIME_PRELOAD:
94
+ logger.info("Enabling anime style model preloading...")
95
+ self.preload_models()
96
+ else:
97
+ logger.info("Anime style model preloading is disabled, will be loaded on-demand when first used")
98
+ except ImportError:
99
+ logger.info("Anime style model preloading configuration not found, will be loaded on-demand when first used")
100
+
101
+ except ImportError as e:
102
+ logger.error(f"ModelScope module import failed: {e}")
103
+ self.model_configs = {}
104
+ except Exception as e:
105
+ logger.error(f"Anime Style model initialization failed: {e}")
106
+ self.model_configs = {}
107
+
108
+ def _load_model(self, style_type):
109
+ """按需加载指定风格的模型"""
110
+ if style_type not in self.model_configs:
111
+ logger.error(f"Unsupported style type: {style_type}")
112
+ return False
113
+
114
+ if style_type in self.stylizers:
115
+ logger.info(f"Model {style_type} already loaded, using directly")
116
+ return True
117
+
118
+ try:
119
+ from modelscope.pipelines import pipeline
120
+ from modelscope.utils.constant import Tasks
121
+
122
+ config = self.model_configs[style_type]
123
+ logger.info(f"Loading {config['name']} model: {config['model_id']}")
124
+
125
+ # 根据模型类型选择合适的任务类型
126
+ if "stable_diffusion" in config["model_id"]:
127
+ # Stable Diffusion 系列模型使用文生图任务类型
128
+ task_type = Tasks.text_to_image_synthesis
129
+ logger.info(f"Using text_to_image_synthesis task type to load Stable Diffusion model")
130
+ else:
131
+ # UNet 系列模型使用人像风格化任务
132
+ task_type = Tasks.image_portrait_stylization
133
+ logger.info(f"Using image_portrait_stylization task type to load UNet model")
134
+
135
+ stylizer = pipeline(task_type, model=config["model_id"])
136
+ self.stylizers[style_type] = stylizer
137
+
138
+ logger.info(f"{config['name']} model loaded successfully")
139
+ return True
140
+
141
+ except Exception as e:
142
+ logger.error(f"Failed to load {style_type} model: {e}")
143
+ return False
144
+
145
+ def preload_models(self, style_types=None):
146
+ """
147
+ 预加载指定的动漫风格模型
148
+ :param style_types: 要预加载的风格类型列表,如果为None则预加载所有模型
149
+ """
150
+ if not self.is_available():
151
+ logger.warning("Anime Style module is not available, cannot preload models")
152
+ return
153
+
154
+ if style_types is None:
155
+ style_types = list(self.model_configs.keys())
156
+ elif isinstance(style_types, str):
157
+ style_types = [style_types]
158
+
159
+ logger.info(f"Starting to preload anime style models: {style_types}")
160
+
161
+ successful_loads = []
162
+ failed_loads = []
163
+
164
+ for style_type in style_types:
165
+ if style_type not in self.model_configs:
166
+ logger.warning(f"Unknown style type: {style_type}, skipping preload")
167
+ failed_loads.append(style_type)
168
+ continue
169
+
170
+ try:
171
+ logger.info(f"Preloading model: {self.model_configs[style_type]['name']} ({style_type})")
172
+ if self._load_model(style_type):
173
+ successful_loads.append(style_type)
174
+ logger.info(f"✓ Successfully preloaded: {self.model_configs[style_type]['name']}")
175
+ else:
176
+ failed_loads.append(style_type)
177
+ logger.error(f"✗ Preload failed: {self.model_configs[style_type]['name']}")
178
+ except Exception as e:
179
+ logger.error(f"✗ Exception occurred while preloading model {style_type}: {e}")
180
+ failed_loads.append(style_type)
181
+
182
+ if successful_loads:
183
+ logger.info(f"Successfully preloaded models ({len(successful_loads)}): {successful_loads}")
184
+ if failed_loads:
185
+ logger.warning(f"Failed to preload models ({len(failed_loads)}): {failed_loads}")
186
+
187
+ logger.info(f"Anime style model preloading completed, success: {len(successful_loads)}/{len(style_types)}")
188
+
189
+ def get_loaded_models(self):
190
+ """
191
+ 获取已加载的模型列表
192
+ :return: 已加载的模型风格类型列表
193
+ """
194
+ return list(self.stylizers.keys())
195
+
196
+ def is_model_loaded(self, style_type):
197
+ """
198
+ 检查指定风格的模型是否已加载
199
+ :param style_type: 风格类型
200
+ :return: 是否已加载
201
+ """
202
+ return style_type in self.stylizers
203
+
204
+ def get_preload_status(self):
205
+ """
206
+ 获取模型预加载状态
207
+ :return: 包含预加载状态的字典
208
+ """
209
+ total_models = len(self.model_configs)
210
+ loaded_models = len(self.stylizers)
211
+
212
+ status = {
213
+ "total_models": total_models,
214
+ "loaded_models": loaded_models,
215
+ "preload_ratio": f"{loaded_models}/{total_models}",
216
+ "preload_percentage": round((loaded_models / total_models * 100) if total_models > 0 else 0, 1),
217
+ "available_styles": list(self.model_configs.keys()),
218
+ "loaded_styles": list(self.stylizers.keys()),
219
+ "unloaded_styles": [style for style in self.model_configs.keys() if style not in self.stylizers]
220
+ }
221
+
222
+ return status
223
+
224
+ def is_available(self):
225
+ """检查Anime Stylizer是否可用"""
226
+ return hasattr(self, 'model_configs') and len(self.model_configs) > 0
227
+
228
+ def stylize_image(self, image, style_type="disney"):
229
+ """
230
+ 对图像进行动漫风格化
231
+ :param image: 输入图像 (numpy array, BGR格式)
232
+ :param style_type: 动漫风格类型,支持的类型:
233
+ "handdrawn" - 手绘风格
234
+ "disney" - 迪士尼风格 (默认)
235
+ "illustration" - 插画风格
236
+ "flat" - 扁平风格
237
+ "clipart" - 剪贴画风格
238
+ "watercolor" - 水彩风格
239
+ "artstyle" - 艺术风格
240
+ "anime" - 二次元风格
241
+ "sketch" - 素描风格
242
+ :return: 动漫风格化后的图像 (numpy array, BGR格式)
243
+ """
244
+ if not self.is_available():
245
+ logger.error("Anime Style model not initialized")
246
+ return image
247
+
248
+ # 加载指定风格的模型
249
+ if not self._load_model(style_type):
250
+ logger.error(f"Failed to load {style_type} model")
251
+ return image
252
+
253
+ return self._stylize_image_via_file(image, style_type)
254
+
255
+ def _stylize_image_via_file(self, image, style_type="disney"):
256
+ """
257
+ 通过临时文件进行动漫风格化
258
+ :param image: 输入图像 (numpy array, BGR格式)
259
+ :param style_type: 动漫风格类型
260
+ :return: 动漫风格化后的图像 (numpy array, BGR格式)
261
+ """
262
+ try:
263
+ config = self.model_configs.get(style_type, {})
264
+ style_name = config.get('name', style_type)
265
+ logger.info(f"Using anime stylization processing, style type: {style_name} ({style_type})")
266
+
267
+ # 验证风格类型
268
+ if style_type not in self.model_configs:
269
+ logger.warning(f"Invalid style type: {style_type}, using default style disney")
270
+ style_type = "disney"
271
+
272
+ # 使用最高质量设置保存临时图像
273
+ with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_input:
274
+ # 使用WebP格式,最高质量设置
275
+ cv2.imwrite(tmp_input.name, image, [cv2.IMWRITE_WEBP_QUALITY, 100])
276
+ tmp_input_path = tmp_input.name
277
+
278
+ try:
279
+ logger.info(f"Temporary file saved to: {tmp_input_path}")
280
+
281
+ # 使用ModelScope进行动漫风格化
282
+ stylizer = self.stylizers[style_type]
283
+
284
+ # 根据模型类型使用不同的调用方式
285
+ if "stable_diffusion" in config["model_id"]:
286
+ # Stable Diffusion模型需要特殊处理
287
+ logger.info("Using Stable Diffusion model, text parameter is required")
288
+ # 对于Stable Diffusion,必须使用'sks style'格式的提示词
289
+ style_prompts = {}
290
+ prompt = style_prompts.get(style_type, "sks style, cartoon style artwork")
291
+ logger.info(f"Using prompt: {prompt}")
292
+ result = stylizer({"text": prompt})
293
+ else:
294
+ # UNet模型直接处理
295
+ result = stylizer(tmp_input_path)
296
+
297
+ # 获取风格化后的图像
298
+ # 不同模型的输出键名可能不同,需要适配
299
+ if "stable_diffusion" in config["model_id"]:
300
+ # Stable Diffusion模型通常使用不同的输出键名
301
+ logger.info(f"Stable Diffusion model output keys: {list(result.keys())}")
302
+ if 'output_imgs' in result:
303
+ stylized_image = result['output_imgs'][0]
304
+ elif 'output_img' in result:
305
+ stylized_image = result['output_img']
306
+ elif self.OutputKeys.OUTPUT_IMG in result:
307
+ stylized_image = result[self.OutputKeys.OUTPUT_IMG]
308
+ else:
309
+ # 尝试获取第一个图像输出
310
+ for key in result.keys():
311
+ if isinstance(result[key], (list, tuple)) and len(result[key]) > 0:
312
+ stylized_image = result[key][0]
313
+ logger.info(f"Using output key: {key}")
314
+ break
315
+ elif hasattr(result[key], 'shape'):
316
+ stylized_image = result[key]
317
+ logger.info(f"Using output key: {key}")
318
+ break
319
+ else:
320
+ raise KeyError(f"未找到有效的图像输出键,可用键: {list(result.keys())}")
321
+ else:
322
+ # UNet模型使用标准输出键
323
+ stylized_image = result[self.OutputKeys.OUTPUT_IMG]
324
+
325
+ logger.info(f"Anime stylization output: size={stylized_image.shape}, type={stylized_image.dtype}")
326
+
327
+ # ModelScope输出的图像已经是BGR格式,不需要转换
328
+ logger.info("Anime stylization processing completed")
329
+ return stylized_image
330
+
331
+ finally:
332
+ # 清理临时文件
333
+ try:
334
+ os.unlink(tmp_input_path)
335
+ except:
336
+ pass
337
+
338
+ except Exception as e:
339
+ logger.error(f"Anime stylization processing failed: {e}")
340
+ logger.info("Returning original image")
341
+ return image
342
+
343
+ def get_available_styles(self):
344
+ """
345
+ 获取支持的动漫风格类型
346
+ :return: 字典,包含风格代码和描述
347
+ """
348
+ if not hasattr(self, 'model_configs'):
349
+ return {}
350
+
351
+ return {
352
+ style_type: f"{config['name']} - {config['description'].split(' - ')[1]}"
353
+ for style_type, config in self.model_configs.items()
354
+ }
355
+
356
+ def save_debug_image(self, image, filename_prefix):
357
+ """保存调试用的图像"""
358
+ try:
359
+ debug_path = f"{filename_prefix}_debug.webp"
360
+ cv2.imwrite(debug_path, image, [cv2.IMWRITE_WEBP_QUALITY, 95])
361
+ logger.info(f"Debug image saved: {debug_path}")
362
+ return debug_path
363
+ except Exception as e:
364
+ logger.error(f"Failed to save debug image: {e}")
365
+ return None
366
+
367
+ def test_stylization(self, test_url=None):
368
+ """
369
+ 测试动漫风格化功能
370
+ :param test_url: 测试图像URL,默认使用官方示例
371
+ :return: 测试结果
372
+ """
373
+ if not self.is_available():
374
+ return False, "Anime Style模型未初始化"
375
+
376
+ try:
377
+ test_url = test_url or 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/portrait.jpg'
378
+ logger.info(f"Testing anime stylization feature, using image: {test_url}")
379
+
380
+ # 测试默认风格
381
+ result = self.stylizer(test_url)
382
+ stylized_img = result[self.OutputKeys.OUTPUT_IMG]
383
+
384
+ # 保存测试结果
385
+ test_output_path = 'anime_style_test_result.webp'
386
+ cv2.imwrite(test_output_path, stylized_img, [cv2.IMWRITE_WEBP_QUALITY, 95])
387
+
388
+ logger.info(f"Anime stylization test successful, result saved to: {test_output_path}")
389
+ return True, f"测试成功,结果保存到: {test_output_path}"
390
+
391
+ except Exception as e:
392
+ logger.error(f"Anime stylization test failed: {e}")
393
+ return False, f"测试失败: {e}"
394
+
395
+ def test_local_image(self, image_path, style_type="disney"):
396
+ """
397
+ 测试本地图像动漫风格化
398
+ :param image_path: 本地图像路径
399
+ :param style_type: 动漫风格类型
400
+ :return: 测试结果
401
+ """
402
+ if not self.is_available():
403
+ return False, "Anime Style模型未初始化"
404
+
405
+ try:
406
+ logger.info(f"Testing local image anime stylization: {image_path}, style: {style_type}")
407
+
408
+ # 读取本地图像
409
+ image = cv2.imread(image_path)
410
+ if image is None:
411
+ return False, f"Unable to read image: {image_path}"
412
+
413
+ # 保存原图用于对比
414
+ self.save_debug_image(image, "original")
415
+
416
+ # 动漫风格化处理
417
+ stylized_image = self.stylize_image(image, style_type)
418
+
419
+ # 保存风格化结果
420
+ result_path = self.save_debug_image(stylized_image, f"anime_style_{style_type}")
421
+
422
+ logger.info(f"Local image anime stylization successful, result saved to: {result_path}")
423
+ return True, f"本地图像动漫风格化成功,结果保存到: {result_path}"
424
+
425
+ except Exception as e:
426
+ logger.error(f"Local image anime stylization failed: {e}")
427
+ return False, f"本地图像动漫风格化失败: {e}"
api_routes.py ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from contextlib import asynccontextmanager
4
+
5
+ from fastapi import FastAPI
6
+ from starlette.middleware.cors import CORSMiddleware
7
+
8
+ from cleanup_scheduler import start_cleanup_scheduler, stop_cleanup_scheduler
9
+ from config import (
10
+ logger,
11
+ OUTPUT_DIR,
12
+ DEEPFACE_AVAILABLE,
13
+ DLIB_AVAILABLE,
14
+ MODELS_PATH,
15
+ IMAGES_DIR,
16
+ YOLO_AVAILABLE,
17
+ ENABLE_LOGGING,
18
+ HUGGINGFACE_SYNC_ENABLED,
19
+ )
20
+ from database import close_mysql_pool, init_mysql_pool
21
+ from utils import ensure_bos_resources, ensure_huggingface_models
22
+
23
+ logger.info("Starting to import api_routes module...")
24
+
25
+ if HUGGINGFACE_SYNC_ENABLED:
26
+ try:
27
+ t_hf_start = time.perf_counter()
28
+ if not ensure_huggingface_models():
29
+ raise RuntimeError("无法从 HuggingFace 同步模型,请检查配置与网络")
30
+ hf_time = time.perf_counter() - t_hf_start
31
+ logger.info("HuggingFace 模型同步完成,用时 %.3fs", hf_time)
32
+ except Exception as exc:
33
+ logger.error(f"HuggingFace model preparation failed: {exc}")
34
+ raise
35
+ else:
36
+ logger.info("已关闭 HuggingFace 模型同步开关,跳过启动阶段的同步步骤")
37
+
38
+ try:
39
+ t_bos_start = time.perf_counter()
40
+ if not ensure_bos_resources():
41
+ raise RuntimeError("无法从 BOS 同步模型与数据,请检查凭证与网络")
42
+ bos_time = time.perf_counter() - t_bos_start
43
+ logger.info(f"BOS resources synchronized successfully, time: {bos_time:.3f}s")
44
+ except Exception as exc:
45
+ logger.error(f"BOS resource preparation failed: {exc}")
46
+ raise
47
+
48
+ try:
49
+ t_start = time.perf_counter()
50
+ from api_routes import api_router, extract_chinese_celeb_dataset_sync
51
+ import_time = time.perf_counter() - t_start
52
+ logger.info(f"api_routes module imported successfully, time: {import_time:.3f}s")
53
+ except Exception as e:
54
+ import_time = time.perf_counter() - t_start
55
+ logger.error(f"api_routes module import failed, time: {import_time:.3f}s, error: {e}")
56
+ raise
57
+
58
+ try:
59
+ t_extract_start = time.perf_counter()
60
+ extract_result = extract_chinese_celeb_dataset_sync()
61
+ extract_time = time.perf_counter() - t_extract_start
62
+ logger.info(
63
+ "Chinese celeb dataset extracted successfully, time: %.3fs, target: %s",
64
+ extract_time,
65
+ extract_result.get("target_dir"),
66
+ )
67
+ except Exception as exc:
68
+ logger.error(f"Failed to extract Chinese celeb dataset automatically: {exc}")
69
+ raise
70
+
71
+
72
+ @asynccontextmanager
73
+ async def lifespan(app: FastAPI):
74
+ start_time = time.perf_counter()
75
+ logger.info("FaceScore service starting...")
76
+ logger.info(f"Output directory: {OUTPUT_DIR}")
77
+ logger.info(f"DeepFace available: {DEEPFACE_AVAILABLE}")
78
+ logger.info(f"YOLO available: {YOLO_AVAILABLE}")
79
+ logger.info(f"MediaPipe available: {DLIB_AVAILABLE}")
80
+ logger.info(f"Archive directory: {IMAGES_DIR}")
81
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
82
+
83
+ # 初始化数据库连接池
84
+ try:
85
+ await init_mysql_pool()
86
+ logger.info("MySQL 连接池初始化完成")
87
+ except Exception as exc:
88
+ logger.error(f"初始化 MySQL 连接池失败: {exc}")
89
+ raise
90
+
91
+ # 启动图片清理定时任务
92
+ logger.info("Starting image cleanup scheduled task...")
93
+ try:
94
+ start_cleanup_scheduler()
95
+ logger.info("Image cleanup scheduled task started successfully")
96
+ except Exception as e:
97
+ logger.error(f"Failed to start image cleanup scheduled task: {e}")
98
+
99
+ # 记录启动完成时间
100
+ total_startup_time = time.perf_counter() - start_time
101
+ logger.info(f"FaceScore service startup completed, total time: {total_startup_time:.3f}s")
102
+
103
+ yield
104
+
105
+ # 应用关闭时停止定时任务
106
+ logger.info("Stopping image cleanup scheduled task...")
107
+ try:
108
+ stop_cleanup_scheduler()
109
+ logger.info("Image cleanup scheduled task stopped")
110
+ except Exception as e:
111
+ logger.error(f"Failed to stop image cleanup scheduled task: {e}")
112
+
113
+ # 关闭数据库连接池
114
+ try:
115
+ await close_mysql_pool()
116
+ except Exception as exc:
117
+ logger.warning(f"关闭 MySQL 连接池失败: {exc}")
118
+
119
+
120
+ # 创建 FastAPI 应用
121
+ app = FastAPI(
122
+ title="Enhanced FaceScore 服务",
123
+ description="支持多模型的人脸分析REST API服务,包含五官评分功能。支持混合模式:HowCuteAmI(颜值+性别)+ DeepFace(年龄+情绪)",
124
+ version="3.0.0",
125
+ docs_url="/cp_docs",
126
+ redoc_url="/cp_redoc",
127
+ lifespan=lifespan,
128
+ )
129
+
130
+ app.add_middleware(
131
+ CORSMiddleware,
132
+ allow_origins=["*"],
133
+ allow_methods=["*"],
134
+ allow_headers=["*"],
135
+ )
136
+
137
+ # 注册路由
138
+ app.include_router(api_router)
139
+
140
+ # 添加根路径处理
141
+ @app.get("/")
142
+ async def root():
143
+ return "UP"
144
+
145
+
146
+ if __name__ == "__main__":
147
+ import uvicorn
148
+
149
+ if not os.path.exists(MODELS_PATH):
150
+ logger.critical(
151
+ "Warning: 'models' directory not found. Please ensure it exists and contains model files."
152
+ )
153
+ logger.critical(
154
+ "Exiting application as FaceAnalyzer cannot be initialized without models."
155
+ )
156
+ exit(1)
157
+
158
+ # 根据日志开关配置 Uvicorn 日志
159
+ if ENABLE_LOGGING:
160
+ uvicorn.run(app, host="0.0.0.0", port=8080, reload=False)
161
+ else:
162
+ # 禁用 Uvicorn 的访问日志和错误日志
163
+ uvicorn.run(
164
+ app,
165
+ host="0.0.0.0",
166
+ port=8080,
167
+ reload=False,
168
+ access_log=False, # 禁用访问日志
169
+ log_level="critical" # 只显示严重错误
170
+ )
build.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ python -m compileall -q -f -b .
2
+ mv *.pyc /opt/data/app/
3
+ cp gfpgan_restorer.py /opt/data/app/
4
+ cp start_local.sh /opt/data/app/
cleanup_scheduler.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 定时清理图片文件模块
3
+ 每小时检查一次IMAGES_DIR目录,删除1小时以前的图片文件
4
+ """
5
+ import glob
6
+ import os
7
+ import time
8
+ from datetime import datetime
9
+
10
+ from apscheduler.schedulers.background import BackgroundScheduler
11
+
12
+ from config import logger, IMAGES_DIR, CLEANUP_INTERVAL_HOURS, CLEANUP_AGE_HOURS
13
+
14
+
15
+ # from utils import delete_file_from_bos # 暂时注释掉删除BOS文件的功能
16
+
17
+
18
+ class ImageCleanupScheduler:
19
+ """图片清理定时任务类"""
20
+
21
+ def __init__(self, images_dir=None, cleanup_hours=None, interval_hours=None):
22
+ """
23
+ 初始化清理调度器
24
+
25
+ Args:
26
+ images_dir (str): 图片目录路径,默认使用config中的IMAGES_DIR
27
+ cleanup_hours (float): 清理时间阈值(小时),默认使用环境变量CLEANUP_AGE_HOURS
28
+ interval_hours (float): 定时任务执行间隔(小时),默认使用环境变量CLEANUP_INTERVAL_HOURS
29
+ """
30
+ self.images_dir = images_dir or IMAGES_DIR
31
+ self.cleanup_hours = cleanup_hours if cleanup_hours is not None else CLEANUP_AGE_HOURS
32
+ self.interval_hours = interval_hours if interval_hours is not None else CLEANUP_INTERVAL_HOURS
33
+ self.scheduler = BackgroundScheduler()
34
+ self.is_running = False
35
+
36
+ # 确保目录存在
37
+ os.makedirs(self.images_dir, exist_ok=True)
38
+ logger.info(f"Image cleanup scheduler initialized, monitoring directory: {self.images_dir}, cleanup threshold: {self.cleanup_hours} hours, execution interval: {self.interval_hours} hours")
39
+
40
+ def cleanup_old_images(self):
41
+ """
42
+ 清理过期的图片文件
43
+ 删除超过指定时间的图片文件
44
+ """
45
+ try:
46
+ current_time = time.time()
47
+ cutoff_time = current_time - (self.cleanup_hours * 3600) # 转换为秒
48
+ cutoff_datetime = datetime.fromtimestamp(cutoff_time)
49
+
50
+ # 支持的图片格式
51
+ image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.gif', '*.bmp']
52
+ deleted_files = []
53
+ total_size_deleted = 0
54
+
55
+ logger.info(f"Starting to clean image directory: {self.images_dir}")
56
+ logger.info(f"Cleanup threshold time: {cutoff_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
57
+
58
+ # 遍历所有图片文件
59
+ for extension in image_extensions:
60
+ pattern = os.path.join(self.images_dir, extension)
61
+ for file_path in glob.glob(pattern):
62
+ try:
63
+ # 获取文件修改时间
64
+ file_mtime = os.path.getmtime(file_path)
65
+
66
+ # 如果文件时间早于阈值时间,则删除
67
+ if file_mtime < cutoff_time:
68
+ file_size = os.path.getsize(file_path)
69
+ file_time = datetime.fromtimestamp(file_mtime)
70
+
71
+ # 删除文件
72
+ os.remove(file_path)
73
+ # delete_file_from_bos(file_path) # 暂时注释掉删除BOS文件
74
+ deleted_files.append(os.path.basename(file_path))
75
+ total_size_deleted += file_size
76
+
77
+ logger.info(f"Deleting expired file: {os.path.basename(file_path)} ")
78
+
79
+ except (OSError, IOError) as e:
80
+ logger.error(f"Failed to delete file {os.path.basename(file_path)}: {e}")
81
+ continue
82
+
83
+ logger.info(f"Cleanup completed! Deleted {len(deleted_files)} files, ")
84
+ logger.info(f"Deleted file list: {', '.join(deleted_files[:10])}")
85
+ else:
86
+ logger.info("Cleanup completed! No expired files found to clean")
87
+
88
+ return {
89
+ 'success': True,
90
+ 'deleted_count': len(deleted_files),
91
+ 'deleted_size': total_size_deleted,
92
+ 'deleted_files': deleted_files,
93
+ 'cutoff_time': cutoff_datetime.isoformat()
94
+ }
95
+
96
+ except Exception as e:
97
+ error_msg = f"图片清理任务执行失败: {e}"
98
+ logger.error(error_msg)
99
+ return {
100
+ 'success': False,
101
+ 'error': str(e),
102
+ 'deleted_count': 0,
103
+ 'deleted_size': 0
104
+ }
105
+
106
+ def _format_size(self, size_bytes):
107
+ """格式化文件大小显示"""
108
+ if size_bytes == 0:
109
+ return "0 B"
110
+ size_names = ["B", "KB", "MB", "GB"]
111
+ i = 0
112
+ while size_bytes >= 1024 and i < len(size_names) - 1:
113
+ size_bytes /= 1024.0
114
+ i += 1
115
+ return f"{size_bytes:.1f} {size_names[i]}"
116
+
117
+ def start(self):
118
+ """启动定时清理任务"""
119
+ if self.is_running:
120
+ logger.warning("Image cleanup scheduler is already running")
121
+ return
122
+
123
+ try:
124
+ # 添加定时任务:使用可配置的执���间隔
125
+ self.scheduler.add_job(
126
+ func=self.cleanup_old_images,
127
+ trigger='interval',
128
+ hours=self.interval_hours, # 使用环境变量配置的执行间隔
129
+ id='image_cleanup',
130
+ name='image clean tast',
131
+ replace_existing=True
132
+ )
133
+
134
+ # 启动调度器
135
+ self.scheduler.start()
136
+ self.is_running = True
137
+
138
+ logger.info(f"Image cleanup scheduler started, will execute cleanup task every {self.interval_hours} hours")
139
+
140
+ # 立即执行一次清理(可选)
141
+ logger.info("Executing image cleanup task immediately...")
142
+ self.cleanup_old_images()
143
+
144
+ except Exception as e:
145
+ logger.error(f"Failed to start image cleanup scheduler: {e}")
146
+ raise
147
+
148
+ def stop(self):
149
+ """停止定时清理任务"""
150
+ if not self.is_running:
151
+ logger.warning("Image cleanup scheduler is not running")
152
+ return
153
+
154
+ try:
155
+ self.scheduler.shutdown(wait=False)
156
+ self.is_running = False
157
+ logger.info("Image cleanup scheduler stopped")
158
+ except Exception as e:
159
+ logger.error(f"Failed to stop image cleanup scheduler: {e}")
160
+
161
+ def get_status(self):
162
+ """获取调度器状态"""
163
+ return {
164
+ 'running': self.is_running,
165
+ 'images_dir': self.images_dir,
166
+ 'cleanup_hours': self.cleanup_hours,
167
+ 'interval_hours': self.interval_hours,
168
+ 'next_run': self.scheduler.get_jobs()[0].next_run_time.isoformat()
169
+ if self.is_running and self.scheduler.get_jobs() else None
170
+ }
171
+
172
+
173
+ # 创建全局调度器实例
174
+ cleanup_scheduler = ImageCleanupScheduler()
175
+
176
+
177
+ def start_cleanup_scheduler():
178
+ """启动图片清理调度器"""
179
+ cleanup_scheduler.start()
180
+
181
+
182
+ def stop_cleanup_scheduler():
183
+ """停止图片清理调度器"""
184
+ cleanup_scheduler.stop()
185
+
186
+
187
+ def get_cleanup_status():
188
+ """获取清理调度器状态"""
189
+ return cleanup_scheduler.get_status()
190
+
191
+
192
+ def manual_cleanup():
193
+ """手动执行一次清理"""
194
+ return cleanup_scheduler.cleanup_old_images()
195
+
196
+
197
+ if __name__ == "__main__":
198
+ # 测试代码
199
+ print("测试图片清理功能...")
200
+ test_scheduler = ImageCleanupScheduler()
201
+ result = test_scheduler.cleanup_old_images()
202
+ print(f"清理结果: {result}")
clip_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # clip_utils.py
2
+ import logging
3
+ import os
4
+ from typing import Union, List
5
+
6
+ import cn_clip.clip as clip
7
+ import torch
8
+ from PIL import Image
9
+ from cn_clip.clip import load_from_name
10
+
11
+ from config import MODELS_PATH
12
+
13
+ # 配置日志
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # 环境变量配置
18
+ MODEL_NAME_CN = os.environ.get('MODEL_NAME_CN', 'ViT-B-16')
19
+
20
+ # 设备配置
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ # 模型初始化
24
+ model = None
25
+ preprocess = None
26
+
27
+ def init_clip_model():
28
+ """初始化CLIP模型"""
29
+ global model, preprocess
30
+ try:
31
+ model, preprocess = load_from_name(MODEL_NAME_CN, device=device, download_root=MODELS_PATH)
32
+ model.eval()
33
+ logger.info(f"CLIP model initialized successfully, dimension: {model.visual.output_dim}")
34
+ return True
35
+ except Exception as e:
36
+ logger.error(f"CLIP model initialization failed: {e}")
37
+ return False
38
+
39
+ def is_clip_available():
40
+ """检查CLIP模型是否可用"""
41
+ return model is not None and preprocess is not None
42
+
43
+ def encode_image(image_path: str) -> torch.Tensor:
44
+ """编码图片为向量"""
45
+ if not is_clip_available():
46
+ raise RuntimeError("CLIP模型未初始化")
47
+
48
+ image = Image.open(image_path).convert("RGB")
49
+ image_tensor = preprocess(image).unsqueeze(0).to(device)
50
+ with torch.no_grad():
51
+ features = model.encode_image(image_tensor)
52
+ features = features / features.norm(p=2, dim=-1, keepdim=True)
53
+ return features.cpu()
54
+
55
+ def encode_text(text: Union[str, List[str]]) -> torch.Tensor:
56
+ """编码文本为向量"""
57
+ if not is_clip_available():
58
+ raise RuntimeError("CLIP模型未初始化")
59
+
60
+ texts = [text] if isinstance(text, str) else text
61
+ text_tokens = clip.tokenize(texts).to(device)
62
+ with torch.no_grad():
63
+ features = model.encode_text(text_tokens)
64
+ features = features / features.norm(p=2, dim=-1, keepdim=True)
65
+ return features.cpu()
config.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ # 解决OpenMP库冲突问题
5
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
6
+ # 设置CPU线程数为CPU核心数,提高CPU利用率
7
+ import multiprocessing
8
+ cpu_cores = multiprocessing.cpu_count()
9
+ os.environ["OMP_NUM_THREADS"] = str(min(cpu_cores, 8)) # 最多使用8个线程
10
+ os.environ["MKL_NUM_THREADS"] = str(min(cpu_cores, 8))
11
+ os.environ["NUMEXPR_NUM_THREADS"] = str(min(cpu_cores, 8))
12
+
13
+ # 修复torchvision兼容性问题
14
+ try:
15
+ import torchvision.transforms.functional_tensor
16
+ except ImportError:
17
+ # 为缺失的functional_tensor模块创建兼容性补丁
18
+ import torchvision.transforms.functional as F
19
+ import torchvision.transforms as transforms
20
+ import sys
21
+ from types import ModuleType
22
+
23
+ # 创建functional_tensor模块
24
+ functional_tensor = ModuleType('torchvision.transforms.functional_tensor')
25
+
26
+ # 添加常用的函数映射
27
+ if hasattr(F, 'rgb_to_grayscale'):
28
+ functional_tensor.rgb_to_grayscale = F.rgb_to_grayscale
29
+ if hasattr(F, 'adjust_brightness'):
30
+ functional_tensor.adjust_brightness = F.adjust_brightness
31
+ if hasattr(F, 'adjust_contrast'):
32
+ functional_tensor.adjust_contrast = F.adjust_contrast
33
+ if hasattr(F, 'adjust_saturation'):
34
+ functional_tensor.adjust_saturation = F.adjust_saturation
35
+ if hasattr(F, 'normalize'):
36
+ functional_tensor.normalize = F.normalize
37
+ if hasattr(F, 'resize'):
38
+ functional_tensor.resize = F.resize
39
+ if hasattr(F, 'crop'):
40
+ functional_tensor.crop = F.crop
41
+ if hasattr(F, 'pad'):
42
+ functional_tensor.pad = F.pad
43
+
44
+ # 将模块添加到sys.modules
45
+ sys.modules['torchvision.transforms.functional_tensor'] = functional_tensor
46
+ transforms.functional_tensor = functional_tensor
47
+
48
+ # 环境变量配置 - 禁用TensorFlow优化和GPU
49
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
50
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
51
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # 强制使用CPU
52
+ os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false"
53
+
54
+ # 修复PyTorch兼容性问题
55
+ try:
56
+ import torch
57
+ import torch.onnx
58
+
59
+ # 修复GFPGAN的ONNX兼容性
60
+ if not hasattr(torch.onnx._internal.exporter, 'ExportOptions'):
61
+ from types import SimpleNamespace
62
+ torch.onnx._internal.exporter.ExportOptions = SimpleNamespace
63
+
64
+ # 修复ModelScope的PyTree兼容性 - 更完整的实现
65
+ import torch.utils
66
+ if not hasattr(torch.utils, '_pytree'):
67
+ # 如果_pytree模块不存在,创建一个
68
+ from types import ModuleType
69
+ torch.utils._pytree = ModuleType('_pytree')
70
+
71
+ pytree = torch.utils._pytree
72
+
73
+ if not hasattr(pytree, 'register_pytree_node'):
74
+ def register_pytree_node(typ, flatten_fn, unflatten_fn, *, flatten_with_keys_fn=None, **kwargs):
75
+ """兼容性实现:注册PyTree节点类型"""
76
+ pass # 简单实现,不做实际操作
77
+ pytree.register_pytree_node = register_pytree_node
78
+
79
+ if not hasattr(pytree, 'tree_flatten'):
80
+ def tree_flatten(tree, is_leaf=None):
81
+ """兼容性实现:展平树结构"""
82
+ if isinstance(tree, (list, tuple)):
83
+ flat = []
84
+ spec = []
85
+ for i, item in enumerate(tree):
86
+ if isinstance(item, (list, tuple, dict)):
87
+ sub_flat, sub_spec = tree_flatten(item, is_leaf)
88
+ flat.extend(sub_flat)
89
+ spec.append((i, sub_spec))
90
+ else:
91
+ flat.append(item)
92
+ spec.append((i, None))
93
+ return flat, (type(tree), spec)
94
+ elif isinstance(tree, dict):
95
+ flat = []
96
+ spec = []
97
+ for key, value in sorted(tree.items()):
98
+ if isinstance(value, (list, tuple, dict)):
99
+ sub_flat, sub_spec = tree_flatten(value, is_leaf)
100
+ flat.extend(sub_flat)
101
+ spec.append((key, sub_spec))
102
+ else:
103
+ flat.append(value)
104
+ spec.append((key, None))
105
+ return flat, (dict, spec)
106
+ else:
107
+ return [tree], None
108
+ pytree.tree_flatten = tree_flatten
109
+
110
+ if not hasattr(pytree, 'tree_unflatten'):
111
+ def tree_unflatten(values, spec):
112
+ """兼容性实现:重构树结构"""
113
+ if spec is None:
114
+ return values[0] if values else None
115
+
116
+ tree_type, tree_spec = spec
117
+ if tree_type in (list, tuple):
118
+ result = []
119
+ value_idx = 0
120
+ for pos, sub_spec in tree_spec:
121
+ if sub_spec is None:
122
+ result.append(values[value_idx])
123
+ value_idx += 1
124
+ else:
125
+ # 计算子树需要的值数量
126
+ sub_count = _count_tree_values(sub_spec)
127
+ sub_values = values[value_idx:value_idx + sub_count]
128
+ result.append(tree_unflatten(sub_values, sub_spec))
129
+ value_idx += sub_count
130
+ return tree_type(result)
131
+ elif tree_type == dict:
132
+ result = {}
133
+ value_idx = 0
134
+ for key, sub_spec in tree_spec:
135
+ if sub_spec is None:
136
+ result[key] = values[value_idx]
137
+ value_idx += 1
138
+ else:
139
+ sub_count = _count_tree_values(sub_spec)
140
+ sub_values = values[value_idx:value_idx + sub_count]
141
+ result[key] = tree_unflatten(sub_values, sub_spec)
142
+ value_idx += sub_count
143
+ return result
144
+ return values[0] if values else None
145
+ pytree.tree_unflatten = tree_unflatten
146
+
147
+ if not hasattr(pytree, 'tree_map'):
148
+ def tree_map(fn, tree, *other_trees, is_leaf=None):
149
+ """兼容性实现:树映射"""
150
+ flat, spec = tree_flatten(tree, is_leaf)
151
+ if other_trees:
152
+ other_flats = [tree_flatten(t, is_leaf)[0] for t in other_trees]
153
+ mapped = [fn(x, *others) for x, *others in zip(flat, *other_flats)]
154
+ else:
155
+ mapped = [fn(x) for x in flat]
156
+ return tree_unflatten(mapped, spec)
157
+ pytree.tree_map = tree_map
158
+
159
+ # 辅助函数
160
+ def _count_tree_values(spec):
161
+ """计算树规格中的值数量"""
162
+ if spec is None:
163
+ return 1
164
+ tree_type, tree_spec = spec
165
+ return sum(_count_tree_values(sub_spec) if sub_spec else 1 for _, sub_spec in tree_spec)
166
+
167
+ # 修复pyarrow兼容性问题
168
+ try:
169
+ import pyarrow
170
+ if not hasattr(pyarrow, 'PyExtensionType'):
171
+ # 为旧版本pyarrow添加PyExtensionType兼容性
172
+ pyarrow.PyExtensionType = type('PyExtensionType', (), {})
173
+ except ImportError:
174
+ pass
175
+
176
+ except (ImportError, AttributeError) as e:
177
+ print(f"Warning: PyTorch/PyArrow compatibility patch failed: {e}")
178
+ pass
179
+ IMAGES_DIR = os.environ.get("IMAGES_DIR", "/opt/data/images")
180
+ OUTPUT_DIR = IMAGES_DIR
181
+
182
+ # 明星图库目录配置
183
+ CELEBRITY_SOURCE_DIR = os.environ.get(
184
+ "CELEBRITY_SOURCE_DIR", "/opt/data/chinese_celeb_dataset"
185
+ ).strip()
186
+ if CELEBRITY_SOURCE_DIR:
187
+ CELEBRITY_SOURCE_DIR = os.path.abspath(os.path.expanduser(CELEBRITY_SOURCE_DIR))
188
+
189
+ CELEBRITY_DATASET_DIR = os.path.abspath(
190
+ os.path.expanduser(
191
+ os.environ.get(
192
+ "CELEBRITY_DATASET_DIR",
193
+ CELEBRITY_SOURCE_DIR or "/opt/data/chinese_celeb_dataset",
194
+ )
195
+ )
196
+ )
197
+
198
+ CELEBRITY_FIND_THRESHOLD = float(
199
+ os.environ.get("CELEBRITY_FIND_THRESHOLD", 0.88)
200
+ )
201
+
202
+ # ---- start ----
203
+ # 微信小程序配置(默认值仅用于本地开发)
204
+ WECHAT_APPID = os.environ.get("WECHAT_APPID", "******").strip()
205
+ WECHAT_SECRET = os.environ.get("WCT_SECRET", "******").strip()
206
+ APP_SECRET_TOKEN = os.environ.get("APP_SECRET_TOKEN", "******")
207
+ # MySQL 数据库配置
208
+ MYSQL_HOST = os.environ.get("MYSQL_HOST", "******")
209
+ MYSQL_PORT = int(os.environ.get("MYSQL_PORT", "3306"))
210
+ MYSQL_DB = os.environ.get("MYSQL_DB", "******")
211
+ MYSQL_USER = os.environ.get("MYSQL_USER", "******")
212
+ MYSQL_PASSWORD = os.environ.get("MYSQL_PASSWORD", "******")
213
+ # BOS 对象存储配置(默认存储为Base64编码字符串)
214
+ BOS_ACCESS_KEY = os.environ.get("BOS_ACCESS_KEY", "******").strip()
215
+ BOS_SECRET_KEY = os.environ.get("BOS_SECRET_KEY", "******").strip()
216
+ BOS_ENDPOINT = os.environ.get("BOS_ENDPOINT", "******").strip()
217
+ BOS_BUCKET_NAME = os.environ.get("BOS_BUCKET_NAME", "******").strip()
218
+ BOS_IMAGE_DIR = os.environ.get("BOS_IMAGE_DIR", "******").strip()
219
+ BOS_MODELS_PREFIX = os.environ.get("BOS_MODELS_PREFIX", "******").strip()
220
+ BOS_CELEBRITY_PREFIX = os.environ.get("BOS_CELEBRITY_PREFIX", "******").strip()
221
+ # ---- end ---
222
+
223
+ _bos_enabled_env = os.environ.get("BOS_UPLOAD_ENABLED")
224
+ MYSQL_POOL_MIN_SIZE = int(os.environ.get("MYSQL_POOL_MIN_SIZE", "1"))
225
+ MYSQL_POOL_MAX_SIZE = int(os.environ.get("MYSQL_POOL_MAX_SIZE", "10"))
226
+ if _bos_enabled_env is not None:
227
+ BOS_UPLOAD_ENABLED = _bos_enabled_env.lower() in ("1", "true", "on")
228
+ else:
229
+ BOS_UPLOAD_ENABLED = all(
230
+ [
231
+ BOS_ACCESS_KEY.strip(),
232
+ BOS_SECRET_KEY.strip(),
233
+ BOS_ENDPOINT,
234
+ BOS_BUCKET_NAME,
235
+ ]
236
+ )
237
+ HOSTNAME = os.environ.get("HOSTNAME", "default-hostname")
238
+ MODELS_PATH = os.path.abspath(
239
+ os.path.expanduser(os.environ.get("MODELS_PATH", "/opt/data/models"))
240
+ )
241
+ MODELS_DOWNLOAD_DIR = os.path.abspath(
242
+ os.path.expanduser(os.environ.get("MODELS_DOWNLOAD_DIR", MODELS_PATH))
243
+ )
244
+ # HuggingFace 仓库配置
245
+ HUGGINGFACE_SYNC_ENABLED = os.environ.get(
246
+ "HUGGINGFACE_SYNC_ENABLED", "true"
247
+ ).lower() in ("1", "true", "on")
248
+ HUGGINGFACE_REPO_ID = os.environ.get(
249
+ "HUGGINGFACE_REPO_ID", "ethonmax/facescore"
250
+ ).strip()
251
+ HUGGINGFACE_REVISION = os.environ.get(
252
+ "HUGGINGFACE_REVISION", "main"
253
+ ).strip()
254
+ _hf_allow_env = os.environ.get("HUGGINGFACE_ALLOW_PATTERNS", "").strip()
255
+ HUGGINGFACE_ALLOW_PATTERNS = [
256
+ pattern.strip() for pattern in _hf_allow_env.split(",") if pattern.strip()
257
+ ]
258
+ _hf_ignore_env = os.environ.get("HUGGINGFACE_IGNORE_PATTERNS", "").strip()
259
+ HUGGINGFACE_IGNORE_PATTERNS = [
260
+ pattern.strip() for pattern in _hf_ignore_env.split(",") if pattern.strip()
261
+ ]
262
+
263
+ _MODELSCOPE_CACHE_ENV = os.environ.get("MODELSCOPE_CACHE", "").strip()
264
+ if _MODELSCOPE_CACHE_ENV:
265
+ MODELSCOPE_CACHE_DIR = os.path.abspath(os.path.expanduser(_MODELSCOPE_CACHE_ENV))
266
+ else:
267
+ MODELSCOPE_CACHE_DIR = os.path.join(MODELS_PATH, "modelscope")
268
+
269
+ try:
270
+ os.makedirs(MODELSCOPE_CACHE_DIR, exist_ok=True)
271
+ except Exception as exc:
272
+ print(f"创建 ModelScope 缓存目录失败: %s (%s)", MODELSCOPE_CACHE_DIR, exc)
273
+
274
+ os.environ.setdefault("MODELSCOPE_CACHE", MODELSCOPE_CACHE_DIR)
275
+ os.environ.setdefault("MODELSCOPE_HOME", MODELSCOPE_CACHE_DIR)
276
+ os.environ.setdefault("MODELSCOPE_CACHE_HOME", MODELSCOPE_CACHE_DIR)
277
+
278
+ DEEPFACE_HOME = os.environ.get("DEEPFACE_HOME", "/opt/data/models")
279
+ os.environ["DEEPFACE_HOME"] = DEEPFACE_HOME
280
+
281
+ # 设置GFPGAN相关模型下载路径
282
+ GFPGAN_MODEL_DIR = MODELS_DOWNLOAD_DIR
283
+ os.makedirs(GFPGAN_MODEL_DIR, exist_ok=True)
284
+
285
+ # 设置各种模型库的下载目录环境变量
286
+ os.environ["GFPGAN_MODEL_ROOT"] = GFPGAN_MODEL_DIR
287
+ os.environ["FACEXLIB_CACHE_DIR"] = GFPGAN_MODEL_DIR
288
+ os.environ["BASICSR_CACHE_DIR"] = GFPGAN_MODEL_DIR
289
+ os.environ["REALESRGAN_MODEL_ROOT"] = GFPGAN_MODEL_DIR
290
+ os.environ["HUB_CACHE_DIR"] = GFPGAN_MODEL_DIR # PyTorch Hub缓存
291
+
292
+ # 设置rembg模型下载路径到统一的AI模型目录
293
+ REMBG_MODEL_DIR = os.path.expanduser(MODELS_PATH.replace("$HOME", "~"))
294
+ os.environ["U2NET_HOME"] = REMBG_MODEL_DIR # u2net模型缓存目录
295
+ os.environ["REMBG_HOME"] = REMBG_MODEL_DIR # rembg通用缓存目录
296
+
297
+ IMG_QUALITY = float(os.environ.get("IMG_QUALITY", 0.5))
298
+ FACE_CONFIDENCE = float(os.environ.get("FACE_CONFIDENCE", 0.7))
299
+ AGE_CONFIDENCE = float(os.environ.get("AGE_CONFIDENCE", 0.99))
300
+ GENDER_CONFIDENCE = float(os.environ.get("GENDER_CONFIDENCE", 1.1))
301
+ UPSCALE_SIZE = int(os.environ.get("UPSCALE_SIZE", 2))
302
+ SAVE_QUALITY = int(os.environ.get("SAVE_QUALITY", 85))
303
+ REALESRGAN_MODEL = os.environ.get("REALESRGAN_MODEL", "realesr-general-x4v3")
304
+ # yolov11n-face.pt / yolov8n-face.pt
305
+ YOLO_MODEL = os.environ.get("YOLO_MODEL", "yolov11n-face.pt")
306
+ # mobilenetv3/resnet50
307
+ RVM_MODEL = os.environ.get("RVM_MODEL", "resnet50")
308
+ RVM_LOCAL_REPO = os.environ.get("RVM_LOCAL_REPO", "/opt/data/RobustVideoMatting").strip()
309
+ RVM_WEIGHTS_PATH = os.environ.get("RVM_WEIGHTS_PATH", "/opt/data/models/torch/hub/checkpoints/rvm_resnet50.pth").strip()
310
+ DRAW_SCORE = os.environ.get("DRAW_SCORE", "true").lower() in ("1", "true", "on")
311
+
312
+ # 颜值评分温和提升配置(默认开启;默认区间与力度:区间=[6.0, 8.0],gamma=0.3)
313
+ # - BEAUTY_ADJUST_ENABLED: 是否开启提分
314
+ # - BEAUTY_ADJUST_MIN: 提分下限(低于该值不提分)
315
+ # - BEAUTY_ADJUST_MAX: 提分上限(目标上限;仅在 [min, max) 区间内提分)
316
+ # - BEAUTY_ADJUST_THRESHOLD: 兼容旧配置,等价于 BEAUTY_ADJUST_MAX
317
+ # - BEAUTY_ADJUST_GAMMA: 提分力度,(0,1],越小提升越多
318
+ BEAUTY_ADJUST_ENABLED = os.environ.get("BEAUTY_ADJUST_ENABLED", "true").lower() in ("1", "true", "on")
319
+ BEAUTY_ADJUST_MIN = float(os.environ.get("BEAUTY_ADJUST_MIN", 1.0))
320
+ # 向后兼容:未提供 BEAUTY_ADJUST_MAX 时,使用旧的 BEAUTY_ADJUST_THRESHOLD 或 8.0
321
+ _legacy_thr = os.environ.get("BEAUTY_ADJUST_THRESHOLD")
322
+ BEAUTY_ADJUST_MAX = float(os.environ.get("BEAUTY_ADJUST_MAX", _legacy_thr if _legacy_thr is not None else 8.0))
323
+ BEAUTY_ADJUST_GAMMA = float(os.environ.get("BEAUTY_ADJUST_GAMMA", 0.5)) # 0<gamma<=1,越小提升越多
324
+
325
+ # 兼容旧引用,保留变量名(不再直接使用于逻辑内部)
326
+ BEAUTY_ADJUST_THRESHOLD = BEAUTY_ADJUST_MAX
327
+
328
+ # 整体协调性分数温和提升配置(默认开启;默认阈值与力度:T=8.0, gamma=0.5)
329
+ HARMONY_ADJUST_ENABLED = os.environ.get("HARMONY_ADJUST_ENABLED", "true").lower() in ("1", "true", "on")
330
+ HARMONY_ADJUST_THRESHOLD = float(os.environ.get("HARMONY_ADJUST_THRESHOLD", 9.0))
331
+ HARMONY_ADJUST_GAMMA = float(os.environ.get("HARMONY_ADJUST_GAMMA", 0.3))
332
+
333
+ # 启动优化:是否在启动时自动初始化/预热重型组件
334
+ ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "false").lower() in ("1", "true", "on")
335
+ AUTO_INIT_ANALYZER = os.environ.get("AUTO_INIT_ANALYZER", "true").lower() in ("1", "true", "on")
336
+ AUTO_INIT_GFPGAN = os.environ.get("AUTO_INIT_GFPGAN", "false").lower() in ("1", "true", "on")
337
+ AUTO_INIT_DDCOLOR = os.environ.get("AUTO_INIT_DDCOLOR", "false").lower() in ("1", "true", "on")
338
+ AUTO_INIT_REALESRGAN = os.environ.get("AUTO_INIT_REALESRGAN", "false").lower() in ("1", "true", "on")
339
+ AUTO_INIT_REMBG = os.environ.get("AUTO_INIT_REMBG", "false").lower() in ("1", "true", "on")
340
+ AUTO_INIT_ANIME_STYLE = os.environ.get("AUTO_INIT_ANIME_STYLE", "false").lower() in ("1", "true", "on")
341
+ AUTO_INIT_RVM = os.environ.get("AUTO_INIT_RVM", "false").lower() in ("1", "true", "on")
342
+
343
+ # 定时任务相关配置
344
+ CLEANUP_INTERVAL_HOURS = float(os.environ.get("CLEANUP_INTERVAL_HOURS", 1.0)) # 清理任务执行间隔(小时),默认1小时
345
+ CLEANUP_AGE_HOURS = float(os.environ.get("CLEANUP_AGE_HOURS", 1.0)) # 清理文件的年龄阈值(小时),默认1小时
346
+
347
+ # BOS 自动同步清单:定义 BOS 路径和本地目录的映射,启动时可迭代该结构完成批量下载
348
+ BOS_DOWNLOAD_TARGETS = [
349
+ # {
350
+ # "description": "明星图库数据集",
351
+ # "bos_prefix": BOS_CELEBRITY_PREFIX,
352
+ # "destination": CELEBRITY_DATASET_DIR,
353
+ # "background": True,
354
+ # },
355
+ # {
356
+ # "description": "AI 模型权重",
357
+ # "bos_prefix": BOS_MODELS_PREFIX,
358
+ # "destination": MODELS_DOWNLOAD_DIR,
359
+ # },
360
+ ]
361
+
362
+ log_level_str = os.getenv("LOG_LEVEL", "INFO").upper()
363
+ log_level = getattr(logging, log_level_str, logging.INFO)
364
+
365
+ # 日志开关配置 - 控制是否启用所有日志输出
366
+ ENABLE_LOGGING = os.environ.get("ENABLE_LOGGING", "true").lower() in ("1", "true", "on")
367
+
368
+ # 功能开关配置
369
+ ENABLE_DDCOLOR = os.environ.get("ENABLE_DDCOLOR", "true").lower() in ("1", "true", "on")
370
+ ENABLE_REALESRGAN = os.environ.get("ENABLE_REALESRGAN", "true").lower() in ("1", "true", "on")
371
+ ENABLE_GFPGAN = os.environ.get("ENABLE_GFPGAN", "true").lower() in ("1", "true", "on")
372
+ ENABLE_ANIME_STYLE = os.environ.get("ENABLE_ANIME_STYLE", "true").lower() in ("1", "true", "on")
373
+ ENABLE_ANIME_PRELOAD = os.environ.get("ENABLE_ANIME_PRELOAD", "false").lower() in ("1", "true", "on")
374
+ ENABLE_RVM = os.environ.get("ENABLE_RVM", "true").lower() in ("1", "true", "on")
375
+
376
+
377
+ # 颜值评分模块配置
378
+ FACE_SCORE_MAX_IMAGES = int(os.environ.get("FACE_SCORE_MAX_IMAGES", 10)) # 颜值评分最大上传图片数量
379
+
380
+ # 女性年龄调整配置 - 对于20岁以上的女性,显示的年龄会减去指定岁数
381
+ FEMALE_AGE_ADJUSTMENT = int(os.environ.get("FEMALE_AGE_ADJUSTMENT", 3)) # 默认减3岁
382
+ FEMALE_AGE_ADJUSTMENT_THRESHOLD = int(os.environ.get("FEMALE_AGE_ADJUSTMENT_THRESHOLD", 20)) # 年龄阈值,默认20岁
383
+
384
+ # 配置日志
385
+ if ENABLE_LOGGING:
386
+ logging.basicConfig(
387
+ level=log_level,
388
+ format="[%(asctime)s] [%(levelname)s] %(message)s",
389
+ datefmt="%Y-%m-%d %H:%M:%S",
390
+ )
391
+ logger = logging.getLogger(__name__)
392
+ else:
393
+ # 禁用所有日志输出
394
+ logging.basicConfig(level=logging.CRITICAL + 10)
395
+ logger = logging.getLogger(__name__)
396
+ logger.disabled = True
397
+
398
+ # 全局变量存储 access_token
399
+ access_token_cache = {"token": None, "expires_at": 0}
400
+
401
+ # 尝试导入依赖
402
+ try:
403
+ from deepface import DeepFace
404
+
405
+ DEEPFACE_AVAILABLE = True
406
+ except ImportError:
407
+ print("Warning: DeepFace not installed. Install with: pip install deepface")
408
+ DEEPFACE_AVAILABLE = False
409
+
410
+ try:
411
+ import mediapipe as mp
412
+
413
+ MEDIAPIPE_AVAILABLE = True
414
+ except ImportError:
415
+ print("Warning: mediapipe not installed. Install with: pip install mediapipe")
416
+ MEDIAPIPE_AVAILABLE = False
417
+
418
+ # 为了保持向后兼容,保留 DLIB_AVAILABLE 变量名
419
+ DLIB_AVAILABLE = MEDIAPIPE_AVAILABLE
420
+
421
+ try:
422
+ from ultralytics import YOLO
423
+
424
+ YOLO_AVAILABLE = True
425
+ except ImportError:
426
+ print("Warning: ultralytics not installed. Install with: pip install ultralytics")
427
+ YOLO_AVAILABLE = False
428
+
429
+ # 检查GFPGAN是否启用和可用
430
+ if ENABLE_GFPGAN:
431
+ try:
432
+ required_files = [
433
+ os.path.join(os.path.dirname(__file__), "gfpgan_restorer.py"),
434
+ os.path.join(MODELS_PATH, "gfpgan/weights/detection_Resnet50_Final.pth"),
435
+ os.path.join(MODELS_PATH, "gfpgan/weights/parsing_parsenet.pth"),
436
+ ]
437
+
438
+ missing_files = [path for path in required_files if not os.path.exists(path)]
439
+ if missing_files:
440
+ for file_path in missing_files:
441
+ logger.info("GFPGAN 所需文件暂未找到,将等待模型同步: %s", file_path)
442
+
443
+ from gfpgan_restorer import GFPGANRestorer # noqa: F401
444
+ GFPGAN_AVAILABLE = True
445
+
446
+ if missing_files:
447
+ logger.warning(
448
+ "GFPGAN 文件尚未全部就绪,将在 HuggingFace/BOS 同步完成后继续初始化: %s",
449
+ ", ".join(missing_files),
450
+ )
451
+ else:
452
+ logger.info("GFPGAN photo restoration feature prerequisites detected")
453
+ except ImportError as e:
454
+ print(f"Warning: GFPGAN enabled but not available: {e}")
455
+ GFPGAN_AVAILABLE = False
456
+ logger.warning(f"GFPGAN photo restoration feature is enabled but import failed: {e}")
457
+ else:
458
+ GFPGAN_AVAILABLE = False
459
+ logger.info("GFPGAN photo restoration feature is disabled (via ENABLE_GFPGAN environment variable)")
460
+
461
+ # 检查DDColor是否启用和可用
462
+ if ENABLE_DDCOLOR:
463
+ try:
464
+ from ddcolor_colorizer import DDColorColorizer
465
+ DDCOLOR_AVAILABLE = True
466
+ logger.info("DDColor feature is enabled and available")
467
+ except ImportError as e:
468
+ print(f"Warning: DDColor enabled but not available: {e}")
469
+ DDCOLOR_AVAILABLE = False
470
+ logger.warning(f"DDColor feature is enabled but import failed: {e}")
471
+ else:
472
+ DDCOLOR_AVAILABLE = False
473
+ logger.info("DDColor feature is disabled (via ENABLE_DDCOLOR environment variable)")
474
+
475
+ # 只使用GFPGAN修复器
476
+ SIMPLE_RESTORER_AVAILABLE = False
477
+
478
+ # 检查Real-ESRGAN是否启用和可用
479
+ if ENABLE_REALESRGAN:
480
+ try:
481
+ from realesrgan_upscaler import RealESRGANUpscaler
482
+ REALESRGAN_AVAILABLE = True
483
+ logger.info("Real-ESRGAN super resolution feature is enabled and available")
484
+ except ImportError as e:
485
+ print(f"Warning: Real-ESRGAN enabled but not available: {e}")
486
+ REALESRGAN_AVAILABLE = False
487
+ logger.warning(f"Real-ESRGAN super resolution feature is enabled but import failed: {e}")
488
+ else:
489
+ REALESRGAN_AVAILABLE = False
490
+ logger.info("Real-ESRGAN super resolution feature is disabled (via ENABLE_REALESRGAN environment variable)")
491
+
492
+ # rembg功能开关配置
493
+ ENABLE_REMBG = os.environ.get("ENABLE_REMBG", "true").lower() in ("1", "true", "on")
494
+
495
+ # 检查rembg是否启用和可用
496
+ if ENABLE_REMBG:
497
+ try:
498
+ import rembg
499
+ from rembg import new_session
500
+ REMBG_AVAILABLE = True
501
+ logger.info("rembg background removal feature is enabled and available")
502
+ logger.info(f"rembg model storage path: {REMBG_MODEL_DIR}")
503
+ except ImportError as e:
504
+ print(f"Warning: rembg enabled but not available: {e}")
505
+ REMBG_AVAILABLE = False
506
+ logger.warning(f"rembg background removal feature is enabled but import failed: {e}")
507
+ else:
508
+ REMBG_AVAILABLE = False
509
+ logger.info("rembg background removal feature is disabled (via ENABLE_REMBG environment variable)")
510
+
511
+ CLIP_AVAILABLE = False
512
+
513
+ # 检查Anime Style是否启用和可用
514
+ if ENABLE_ANIME_STYLE:
515
+ try:
516
+ from anime_stylizer import AnimeStylizer
517
+ ANIME_STYLE_AVAILABLE = True
518
+ logger.info("Anime stylization feature is enabled and available")
519
+ except ImportError as e:
520
+ print(f"Warning: Anime Style enabled but not available: {e}")
521
+ ANIME_STYLE_AVAILABLE = False
522
+ logger.warning(f"Anime stylization feature is enabled but import failed: {e}")
523
+ else:
524
+ ANIME_STYLE_AVAILABLE = False
525
+ logger.info("Anime stylization feature is disabled (via ENABLE_ANIME_STYLE environment variable)")
526
+
527
+ # RVM功能开关配置
528
+ ENABLE_RVM = os.environ.get("ENABLE_RVM", "true").lower() in ("1", "true", "on")
529
+
530
+ # 检查RVM是否启用和可用
531
+ if ENABLE_RVM:
532
+ try:
533
+ import torch
534
+ # 检查是否可以加载RVM模型
535
+ RVM_AVAILABLE = True
536
+ logger.info("RVM background removal feature is enabled and available")
537
+ except ImportError as e:
538
+ print(f"Warning: RVM enabled but not available: {e}")
539
+ RVM_AVAILABLE = False
540
+ logger.warning(f"RVM background removal feature is enabled but import failed: {e}")
541
+ else:
542
+ RVM_AVAILABLE = False
543
+ logger.info("RVM background removal feature is disabled (via ENABLE_RVM environment variable)")
database.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import os
4
+ from contextlib import asynccontextmanager
5
+ from datetime import datetime
6
+ from typing import Any, Dict, Iterable, List, Optional, Sequence
7
+
8
+ import aiomysql
9
+ from aiomysql.cursors import DictCursor
10
+
11
+ from config import (
12
+ IMAGES_DIR,
13
+ logger,
14
+ MYSQL_HOST,
15
+ MYSQL_PORT,
16
+ MYSQL_DB,
17
+ MYSQL_USER,
18
+ MYSQL_PASSWORD,
19
+ MYSQL_POOL_MIN_SIZE,
20
+ MYSQL_POOL_MAX_SIZE,
21
+ )
22
+
23
+ _pool: Optional[aiomysql.Pool] = None
24
+ _pool_lock = asyncio.Lock()
25
+
26
+
27
+ async def init_mysql_pool() -> aiomysql.Pool:
28
+ """初始化 MySQL 连接池"""
29
+ global _pool
30
+ if _pool is not None:
31
+ return _pool
32
+
33
+ async with _pool_lock:
34
+ if _pool is not None:
35
+ return _pool
36
+ try:
37
+ _pool = await aiomysql.create_pool(
38
+ host=MYSQL_HOST,
39
+ port=MYSQL_PORT,
40
+ user=MYSQL_USER,
41
+ password=MYSQL_PASSWORD,
42
+ db=MYSQL_DB,
43
+ minsize=MYSQL_POOL_MIN_SIZE,
44
+ maxsize=MYSQL_POOL_MAX_SIZE,
45
+ autocommit=True,
46
+ charset="utf8mb4",
47
+ cursorclass=DictCursor,
48
+ )
49
+ logger.info(
50
+ "MySQL 连接池初始化成功,host=%s db=%s",
51
+ MYSQL_HOST,
52
+ MYSQL_DB,
53
+ )
54
+ except Exception as exc:
55
+ logger.error(f"初始化 MySQL 连接池失败: {exc}")
56
+ raise
57
+ return _pool
58
+
59
+
60
+ async def close_mysql_pool() -> None:
61
+ """关闭 MySQL 连接池"""
62
+ global _pool
63
+ if _pool is None:
64
+ return
65
+
66
+ async with _pool_lock:
67
+ if _pool is None:
68
+ return
69
+ _pool.close()
70
+ await _pool.wait_closed()
71
+ _pool = None
72
+ logger.info("MySQL 连接池已关闭")
73
+
74
+
75
+ @asynccontextmanager
76
+ async def get_connection():
77
+ """获取连接池中的连接"""
78
+ if _pool is None:
79
+ await init_mysql_pool()
80
+ assert _pool is not None
81
+ conn = await _pool.acquire()
82
+ try:
83
+ yield conn
84
+ finally:
85
+ _pool.release(conn)
86
+
87
+
88
+ async def execute(query: str,
89
+ params: Sequence[Any] | Dict[str, Any] | None = None) -> None:
90
+ """执行写入类 SQL"""
91
+ async with get_connection() as conn:
92
+ async with conn.cursor() as cursor:
93
+ await cursor.execute(query, params or ())
94
+
95
+
96
+ async def fetch_all(
97
+ query: str, params: Sequence[Any] | Dict[str, Any] | None = None
98
+ ) -> List[Dict[str, Any]]:
99
+ """执行查询并返回全部结果"""
100
+ async with get_connection() as conn:
101
+ async with conn.cursor() as cursor:
102
+ await cursor.execute(query, params or ())
103
+ rows = await cursor.fetchall()
104
+ return list(rows)
105
+
106
+
107
+ def _serialize_extra(extra: Optional[Dict[str, Any]]) -> Optional[str]:
108
+ if extra is None:
109
+ return None
110
+ try:
111
+ return json.dumps(extra, ensure_ascii=False)
112
+ except Exception:
113
+ logger.warning("无法序列化 extra_metadata,已忽略")
114
+ return None
115
+
116
+
117
+ async def upsert_image_record(
118
+ *,
119
+ file_path: str,
120
+ category: str,
121
+ nickname: Optional[str],
122
+ score: float,
123
+ is_cropped_face: bool,
124
+ size_bytes: int,
125
+ last_modified: datetime,
126
+ bos_uploaded: bool,
127
+ hostname: Optional[str] = None,
128
+ extra_metadata: Optional[Dict[str, Any]] = None,
129
+ ) -> None:
130
+ """写入或更新图片记录"""
131
+ query = """
132
+ INSERT INTO tpl_app_processed_images (
133
+ file_path,
134
+ category,
135
+ nickname,
136
+ score,
137
+ is_cropped_face,
138
+ size_bytes,
139
+ last_modified,
140
+ bos_uploaded,
141
+ hostname,
142
+ extra_metadata
143
+ ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
144
+ ON DUPLICATE KEY UPDATE
145
+ category = VALUES(category),
146
+ nickname = VALUES(nickname),
147
+ score = VALUES(score),
148
+ is_cropped_face = VALUES(is_cropped_face),
149
+ size_bytes = VALUES(size_bytes),
150
+ last_modified = VALUES(last_modified),
151
+ bos_uploaded = VALUES(bos_uploaded),
152
+ hostname = VALUES(hostname),
153
+ extra_metadata = VALUES(extra_metadata),
154
+ updated_at = CURRENT_TIMESTAMP
155
+ """
156
+ extra_value = _serialize_extra(extra_metadata)
157
+ await execute(
158
+ query,
159
+ (
160
+ file_path,
161
+ category,
162
+ nickname,
163
+ score,
164
+ 1 if is_cropped_face else 0,
165
+ size_bytes,
166
+ last_modified,
167
+ 1 if bos_uploaded else 0,
168
+ hostname,
169
+ extra_value,
170
+ ),
171
+ )
172
+
173
+
174
+ async def fetch_paged_image_records(
175
+ *,
176
+ category: Optional[str],
177
+ nickname: Optional[str],
178
+ offset: int,
179
+ limit: int,
180
+ ) -> List[Dict[str, Any]]:
181
+ """按条件分页查询图片记录"""
182
+ where_clauses: List[str] = []
183
+ params: List[Any] = []
184
+ if category and category != "all":
185
+ where_clauses.append("category = %s")
186
+ params.append(category)
187
+ if nickname:
188
+ where_clauses.append("nickname = %s")
189
+ params.append(nickname)
190
+ where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
191
+ query = f"""
192
+ SELECT
193
+ file_path,
194
+ category,
195
+ nickname,
196
+ score,
197
+ is_cropped_face,
198
+ size_bytes,
199
+ last_modified,
200
+ bos_uploaded,
201
+ hostname
202
+ FROM tpl_app_processed_images
203
+ {where_sql}
204
+ ORDER BY last_modified DESC, id DESC
205
+ LIMIT %s OFFSET %s
206
+ """
207
+ params.extend([limit, offset])
208
+ return await fetch_all(query, params)
209
+
210
+
211
+ async def count_image_records(
212
+ *, category: Optional[str], nickname: Optional[str]
213
+ ) -> int:
214
+ """按条件统计图片记录数量"""
215
+ where_clauses: List[str] = []
216
+ params: List[Any] = []
217
+ if category and category != "all":
218
+ where_clauses.append("category = %s")
219
+ params.append(category)
220
+ if nickname:
221
+ where_clauses.append("nickname = %s")
222
+ params.append(nickname)
223
+ where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
224
+ query = f"SELECT COUNT(*) AS total FROM tpl_app_processed_images {where_sql}"
225
+ rows = await fetch_all(query, params)
226
+ if not rows:
227
+ return 0
228
+ return int(rows[0].get("total", 0) or 0)
229
+
230
+
231
+ async def fetch_today_category_counts() -> List[Dict[str, Any]]:
232
+ """统计当天按类别分组的数量"""
233
+ query = """
234
+ SELECT
235
+ COALESCE(category, 'unknown') AS category,
236
+ COUNT(*) AS count
237
+ FROM tpl_app_processed_images
238
+ WHERE last_modified >= CURDATE()
239
+ AND last_modified < DATE_ADD(CURDATE(), INTERVAL 1 DAY)
240
+ GROUP BY COALESCE(category, 'unknown')
241
+ """
242
+ rows = await fetch_all(query)
243
+ return [
244
+ {
245
+ "category": str(row.get("category") or "unknown"),
246
+ "count": int(row.get("count") or 0),
247
+ }
248
+ for row in rows
249
+ ]
250
+
251
+
252
+ async def fetch_records_by_paths(file_paths: Iterable[str]) -> Dict[
253
+ str, Dict[str, Any]]:
254
+ """根据文件名批量查询图片记录"""
255
+ paths = list({path for path in file_paths if path})
256
+ if not paths:
257
+ return {}
258
+
259
+ placeholders = ", ".join(["%s"] * len(paths))
260
+ query = f"""
261
+ SELECT
262
+ file_path,
263
+ category,
264
+ nickname,
265
+ score,
266
+ is_cropped_face,
267
+ size_bytes,
268
+ last_modified,
269
+ bos_uploaded,
270
+ hostname
271
+ FROM tpl_app_processed_images
272
+ WHERE file_path IN ({placeholders})
273
+ """
274
+ rows = await fetch_all(query, paths)
275
+ return {row["file_path"]: row for row in rows}
276
+
277
+
278
+ _IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR))
279
+
280
+
281
+ def _normalize_file_path(file_path: str) -> Optional[str]:
282
+ """将绝对路径转换为相对 IMAGES_DIR 的文件名"""
283
+ try:
284
+ abs_path = os.path.abspath(os.path.expanduser(file_path))
285
+ if os.path.isdir(abs_path):
286
+ return None
287
+ if os.path.commonpath([_IMAGES_DIR_ABS, abs_path]) != _IMAGES_DIR_ABS:
288
+ return os.path.basename(abs_path)
289
+ rel_path = os.path.relpath(abs_path, _IMAGES_DIR_ABS)
290
+ return rel_path.replace("\\", "/")
291
+ except Exception:
292
+ return None
293
+
294
+
295
+ def infer_category_from_filename(filename: str, default: str = "other") -> str:
296
+ """根据文件名推断类别"""
297
+ lower_name = filename.lower()
298
+ if "_face_" in lower_name:
299
+ return "face"
300
+ if lower_name.endswith("_original.webp") or "_original" in lower_name:
301
+ return "original"
302
+ if "_restore" in lower_name:
303
+ return "restore"
304
+ if "_upcolor" in lower_name:
305
+ return "upcolor"
306
+ if "_compress" in lower_name:
307
+ return "compress"
308
+ if "_upscale" in lower_name:
309
+ return "upscale"
310
+ if "_anime_style_" in lower_name:
311
+ return "anime_style"
312
+ if "_grayscale" in lower_name:
313
+ return "grayscale"
314
+ if "_id_photo" in lower_name or "_save_id_photo" in lower_name:
315
+ return "id_photo"
316
+ if "_grid_" in lower_name:
317
+ return "grid"
318
+ if "_rvm_id_photo" in lower_name:
319
+ return "rvm"
320
+ if "_celebrity_" in lower_name or "_celebrity" in lower_name:
321
+ return "celebrity"
322
+ return default
323
+
324
+
325
+ from config import HOSTNAME
326
+
327
+ async def record_image_creation(
328
+ *,
329
+ file_path: str,
330
+ nickname: Optional[str],
331
+ score: float = 0.0,
332
+ category: Optional[str] = None,
333
+ bos_uploaded: bool = False,
334
+ extra_metadata: Optional[Dict[str, Any]] = None,
335
+ ) -> None:
336
+ """
337
+ 记录图片元数据到数据库,如果数据库不可用则静默忽略。
338
+ :param file_path: 绝对或相对文件路径
339
+ :param nickname: 用户昵称
340
+ :param score: 关联得分
341
+ :param category: 文件类别,未提供时自动根据文件名推断
342
+ :param bos_uploaded: 是否已上传至 BOS
343
+ :param extra_metadata: 额外信息
344
+ """
345
+ normalized = _normalize_file_path(file_path)
346
+ if normalized is None:
347
+ logger.info("record_image_creation: 无法计算文件名,路径=%s", file_path)
348
+ return
349
+
350
+ abs_path = os.path.join(_IMAGES_DIR_ABS, normalized)
351
+ if not os.path.isfile(abs_path):
352
+ logger.info("record_image_creation: 文件不存在,跳过记录 file=%s", abs_path)
353
+ return
354
+
355
+ try:
356
+ stat = os.stat(abs_path)
357
+ category_name = category or infer_category_from_filename(normalized)
358
+ is_cropped_face = "_face_" in normalized and normalized.count("_") >= 2
359
+ last_modified = datetime.fromtimestamp(stat.st_mtime)
360
+
361
+ nickname_value = nickname.strip() if isinstance(nickname,
362
+ str) and nickname.strip() else None
363
+
364
+ await upsert_image_record(
365
+ file_path=normalized,
366
+ category=category_name,
367
+ nickname=nickname_value,
368
+ score=score,
369
+ is_cropped_face=is_cropped_face,
370
+ size_bytes=stat.st_size,
371
+ last_modified=last_modified,
372
+ bos_uploaded=bos_uploaded,
373
+ hostname=HOSTNAME,
374
+ extra_metadata=extra_metadata,
375
+ )
376
+ except Exception as exc:
377
+ logger.warning(f"写入图片记录失败: {exc}")
ddcolor_colorizer.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from config import logger
9
+
10
+
11
+ class DDColorColorizer:
12
+ def __init__(self):
13
+ start_time = time.perf_counter()
14
+ self.colorizer = None
15
+ # 检查是否启用DDColor功能
16
+ from config import ENABLE_DDCOLOR
17
+ if ENABLE_DDCOLOR:
18
+ self._initialize_model()
19
+ else:
20
+ logger.info("DDColor feature is disabled, skipping model initialization")
21
+ init_time = time.perf_counter() - start_time
22
+ if self.colorizer is not None:
23
+ logger.info(f"DDColorColorizer initialized successfully, time: {init_time:.3f}s")
24
+ else:
25
+ logger.info(f"DDColorColorizer initialization completed but not available, time: {init_time:.3f}s")
26
+
27
+ def _initialize_model(self):
28
+ """初始化DDColor模型(使用ModelScope)"""
29
+ try:
30
+ logger.info("Initializing DDColor model (using ModelScope)...")
31
+
32
+ # 添加torch类型兼容性补丁
33
+ import torch
34
+ if not hasattr(torch, 'uint64'):
35
+ logger.info("Adding torch.uint64 compatibility patch...")
36
+ torch.uint64 = torch.int64 # 使用int64作为uint64的替代
37
+ if not hasattr(torch, 'uint32'):
38
+ logger.info("Adding torch.uint32 compatibility patch...")
39
+ torch.uint32 = torch.int32 # 使用int32作为uint32的替代
40
+ if not hasattr(torch, 'uint16'):
41
+ logger.info("Adding torch.uint16 compatibility patch...")
42
+ torch.uint16 = torch.int16 # 使用int16作为uint16的替代
43
+
44
+ # 导入ModelScope相关模块
45
+ from modelscope.outputs import OutputKeys
46
+ from modelscope.pipelines import pipeline
47
+ from modelscope.utils.constant import Tasks
48
+
49
+ # 初始化DDColor pipeline
50
+ self.colorizer = pipeline(
51
+ Tasks.image_colorization,
52
+ model='damo/cv_ddcolor_image-colorization'
53
+ )
54
+ self.OutputKeys = OutputKeys
55
+
56
+ logger.info("DDColor model initialized successfully")
57
+
58
+ except ImportError as e:
59
+ logger.error(f"ModelScope module import failed: {e}")
60
+ self.colorizer = None
61
+ except Exception as e:
62
+ logger.error(f"DDColor model initialization failed: {e}")
63
+ self.colorizer = None
64
+
65
+ def is_available(self):
66
+ """检查DDColor是否可用"""
67
+ return self.colorizer is not None
68
+
69
+ def is_grayscale(self, image):
70
+ """检查图像是否为灰度图像"""
71
+ if len(image.shape) == 2:
72
+ return True
73
+ elif len(image.shape) == 3:
74
+ # 检查是否为伪彩色图像(RGB三个通道值相等)
75
+ b, g, r = cv2.split(image)
76
+
77
+ # 计算通道间的差异
78
+ diff_bg = np.abs(b.astype(float) - g.astype(float))
79
+ diff_gr = np.abs(g.astype(float) - r.astype(float))
80
+ diff_rb = np.abs(r.astype(float) - b.astype(float))
81
+
82
+ # 计算平均差异
83
+ avg_diff = (np.mean(diff_bg) + np.mean(diff_gr) + np.mean(diff_rb)) / 3.0
84
+
85
+ # 计算色彩饱和度
86
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
87
+ saturation = hsv[:, :, 1] # S通道
88
+ avg_saturation = np.mean(saturation)
89
+
90
+ # 改进的灰度检测:同时考虑通道差异和饱和度
91
+ is_gray = (avg_diff < 5.0) or (avg_saturation < 20.0)
92
+
93
+ logger.info(f"Grayscale detection - Average channel difference: {avg_diff:.2f}, Average saturation: {avg_saturation:.2f}, Result: {is_gray}")
94
+ return is_gray
95
+ return False
96
+
97
+ def colorize_image(self, image):
98
+ """
99
+ 使用DDColor对灰度图像进行上色
100
+ :param image: 输入图像 (numpy array, BGR格式)
101
+ :return: 上色后的图像 (numpy array, BGR格式)
102
+ """
103
+ if not self.is_available():
104
+ logger.error("DDColor model not initialized")
105
+ return image
106
+
107
+ # 检查是否为灰度图像
108
+ if not self.is_grayscale(image):
109
+ logger.info("Image is already colored, no need for colorization")
110
+ return image
111
+
112
+ return self.colorize_image_direct(image)
113
+
114
+ def colorize_image_direct(self, image):
115
+ """
116
+ 直接对图像进行上色,不检查是否为灰度图
117
+ 使用与test_ddcolor.py相同质量的文件路径方法
118
+ :param image: 输入图像 (numpy array, BGR格式)
119
+ :return: 上色后的图像 (numpy array, BGR格式)
120
+ """
121
+ if not self.is_available():
122
+ logger.error("DDColor model not initialized")
123
+ return image
124
+
125
+ # 直接使用文件路径方法,这是经过验证效果最好的方式
126
+ return self._colorize_image_via_file(image)
127
+
128
+ def _colorize_image_via_file(self, image):
129
+ """
130
+ 通过临时文件进行上色,尽可能模拟test_ddcolor.py的处理方式
131
+ :param image: 输入图像 (numpy array, BGR格式)
132
+ :return: 上色后的图像 (numpy array, BGR格式)
133
+ """
134
+ try:
135
+ logger.info("Using high-quality file path method for colorization...")
136
+
137
+ # 使用最高质量设置保存临时图像,尽可能保持原始质量
138
+ with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_input:
139
+ # 使用WebP格式以获得更好的质量和更小的文件大小
140
+ cv2.imwrite(tmp_input.name, image, [cv2.IMWRITE_WEBP_QUALITY, 100])
141
+ tmp_input_path = tmp_input.name
142
+
143
+ try:
144
+ logger.info(f"Temporary file saved to: {tmp_input_path}")
145
+
146
+ # 使用ModelScope进行上色 - 与test_colorization完全相同的调用方式
147
+ result = self.colorizer(tmp_input_path)
148
+
149
+ # 获取上色后的图像 - 与test_colorization完全相同的处理
150
+ colorized_image = result[self.OutputKeys.OUTPUT_IMG]
151
+
152
+ logger.info(f"Colorization output: size={colorized_image.shape}, type={colorized_image.dtype}")
153
+
154
+ # ModelScope输出的图像已经是BGR格式,不需要转换
155
+ # (与test_colorization保存时直接使用cv2.imwrite一致)
156
+ logger.info("High-quality file path method colorization completed")
157
+ return colorized_image
158
+
159
+ finally:
160
+ # 清理临时文件
161
+ try:
162
+ os.unlink(tmp_input_path)
163
+ except:
164
+ pass
165
+
166
+ except Exception as e:
167
+ logger.error(f"High-quality file path method colorization failed: {e}")
168
+ logger.info("Returning original image")
169
+ return image
170
+
171
+ def restore_and_colorize(self, image, gfpgan_restorer=None):
172
+ """
173
+ 先修复后上色的组合处理(旧版本,保持兼容性)
174
+ :param image: 输入图像
175
+ :param gfpgan_restorer: GFPGAN修复器实例
176
+ :return: 修复并上色后的图像
177
+ """
178
+ try:
179
+ # 先进行修复(如果有修复器)
180
+ if gfpgan_restorer and gfpgan_restorer.is_available():
181
+ logger.info("First performing image restoration...")
182
+ restored_image = gfpgan_restorer.restore_image(image)
183
+ else:
184
+ restored_image = image
185
+
186
+ # 再进行上色
187
+ if self.is_grayscale(restored_image):
188
+ logger.info("Grayscale image detected, performing colorization...")
189
+ colorized_image = self.colorize_image(restored_image)
190
+ return colorized_image
191
+ else:
192
+ logger.info("Image is already colored, only returning restoration result")
193
+ return restored_image
194
+
195
+ except Exception as e:
196
+ logger.error(f"Restoration and colorization combination processing failed: {e}")
197
+ return image
198
+
199
+ def colorize_and_restore(self, image, gfpgan_restorer=None):
200
+ """
201
+ 先上色后修复的组合处理(新版本)
202
+ :param image: 输入图像
203
+ :param gfpgan_restorer: GFPGAN修复器实例
204
+ :return: 上色并修复后的图像
205
+ """
206
+ try:
207
+ # 先进行上色(如果是灰度图)
208
+ if self.is_grayscale(image):
209
+ logger.info("Grayscale image detected, performing colorization first...")
210
+ colorized_image = self.colorize_image_direct(image)
211
+ else:
212
+ logger.info("Image is already colored, skipping colorization step")
213
+ colorized_image = image
214
+
215
+ # 再进行修复(如果有修复器)
216
+ if gfpgan_restorer and gfpgan_restorer.is_available():
217
+ logger.info("Performing restoration on the colorized image...")
218
+ final_image = gfpgan_restorer.restore_image(colorized_image)
219
+ return final_image
220
+ else:
221
+ logger.info("No restorer available, returning colorization result")
222
+ return colorized_image
223
+
224
+ except Exception as e:
225
+ logger.error(f"Colorization and restoration combination processing failed: {e}")
226
+ return image
227
+
228
+ def save_debug_image(self, image, filename_prefix):
229
+ """保存调试用的图像"""
230
+ try:
231
+ debug_path = f"{filename_prefix}_debug.webp"
232
+ cv2.imwrite(debug_path, image, [cv2.IMWRITE_WEBP_QUALITY, 95])
233
+ logger.info(f"Debug image saved: {debug_path}")
234
+ return debug_path
235
+ except Exception as e:
236
+ logger.error(f"Failed to save debug image: {e}")
237
+ return None
238
+
239
+ def test_colorization(self, test_url=None):
240
+ """
241
+ 测试上色功能
242
+ :param test_url: 测试图像URL,默认使用官方示例
243
+ :return: 测试结果
244
+ """
245
+ if not self.is_available():
246
+ return False, "DDColor模型未初始化"
247
+
248
+ try:
249
+ test_url = test_url or 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/audrey_hepburn.jpg'
250
+ logger.info(f"Testing DDColor colorization feature, using image: {test_url}")
251
+
252
+ result = self.colorizer(test_url)
253
+ colorized_img = result[self.OutputKeys.OUTPUT_IMG]
254
+
255
+ # 保存测试结果
256
+ test_output_path = 'ddcolor_test_result.webp'
257
+ cv2.imwrite(test_output_path, colorized_img, [cv2.IMWRITE_WEBP_QUALITY, 95])
258
+
259
+ logger.info(f"DDColor test successful, result saved to: {test_output_path}")
260
+ return True, f"测试成功,结果保存到: {test_output_path}"
261
+
262
+ except Exception as e:
263
+ logger.error(f"DDColor test failed: {e}")
264
+ return False, f"测试失败: {e}"
265
+
266
+ def test_local_image(self, image_path):
267
+ """
268
+ 测试本地图像上色,用于对比分析
269
+ :param image_path: 本地图像路径
270
+ :return: 测试结果
271
+ """
272
+ if not self.is_available():
273
+ return False, "DDColor模型未初始化"
274
+
275
+ try:
276
+ logger.info(f"Testing local image colorization: {image_path}")
277
+
278
+ # 读取本地图像
279
+ image = cv2.imread(image_path)
280
+ if image is None:
281
+ return False, f"无法读取图像: {image_path}"
282
+
283
+ # 检查是否为灰度
284
+ is_gray = self.is_grayscale(image)
285
+ logger.info(f"Local image grayscale detection result: {is_gray}")
286
+
287
+ # 保存原图用于对比
288
+ self.save_debug_image(image, "original")
289
+
290
+ # 直接上色
291
+ colorized_image = self.colorize_image_direct(image)
292
+
293
+ # 保存上色结果
294
+ result_path = self.save_debug_image(colorized_image, "local_colorized")
295
+
296
+ logger.info(f"Local image colorization successful, result saved to: {result_path}")
297
+ return True, f"本地图像上色成功,结果保存到: {result_path}"
298
+
299
+ except Exception as e:
300
+ logger.error(f"Local image colorization failed: {e}")
301
+ return False, f"本地图像上色失败: {e}"
debug_colorize.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 调试上色效果差异的脚本
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ import cv2
9
+ import numpy as np
10
+
11
+ # 添加当前目录到路径
12
+ sys.path.insert(0, os.path.dirname(__file__))
13
+
14
+ from ddcolor_colorizer import DDColorColorizer
15
+ from gfpgan_restorer import GFPGANRestorer
16
+ import logging
17
+
18
+ # 设置日志
19
+ logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
20
+
21
+ def simulate_api_processing(image_path):
22
+ """
23
+ 模拟API接口的完整处理流程
24
+ """
25
+ print("\n=== 模拟API接口处理流程 ===")
26
+
27
+ # 初始化组件
28
+ print("初始化GFPGAN修复器...")
29
+ try:
30
+ gfpgan_restorer = GFPGANRestorer()
31
+ if not gfpgan_restorer.is_available():
32
+ print("❌ GFPGAN不可用")
33
+ return None
34
+ print("✅ GFPGAN初始化成功")
35
+ except Exception as e:
36
+ print(f"❌ GFPGAN初始化失败: {e}")
37
+ return None
38
+
39
+ print("初始化DDColor上色器...")
40
+ try:
41
+ ddcolor_colorizer = DDColorColorizer()
42
+ if not ddcolor_colorizer.is_available():
43
+ print("❌ DDColor不可用")
44
+ return None
45
+ print("✅ DDColor初始化成功")
46
+ except Exception as e:
47
+ print(f"❌ DDColor初始化失败: {e}")
48
+ return None
49
+
50
+ # 读取图像
51
+ print(f"读取图像: {image_path}")
52
+ image = cv2.imread(image_path)
53
+ if image is None:
54
+ print(f"❌ 无法读取图像: {image_path}")
55
+ return None
56
+
57
+ print(f"原图尺寸: {image.shape}")
58
+
59
+ # 保存原图
60
+ ddcolor_colorizer.save_debug_image(image, "api_original")
61
+
62
+ # 检查原图灰度状态
63
+ original_is_grayscale = ddcolor_colorizer.is_grayscale(image)
64
+ print(f"原图灰度检测: {original_is_grayscale}")
65
+
66
+ # 新的处理流程:先上色再修复
67
+ # 步骤1: 上色处理
68
+ print("\n步骤1: 上色处理...")
69
+ if original_is_grayscale:
70
+ print("策略: 对原图进行上色")
71
+ colorized_image = ddcolor_colorizer.colorize_image_direct(image)
72
+ ddcolor_colorizer.save_debug_image(colorized_image, "api_colorized")
73
+ strategy = "先上色"
74
+ current_image = colorized_image
75
+ else:
76
+ print("策略: 图像已经是彩色的,跳过上色")
77
+ strategy = "跳过上色"
78
+ current_image = image
79
+
80
+ # 步骤2: GFPGAN修复
81
+ print("\n步骤2: GFPGAN修复...")
82
+ final_image = gfpgan_restorer.restore_image(current_image)
83
+ print(f"修复后图像尺寸: {final_image.shape}")
84
+
85
+ # 保存最终结果
86
+ result_path = ddcolor_colorizer.save_debug_image(final_image, "api_final")
87
+
88
+ strategy += " -> 再修复"
89
+
90
+ print(f"\n✅ API模拟完成")
91
+ print(f" - 处理策略: {strategy}")
92
+ print(f" - 最终结果: {result_path}")
93
+
94
+ return {
95
+ 'original': image,
96
+ 'colorized': colorized_image if original_is_grayscale else None,
97
+ 'final': final_image,
98
+ 'strategy': strategy
99
+ }
100
+
101
+ def test_direct_colorization(image_path):
102
+ """
103
+ 测试直接上色(类似test_ddcolor.py的方式)
104
+ """
105
+ print("\n=== 测试直接上色 ===")
106
+
107
+ colorizer = DDColorColorizer()
108
+ if not colorizer.is_available():
109
+ print("❌ DDColor不可用")
110
+ return None
111
+
112
+ # 直接使用URL进行上色(和test_ddcolor.py相同)
113
+ print("使用官方示例URL上色...")
114
+ success, message = colorizer.test_colorization()
115
+
116
+ if success:
117
+ print(f"✅ URL上色成功: {message}")
118
+ else:
119
+ print(f"❌ URL上色失败: {message}")
120
+
121
+ # 对本地图像进行直接上色
122
+ print(f"对本地图像直接上色: {image_path}")
123
+ success, message = colorizer.test_local_image(image_path)
124
+
125
+ if success:
126
+ print(f"✅ 本地图像上色成功: {message}")
127
+ else:
128
+ print(f"❌ 本地图像上色失败: {message}")
129
+
130
+ def compare_results():
131
+ """
132
+ 对比分析结果
133
+ """
134
+ print("\n=== 结果对比分析 ===")
135
+
136
+ # 列出生成的调试图像
137
+ debug_files = []
138
+ for f in os.listdir("."):
139
+ if f.endswith("_debug.webp"):
140
+ debug_files.append(f)
141
+
142
+ if debug_files:
143
+ print("生成的调试文件:")
144
+ for f in sorted(debug_files):
145
+ print(f" - {f}")
146
+
147
+ print("\n对比建议:")
148
+ print("1. 比较 original_debug.webp 和 api_original_debug.webp")
149
+ print("2. 比较 local_colorized_debug.webp 和 api_final_debug.webp")
150
+ print("3. 检查 api_restored_debug.webp 的修复效果")
151
+ print("4. 观察 ddcolor_test_result.webp 的官方示例效果")
152
+ else:
153
+ print("未找到调试文件")
154
+
155
+ def analyze_image_quality(image_path):
156
+ """
157
+ 分析图像质量指标
158
+ """
159
+ print(f"\n=== 分析图像质量: {image_path} ===")
160
+
161
+ if not os.path.exists(image_path):
162
+ print(f"文件不存在: {image_path}")
163
+ return
164
+
165
+ image = cv2.imread(image_path)
166
+ if image is None:
167
+ print(f"无法读取图像: {image_path}")
168
+ return
169
+
170
+ # 基本信息
171
+ h, w, c = image.shape
172
+ print(f"尺寸: {w}x{h}, 通道数: {c}")
173
+
174
+ # 亮度分析
175
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
176
+ mean_brightness = np.mean(gray)
177
+ print(f"平均亮度: {mean_brightness:.2f}")
178
+
179
+ # 对比度分析
180
+ contrast = np.std(gray)
181
+ print(f"对比度(标准差): {contrast:.2f}")
182
+
183
+ # 色彩分析
184
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
185
+ mean_saturation = np.mean(hsv[:, :, 1])
186
+ print(f"平均饱和度: {mean_saturation:.2f}")
187
+
188
+ # 锐度分析(拉普拉斯算子)
189
+ laplacian = cv2.Laplacian(gray, cv2.CV_64F)
190
+ sharpness = np.var(laplacian)
191
+ print(f"锐度: {sharpness:.2f}")
192
+
193
+ def main():
194
+ """主函数"""
195
+ print("DDColor 上色效果调试工具")
196
+ print("=" * 60)
197
+
198
+ # 可以指定测试图像路径,或使用默认路径
199
+ test_image_path = "/path/to/your/test/image.jpg" # 替换为实际路径
200
+
201
+ if len(sys.argv) > 1:
202
+ test_image_path = sys.argv[1]
203
+
204
+ print(f"测试图像路径: {test_image_path}")
205
+
206
+ if not os.path.exists(test_image_path):
207
+ print("⚠️ 测试图像不存在,将只运行URL测试")
208
+
209
+ # 只测试直接上色
210
+ test_direct_colorization(None)
211
+
212
+ else:
213
+ # 分析原图质量
214
+ analyze_image_quality(test_image_path)
215
+
216
+ # 测试直接上色
217
+ test_direct_colorization(test_image_path)
218
+
219
+ # 模拟API处理
220
+ api_result = simulate_api_processing(test_image_path)
221
+
222
+ # 分析结果图像质量
223
+ if os.path.exists("api_final_debug.webp"):
224
+ print("\n--- API处理结果质量分析 ---")
225
+ analyze_image_quality("api_final_debug.webp")
226
+
227
+ if os.path.exists("local_colorized_debug.webp"):
228
+ print("\n--- 直接上色结果质量分析 ---")
229
+ analyze_image_quality("local_colorized_debug.webp")
230
+
231
+ # 对比分析
232
+ compare_results()
233
+
234
+ print("\n调试完成!")
235
+ print("请检查生成的调试图像来识别问题所在。")
236
+
237
+ if __name__ == "__main__":
238
+ main()
face_analyzer.py ADDED
@@ -0,0 +1,1099 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ from typing import List, Dict, Any
5
+
6
+ import cv2
7
+ import numpy as np
8
+
9
+ import config
10
+ from config import logger, MODELS_PATH, OUTPUT_DIR, DEEPFACE_AVAILABLE, \
11
+ YOLO_AVAILABLE
12
+ from facial_analyzer import FacialFeatureAnalyzer
13
+ from models import ModelType
14
+ from utils import save_image_high_quality
15
+
16
+ if DEEPFACE_AVAILABLE:
17
+ from deepface import DeepFace
18
+
19
+ # 可选导入 YOLO
20
+ if YOLO_AVAILABLE:
21
+ try:
22
+ from ultralytics import YOLO
23
+
24
+ YOLO_AVAILABLE = True
25
+ except ImportError:
26
+ YOLO_AVAILABLE = False
27
+ YOLO = None
28
+ print("Warning: ENABLE_YOLO=true but ultralytics not available")
29
+
30
+
31
+ class EnhancedFaceAnalyzer:
32
+ """增强版人脸分析器 - 支持混合模型"""
33
+
34
+ def __init__(self, models_dir: str = MODELS_PATH):
35
+ """
36
+ 初始化人脸分析器
37
+ :param models_dir: 模型文件目录
38
+ """
39
+ start_time = time.perf_counter()
40
+ self.models_dir = models_dir
41
+ self.MODEL_MEAN_VALUES = (104, 117, 123)
42
+ self.age_list = [
43
+ "(0-2)",
44
+ "(4-6)",
45
+ "(8-12)",
46
+ "(15-20)",
47
+ "(25-32)",
48
+ "(38-43)",
49
+ "(48-53)",
50
+ "(60-100)",
51
+ ]
52
+ self.gender_list = ["Male", "Female"]
53
+ # 性别对应的颜色 (BGR格式)
54
+ self.gender_colors = {
55
+ "Male": (255, 165, 0), # 橙色 Orange
56
+ "Female": (255, 0, 255), # 洋红 Magenta / Fuchsia
57
+ }
58
+
59
+ # 初始化五官分析器
60
+ self.facial_analyzer = FacialFeatureAnalyzer()
61
+ # 加载HowCuteAmI模型
62
+ self._load_howcuteami_models()
63
+ # 加载YOLOv人脸检测模型
64
+ self._load_yolo_model()
65
+
66
+ # 预热模型(可选,通过配置开关)
67
+ if getattr(config, "ENABLE_WARMUP", False):
68
+ self._warmup_models()
69
+
70
+ init_time = time.perf_counter() - start_time
71
+ logger.info(f"EnhancedFaceAnalyzer initialized successfully, time: {init_time:.3f}s")
72
+
73
+ def _cap_conf(self, value: float) -> float:
74
+ """将置信度限制在 [0, 0.9999] 并保留4位小数。"""
75
+ try:
76
+ v = float(value if value is not None else 0.0)
77
+ except Exception:
78
+ v = 0.0
79
+ if v >= 1.0:
80
+ v = 0.9999
81
+ if v < 0.0:
82
+ v = 0.0
83
+ return round(v, 4)
84
+
85
+ def _adjust_beauty_score(self, score: float) -> float:
86
+ try:
87
+ if not config.BEAUTY_ADJUST_ENABLED:
88
+ return score
89
+ # 读取提分区间与力度
90
+ low = float(getattr(config, "BEAUTY_ADJUST_MIN", 6.0))
91
+ high = float(getattr(config, "BEAUTY_ADJUST_MAX", getattr(config, "BEAUTY_ADJUST_THRESHOLD", 8.0)))
92
+ gamma = float(getattr(config, "BEAUTY_ADJUST_GAMMA", 0.3))
93
+ gamma = max(0.0001, min(1.0, gamma))
94
+
95
+ # 区间有效性保护
96
+ if not (0.0 <= low < high <= 10.0):
97
+ return score
98
+
99
+ # 低于下限不提分,区间内提向上限,高于上限不变
100
+ if score < low:
101
+ return score
102
+ if score < high:
103
+ # 向上限 high 进行温和靠拢:adjusted = high - gamma * (high - score)
104
+ adjusted = high - gamma * (high - score)
105
+ adjusted = round(min(10.0, max(0.0, adjusted)), 1)
106
+ try:
107
+ logger.info(
108
+ f"beauty_score adjusted: original={score:.1f} -> adjusted={adjusted:.1f} "
109
+ f"(range=[{low:.1f},{high:.1f}], gamma={gamma:.3f})"
110
+ )
111
+ except Exception:
112
+ pass
113
+ return adjusted
114
+ return score
115
+ except Exception:
116
+ return score
117
+
118
+ def _load_yolo_model(self):
119
+ """加载YOLOv人脸检测模型"""
120
+ self.yolo_model = None
121
+ if config.YOLO_AVAILABLE:
122
+ try:
123
+ # 尝试加载本地YOLOv人脸模型
124
+ yolo_face_path = os.path.join(self.models_dir, config.YOLO_MODEL)
125
+
126
+ if os.path.exists(yolo_face_path):
127
+ self.yolo_model = YOLO(yolo_face_path)
128
+ logger.info(f"Local YOLO face model loaded successfully: {yolo_face_path}")
129
+ else:
130
+ # 如果本地没有,尝试在线下载(第一次使用时)
131
+ logger.info("Local YOLO face model does not exist, attempting to download...")
132
+ try:
133
+ # 检查是否是yolov8,使用相应的模型
134
+ model_name = "yolov11n-face.pt" # 默认使用yolov8n
135
+ self.yolo_model = YOLO(model_name)
136
+ logger.info(
137
+ f"YOLOv8 general model loaded successfully (detecting 'person' class as face regions)"
138
+ )
139
+ except Exception as e:
140
+ logger.warning(f"YOLOv model download failed: {e}")
141
+
142
+ except Exception as e:
143
+ logger.error(f"YOLOv model loading failed: {e}")
144
+ else:
145
+ logger.warning("ultralytics not installed, cannot use YOLOv")
146
+
147
+ def _load_howcuteami_models(self):
148
+ """加载HowCuteAmI深度学习模型"""
149
+ try:
150
+ # 人脸检测模型
151
+ face_proto = os.path.join(self.models_dir, "opencv_face_detector.pbtxt")
152
+ face_model = os.path.join(self.models_dir, "opencv_face_detector_uint8.pb")
153
+ self.face_net = cv2.dnn.readNet(face_model, face_proto)
154
+
155
+ # 年龄预测模型
156
+ age_proto = os.path.join(self.models_dir, "age_googlenet.prototxt")
157
+ age_model = os.path.join(self.models_dir, "age_googlenet.caffemodel")
158
+ self.age_net = cv2.dnn.readNet(age_model, age_proto)
159
+
160
+ # 性别预测模型
161
+ gender_proto = os.path.join(self.models_dir, "gender_googlenet.prototxt")
162
+ gender_model = os.path.join(self.models_dir, "gender_googlenet.caffemodel")
163
+ self.gender_net = cv2.dnn.readNet(gender_model, gender_proto)
164
+
165
+ # 颜值预测模型
166
+ beauty_proto = os.path.join(self.models_dir, "beauty_resnet.prototxt")
167
+ beauty_model = os.path.join(self.models_dir, "beauty_resnet.caffemodel")
168
+ self.beauty_net = cv2.dnn.readNet(beauty_model, beauty_proto)
169
+
170
+ logger.info("HowCuteAmI model loaded successfully!")
171
+
172
+ except Exception as e:
173
+ logger.error(f"HowCuteAmI model loading failed: {e}")
174
+ raise e
175
+
176
+ # 人脸检测方法
177
+ def _detect_faces(
178
+ self, frame: np.ndarray, conf_threshold: float = config.FACE_CONFIDENCE
179
+ ) -> List[List[int]]:
180
+ """
181
+ 使用YOLO进行人脸检测,如果失败则回退到OpenCV DNN
182
+ """
183
+ # 优先使用YOLO
184
+ face_boxes = []
185
+ if self.yolo_model is not None:
186
+ try:
187
+ results = self.yolo_model(frame, conf=conf_threshold, verbose=False)
188
+ for result in results:
189
+ boxes = result.boxes
190
+ if boxes is not None:
191
+ for box in boxes:
192
+ # 检查类别ID (如果是专门的人脸模型,通常是0;如果是通用模型,person类别通常是0)
193
+ class_id = int(box.cls[0])
194
+ # 获取边界框坐标 (xyxy格式)
195
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
196
+ confidence = float(box.conf[0])
197
+ logger.info(
198
+ f"detect class_id={class_id}, confidence={confidence}"
199
+ )
200
+ # 基本边界检查
201
+ frame_height, frame_width = frame.shape[:2]
202
+ x1 = max(0, int(x1))
203
+ y1 = max(0, int(y1))
204
+ x2 = min(frame_width, int(x2))
205
+ y2 = min(frame_height, int(y2))
206
+
207
+ # 过滤太小的检测框
208
+ width, height = x2 - x1, y2 - y1
209
+ if (
210
+ width > 30 and height > 30
211
+ ): # YOLO通常检测精度更高,可以稍微提高最小尺寸
212
+ # 如果使用通用模型检测person,需要进一步过滤头部区域
213
+ if self._is_likely_face_region(x1, y1, x2, y2, frame):
214
+ face_boxes.append(self._scale_box([x1, y1, x2, y2]))
215
+ logger.info(
216
+ f"YOLO detected {len(face_boxes)} faces, conf_threshold={conf_threshold}"
217
+ )
218
+ if face_boxes: # 如果YOLO检测到了人脸,直接返回
219
+ return face_boxes
220
+
221
+ except Exception as e:
222
+ logger.warning(f"YOLO detection failed, falling back to OpenCV DNN: {e}")
223
+ return self._detect_faces_opencv_fallback(frame, conf_threshold)
224
+
225
+ return face_boxes
226
+
227
+ def _is_likely_face_region(
228
+ self, x1: int, y1: int, x2: int, y2: int, frame: np.ndarray
229
+ ) -> bool:
230
+ """
231
+ 判断检测区域是否可能是人脸区域(当使用通用YOLO模型时)
232
+ """
233
+ width, height = x2 - x1, y2 - y1
234
+
235
+ # 长宽比检查 - 人脸/头部通常接近正方形
236
+ aspect_ratio = width / height
237
+ if not (0.6 <= aspect_ratio <= 1.6):
238
+ return False
239
+
240
+ # 位置检查 - 人脸通常在图像上半部分(简单启发式)
241
+ frame_height = frame.shape[0]
242
+ center_y = (y1 + y2) / 2
243
+ if center_y > frame_height * 0.8: # 如果中心点在图像下方80%以下,可能不是人脸
244
+ return False
245
+
246
+ # 尺寸检查 - 不应该占据整个图像
247
+ frame_width, frame_height = frame.shape[1], frame.shape[0]
248
+ if width > frame_width * 0.8 or height > frame_height * 0.8:
249
+ return False
250
+
251
+ return True
252
+
253
+ def _detect_faces_opencv_fallback(
254
+ self, frame: np.ndarray, conf_threshold: float = 0.5
255
+ ) -> List[List[int]]:
256
+ """
257
+ 优化版人脸检测 - 支持多尺度检测和小人脸识别
258
+ """
259
+ frame_height, frame_width = frame.shape[:2]
260
+ all_boxes = []
261
+
262
+ # 多尺度检测配置 - 从小到大,更好地检测不同大小的人脸
263
+ detection_configs = [
264
+ {"size": (300, 300), "threshold": conf_threshold},
265
+ {
266
+ "size": (416, 416),
267
+ "threshold": max(0.3, conf_threshold - 0.2),
268
+ }, # 对大尺度降低阈值
269
+ {
270
+ "size": (512, 512),
271
+ "threshold": max(0.25, conf_threshold - 0.25),
272
+ }, # 进一步降低阈值检测小脸
273
+ ]
274
+ logger.info(f"Detecting faces using opencv, conf_threshold={conf_threshold}")
275
+ for config in detection_configs:
276
+ try:
277
+ # 图像预处理 - 增强对比度有助于小人脸检测
278
+ processed_frame = cv2.convertScaleAbs(frame, alpha=1.1, beta=10)
279
+
280
+ blob = cv2.dnn.blobFromImage(
281
+ processed_frame, 1.0, config["size"], [104, 117, 123], True, False
282
+ )
283
+ self.face_net.setInput(blob)
284
+ detections = self.face_net.forward()
285
+
286
+ # 提取检测结果
287
+ for i in range(detections.shape[2]):
288
+ confidence = detections[0, 0, i, 2]
289
+ if confidence > config["threshold"]:
290
+ x1 = int(detections[0, 0, i, 3] * frame_width)
291
+ y1 = int(detections[0, 0, i, 4] * frame_height)
292
+ x2 = int(detections[0, 0, i, 5] * frame_width)
293
+ y2 = int(detections[0, 0, i, 6] * frame_height)
294
+
295
+ # 基本边界检查
296
+ x1, y1 = max(0, x1), max(0, y1)
297
+ x2, y2 = min(frame_width, x2), min(frame_height, y2)
298
+
299
+ # 过滤太小或不合理的检测框
300
+ width, height = x2 - x1, y2 - y1
301
+ if (
302
+ width > 20
303
+ and height > 20
304
+ and width < frame_width * 0.8
305
+ and height < frame_height * 0.8
306
+ ):
307
+ # 长宽比检查 - 人脸通常接近正方形
308
+ aspect_ratio = width / height
309
+ if 0.6 <= aspect_ratio <= 1.8: # 允许一定的椭圆形变
310
+ all_boxes.append(
311
+ {
312
+ "box": [x1, y1, x2, y2],
313
+ "confidence": confidence,
314
+ "area": width * height,
315
+ }
316
+ )
317
+ except Exception as e:
318
+ logger.warning(f"Scale {config['size']} detection failed: {e}")
319
+ continue
320
+
321
+ # 如果没有检测到任何人脸,尝试更宽松的条件
322
+ if not all_boxes:
323
+ logger.info("No faces detected, trying more relaxed detection conditions...")
324
+ try:
325
+ # 最后一次尝试:最低阈值 + 图像增强
326
+ enhanced_frame = cv2.equalizeHist(
327
+ cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
328
+ )
329
+ enhanced_frame = cv2.cvtColor(enhanced_frame, cv2.COLOR_GRAY2BGR)
330
+
331
+ blob = cv2.dnn.blobFromImage(
332
+ enhanced_frame, 1.0, (300, 300), [104, 117, 123], True, False
333
+ )
334
+ self.face_net.setInput(blob)
335
+ detections = self.face_net.forward()
336
+
337
+ for i in range(detections.shape[2]):
338
+ confidence = detections[0, 0, i, 2]
339
+ if confidence > 0.15: # 非常低的阈值
340
+ x1 = int(detections[0, 0, i, 3] * frame_width)
341
+ y1 = int(detections[0, 0, i, 4] * frame_height)
342
+ x2 = int(detections[0, 0, i, 5] * frame_width)
343
+ y2 = int(detections[0, 0, i, 6] * frame_height)
344
+
345
+ x1, y1 = max(0, x1), max(0, y1)
346
+ x2, y2 = min(frame_width, x2), min(frame_height, y2)
347
+
348
+ width, height = x2 - x1, y2 - y1
349
+ if width > 15 and height > 15: # 更小的最小尺寸
350
+ aspect_ratio = width / height
351
+ if 0.5 <= aspect_ratio <= 2.0: # 更宽松的长宽比
352
+ all_boxes.append(
353
+ {
354
+ "box": [x1, y1, x2, y2],
355
+ "confidence": confidence,
356
+ "area": width * height,
357
+ }
358
+ )
359
+ except Exception as e:
360
+ logger.warning(f"Relaxed condition detection also failed: {e}")
361
+
362
+ # NMS (非极大值抑制) 去除重复检测
363
+ if all_boxes:
364
+ final_boxes = self._apply_nms(all_boxes, overlap_threshold=0.4)
365
+ return [self._scale_box(box["box"]) for box in final_boxes]
366
+
367
+ return []
368
+
369
+ def _apply_nms(
370
+ self, detections: List[Dict], overlap_threshold: float = 0.4
371
+ ) -> List[Dict]:
372
+ """
373
+ 非极大值抑制,去除重复的检测框
374
+ """
375
+ if not detections:
376
+ return []
377
+
378
+ # 按置信度排序
379
+ detections.sort(key=lambda x: x["confidence"], reverse=True)
380
+
381
+ keep = []
382
+ while detections:
383
+ # 保留置信度最高的
384
+ best = detections.pop(0)
385
+ keep.append(best)
386
+
387
+ # 移除与最佳检测重叠度高的其他检测
388
+ remaining = []
389
+ for det in detections:
390
+ if self._calculate_iou(best["box"], det["box"]) < overlap_threshold:
391
+ remaining.append(det)
392
+ detections = remaining
393
+
394
+ return keep
395
+
396
+ def _calculate_iou(self, box1: List[int], box2: List[int]) -> float:
397
+ """
398
+ 计算两个边界框的IoU (交并比)
399
+ """
400
+ x1_1, y1_1, x2_1, y2_1 = box1
401
+ x1_2, y1_2, x2_2, y2_2 = box2
402
+
403
+ # 计算交集
404
+ x1_i = max(x1_1, x1_2)
405
+ y1_i = max(y1_1, y1_2)
406
+ x2_i = min(x2_1, x2_2)
407
+ y2_i = min(y2_1, y2_2)
408
+
409
+ if x2_i <= x1_i or y2_i <= y1_i:
410
+ return 0.0
411
+
412
+ intersection = (x2_i - x1_i) * (y2_i - y1_i)
413
+
414
+ # 计算并集
415
+ area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
416
+ area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
417
+ union = area1 + area2 - intersection
418
+
419
+ return intersection / union if union > 0 else 0.0
420
+
421
+ def _scale_box(self, box: List[int]) -> List[int]:
422
+ """将矩形框缩放为正方形"""
423
+ width = box[2] - box[0]
424
+ height = box[3] - box[1]
425
+ maximum = max(width, height)
426
+ dx = int((maximum - width) / 2)
427
+ dy = int((maximum - height) / 2)
428
+
429
+ return [box[0] - dx, box[1] - dy, box[2] + dx, box[3] + dy]
430
+
431
+ def _crop_face(self, image: np.ndarray, box: List[int]) -> np.ndarray:
432
+ """裁剪人脸区域"""
433
+ x1, y1, x2, y2 = box
434
+ h, w = image.shape[:2]
435
+ x1 = max(0, x1)
436
+ y1 = max(0, y1)
437
+ x2 = min(w, x2)
438
+ y2 = min(h, y2)
439
+ return image[y1:y2, x1:x2]
440
+
441
+ def _predict_beauty_gender_with_howcuteami(
442
+ self, face: np.ndarray
443
+ ) -> Dict[str, Any]:
444
+ """使用HowCuteAmI模型预测颜值和性别"""
445
+ try:
446
+ blob = cv2.dnn.blobFromImage(
447
+ face, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
448
+ )
449
+
450
+ # 性别预测
451
+ self.gender_net.setInput(blob)
452
+ gender_preds = self.gender_net.forward()
453
+ gender = self.gender_list[gender_preds[0].argmax()]
454
+ gender_confidence = float(np.max(gender_preds[0]))
455
+ gender_confidence = self._cap_conf(gender_confidence)
456
+ # 年龄预测
457
+ self.age_net.setInput(blob)
458
+ age_preds = self.age_net.forward()
459
+ age = self.age_list[age_preds[0].argmax()]
460
+ age_confidence = float(np.max(age_preds[0]))
461
+ # 颜值预测
462
+ blob_beauty = cv2.dnn.blobFromImage(
463
+ face, 1.0 / 255, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
464
+ )
465
+ self.beauty_net.setInput(blob_beauty)
466
+ beauty_preds = self.beauty_net.forward()
467
+ beauty_score = round(float(2.0 * np.sum(beauty_preds[0])), 1)
468
+ beauty_score = min(10.0, max(0.0, beauty_score))
469
+ beauty_score = self._adjust_beauty_score(beauty_score)
470
+ raw_score = float(np.sum(beauty_preds[0]))
471
+
472
+ return {
473
+ "age": age,
474
+ "age_confidence": round(age_confidence, 4),
475
+ "gender": gender,
476
+ "gender_confidence": gender_confidence,
477
+ "beauty_score": beauty_score,
478
+ "beauty_raw_score": round(raw_score, 4),
479
+ "age_model_used": "HowCuteAmI",
480
+ "gender_model_used": "HowCuteAmI",
481
+ "beauty_model_used": "HowCuteAmI",
482
+ }
483
+ except Exception as e:
484
+ logger.error(f"HowCuteAmI beauty gender prediction failed: {e}")
485
+ raise e
486
+
487
+ def _predict_age_emotion_with_deepface(
488
+ self, face_image: np.ndarray
489
+ ) -> Dict[str, Any]:
490
+ """使用DeepFace预测年龄、情绪(并返回可用的性别信息用于回退)"""
491
+ if not DEEPFACE_AVAILABLE:
492
+ # 如���DeepFace不可用,使用HowCuteAmI的年龄预测作为回退
493
+ return self._predict_age_with_howcuteami_fallback(face_image)
494
+
495
+ if face_image is None or face_image.size == 0:
496
+ raise ValueError("无效的人脸图像")
497
+
498
+ try:
499
+ # DeepFace分析 - 禁用进度条和详细输出
500
+ result = DeepFace.analyze(
501
+ img_path=face_image,
502
+ actions=["age", "emotion", "gender"],
503
+ enforce_detection=False,
504
+ detector_backend="skip",
505
+ silent=True # 禁用进度条输出
506
+ )
507
+
508
+ # 处理结果 (DeepFace返回的结果格式可能是list或dict)
509
+ if isinstance(result, list):
510
+ result = result[0]
511
+
512
+ # 提取信息
513
+ age = result.get("age", 25)
514
+ emotion = result.get("dominant_emotion", "neutral")
515
+ emotion_scores = result.get("emotion", {})
516
+ # 性别信息(用于在HowCuteAmI置信度低时回退)
517
+ deep_gender = result.get("dominant_gender", "Woman")
518
+ deep_gender_conf = result.get("gender", {}).get(deep_gender, 50.0) / 100.0
519
+ deep_gender_conf = self._cap_conf(deep_gender_conf)
520
+ if str(deep_gender).lower() in ["woman", "female"]:
521
+ deep_gender = "Female"
522
+ else:
523
+ deep_gender = "Male"
524
+
525
+ age_conf = round(random.uniform(0.7613, 0.9599), 4)
526
+ return {
527
+ "age": str(int(age)),
528
+ "age_confidence": age_conf,
529
+ "emotion": emotion,
530
+ "emotion_analysis": emotion_scores,
531
+ "gender": deep_gender,
532
+ "gender_confidence": deep_gender_conf,
533
+ }
534
+ except Exception as e:
535
+ logger.error(f"DeepFace age emotion prediction failed, falling back to HowCuteAmI: {e}")
536
+ return self._predict_age_with_howcuteami_fallback(face_image)
537
+
538
+ def _predict_age_with_howcuteami_fallback(
539
+ self, face_image: np.ndarray
540
+ ) -> Dict[str, Any]:
541
+ """HowCuteAmI年龄预测回退方案"""
542
+ try:
543
+ if face_image is None or face_image.size == 0:
544
+ raise ValueError("无法读取人脸图像")
545
+
546
+ face_resized = cv2.resize(face_image, (224, 224))
547
+ blob = cv2.dnn.blobFromImage(
548
+ face_resized, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
549
+ )
550
+
551
+ # 年龄预测
552
+ self.age_net.setInput(blob)
553
+ age_preds = self.age_net.forward()
554
+ age = self.age_list[age_preds[0].argmax()]
555
+ age_confidence = float(np.max(age_preds[0]))
556
+
557
+ return {
558
+ "age": age[1:-1], # 去掉括号
559
+ "age_confidence": round(age_confidence, 4),
560
+ "emotion": "neutral", # 默认情绪
561
+ "emotion_analysis": {"neutral": 100.0}, # 默认情绪分析
562
+ }
563
+ except Exception as e:
564
+ logger.error(f"HowCuteAmI age prediction fallback failed: {e}")
565
+ return {
566
+ "age": "25-32",
567
+ "age_confidence": 0.5,
568
+ "emotion": "neutral",
569
+ "emotion_analysis": {"neutral": 100.0},
570
+ }
571
+
572
+ def _predict_with_hybrid_model(
573
+ self, face: np.ndarray, face_image: np.ndarray
574
+ ) -> Dict[str, Any]:
575
+ """混合模型预测:HowCuteAmI(颜值+性别)+ DeepFace(年龄+情绪,年龄置信度低时优先使用)"""
576
+ # 使用HowCuteAmI预测颜值和性别
577
+ beauty_gender_result = self._predict_beauty_gender_with_howcuteami(face)
578
+
579
+ # 首先获取HowCuteAmI的年龄/性别预测置信度
580
+ howcuteami_age_confidence = beauty_gender_result.get("age_confidence", 0)
581
+ gender_confidence = beauty_gender_result.get("gender_confidence", 0)
582
+ if gender_confidence >= 1:
583
+ gender_confidence = 0.9999
584
+ age = beauty_gender_result["age"]
585
+
586
+ # 如果HowCuteAmI的年龄置信度低于阈值,则使用DeepFace的年龄
587
+ agec = config.AGE_CONFIDENCE
588
+ if howcuteami_age_confidence < agec:
589
+ # 使用DeepFace获取年龄/情绪(以及可选的性别回退信息)
590
+ age_emotion_result = self._predict_age_emotion_with_deepface(
591
+ face_image
592
+ )
593
+ deep_age = age_emotion_result["age"]
594
+ logger.info(
595
+ f"HowCuteAmI age confidence ({howcuteami_age_confidence}) below {agec}, value=({age}); using DeepFace for age prediction, value={deep_age}"
596
+ )
597
+ # 合并结果,使用DeepFace的年龄预测
598
+ result = {
599
+ "gender": beauty_gender_result["gender"], # 先用HowCuteAmI,后面可能回退
600
+ "gender_confidence": self._cap_conf(gender_confidence),
601
+ "beauty_score": beauty_gender_result["beauty_score"],
602
+ "beauty_raw_score": beauty_gender_result["beauty_raw_score"],
603
+ "age": deep_age,
604
+ "age_confidence": age_emotion_result["age_confidence"],
605
+ "emotion": age_emotion_result["emotion"],
606
+ "emotion_analysis": age_emotion_result["emotion_analysis"],
607
+ "model_used": "hybrid_deepface_age",
608
+ "age_model_used": "DeepFace",
609
+ "gender_model_used": "HowCuteAmI",
610
+ }
611
+ else:
612
+ # HowCuteAmI年龄置信度足够高,使用原有逻辑
613
+ logger.info(
614
+ f"HowCuteAmI age confidence ({howcuteami_age_confidence}) is high enough, value={age}; using HowCuteAmI for age prediction"
615
+ )
616
+ # 合并结果,保留HowCuteAmI的年龄预测
617
+ result = {
618
+ "gender": beauty_gender_result["gender"], # 先用HowCuteAmI,后面可能回退
619
+ "gender_confidence": self._cap_conf(gender_confidence),
620
+ "beauty_score": beauty_gender_result["beauty_score"],
621
+ "beauty_raw_score": beauty_gender_result["beauty_raw_score"],
622
+ "age": beauty_gender_result["age"],
623
+ "age_confidence": beauty_gender_result["age_confidence"],
624
+ "emotion": None,
625
+ "emotion_analysis": None,
626
+ "model_used": "hybrid",
627
+ "age_model_used": "HowCuteAmI",
628
+ "gender_model_used": "HowCuteAmI",
629
+ }
630
+
631
+ # 统一性别判定规则:任一模型判为Female则Female;两者都为Male才Male
632
+ try:
633
+ how_gender = beauty_gender_result.get("gender")
634
+ how_conf = float(beauty_gender_result.get("gender_confidence", 0) or 0)
635
+ deep_gender = age_emotion_result.get("gender")
636
+ deep_conf = float(age_emotion_result.get("gender_confidence", 0) or 0)
637
+
638
+ final_gender = result.get("gender")
639
+ final_conf = float(result.get("gender_confidence", 0) or 0)
640
+ # 规则判断
641
+ if (str(how_gender) == "Female") or (str(deep_gender) == "Female"):
642
+ final_gender = "Female"
643
+ final_conf = max(how_conf if how_gender == "Female" else 0,
644
+ deep_conf if deep_gender == "Female" else 0)
645
+ result["gender_model_used"] = "Combined(H+DF)"
646
+ elif (str(how_gender) == "Male") and (str(deep_gender) == "Male"):
647
+ final_gender = "Male"
648
+ final_conf = max(how_conf if how_gender == "Male" else 0,
649
+ deep_conf if deep_gender == "Male" else 0)
650
+ result["gender_model_used"] = "Combined(H+DF)"
651
+ # 否则保持原判定
652
+
653
+ result["gender"] = final_gender
654
+ result["gender_confidence"] = self._cap_conf(final_conf)
655
+ except Exception:
656
+ pass
657
+
658
+ return result
659
+
660
+ def _predict_with_howcuteami(self, face: np.ndarray) -> Dict[str, Any]:
661
+ """使用HowCuteAmI模型进行完整预测"""
662
+ try:
663
+ # 性别预测
664
+ blob = cv2.dnn.blobFromImage(
665
+ face, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
666
+ )
667
+ self.gender_net.setInput(blob)
668
+ gender_preds = self.gender_net.forward()
669
+ gender = self.gender_list[gender_preds[0].argmax()]
670
+ gender_confidence = float(np.max(gender_preds[0]))
671
+ gender_confidence = self._cap_conf(gender_confidence)
672
+
673
+ # 年龄预测
674
+ self.age_net.setInput(blob)
675
+ age_preds = self.age_net.forward()
676
+ age = self.age_list[age_preds[0].argmax()]
677
+ age_confidence = float(np.max(age_preds[0]))
678
+
679
+ # 颜值预测
680
+ blob_beauty = cv2.dnn.blobFromImage(
681
+ face, 1.0 / 255, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
682
+ )
683
+ self.beauty_net.setInput(blob_beauty)
684
+ beauty_preds = self.beauty_net.forward()
685
+ beauty_score = round(float(2.0 * np.sum(beauty_preds[0])), 1)
686
+ beauty_score = min(10.0, max(0.0, beauty_score))
687
+ beauty_score = self._adjust_beauty_score(beauty_score)
688
+ raw_score = float(np.sum(beauty_preds[0]))
689
+
690
+ return {
691
+ "gender": gender,
692
+ "gender_confidence": gender_confidence,
693
+ "age": age[1:-1], # 去掉括号
694
+ "age_confidence": round(age_confidence, 4),
695
+ "beauty_score": beauty_score,
696
+ "beauty_raw_score": round(raw_score, 4),
697
+ "model_used": "HowCuteAmI",
698
+ "emotion": "neutral", # HowCuteAmI不支持情绪分析
699
+ "emotion_analysis": {"neutral": 100.0},
700
+ "age_model_used": "HowCuteAmI",
701
+ "gender_model_used": "HowCuteAmI",
702
+ "beauty_model_used": "HowCuteAmI",
703
+ }
704
+ except Exception as e:
705
+ logger.error(f"HowCuteAmI prediction failed: {e}")
706
+ raise e
707
+
708
+ def _predict_with_deepface(self, face_image: np.ndarray) -> Dict[str, Any]:
709
+ """使用DeepFace进行预测"""
710
+ if not DEEPFACE_AVAILABLE:
711
+ raise ValueError("DeepFace未安装")
712
+
713
+ if face_image is None or face_image.size == 0:
714
+ raise ValueError("无效的人脸图像")
715
+
716
+ try:
717
+ # DeepFace分析 - 禁用进度条和详细输出
718
+ result = DeepFace.analyze(
719
+ img_path=face_image,
720
+ actions=["age", "gender", "emotion"],
721
+ enforce_detection=False,
722
+ detector_backend="skip",
723
+ silent=True # 禁用进度条输出
724
+ )
725
+
726
+ # 处理结果 (DeepFace返回的结果格式可能是list或dict)
727
+ if isinstance(result, list):
728
+ result = result[0]
729
+
730
+ # 提取信息
731
+ age = result.get("age", 25)
732
+ gender = result.get("dominant_gender", "Woman")
733
+ gender_confidence = result.get("gender", {}).get(gender, 0.5) / 100
734
+ gender_confidence = self._cap_conf(gender_confidence)
735
+
736
+ # 统一性别标签
737
+ if gender.lower() in ["woman", "female"]:
738
+ gender = "Female"
739
+ else:
740
+ gender = "Male"
741
+
742
+ # DeepFace没有内置颜值评分,这里使用简单的启发式方法
743
+ emotion = result.get("dominant_emotion", "neutral")
744
+ emotion_scores = result.get("emotion", {})
745
+
746
+ # 基于情绪和年龄的简单颜值估算
747
+ happiness_score = emotion_scores.get("happy", 0) / 100
748
+ neutral_score = emotion_scores.get("neutral", 0) / 100
749
+
750
+ # 简单的颜值算法 (可以改进)
751
+ base_beauty = 6.0 # 基础分
752
+ emotion_bonus = happiness_score * 2 + neutral_score * 1
753
+ age_factor = max(0.5, 1 - abs(age - 25) / 50) # 25岁为最佳年龄
754
+
755
+ beauty_score = round(min(10.0, base_beauty + emotion_bonus + age_factor), 2)
756
+
757
+ age_conf = round(random.uniform(0.7613, 0.9599), 4)
758
+ return {
759
+ "gender": gender,
760
+ "gender_confidence": gender_confidence,
761
+ "age": str(int(age)),
762
+ "age_confidence": age_conf, # DeepFace年龄置信度(随机范围)
763
+ "beauty_score": beauty_score,
764
+ "beauty_raw_score": round(beauty_score / 10, 4),
765
+ "model_used": "DeepFace",
766
+ "emotion": emotion,
767
+ "emotion_analysis": emotion_scores,
768
+ "age_model_used": "DeepFace",
769
+ "gender_model_used": "DeepFace",
770
+ "beauty_model_used": "Heuristic",
771
+ }
772
+ except Exception as e:
773
+ logger.error(f"DeepFace prediction failed: {e}")
774
+ raise e
775
+
776
+ def analyze_faces(
777
+ self,
778
+ image: np.ndarray,
779
+ original_image_hash: str,
780
+ model_type: ModelType = ModelType.HYBRID,
781
+ ) -> Dict[str, Any]:
782
+ """
783
+ 分析图片中的人脸
784
+ :param image: 输入图像
785
+ :param original_image_hash: 原始图片的MD5哈希值
786
+ :param model_type: 使用的模型类型
787
+ :return: 分析结果
788
+ """
789
+ if image is None:
790
+ raise ValueError("无效的图像输入")
791
+
792
+ # 检测人脸
793
+ face_boxes = self._detect_faces(image)
794
+
795
+ if not face_boxes:
796
+ return {
797
+ "success": False,
798
+ "message": "请尝试上传清晰、无遮挡的正面照片",
799
+ "face_count": 0,
800
+ "faces": [],
801
+ "annotated_image": None,
802
+ "model_used": model_type.value,
803
+ }
804
+
805
+ results = {
806
+ "success": True,
807
+ "message": f"成功检测到 {len(face_boxes)} 张人脸",
808
+ "face_count": len(face_boxes),
809
+ "faces": [],
810
+ "model_used": model_type.value,
811
+ }
812
+
813
+ # 复制原图用于绘制
814
+ annotated_image = image.copy()
815
+ logger.info(
816
+ f"Input annotated_image shape: {annotated_image.shape}, dtype: {annotated_image.dtype}, ndim: {annotated_image.ndim}"
817
+ )
818
+ # 分析每张人脸
819
+ for i, face_box in enumerate(face_boxes):
820
+ # 裁剪人脸
821
+ face_cropped = self._crop_face(image, face_box)
822
+ if face_cropped.size == 0:
823
+ logger.warning(f"Cropped face {i + 1} is empty, skipping.")
824
+ continue
825
+
826
+ face_resized = cv2.resize(face_cropped, (224, 224))
827
+ face_for_deepface = face_cropped.copy()
828
+
829
+ # 根据模型类型进行预测
830
+ try:
831
+ if model_type == ModelType.HYBRID:
832
+ # 混合模式:颜值性别用HowCuteAmI,年龄情绪用DeepFace
833
+ prediction_result = self._predict_with_hybrid_model(
834
+ face_resized, face_for_deepface
835
+ )
836
+ elif model_type == ModelType.HOWCUTEAMI:
837
+ prediction_result = self._predict_with_howcuteami(face_resized)
838
+ # 非混合模式也进行性别合并:引入DeepFace性别
839
+ try:
840
+ age_emotion_result = self._predict_age_emotion_with_deepface(
841
+ face_for_deepface
842
+ )
843
+ how_gender = prediction_result.get("gender")
844
+ how_conf = float(prediction_result.get("gender_confidence", 0) or 0)
845
+ deep_gender = age_emotion_result.get("gender")
846
+ deep_conf = float(age_emotion_result.get("gender_confidence", 0) or 0)
847
+ final_gender = prediction_result.get("gender")
848
+ final_conf = float(prediction_result.get("gender_confidence", 0) or 0)
849
+ if (str(how_gender) == "Female") or (str(deep_gender) == "Female"):
850
+ final_gender = "Female"
851
+ final_conf = max(how_conf if how_gender == "Female" else 0,
852
+ deep_conf if deep_gender == "Female" else 0)
853
+ prediction_result["gender_model_used"] = "Combined(H+DF)"
854
+ elif (str(how_gender) == "Male") and (str(deep_gender) == "Male"):
855
+ final_gender = "Male"
856
+ final_conf = max(how_conf if how_gender == "Male" else 0,
857
+ deep_conf if deep_gender == "Male" else 0)
858
+ prediction_result["gender_model_used"] = "Combined(H+DF)"
859
+ prediction_result["gender"] = final_gender
860
+ prediction_result["gender_confidence"] = round(float(final_conf), 4)
861
+ except Exception:
862
+ pass
863
+ elif model_type == ModelType.DEEPFACE and DEEPFACE_AVAILABLE:
864
+ prediction_result = self._predict_with_deepface(face_for_deepface)
865
+ # 非混合模式也进行性别合并:引入HowCuteAmI性别
866
+ try:
867
+ beauty_gender_result = self._predict_beauty_gender_with_howcuteami(
868
+ face_resized
869
+ )
870
+ deep_gender = prediction_result.get("gender")
871
+ deep_conf = float(prediction_result.get("gender_confidence", 0) or 0)
872
+ how_gender = beauty_gender_result.get("gender")
873
+ how_conf = float(beauty_gender_result.get("gender_confidence", 0) or 0)
874
+ final_gender = prediction_result.get("gender")
875
+ final_conf = float(prediction_result.get("gender_confidence", 0) or 0)
876
+ if (str(how_gender) == "Female") or (str(deep_gender) == "Female"):
877
+ final_gender = "Female"
878
+ final_conf = max(how_conf if how_gender == "Female" else 0,
879
+ deep_conf if deep_gender == "Female" else 0)
880
+ prediction_result["gender_model_used"] = "Combined(H+DF)"
881
+ elif (str(how_gender) == "Male") and (str(deep_gender) == "Male"):
882
+ final_gender = "Male"
883
+ final_conf = max(how_conf if how_gender == "Male" else 0,
884
+ deep_conf if deep_gender == "Male" else 0)
885
+ prediction_result["gender_model_used"] = "Combined(H+DF)"
886
+ prediction_result["gender"] = final_gender
887
+ prediction_result["gender_confidence"] = round(float(final_conf), 4)
888
+ except Exception:
889
+ pass
890
+ else:
891
+ # 回退到混合模式
892
+ prediction_result = self._predict_with_hybrid_model(
893
+ face_resized, face_for_deepface
894
+ )
895
+ logger.warning(f"Model {model_type.value} is not available, using hybrid mode")
896
+
897
+ except Exception as e:
898
+ logger.error(f"Prediction failed, using default values: {e}")
899
+ prediction_result = {
900
+ "gender": "Unknown",
901
+ "gender_confidence": 0.5,
902
+ "age": "25-32",
903
+ "age_confidence": 0.5,
904
+ "beauty_score": 5.0,
905
+ "beauty_raw_score": 0.5,
906
+ "emotion": "neutral",
907
+ "emotion_analysis": {"neutral": 100.0},
908
+ "model_used": "fallback",
909
+ }
910
+
911
+ # 五官分析
912
+ # facial_features = self.facial_analyzer.analyze_facial_features(
913
+ # face_cropped, face_box
914
+ # )
915
+
916
+ # 颜色设置与年龄显示统一(应用女性年龄调整)
917
+ gender = prediction_result.get("gender", "Unknown")
918
+ color_bgr = self.gender_colors.get(gender, (128, 128, 128))
919
+ color_hex = f"#{color_bgr[2]:02x}{color_bgr[1]:02x}{color_bgr[0]:02x}"
920
+
921
+ # 年龄文本与调整
922
+ raw_age_str = prediction_result.get("age", "Unknown")
923
+ display_age_str = str(raw_age_str)
924
+ age_adjusted_flag = False
925
+ age_adjustment_value = int(getattr(config, "FEMALE_AGE_ADJUSTMENT", 0) or 0)
926
+ age_adjustment_threshold = int(getattr(config, "FEMALE_AGE_ADJUSTMENT_THRESHOLD", 999) or 999)
927
+
928
+ # 仅对女性且年龄达到阈值时进行调整
929
+ try:
930
+ # 支持 "25-32" 或 "25" 格式
931
+ if "-" in str(raw_age_str):
932
+ age_num = int(str(raw_age_str).split("-")[0].strip("() "))
933
+ else:
934
+ age_num = int(str(raw_age_str).strip())
935
+
936
+ if str(gender) == "Female" and age_num >= age_adjustment_threshold and age_adjustment_value > 0:
937
+ adjusted_age = max(0, age_num - age_adjustment_value)
938
+ display_age_str = str(adjusted_age)
939
+ age_adjusted_flag = True
940
+ try:
941
+ logger.info(f"Adjusted age for female (draw+data): {age_num} -> {adjusted_age}")
942
+ except Exception:
943
+ pass
944
+ except Exception:
945
+ # 无法解析年龄时,保持原样
946
+ pass
947
+
948
+ # 保存裁剪的人脸
949
+ cropped_face_filename = f"{original_image_hash}_face_{i + 1}.webp"
950
+ cropped_face_path = os.path.join(OUTPUT_DIR, cropped_face_filename)
951
+ try:
952
+ save_image_high_quality(face_cropped, cropped_face_path)
953
+ logger.info(f"cropped face: {cropped_face_path}")
954
+ except Exception as e:
955
+ logger.error(f"Failed to save cropped face {cropped_face_path}: {e}")
956
+ cropped_face_filename = None
957
+
958
+ # 在图片上绘制标注
959
+ if config.DRAW_SCORE:
960
+ cv2.rectangle(
961
+ annotated_image,
962
+ (face_box[0], face_box[1]),
963
+ (face_box[2], face_box[3]),
964
+ color_bgr,
965
+ int(round(image.shape[0] / 400)),
966
+ 8,
967
+ )
968
+
969
+ # 标签文本
970
+ beauty_score = prediction_result.get("beauty_score", 0)
971
+ label = f"{gender}, {display_age_str}, {beauty_score}"
972
+
973
+ font_scale = max(
974
+ 0.3, min(0.7, image.shape[0] / 800)
975
+ ) # 从500改为800,范围从0.5-1.0改为0.3-0.7
976
+ font_thickness = 2
977
+ font = cv2.FONT_HERSHEY_SIMPLEX
978
+ # 绘制文本
979
+ text_x = face_box[0]
980
+ text_y = face_box[1] - 10 if face_box[1] - 10 > 20 else face_box[1] + 30
981
+
982
+ # 计算文字大小(宽高)
983
+ (text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, font_thickness)
984
+
985
+ # 画黑色矩形背景,稍微比文字框大一点,增加边距
986
+ background_tl = (text_x, text_y - text_height - baseline) # 矩形左上角
987
+ background_br = (text_x + text_width, text_y + baseline) # 矩形右下角
988
+
989
+ if config.DRAW_SCORE:
990
+ cv2.rectangle(
991
+ annotated_image,
992
+ background_tl,
993
+ background_br,
994
+ color_bgr, # 黑色背景
995
+ thickness=-1 # 填充
996
+ )
997
+ cv2.putText(
998
+ annotated_image,
999
+ label,
1000
+ (text_x, text_y),
1001
+ font,
1002
+ font_scale,
1003
+ (255, 255, 255),
1004
+ font_thickness,
1005
+ cv2.LINE_AA,
1006
+ )
1007
+
1008
+ # 构建人脸结果
1009
+ face_result = {
1010
+ "face_id": i + 1,
1011
+ "gender": gender,
1012
+ "gender_confidence": prediction_result.get("gender_confidence", 0),
1013
+ "gender_model_used": prediction_result.get("gender_model_used", prediction_result.get("model_used", model_type.value)),
1014
+ "age": display_age_str,
1015
+ "age_confidence": prediction_result.get("age_confidence", 0),
1016
+ "age_model_used": prediction_result.get("age_model_used", prediction_result.get("model_used", model_type.value)),
1017
+ "beauty_score": prediction_result.get("beauty_score", 0),
1018
+ "beauty_raw_score": prediction_result.get("beauty_raw_score", 0),
1019
+ "emotion": prediction_result.get("emotion", "neutral"),
1020
+ "emotion_analysis": prediction_result.get("emotion_analysis", {}),
1021
+ # "facial_features": facial_features, # 五官分析
1022
+ "bounding_box": {
1023
+ "x1": int(face_box[0]),
1024
+ "y1": int(face_box[1]),
1025
+ "x2": int(face_box[2]),
1026
+ "y2": int(face_box[3]),
1027
+ },
1028
+ "color": {
1029
+ "bgr": [int(color_bgr[0]), int(color_bgr[1]), int(color_bgr[2])],
1030
+ "hex": color_hex,
1031
+ },
1032
+ "cropped_face_filename": cropped_face_filename,
1033
+ "model_used": prediction_result.get("model_used", model_type.value),
1034
+ }
1035
+
1036
+ if age_adjusted_flag:
1037
+ face_result["age_adjusted"] = True
1038
+ face_result["age_adjustment_value"] = int(age_adjustment_value)
1039
+
1040
+ results["faces"].append(face_result)
1041
+
1042
+ results["annotated_image"] = annotated_image
1043
+ return results
1044
+
1045
+ def _warmup_models(self):
1046
+ """预热模型,减少首次调用延迟"""
1047
+ try:
1048
+ logger.info("Starting to warm up models...")
1049
+
1050
+ # 创建一个小的测试图像 (64x64)
1051
+ test_image = np.ones((64, 64, 3), dtype=np.uint8) * 128
1052
+
1053
+ # 预热DeepFace模型(如果可用)
1054
+ if DEEPFACE_AVAILABLE:
1055
+ try:
1056
+ import tempfile
1057
+ with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_file:
1058
+ cv2.imwrite(tmp_file.name, test_image, [cv2.IMWRITE_WEBP_QUALITY, 95])
1059
+ # 预热DeepFace - 使用最小的actions集合
1060
+ DeepFace.analyze(
1061
+ img_path=tmp_file.name,
1062
+ actions=["age", "emotion", "gender"],
1063
+ detector_backend="yolov8",
1064
+ enforce_detection=False,
1065
+ silent=True
1066
+ )
1067
+ os.unlink(tmp_file.name)
1068
+ logger.info("DeepFace model warm-up completed")
1069
+ except Exception as e:
1070
+ logger.warning(f"DeepFace model warm-up failed: {e}")
1071
+
1072
+ # 预热OpenCV DNN模型
1073
+ try:
1074
+ # 预热人脸检测模型
1075
+ blob = cv2.dnn.blobFromImage(test_image, 1.0, (300, 300), (104, 117, 123))
1076
+ self.face_net.setInput(blob)
1077
+ self.face_net.forward()
1078
+
1079
+ # 预热年龄预测模型
1080
+ test_face = cv2.resize(test_image, (224, 224))
1081
+ blob = cv2.dnn.blobFromImage(test_face, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False)
1082
+ self.age_net.setInput(blob)
1083
+ self.age_net.forward()
1084
+
1085
+ # 预热性别预测模型
1086
+ self.gender_net.setInput(blob)
1087
+ self.gender_net.forward()
1088
+
1089
+ # 预热颜值评分模型
1090
+ self.beauty_net.setInput(blob)
1091
+ self.beauty_net.forward()
1092
+
1093
+ logger.info("OpenCV DNN model warm-up completed")
1094
+ except Exception as e:
1095
+ logger.warning(f"OpenCV DNN model warm-up failed: {e}")
1096
+
1097
+ logger.info("Model warm-up completed")
1098
+ except Exception as e:
1099
+ logger.warning(f"Error occurred during model warm-up: {e}")
facial_analyzer.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from typing import List, Dict, Any
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ import config
8
+ from config import logger, DLIB_AVAILABLE
9
+
10
+ if DLIB_AVAILABLE:
11
+ import mediapipe as mp
12
+
13
+
14
+ class FacialFeatureAnalyzer:
15
+ """五官分析器"""
16
+
17
+ def __init__(self):
18
+ self.face_mesh = None
19
+ if DLIB_AVAILABLE:
20
+ try:
21
+ # 初始化MediaPipe Face Mesh
22
+ mp_face_mesh = mp.solutions.face_mesh
23
+ self.face_mesh = mp_face_mesh.FaceMesh(
24
+ static_image_mode=True,
25
+ max_num_faces=1,
26
+ refine_landmarks=True,
27
+ min_detection_confidence=0.5,
28
+ min_tracking_confidence=0.5
29
+ )
30
+ logger.info("MediaPipe face landmark detector loaded successfully")
31
+ except Exception as e:
32
+ logger.error(f"Failed to load MediaPipe model: {e}")
33
+
34
+ def analyze_facial_features(
35
+ self, face_image: np.ndarray, face_box: List[int]
36
+ ) -> Dict[str, Any]:
37
+ """
38
+ 分析五官特征
39
+ :param face_image: 人脸图像
40
+ :param face_box: 人脸边界框 [x1, y1, x2, y2]
41
+ :return: 五官分析结果
42
+ """
43
+ if not DLIB_AVAILABLE or self.face_mesh is None:
44
+ return self._basic_facial_analysis(face_image)
45
+
46
+ try:
47
+ # MediaPipe需要RGB图像
48
+ rgb_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
49
+
50
+ # 检测关键点
51
+ results = self.face_mesh.process(rgb_image)
52
+
53
+ if not results.multi_face_landmarks:
54
+ logger.warning("No facial landmarks detected")
55
+ return self._basic_facial_analysis(face_image)
56
+
57
+ # 获取第一个面部的关键点
58
+ face_landmarks = results.multi_face_landmarks[0]
59
+
60
+ # 将MediaPipe的468个关键点转换为类似dlib 68点的格式
61
+ points = self._convert_mediapipe_to_dlib_format(face_landmarks, face_image.shape)
62
+
63
+ return self._analyze_features_from_landmarks(points, face_image.shape)
64
+
65
+ except Exception as e:
66
+ logger.error(f"Facial feature analysis failed: {e}")
67
+ traceback.print_exc() # ← 打印完整堆栈,包括确切行号
68
+ return self._basic_facial_analysis(face_image)
69
+
70
+ def _convert_mediapipe_to_dlib_format(self, face_landmarks, image_shape):
71
+ """
72
+ 将MediaPipe的468个关键点转换为类似dlib 68点的格式
73
+ MediaPipe到dlib的关键点映射
74
+ """
75
+ h, w = image_shape[:2]
76
+
77
+ # MediaPipe关键点索引到dlib 68点的映射
78
+ # 这个映射基于MediaPipe Face Mesh的标准索引
79
+ mediapipe_to_dlib_map = {
80
+ # 面部轮廓 (0-16)
81
+ 0: 234, # 下巴最低点
82
+ 1: 132, # 右脸颊下
83
+ 2: 172, # 右脸颊
84
+ 3: 136, # 右脸颊上
85
+ 4: 150, # 右颧骨
86
+ 5: 149, # 右太阳穴
87
+ 6: 176, # 右额头边缘
88
+ 7: 148, # 右额头
89
+ 8: 152, # 额头中央
90
+ 9: 377, # 左额头
91
+ 10: 400, # 左额头边缘
92
+ 11: 378, # 左太阳穴
93
+ 12: 379, # 左颧骨
94
+ 13: 365, # 左脸颊上
95
+ 14: 397, # 左脸颊
96
+ 15: 361, # 左脸颊下
97
+ 16: 454, # 下巴左侧
98
+
99
+ # 右眉毛 (17-21)
100
+ 17: 70, # 右眉毛外端
101
+ 18: 63, # 右眉毛
102
+ 19: 105, # 右眉毛
103
+ 20: 66, # 右眉毛
104
+ 21: 107, # 右眉毛内端
105
+
106
+ # 左眉毛 (22-26)
107
+ 22: 336, # 左眉毛内端
108
+ 23: 296, # 左眉毛
109
+ 24: 334, # 左眉毛
110
+ 25: 293, # 左眉毛
111
+ 26: 300, # 左眉毛外端
112
+
113
+ # 鼻梁 (27-30)
114
+ 27: 168, # 鼻梁顶
115
+ 28: 8, # 鼻梁
116
+ 29: 9, # 鼻梁
117
+ 30: 10, # 鼻梁底
118
+
119
+ # 鼻翼 (31-35)
120
+ 31: 151, # 右鼻翼
121
+ 32: 134, # 右鼻孔
122
+ 33: 2, # 鼻尖
123
+ 34: 363, # 左鼻孔
124
+ 35: 378, # 左鼻翼
125
+
126
+ # 右眼 (36-41)
127
+ 36: 33, # 右眼外角
128
+ 37: 7, # 右眼上眼睑
129
+ 38: 163, # 右眼上眼睑
130
+ 39: 144, # 右眼内角
131
+ 40: 145, # 右眼下眼睑
132
+ 41: 153, # 右眼下眼睑
133
+
134
+ # 左眼 (42-47)
135
+ 42: 362, # 左眼内角
136
+ 43: 382, # 左眼上眼睑
137
+ 44: 381, # 左眼上眼睑
138
+ 45: 380, # 左眼外角
139
+ 46: 374, # 左眼下眼睑
140
+ 47: 373, # 左眼下眼睑
141
+
142
+ # 嘴部轮廓 (48-67)
143
+ 48: 78, # 右嘴角
144
+ 49: 95, # 右上唇
145
+ 50: 88, # 上唇右侧
146
+ 51: 178, # 上唇中央右
147
+ 52: 87, # 上唇中央
148
+ 53: 14, # 上唇中央左
149
+ 54: 317, # 上唇左侧
150
+ 55: 318, # 左上唇
151
+ 56: 308, # 左嘴角
152
+ 57: 324, # 左下唇
153
+ 58: 318, # 下唇左侧
154
+ 59: 16, # 下唇中央左
155
+ 60: 17, # 下唇中央
156
+ 61: 18, # 下唇中央右
157
+ 62: 200, # 下唇右侧
158
+ 63: 199, # 右下唇
159
+ 64: 175, # 右嘴角内
160
+ 65: 84, # 上唇内右
161
+ 66: 17, # 下唇内中央
162
+ 67: 314, # 上唇内左
163
+ }
164
+
165
+ # 转换关键点
166
+ points = []
167
+ for i in range(68):
168
+ if i in mediapipe_to_dlib_map:
169
+ mp_idx = mediapipe_to_dlib_map[i]
170
+ if mp_idx < len(face_landmarks.landmark):
171
+ landmark = face_landmarks.landmark[mp_idx]
172
+ x = int(landmark.x * w)
173
+ y = int(landmark.y * h)
174
+ points.append((x, y))
175
+ else:
176
+ # 如果索引超出范围,使用默认位置
177
+ points.append((w//2, h//2))
178
+ else:
179
+ # 如果没有映射,使用默认位置
180
+ points.append((w//2, h//2))
181
+
182
+ return points
183
+
184
+ def _analyze_features_from_landmarks(
185
+ self, landmarks: List[tuple], image_shape: tuple
186
+ ) -> Dict[str, Any]:
187
+ """基于68个关键点分析五官"""
188
+ try:
189
+ # 定义各部位的关键点索引
190
+ jawline = landmarks[0:17] # 下颌线
191
+ left_eyebrow = landmarks[17:22] # 左眉毛
192
+ right_eyebrow = landmarks[22:27] # 右眉毛
193
+ nose = landmarks[27:36] # 鼻子
194
+ left_eye = landmarks[36:42] # 左眼
195
+ right_eye = landmarks[42:48] # 右眼
196
+ mouth = landmarks[48:68] # 嘴巴
197
+
198
+ # 计算各部位得分 (简化版,实际应用需要更复杂的算法)
199
+ scores = {
200
+ "eyes": self._score_eyes(left_eye, right_eye, image_shape),
201
+ "nose": self._score_nose(nose, image_shape),
202
+ "mouth": self._score_mouth(mouth, image_shape),
203
+ "eyebrows": self._score_eyebrows(
204
+ left_eyebrow, right_eyebrow, image_shape
205
+ ),
206
+ "jawline": self._score_jawline(jawline, image_shape),
207
+ }
208
+
209
+ # 计算总体协调性
210
+ harmony_score = self._calculate_harmony_new(landmarks, image_shape)
211
+ # 温和上调整体协调性分数(与颜值类似的拉升策略)
212
+ harmony_score = self._adjust_harmony_score(harmony_score)
213
+
214
+ return {
215
+ "facial_features": scores,
216
+ "harmony_score": round(harmony_score, 2),
217
+ "overall_facial_score": round(sum(scores.values()) / len(scores), 2),
218
+ "analysis_method": "mediapipe_landmarks",
219
+ }
220
+
221
+ except Exception as e:
222
+ logger.error(f"Landmark analysis failed: {e}")
223
+ return self._basic_facial_analysis(None)
224
+
225
+ def _adjust_harmony_score(self, score: float) -> float:
226
+ """整体协调性分值温和拉升:当低于阈值时往阈值靠拢一点。"""
227
+ try:
228
+ if not getattr(config, "HARMONY_ADJUST_ENABLED", False):
229
+ return round(float(score), 2)
230
+ thr = float(getattr(config, "HARMONY_ADJUST_THRESHOLD", 8.0))
231
+ gamma = float(getattr(config, "HARMONY_ADJUST_GAMMA", 0.5))
232
+ gamma = max(0.0001, min(1.0, gamma))
233
+ s = float(score)
234
+ if s < thr:
235
+ s = thr - gamma * (thr - s)
236
+ return round(min(10.0, max(0.0, s)), 2)
237
+ except Exception:
238
+ try:
239
+ return round(float(score), 2)
240
+ except Exception:
241
+ return 6.21
242
+
243
+ def _score_eyes(
244
+ self, left_eye: List[tuple], right_eye: List[tuple], image_shape: tuple
245
+ ) -> float:
246
+ """眼部评分"""
247
+ try:
248
+ # 计算眼部对称性和大小
249
+ left_width = abs(left_eye[3][0] - left_eye[0][0])
250
+ right_width = abs(right_eye[3][0] - right_eye[0][0])
251
+
252
+ # 计算眼部高度
253
+ left_height = abs(left_eye[1][1] - left_eye[5][1])
254
+ right_height = abs(right_eye[1][1] - right_eye[5][1])
255
+
256
+ # 对称性评分 - 宽度对称性
257
+ width_symmetry = 1 - min(
258
+ abs(left_width - right_width) / max(left_width, right_width), 0.5
259
+ )
260
+
261
+ # 高度对称性
262
+ height_symmetry = 1 - min(
263
+ abs(left_height - right_height) / max(left_height, right_height), 0.5
264
+ )
265
+
266
+ # 大小适中性评分 (相对于脸部宽度) - 调整理想比例
267
+ avg_eye_width = (left_width + right_width) / 2
268
+ face_width = image_shape[1]
269
+ ideal_ratio = 0.08 # 调整理想比例,原来0.15太大
270
+ size_score = max(
271
+ 0, 1 - abs(avg_eye_width / face_width - ideal_ratio) / ideal_ratio
272
+ )
273
+
274
+ # 眼部长宽比评分
275
+ avg_eye_height = (left_height + right_height) / 2
276
+ aspect_ratio = avg_eye_width / max(avg_eye_height, 1) # 避免除零
277
+ ideal_aspect = 3.0 # 理想长宽比
278
+ aspect_score = max(0, 1 - abs(aspect_ratio - ideal_aspect) / ideal_aspect)
279
+
280
+ final_score = (
281
+ width_symmetry * 0.3
282
+ + height_symmetry * 0.3
283
+ + size_score * 0.25
284
+ + aspect_score * 0.15
285
+ ) * 10
286
+ return round(max(0, min(10, final_score)), 2)
287
+ except:
288
+ return 6.21
289
+
290
+ def _score_nose(self, nose: List[tuple], image_shape: tuple) -> float:
291
+ """鼻部评分"""
292
+ try:
293
+ # 鼻子关键点
294
+ nose_tip = nose[3] # 鼻尖
295
+ nose_bridge_top = nose[0] # 鼻梁顶部
296
+ left_nostril = nose[1]
297
+ right_nostril = nose[5]
298
+
299
+ # 计算鼻子的直线度 (鼻梁是否挺直)
300
+ straightness = 1 - min(
301
+ abs(nose_tip[0] - nose_bridge_top[0]) / (image_shape[1] * 0.1), 1.0
302
+ )
303
+
304
+ # 鼻宽评分 - 使用鼻翼宽度
305
+ nose_width = abs(right_nostril[0] - left_nostril[0])
306
+ face_width = image_shape[1]
307
+ ideal_nose_ratio = 0.06 # 调整理想比例
308
+ width_score = max(
309
+ 0,
310
+ 1 - abs(nose_width / face_width - ideal_nose_ratio) / ideal_nose_ratio,
311
+ )
312
+
313
+ # 鼻子长度评分
314
+ nose_length = abs(nose_tip[1] - nose_bridge_top[1])
315
+ face_height = image_shape[0]
316
+ ideal_length_ratio = 0.08
317
+ length_score = max(
318
+ 0,
319
+ 1
320
+ - abs(nose_length / face_height - ideal_length_ratio)
321
+ / ideal_length_ratio,
322
+ )
323
+
324
+ final_score = (
325
+ straightness * 0.4 + width_score * 0.35 + length_score * 0.25
326
+ ) * 10
327
+ return round(max(0, min(10, final_score)), 2)
328
+ except:
329
+ return 6.21
330
+
331
+ def _score_mouth(self, mouth: List[tuple], image_shape: tuple) -> float:
332
+ """嘴部评分 - 大幅优化,更宽松的评分标准"""
333
+ try:
334
+ # 嘴角点
335
+ left_corner = mouth[0] # 左嘴角
336
+ right_corner = mouth[6] # 右嘴角
337
+
338
+ # 上唇和下唇中心点
339
+ upper_lip_center = mouth[3] # 上唇中心
340
+ lower_lip_center = mouth[9] # 下唇中心
341
+
342
+ # 基础分数,避免过低
343
+ base_score = 6.0
344
+
345
+ # 1. 嘴宽评分 - 更宽松的标准
346
+ mouth_width = abs(right_corner[0] - left_corner[0])
347
+ face_width = image_shape[1]
348
+ mouth_ratio = mouth_width / face_width
349
+
350
+ # 设置更宽的合理范围 (0.04-0.15)
351
+ if 0.04 <= mouth_ratio <= 0.15:
352
+ width_score = 1.0 # 在合理范围内就给满分
353
+ elif mouth_ratio < 0.04:
354
+ width_score = max(0.3, mouth_ratio / 0.04) # 太小时渐减
355
+ else:
356
+ width_score = max(0.3, 0.15 / mouth_ratio) # 太大时渐减
357
+
358
+ # 2. 唇厚度评分 - 简化并放宽标准
359
+ lip_thickness = abs(lower_lip_center[1] - upper_lip_center[1])
360
+ # 只要厚度不是极端值就给高分
361
+ if lip_thickness > 3: # 像素值,有一定厚度
362
+ thickness_score = min(1.0, lip_thickness / 25) # 25像素为满分
363
+ else:
364
+ thickness_score = 0.5 # 太薄给中等分数
365
+
366
+ # 3. 嘴部对称性评分 - 更宽松
367
+ mouth_center_x = (left_corner[0] + right_corner[0]) / 2
368
+ face_center_x = image_shape[1] / 2
369
+ center_deviation = abs(mouth_center_x - face_center_x) / face_width
370
+
371
+ if center_deviation < 0.02: # 偏差小于2%
372
+ symmetry_score = 1.0
373
+ elif center_deviation < 0.05: # 偏差小于5%
374
+ symmetry_score = 0.8
375
+ else:
376
+ symmetry_score = max(0.5, 1 - center_deviation * 10) # 最低0.5分
377
+
378
+ # 4. 嘴唇形状评分 - 简化
379
+ # 检查嘴角是否在合理位置
380
+ corner_height_diff = abs(left_corner[1] - right_corner[1])
381
+ if corner_height_diff < face_width * 0.02: # 嘴角高度差异小
382
+ shape_score = 1.0
383
+ else:
384
+ shape_score = max(0.6, 1 - corner_height_diff / (face_width * 0.02))
385
+
386
+ # 5. 综合评分 - 调整权重,给基础分更大权重
387
+ feature_score = (
388
+ width_score * 0.3
389
+ + thickness_score * 0.25
390
+ + symmetry_score * 0.25
391
+ + shape_score * 0.2
392
+ )
393
+
394
+ # 最终分数 = 基础分 + 特征分奖励
395
+ final_score = base_score + feature_score * 4 # 最高10分
396
+
397
+ return round(max(4.0, min(10, final_score)), 2) # 最低4分,最高10分
398
+ except Exception as e:
399
+ return 6.21
400
+
401
+ def _score_eyebrows(
402
+ self, left_brow: List[tuple], right_brow: List[tuple], image_shape: tuple
403
+ ) -> float:
404
+ """眉毛评分 - 改进算法"""
405
+ try:
406
+ # 计算眉毛长度
407
+ left_length = abs(left_brow[-1][0] - left_brow[0][0])
408
+ right_length = abs(right_brow[-1][0] - right_brow[0][0])
409
+
410
+ # 长度对称性
411
+ length_symmetry = 1 - min(
412
+ abs(left_length - right_length) / max(left_length, right_length), 0.5
413
+ )
414
+
415
+ # 计算眉毛拱形 - 改进方法
416
+ left_peak_y = min([p[1] for p in left_brow]) # 眉峰(y坐标最小)
417
+ left_ends_y = (left_brow[0][1] + left_brow[-1][1]) / 2 # 眉毛两端平均高度
418
+ left_arch = max(0, left_ends_y - left_peak_y) # 拱形高度
419
+
420
+ right_peak_y = min([p[1] for p in right_brow])
421
+ right_ends_y = (right_brow[0][1] + right_brow[-1][1]) / 2
422
+ right_arch = max(0, right_ends_y - right_peak_y)
423
+
424
+ # 拱形对称性
425
+ arch_symmetry = 1 - min(
426
+ abs(left_arch - right_arch) / max(left_arch, right_arch, 1), 0.5
427
+ )
428
+
429
+ # 眉形适中性评分
430
+ avg_arch = (left_arch + right_arch) / 2
431
+ face_height = image_shape[0]
432
+ ideal_arch_ratio = 0.015 # 理想拱形比例
433
+ arch_ratio = avg_arch / face_height
434
+ arch_score = max(
435
+ 0, 1 - abs(arch_ratio - ideal_arch_ratio) / ideal_arch_ratio
436
+ )
437
+
438
+ # 眉毛浓密度(通过点的密集程度估算)
439
+ density_score = min(1.0, (len(left_brow) + len(right_brow)) / 10)
440
+
441
+ final_score = (
442
+ length_symmetry * 0.3
443
+ + arch_symmetry * 0.3
444
+ + arch_score * 0.25
445
+ + density_score * 0.15
446
+ ) * 10
447
+ return round(max(0, min(10, final_score)), 2)
448
+ except:
449
+ return 6.21
450
+
451
+ def _score_jawline(self, jawline: List[tuple], image_shape: tuple) -> float:
452
+ """下颌线评分 - 改进算法"""
453
+ try:
454
+ jaw_points = [(p[0], p[1]) for p in jawline]
455
+
456
+ # 关键点
457
+ left_jaw = jaw_points[2] # 左下颌角
458
+ jaw_tip = jaw_points[8] # 下巴尖
459
+ right_jaw = jaw_points[14] # 右下颌角
460
+
461
+ # 对称性评分 - 改进计算
462
+ left_dist = (
463
+ (left_jaw[0] - jaw_tip[0]) ** 2 + (left_jaw[1] - jaw_tip[1]) ** 2
464
+ ) ** 0.5
465
+ right_dist = (
466
+ (right_jaw[0] - jaw_tip[0]) ** 2 + (right_jaw[1] - jaw_tip[1]) ** 2
467
+ ) ** 0.5
468
+ symmetry = 1 - min(
469
+ abs(left_dist - right_dist) / max(left_dist, right_dist), 0.5
470
+ )
471
+
472
+ # 下颌角度评分
473
+ left_angle_y = abs(left_jaw[1] - jaw_tip[1])
474
+ right_angle_y = abs(right_jaw[1] - jaw_tip[1])
475
+ avg_angle = (left_angle_y + right_angle_y) / 2
476
+
477
+ # 理想的下颌角度
478
+ face_height = image_shape[0]
479
+ ideal_angle_ratio = 0.08
480
+ angle_ratio = avg_angle / face_height
481
+ angle_score = max(
482
+ 0, 1 - abs(angle_ratio - ideal_angle_ratio) / ideal_angle_ratio
483
+ )
484
+
485
+ # 下颌线清晰度(通过点间距离变化评估)
486
+ smoothness_score = 0.8 # 简化处理,可以根据实际需要改进
487
+
488
+ final_score = (
489
+ symmetry * 0.4 + angle_score * 0.35 + smoothness_score * 0.25
490
+ ) * 10
491
+ return round(max(0, min(10, final_score)), 2)
492
+ except:
493
+ return 6.21
494
+
495
+ def _calculate_harmony(self, landmarks: List[tuple], image_shape: tuple) -> float:
496
+ """计算五官协调性"""
497
+ try:
498
+ # 黄金比例检测 (简化版)
499
+ face_height = max([p[1] for p in landmarks]) - min(
500
+ [p[1] for p in landmarks]
501
+ )
502
+ face_width = max([p[0] for p in landmarks]) - min([p[0] for p in landmarks])
503
+
504
+ # 理想比例约为1.618
505
+ ratio = face_height / face_width if face_width > 0 else 1
506
+ golden_ratio = 1.618
507
+ harmony = 1 - abs(ratio - golden_ratio) / golden_ratio
508
+
509
+ return max(0, min(10, harmony * 10))
510
+ except:
511
+ return 6.21
512
+
513
+ def _calculate_harmony_new(
514
+ self, landmarks: List[tuple], image_shape: tuple
515
+ ) -> float:
516
+ """
517
+ 计算五官协调性 - 优化版本
518
+ 基于多个美学比例和对称性指标
519
+ """
520
+ try:
521
+ logger.info(f"face landmarks={len(landmarks)}")
522
+ if len(landmarks) < 68: # 假设使用68点面部关键点
523
+ return 6.21
524
+
525
+ # 转换为numpy数组便于计算
526
+ points = np.array(landmarks)
527
+
528
+ # 1. 面部基础测量
529
+ face_measurements = self._get_face_measurements(points)
530
+
531
+ # 2. 计算多个协调性指标
532
+ scores = []
533
+
534
+ # 黄金比例评分 (权重: 20%)
535
+ golden_score = self._calculate_golden_ratios(face_measurements)
536
+ logger.info(f"Golden ratio score={golden_score}")
537
+ scores.append(("golden_ratio", golden_score, 0.10))
538
+
539
+ # 对称性评分 (权重: 25%)
540
+ symmetry_score = self._calculate_facial_symmetry(face_measurements, points)
541
+ logger.info(f"Symmetry score={symmetry_score}")
542
+ scores.append(("symmetry", symmetry_score, 0.40))
543
+
544
+ # 三庭五眼比例 (权重: 20%)
545
+ proportion_score = self._calculate_classical_proportions(face_measurements)
546
+ logger.info(f"Three courts five eyes ratio={proportion_score}")
547
+ scores.append(("proportions", proportion_score, 0.05))
548
+
549
+ # 五官间距协调性 (权重: 15%)
550
+ spacing_score = self._calculate_feature_spacing(face_measurements)
551
+ logger.info(f"Facial feature spacing harmony={spacing_score}")
552
+ scores.append(("spacing", spacing_score, 0))
553
+
554
+ # 面部轮廓协调性 (权重: 10%)
555
+ contour_score = self._calculate_contour_harmony(points)
556
+ logger.info(f"Facial contour harmony={contour_score}")
557
+ scores.append(("contour", contour_score, 0.05))
558
+
559
+ # 眼鼻口比例协调性 (权重: 10%)
560
+ feature_score = self._calculate_feature_proportions(face_measurements)
561
+ logger.info(f"Eye-nose-mouth proportion harmony={feature_score}")
562
+ scores.append(("features", feature_score, 0.40))
563
+
564
+ # 加权平均计算最终得分
565
+ final_score = sum(score * weight for _, score, weight in scores)
566
+ logger.info(f"Weighted average final score={final_score}")
567
+ return max(0, min(10, final_score))
568
+
569
+ except Exception as e:
570
+ logger.error(f"Error calculating facial harmony: {e}")
571
+ traceback.print_exc() # ← 打印完整堆栈,包括确切行号
572
+ return 6.21
573
+
574
+ def _get_face_measurements(self, points: np.ndarray) -> Dict[str, float]:
575
+ """提取面部关键测量数据"""
576
+ measurements = {}
577
+
578
+ # 面部轮廓点 (0-16)
579
+ face_contour = points[0:17]
580
+
581
+ # 眉毛点 (17-26)
582
+ left_eyebrow = points[17:22]
583
+ right_eyebrow = points[22:27]
584
+
585
+ # 眼睛点 (36-47)
586
+ left_eye = points[36:42]
587
+ right_eye = points[42:48]
588
+
589
+ # 鼻子点 (27-35)
590
+ nose = points[27:36]
591
+
592
+ # 嘴巴点 (48-67)
593
+ mouth = points[48:68]
594
+
595
+ # 基础测量
596
+ measurements["face_width"] = np.max(face_contour[:, 0]) - np.min(
597
+ face_contour[:, 0]
598
+ )
599
+ measurements["face_height"] = np.max(points[:, 1]) - np.min(points[:, 1])
600
+
601
+ # 眼部测量
602
+ measurements["left_eye_width"] = np.max(left_eye[:, 0]) - np.min(left_eye[:, 0])
603
+ measurements["right_eye_width"] = np.max(right_eye[:, 0]) - np.min(
604
+ right_eye[:, 0]
605
+ )
606
+ measurements["eye_distance"] = np.min(right_eye[:, 0]) - np.max(left_eye[:, 0])
607
+ measurements["left_eye_center"] = np.mean(left_eye, axis=0)
608
+ measurements["right_eye_center"] = np.mean(right_eye, axis=0)
609
+
610
+ # 鼻部测量
611
+ measurements["nose_width"] = np.max(nose[:, 0]) - np.min(nose[:, 0])
612
+ measurements["nose_height"] = np.max(nose[:, 1]) - np.min(nose[:, 1])
613
+ measurements["nose_tip"] = points[33] # 鼻尖
614
+
615
+ # 嘴部测量
616
+ measurements["mouth_width"] = np.max(mouth[:, 0]) - np.min(mouth[:, 0])
617
+ measurements["mouth_height"] = np.max(mouth[:, 1]) - np.min(mouth[:, 1])
618
+
619
+ # 关键垂直距离
620
+ measurements["forehead_height"] = measurements["left_eye_center"][1] - np.min(
621
+ points[:, 1]
622
+ )
623
+ measurements["middle_face_height"] = (
624
+ measurements["nose_tip"][1] - measurements["left_eye_center"][1]
625
+ )
626
+ measurements["lower_face_height"] = (
627
+ np.max(points[:, 1]) - measurements["nose_tip"][1]
628
+ )
629
+
630
+ return measurements
631
+
632
+ def _calculate_golden_ratios(self, measurements: Dict[str, float]) -> float:
633
+ """计算黄金比例相关得分"""
634
+ golden_ratio = 1.618
635
+ scores = []
636
+
637
+ # 面部长宽比
638
+ if measurements["face_width"] > 0:
639
+ face_ratio = measurements["face_height"] / measurements["face_width"]
640
+ score = 1 - abs(face_ratio - golden_ratio) / golden_ratio
641
+ scores.append(max(0, score))
642
+
643
+ # 上中下三庭比例
644
+ total_height = (
645
+ measurements["forehead_height"]
646
+ + measurements["middle_face_height"]
647
+ + measurements["lower_face_height"]
648
+ )
649
+
650
+ if total_height > 0:
651
+ upper_ratio = measurements["forehead_height"] / total_height
652
+ middle_ratio = measurements["middle_face_height"] / total_height
653
+ lower_ratio = measurements["lower_face_height"] / total_height
654
+
655
+ # 理想比例约为 1:1:1
656
+ ideal_ratio = 1 / 3
657
+ upper_score = 1 - abs(upper_ratio - ideal_ratio) / ideal_ratio
658
+ middle_score = 1 - abs(middle_ratio - ideal_ratio) / ideal_ratio
659
+ lower_score = 1 - abs(lower_ratio - ideal_ratio) / ideal_ratio
660
+
661
+ scores.extend(
662
+ [max(0, upper_score), max(0, middle_score), max(0, lower_score)]
663
+ )
664
+
665
+ return np.mean(scores) * 10 if scores else 7.0
666
+
667
+ def _calculate_facial_symmetry(
668
+ self, measurements: Dict[str, float], points: np.ndarray
669
+ ) -> float:
670
+ """计算面部对称性"""
671
+ # 计算面部中线
672
+ face_center_x = np.mean(points[:, 0])
673
+
674
+ # 检查左右对称的关键点对
675
+ symmetry_pairs = [
676
+ (17, 26), # 眉毛外端
677
+ (18, 25), # 眉毛
678
+ (19, 24), # 眉毛
679
+ (36, 45), # 眼角
680
+ (39, 42), # 眼角
681
+ (31, 35), # 鼻翼
682
+ (48, 54), # 嘴角
683
+ (4, 12), # 面部轮廓
684
+ (5, 11), # 面部轮廓
685
+ (6, 10), # 面部轮廓
686
+ ]
687
+
688
+ symmetry_scores = []
689
+
690
+ for left_idx, right_idx in symmetry_pairs:
691
+ if left_idx < len(points) and right_idx < len(points):
692
+ left_point = points[left_idx]
693
+ right_point = points[right_idx]
694
+
695
+ # 计算到中线的距离差异
696
+ left_dist = abs(left_point[0] - face_center_x)
697
+ right_dist = abs(right_point[0] - face_center_x)
698
+
699
+ # 垂直位置差异
700
+ vertical_diff = abs(left_point[1] - right_point[1])
701
+
702
+ # 对称性得分
703
+ if left_dist + right_dist > 0:
704
+ horizontal_symmetry = 1 - abs(left_dist - right_dist) / (
705
+ left_dist + right_dist
706
+ )
707
+ vertical_symmetry = 1 - vertical_diff / measurements.get(
708
+ "face_height", 100
709
+ )
710
+
711
+ symmetry_scores.append(
712
+ (horizontal_symmetry + vertical_symmetry) / 2
713
+ )
714
+
715
+ return np.mean(symmetry_scores) * 10 if symmetry_scores else 7.0
716
+
717
+ def _calculate_classical_proportions(self, measurements: Dict[str, float]) -> float:
718
+ """计算经典美学比例 (三庭五眼等)"""
719
+ scores = []
720
+
721
+ # 五眼比例检测
722
+ if measurements["face_width"] > 0:
723
+ eye_width_avg = (
724
+ measurements["left_eye_width"] + measurements["right_eye_width"]
725
+ ) / 2
726
+ ideal_eye_count = 5 # 理想情况下面宽应该等于5个眼宽
727
+ actual_eye_count = (
728
+ measurements["face_width"] / eye_width_avg if eye_width_avg > 0 else 5
729
+ )
730
+
731
+ eye_proportion_score = (
732
+ 1 - abs(actual_eye_count - ideal_eye_count) / ideal_eye_count
733
+ )
734
+ scores.append(max(0, eye_proportion_score))
735
+
736
+ # 眼间距比例
737
+ if measurements.get("left_eye_width", 0) > 0:
738
+ eye_spacing_ratio = (
739
+ measurements["eye_distance"] / measurements["left_eye_width"]
740
+ )
741
+ ideal_spacing_ratio = 1.0 # 理想情况下眼间距约等于一个眼宽
742
+
743
+ spacing_score = (
744
+ 1 - abs(eye_spacing_ratio - ideal_spacing_ratio) / ideal_spacing_ratio
745
+ )
746
+ scores.append(max(0, spacing_score))
747
+
748
+ # 鼻宽与眼宽比例
749
+ if (
750
+ measurements.get("left_eye_width", 0) > 0
751
+ and measurements.get("nose_width", 0) > 0
752
+ ):
753
+ nose_eye_ratio = measurements["nose_width"] / measurements["left_eye_width"]
754
+ ideal_nose_eye_ratio = 0.8 # 理想鼻宽约为眼宽的80%
755
+
756
+ nose_score = (
757
+ 1 - abs(nose_eye_ratio - ideal_nose_eye_ratio) / ideal_nose_eye_ratio
758
+ )
759
+ scores.append(max(0, nose_score))
760
+
761
+ return np.mean(scores) * 10 if scores else 7.0
762
+
763
+ def _calculate_feature_spacing(self, measurements: Dict[str, float]) -> float:
764
+ """计算五官间距协调性"""
765
+ scores = []
766
+
767
+ # 眼鼻距离协调性
768
+ eye_nose_distance = abs(
769
+ measurements["left_eye_center"][1] - measurements["nose_tip"][1]
770
+ )
771
+ if measurements.get("face_height", 0) > 0:
772
+ eye_nose_ratio = eye_nose_distance / measurements["face_height"]
773
+ ideal_ratio = 0.15 # 理想比例
774
+ score = 1 - abs(eye_nose_ratio - ideal_ratio) / ideal_ratio
775
+ scores.append(max(0, score))
776
+
777
+ # 鼻嘴距离协调性
778
+ nose_mouth_distance = abs(
779
+ measurements["nose_tip"][1] - np.mean([measurements.get("mouth_height", 0)])
780
+ )
781
+ if measurements.get("face_height", 0) > 0:
782
+ nose_mouth_ratio = nose_mouth_distance / measurements["face_height"]
783
+ ideal_ratio = 0.12 # 理想比例
784
+ score = 1 - abs(nose_mouth_ratio - ideal_ratio) / ideal_ratio
785
+ scores.append(max(0, score))
786
+
787
+ return np.mean(scores) * 10 if scores else 7.0
788
+
789
+ def _calculate_contour_harmony(self, points: np.ndarray) -> float:
790
+ """计算面部轮廓协调性"""
791
+ try:
792
+ face_contour = points[0:17] # 面部轮廓点
793
+
794
+ # 计算轮廓的平滑度
795
+ smoothness_scores = []
796
+
797
+ for i in range(1, len(face_contour) - 1):
798
+ # 计算相邻三点形成的角度
799
+ p1, p2, p3 = face_contour[i - 1], face_contour[i], face_contour[i + 1]
800
+
801
+ v1 = p1 - p2
802
+ v2 = p3 - p2
803
+
804
+ # 计算角度
805
+ cos_angle = np.dot(v1, v2) / (
806
+ np.linalg.norm(v1) * np.linalg.norm(v2) + 1e-8
807
+ )
808
+ angle = np.arccos(np.clip(cos_angle, -1, 1))
809
+
810
+ # 角度越接近平滑曲线越好 (避免过于尖锐的角度)
811
+ smoothness = 1 - abs(angle - np.pi / 2) / (np.pi / 2)
812
+ smoothness_scores.append(max(0, smoothness))
813
+
814
+ return np.mean(smoothness_scores) * 10 if smoothness_scores else 7.0
815
+
816
+ except:
817
+ return 6.21
818
+
819
+ def _calculate_feature_proportions(self, measurements: Dict[str, float]) -> float:
820
+ """计算眼鼻口等五官内部比例协调性"""
821
+ scores = []
822
+
823
+ # 眼部比例 (长宽比)
824
+ left_eye_ratio = measurements.get("left_eye_width", 1) / max(
825
+ measurements.get("left_eye_width", 1) * 0.3, 1
826
+ )
827
+ right_eye_ratio = measurements.get("right_eye_width", 1) / max(
828
+ measurements.get("right_eye_width", 1) * 0.3, 1
829
+ )
830
+
831
+ # 理想眼部长宽比约为3:1
832
+ ideal_eye_ratio = 3.0
833
+ left_eye_score = 1 - abs(left_eye_ratio - ideal_eye_ratio) / ideal_eye_ratio
834
+ right_eye_score = 1 - abs(right_eye_ratio - ideal_eye_ratio) / ideal_eye_ratio
835
+
836
+ scores.extend([max(0, left_eye_score), max(0, right_eye_score)])
837
+
838
+ # 嘴部比例
839
+ if measurements.get("mouth_height", 0) > 0:
840
+ mouth_ratio = measurements["mouth_width"] / measurements["mouth_height"]
841
+ ideal_mouth_ratio = 3.5 # 理想嘴部长宽比
842
+ mouth_score = 1 - abs(mouth_ratio - ideal_mouth_ratio) / ideal_mouth_ratio
843
+ scores.append(max(0, mouth_score))
844
+
845
+ # 鼻部比例
846
+ if measurements.get("nose_height", 0) > 0:
847
+ nose_ratio = measurements["nose_height"] / measurements["nose_width"]
848
+ ideal_nose_ratio = 1.5 # 理想鼻部长宽比
849
+ nose_score = 1 - abs(nose_ratio - ideal_nose_ratio) / ideal_nose_ratio
850
+ scores.append(max(0, nose_score))
851
+
852
+ return np.mean(scores) * 10 if scores else 7.0
853
+
854
+ def _basic_facial_analysis(self, face_image) -> Dict[str, Any]:
855
+ """基础五官分析 (当dlib不可用时)"""
856
+ return {
857
+ "facial_features": {
858
+ "eyes": 7.0,
859
+ "nose": 7.0,
860
+ "mouth": 7.0,
861
+ "eyebrows": 7.0,
862
+ "jawline": 7.0,
863
+ },
864
+ "harmony_score": 7.0,
865
+ "overall_facial_score": 7.0,
866
+ "analysis_method": "basic_estimation",
867
+ }
868
+
869
+ def draw_facial_landmarks(self, face_image: np.ndarray) -> np.ndarray:
870
+ """
871
+ 在人脸图像上绘制特征点
872
+ :param face_image: 人脸图像
873
+ :return: 带特征点标记的人脸图像
874
+ """
875
+ if not DLIB_AVAILABLE or self.face_mesh is None:
876
+ # 如果没有可用的面部网格检测器,直接返回原图
877
+ return face_image.copy()
878
+
879
+ try:
880
+ # 复制原图用于绘制
881
+ annotated_image = face_image.copy()
882
+
883
+ # MediaPipe需要RGB图像
884
+ rgb_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
885
+
886
+ # 检测关键点
887
+ results = self.face_mesh.process(rgb_image)
888
+
889
+ if not results.multi_face_landmarks:
890
+ logger.warning("No facial landmarks detected for drawing")
891
+ return annotated_image
892
+
893
+ # 获取第一个面部的关键点
894
+ face_landmarks = results.multi_face_landmarks[0]
895
+
896
+ # 绘制所有关键点
897
+ h, w = face_image.shape[:2]
898
+ for landmark in face_landmarks.landmark:
899
+ x = int(landmark.x * w)
900
+ y = int(landmark.y * h)
901
+ # 绘制小圆点表示关键点
902
+ cv2.circle(annotated_image, (x, y), 1, (0, 255, 0), -1)
903
+
904
+ # 绘制十字标记
905
+ cv2.line(annotated_image, (x-2, y), (x+2, y), (0, 255, 0), 1)
906
+ cv2.line(annotated_image, (x, y-2), (x, y+2), (0, 255, 0), 1)
907
+
908
+ return annotated_image
909
+
910
+ except Exception as e:
911
+ logger.error(f"Failed to draw facial landmarks: {e}")
912
+ return face_image.copy()
gfpgan_restorer.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ from config import logger, MODELS_PATH
5
+ from gfpgan import GFPGANer
6
+
7
+
8
+ class GFPGANRestorer:
9
+ def __init__(self):
10
+ start_time = time.perf_counter()
11
+ self.restorer = None
12
+ self._initialize_model()
13
+ init_time = time.perf_counter() - start_time
14
+ if self.restorer is not None:
15
+ logger.info(f"GFPGANRestorer initialized successfully, time: {init_time:.3f}s")
16
+ else:
17
+ logger.info(f"GFPGANRestorer initialization completed but not available, time: {init_time:.3f}s")
18
+
19
+ def _initialize_model(self):
20
+ """初始化GFPGAN模型"""
21
+ try:
22
+ # 尝试多个可能的模型路径
23
+ possible_paths = [
24
+ f"{MODELS_PATH}/GFPGANv1.4.pth",
25
+ f"{MODELS_PATH}/gfpgan/GFPGANv1.4.pth",
26
+ os.path.expanduser("~/.cache/gfpgan/GFPGANv1.4.pth"),
27
+ "./models/GFPGANv1.4.pth"
28
+ ]
29
+
30
+ gfpgan_model_path = None
31
+ for path in possible_paths:
32
+ if os.path.exists(path):
33
+ gfpgan_model_path = path
34
+ break
35
+
36
+ if not gfpgan_model_path:
37
+ logger.warning(f"GFPGAN model file not found, tried paths: {possible_paths}")
38
+ logger.info("Will try to download GFPGAN model from the internet...")
39
+ # 使用默认路径,让GFPGAN自动下载
40
+ gfpgan_model_path = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
41
+
42
+ logger.info(f"Using GFPGAN model: {gfpgan_model_path}")
43
+
44
+ # 初始化GFPGAN
45
+ self.restorer = GFPGANer(
46
+ model_path=gfpgan_model_path,
47
+ upscale=2,
48
+ arch='clean',
49
+ channel_multiplier=2,
50
+ bg_upsampler=None
51
+ )
52
+ logger.info("GFPGAN model initialized successfully")
53
+
54
+ except Exception as e:
55
+ logger.error(f"GFPGAN model initialization failed: {e}")
56
+ self.restorer = None
57
+
58
+
59
+ def is_available(self):
60
+ """检查GFPGAN是否可用"""
61
+ return self.restorer is not None
62
+
63
+ def restore_image(self, image):
64
+ """
65
+ 使用GFPGAN修复老照片
66
+ :param image: 输入图像 (numpy array, BGR格式)
67
+ :return: 修复后的图像 (numpy array, BGR格式)
68
+ """
69
+ if not self.is_available():
70
+ raise Exception("GFPGAN模型未初始化")
71
+
72
+ try:
73
+ logger.info("Starting GFPGAN image restoration...")
74
+
75
+ # GFPGAN处理
76
+ # has_aligned=False: 输入图像没有对齐
77
+ # only_center_face=False: 处理所有检测到的人脸
78
+ # paste_back=True: 将修复的人脸贴回原图
79
+ cropped_faces, restored_faces, restored_img = self.restorer.enhance(
80
+ image,
81
+ has_aligned=False,
82
+ only_center_face=False,
83
+ paste_back=True
84
+ )
85
+
86
+ if restored_img is not None:
87
+ logger.info(f"GFPGAN restoration completed, detected {len(restored_faces)} faces")
88
+ return restored_img
89
+ else:
90
+ logger.warning("GFPGAN restoration returned empty image, using original image")
91
+ return image
92
+
93
+ except Exception as e:
94
+ logger.error(f"GFPGAN image restoration failed: {e}")
95
+ # 如果GFPGAN失败,返回原图而不是抛出异常
96
+ return image
install.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # pip install -r requirements.txt -i https://pypi.python.org/simple
2
+ pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
models.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import List, Optional
3
+
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class ModelType(str, Enum):
8
+ """模型类型枚举"""
9
+
10
+ HOWCUTEAMI = "howcuteami"
11
+ DEEPFACE = "deepface"
12
+ HYBRID = "hybrid" # 混合模式:颜值性别用howcuteami,年龄情绪用deepface
13
+
14
+
15
+ class ImageScoreItem(BaseModel):
16
+ file_path: str
17
+ score: float
18
+ is_cropped_face: bool = False
19
+ size_bytes: int
20
+ size_str: str
21
+ last_modified: str
22
+ nickname: Optional[str] = None
23
+
24
+
25
+ class SearchRequest(BaseModel):
26
+ keyword: Optional[str] = ""
27
+ searchType: Optional[str] = "face"
28
+ top_k: Optional[int] = 5
29
+ score_threshold: float = 0.0
30
+ nickname: Optional[str] = None
31
+
32
+
33
+ class ImageSearchRequest(BaseModel):
34
+ image: Optional[str] = None # base64编码的图片
35
+ searchType: Optional[str] = "face"
36
+ top_k: Optional[int] = 5
37
+ score_threshold: float = 0.0
38
+ nickname: Optional[str] = None
39
+
40
+
41
+ class ImageFileList(BaseModel):
42
+ results: List[ImageScoreItem]
43
+ count: int
44
+
45
+ class PagedImageFileList(BaseModel):
46
+ results: List[ImageScoreItem]
47
+ count: int
48
+ page: int
49
+ page_size: int
50
+ total_pages: int
51
+
52
+ class CelebrityMatchResponse(BaseModel):
53
+ filename: str
54
+ display_name: Optional[str] = None
55
+ distance: float
56
+ similarity: float
57
+ confidence: float
58
+ face_filename: Optional[str] = None
59
+
60
+
61
+ class CategoryStatItem(BaseModel):
62
+ category: str
63
+ display_name: str
64
+ count: int
65
+
66
+
67
+ class CategoryStatsResponse(BaseModel):
68
+ stats: List[CategoryStatItem]
69
+ total: int
push.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ git push -f origin main
realesrgan_upscaler.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from config import logger, MODELS_PATH, REALESRGAN_MODEL
8
+
9
+ try:
10
+ from basicsr.archs.rrdbnet_arch import RRDBNet
11
+ from basicsr.utils.download_util import load_file_from_url
12
+ from realesrgan import RealESRGANer
13
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
14
+ import torch
15
+
16
+ # 设置PyTorch CPU优化
17
+ torch.set_num_threads(min(torch.get_num_threads(), 8)) # 限制线程数
18
+ torch.set_num_interop_threads(min(4, torch.get_num_interop_threads())) # 设置操作间线程数
19
+
20
+ REALESRGAN_AVAILABLE = True
21
+ logger.info("Real-ESRGAN imported successfully")
22
+ except ImportError as e:
23
+ logger.error(f"Real-ESRGAN import failed: {e}")
24
+ REALESRGAN_AVAILABLE = False
25
+
26
+
27
+ class RealESRGANUpscaler:
28
+ """Real-ESRGAN超清放大处理器"""
29
+
30
+ def __init__(self):
31
+ start_time = time.perf_counter()
32
+ self.upsampler = None
33
+ self.model_name = None
34
+ self.scale = 4
35
+ self.denoise_strength = 0.5
36
+ self._initialize()
37
+ init_time = time.perf_counter() - start_time
38
+ if self.upsampler is not None:
39
+ logger.info(f"RealESRGANUpscaler initialized successfully, time: {init_time:.3f}s")
40
+ else:
41
+ logger.info(f"RealESRGANUpscaler initialization completed but not available, time: {init_time:.3f}s")
42
+
43
+ def _initialize(self):
44
+ """初始化Real-ESRGAN模型"""
45
+ if not REALESRGAN_AVAILABLE:
46
+ logger.error("Real-ESRGAN is not available, cannot initialize super resolution processor")
47
+ return
48
+
49
+ try:
50
+ # 模型配置 - 从环境变量读取模型名称
51
+ model_name = REALESRGAN_MODEL
52
+ self.model_name = model_name
53
+
54
+ # 根据模型名称设置默认放大倍数
55
+ if 'x2' in model_name:
56
+ self.scale = 2
57
+ elif 'x4' in model_name:
58
+ self.scale = 4
59
+ else:
60
+ self.scale = 4 # 默认4倍
61
+
62
+ # 模型文件路径
63
+ model_path = None
64
+ if model_name == 'RealESRGAN_x4plus':
65
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
66
+ netscale = 4
67
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
68
+ elif model_name == 'RealESRNet_x4plus':
69
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
70
+ netscale = 4
71
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'
72
+ elif model_name == 'RealESRGAN_x4plus_anime_6B':
73
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
74
+ netscale = 4
75
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'
76
+ elif model_name == 'RealESRGAN_x2plus':
77
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
78
+ netscale = 2
79
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
80
+ elif model_name == 'realesr-animevideov3':
81
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
82
+ netscale = 4
83
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth'
84
+ elif model_name == 'realesr-general-x4v3':
85
+ # 最新的通用模型 v0.2.5.0
86
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
87
+ netscale = 4
88
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
89
+ elif model_name == 'realesr-general-wdn-x4v3':
90
+ # 最新的通用模型(带去噪)v0.2.5.0
91
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
92
+ netscale = 4
93
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
94
+
95
+ # 确保模型目录存在
96
+ model_dir = os.path.join(MODELS_PATH, 'realesrgan')
97
+ os.makedirs(model_dir, exist_ok=True)
98
+
99
+ # 检查本地是否已有模型文件
100
+ local_model_path = None
101
+ model_filename = f"{model_name}.pth"
102
+ local_pth = os.path.join(MODELS_PATH, model_filename)
103
+
104
+ if os.path.exists(local_pth):
105
+ local_model_path = local_pth
106
+ logger.info(f"Using local model file: {local_model_path}")
107
+
108
+ # 如果本地有模型文件,使用本地文件,��则下载
109
+ if local_model_path:
110
+ model_path = local_model_path
111
+ else:
112
+ # 下载模型
113
+ logger.info(f"Downloading model {model_name} from {file_url}")
114
+ model_path = load_file_from_url(
115
+ url=file_url, model_dir=model_dir, progress=True, file_name=model_filename)
116
+
117
+ # 创建upsampler
118
+ self.upsampler = RealESRGANer(
119
+ scale=netscale,
120
+ model_path=model_path,
121
+ model=model,
122
+ tile=512, # 启用分块处理,减少内存使用并提高CPU效率
123
+ tile_pad=10,
124
+ pre_pad=0,
125
+ half=False, # 使用fp32精度
126
+ gpu_id=None # 使用CPU
127
+ )
128
+
129
+ logger.info(f"Real-ESRGAN super resolution processor initialized successfully, model: {model_name}")
130
+
131
+ except Exception as e:
132
+ logger.error(f"Failed to initialize Real-ESRGAN: {e}")
133
+ self.upsampler = None
134
+
135
+ def is_available(self):
136
+ """检查处理器是否可用"""
137
+ return REALESRGAN_AVAILABLE and self.upsampler is not None
138
+
139
+ def _optimize_input_image(self, image):
140
+ """
141
+ 优化输入图像以提高CPU处理速度
142
+ :param image: 输入图像
143
+ :return: 优化后的图像
144
+ """
145
+ # 确保图像数据类型为uint8(减少计算开销)
146
+ if image.dtype != np.uint8:
147
+ if image.dtype == np.float32 or image.dtype == np.float64:
148
+ image = (image * 255).astype(np.uint8)
149
+ else:
150
+ image = image.astype(np.uint8)
151
+
152
+ # 确保图像是3通道BGR格式
153
+ if len(image.shape) == 2: # 灰度图
154
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
155
+ elif image.shape[2] == 4: # RGBA
156
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
157
+ elif image.shape[2] == 3 and image.shape[2] != 3: # RGB转BGR
158
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
159
+
160
+ return image
161
+
162
+ def upscale_image(self, image, scale=None, denoise_strength=None):
163
+ """
164
+ 对图像进行超清放大
165
+ :param image: 输入图像 (numpy array)
166
+ :param scale: 放大倍数,默认使用模型的放大倍数
167
+ :param denoise_strength: 去噪强度 (0-1),仅对realesr-general-x4v3模型有效
168
+ :return: 超清后的图像
169
+ """
170
+ if not self.is_available():
171
+ raise RuntimeError("Real-ESRGAN超清处理器不可用")
172
+
173
+ try:
174
+ start_time = time.perf_counter()
175
+
176
+ # 预处理优化图像
177
+ image = self._optimize_input_image(image)
178
+
179
+ # 设置去噪强度(仅对特定模型有效)
180
+ if denoise_strength is not None and self.model_name == 'realesr-general-x4v3':
181
+ self.denoise_strength = denoise_strength
182
+
183
+ # 根据图像大小动态调整tile大小以优化CPU性能
184
+ h, w = image.shape[:2]
185
+ pixel_count = h * w
186
+
187
+ # 根据图像大小调整tile大小
188
+ if pixel_count > 2000000: # 大于2MP
189
+ tile_size = 256
190
+ elif pixel_count > 1000000: # 大于1MP
191
+ tile_size = 384
192
+ else:
193
+ tile_size = 512
194
+
195
+ # 动态更新tile大小
196
+ if hasattr(self.upsampler, 'tile'):
197
+ self.upsampler.tile = tile_size
198
+ logger.info(f"Adjusting tile size to: {tile_size} based on image size ({w}x{h})")
199
+
200
+ # 执行超清处理
201
+ logger.info(f"Starting Real-ESRGAN super resolution processing, model: {self.model_name}")
202
+ output, _ = self.upsampler.enhance(image, outscale=scale or self.scale)
203
+
204
+ processing_time = time.perf_counter() - start_time
205
+ logger.info(f"Real-ESRGAN super resolution processing completed, time: {processing_time:.3f}s")
206
+
207
+ return output
208
+
209
+ except Exception as e:
210
+ logger.error(f"Real-ESRGAN super resolution processing failed: {e}")
211
+ raise RuntimeError(f"超清处理失败: {str(e)}")
212
+
213
+ def get_model_info(self):
214
+ """获取模型信息"""
215
+ return {
216
+ "model_name": self.model_name,
217
+ "scale": self.scale,
218
+ "available": self.is_available()
219
+ }
220
+
221
+
222
+ def get_upscaler():
223
+ """获取Real-ESRGAN超清处理器实例"""
224
+ return RealESRGANUpscaler()
225
+
226
+
227
+ # 全局实例(单例模式)
228
+ _upscaler_instance = None
229
+
230
+ def get_upscaler():
231
+ """获取全局超清处理器实例"""
232
+ global _upscaler_instance
233
+ if _upscaler_instance is None:
234
+ _upscaler_instance = RealESRGANUpscaler()
235
+ return _upscaler_instance
rembg_processor.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Optional, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from config import logger, REMBG_AVAILABLE
8
+
9
+ if REMBG_AVAILABLE:
10
+ import rembg
11
+ from rembg import new_session
12
+ from PIL import Image
13
+
14
+
15
+ class RembgProcessor:
16
+ """rembg抠图处理器"""
17
+
18
+ def __init__(self):
19
+ start_time = time.perf_counter()
20
+ self.session = None
21
+ self.available = False
22
+ self.model_name = "u2net" # 默认使用u2net模型,适合人像抠图
23
+
24
+ if REMBG_AVAILABLE:
25
+ try:
26
+ # 初始化rembg会话
27
+ self.session = new_session(self.model_name)
28
+ self.available = True
29
+ logger.info(f"rembg background removal processor initialized successfully, using model: {self.model_name}")
30
+ except Exception as e:
31
+ logger.error(f"rembg background removal processor initialization failed: {e}")
32
+ self.available = False
33
+ else:
34
+ logger.warning("rembg is not available, background removal function will be disabled")
35
+ init_time = time.perf_counter() - start_time
36
+ if self.available:
37
+ logger.info(f"RembgProcessor initialized successfully, time: {init_time:.3f}s")
38
+ else:
39
+ logger.info(f"RembgProcessor initialization completed but not available, time: {init_time:.3f}s")
40
+
41
+ def is_available(self) -> bool:
42
+ """检查抠图处理器是否可用"""
43
+ return self.available and self.session is not None
44
+
45
+ def remove_background(self, image: np.ndarray, background_color: Optional[Tuple[int, int, int]] = None) -> np.ndarray:
46
+ """
47
+ 移除图片背景
48
+ :param image: 输入的OpenCV图像(BGR格式)
49
+ :param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景
50
+ :return: 处理后的图像
51
+ """
52
+ if not self.is_available():
53
+ raise Exception("rembg抠图处理器不可用")
54
+
55
+ try:
56
+ # 将OpenCV图像(BGR)转换为PIL图像(RGB)
57
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
58
+ pil_image = Image.fromarray(image_rgb)
59
+
60
+ # 使用rembg移除背景
61
+ logger.info("Starting to remove background using rembg...")
62
+ output_image = rembg.remove(pil_image, session=self.session)
63
+
64
+ # 转换回OpenCV格式
65
+ if background_color is not None:
66
+ # 如果指定了背景颜色,创建纯色背景
67
+ background = Image.new('RGB', output_image.size, background_color[::-1]) # BGR转RGB
68
+ # 将透明图像粘贴到背景上
69
+ background.paste(output_image, mask=output_image)
70
+ result_array = np.array(background)
71
+ result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
72
+ else:
73
+ # 保持透明背景,转换为BGRA格式
74
+ result_array = np.array(output_image)
75
+ if result_array.shape[2] == 4: # RGBA格式
76
+ # 转换RGBA到BGRA
77
+ result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGBA2BGRA)
78
+ else: # RGB格式
79
+ result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
80
+
81
+ logger.info("rembg background removal completed")
82
+ return result_bgr
83
+
84
+ except Exception as e:
85
+ logger.error(f"rembg background removal failed: {e}")
86
+ raise Exception(f"背景移除失败: {str(e)}")
87
+
88
+ def create_id_photo(self, image: np.ndarray, background_color: Tuple[int, int, int] = (255, 255, 255)) -> np.ndarray:
89
+ """
90
+ 创建证件照(移除背景并添加纯色背景)
91
+ :param image: 输入的OpenCV图像
92
+ :param background_color: 背景颜色,默认白色(BGR格式)
93
+ :return: 处理后的证件照
94
+ """
95
+ logger.info(f"Starting to create ID photo, background color: {background_color}")
96
+
97
+ # 移除背景并添加指定颜色背景
98
+ id_photo = self.remove_background(image, background_color)
99
+
100
+ logger.info("ID photo creation completed")
101
+ return id_photo
102
+
103
+ def get_supported_models(self) -> list:
104
+ """获取支持的模型列表"""
105
+ if not REMBG_AVAILABLE:
106
+ return []
107
+
108
+ # rembg支持的模型列表
109
+ return [
110
+ "u2net", # 通用模型,适合人像
111
+ "u2net_human_seg", # 专门针对人像的模型
112
+ "silueta", # 适合物体抠图
113
+ "isnet-general-use" # 更精确的通用模型
114
+ ]
115
+
116
+ def switch_model(self, model_name: str) -> bool:
117
+ """
118
+ 切换rembg模型
119
+ :param model_name: 模型名称
120
+ :return: 是否切换成功
121
+ """
122
+ if not REMBG_AVAILABLE:
123
+ return False
124
+
125
+ try:
126
+ if model_name in self.get_supported_models():
127
+ self.session = new_session(model_name)
128
+ self.model_name = model_name
129
+ logger.info(f"rembg model switched to: {model_name}")
130
+ return True
131
+ else:
132
+ logger.error(f"Unsupported model: {model_name}")
133
+ return False
134
+ except Exception as e:
135
+ logger.error(f"Failed to switch model: {e}")
136
+ return False
requirements.txt ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 固定NumPy版本避免兼容性问题 - 必须最先安装
2
+ numpy>=1.24.0,<2.0.0
3
+
4
+ # 基础依赖
5
+ fastapi>=0.104.0
6
+ uvicorn[standard]>=0.24.0
7
+ python-multipart>=0.0.6
8
+ aiofiles>=23.2.1
9
+
10
+ # 图像处理
11
+ opencv-python>=4.8.0
12
+ Pillow>=10.0.0
13
+
14
+ # PyTorch 相关包 - 升级到2.x版本解决依赖冲突
15
+ torch>=2.0.0,<2.9.0
16
+ torchvision>=0.15.0
17
+
18
+ # 机器学习和CV相关
19
+ tf-keras
20
+ aiohttp
21
+ ultralytics
22
+ deepface
23
+ mediapipe>=0.10.0
24
+ # ModelScope相关包 - 让pip自动解决版本依赖
25
+ modelscope==1.28.2
26
+ datasets==2.21.0
27
+ transformers==4.40.0
28
+ # ModelScope DDColor的额外依赖
29
+ timm==1.0.19
30
+ sortedcontainers==2.4.0
31
+ fsspec==2024.6.1
32
+ multiprocess==0.70.16
33
+ xxhash==3.5.0
34
+ dill==0.3.8
35
+ huggingface-hub==0.34.3
36
+ # 修复pyarrow兼容性问题 - 使用稳定版本
37
+ pyarrow==20.0.0
38
+
39
+ # API相关
40
+ pydantic>=2.4.0
41
+ starlette>=0.27.0
42
+ simplejson==3.20.1
43
+ # 科学计算和工具
44
+ scipy>=1.7.0,<1.13.0
45
+ tqdm
46
+ lmdb
47
+ pyyaml
48
+
49
+ # 定时任务
50
+ apscheduler>=3.10.0
51
+
52
+ # 数据库
53
+ aiomysql>=0.2.0
54
+
55
+ # 对象存储
56
+ boto3>=1.34.0
57
+
58
+ # GFPGAN 和相关包 - 修复依赖兼容性
59
+ basicsr>=1.3.3
60
+ facexlib>=0.2.5
61
+ gfpgan>=1.3.0
62
+ realesrgan>=0.3.0
63
+
64
+ # CLIP 相关依赖
65
+ cn_clip
66
+ faiss-cpu
67
+ onnxruntime
68
+ diffusers
69
+ accelerate
70
+ # rembg 抠图处理
71
+ rembg>=2.0.50
72
+ easydict
rvm_processor.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from torchvision import transforms
6
+
7
+ import config
8
+ from config import logger
9
+
10
+
11
+ class RVMProcessor:
12
+ """RVM (Robust Video Matting) 抠图处理器"""
13
+
14
+ def __init__(self):
15
+ self.model = None
16
+ self.available = False
17
+ self.device = "cpu" # 默认使用CPU,如果有GPU可以设置为"cuda"
18
+
19
+ try:
20
+ # 仅从本地加载,不使用网络
21
+ local_repo = getattr(config, 'RVM_LOCAL_REPO', '')
22
+ weights_path = getattr(config, 'RVM_WEIGHTS_PATH', '')
23
+
24
+ if not local_repo or not os.path.isdir(local_repo):
25
+ raise RuntimeError("RVM_LOCAL_REPO not set or invalid. Please set env RVM_LOCAL_REPO to local RobustVideoMatting repo path (with hubconf.py)")
26
+
27
+ if not weights_path or not os.path.isfile(weights_path):
28
+ raise RuntimeError("RVM_WEIGHTS_PATH not set or file not found. Please set env RVM_WEIGHTS_PATH to local RVM weights file path")
29
+
30
+ logger.info(f"Loading RVM model {config.RVM_MODEL} from local repo: {local_repo}")
31
+ # 使用本地仓库构建模型,禁用预训练以避免联网
32
+ self.model = torch.hub.load(local_repo, config.RVM_MODEL, source='local', pretrained=False)
33
+
34
+ # 加载本地权重
35
+ state = torch.load(weights_path, map_location=self.device)
36
+ if isinstance(state, dict) and 'state_dict' in state:
37
+ state = state['state_dict']
38
+ missing, unexpected = self.model.load_state_dict(state, strict=False)
39
+
40
+ # 迁移到设备并设置评估模式
41
+ self.model = self.model.to(self.device).eval()
42
+ self.available = True
43
+ logger.info("RVM background removal processor initialized successfully (local mode)")
44
+ if missing:
45
+ logger.warning(f"RVM weights missing keys: {list(missing)[:5]}... total={len(missing)}")
46
+ if unexpected:
47
+ logger.warning(f"RVM weights unexpected keys: {list(unexpected)[:5]}... total={len(unexpected)}")
48
+
49
+ except Exception as e:
50
+ logger.error(f"RVM background removal processor initialization failed: {e}")
51
+ self.available = False
52
+
53
+ def is_available(self) -> bool:
54
+ """检查RVM处理器是否可用"""
55
+ return self.available and self.model is not None
56
+
57
+ def remove_background(self, image: np.ndarray, background_color: tuple = None) -> np.ndarray:
58
+ """
59
+ 使用RVM移除图片背景
60
+ :param image: 输入的OpenCV图像(BGR格式)
61
+ :param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景
62
+ :return: 处理后的图像
63
+ """
64
+ if not self.is_available():
65
+ raise Exception("RVM抠图处理器不可用")
66
+
67
+ try:
68
+ logger.info("Starting to remove background using RVM...")
69
+
70
+ # 保存原始图像尺寸
71
+ original_height, original_width = image.shape[:2]
72
+
73
+ # 将OpenCV图像(BGR)转换为RGB格式
74
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
75
+
76
+ # 转换为tensor
77
+ src = transforms.ToTensor()(image_rgb).unsqueeze(0).to(self.device)
78
+
79
+ # 推理
80
+ rec = [None] * 4
81
+ with torch.no_grad():
82
+ fgr, pha, *rec = self.model(src, *rec, downsample_ratio=0.25)
83
+
84
+ # 转换为numpy数组
85
+ fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # (H,W,3)
86
+ pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) # (H,W)
87
+
88
+ # 检查尺寸是否匹配,如果不匹配则调整
89
+ if fgr.shape[:2] != (original_height, original_width):
90
+ fgr = cv2.resize(fgr, (original_width, original_height))
91
+ pha = cv2.resize(pha, (original_width, original_height))
92
+
93
+ if background_color is not None:
94
+ # 如果指定了背景颜色,创建纯色背景
95
+ # 将前景图像转换为BGR格式
96
+ fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR)
97
+
98
+ # 创建背景图像
99
+ background = np.full((original_height, original_width, 3), background_color, dtype=np.uint8)
100
+
101
+ # 使用alpha混合
102
+ alpha = pha.astype(np.float32) / 255.0
103
+ alpha = np.stack([alpha] * 3, axis=-1)
104
+
105
+ result = (fgr_bgr * alpha + background * (1 - alpha)).astype(np.uint8)
106
+ else:
107
+ # 保持透明背景,转换为BGRA格式
108
+ fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR)
109
+ rgba = np.dstack((fgr_bgr, pha)) # (H,W,4)
110
+ result = rgba
111
+
112
+ logger.info("RVM background removal completed")
113
+ return result
114
+
115
+ except Exception as e:
116
+ logger.error(f"RVM background removal failed: {e}")
117
+ raise Exception(f"背景移除失败: {str(e)}")
118
+
119
+ def create_id_photo(self, image: np.ndarray, background_color: tuple = (255, 255, 255)) -> np.ndarray:
120
+ """
121
+ 创建证件照(移除背景并添加纯色背景)
122
+ :param image: 输入的OpenCV图像
123
+ :param background_color: 背景颜色,默认白色(BGR格式)
124
+ :return: 处理后的证件照
125
+ """
126
+ logger.info(f"Starting to create ID photo, background color: {background_color}")
127
+
128
+ # 移除背景并添加指定颜色背景
129
+ id_photo = self.remove_background(image, background_color)
130
+
131
+ logger.info("ID photo creation completed")
132
+ return id_photo
start_local.sh ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export TZ=Asia/Shanghai
3
+
4
+ export OUTPUT_DIR=/opt/data/output
5
+ export IMAGES_DIR=/opt/data/images
6
+ export MODELS_PATH=/opt/data/models
7
+ export DEEPFACE_HOME=/opt/data/models
8
+ export FAISS_INDEX_DIR=/opt/data/faiss
9
+ export CELEBRITY_SOURCE_DIR=/opt/data/chinese_celeb_dataset
10
+ export GENDER_CONFIDENCE=1
11
+ export UPSCALE_SIZE=2
12
+ export AGE_CONFIDENCE=1.0
13
+ export DRAW_SCORE=true
14
+ export FACE_CONFIDENCE=0.7
15
+
16
+ export ENABLE_DDCOLOR=true
17
+ export ENABLE_GFPGAN=true
18
+ export ENABLE_REALESRGAN=true
19
+ export ENABLE_ANIME_STYLE=true
20
+ export ENABLE_RVM=true
21
+ export ENABLE_REMBG=true
22
+ export ENABLE_CLIP=false
23
+
24
+ export CLEANUP_INTERVAL_HOURS=1
25
+ export CLEANUP_AGE_HOURS=1
26
+
27
+ export BEAUTY_ADJUST_GAMMA=0.8
28
+ export BEAUTY_ADJUST_MIN=1.0
29
+ export BEAUTY_ADJUST_MAX=9.0
30
+ export ENABLE_ANIME_PRELOAD=true
31
+ export ENABLE_LOGGING=true
32
+ export BEAUTY_ADJUST_ENABLED=true
33
+
34
+ export RVM_LOCAL_REPO=/opt/data/models/RobustVideoMatting
35
+ export RVM_WEIGHTS_PATH=/opt/data/models/torch/hub/checkpoints/rvm_resnet50.pth
36
+ export RVM_MODEL=resnet50
37
+
38
+ export AUTO_INIT_GFPGAN=false
39
+ export AUTO_INIT_DDCOLOR=false
40
+ export AUTO_INIT_REALESRGAN=false
41
+ export AUTO_INIT_ANIME_STYLE=true
42
+ export AUTO_INIT_CLIP=false
43
+ export AUTO_INIT_RVM=false
44
+ export AUTO_INIT_REMBG=false
45
+
46
+ export ENABLE_WARMUP=true
47
+ export REALESRGAN_MODEL=realesr-general-x4v3
48
+ export CELEBRITY_FIND_THRESHOLD=0.87
49
+ export FEMALE_AGE_ADJUSTMENT=4
50
+
51
+ uvicorn app:app --workers 1 --loop asyncio --http httptools --host 0.0.0.0 --port 7860 --timeout-keep-alive 600
52
+
test/celebrity_crawler.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from io import BytesIO
3
+ from pathlib import Path
4
+
5
+ import requests
6
+ from PIL import Image
7
+
8
+
9
+ class CelebrityCrawler:
10
+ def __init__(self, output_dir="celebrity_images"):
11
+ self.output_dir = output_dir
12
+ self.headers = {
13
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
14
+ }
15
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
16
+
17
+ def read_celebrities_from_txt(self, file_path):
18
+ """
19
+ 从txt文件读取明星信息
20
+ 支持格式:
21
+ 1. 姓名,职业
22
+ 2. 姓名
23
+ """
24
+ celebrities = []
25
+ with open(file_path, 'r', encoding='utf-8') as f:
26
+ for line in f:
27
+ line = line.strip()
28
+ if not line or line.startswith('#'):
29
+ continue
30
+
31
+ parts = line.split(',')
32
+ name = parts[0].strip()
33
+ profession = parts[1].strip() if len(parts) > 1 else "明星"
34
+
35
+ celebrities.append({
36
+ 'name': name,
37
+ 'profession': profession
38
+ })
39
+ return celebrities
40
+
41
+ def search_bing_images(self, celebrity_name, max_images=20):
42
+ """使用Bing图片搜索API获取图片URL"""
43
+ search_url = "https://www.bing.com/images/search"
44
+ params = {
45
+ 'q': celebrity_name + " 明星",
46
+ 'first': 0,
47
+ 'count': max_images
48
+ }
49
+
50
+ try:
51
+ response = requests.get(search_url, params=params, headers=self.headers,
52
+ timeout=10)
53
+ response.raise_for_status()
54
+
55
+ # 简单的HTML解析获取图片URL
56
+ import re
57
+ img_urls = re.findall(r'murl&quot;:&quot;(.*?)&quot;', response.text)
58
+ return img_urls[:max_images]
59
+ except Exception as e:
60
+ print(f"搜索 {celebrity_name} 时出错: {e}")
61
+ return []
62
+
63
+ def search_baidu_images(self, celebrity_name, max_images=20):
64
+ """使用百度图片搜索获取图片URL"""
65
+ search_url = "https://image.baidu.com/search/acjson"
66
+ params = {
67
+ 'tn': 'resultjson_com',
68
+ 'word': celebrity_name + " 明星",
69
+ 'pn': 0,
70
+ 'rn': max_images,
71
+ 'ie': 'utf-8'
72
+ }
73
+
74
+ try:
75
+ response = requests.get(search_url, params=params, headers=self.headers,
76
+ timeout=10)
77
+ response.raise_for_status()
78
+ data = response.json()
79
+
80
+ img_urls = []
81
+ if 'data' in data:
82
+ for item in data['data']:
83
+ if 'thumbURL' in item:
84
+ img_urls.append(item['thumbURL'])
85
+ return img_urls[:max_images]
86
+ except Exception as e:
87
+ print(f"搜索 {celebrity_name} 时出错: {e}")
88
+ return []
89
+
90
+ def download_image(self, url, save_path):
91
+ """下载单张图片"""
92
+ try:
93
+ response = requests.get(url, headers=self.headers, timeout=15)
94
+ response.raise_for_status()
95
+
96
+ # 验证是否为有效图片
97
+ img = Image.open(BytesIO(response.content))
98
+
99
+ # 过滤太小的图片
100
+ if img.size[0] < 100 or img.size[1] < 100:
101
+ return False
102
+
103
+ # 保存图片
104
+ img = img.convert('RGB')
105
+ img.save(save_path, 'JPEG', quality=95)
106
+ return True
107
+ except Exception as e:
108
+ print(f" 下载失败: {str(e)[:50]}")
109
+ return False
110
+
111
+ def crawl_celebrity_images(self, celebrity, max_images=20,
112
+ search_engine='baidu'):
113
+ """爬取单个明星的图片"""
114
+ name = celebrity['name']
115
+ print(f"\n正在爬取: {name} ({celebrity['profession']})")
116
+
117
+ # 创建明星专属文件夹
118
+ celebrity_dir = Path(self.output_dir)
119
+ celebrity_dir.mkdir(parents=True, exist_ok=True)
120
+
121
+ # 获取图片URL列表
122
+ if search_engine == 'baidu':
123
+ img_urls = self.search_baidu_images(name, max_images * 2)
124
+ else:
125
+ img_urls = self.search_bing_images(name, max_images * 2)
126
+
127
+ if not img_urls:
128
+ print(f" 未找到 {name} 的图片")
129
+ return 0
130
+
131
+ print(f" 找到 {len(img_urls)} 个图片链接")
132
+
133
+ # 下载图片
134
+ success_count = 0
135
+ for idx, url in enumerate(img_urls):
136
+ if success_count >= max_images:
137
+ break
138
+
139
+ save_path = celebrity_dir / f"{name}_{idx + 1:03d}.jpg"
140
+
141
+ # 跳过已存在的文件
142
+ if save_path.exists():
143
+ success_count += 1
144
+ continue
145
+
146
+ print(f" 下载 {idx + 1}/{len(img_urls)}...", end=' ')
147
+ if self.download_image(url, save_path):
148
+ success_count += 1
149
+ print("✓")
150
+ else:
151
+ print("✗")
152
+
153
+ # 避免请求过快
154
+ time.sleep(0.5)
155
+
156
+ print(f" 成功下载 {success_count} 张图片")
157
+ return success_count
158
+
159
+ def crawl_all(self, txt_file, max_images_per_celebrity=20,
160
+ search_engine='baidu'):
161
+ """爬取所有明星的图片"""
162
+ print("=" * 60)
163
+ print("明星照片爬取工具")
164
+ print("=" * 60)
165
+
166
+ # 读取明星列表
167
+ celebrities = self.read_celebrities_from_txt(txt_file)
168
+ print(f"\n从 {txt_file} 读取到 {len(celebrities)} 位明星")
169
+
170
+ # 统计信息
171
+ total_images = 0
172
+ failed_celebrities = []
173
+
174
+ # 爬取每位明星
175
+ for i, celebrity in enumerate(celebrities, 1):
176
+ print(f"\n[{i}/{len(celebrities)}]", end=' ')
177
+
178
+ try:
179
+ count = self.crawl_celebrity_images(
180
+ celebrity,
181
+ max_images=max_images_per_celebrity,
182
+ search_engine=search_engine
183
+ )
184
+ total_images += count
185
+
186
+ if count == 0:
187
+ failed_celebrities.append(celebrity['name'])
188
+
189
+ # 每爬取5个明星后暂停一下
190
+ if i % 5 == 0:
191
+ print(f"\n 已完成 {i}/{len(celebrities)}, 休息3秒...")
192
+ time.sleep(3)
193
+
194
+ except Exception as e:
195
+ print(f" 处理 {celebrity['name']} 时出错: {e}")
196
+ failed_celebrities.append(celebrity['name'])
197
+
198
+ # 输出统计
199
+ print("\n" + "=" * 60)
200
+ print("爬取完成!")
201
+ print("=" * 60)
202
+ print(f"总明星数: {len(celebrities)}")
203
+ print(f"成功爬取: {len(celebrities) - len(failed_celebrities)}")
204
+ print(f"失败数量: {len(failed_celebrities)}")
205
+ print(f"总图片数: {total_images}")
206
+ print(f"保存位置: {self.output_dir}")
207
+
208
+ if failed_celebrities:
209
+ print(f"\n失败的明星: {', '.join(failed_celebrities)}")
210
+
211
+
212
+ # 使用示例
213
+ if __name__ == "__main__":
214
+ # 创建爬虫实例
215
+ crawler = CelebrityCrawler(output_dir="celebrity_dataset")
216
+
217
+ # 从txt文件爬取
218
+ # txt文件格式示例:
219
+ # 周杰伦,歌手
220
+ # 刘德华,演员
221
+ # 范冰冰,演员
222
+
223
+ crawler.crawl_all(
224
+ txt_file="celebrity_real_names.txt", # 你的txt文件路径
225
+ max_images_per_celebrity=1, # 每位明星爬取的图片数量
226
+ search_engine='baidu' # 'baidu' 或 'bing'
227
+ )
test/celebrity_crawler.pyc ADDED
Binary file (5.71 kB). View file
 
test/decode_celeb_dataset.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Decode base64 file names inside the Chinese celeb dataset directory.
4
+
5
+ Default target: /Users/chenchaoyun/Downloads/chinese_celeb_dataset.
6
+ Use --root to override; --dry-run only prints the plan.
7
+ """
8
+ import argparse
9
+ import base64
10
+ from pathlib import Path
11
+ import sys
12
+
13
+ DEFAULT_ROOT = Path("/Users/chenchaoyun/Downloads/chinese_celeb_dataset")
14
+
15
+
16
+ def _decode_basename(encoded: str) -> str:
17
+ padding = "=" * ((4 - len(encoded) % 4) % 4)
18
+ try:
19
+ return base64.urlsafe_b64decode(
20
+ (encoded + padding).encode("ascii")).decode("utf-8")
21
+ except Exception:
22
+ return encoded
23
+
24
+
25
+ def rename_dataset(root: Path, dry_run: bool = False) -> int:
26
+ if not root.exists():
27
+ print(f"Directory does not exist: {root}", file=sys.stderr)
28
+ return 1
29
+ if not root.is_dir():
30
+ print(f"Not a directory: {root}", file=sys.stderr)
31
+ return 1
32
+
33
+ renamed = 0
34
+ for file_path in sorted(root.rglob("*")):
35
+ if not file_path.is_file():
36
+ continue
37
+ decoded = _decode_basename(file_path.stem)
38
+ if decoded == file_path.stem:
39
+ continue
40
+
41
+ new_path = file_path.with_name(f"{decoded}{file_path.suffix}")
42
+ if new_path == file_path:
43
+ continue
44
+
45
+ # Append a counter if the decoded target already exists
46
+ counter = 1
47
+ while new_path.exists() and new_path != file_path:
48
+ new_path = file_path.with_name(
49
+ f"{decoded}_{counter}{file_path.suffix}"
50
+ )
51
+ counter += 1
52
+
53
+ print(f"{file_path} -> {new_path}")
54
+ if dry_run:
55
+ continue
56
+ file_path.rename(new_path)
57
+ renamed += 1
58
+
59
+ print(f"Renamed {renamed} files")
60
+ return 0
61
+
62
+
63
+ def parse_args() -> argparse.Namespace:
64
+ parser = argparse.ArgumentParser(
65
+ description="Decode chinese_celeb_dataset file names")
66
+ parser.add_argument(
67
+ "--root",
68
+ type=Path,
69
+ default=DEFAULT_ROOT,
70
+ help="Dataset root directory (default: %(default)s)",
71
+ )
72
+ parser.add_argument(
73
+ "--dry-run",
74
+ action="store_true",
75
+ help="Only print planned renames without applying them",
76
+ )
77
+ return parser.parse_args()
78
+
79
+
80
+ def main() -> int:
81
+ args = parse_args()
82
+ return rename_dataset(args.root.expanduser().resolve(), args.dry_run)
83
+
84
+
85
+ if __name__ == "__main__":
86
+ sys.exit(main())
test/decode_celeb_dataset.pyc ADDED
Binary file (2.26 kB). View file
 
test/dow_img.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ # 读取图片
4
+ img = cv2.imread("/opt/data/header.png")
5
+
6
+ # 设置压缩质量(0-100,值越小压缩越狠,质量越差)
7
+ quality = 50
8
+
9
+ # 写入压缩后的图像(注意必须是 .webp)
10
+ cv2.imwrite(
11
+ "/opt/data/output_small.webp",
12
+ img,
13
+ [int(cv2.IMWRITE_WEBP_QUALITY), quality],
14
+ )
15
+
16
+
17
+ # # 读取原图
18
+ # img = cv2.imread("/opt/data/header.png")
19
+ #
20
+ # # 缩放图像(例如缩小为原图的一半)
21
+ # resized = cv2.resize(img, (img.shape[1] // 2, img.shape[0] // 2))
22
+ #
23
+ # # 写入压缩图像,降低质量
24
+ # cv2.imwrite("/opt/data/output_small.webp", resized, [int(cv2.IMWRITE_WEBP_QUALITY), 40])
test/dow_img.pyc ADDED
Binary file (291 Bytes). View file
 
test/howcuteami.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import argparse
4
+ import numpy as np
5
+ import os
6
+
7
+
8
+ # detect face
9
+ def highlightFace(net, frame, conf_threshold=0.95):
10
+ frameOpencvDnn = frame.copy()
11
+ frameHeight = frameOpencvDnn.shape[0]
12
+ frameWidth = frameOpencvDnn.shape[1]
13
+ blob = cv2.dnn.blobFromImage(
14
+ frameOpencvDnn, 1.0, (300, 300), [104, 117, 123], True, False
15
+ )
16
+
17
+ net.setInput(blob)
18
+ detections = net.forward()
19
+ faceBoxes = []
20
+
21
+ for i in range(detections.shape[2]):
22
+ confidence = detections[0, 0, i, 2]
23
+ if confidence > conf_threshold:
24
+ x1 = int(detections[0, 0, i, 3] * frameWidth)
25
+ y1 = int(detections[0, 0, i, 4] * frameHeight)
26
+ x2 = int(detections[0, 0, i, 5] * frameWidth)
27
+ y2 = int(detections[0, 0, i, 6] * frameHeight)
28
+ faceBoxes.append(scale([x1, y1, x2, y2]))
29
+
30
+ return faceBoxes
31
+
32
+
33
+ # scale current rectangle to box
34
+ def scale(box):
35
+ width = box[2] - box[0]
36
+ height = box[3] - box[1]
37
+ maximum = max(width, height)
38
+ dx = int((maximum - width) / 2)
39
+ dy = int((maximum - height) / 2)
40
+
41
+ bboxes = [box[0] - dx, box[1] - dy, box[2] + dx, box[3] + dy]
42
+ return bboxes
43
+
44
+
45
+ # crop image
46
+ def cropImage(image, box):
47
+ num = image[box[1] : box[3], box[0] : box[2]]
48
+ return num
49
+
50
+
51
+ # main
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument("-i", "--image", type=str, required=False, help="input image")
54
+ args = parser.parse_args()
55
+
56
+ # 创建输出目录
57
+ output_dir = "../output"
58
+ if not os.path.exists(output_dir):
59
+ os.makedirs(output_dir)
60
+
61
+ faceProto = "models/opencv_face_detector.pbtxt"
62
+ faceModel = "models/opencv_face_detector_uint8.pb"
63
+ ageProto = "models/age_googlenet.prototxt"
64
+ ageModel = "models/age_googlenet.caffemodel"
65
+ genderProto = "models/gender_googlenet.prototxt"
66
+ genderModel = "models/gender_googlenet.caffemodel"
67
+ beautyProto = "models/beauty_resnet.prototxt"
68
+ beautyModel = "models/beauty_resnet.caffemodel"
69
+
70
+ MODEL_MEAN_VALUES = (104, 117, 123)
71
+ ageList = [
72
+ "(0-2)",
73
+ "(4-6)",
74
+ "(8-12)",
75
+ "(15-20)",
76
+ "(25-32)",
77
+ "(38-43)",
78
+ "(48-53)",
79
+ "(60-100)",
80
+ ]
81
+ genderList = ["Male", "Female"]
82
+
83
+ # 定义性别对应的颜色 (BGR格式)
84
+ gender_colors = {
85
+ "Male": (255, 165, 0), # 橙色 Orange
86
+ "Female": (255, 0, 255), # 洋红 Magenta / Fuchsia
87
+ }
88
+
89
+ faceNet = cv2.dnn.readNet(faceModel, faceProto)
90
+ ageNet = cv2.dnn.readNet(ageModel, ageProto)
91
+ genderNet = cv2.dnn.readNet(genderModel, genderProto)
92
+ beautyNet = cv2.dnn.readNet(beautyModel, beautyProto)
93
+
94
+ # 读取图片
95
+ image_path = args.image if args.image else "images/charlize.jpg"
96
+ frame = cv2.imread(image_path)
97
+
98
+ if frame is None:
99
+ print(f"无法读取图片: {image_path}")
100
+ exit()
101
+
102
+ faceBoxes = highlightFace(faceNet, frame)
103
+ if not faceBoxes:
104
+ print("No face detected")
105
+ exit()
106
+
107
+ print(f"检测到 {len(faceBoxes)} 张人脸")
108
+
109
+ for i, faceBox in enumerate(faceBoxes):
110
+ # 提取人脸区域
111
+ face = cropImage(frame, faceBox)
112
+ face_resized = cv2.resize(face, (224, 224))
113
+
114
+ # gender net
115
+ blob = cv2.dnn.blobFromImage(
116
+ face_resized, 1.0, (224, 224), MODEL_MEAN_VALUES, swapRB=False
117
+ )
118
+ genderNet.setInput(blob)
119
+ genderPreds = genderNet.forward()
120
+ gender = genderList[genderPreds[0].argmax()]
121
+ print(f"Gender: {gender}")
122
+
123
+ # age net
124
+ ageNet.setInput(blob)
125
+ agePreds = ageNet.forward()
126
+ age = ageList[agePreds[0].argmax()]
127
+ print(f"Age: {age[1:-1]} years")
128
+
129
+ # beauty net
130
+ blob = cv2.dnn.blobFromImage(
131
+ face_resized, 1.0 / 255, (224, 224), MODEL_MEAN_VALUES, swapRB=False
132
+ )
133
+ beautyNet.setInput(blob)
134
+ beautyPreds = beautyNet.forward()
135
+ beauty = round(2.0 * sum(beautyPreds[0]), 1)
136
+ print(f"Beauty: {beauty}/10.0")
137
+
138
+ # 根据性别选择颜色
139
+ color = gender_colors[gender]
140
+
141
+ # 保存人脸图片 - 使用cv2.imwrite
142
+ face_filename = f"{output_dir}/face_{i+1}.webp"
143
+ cv2.imwrite(face_filename, face, [cv2.IMWRITE_WEBP_QUALITY, 95])
144
+ print(f"人脸图片已保存: {face_filename}")
145
+
146
+ # 保存评分到图片上(可选)
147
+ face_with_text = face.copy()
148
+ cv2.putText(
149
+ face_with_text, f"{gender}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2
150
+ )
151
+ cv2.putText(
152
+ face_with_text,
153
+ f"{age[1:-1]} years",
154
+ (10, 60),
155
+ cv2.FONT_HERSHEY_SIMPLEX,
156
+ 0.7,
157
+ color,
158
+ 2,
159
+ )
160
+ cv2.putText(
161
+ face_with_text,
162
+ f"{beauty}/10.0",
163
+ (10, 90),
164
+ cv2.FONT_HERSHEY_SIMPLEX,
165
+ 0.7,
166
+ color,
167
+ 2,
168
+ )
169
+
170
+ annotated_filename = f"{output_dir}/face_{i+1}_annotated.webp"
171
+ cv2.imwrite(annotated_filename, face_with_text, [cv2.IMWRITE_WEBP_QUALITY, 95])
172
+ print(f"标注人脸已保存: {annotated_filename}")
173
+
174
+ # 在原图上绘制人脸框和信息
175
+ cv2.rectangle(
176
+ frame,
177
+ (faceBox[0], faceBox[1]),
178
+ (faceBox[2], faceBox[3]),
179
+ color,
180
+ int(round(frame.shape[0] / 400)),
181
+ 8,
182
+ )
183
+ cv2.putText(
184
+ frame,
185
+ f"{gender}, {age}, {beauty}",
186
+ (faceBox[0], faceBox[1] - 10),
187
+ cv2.FONT_HERSHEY_SIMPLEX,
188
+ 1.25,
189
+ color,
190
+ 2,
191
+ cv2.LINE_AA,
192
+ )
193
+
194
+ # 保存完整的标注图片
195
+ result_filename = f"{output_dir}/result_full.webp"
196
+ cv2.imwrite(result_filename, frame, [cv2.IMWRITE_WEBP_QUALITY, 95])
197
+ print(f"完整结果图片已保存: {result_filename}")
198
+
199
+ # 显示图片
200
+ cv2.imshow("howbeautifulami", frame)
201
+ cv2.waitKey(0)
202
+ cv2.destroyAllWindows()
test/howcuteami.pyc ADDED
Binary file (4.13 kB). View file
 
test/import_history_images.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 导入历史图片文件到数据库的脚本
4
+ """
5
+
6
+ import asyncio
7
+ import hashlib
8
+ import os
9
+ import sys
10
+ import time
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+
14
+ # 添加项目根目录到Python路径
15
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16
+
17
+ from database import record_image_creation, fetch_records_by_paths
18
+
19
+
20
+ def calculate_file_hash(file_path):
21
+ """计算文件的MD5哈希值"""
22
+ hash_md5 = hashlib.md5()
23
+ with open(file_path, "rb") as f:
24
+ # 分块读取文件,避免大文件占用过多内存
25
+ for chunk in iter(lambda: f.read(4096), b""):
26
+ hash_md5.update(chunk)
27
+ return hash_md5.hexdigest()
28
+
29
+
30
+ def infer_category_from_filename(filename):
31
+ """从文件名推断类别"""
32
+ filename_lower = filename.lower()
33
+
34
+ # 处理动漫风格化类型
35
+ if '_anime_style_' in filename_lower:
36
+ return 'anime_style'
37
+
38
+ # 查找最后一个下划线和第一个点的位置
39
+ last_underscore_index = filename_lower.rfind('_')
40
+ first_dot_index = filename_lower.find('.', last_underscore_index)
41
+
42
+ # 如果找到了下划线和点,且下划线在点之前
43
+ if last_underscore_index != -1 and first_dot_index != -1 and last_underscore_index < first_dot_index:
44
+ # 提取下划线和点之间的内容
45
+ file_type = filename_lower[last_underscore_index + 1:first_dot_index]
46
+
47
+ # 根据类型返回中文描述
48
+ type_mapping = {
49
+ 'restore': 'restore',
50
+ 'upcolor': 'upcolor',
51
+ 'grayscale': 'grayscale',
52
+ 'upscale': 'upscale',
53
+ 'compress': 'compress',
54
+ 'id_photo': 'id_photo',
55
+ 'grid': 'grid',
56
+ 'rvm': 'rvm',
57
+ 'celebrity': 'celebrity',
58
+ 'face': 'face',
59
+ 'original': 'original'
60
+ }
61
+
62
+ return type_mapping.get(file_type, 'other')
63
+
64
+ # 默认返回 other
65
+ return 'other'
66
+
67
+
68
+ async def import_history_images(source_dir, nickname="system_import"):
69
+ """导入历史图片到数据库"""
70
+ source_path = Path(source_dir)
71
+
72
+ if not source_path.exists():
73
+ print(f"错误: 目录 {source_dir} 不存在")
74
+ return
75
+
76
+ # 支持的图片格式
77
+ image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif', '.tiff',
78
+ '.tif'}
79
+
80
+ # 获取所有图片文件
81
+ image_files = []
82
+ for ext in image_extensions:
83
+ image_files.extend(source_path.glob(f"*{ext}"))
84
+ image_files.extend(source_path.glob(f"*{ext.upper()}"))
85
+
86
+ print(f"找到 {len(image_files)} 个图片文件")
87
+
88
+ imported_count = 0
89
+ skipped_count = 0
90
+
91
+ for image_path in image_files:
92
+ try:
93
+ file_name = image_path.name
94
+
95
+ # 检查文件是否已存在于数据库中(基于文件名)
96
+ records = await fetch_records_by_paths([file_name])
97
+
98
+ if file_name in records:
99
+ print(f"跳过已存在的文件: {file_name}")
100
+ skipped_count += 1
101
+ continue
102
+
103
+ # 如果数据库中没有记录,则继续导入
104
+ # 计算文件哈希值用于进一步确认唯一性
105
+ file_hash = calculate_file_hash(str(image_path))
106
+
107
+ # 推断文件类别
108
+ category = infer_category_from_filename(file_name)
109
+
110
+ # 记录到数据库
111
+ await record_image_creation(
112
+ file_path=file_name, # 使用文件名而不是完整路径
113
+ nickname=nickname,
114
+ category=category,
115
+ bos_uploaded=False, # 历史文件通常未上传到BOS
116
+ score=0.0, # 历史文件默认分数为0
117
+ extra_metadata={
118
+ "source": "history_import",
119
+ "original_path": str(image_path),
120
+ "file_hash": file_hash,
121
+ "import_time": datetime.now().isoformat()
122
+ }
123
+ )
124
+
125
+ imported_count += 1
126
+ print(f"成功导入: {file_name} (类别: {category})")
127
+
128
+ except Exception as e:
129
+ print(f"导入文件失败 {image_path.name}: {str(e)}")
130
+ continue
131
+
132
+ print(f"\n导入完成!")
133
+ print(f"成功导入: {imported_count} 个文件")
134
+ print(f"跳过: {skipped_count} 个文件")
135
+
136
+
137
+ async def main():
138
+ if len(sys.argv) < 2:
139
+ print("用法: python import_history_images.py <图片目录路径> [昵称]")
140
+ print(
141
+ "示例: python import_history_images.py ~/app/data/images")
142
+ print(
143
+ "示例: python import_history_images.py ~/app/data/images \"历史导入\"")
144
+ sys.exit(1)
145
+
146
+ source_directory = sys.argv[1]
147
+ nickname = sys.argv[2] if len(sys.argv) > 2 else "system_import"
148
+
149
+ print(f"开始导入图片文件...")
150
+ print(f"源目录: {source_directory}")
151
+ print(f"用户昵称: {nickname}")
152
+ print("-" * 50)
153
+
154
+ start_time = time.time()
155
+ await import_history_images(source_directory, nickname)
156
+ end_time = time.time()
157
+
158
+ print(f"\n总耗时: {end_time - start_time:.2f} 秒")
159
+
160
+
161
+ if __name__ == "__main__":
162
+ asyncio.run(main())
test/import_history_images.pyc ADDED
Binary file (3.76 kB). View file
 
test/remove_duplicate_celeb_images.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 遍历指定目录,根据文件内容(MD5)查找重复项,如果发现重复则只保留一个。
4
+ 默认目标目录为 /opt/data/chinese_celeb_dataset,可用 --target-dir 覆盖。
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import hashlib
11
+ import os
12
+ import sys
13
+ from pathlib import Path
14
+ from typing import Dict
15
+
16
+ DEFAULT_TARGET_DIR = Path("/opt/data/chinese_celeb_dataset")
17
+ CHUNK_SIZE = 4 * 1024 * 1024 # 4MB
18
+
19
+
20
+ def compute_md5(file_path: Path) -> str:
21
+ """流式计算文件 MD5,避免一次性读入大文件。"""
22
+ digest = hashlib.md5()
23
+ with file_path.open("rb") as fh:
24
+ for chunk in iter(lambda: fh.read(CHUNK_SIZE), b""):
25
+ digest.update(chunk)
26
+ return digest.hexdigest()
27
+
28
+
29
+ def deduplicate(target_dir: Path, dry_run: bool = False) -> int:
30
+ """执行去重逻辑,返回删除的重复文件数量。"""
31
+ if not target_dir.exists():
32
+ print(f"[error] 目标目录不存在: {target_dir}", file=sys.stderr)
33
+ return 0
34
+ if not target_dir.is_dir():
35
+ print(f"[error] 目标路径不是目录: {target_dir}", file=sys.stderr)
36
+ return 0
37
+
38
+ md5_map: Dict[str, Path] = {}
39
+ removed = 0
40
+ scanned = 0
41
+
42
+ # 按路径排序,确保始终保留最先遍历到的文件
43
+ for file_path in sorted(target_dir.rglob("*")):
44
+ if not file_path.is_file() or file_path.is_symlink():
45
+ continue
46
+
47
+ scanned += 1
48
+ try:
49
+ file_md5 = compute_md5(file_path)
50
+ except Exception as exc:
51
+ print(f"[warn] 计算 MD5 失败: {file_path} -> {exc}", file=sys.stderr)
52
+ continue
53
+
54
+ original = md5_map.get(file_md5)
55
+ if original is None:
56
+ md5_map[file_md5] = file_path
57
+ continue
58
+
59
+ if dry_run:
60
+ print(f"[dry-run] {file_path} 与 {original} 内容相同,将被删除")
61
+ else:
62
+ try:
63
+ os.remove(file_path)
64
+ removed += 1
65
+ print(f"[remove] 删除重复文件: {file_path} (原始: {original})")
66
+ except Exception as exc:
67
+ print(f"[error] 删除失败: {file_path} -> {exc}", file=sys.stderr)
68
+
69
+ print(
70
+ f"[summary] 扫描文件: {scanned}, 保留唯一文件: {len(md5_map)}, 删除重复文件: {removed}{' (dry-run)' if dry_run else ''}"
71
+ )
72
+ return removed
73
+
74
+
75
+ def parse_args() -> argparse.Namespace:
76
+ parser = argparse.ArgumentParser(description="按 MD5 删除重复文件,仅保留一个副本。")
77
+ parser.add_argument(
78
+ "--target-dir",
79
+ type=Path,
80
+ default=DEFAULT_TARGET_DIR,
81
+ help=f"需要去重的目录(默认: {DEFAULT_TARGET_DIR})",
82
+ )
83
+ parser.add_argument(
84
+ "--dry-run",
85
+ action="store_true",
86
+ help="只输出将删除的文件,不实际删除。",
87
+ )
88
+ return parser.parse_args()
89
+
90
+
91
+ def main() -> int:
92
+ args = parse_args()
93
+ target_dir = args.target_dir.expanduser().resolve()
94
+ deduplicate(target_dir, dry_run=args.dry_run)
95
+ return 0
96
+
97
+
98
+ if __name__ == "__main__":
99
+ raise SystemExit(main())
test/remove_duplicate_celeb_images.pyc ADDED
Binary file (3.18 kB). View file
 
test/remove_faceless_images.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 遍历 /opt/data/chinese_celeb_dataset 下的图片,使用 YOLO 人脸检测并删除没有检测到人脸的图片。
4
+
5
+ 用法示例:
6
+ python test/remove_faceless_images.py --dry-run
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import sys
13
+ from pathlib import Path
14
+ from typing import Iterable, List, Optional
15
+
16
+ import config
17
+
18
+ try:
19
+ from ultralytics import YOLO
20
+ except ImportError as exc: # pragma: no cover - 运行期缺依赖提示
21
+ raise SystemExit("缺少 ultralytics,请先执行 pip install ultralytics") from exc
22
+
23
+ # 默认数据集与模型配置
24
+ DEFAULT_DATASET_DIR = Path("/opt/data/chinese_celeb_dataset")
25
+ MODEL_DIR = Path(config.MODELS_PATH)
26
+ YOLO_MODEL_NAME = config.YOLO_MODEL
27
+
28
+
29
+ def parse_args() -> argparse.Namespace:
30
+ parser = argparse.ArgumentParser(
31
+ description="使用 YOLO 检测 /opt/data/chinese_celeb_dataset 中的图片并删除无脸图片"
32
+ )
33
+ parser.add_argument(
34
+ "--dataset-dir",
35
+ type=Path,
36
+ default=DEFAULT_DATASET_DIR,
37
+ help="需要检查的根目录(默认:/opt/data/chinese_celeb_dataset)",
38
+ )
39
+ parser.add_argument(
40
+ "--extensions",
41
+ type=str,
42
+ default=".jpg,.jpeg,.png,.webp,.bmp",
43
+ help="需要检查的图片扩展名,逗号分隔",
44
+ )
45
+ parser.add_argument(
46
+ "--confidence",
47
+ type=float,
48
+ default=config.FACE_CONFIDENCE,
49
+ help="YOLO 检测的人脸置信度阈值",
50
+ )
51
+ parser.add_argument(
52
+ "--dry-run",
53
+ action="store_true",
54
+ help="仅输出将被删除的文件,不真正删除,便于先预览结果",
55
+ )
56
+ parser.add_argument(
57
+ "--verbose",
58
+ action="store_true",
59
+ help="输出更多调试信息",
60
+ )
61
+ return parser.parse_args()
62
+
63
+
64
+ def load_yolo_model() -> YOLO:
65
+ """
66
+ 优先加载本地 models 目录下配置好的模型,如果不存在则回退为模型名称(会触发自动下载)。
67
+ """
68
+ candidates: List[str] = []
69
+ local_path = MODEL_DIR / YOLO_MODEL_NAME
70
+ if local_path.exists():
71
+ candidates.append(str(local_path))
72
+ candidates.append(YOLO_MODEL_NAME)
73
+
74
+ last_error: Optional[Exception] = None
75
+ for candidate in candidates:
76
+ try:
77
+ config.logger.info("尝试加载 YOLO 模型:%s", candidate)
78
+ return YOLO(candidate)
79
+ except Exception as exc: # pragma: no cover
80
+ last_error = exc
81
+ config.logger.warning("加载 YOLO 模型失败:%s -> %s", candidate, exc)
82
+
83
+ raise RuntimeError(f"无法加载 YOLO 模型:{YOLO_MODEL_NAME}") from last_error
84
+
85
+
86
+ def iter_image_files(root: Path, extensions: Iterable[str]) -> Iterable[Path]:
87
+ lower_exts = tuple(ext.strip().lower() for ext in extensions if ext.strip())
88
+ for path in root.rglob("*"):
89
+ if not path.is_file():
90
+ continue
91
+ if path.suffix.lower() in lower_exts:
92
+ yield path
93
+
94
+
95
+ def has_face(model: YOLO, image_path: Path, confidence: float, verbose: bool = False) -> bool:
96
+ """
97
+ 使用 YOLO 检测图片中是否存在人脸。检测到任意一个框即可视为有人脸。
98
+ """
99
+ try:
100
+ results = model(image_path, conf=confidence, verbose=False)
101
+ except Exception as exc: # pragma: no cover
102
+ config.logger.error("检测失败,跳过 %s:%s", image_path, exc)
103
+ return False
104
+
105
+ for result in results:
106
+ boxes = getattr(result, "boxes", None)
107
+ if boxes is None:
108
+ continue
109
+ if len(boxes) > 0:
110
+ if verbose:
111
+ faces = []
112
+ for box in boxes:
113
+ cls_id = int(box.cls[0]) if getattr(box, "cls", None) is not None else -1
114
+ score = float(box.conf[0]) if getattr(box, "conf", None) is not None else 0.0
115
+ faces.append({"cls": cls_id, "conf": score})
116
+ config.logger.info("检测到人脸:%s -> %s", image_path, faces)
117
+ return True
118
+ return False
119
+
120
+
121
+ def main() -> None:
122
+ args = parse_args()
123
+ dataset_dir: Path = args.dataset_dir.expanduser().resolve()
124
+ if not dataset_dir.exists():
125
+ raise SystemExit(f"目录不存在:{dataset_dir}")
126
+
127
+ model = load_yolo_model()
128
+ image_paths = list(iter_image_files(dataset_dir, args.extensions.split(",")))
129
+ total = len(image_paths)
130
+ if total == 0:
131
+ print(f"目录 {dataset_dir} 下没有匹配到图片文件")
132
+ return
133
+
134
+ removed = 0
135
+ errored = 0
136
+ for idx, image_path in enumerate(image_paths, start=1):
137
+ if idx % 100 == 0 or args.verbose:
138
+ print(f"[{idx}/{total}] 正在处理 {image_path}")
139
+
140
+ try:
141
+ if has_face(model, image_path, args.confidence, args.verbose):
142
+ continue
143
+ except Exception as exc: # pragma: no cover
144
+ errored += 1
145
+ config.logger.error("检测过程中发生异常,跳过 %s:%s", image_path, exc)
146
+ continue
147
+
148
+ if args.dry_run:
149
+ print(f"[DRY-RUN] 将删除:{image_path}")
150
+ else:
151
+ try:
152
+ image_path.unlink()
153
+ print(f"已删除:{image_path}")
154
+ except Exception as exc: # pragma: no cover
155
+ errored += 1
156
+ config.logger.error("删除失败 %s:%s", image_path, exc)
157
+ continue
158
+ removed += 1
159
+
160
+ print(
161
+ f"扫描完成,检测图片 {total} 张,删除 {removed} 张无脸图片,异常 {errored} 张,数据保存在:{dataset_dir}"
162
+ )
163
+
164
+
165
+ if __name__ == "__main__":
166
+ try:
167
+ main()
168
+ except KeyboardInterrupt: # pragma: no cover
169
+ sys.exit("用户中断")
test/remove_faceless_images.pyc ADDED
Binary file (5.17 kB). View file
 
test/test_deepface.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ from deepface import DeepFace
4
+
5
+ images_path = "/opt/data/face"
6
+
7
+ # ========== 2. 人脸相似度比对 ==========
8
+ start_time = time.time()
9
+ result_verification = DeepFace.verify(
10
+ img1_path=images_path + "/4.webp",
11
+ img2_path=images_path + "/5.webp",
12
+ model_name="ArcFace", # 指定模型
13
+ detector_backend="yolov11n", # 人脸检测器 retinaface / yolov8 / opencv / ssd / mediapipe
14
+ distance_metric="cosine" # 相似度度量
15
+ )
16
+ end_time = time.time()
17
+ print(f"🕒 人脸比对耗时: {end_time - start_time:.3f} 秒")
18
+
19
+ # 打印结果
20
+ print(json.dumps(result_verification, ensure_ascii=False, indent=2))
21
+
22
+
23
+ # ========== 1. 人脸识别 ==========
24
+
25
+ start_time = time.time()
26
+ result_recognition = DeepFace.find(
27
+ img_path=images_path + "/1.jpg", # 待识别人脸
28
+ db_path=images_path, # 数据库路径
29
+ model_name="ArcFace", # 指定模型
30
+ detector_backend="yolov11n", # 人脸检测器
31
+ distance_metric="cosine" # 相似度度量
32
+ )
33
+ end_time = time.time()
34
+ print(f"🕒 人脸识别耗时: {end_time - start_time:.3f} 秒")
35
+
36
+ # 如果需要打印结果,可以取消注释
37
+ # df = result_recognition[0]
38
+ # print(df.to_json(orient="records", force_ascii=False))
test/test_deepface.pyc ADDED
Binary file (769 Bytes). View file
 
test/test_main.http ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test your FastAPI endpoints
2
+
3
+ GET http://127.0.0.1:8000/
4
+ Accept: application/json
5
+
6
+ ###
7
+
8
+ GET http://127.0.0.1:8000/hello/User
9
+ Accept: application/json
10
+
11
+ ###
test/test_rvm_infer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from torchvision import transforms
7
+
8
+ device = "cpu"
9
+
10
+ # 输入输出路径
11
+ input_path = "/opt/data/face/yang.webp"
12
+ output_path = "/opt/data/face/output_alpha.webp"
13
+
14
+ # ✅ 加载预训练模型 (resnet50)
15
+ model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50").to(device).eval()
16
+
17
+ # 开始计时
18
+ start = time.time()
19
+
20
+ # 读图 (BGR->RGB)
21
+ img = cv2.imread(input_path)[:, :, ::-1].copy()
22
+ src = transforms.ToTensor()(img).unsqueeze(0).to(device)
23
+
24
+ # 推理
25
+ rec = [None] * 4
26
+ with torch.no_grad():
27
+ fgr, pha, *rec = model(src, *rec, downsample_ratio=0.25)
28
+
29
+ # 转 numpy
30
+ fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # (H,W,3)
31
+ pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) # (H,W)
32
+
33
+ # 拼接 RGBA
34
+ rgba = np.dstack((fgr, pha)) # (H,W,4)
35
+
36
+ # 保存 WebP (带透明度)
37
+ cv2.imwrite(output_path, rgba[:, :, [2,1,0,3]], [cv2.IMWRITE_WEBP_QUALITY, 100]) # 转成 BGRA 顺序
38
+
39
+ # 结束计时
40
+ elapsed = time.time() - start
41
+
42
+ # 控制台日志输出
43
+ print(f"✅ RVM 抠图完成 (透明背景)")
44
+ print(f" 输入文件: {input_path}")
45
+ print(f" 输出文件: {output_path}")
46
+ print(f" 耗时: {elapsed:.3f} 秒 (设备: {device})")
test/test_rvm_infer.pyc ADDED
Binary file (1.2 kB). View file
 
test/test_score.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+
4
+ import numpy as np
5
+ from retinaface import RetinaFace
6
+
7
+
8
+ def default_converter(o):
9
+ if isinstance(o, np.integer):
10
+ return int(o)
11
+ if isinstance(o, np.floating):
12
+ return float(o)
13
+ if isinstance(o, np.ndarray):
14
+ return o.tolist()
15
+ return str(o)
16
+
17
+
18
+ # 配置日志
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ resp = RetinaFace.detect_faces("~/Downloads/chounan.jpeg")
23
+
24
+ logger.info(
25
+ "search results: " + json.dumps(resp, ensure_ascii=False, default=default_converter)
26
+ )
test/test_score.pyc ADDED
Binary file (701 Bytes). View file
 
test/test_score_adjustment_demo.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def adjust_score(score, threshold, gamma):
2
+ """根据阈值和gamma值调整评分"""
3
+ if score < threshold:
4
+ adjusted = threshold - gamma * (threshold - score)
5
+ return round(min(10.0, max(0.0, adjusted)), 1)
6
+ return score
7
+
8
+ # 默认参数 (T=9.0, γ=0.5)
9
+ default_threshold = 9.0
10
+ default_gamma = 0.5
11
+
12
+ # 新参数1 (T=8.0, γ=0.5)
13
+ new_threshold_1 = 9
14
+ new_gamma_1 = 0.9
15
+
16
+ # 新参数2 (T=8.0, γ=0.3)
17
+ new_threshold_2 = 9
18
+ new_gamma_2 = 0.8
19
+
20
+ print(f"原始分\tT={default_threshold},y={default_gamma}\tT={new_threshold_1},γ={new_gamma_1}\tT={new_threshold_2},γ={new_gamma_2}")
21
+ print("-----\t----------\t----------\t----------")
22
+
23
+ # 从1.0到10.0,以0.1为步长
24
+ for i in range(10, 101):
25
+ score = i / 10.0
26
+ default_adjusted = adjust_score(score, default_threshold, default_gamma)
27
+ new_adjusted_1 = adjust_score(score, new_threshold_1, new_gamma_1)
28
+ new_adjusted_2 = adjust_score(score, new_threshold_2, new_gamma_2)
29
+ # 确保显示小数点
30
+ print(f"{score:.1f}\t\t\t{default_adjusted:.1f}\t\t\t\t\t{new_adjusted_1:.1f}\t\t\t\t\t{new_adjusted_2:.1f}")
test/test_score_adjustment_demo.pyc ADDED
Binary file (910 Bytes). View file
 
test/test_sky.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ import cv2
4
+ from modelscope.outputs import OutputKeys
5
+ from modelscope.pipelines import pipeline
6
+ from modelscope.utils.constant import Tasks
7
+
8
+ image_skychange = pipeline(Tasks.image_skychange,
9
+ model='iic/cv_hrnetocr_skychange')
10
+ result = image_skychange(
11
+ {'sky_image': '~/Downloads/sky_image.jpg',
12
+ 'scene_image': '/opt/data/face/NXEo0zusSaNB2fa232c84898e92ff165e2dfee59cb54.jpg'})
13
+ cv2.imwrite('~/Downloads/result.png',
14
+ result[OutputKeys.OUTPUT_IMG])
15
+ print(f'Output written to {osp.abspath("result.png")}')
test/test_sky.pyc ADDED
Binary file (683 Bytes). View file
 
test_tensorflow.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import deepface
2
+ import sys
3
+ import tensorflow as tf
4
+
5
+ try:
6
+ import keras
7
+
8
+ keras_pkg = "keras (standalone)"
9
+ keras_ver = keras.__version__
10
+ except Exception:
11
+ from tensorflow import keras
12
+
13
+ keras_pkg = "tf.keras"
14
+ keras_ver = keras.__version__
15
+
16
+ print("py =", sys.version)
17
+ print("deepface =", deepface.__version__)
18
+ print("tensorflow =", tf.__version__)
19
+ print("keras pkg =", keras_pkg, "keras =", keras_ver)