chawin.chen commited on
Commit
cd5aabe
·
1 Parent(s): 097061b
.gitignore ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HELP.md
2
+ target/
3
+ output/
4
+ !.mvn/wrapper/maven-wrapper.jar
5
+ !**/src/main/**/target/
6
+ !**/src/test/**/target/
7
+
8
+ .flattened-pom.xml
9
+
10
+ ### STS ###
11
+ .apt_generated
12
+ .classpath
13
+ .factorypath
14
+ .project
15
+ .settings
16
+ .springBeans
17
+ .sts4-cache
18
+
19
+ ### IntelliJ IDEA ###
20
+ .idea
21
+ *.iws
22
+ *.iml
23
+ *.ipr
24
+
25
+ ### NetBeans ###
26
+ /nbproject/private/
27
+ /nbbuild/
28
+ /dist/
29
+ /nbdist/
30
+ /.nb-gradle/
31
+ build/
32
+ !**/src/main/**/build/
33
+ !**/src/test/**/build/
34
+
35
+ ### VS Code ###
36
+ .vscode/
37
+
38
+ ### LOG ###
39
+ logs/
40
+
41
+ *.class
42
+
43
+ **/node_modules/
44
+ /*.log
45
+ /output/
46
+ /faiss/
47
+ /web/facelist-web/
48
+ **/._*
49
+ __pycache__/
50
+ .DS_Store
51
+ *.pth
52
+ /data/celebrity_faces/ds_model_arcface_detector_retinaface_aligned_normalization_base_expand_0.pkl
53
+ /data/celebrity_faces/jpeg_6c06eca6.jpeg
54
+ /data/celebrity_faces/jpeg_51e1394b.jpeg
55
+ /data/celebrity_faces/jpeg_66fee390.jpeg
56
+ /data/celebrity_faces/jpeg_70b86102.jpeg
57
+ /data/celebrity_faces/jpeg_406b961a.jpeg
58
+ /data/celebrity_faces/jpeg_1321f87f.jpeg
59
+ /data/celebrity_faces/jpeg_b56ae384.jpeg
60
+ /data/celebrity_faces/jpeg_c07cdb46.jpeg
61
+ /data/celebrity_faces/jpeg_c7353005.jpeg
62
+ /data/celebrity_faces/jpeg_d4cb0602.jpeg
63
+ /data/celebrity_faces/jpeg_dbb64030.jpeg
64
+ /data/celebrity_faces/jpeg_fc652ad4.jpeg
65
+ /data/celebrity_faces/jpeg_fd6b0869.jpeg
66
+ /data/celebrity_embeddings.db
Dockerfile ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV TZ=Asia/Shanghai \
4
+ OUTPUT_DIR=/opt/output \
5
+ IMAGES_DIR=/opt/images \
6
+ MODELS_PATH=/opt/models \
7
+ DEEPFACE_HOME=/opt/models \
8
+ FAISS_INDEX_DIR=/opt/faiss \
9
+ CELEBRITY_SOURCE_DIR=/opt/chinese_celeb_dataset \
10
+ GENDER_CONFIDENCE=1 \
11
+ UPSCALE_SIZE=2 \
12
+ AGE_CONFIDENCE=0.1 \
13
+ DRAW_SCORE=true \
14
+ FACE_CONFIDENCE=0.7 \
15
+ ENABLE_DDCOLOR=false \
16
+ ENABLE_GFPGAN=false \
17
+ ENABLE_REALESRGAN=false \
18
+ ENABLE_ANIME_STYLE=false \
19
+ ENABLE_RVM=false \
20
+ ENABLE_REMBG=false \
21
+ ENABLE_CLIP=false \
22
+ CLEANUP_INTERVAL_HOURS=1 \
23
+ CLEANUP_AGE_HOURS=1 \
24
+ BEAUTY_ADJUST_GAMMA=0.8 \
25
+ BEAUTY_ADJUST_MIN=1.0 \
26
+ BEAUTY_ADJUST_MAX=9.0 \
27
+ ENABLE_ANIME_PRELOAD=false \
28
+ ENABLE_LOGGING=true \
29
+ BEAUTY_ADJUST_ENABLED=true \
30
+ RVM_LOCAL_REPO=/app/RobustVideoMatting \
31
+ RVM_WEIGHTS_PATH=/opt/models/torch/hub/checkpoints/rvm_resnet50.pth \
32
+ RVM_MODEL=resnet50 \
33
+ AUTO_INIT_GFPGAN=false \
34
+ AUTO_INIT_DDCOLOR=false \
35
+ AUTO_INIT_REALESRGAN=false \
36
+ AUTO_INIT_ANIME_STYLE=false \
37
+ AUTO_INIT_CLIP=false \
38
+ AUTO_INIT_RVM=false \
39
+ AUTO_INIT_REMBG=false \
40
+ ENABLE_WARMUP=true \
41
+ REALESRGAN_MODEL=realesr-general-x4v3 \
42
+ CELEBRITY_FIND_THRESHOLD=0.87 \
43
+ FEMALE_AGE_ADJUSTMENT=4 \
44
+ HOSTNAME=HG
45
+
46
+ RUN mkdir -p /opt/chinese_celeb_dataset /opt/faiss /opt/models /opt/images /opt/output
47
+ WORKDIR /app
48
+ COPY requirements.txt .
49
+ COPY *.py /app/
50
+
51
+ # 安装必要的系统工具和依赖
52
+ RUN apt-get update && apt-get install -y --no-install-recommends \
53
+ build-essential \
54
+ cmake \
55
+ git \
56
+ wget \
57
+ curl \
58
+ ca-certificates \
59
+ libopenblas-dev \
60
+ liblapack-dev \
61
+ libx11-dev \
62
+ libgtk-3-dev \
63
+ libboost-python-dev \
64
+ libglib2.0-0 \
65
+ libsm6 \
66
+ libxext6 \
67
+ libxrender-dev \
68
+ libgomp1 \
69
+ && rm -rf /var/lib/apt/lists/*
70
+
71
+ RUN pip install --upgrade pip
72
+ # 安装所有依赖 - 现在可以一次性完成
73
+ RUN pip install --no-cache-dir -r requirements.txt
74
+ EXPOSE 7860
75
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860","--no-access-log","--log-level","critical"]
anime_stylizer.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+
5
+ import cv2
6
+
7
+ from config import logger
8
+
9
+
10
+ class AnimeStylizer:
11
+ def __init__(self):
12
+ start_time = time.perf_counter()
13
+ self.stylizers = {} # 存储不同风格的模型
14
+ self.current_style = None
15
+ self.current_stylizer = None
16
+
17
+ # 检查是否启用Anime Style功能
18
+ from config import ENABLE_ANIME_STYLE
19
+ if ENABLE_ANIME_STYLE:
20
+ self._initialize_models()
21
+ else:
22
+ logger.info("Anime Style feature is disabled, skipping model initialization")
23
+ init_time = time.perf_counter() - start_time
24
+ if hasattr(self, 'model_configs') and len(self.model_configs) > 0:
25
+ logger.info(f"AnimeStylizer initialized successfully, time: {init_time:.3f}s")
26
+ else:
27
+ logger.info(f"AnimeStylizer initialization completed but not available, time: {init_time:.3f}s")
28
+
29
+ def _initialize_models(self):
30
+ """初始化所有Anime Style模型(使用ModelScope)"""
31
+ try:
32
+ logger.info("Initializing multiple Anime Style models (using ModelScope)...")
33
+
34
+ # 添加torch类型兼容性补丁
35
+ import torch
36
+ if not hasattr(torch, 'uint64'):
37
+ logger.info("Adding torch.uint64 compatibility patch...")
38
+ torch.uint64 = torch.int64 # 使用int64作为uint64的替代
39
+ if not hasattr(torch, 'uint32'):
40
+ logger.info("Adding torch.uint32 compatibility patch...")
41
+ torch.uint32 = torch.int32 # 使用int32作为uint32的替代
42
+ if not hasattr(torch, 'uint16'):
43
+ logger.info("Adding torch.uint16 compatibility patch...")
44
+ torch.uint16 = torch.int16 # 使用int16作为uint16的替代
45
+
46
+ # 导入ModelScope相关模块
47
+ from modelscope.outputs import OutputKeys
48
+ from modelscope.pipelines import pipeline
49
+ from modelscope.utils.constant import Tasks
50
+
51
+ self.OutputKeys = OutputKeys
52
+
53
+ # 定义所有可用的模型和风格
54
+ self.model_configs = {
55
+ "handdrawn": {
56
+ "model_id": "iic/cv_unet_person-image-cartoon-handdrawn_compound-models",
57
+ "name": "手绘风格",
58
+ "description": "手绘动漫风格 - 传统手绘感觉,线条清晰"
59
+ },
60
+ "disney": {
61
+ "model_id": "iic/cv_unet_person-image-cartoon-3d_compound-models",
62
+ "name": "迪士尼风格",
63
+ "description": "迪士尼风格 - 立体感强,色彩鲜艳"
64
+ },
65
+ "illustration": {
66
+ "model_id": "iic/cv_unet_person-image-cartoon-sd-design_compound-models",
67
+ "name": "插画风格",
68
+ "description": "插画风格 - 现代插画设计感"
69
+ },
70
+ "artstyle": {
71
+ "model_id": "iic/cv_unet_person-image-cartoon-artstyle_compound-models",
72
+ "name": "艺术风格",
73
+ "description": "艺术风格 - 独特的艺术表现力"
74
+ },
75
+ "anime": {
76
+ "model_id": "iic/cv_unet_person-image-cartoon_compound-models",
77
+ "name": "二次元风格",
78
+ "description": "二次元风格 - 经典动漫角色风格"
79
+ },
80
+ "sketch": {
81
+ "model_id": "iic/cv_unet_person-image-cartoon-sketch_compound-models",
82
+ "name": "素描风格",
83
+ "description": "素描风格 - 黑白素描画效果"
84
+ }
85
+ }
86
+
87
+ logger.info(f"Defined {len(self.model_configs)} anime style model configurations")
88
+ logger.info("Models will be loaded on-demand when first used to save memory")
89
+
90
+ # 检查是否启用预加载
91
+ try:
92
+ from config import ENABLE_ANIME_PRELOAD
93
+ if ENABLE_ANIME_PRELOAD:
94
+ logger.info("Enabling anime style model preloading...")
95
+ self.preload_models()
96
+ else:
97
+ logger.info("Anime style model preloading is disabled, will be loaded on-demand when first used")
98
+ except ImportError:
99
+ logger.info("Anime style model preloading configuration not found, will be loaded on-demand when first used")
100
+
101
+ except ImportError as e:
102
+ logger.error(f"ModelScope module import failed: {e}")
103
+ self.model_configs = {}
104
+ except Exception as e:
105
+ logger.error(f"Anime Style model initialization failed: {e}")
106
+ self.model_configs = {}
107
+
108
+ def _load_model(self, style_type):
109
+ """按需加载指定风格的模型"""
110
+ if style_type not in self.model_configs:
111
+ logger.error(f"Unsupported style type: {style_type}")
112
+ return False
113
+
114
+ if style_type in self.stylizers:
115
+ logger.info(f"Model {style_type} already loaded, using directly")
116
+ return True
117
+
118
+ try:
119
+ from modelscope.pipelines import pipeline
120
+ from modelscope.utils.constant import Tasks
121
+
122
+ config = self.model_configs[style_type]
123
+ logger.info(f"Loading {config['name']} model: {config['model_id']}")
124
+
125
+ # 根据模型类型选择合适的任务类型
126
+ if "stable_diffusion" in config["model_id"]:
127
+ # Stable Diffusion 系列模型使用文生图任务类型
128
+ task_type = Tasks.text_to_image_synthesis
129
+ logger.info(f"Using text_to_image_synthesis task type to load Stable Diffusion model")
130
+ else:
131
+ # UNet 系列模型使用人像风格化任务
132
+ task_type = Tasks.image_portrait_stylization
133
+ logger.info(f"Using image_portrait_stylization task type to load UNet model")
134
+
135
+ stylizer = pipeline(task_type, model=config["model_id"])
136
+ self.stylizers[style_type] = stylizer
137
+
138
+ logger.info(f"{config['name']} model loaded successfully")
139
+ return True
140
+
141
+ except Exception as e:
142
+ logger.error(f"Failed to load {style_type} model: {e}")
143
+ return False
144
+
145
+ def preload_models(self, style_types=None):
146
+ """
147
+ 预加载指定的动漫风格模型
148
+ :param style_types: 要预加载的风格类型列表,如果为None则预加载所有模型
149
+ """
150
+ if not self.is_available():
151
+ logger.warning("Anime Style module is not available, cannot preload models")
152
+ return
153
+
154
+ if style_types is None:
155
+ style_types = list(self.model_configs.keys())
156
+ elif isinstance(style_types, str):
157
+ style_types = [style_types]
158
+
159
+ logger.info(f"Starting to preload anime style models: {style_types}")
160
+
161
+ successful_loads = []
162
+ failed_loads = []
163
+
164
+ for style_type in style_types:
165
+ if style_type not in self.model_configs:
166
+ logger.warning(f"Unknown style type: {style_type}, skipping preload")
167
+ failed_loads.append(style_type)
168
+ continue
169
+
170
+ try:
171
+ logger.info(f"Preloading model: {self.model_configs[style_type]['name']} ({style_type})")
172
+ if self._load_model(style_type):
173
+ successful_loads.append(style_type)
174
+ logger.info(f"✓ Successfully preloaded: {self.model_configs[style_type]['name']}")
175
+ else:
176
+ failed_loads.append(style_type)
177
+ logger.error(f"✗ Preload failed: {self.model_configs[style_type]['name']}")
178
+ except Exception as e:
179
+ logger.error(f"✗ Exception occurred while preloading model {style_type}: {e}")
180
+ failed_loads.append(style_type)
181
+
182
+ if successful_loads:
183
+ logger.info(f"Successfully preloaded models ({len(successful_loads)}): {successful_loads}")
184
+ if failed_loads:
185
+ logger.warning(f"Failed to preload models ({len(failed_loads)}): {failed_loads}")
186
+
187
+ logger.info(f"Anime style model preloading completed, success: {len(successful_loads)}/{len(style_types)}")
188
+
189
+ def get_loaded_models(self):
190
+ """
191
+ 获取已加载的模型列表
192
+ :return: 已加载的模型风格类型列表
193
+ """
194
+ return list(self.stylizers.keys())
195
+
196
+ def is_model_loaded(self, style_type):
197
+ """
198
+ 检查指定风格的模型是否已加载
199
+ :param style_type: 风格类型
200
+ :return: 是否已加载
201
+ """
202
+ return style_type in self.stylizers
203
+
204
+ def get_preload_status(self):
205
+ """
206
+ 获取模型预加载状态
207
+ :return: 包含预加载状态的字典
208
+ """
209
+ total_models = len(self.model_configs)
210
+ loaded_models = len(self.stylizers)
211
+
212
+ status = {
213
+ "total_models": total_models,
214
+ "loaded_models": loaded_models,
215
+ "preload_ratio": f"{loaded_models}/{total_models}",
216
+ "preload_percentage": round((loaded_models / total_models * 100) if total_models > 0 else 0, 1),
217
+ "available_styles": list(self.model_configs.keys()),
218
+ "loaded_styles": list(self.stylizers.keys()),
219
+ "unloaded_styles": [style for style in self.model_configs.keys() if style not in self.stylizers]
220
+ }
221
+
222
+ return status
223
+
224
+ def is_available(self):
225
+ """检查Anime Stylizer是否可用"""
226
+ return hasattr(self, 'model_configs') and len(self.model_configs) > 0
227
+
228
+ def stylize_image(self, image, style_type="disney"):
229
+ """
230
+ 对图像进行动漫风格化
231
+ :param image: 输入图像 (numpy array, BGR格式)
232
+ :param style_type: 动漫风格类型,支持的类型:
233
+ "handdrawn" - 手绘风格
234
+ "disney" - 迪士尼风格 (默认)
235
+ "illustration" - 插画风格
236
+ "flat" - 扁平风格
237
+ "clipart" - 剪贴画风格
238
+ "watercolor" - 水彩风格
239
+ "artstyle" - 艺术风格
240
+ "anime" - 二次元风格
241
+ "sketch" - 素描风格
242
+ :return: 动漫风格化后的图像 (numpy array, BGR格式)
243
+ """
244
+ if not self.is_available():
245
+ logger.error("Anime Style model not initialized")
246
+ return image
247
+
248
+ # 加载指定风格的模型
249
+ if not self._load_model(style_type):
250
+ logger.error(f"Failed to load {style_type} model")
251
+ return image
252
+
253
+ return self._stylize_image_via_file(image, style_type)
254
+
255
+ def _stylize_image_via_file(self, image, style_type="disney"):
256
+ """
257
+ 通过临时文件进行动漫风格化
258
+ :param image: 输入图像 (numpy array, BGR格式)
259
+ :param style_type: 动漫风格类型
260
+ :return: 动漫风格化后的图像 (numpy array, BGR格式)
261
+ """
262
+ try:
263
+ config = self.model_configs.get(style_type, {})
264
+ style_name = config.get('name', style_type)
265
+ logger.info(f"Using anime stylization processing, style type: {style_name} ({style_type})")
266
+
267
+ # 验证风格类型
268
+ if style_type not in self.model_configs:
269
+ logger.warning(f"Invalid style type: {style_type}, using default style disney")
270
+ style_type = "disney"
271
+
272
+ # 使用最高质量设置保存临时图像
273
+ with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_input:
274
+ # 使用WebP格式,最高质量设置
275
+ cv2.imwrite(tmp_input.name, image, [cv2.IMWRITE_WEBP_QUALITY, 100])
276
+ tmp_input_path = tmp_input.name
277
+
278
+ try:
279
+ logger.info(f"Temporary file saved to: {tmp_input_path}")
280
+
281
+ # 使用ModelScope进行动漫风格化
282
+ stylizer = self.stylizers[style_type]
283
+
284
+ # 根据模型类型使用不同的调用方式
285
+ if "stable_diffusion" in config["model_id"]:
286
+ # Stable Diffusion模型需要特殊处理
287
+ logger.info("Using Stable Diffusion model, text parameter is required")
288
+ # 对于Stable Diffusion,必须使用'sks style'格式的提示词
289
+ style_prompts = {}
290
+ prompt = style_prompts.get(style_type, "sks style, cartoon style artwork")
291
+ logger.info(f"Using prompt: {prompt}")
292
+ result = stylizer({"text": prompt})
293
+ else:
294
+ # UNet模型直接处理
295
+ result = stylizer(tmp_input_path)
296
+
297
+ # 获取风格化后的图像
298
+ # 不同模型的输出键名可能不同,需要适配
299
+ if "stable_diffusion" in config["model_id"]:
300
+ # Stable Diffusion模型通常使用不同的输出键名
301
+ logger.info(f"Stable Diffusion model output keys: {list(result.keys())}")
302
+ if 'output_imgs' in result:
303
+ stylized_image = result['output_imgs'][0]
304
+ elif 'output_img' in result:
305
+ stylized_image = result['output_img']
306
+ elif self.OutputKeys.OUTPUT_IMG in result:
307
+ stylized_image = result[self.OutputKeys.OUTPUT_IMG]
308
+ else:
309
+ # 尝试获取第一个图像输出
310
+ for key in result.keys():
311
+ if isinstance(result[key], (list, tuple)) and len(result[key]) > 0:
312
+ stylized_image = result[key][0]
313
+ logger.info(f"Using output key: {key}")
314
+ break
315
+ elif hasattr(result[key], 'shape'):
316
+ stylized_image = result[key]
317
+ logger.info(f"Using output key: {key}")
318
+ break
319
+ else:
320
+ raise KeyError(f"未找到有效的图像输出键,可用键: {list(result.keys())}")
321
+ else:
322
+ # UNet模型使用标准输出键
323
+ stylized_image = result[self.OutputKeys.OUTPUT_IMG]
324
+
325
+ logger.info(f"Anime stylization output: size={stylized_image.shape}, type={stylized_image.dtype}")
326
+
327
+ # ModelScope输出的图像已经是BGR格式,不需要转换
328
+ logger.info("Anime stylization processing completed")
329
+ return stylized_image
330
+
331
+ finally:
332
+ # 清理临时文件
333
+ try:
334
+ os.unlink(tmp_input_path)
335
+ except:
336
+ pass
337
+
338
+ except Exception as e:
339
+ logger.error(f"Anime stylization processing failed: {e}")
340
+ logger.info("Returning original image")
341
+ return image
342
+
343
+ def get_available_styles(self):
344
+ """
345
+ 获取支持的动漫风格类型
346
+ :return: 字典,包含风格代码和描述
347
+ """
348
+ if not hasattr(self, 'model_configs'):
349
+ return {}
350
+
351
+ return {
352
+ style_type: f"{config['name']} - {config['description'].split(' - ')[1]}"
353
+ for style_type, config in self.model_configs.items()
354
+ }
355
+
356
+ def save_debug_image(self, image, filename_prefix):
357
+ """保存调试用的图像"""
358
+ try:
359
+ debug_path = f"{filename_prefix}_debug.webp"
360
+ cv2.imwrite(debug_path, image, [cv2.IMWRITE_WEBP_QUALITY, 95])
361
+ logger.info(f"Debug image saved: {debug_path}")
362
+ return debug_path
363
+ except Exception as e:
364
+ logger.error(f"Failed to save debug image: {e}")
365
+ return None
366
+
367
+ def test_stylization(self, test_url=None):
368
+ """
369
+ 测试动漫风格化功能
370
+ :param test_url: 测试图像URL,默认使用官方示例
371
+ :return: 测试结果
372
+ """
373
+ if not self.is_available():
374
+ return False, "Anime Style模型未初始化"
375
+
376
+ try:
377
+ test_url = test_url or 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/portrait.jpg'
378
+ logger.info(f"Testing anime stylization feature, using image: {test_url}")
379
+
380
+ # 测试默认风格
381
+ result = self.stylizer(test_url)
382
+ stylized_img = result[self.OutputKeys.OUTPUT_IMG]
383
+
384
+ # 保存测试结果
385
+ test_output_path = 'anime_style_test_result.webp'
386
+ cv2.imwrite(test_output_path, stylized_img, [cv2.IMWRITE_WEBP_QUALITY, 95])
387
+
388
+ logger.info(f"Anime stylization test successful, result saved to: {test_output_path}")
389
+ return True, f"测试成功,结果保存到: {test_output_path}"
390
+
391
+ except Exception as e:
392
+ logger.error(f"Anime stylization test failed: {e}")
393
+ return False, f"测试失败: {e}"
394
+
395
+ def test_local_image(self, image_path, style_type="disney"):
396
+ """
397
+ 测试本地图像动漫风格化
398
+ :param image_path: 本地图像路径
399
+ :param style_type: 动漫风格类型
400
+ :return: 测试结果
401
+ """
402
+ if not self.is_available():
403
+ return False, "Anime Style模型未初始化"
404
+
405
+ try:
406
+ logger.info(f"Testing local image anime stylization: {image_path}, style: {style_type}")
407
+
408
+ # 读取本地图像
409
+ image = cv2.imread(image_path)
410
+ if image is None:
411
+ return False, f"Unable to read image: {image_path}"
412
+
413
+ # 保存原图用于对比
414
+ self.save_debug_image(image, "original")
415
+
416
+ # 动漫风格化处理
417
+ stylized_image = self.stylize_image(image, style_type)
418
+
419
+ # 保存风格化结果
420
+ result_path = self.save_debug_image(stylized_image, f"anime_style_{style_type}")
421
+
422
+ logger.info(f"Local image anime stylization successful, result saved to: {result_path}")
423
+ return True, f"本地图像动漫风格化成功,结果保存到: {result_path}"
424
+
425
+ except Exception as e:
426
+ logger.error(f"Local image anime stylization failed: {e}")
427
+ return False, f"本地图像动漫风格化失败: {e}"
api_routes.py ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from contextlib import asynccontextmanager
4
+
5
+ from fastapi import FastAPI
6
+ from starlette.middleware.cors import CORSMiddleware
7
+
8
+ from cleanup_scheduler import start_cleanup_scheduler, stop_cleanup_scheduler
9
+ from config import (
10
+ logger,
11
+ OUTPUT_DIR,
12
+ DEEPFACE_AVAILABLE,
13
+ DLIB_AVAILABLE,
14
+ MODELS_PATH,
15
+ IMAGES_DIR,
16
+ YOLO_AVAILABLE,
17
+ ENABLE_LOGGING,
18
+ )
19
+ from database import close_mysql_pool, init_mysql_pool
20
+
21
+ logger.info("Starting to import api_routes module...")
22
+ try:
23
+ t_start = time.perf_counter()
24
+ from api_routes import api_router
25
+ import_time = time.perf_counter() - t_start
26
+ logger.info(f"api_routes module imported successfully, time: {import_time:.3f}s")
27
+ except Exception as e:
28
+ import_time = time.perf_counter() - t_start
29
+ logger.error(f"api_routes module import failed, time: {import_time:.3f}s, error: {e}")
30
+ raise
31
+
32
+
33
+ @asynccontextmanager
34
+ async def lifespan(app: FastAPI):
35
+ start_time = time.perf_counter()
36
+ logger.info("FaceScore service starting...")
37
+ logger.info(f"Output directory: {OUTPUT_DIR}")
38
+ logger.info(f"DeepFace available: {DEEPFACE_AVAILABLE}")
39
+ logger.info(f"YOLO available: {YOLO_AVAILABLE}")
40
+ logger.info(f"MediaPipe available: {DLIB_AVAILABLE}")
41
+ logger.debug(f"Archive directory: {IMAGES_DIR}")
42
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
43
+
44
+ # 初始化数据库连接池
45
+ try:
46
+ await init_mysql_pool()
47
+ logger.info("MySQL 连接池初始化完成")
48
+ except Exception as exc:
49
+ logger.error(f"初始化 MySQL 连接池失败: {exc}")
50
+ raise
51
+
52
+ # 启动图片清理定时任务
53
+ logger.info("Starting image cleanup scheduled task...")
54
+ try:
55
+ start_cleanup_scheduler()
56
+ logger.info("Image cleanup scheduled task started successfully")
57
+ except Exception as e:
58
+ logger.error(f"Failed to start image cleanup scheduled task: {e}")
59
+
60
+ # 记录启动完成时间
61
+ total_startup_time = time.perf_counter() - start_time
62
+ logger.info(f"FaceScore service startup completed, total time: {total_startup_time:.3f}s")
63
+
64
+ yield
65
+
66
+ # 应用关闭时停止定时任务
67
+ logger.info("Stopping image cleanup scheduled task...")
68
+ try:
69
+ stop_cleanup_scheduler()
70
+ logger.info("Image cleanup scheduled task stopped")
71
+ except Exception as e:
72
+ logger.error(f"Failed to stop image cleanup scheduled task: {e}")
73
+
74
+ # 关闭数据库连接池
75
+ try:
76
+ await close_mysql_pool()
77
+ except Exception as exc:
78
+ logger.warning(f"关闭 MySQL 连接池失败: {exc}")
79
+
80
+
81
+ # 创建 FastAPI 应用
82
+ app = FastAPI(
83
+ title="Enhanced FaceScore 服务",
84
+ description="支持多模型的人脸分析REST API服务,包含五官评分功能。支持混合模式:HowCuteAmI(颜值+性别)+ DeepFace(年龄+情绪)",
85
+ version="3.0.0",
86
+ docs_url="/cp_docs",
87
+ redoc_url="/cp_redoc",
88
+ lifespan=lifespan,
89
+ )
90
+
91
+ app.add_middleware(
92
+ CORSMiddleware,
93
+ allow_origins=["*"],
94
+ allow_methods=["*"],
95
+ allow_headers=["*"],
96
+ )
97
+
98
+ # 注册路由
99
+ app.include_router(api_router)
100
+
101
+ # 添加根路径处理
102
+ @app.get("/")
103
+ async def root():
104
+ return "UP"
105
+
106
+
107
+ if __name__ == "__main__":
108
+ import uvicorn
109
+
110
+ if not os.path.exists(MODELS_PATH):
111
+ logger.critical(
112
+ "Warning: 'models' directory not found. Please ensure it exists and contains model files."
113
+ )
114
+ logger.critical(
115
+ "Exiting application as FaceAnalyzer cannot be initialized without models."
116
+ )
117
+ exit(1)
118
+
119
+ # 根据日志开关配置 Uvicorn 日志
120
+ if ENABLE_LOGGING:
121
+ uvicorn.run(app, host="0.0.0.0", port=8080, reload=False)
122
+ else:
123
+ # 禁用 Uvicorn 的访问日志和错误日志
124
+ uvicorn.run(
125
+ app,
126
+ host="0.0.0.0",
127
+ port=8080,
128
+ reload=False,
129
+ access_log=False, # 禁用访问日志
130
+ log_level="critical" # 只显示严重错误
131
+ )
cleanup_scheduler.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 定时清理图片文件模块
3
+ 每小时检查一次IMAGES_DIR目录,删除1小时以前的图片文件
4
+ """
5
+ import glob
6
+ import os
7
+ import time
8
+ from datetime import datetime
9
+
10
+ from apscheduler.schedulers.background import BackgroundScheduler
11
+
12
+ from config import logger, IMAGES_DIR, CLEANUP_INTERVAL_HOURS, CLEANUP_AGE_HOURS
13
+
14
+
15
+ # from utils import delete_file_from_bos # 暂时注释掉删除BOS文件的功能
16
+
17
+
18
+ class ImageCleanupScheduler:
19
+ """图片清理定时任务类"""
20
+
21
+ def __init__(self, images_dir=None, cleanup_hours=None, interval_hours=None):
22
+ """
23
+ 初始化清理调度器
24
+
25
+ Args:
26
+ images_dir (str): 图片目录路径,默认使用config中的IMAGES_DIR
27
+ cleanup_hours (float): 清理时间阈值(小时),默认使用环境变量CLEANUP_AGE_HOURS
28
+ interval_hours (float): 定时任务执行间隔(小时),默认使用环境变量CLEANUP_INTERVAL_HOURS
29
+ """
30
+ self.images_dir = images_dir or IMAGES_DIR
31
+ self.cleanup_hours = cleanup_hours if cleanup_hours is not None else CLEANUP_AGE_HOURS
32
+ self.interval_hours = interval_hours if interval_hours is not None else CLEANUP_INTERVAL_HOURS
33
+ self.scheduler = BackgroundScheduler()
34
+ self.is_running = False
35
+
36
+ # 确保目录存在
37
+ os.makedirs(self.images_dir, exist_ok=True)
38
+ logger.info(f"Image cleanup scheduler initialized, monitoring directory: {self.images_dir}, cleanup threshold: {self.cleanup_hours} hours, execution interval: {self.interval_hours} hours")
39
+
40
+ def cleanup_old_images(self):
41
+ """
42
+ 清理过期的图片文件
43
+ 删除超过指定时间的图片文件
44
+ """
45
+ try:
46
+ current_time = time.time()
47
+ cutoff_time = current_time - (self.cleanup_hours * 3600) # 转换为秒
48
+ cutoff_datetime = datetime.fromtimestamp(cutoff_time)
49
+
50
+ # 支持的图片格式
51
+ image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.gif', '*.bmp']
52
+ deleted_files = []
53
+ total_size_deleted = 0
54
+
55
+ logger.info(f"Starting to clean image directory: {self.images_dir}")
56
+ logger.info(f"Cleanup threshold time: {cutoff_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
57
+
58
+ # 遍历所有图片文件
59
+ for extension in image_extensions:
60
+ pattern = os.path.join(self.images_dir, extension)
61
+ for file_path in glob.glob(pattern):
62
+ try:
63
+ # 获取文件修改时间
64
+ file_mtime = os.path.getmtime(file_path)
65
+
66
+ # 如果文件时间早于阈值时间,则删除
67
+ if file_mtime < cutoff_time:
68
+ file_size = os.path.getsize(file_path)
69
+ file_time = datetime.fromtimestamp(file_mtime)
70
+
71
+ # 删除文件
72
+ os.remove(file_path)
73
+ # delete_file_from_bos(file_path) # 暂时注释掉删除BOS文件
74
+ deleted_files.append(os.path.basename(file_path))
75
+ total_size_deleted += file_size
76
+
77
+ logger.debug(f"Deleting expired file: {os.path.basename(file_path)} ")
78
+
79
+ except (OSError, IOError) as e:
80
+ logger.error(f"Failed to delete file {os.path.basename(file_path)}: {e}")
81
+ continue
82
+
83
+ logger.info(f"Cleanup completed! Deleted {len(deleted_files)} files, ")
84
+ logger.debug(f"Deleted file list: {', '.join(deleted_files[:10])}")
85
+ else:
86
+ logger.info("Cleanup completed! No expired files found to clean")
87
+
88
+ return {
89
+ 'success': True,
90
+ 'deleted_count': len(deleted_files),
91
+ 'deleted_size': total_size_deleted,
92
+ 'deleted_files': deleted_files,
93
+ 'cutoff_time': cutoff_datetime.isoformat()
94
+ }
95
+
96
+ except Exception as e:
97
+ error_msg = f"图片清理任务执行失败: {e}"
98
+ logger.error(error_msg)
99
+ return {
100
+ 'success': False,
101
+ 'error': str(e),
102
+ 'deleted_count': 0,
103
+ 'deleted_size': 0
104
+ }
105
+
106
+ def _format_size(self, size_bytes):
107
+ """格式化文件大小显示"""
108
+ if size_bytes == 0:
109
+ return "0 B"
110
+ size_names = ["B", "KB", "MB", "GB"]
111
+ i = 0
112
+ while size_bytes >= 1024 and i < len(size_names) - 1:
113
+ size_bytes /= 1024.0
114
+ i += 1
115
+ return f"{size_bytes:.1f} {size_names[i]}"
116
+
117
+ def start(self):
118
+ """启动定时清理任务"""
119
+ if self.is_running:
120
+ logger.warning("Image cleanup scheduler is already running")
121
+ return
122
+
123
+ try:
124
+ # 添加定时任务:使用可配置的��行间隔
125
+ self.scheduler.add_job(
126
+ func=self.cleanup_old_images,
127
+ trigger='interval',
128
+ hours=self.interval_hours, # 使用环境变量配置的执行间隔
129
+ id='image_cleanup',
130
+ name='image clean tast',
131
+ replace_existing=True
132
+ )
133
+
134
+ # 启动调度器
135
+ self.scheduler.start()
136
+ self.is_running = True
137
+
138
+ logger.info(f"Image cleanup scheduler started, will execute cleanup task every {self.interval_hours} hours")
139
+
140
+ # 立即执行一次清理(可选)
141
+ logger.info("Executing image cleanup task immediately...")
142
+ self.cleanup_old_images()
143
+
144
+ except Exception as e:
145
+ logger.error(f"Failed to start image cleanup scheduler: {e}")
146
+ raise
147
+
148
+ def stop(self):
149
+ """停止定时清理任务"""
150
+ if not self.is_running:
151
+ logger.warning("Image cleanup scheduler is not running")
152
+ return
153
+
154
+ try:
155
+ self.scheduler.shutdown(wait=False)
156
+ self.is_running = False
157
+ logger.info("Image cleanup scheduler stopped")
158
+ except Exception as e:
159
+ logger.error(f"Failed to stop image cleanup scheduler: {e}")
160
+
161
+ def get_status(self):
162
+ """获取调度器状态"""
163
+ return {
164
+ 'running': self.is_running,
165
+ 'images_dir': self.images_dir,
166
+ 'cleanup_hours': self.cleanup_hours,
167
+ 'interval_hours': self.interval_hours,
168
+ 'next_run': self.scheduler.get_jobs()[0].next_run_time.isoformat()
169
+ if self.is_running and self.scheduler.get_jobs() else None
170
+ }
171
+
172
+
173
+ # 创建全局调度器实例
174
+ cleanup_scheduler = ImageCleanupScheduler()
175
+
176
+
177
+ def start_cleanup_scheduler():
178
+ """启动图片清理调度器"""
179
+ cleanup_scheduler.start()
180
+
181
+
182
+ def stop_cleanup_scheduler():
183
+ """停止图片清理调度器"""
184
+ cleanup_scheduler.stop()
185
+
186
+
187
+ def get_cleanup_status():
188
+ """获取清理调度器状态"""
189
+ return cleanup_scheduler.get_status()
190
+
191
+
192
+ def manual_cleanup():
193
+ """手动执行一次清理"""
194
+ return cleanup_scheduler.cleanup_old_images()
195
+
196
+
197
+ if __name__ == "__main__":
198
+ # 测试代码
199
+ print("测试图片清理功能...")
200
+ test_scheduler = ImageCleanupScheduler()
201
+ result = test_scheduler.cleanup_old_images()
202
+ print(f"清理结果: {result}")
clip_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # clip_utils.py
2
+ import logging
3
+ import os
4
+ from typing import Union, List
5
+
6
+ import cn_clip.clip as clip
7
+ import torch
8
+ from PIL import Image
9
+ from cn_clip.clip import load_from_name
10
+
11
+ from config import MODELS_PATH
12
+
13
+ # 配置日志
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # 环境变量配置
18
+ MODEL_NAME_CN = os.environ.get('MODEL_NAME_CN', 'ViT-B-16')
19
+
20
+ # 设备配置
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ # 模型初始化
24
+ model = None
25
+ preprocess = None
26
+
27
+ def init_clip_model():
28
+ """初始化CLIP模型"""
29
+ global model, preprocess
30
+ try:
31
+ model, preprocess = load_from_name(MODEL_NAME_CN, device=device, download_root=MODELS_PATH)
32
+ model.eval()
33
+ logger.info(f"CLIP model initialized successfully, dimension: {model.visual.output_dim}")
34
+ return True
35
+ except Exception as e:
36
+ logger.error(f"CLIP model initialization failed: {e}")
37
+ return False
38
+
39
+ def is_clip_available():
40
+ """检查CLIP模型是否可用"""
41
+ return model is not None and preprocess is not None
42
+
43
+ def encode_image(image_path: str) -> torch.Tensor:
44
+ """编码图片为向量"""
45
+ if not is_clip_available():
46
+ raise RuntimeError("CLIP模型未初始化")
47
+
48
+ image = Image.open(image_path).convert("RGB")
49
+ image_tensor = preprocess(image).unsqueeze(0).to(device)
50
+ with torch.no_grad():
51
+ features = model.encode_image(image_tensor)
52
+ features = features / features.norm(p=2, dim=-1, keepdim=True)
53
+ return features.cpu()
54
+
55
+ def encode_text(text: Union[str, List[str]]) -> torch.Tensor:
56
+ """编码文本为向量"""
57
+ if not is_clip_available():
58
+ raise RuntimeError("CLIP模型未初始化")
59
+
60
+ texts = [text] if isinstance(text, str) else text
61
+ text_tokens = clip.tokenize(texts).to(device)
62
+ with torch.no_grad():
63
+ features = model.encode_text(text_tokens)
64
+ features = features / features.norm(p=2, dim=-1, keepdim=True)
65
+ return features.cpu()
config.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ # 解决OpenMP库冲突问题
5
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
6
+ # 设置CPU线程数为CPU核心数,提高CPU利用率
7
+ import multiprocessing
8
+ cpu_cores = multiprocessing.cpu_count()
9
+ os.environ["OMP_NUM_THREADS"] = str(min(cpu_cores, 8)) # 最多使用8个线程
10
+ os.environ["MKL_NUM_THREADS"] = str(min(cpu_cores, 8))
11
+ os.environ["NUMEXPR_NUM_THREADS"] = str(min(cpu_cores, 8))
12
+
13
+ # 修复torchvision兼容性问题
14
+ try:
15
+ import torchvision.transforms.functional_tensor
16
+ except ImportError:
17
+ # 为缺失的functional_tensor模块创建兼容性补丁
18
+ import torchvision.transforms.functional as F
19
+ import torchvision.transforms as transforms
20
+ import sys
21
+ from types import ModuleType
22
+
23
+ # 创建functional_tensor模块
24
+ functional_tensor = ModuleType('torchvision.transforms.functional_tensor')
25
+
26
+ # 添加常用的函数映射
27
+ if hasattr(F, 'rgb_to_grayscale'):
28
+ functional_tensor.rgb_to_grayscale = F.rgb_to_grayscale
29
+ if hasattr(F, 'adjust_brightness'):
30
+ functional_tensor.adjust_brightness = F.adjust_brightness
31
+ if hasattr(F, 'adjust_contrast'):
32
+ functional_tensor.adjust_contrast = F.adjust_contrast
33
+ if hasattr(F, 'adjust_saturation'):
34
+ functional_tensor.adjust_saturation = F.adjust_saturation
35
+ if hasattr(F, 'normalize'):
36
+ functional_tensor.normalize = F.normalize
37
+ if hasattr(F, 'resize'):
38
+ functional_tensor.resize = F.resize
39
+ if hasattr(F, 'crop'):
40
+ functional_tensor.crop = F.crop
41
+ if hasattr(F, 'pad'):
42
+ functional_tensor.pad = F.pad
43
+
44
+ # 将模块添加到sys.modules
45
+ sys.modules['torchvision.transforms.functional_tensor'] = functional_tensor
46
+ transforms.functional_tensor = functional_tensor
47
+
48
+ # 环境变量配置 - 禁用TensorFlow优化和GPU
49
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
50
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
51
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # 强制使用CPU
52
+ os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false"
53
+
54
+ # 修复PyTorch兼容性问题
55
+ try:
56
+ import torch
57
+ import torch.onnx
58
+
59
+ # 修复GFPGAN的ONNX兼容性
60
+ if not hasattr(torch.onnx._internal.exporter, 'ExportOptions'):
61
+ from types import SimpleNamespace
62
+ torch.onnx._internal.exporter.ExportOptions = SimpleNamespace
63
+
64
+ # 修复ModelScope的PyTree兼容性 - 更完整的实现
65
+ import torch.utils
66
+ if not hasattr(torch.utils, '_pytree'):
67
+ # 如果_pytree模块不存在,创建一个
68
+ from types import ModuleType
69
+ torch.utils._pytree = ModuleType('_pytree')
70
+
71
+ pytree = torch.utils._pytree
72
+
73
+ if not hasattr(pytree, 'register_pytree_node'):
74
+ def register_pytree_node(typ, flatten_fn, unflatten_fn, *, flatten_with_keys_fn=None, **kwargs):
75
+ """兼容性实现:注册PyTree节点类型"""
76
+ pass # 简单实现,不做实际操作
77
+ pytree.register_pytree_node = register_pytree_node
78
+
79
+ if not hasattr(pytree, 'tree_flatten'):
80
+ def tree_flatten(tree, is_leaf=None):
81
+ """兼容性实现:展平树结构"""
82
+ if isinstance(tree, (list, tuple)):
83
+ flat = []
84
+ spec = []
85
+ for i, item in enumerate(tree):
86
+ if isinstance(item, (list, tuple, dict)):
87
+ sub_flat, sub_spec = tree_flatten(item, is_leaf)
88
+ flat.extend(sub_flat)
89
+ spec.append((i, sub_spec))
90
+ else:
91
+ flat.append(item)
92
+ spec.append((i, None))
93
+ return flat, (type(tree), spec)
94
+ elif isinstance(tree, dict):
95
+ flat = []
96
+ spec = []
97
+ for key, value in sorted(tree.items()):
98
+ if isinstance(value, (list, tuple, dict)):
99
+ sub_flat, sub_spec = tree_flatten(value, is_leaf)
100
+ flat.extend(sub_flat)
101
+ spec.append((key, sub_spec))
102
+ else:
103
+ flat.append(value)
104
+ spec.append((key, None))
105
+ return flat, (dict, spec)
106
+ else:
107
+ return [tree], None
108
+ pytree.tree_flatten = tree_flatten
109
+
110
+ if not hasattr(pytree, 'tree_unflatten'):
111
+ def tree_unflatten(values, spec):
112
+ """兼容性实现:重构树结构"""
113
+ if spec is None:
114
+ return values[0] if values else None
115
+
116
+ tree_type, tree_spec = spec
117
+ if tree_type in (list, tuple):
118
+ result = []
119
+ value_idx = 0
120
+ for pos, sub_spec in tree_spec:
121
+ if sub_spec is None:
122
+ result.append(values[value_idx])
123
+ value_idx += 1
124
+ else:
125
+ # 计算子树需要的值数量
126
+ sub_count = _count_tree_values(sub_spec)
127
+ sub_values = values[value_idx:value_idx + sub_count]
128
+ result.append(tree_unflatten(sub_values, sub_spec))
129
+ value_idx += sub_count
130
+ return tree_type(result)
131
+ elif tree_type == dict:
132
+ result = {}
133
+ value_idx = 0
134
+ for key, sub_spec in tree_spec:
135
+ if sub_spec is None:
136
+ result[key] = values[value_idx]
137
+ value_idx += 1
138
+ else:
139
+ sub_count = _count_tree_values(sub_spec)
140
+ sub_values = values[value_idx:value_idx + sub_count]
141
+ result[key] = tree_unflatten(sub_values, sub_spec)
142
+ value_idx += sub_count
143
+ return result
144
+ return values[0] if values else None
145
+ pytree.tree_unflatten = tree_unflatten
146
+
147
+ if not hasattr(pytree, 'tree_map'):
148
+ def tree_map(fn, tree, *other_trees, is_leaf=None):
149
+ """兼容性实现:树映射"""
150
+ flat, spec = tree_flatten(tree, is_leaf)
151
+ if other_trees:
152
+ other_flats = [tree_flatten(t, is_leaf)[0] for t in other_trees]
153
+ mapped = [fn(x, *others) for x, *others in zip(flat, *other_flats)]
154
+ else:
155
+ mapped = [fn(x) for x in flat]
156
+ return tree_unflatten(mapped, spec)
157
+ pytree.tree_map = tree_map
158
+
159
+ # 辅助函数
160
+ def _count_tree_values(spec):
161
+ """计算树规格中的值数量"""
162
+ if spec is None:
163
+ return 1
164
+ tree_type, tree_spec = spec
165
+ return sum(_count_tree_values(sub_spec) if sub_spec else 1 for _, sub_spec in tree_spec)
166
+
167
+ # 修复pyarrow兼容性问题
168
+ try:
169
+ import pyarrow
170
+ if not hasattr(pyarrow, 'PyExtensionType'):
171
+ # 为旧版本pyarrow添加PyExtensionType兼容性
172
+ pyarrow.PyExtensionType = type('PyExtensionType', (), {})
173
+ except ImportError:
174
+ pass
175
+
176
+ except (ImportError, AttributeError) as e:
177
+ print(f"Warning: PyTorch/PyArrow compatibility patch failed: {e}")
178
+ pass
179
+ IMAGES_DIR = os.environ.get("IMAGES_DIR", "~/app/data/images")
180
+ OUTPUT_DIR = IMAGES_DIR
181
+
182
+ # 明星图库目录配置
183
+ CELEBRITY_SOURCE_DIR = os.environ.get(
184
+ "CELEBRITY_SOURCE_DIR", "~/apps/chinese_celeb_imgs"
185
+ ).strip()
186
+ if CELEBRITY_SOURCE_DIR:
187
+ CELEBRITY_SOURCE_DIR = os.path.expanduser(CELEBRITY_SOURCE_DIR)
188
+
189
+ CELEBRITY_FIND_THRESHOLD = float(
190
+ os.environ.get("CELEBRITY_FIND_THRESHOLD", 0.88)
191
+ )
192
+
193
+ # BOS 对象存储配置(默认存储为Base64编码字符串)
194
+ BOS_ACCESS_KEY = os.environ.get(
195
+ "BOS_ACCESS_KEY", "YjljNWQxYjZiMDdiNDU5ZGIzNGZmNjdlMzMzY2QxZDE="
196
+ ).strip()
197
+ BOS_SECRET_KEY = os.environ.get(
198
+ "BOS_SECRET_KEY", "MGE4Y2Y1ZTk5MDQ4NGYyMTk4NmVmODM5MjI4Y2U0N2I="
199
+ ).strip()
200
+ BOS_ENDPOINT = os.environ.get(
201
+ "BOS_ENDPOINT", "https://s3.bj.bcebos.com"
202
+ ).strip()
203
+ BOS_BUCKET_NAME = os.environ.get("BOS_BUCKET_NAME", "hbgs-travel").strip()
204
+ BOS_IMAGE_DIR = os.environ.get("BOS_IMAGE_DIR", "20220808").strip()
205
+ _bos_enabled_env = os.environ.get("BOS_UPLOAD_ENABLED")
206
+ if _bos_enabled_env is not None:
207
+ BOS_UPLOAD_ENABLED = _bos_enabled_env.lower() in ("1", "true", "on")
208
+ else:
209
+ BOS_UPLOAD_ENABLED = all(
210
+ [
211
+ BOS_ACCESS_KEY.strip(),
212
+ BOS_SECRET_KEY.strip(),
213
+ BOS_ENDPOINT,
214
+ BOS_BUCKET_NAME,
215
+ ]
216
+ )
217
+ APP_SECRET_TOKEN = os.environ.get("APP_SECRET_TOKEN", "Abdc@q1")
218
+ HOSTNAME = os.environ.get("HOSTNAME", "default-hostname")
219
+ MODELS_PATH = os.environ.get("MODELS_PATH", "~/apps/ai/models")
220
+ DEEPFACE_HOME = os.environ.get("DEEPFACE_HOME", "~/apps/ai")
221
+ os.environ["DEEPFACE_HOME"] = DEEPFACE_HOME
222
+
223
+ # 设置GFPGAN相关模型下载路径
224
+ GFPGAN_MODEL_DIR = "~/apps/ai/models"
225
+ os.makedirs(GFPGAN_MODEL_DIR, exist_ok=True)
226
+
227
+ # 设置各种模型库的下载目录环境变量
228
+ os.environ["GFPGAN_MODEL_ROOT"] = GFPGAN_MODEL_DIR
229
+ os.environ["FACEXLIB_CACHE_DIR"] = GFPGAN_MODEL_DIR
230
+ os.environ["BASICSR_CACHE_DIR"] = GFPGAN_MODEL_DIR
231
+ os.environ["REALESRGAN_MODEL_ROOT"] = GFPGAN_MODEL_DIR
232
+ os.environ["HUB_CACHE_DIR"] = GFPGAN_MODEL_DIR # PyTorch Hub缓存
233
+
234
+ # 设置rembg模型下载路径到统一的AI模型目录
235
+ REMBG_MODEL_DIR = os.path.expanduser(MODELS_PATH.replace("$HOME", "~"))
236
+ os.environ["U2NET_HOME"] = REMBG_MODEL_DIR # u2net模型缓存目录
237
+ os.environ["REMBG_HOME"] = REMBG_MODEL_DIR # rembg通用缓存目录
238
+
239
+ IMG_QUALITY = float(os.environ.get("IMG_QUALITY", 0.5))
240
+ FACE_CONFIDENCE = float(os.environ.get("FACE_CONFIDENCE", 0.7))
241
+ AGE_CONFIDENCE = float(os.environ.get("AGE_CONFIDENCE", 0.99))
242
+ GENDER_CONFIDENCE = float(os.environ.get("GENDER_CONFIDENCE", 1.1))
243
+ UPSCALE_SIZE = int(os.environ.get("UPSCALE_SIZE", 2))
244
+ SAVE_QUALITY = int(os.environ.get("SAVE_QUALITY", 90))
245
+ REALESRGAN_MODEL = os.environ.get("REALESRGAN_MODEL", "realesr-general-x4v3")
246
+ # yolov11n-face.pt / yolov8n-face.pt
247
+ YOLO_MODEL = os.environ.get("YOLO_MODEL", "yolov8n-face.pt")
248
+ # mobilenetv3/resnet50
249
+ RVM_MODEL = os.environ.get("RVM_MODEL", "resnet50")
250
+ RVM_LOCAL_REPO = os.environ.get("RVM_LOCAL_REPO", "").strip()
251
+ RVM_WEIGHTS_PATH = os.environ.get("RVM_WEIGHTS_PATH", "").strip()
252
+ DRAW_SCORE = os.environ.get("DRAW_SCORE", "true").lower() in ("1", "true", "on")
253
+
254
+ # 颜值评分温和提升配置(默认开启;默认区间与力度:区间=[6.0, 8.0],gamma=0.3)
255
+ # - BEAUTY_ADJUST_ENABLED: 是否开启提分
256
+ # - BEAUTY_ADJUST_MIN: 提分下限(低于该值不提分)
257
+ # - BEAUTY_ADJUST_MAX: 提分上限(目标上限;仅在 [min, max) 区间内提分)
258
+ # - BEAUTY_ADJUST_THRESHOLD: 兼容旧配置,等价于 BEAUTY_ADJUST_MAX
259
+ # - BEAUTY_ADJUST_GAMMA: 提分力度,(0,1],越小提升越多
260
+ BEAUTY_ADJUST_ENABLED = os.environ.get("BEAUTY_ADJUST_ENABLED", "true").lower() in ("1", "true", "on")
261
+ BEAUTY_ADJUST_MIN = float(os.environ.get("BEAUTY_ADJUST_MIN", 1.0))
262
+ # 向后兼容:未提供 BEAUTY_ADJUST_MAX 时,使用旧的 BEAUTY_ADJUST_THRESHOLD 或 8.0
263
+ _legacy_thr = os.environ.get("BEAUTY_ADJUST_THRESHOLD")
264
+ BEAUTY_ADJUST_MAX = float(os.environ.get("BEAUTY_ADJUST_MAX", _legacy_thr if _legacy_thr is not None else 8.0))
265
+ BEAUTY_ADJUST_GAMMA = float(os.environ.get("BEAUTY_ADJUST_GAMMA", 0.5)) # 0<gamma<=1,越小提升越多
266
+
267
+ # 兼容旧引用,保留变量名(不再直接使用于逻辑内部)
268
+ BEAUTY_ADJUST_THRESHOLD = BEAUTY_ADJUST_MAX
269
+
270
+ # 整体协调性分数温和提升配置(默认开启;默认阈值与力度:T=8.0, gamma=0.5)
271
+ HARMONY_ADJUST_ENABLED = os.environ.get("HARMONY_ADJUST_ENABLED", "true").lower() in ("1", "true", "on")
272
+ HARMONY_ADJUST_THRESHOLD = float(os.environ.get("HARMONY_ADJUST_THRESHOLD", 9.0))
273
+ HARMONY_ADJUST_GAMMA = float(os.environ.get("HARMONY_ADJUST_GAMMA", 0.3))
274
+
275
+ # 启动优化:是否在启动时自动初始化/预热重型组件
276
+ ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "false").lower() in ("1", "true", "on")
277
+ AUTO_INIT_ANALYZER = os.environ.get("AUTO_INIT_ANALYZER", "true").lower() in ("1", "true", "on")
278
+ AUTO_INIT_GFPGAN = os.environ.get("AUTO_INIT_GFPGAN", "false").lower() in ("1", "true", "on")
279
+ AUTO_INIT_DDCOLOR = os.environ.get("AUTO_INIT_DDCOLOR", "false").lower() in ("1", "true", "on")
280
+ AUTO_INIT_REALESRGAN = os.environ.get("AUTO_INIT_REALESRGAN", "false").lower() in ("1", "true", "on")
281
+ AUTO_INIT_REMBG = os.environ.get("AUTO_INIT_REMBG", "false").lower() in ("1", "true", "on")
282
+ AUTO_INIT_ANIME_STYLE = os.environ.get("AUTO_INIT_ANIME_STYLE", "false").lower() in ("1", "true", "on")
283
+ AUTO_INIT_RVM = os.environ.get("AUTO_INIT_RVM", "false").lower() in ("1", "true", "on")
284
+
285
+ # 定时任务相关配置
286
+ CLEANUP_INTERVAL_HOURS = float(os.environ.get("CLEANUP_INTERVAL_HOURS", 12.0)) # 清理任务执行间隔(小时),默认1小时
287
+ CLEANUP_AGE_HOURS = float(os.environ.get("CLEANUP_AGE_HOURS", 12.0)) # 清理文件的年龄阈值(小时),默认1小时
288
+
289
+ log_level_str = os.getenv("LOG_LEVEL", "INFO").upper()
290
+ log_level = getattr(logging, log_level_str, logging.INFO)
291
+
292
+ # 日志开关配置 - 控制是否启用所有日志输出
293
+ ENABLE_LOGGING = os.environ.get("ENABLE_LOGGING", "true").lower() in ("1", "true", "on")
294
+
295
+ # 功能开关配置
296
+ ENABLE_DDCOLOR = os.environ.get("ENABLE_DDCOLOR", "true").lower() in ("1", "true", "on")
297
+ ENABLE_REALESRGAN = os.environ.get("ENABLE_REALESRGAN", "true").lower() in ("1", "true", "on")
298
+ ENABLE_GFPGAN = os.environ.get("ENABLE_GFPGAN", "true").lower() in ("1", "true", "on")
299
+ ENABLE_ANIME_STYLE = os.environ.get("ENABLE_ANIME_STYLE", "true").lower() in ("1", "true", "on")
300
+ ENABLE_ANIME_PRELOAD = os.environ.get("ENABLE_ANIME_PRELOAD", "false").lower() in ("1", "true", "on")
301
+ ENABLE_RVM = os.environ.get("ENABLE_RVM", "true").lower() in ("1", "true", "on")
302
+
303
+ # 微信小程序配置(需要替换为你的实际值)
304
+ WECHAT_APPID = "wxe520a15ff9642313"
305
+ WECHAT_SECRET = "9a4ca8b0d9a8c8b7eee338c108bfc11f"
306
+
307
+ # 颜值评分模块配置
308
+ FACE_SCORE_MAX_IMAGES = int(os.environ.get("FACE_SCORE_MAX_IMAGES", 10)) # 颜值评分最大上传图片数量
309
+
310
+ # 女性年龄调整配置 - 对于20岁以上的女性,显示的年龄会减去指定岁数
311
+ FEMALE_AGE_ADJUSTMENT = int(os.environ.get("FEMALE_AGE_ADJUSTMENT", 3)) # 默认减3岁
312
+ FEMALE_AGE_ADJUSTMENT_THRESHOLD = int(os.environ.get("FEMALE_AGE_ADJUSTMENT_THRESHOLD", 20)) # 年龄阈值,默认20岁
313
+
314
+ # 配置日志
315
+ if ENABLE_LOGGING:
316
+ logging.basicConfig(
317
+ level=log_level,
318
+ format="[%(asctime)s] [%(levelname)s] %(message)s",
319
+ datefmt="%Y-%m-%d %H:%M:%S",
320
+ )
321
+ logger = logging.getLogger(__name__)
322
+ else:
323
+ # 禁用所有日志输出
324
+ logging.basicConfig(level=logging.CRITICAL + 10)
325
+ logger = logging.getLogger(__name__)
326
+ logger.disabled = True
327
+
328
+ # 全局变量存储 access_token
329
+ access_token_cache = {"token": None, "expires_at": 0}
330
+
331
+ # 尝试导入依赖
332
+ try:
333
+ from deepface import DeepFace
334
+
335
+ DEEPFACE_AVAILABLE = True
336
+ except ImportError:
337
+ print("Warning: DeepFace not installed. Install with: pip install deepface")
338
+ DEEPFACE_AVAILABLE = False
339
+
340
+ try:
341
+ import mediapipe as mp
342
+
343
+ MEDIAPIPE_AVAILABLE = True
344
+ except ImportError:
345
+ print("Warning: mediapipe not installed. Install with: pip install mediapipe")
346
+ MEDIAPIPE_AVAILABLE = False
347
+
348
+ # 为了保持向后兼容,保留 DLIB_AVAILABLE 变量名
349
+ DLIB_AVAILABLE = MEDIAPIPE_AVAILABLE
350
+
351
+ try:
352
+ from ultralytics import YOLO
353
+
354
+ YOLO_AVAILABLE = True
355
+ except ImportError:
356
+ print("Warning: ultralytics not installed. Install with: pip install ultralytics")
357
+ YOLO_AVAILABLE = False
358
+
359
+ # 检查GFPGAN是否启用和可用
360
+ if ENABLE_GFPGAN:
361
+ try:
362
+ # 检查GFPGAN相关文件是否存在
363
+ gfpgan_files_exist = True
364
+ required_files = [
365
+ "gfpgan_restorer.py",
366
+ "gfpgan/weights/detection_Resnet50_Final.pth",
367
+ "gfpgan/weights/parsing_parsenet.pth"
368
+ ]
369
+
370
+ for file_path in required_files:
371
+ if not os.path.exists(file_path):
372
+ print(f"Missing GFPGAN file: {file_path}")
373
+ gfpgan_files_exist = False
374
+
375
+ if gfpgan_files_exist:
376
+ from gfpgan_restorer import GFPGANRestorer
377
+ GFPGAN_AVAILABLE = True
378
+ logger.info("GFPGAN photo restoration feature is enabled and available")
379
+ else:
380
+ GFPGAN_AVAILABLE = False
381
+ print("Warning: GFPGAN files missing, functionality disabled")
382
+ except ImportError as e:
383
+ print(f"Warning: GFPGAN enabled but not available: {e}")
384
+ GFPGAN_AVAILABLE = False
385
+ logger.warning(f"GFPGAN photo restoration feature is enabled but import failed: {e}")
386
+ else:
387
+ GFPGAN_AVAILABLE = False
388
+ logger.info("GFPGAN photo restoration feature is disabled (via ENABLE_GFPGAN environment variable)")
389
+
390
+ # 检查DDColor是否启用和可用
391
+ if ENABLE_DDCOLOR:
392
+ try:
393
+ from ddcolor_colorizer import DDColorColorizer
394
+ DDCOLOR_AVAILABLE = True
395
+ logger.info("DDColor feature is enabled and available")
396
+ except ImportError as e:
397
+ print(f"Warning: DDColor enabled but not available: {e}")
398
+ DDCOLOR_AVAILABLE = False
399
+ logger.warning(f"DDColor feature is enabled but import failed: {e}")
400
+ else:
401
+ DDCOLOR_AVAILABLE = False
402
+ logger.info("DDColor feature is disabled (via ENABLE_DDCOLOR environment variable)")
403
+
404
+ # 只使用GFPGAN修复器
405
+ SIMPLE_RESTORER_AVAILABLE = False
406
+
407
+ # 检查Real-ESRGAN是否启用和可用
408
+ if ENABLE_REALESRGAN:
409
+ try:
410
+ from realesrgan_upscaler import RealESRGANUpscaler
411
+ REALESRGAN_AVAILABLE = True
412
+ logger.info("Real-ESRGAN super resolution feature is enabled and available")
413
+ except ImportError as e:
414
+ print(f"Warning: Real-ESRGAN enabled but not available: {e}")
415
+ REALESRGAN_AVAILABLE = False
416
+ logger.warning(f"Real-ESRGAN super resolution feature is enabled but import failed: {e}")
417
+ else:
418
+ REALESRGAN_AVAILABLE = False
419
+ logger.info("Real-ESRGAN super resolution feature is disabled (via ENABLE_REALESRGAN environment variable)")
420
+
421
+ # rembg功能开关配置
422
+ ENABLE_REMBG = os.environ.get("ENABLE_REMBG", "true").lower() in ("1", "true", "on")
423
+
424
+ # 检查rembg是否启用和可用
425
+ if ENABLE_REMBG:
426
+ try:
427
+ import rembg
428
+ from rembg import new_session
429
+ REMBG_AVAILABLE = True
430
+ logger.info("rembg background removal feature is enabled and available")
431
+ logger.info(f"rembg model storage path: {REMBG_MODEL_DIR}")
432
+ except ImportError as e:
433
+ print(f"Warning: rembg enabled but not available: {e}")
434
+ REMBG_AVAILABLE = False
435
+ logger.warning(f"rembg background removal feature is enabled but import failed: {e}")
436
+ else:
437
+ REMBG_AVAILABLE = False
438
+ logger.info("rembg background removal feature is disabled (via ENABLE_REMBG environment variable)")
439
+
440
+ CLIP_AVAILABLE = False
441
+
442
+ # 检查Anime Style是否启用和可用
443
+ if ENABLE_ANIME_STYLE:
444
+ try:
445
+ from anime_stylizer import AnimeStylizer
446
+ ANIME_STYLE_AVAILABLE = True
447
+ logger.info("Anime stylization feature is enabled and available")
448
+ except ImportError as e:
449
+ print(f"Warning: Anime Style enabled but not available: {e}")
450
+ ANIME_STYLE_AVAILABLE = False
451
+ logger.warning(f"Anime stylization feature is enabled but import failed: {e}")
452
+ else:
453
+ ANIME_STYLE_AVAILABLE = False
454
+ logger.info("Anime stylization feature is disabled (via ENABLE_ANIME_STYLE environment variable)")
455
+
456
+ # RVM功能开关配置
457
+ ENABLE_RVM = os.environ.get("ENABLE_RVM", "true").lower() in ("1", "true", "on")
458
+
459
+ # 检查RVM是否启用和可用
460
+ if ENABLE_RVM:
461
+ try:
462
+ import torch
463
+ # 检查是否可以加载RVM模型
464
+ RVM_AVAILABLE = True
465
+ logger.info("RVM background removal feature is enabled and available")
466
+ except ImportError as e:
467
+ print(f"Warning: RVM enabled but not available: {e}")
468
+ RVM_AVAILABLE = False
469
+ logger.warning(f"RVM background removal feature is enabled but import failed: {e}")
470
+ else:
471
+ RVM_AVAILABLE = False
472
+ logger.info("RVM background removal feature is disabled (via ENABLE_RVM environment variable)")
database.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import os
4
+ from contextlib import asynccontextmanager
5
+ from datetime import datetime
6
+ from typing import Any, Dict, Iterable, List, Optional, Sequence
7
+
8
+ import aiomysql
9
+ from aiomysql.cursors import DictCursor
10
+
11
+ from config import IMAGES_DIR, logger
12
+
13
+ MYSQL_HOST = os.environ.get(
14
+ "MYSQL_HOST", "rm-bp1205c91psn350b3lo.mysql.rds.aliyuncs.com"
15
+ )
16
+ MYSQL_PORT = int(os.environ.get("MYSQL_PORT", "3306"))
17
+ MYSQL_DB = os.environ.get("MYSQL_DB", "pexar-service-test")
18
+ MYSQL_USER = os.environ.get("MYSQL_USER", "lexar")
19
+ MYSQL_PASSWORD = os.environ.get("MYSQL_PASSWORD", "lexar20241119*")
20
+ MYSQL_POOL_MIN_SIZE = int(os.environ.get("MYSQL_POOL_MIN_SIZE", "1"))
21
+ MYSQL_POOL_MAX_SIZE = int(os.environ.get("MYSQL_POOL_MAX_SIZE", "10"))
22
+
23
+ _pool: Optional[aiomysql.Pool] = None
24
+ _pool_lock = asyncio.Lock()
25
+
26
+
27
+ async def init_mysql_pool() -> aiomysql.Pool:
28
+ """初始化 MySQL 连接池"""
29
+ global _pool
30
+ if _pool is not None:
31
+ return _pool
32
+
33
+ async with _pool_lock:
34
+ if _pool is not None:
35
+ return _pool
36
+ try:
37
+ _pool = await aiomysql.create_pool(
38
+ host=MYSQL_HOST,
39
+ port=MYSQL_PORT,
40
+ user=MYSQL_USER,
41
+ password=MYSQL_PASSWORD,
42
+ db=MYSQL_DB,
43
+ minsize=MYSQL_POOL_MIN_SIZE,
44
+ maxsize=MYSQL_POOL_MAX_SIZE,
45
+ autocommit=True,
46
+ charset="utf8mb4",
47
+ cursorclass=DictCursor,
48
+ )
49
+ logger.info(
50
+ "MySQL 连接池初始化成功,host=%s db=%s",
51
+ MYSQL_HOST,
52
+ MYSQL_DB,
53
+ )
54
+ except Exception as exc:
55
+ logger.error(f"初始化 MySQL 连接池失败: {exc}")
56
+ raise
57
+ return _pool
58
+
59
+
60
+ async def close_mysql_pool() -> None:
61
+ """关闭 MySQL 连接池"""
62
+ global _pool
63
+ if _pool is None:
64
+ return
65
+
66
+ async with _pool_lock:
67
+ if _pool is None:
68
+ return
69
+ _pool.close()
70
+ await _pool.wait_closed()
71
+ _pool = None
72
+ logger.info("MySQL 连接池已关闭")
73
+
74
+
75
+ @asynccontextmanager
76
+ async def get_connection():
77
+ """获取连接池中的连接"""
78
+ if _pool is None:
79
+ await init_mysql_pool()
80
+ assert _pool is not None
81
+ conn = await _pool.acquire()
82
+ try:
83
+ yield conn
84
+ finally:
85
+ _pool.release(conn)
86
+
87
+
88
+ async def execute(query: str,
89
+ params: Sequence[Any] | Dict[str, Any] | None = None) -> None:
90
+ """执行写入类 SQL"""
91
+ async with get_connection() as conn:
92
+ async with conn.cursor() as cursor:
93
+ await cursor.execute(query, params or ())
94
+
95
+
96
+ async def fetch_all(
97
+ query: str, params: Sequence[Any] | Dict[str, Any] | None = None
98
+ ) -> List[Dict[str, Any]]:
99
+ """执行查询并返回全部结果"""
100
+ async with get_connection() as conn:
101
+ async with conn.cursor() as cursor:
102
+ await cursor.execute(query, params or ())
103
+ rows = await cursor.fetchall()
104
+ return list(rows)
105
+
106
+
107
+ def _serialize_extra(extra: Optional[Dict[str, Any]]) -> Optional[str]:
108
+ if extra is None:
109
+ return None
110
+ try:
111
+ return json.dumps(extra, ensure_ascii=False)
112
+ except Exception:
113
+ logger.warning("无法序列化 extra_metadata,已忽略")
114
+ return None
115
+
116
+
117
+ async def upsert_image_record(
118
+ *,
119
+ file_path: str,
120
+ category: str,
121
+ nickname: Optional[str],
122
+ score: float,
123
+ is_cropped_face: bool,
124
+ size_bytes: int,
125
+ last_modified: datetime,
126
+ bos_uploaded: bool,
127
+ hostname: Optional[str] = None,
128
+ extra_metadata: Optional[Dict[str, Any]] = None,
129
+ ) -> None:
130
+ """写入或更新图片记录"""
131
+ query = """
132
+ INSERT INTO tpl_app_processed_images (
133
+ file_path,
134
+ category,
135
+ nickname,
136
+ score,
137
+ is_cropped_face,
138
+ size_bytes,
139
+ last_modified,
140
+ bos_uploaded,
141
+ hostname,
142
+ extra_metadata
143
+ ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
144
+ ON DUPLICATE KEY UPDATE
145
+ category = VALUES(category),
146
+ nickname = VALUES(nickname),
147
+ score = VALUES(score),
148
+ is_cropped_face = VALUES(is_cropped_face),
149
+ size_bytes = VALUES(size_bytes),
150
+ last_modified = VALUES(last_modified),
151
+ bos_uploaded = VALUES(bos_uploaded),
152
+ hostname = VALUES(hostname),
153
+ extra_metadata = VALUES(extra_metadata),
154
+ updated_at = CURRENT_TIMESTAMP
155
+ """
156
+ extra_value = _serialize_extra(extra_metadata)
157
+ await execute(
158
+ query,
159
+ (
160
+ file_path,
161
+ category,
162
+ nickname,
163
+ score,
164
+ 1 if is_cropped_face else 0,
165
+ size_bytes,
166
+ last_modified,
167
+ 1 if bos_uploaded else 0,
168
+ hostname,
169
+ extra_value,
170
+ ),
171
+ )
172
+
173
+
174
+ async def fetch_paged_image_records(
175
+ *,
176
+ category: Optional[str],
177
+ nickname: Optional[str],
178
+ offset: int,
179
+ limit: int,
180
+ ) -> List[Dict[str, Any]]:
181
+ """按条件分页查询图片记录"""
182
+ where_clauses: List[str] = []
183
+ params: List[Any] = []
184
+ if category and category != "all":
185
+ where_clauses.append("category = %s")
186
+ params.append(category)
187
+ if nickname:
188
+ where_clauses.append("nickname = %s")
189
+ params.append(nickname)
190
+ where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
191
+ query = f"""
192
+ SELECT
193
+ file_path,
194
+ category,
195
+ nickname,
196
+ score,
197
+ is_cropped_face,
198
+ size_bytes,
199
+ last_modified,
200
+ bos_uploaded,
201
+ hostname
202
+ FROM tpl_app_processed_images
203
+ {where_sql}
204
+ ORDER BY last_modified DESC, id DESC
205
+ LIMIT %s OFFSET %s
206
+ """
207
+ params.extend([limit, offset])
208
+ return await fetch_all(query, params)
209
+
210
+
211
+ async def count_image_records(
212
+ *, category: Optional[str], nickname: Optional[str]
213
+ ) -> int:
214
+ """按条件统计图片记录数量"""
215
+ where_clauses: List[str] = []
216
+ params: List[Any] = []
217
+ if category and category != "all":
218
+ where_clauses.append("category = %s")
219
+ params.append(category)
220
+ if nickname:
221
+ where_clauses.append("nickname = %s")
222
+ params.append(nickname)
223
+ where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
224
+ query = f"SELECT COUNT(*) AS total FROM tpl_app_processed_images {where_sql}"
225
+ rows = await fetch_all(query, params)
226
+ if not rows:
227
+ return 0
228
+ return int(rows[0].get("total", 0) or 0)
229
+
230
+
231
+ async def fetch_today_category_counts() -> List[Dict[str, Any]]:
232
+ """统计当天按类别分组的数量"""
233
+ query = """
234
+ SELECT
235
+ COALESCE(category, 'unknown') AS category,
236
+ COUNT(*) AS count
237
+ FROM tpl_app_processed_images
238
+ WHERE DATE(last_modified) = CURDATE()
239
+ GROUP BY COALESCE(category, 'unknown')
240
+ """
241
+ rows = await fetch_all(query)
242
+ return [
243
+ {
244
+ "category": str(row.get("category") or "unknown"),
245
+ "count": int(row.get("count") or 0),
246
+ }
247
+ for row in rows
248
+ ]
249
+
250
+
251
+ async def fetch_records_by_paths(file_paths: Iterable[str]) -> Dict[
252
+ str, Dict[str, Any]]:
253
+ """根据文件名批量查询图片记录"""
254
+ paths = list({path for path in file_paths if path})
255
+ if not paths:
256
+ return {}
257
+
258
+ placeholders = ", ".join(["%s"] * len(paths))
259
+ query = f"""
260
+ SELECT
261
+ file_path,
262
+ category,
263
+ nickname,
264
+ score,
265
+ is_cropped_face,
266
+ size_bytes,
267
+ last_modified,
268
+ bos_uploaded,
269
+ hostname
270
+ FROM tpl_app_processed_images
271
+ WHERE file_path IN ({placeholders})
272
+ """
273
+ rows = await fetch_all(query, paths)
274
+ return {row["file_path"]: row for row in rows}
275
+
276
+
277
+ _IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR))
278
+
279
+
280
+ def _normalize_file_path(file_path: str) -> Optional[str]:
281
+ """将绝对路径转换为相对 IMAGES_DIR 的文件名"""
282
+ try:
283
+ abs_path = os.path.abspath(os.path.expanduser(file_path))
284
+ if os.path.isdir(abs_path):
285
+ return None
286
+ if os.path.commonpath([_IMAGES_DIR_ABS, abs_path]) != _IMAGES_DIR_ABS:
287
+ return os.path.basename(abs_path)
288
+ rel_path = os.path.relpath(abs_path, _IMAGES_DIR_ABS)
289
+ return rel_path.replace("\\", "/")
290
+ except Exception:
291
+ return None
292
+
293
+
294
+ def infer_category_from_filename(filename: str, default: str = "other") -> str:
295
+ """根据文件名推断类别"""
296
+ lower_name = filename.lower()
297
+ if "_face_" in lower_name:
298
+ return "face"
299
+ if lower_name.endswith("_original.webp") or "_original" in lower_name:
300
+ return "original"
301
+ if "_restore" in lower_name:
302
+ return "restore"
303
+ if "_upcolor" in lower_name:
304
+ return "upcolor"
305
+ if "_compress" in lower_name:
306
+ return "compress"
307
+ if "_upscale" in lower_name:
308
+ return "upscale"
309
+ if "_anime_style_" in lower_name:
310
+ return "anime_style"
311
+ if "_grayscale" in lower_name:
312
+ return "grayscale"
313
+ if "_id_photo" in lower_name or "_save_id_photo" in lower_name:
314
+ return "id_photo"
315
+ if "_grid_" in lower_name:
316
+ return "grid"
317
+ if "_rvm_id_photo" in lower_name:
318
+ return "rvm"
319
+ if "_celebrity_" in lower_name or "_celebrity" in lower_name:
320
+ return "celebrity"
321
+ return default
322
+
323
+
324
+ from config import HOSTNAME
325
+
326
+ async def record_image_creation(
327
+ *,
328
+ file_path: str,
329
+ nickname: Optional[str],
330
+ score: float = 0.0,
331
+ category: Optional[str] = None,
332
+ bos_uploaded: bool = False,
333
+ extra_metadata: Optional[Dict[str, Any]] = None,
334
+ ) -> None:
335
+ """
336
+ 记录图片元数据到数据库,如果数据库不可用则静默忽略。
337
+ :param file_path: 绝对或相对文件路径
338
+ :param nickname: 用户昵称
339
+ :param score: 关联得分
340
+ :param category: 文件类别,未提供时自动根据文件名推断
341
+ :param bos_uploaded: 是否已上传至 BOS
342
+ :param extra_metadata: 额外信息
343
+ """
344
+ normalized = _normalize_file_path(file_path)
345
+ if normalized is None:
346
+ logger.info("record_image_creation: 无法计算文件名,路径=%s", file_path)
347
+ return
348
+
349
+ abs_path = os.path.join(_IMAGES_DIR_ABS, normalized)
350
+ if not os.path.isfile(abs_path):
351
+ logger.info("record_image_creation: 文件不存在,跳过记录 file=%s", abs_path)
352
+ return
353
+
354
+ try:
355
+ stat = os.stat(abs_path)
356
+ category_name = category or infer_category_from_filename(normalized)
357
+ is_cropped_face = "_face_" in normalized and normalized.count("_") >= 2
358
+ last_modified = datetime.fromtimestamp(stat.st_mtime)
359
+
360
+ nickname_value = nickname.strip() if isinstance(nickname,
361
+ str) and nickname.strip() else None
362
+
363
+ await upsert_image_record(
364
+ file_path=normalized,
365
+ category=category_name,
366
+ nickname=nickname_value,
367
+ score=score,
368
+ is_cropped_face=is_cropped_face,
369
+ size_bytes=stat.st_size,
370
+ last_modified=last_modified,
371
+ bos_uploaded=bos_uploaded,
372
+ hostname=HOSTNAME,
373
+ extra_metadata=extra_metadata,
374
+ )
375
+ except Exception as exc:
376
+ logger.warning(f"写入图片记录失败: {exc}")
ddcolor_colorizer.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from config import logger
9
+
10
+
11
+ class DDColorColorizer:
12
+ def __init__(self):
13
+ start_time = time.perf_counter()
14
+ self.colorizer = None
15
+ # 检查是否启用DDColor功能
16
+ from config import ENABLE_DDCOLOR
17
+ if ENABLE_DDCOLOR:
18
+ self._initialize_model()
19
+ else:
20
+ logger.info("DDColor feature is disabled, skipping model initialization")
21
+ init_time = time.perf_counter() - start_time
22
+ if self.colorizer is not None:
23
+ logger.info(f"DDColorColorizer initialized successfully, time: {init_time:.3f}s")
24
+ else:
25
+ logger.info(f"DDColorColorizer initialization completed but not available, time: {init_time:.3f}s")
26
+
27
+ def _initialize_model(self):
28
+ """初始化DDColor模型(使用ModelScope)"""
29
+ try:
30
+ logger.info("Initializing DDColor model (using ModelScope)...")
31
+
32
+ # 添加torch类型兼容性补丁
33
+ import torch
34
+ if not hasattr(torch, 'uint64'):
35
+ logger.info("Adding torch.uint64 compatibility patch...")
36
+ torch.uint64 = torch.int64 # 使用int64作为uint64的替代
37
+ if not hasattr(torch, 'uint32'):
38
+ logger.info("Adding torch.uint32 compatibility patch...")
39
+ torch.uint32 = torch.int32 # 使用int32作为uint32的替代
40
+ if not hasattr(torch, 'uint16'):
41
+ logger.info("Adding torch.uint16 compatibility patch...")
42
+ torch.uint16 = torch.int16 # 使用int16作为uint16的替代
43
+
44
+ # 导入ModelScope相关模块
45
+ from modelscope.outputs import OutputKeys
46
+ from modelscope.pipelines import pipeline
47
+ from modelscope.utils.constant import Tasks
48
+
49
+ # 初始化DDColor pipeline
50
+ self.colorizer = pipeline(
51
+ Tasks.image_colorization,
52
+ model='damo/cv_ddcolor_image-colorization'
53
+ )
54
+ self.OutputKeys = OutputKeys
55
+
56
+ logger.info("DDColor model initialized successfully")
57
+
58
+ except ImportError as e:
59
+ logger.error(f"ModelScope module import failed: {e}")
60
+ self.colorizer = None
61
+ except Exception as e:
62
+ logger.error(f"DDColor model initialization failed: {e}")
63
+ self.colorizer = None
64
+
65
+ def is_available(self):
66
+ """检查DDColor是否可用"""
67
+ return self.colorizer is not None
68
+
69
+ def is_grayscale(self, image):
70
+ """检查图像是否为灰度图像"""
71
+ if len(image.shape) == 2:
72
+ return True
73
+ elif len(image.shape) == 3:
74
+ # 检查是否为伪彩色图像(RGB三个通道值相等)
75
+ b, g, r = cv2.split(image)
76
+
77
+ # 计算通道间的差异
78
+ diff_bg = np.abs(b.astype(float) - g.astype(float))
79
+ diff_gr = np.abs(g.astype(float) - r.astype(float))
80
+ diff_rb = np.abs(r.astype(float) - b.astype(float))
81
+
82
+ # 计算平均差异
83
+ avg_diff = (np.mean(diff_bg) + np.mean(diff_gr) + np.mean(diff_rb)) / 3.0
84
+
85
+ # 计算色彩饱和度
86
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
87
+ saturation = hsv[:, :, 1] # S通道
88
+ avg_saturation = np.mean(saturation)
89
+
90
+ # 改进的灰度检测:同时考虑通道差异和饱和度
91
+ is_gray = (avg_diff < 5.0) or (avg_saturation < 20.0)
92
+
93
+ logger.info(f"Grayscale detection - Average channel difference: {avg_diff:.2f}, Average saturation: {avg_saturation:.2f}, Result: {is_gray}")
94
+ return is_gray
95
+ return False
96
+
97
+ def colorize_image(self, image):
98
+ """
99
+ 使用DDColor对灰度图像进行上色
100
+ :param image: 输入图像 (numpy array, BGR格式)
101
+ :return: 上色后的图像 (numpy array, BGR格式)
102
+ """
103
+ if not self.is_available():
104
+ logger.error("DDColor model not initialized")
105
+ return image
106
+
107
+ # 检查是否为灰度图像
108
+ if not self.is_grayscale(image):
109
+ logger.info("Image is already colored, no need for colorization")
110
+ return image
111
+
112
+ return self.colorize_image_direct(image)
113
+
114
+ def colorize_image_direct(self, image):
115
+ """
116
+ 直接对图像进行上色,不检查是否为灰度图
117
+ 使用与test_ddcolor.py相同质量的文件路径方法
118
+ :param image: 输入图像 (numpy array, BGR格式)
119
+ :return: 上色后的图像 (numpy array, BGR格式)
120
+ """
121
+ if not self.is_available():
122
+ logger.error("DDColor model not initialized")
123
+ return image
124
+
125
+ # 直接使用文件路径方法,这是经过验证效果最好的方式
126
+ return self._colorize_image_via_file(image)
127
+
128
+ def _colorize_image_via_file(self, image):
129
+ """
130
+ 通过临时文件进行上色,尽可能模拟test_ddcolor.py的处理方式
131
+ :param image: 输入图像 (numpy array, BGR格式)
132
+ :return: 上色后的图像 (numpy array, BGR格式)
133
+ """
134
+ try:
135
+ logger.info("Using high-quality file path method for colorization...")
136
+
137
+ # 使用最高质量设置保存临时图像,尽可能保持原始质量
138
+ with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_input:
139
+ # 使用WebP格式以获得更好的质量和更小的文件大小
140
+ cv2.imwrite(tmp_input.name, image, [cv2.IMWRITE_WEBP_QUALITY, 100])
141
+ tmp_input_path = tmp_input.name
142
+
143
+ try:
144
+ logger.info(f"Temporary file saved to: {tmp_input_path}")
145
+
146
+ # 使用ModelScope进行上色 - 与test_colorization完全相同的调用方式
147
+ result = self.colorizer(tmp_input_path)
148
+
149
+ # 获取上色后的图像 - 与test_colorization完全相同的处理
150
+ colorized_image = result[self.OutputKeys.OUTPUT_IMG]
151
+
152
+ logger.info(f"Colorization output: size={colorized_image.shape}, type={colorized_image.dtype}")
153
+
154
+ # ModelScope输出的图像已经是BGR格式,不需要转换
155
+ # (与test_colorization保存时直接使用cv2.imwrite一致)
156
+ logger.info("High-quality file path method colorization completed")
157
+ return colorized_image
158
+
159
+ finally:
160
+ # 清理临时文件
161
+ try:
162
+ os.unlink(tmp_input_path)
163
+ except:
164
+ pass
165
+
166
+ except Exception as e:
167
+ logger.error(f"High-quality file path method colorization failed: {e}")
168
+ logger.info("Returning original image")
169
+ return image
170
+
171
+ def restore_and_colorize(self, image, gfpgan_restorer=None):
172
+ """
173
+ 先修复后上色的组合处理(旧版本,保持兼容性)
174
+ :param image: 输入图像
175
+ :param gfpgan_restorer: GFPGAN修复器实例
176
+ :return: 修复并上色后的图像
177
+ """
178
+ try:
179
+ # 先进行修复(如果有修复器)
180
+ if gfpgan_restorer and gfpgan_restorer.is_available():
181
+ logger.info("First performing image restoration...")
182
+ restored_image = gfpgan_restorer.restore_image(image)
183
+ else:
184
+ restored_image = image
185
+
186
+ # 再进行上色
187
+ if self.is_grayscale(restored_image):
188
+ logger.info("Grayscale image detected, performing colorization...")
189
+ colorized_image = self.colorize_image(restored_image)
190
+ return colorized_image
191
+ else:
192
+ logger.info("Image is already colored, only returning restoration result")
193
+ return restored_image
194
+
195
+ except Exception as e:
196
+ logger.error(f"Restoration and colorization combination processing failed: {e}")
197
+ return image
198
+
199
+ def colorize_and_restore(self, image, gfpgan_restorer=None):
200
+ """
201
+ 先上色后修复的组合处理(新版本)
202
+ :param image: 输入图像
203
+ :param gfpgan_restorer: GFPGAN修复器实例
204
+ :return: 上色并修复后的图像
205
+ """
206
+ try:
207
+ # 先进行上色(如果是灰度图)
208
+ if self.is_grayscale(image):
209
+ logger.info("Grayscale image detected, performing colorization first...")
210
+ colorized_image = self.colorize_image_direct(image)
211
+ else:
212
+ logger.info("Image is already colored, skipping colorization step")
213
+ colorized_image = image
214
+
215
+ # 再进行修复(如果有修复器)
216
+ if gfpgan_restorer and gfpgan_restorer.is_available():
217
+ logger.info("Performing restoration on the colorized image...")
218
+ final_image = gfpgan_restorer.restore_image(colorized_image)
219
+ return final_image
220
+ else:
221
+ logger.info("No restorer available, returning colorization result")
222
+ return colorized_image
223
+
224
+ except Exception as e:
225
+ logger.error(f"Colorization and restoration combination processing failed: {e}")
226
+ return image
227
+
228
+ def save_debug_image(self, image, filename_prefix):
229
+ """保存调试用的图像"""
230
+ try:
231
+ debug_path = f"{filename_prefix}_debug.webp"
232
+ cv2.imwrite(debug_path, image, [cv2.IMWRITE_WEBP_QUALITY, 95])
233
+ logger.info(f"Debug image saved: {debug_path}")
234
+ return debug_path
235
+ except Exception as e:
236
+ logger.error(f"Failed to save debug image: {e}")
237
+ return None
238
+
239
+ def test_colorization(self, test_url=None):
240
+ """
241
+ 测试上色功能
242
+ :param test_url: 测试图像URL,默认使用官方示例
243
+ :return: 测试结果
244
+ """
245
+ if not self.is_available():
246
+ return False, "DDColor模型未初始化"
247
+
248
+ try:
249
+ test_url = test_url or 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/audrey_hepburn.jpg'
250
+ logger.info(f"Testing DDColor colorization feature, using image: {test_url}")
251
+
252
+ result = self.colorizer(test_url)
253
+ colorized_img = result[self.OutputKeys.OUTPUT_IMG]
254
+
255
+ # 保存测试结果
256
+ test_output_path = 'ddcolor_test_result.webp'
257
+ cv2.imwrite(test_output_path, colorized_img, [cv2.IMWRITE_WEBP_QUALITY, 95])
258
+
259
+ logger.info(f"DDColor test successful, result saved to: {test_output_path}")
260
+ return True, f"测试成功,结果保存到: {test_output_path}"
261
+
262
+ except Exception as e:
263
+ logger.error(f"DDColor test failed: {e}")
264
+ return False, f"测试失败: {e}"
265
+
266
+ def test_local_image(self, image_path):
267
+ """
268
+ 测试本地图像上色,用于对比分析
269
+ :param image_path: 本地图像路径
270
+ :return: 测试结果
271
+ """
272
+ if not self.is_available():
273
+ return False, "DDColor模型未初始化"
274
+
275
+ try:
276
+ logger.info(f"Testing local image colorization: {image_path}")
277
+
278
+ # 读取本地图像
279
+ image = cv2.imread(image_path)
280
+ if image is None:
281
+ return False, f"无法读取图像: {image_path}"
282
+
283
+ # 检查是否为灰度
284
+ is_gray = self.is_grayscale(image)
285
+ logger.info(f"Local image grayscale detection result: {is_gray}")
286
+
287
+ # 保存原图用于对比
288
+ self.save_debug_image(image, "original")
289
+
290
+ # 直接上色
291
+ colorized_image = self.colorize_image_direct(image)
292
+
293
+ # 保存上色结果
294
+ result_path = self.save_debug_image(colorized_image, "local_colorized")
295
+
296
+ logger.info(f"Local image colorization successful, result saved to: {result_path}")
297
+ return True, f"本地图像上色成功,结果保存到: {result_path}"
298
+
299
+ except Exception as e:
300
+ logger.error(f"Local image colorization failed: {e}")
301
+ return False, f"本地图像上色失败: {e}"
debug_colorize.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 调试上色效果差异的脚本
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ import cv2
9
+ import numpy as np
10
+
11
+ # 添加当前目录到路径
12
+ sys.path.insert(0, os.path.dirname(__file__))
13
+
14
+ from ddcolor_colorizer import DDColorColorizer
15
+ from gfpgan_restorer import GFPGANRestorer
16
+ import logging
17
+
18
+ # 设置日志
19
+ logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
20
+
21
+ def simulate_api_processing(image_path):
22
+ """
23
+ 模拟API接口的完整处理流程
24
+ """
25
+ print("\n=== 模拟API接口处理流程 ===")
26
+
27
+ # 初始化组件
28
+ print("初始化GFPGAN修复器...")
29
+ try:
30
+ gfpgan_restorer = GFPGANRestorer()
31
+ if not gfpgan_restorer.is_available():
32
+ print("❌ GFPGAN不可用")
33
+ return None
34
+ print("✅ GFPGAN初始化成功")
35
+ except Exception as e:
36
+ print(f"❌ GFPGAN初始化失败: {e}")
37
+ return None
38
+
39
+ print("初始化DDColor上色器...")
40
+ try:
41
+ ddcolor_colorizer = DDColorColorizer()
42
+ if not ddcolor_colorizer.is_available():
43
+ print("❌ DDColor不可用")
44
+ return None
45
+ print("✅ DDColor初始化成功")
46
+ except Exception as e:
47
+ print(f"❌ DDColor初始化失败: {e}")
48
+ return None
49
+
50
+ # 读取图像
51
+ print(f"读取图像: {image_path}")
52
+ image = cv2.imread(image_path)
53
+ if image is None:
54
+ print(f"❌ 无法读取图像: {image_path}")
55
+ return None
56
+
57
+ print(f"原图尺寸: {image.shape}")
58
+
59
+ # 保存原图
60
+ ddcolor_colorizer.save_debug_image(image, "api_original")
61
+
62
+ # 检查原图灰度状态
63
+ original_is_grayscale = ddcolor_colorizer.is_grayscale(image)
64
+ print(f"原图灰度检测: {original_is_grayscale}")
65
+
66
+ # 新的处理流程:先上色再修复
67
+ # 步骤1: 上色处理
68
+ print("\n步骤1: 上色处理...")
69
+ if original_is_grayscale:
70
+ print("策略: 对原图进行上色")
71
+ colorized_image = ddcolor_colorizer.colorize_image_direct(image)
72
+ ddcolor_colorizer.save_debug_image(colorized_image, "api_colorized")
73
+ strategy = "先上色"
74
+ current_image = colorized_image
75
+ else:
76
+ print("策略: 图像已经是彩色的,跳过上色")
77
+ strategy = "跳过上色"
78
+ current_image = image
79
+
80
+ # 步骤2: GFPGAN修复
81
+ print("\n步骤2: GFPGAN修复...")
82
+ final_image = gfpgan_restorer.restore_image(current_image)
83
+ print(f"修复后图像尺寸: {final_image.shape}")
84
+
85
+ # 保存最终结果
86
+ result_path = ddcolor_colorizer.save_debug_image(final_image, "api_final")
87
+
88
+ strategy += " -> 再修复"
89
+
90
+ print(f"\n✅ API模拟完成")
91
+ print(f" - 处理策略: {strategy}")
92
+ print(f" - 最终结果: {result_path}")
93
+
94
+ return {
95
+ 'original': image,
96
+ 'colorized': colorized_image if original_is_grayscale else None,
97
+ 'final': final_image,
98
+ 'strategy': strategy
99
+ }
100
+
101
+ def test_direct_colorization(image_path):
102
+ """
103
+ 测试直接上色(类似test_ddcolor.py的方式)
104
+ """
105
+ print("\n=== 测试直接上色 ===")
106
+
107
+ colorizer = DDColorColorizer()
108
+ if not colorizer.is_available():
109
+ print("❌ DDColor不可用")
110
+ return None
111
+
112
+ # 直接使用URL进行上色(和test_ddcolor.py相同)
113
+ print("使用官方示例URL上色...")
114
+ success, message = colorizer.test_colorization()
115
+
116
+ if success:
117
+ print(f"✅ URL上色成功: {message}")
118
+ else:
119
+ print(f"❌ URL上色失败: {message}")
120
+
121
+ # 对本地图像进行直接上色
122
+ print(f"对本地图像直接上色: {image_path}")
123
+ success, message = colorizer.test_local_image(image_path)
124
+
125
+ if success:
126
+ print(f"✅ 本地图像上色成功: {message}")
127
+ else:
128
+ print(f"❌ 本地图像上色失败: {message}")
129
+
130
+ def compare_results():
131
+ """
132
+ 对比分析结果
133
+ """
134
+ print("\n=== 结果对比分析 ===")
135
+
136
+ # 列出生成的调试图像
137
+ debug_files = []
138
+ for f in os.listdir("."):
139
+ if f.endswith("_debug.webp"):
140
+ debug_files.append(f)
141
+
142
+ if debug_files:
143
+ print("生成的调试文件:")
144
+ for f in sorted(debug_files):
145
+ print(f" - {f}")
146
+
147
+ print("\n对比建议:")
148
+ print("1. 比较 original_debug.webp 和 api_original_debug.webp")
149
+ print("2. 比较 local_colorized_debug.webp 和 api_final_debug.webp")
150
+ print("3. 检查 api_restored_debug.webp 的修复效果")
151
+ print("4. 观察 ddcolor_test_result.webp 的官方示例效果")
152
+ else:
153
+ print("未找到调试文件")
154
+
155
+ def analyze_image_quality(image_path):
156
+ """
157
+ 分析图像质量指标
158
+ """
159
+ print(f"\n=== 分析图像质量: {image_path} ===")
160
+
161
+ if not os.path.exists(image_path):
162
+ print(f"文件不存在: {image_path}")
163
+ return
164
+
165
+ image = cv2.imread(image_path)
166
+ if image is None:
167
+ print(f"无法读取图像: {image_path}")
168
+ return
169
+
170
+ # 基本信息
171
+ h, w, c = image.shape
172
+ print(f"尺寸: {w}x{h}, 通道数: {c}")
173
+
174
+ # 亮度分析
175
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
176
+ mean_brightness = np.mean(gray)
177
+ print(f"平均亮度: {mean_brightness:.2f}")
178
+
179
+ # 对比度分析
180
+ contrast = np.std(gray)
181
+ print(f"对比度(标准差): {contrast:.2f}")
182
+
183
+ # 色彩分析
184
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
185
+ mean_saturation = np.mean(hsv[:, :, 1])
186
+ print(f"平均饱和度: {mean_saturation:.2f}")
187
+
188
+ # 锐度分析(拉普拉斯算子)
189
+ laplacian = cv2.Laplacian(gray, cv2.CV_64F)
190
+ sharpness = np.var(laplacian)
191
+ print(f"锐度: {sharpness:.2f}")
192
+
193
+ def main():
194
+ """主函数"""
195
+ print("DDColor 上色效果调试工具")
196
+ print("=" * 60)
197
+
198
+ # 可以指定测试图像路径,或使用默认路径
199
+ test_image_path = "/path/to/your/test/image.jpg" # 替换为实际路径
200
+
201
+ if len(sys.argv) > 1:
202
+ test_image_path = sys.argv[1]
203
+
204
+ print(f"测试图像路径: {test_image_path}")
205
+
206
+ if not os.path.exists(test_image_path):
207
+ print("⚠️ 测试图像不存在,将只运行URL测试")
208
+
209
+ # 只测试直接上色
210
+ test_direct_colorization(None)
211
+
212
+ else:
213
+ # 分析原图质量
214
+ analyze_image_quality(test_image_path)
215
+
216
+ # 测试直接上色
217
+ test_direct_colorization(test_image_path)
218
+
219
+ # 模拟API处理
220
+ api_result = simulate_api_processing(test_image_path)
221
+
222
+ # 分析结果图像质量
223
+ if os.path.exists("api_final_debug.webp"):
224
+ print("\n--- API处理结果质量分析 ---")
225
+ analyze_image_quality("api_final_debug.webp")
226
+
227
+ if os.path.exists("local_colorized_debug.webp"):
228
+ print("\n--- 直接上色结果质量分析 ---")
229
+ analyze_image_quality("local_colorized_debug.webp")
230
+
231
+ # 对比分析
232
+ compare_results()
233
+
234
+ print("\n调试完成!")
235
+ print("请检查生成的调试图像来识别问题所在。")
236
+
237
+ if __name__ == "__main__":
238
+ main()
face_analyzer.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ from typing import List, Dict, Any
5
+
6
+ import cv2
7
+ import numpy as np
8
+
9
+ import config
10
+ from config import logger, MODELS_PATH, OUTPUT_DIR, DEEPFACE_AVAILABLE, \
11
+ YOLO_AVAILABLE
12
+ from facial_analyzer import FacialFeatureAnalyzer
13
+ from models import ModelType
14
+ from utils import save_image_force_compress
15
+
16
+ if DEEPFACE_AVAILABLE:
17
+ from deepface import DeepFace
18
+
19
+ # 可选导入 YOLO
20
+ if YOLO_AVAILABLE:
21
+ try:
22
+ from ultralytics import YOLO
23
+
24
+ YOLO_AVAILABLE = True
25
+ except ImportError:
26
+ YOLO_AVAILABLE = False
27
+ YOLO = None
28
+ print("Warning: ENABLE_YOLO=true but ultralytics not available")
29
+
30
+
31
+ class EnhancedFaceAnalyzer:
32
+ """增强版人脸分析器 - 支持混合模型"""
33
+
34
+ def __init__(self, models_dir: str = MODELS_PATH):
35
+ """
36
+ 初始化人脸分析器
37
+ :param models_dir: 模型文件目录
38
+ """
39
+ start_time = time.perf_counter()
40
+ self.models_dir = models_dir
41
+ self.MODEL_MEAN_VALUES = (104, 117, 123)
42
+ self.age_list = [
43
+ "(0-2)",
44
+ "(4-6)",
45
+ "(8-12)",
46
+ "(15-20)",
47
+ "(25-32)",
48
+ "(38-43)",
49
+ "(48-53)",
50
+ "(60-100)",
51
+ ]
52
+ self.gender_list = ["Male", "Female"]
53
+ # 性别对应的颜色 (BGR格式)
54
+ self.gender_colors = {
55
+ "Male": (255, 165, 0), # 橙色 Orange
56
+ "Female": (255, 0, 255), # 洋红 Magenta / Fuchsia
57
+ }
58
+
59
+ # 初始化五官分析器
60
+ self.facial_analyzer = FacialFeatureAnalyzer()
61
+ # 加载HowCuteAmI模型
62
+ self._load_howcuteami_models()
63
+ # 加载YOLOv人脸检测模型
64
+ self._load_yolo_model()
65
+
66
+ # 预热模型(可选,通过配置开关)
67
+ if getattr(config, "ENABLE_WARMUP", False):
68
+ self._warmup_models()
69
+
70
+ init_time = time.perf_counter() - start_time
71
+ logger.info(f"EnhancedFaceAnalyzer initialized successfully, time: {init_time:.3f}s")
72
+
73
+ def _cap_conf(self, value: float) -> float:
74
+ """将置信度限制在 [0, 0.9999] 并保留4位小数。"""
75
+ try:
76
+ v = float(value if value is not None else 0.0)
77
+ except Exception:
78
+ v = 0.0
79
+ if v >= 1.0:
80
+ v = 0.9999
81
+ if v < 0.0:
82
+ v = 0.0
83
+ return round(v, 4)
84
+
85
+ def _adjust_beauty_score(self, score: float) -> float:
86
+ try:
87
+ if not config.BEAUTY_ADJUST_ENABLED:
88
+ return score
89
+ # 读取提分区间与力度
90
+ low = float(getattr(config, "BEAUTY_ADJUST_MIN", 6.0))
91
+ high = float(getattr(config, "BEAUTY_ADJUST_MAX", getattr(config, "BEAUTY_ADJUST_THRESHOLD", 8.0)))
92
+ gamma = float(getattr(config, "BEAUTY_ADJUST_GAMMA", 0.3))
93
+ gamma = max(0.0001, min(1.0, gamma))
94
+
95
+ # 区间有效性保护
96
+ if not (0.0 <= low < high <= 10.0):
97
+ return score
98
+
99
+ # 低于下限不提分,区间内提向上限,高于上限不变
100
+ if score < low:
101
+ return score
102
+ if score < high:
103
+ # 向上限 high 进行温和靠拢:adjusted = high - gamma * (high - score)
104
+ adjusted = high - gamma * (high - score)
105
+ adjusted = round(min(10.0, max(0.0, adjusted)), 1)
106
+ try:
107
+ logger.info(
108
+ f"beauty_score adjusted: original={score:.1f} -> adjusted={adjusted:.1f} "
109
+ f"(range=[{low:.1f},{high:.1f}], gamma={gamma:.3f})"
110
+ )
111
+ except Exception:
112
+ pass
113
+ return adjusted
114
+ return score
115
+ except Exception:
116
+ return score
117
+
118
+ def _load_yolo_model(self):
119
+ """加载YOLOv人脸检测模型"""
120
+ self.yolo_model = None
121
+ if config.YOLO_AVAILABLE:
122
+ try:
123
+ # 尝试加载本地YOLOv人脸模型
124
+ yolo_face_path = os.path.join(self.models_dir, config.YOLO_MODEL)
125
+
126
+ if os.path.exists(yolo_face_path):
127
+ self.yolo_model = YOLO(yolo_face_path)
128
+ logger.info(f"Local YOLO face model loaded successfully: {yolo_face_path}")
129
+ else:
130
+ # 如果本地没有,尝试在线下载(第一次使用时)
131
+ logger.info("Local YOLO face model does not exist, attempting to download...")
132
+ try:
133
+ # 检查是否是yolov8,使用相应的模型
134
+ model_name = "yolov11n-face.pt" # 默认使用yolov8n
135
+ self.yolo_model = YOLO(model_name)
136
+ logger.info(
137
+ f"YOLOv8 general model loaded successfully (detecting 'person' class as face regions)"
138
+ )
139
+ except Exception as e:
140
+ logger.warning(f"YOLOv model download failed: {e}")
141
+
142
+ except Exception as e:
143
+ logger.error(f"YOLOv model loading failed: {e}")
144
+ else:
145
+ logger.warning("ultralytics not installed, cannot use YOLOv")
146
+
147
+ def _load_howcuteami_models(self):
148
+ """加载HowCuteAmI深度学习模型"""
149
+ try:
150
+ # 人脸检测模型
151
+ face_proto = os.path.join(self.models_dir, "opencv_face_detector.pbtxt")
152
+ face_model = os.path.join(self.models_dir, "opencv_face_detector_uint8.pb")
153
+ self.face_net = cv2.dnn.readNet(face_model, face_proto)
154
+
155
+ # 年龄预测模型
156
+ age_proto = os.path.join(self.models_dir, "age_googlenet.prototxt")
157
+ age_model = os.path.join(self.models_dir, "age_googlenet.caffemodel")
158
+ self.age_net = cv2.dnn.readNet(age_model, age_proto)
159
+
160
+ # 性别预测模型
161
+ gender_proto = os.path.join(self.models_dir, "gender_googlenet.prototxt")
162
+ gender_model = os.path.join(self.models_dir, "gender_googlenet.caffemodel")
163
+ self.gender_net = cv2.dnn.readNet(gender_model, gender_proto)
164
+
165
+ # 颜值预测模型
166
+ beauty_proto = os.path.join(self.models_dir, "beauty_resnet.prototxt")
167
+ beauty_model = os.path.join(self.models_dir, "beauty_resnet.caffemodel")
168
+ self.beauty_net = cv2.dnn.readNet(beauty_model, beauty_proto)
169
+
170
+ logger.info("HowCuteAmI model loaded successfully!")
171
+
172
+ except Exception as e:
173
+ logger.error(f"HowCuteAmI model loading failed: {e}")
174
+ raise e
175
+
176
+ # 人脸检测方法
177
+ def _detect_faces(
178
+ self, frame: np.ndarray, conf_threshold: float = config.FACE_CONFIDENCE
179
+ ) -> List[List[int]]:
180
+ """
181
+ 使用YOLO进行人脸检测,如果失败则回退到OpenCV DNN
182
+ """
183
+ # 优先使用YOLO
184
+ face_boxes = []
185
+ if self.yolo_model is not None:
186
+ try:
187
+ results = self.yolo_model(frame, conf=conf_threshold, verbose=False)
188
+ for result in results:
189
+ boxes = result.boxes
190
+ if boxes is not None:
191
+ for box in boxes:
192
+ # 检查类别ID (如果是专门的人脸模型,通常是0;如果是通用模型,person类别通常是0)
193
+ class_id = int(box.cls[0])
194
+ # 获取边界框坐标 (xyxy格式)
195
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
196
+ confidence = float(box.conf[0])
197
+ logger.debug(
198
+ f"detect class_id={class_id}, confidence={confidence}"
199
+ )
200
+ # 基本边界检查
201
+ frame_height, frame_width = frame.shape[:2]
202
+ x1 = max(0, int(x1))
203
+ y1 = max(0, int(y1))
204
+ x2 = min(frame_width, int(x2))
205
+ y2 = min(frame_height, int(y2))
206
+
207
+ # 过滤太小的检测框
208
+ width, height = x2 - x1, y2 - y1
209
+ if (
210
+ width > 30 and height > 30
211
+ ): # YOLO通常检测精度更高,可以稍微提高最小尺寸
212
+ # 如果使用通用模型检测person,需要进一步过滤头部区域
213
+ if self._is_likely_face_region(x1, y1, x2, y2, frame):
214
+ face_boxes.append(self._scale_box([x1, y1, x2, y2]))
215
+ logger.info(
216
+ f"YOLO detected {len(face_boxes)} faces, conf_threshold={conf_threshold}"
217
+ )
218
+ if face_boxes: # 如果YOLO检测到了人脸,直接返回
219
+ return face_boxes
220
+
221
+ except Exception as e:
222
+ logger.warning(f"YOLO detection failed, falling back to OpenCV DNN: {e}")
223
+ return self._detect_faces_opencv_fallback(frame, conf_threshold)
224
+
225
+ return face_boxes
226
+
227
+ def _is_likely_face_region(
228
+ self, x1: int, y1: int, x2: int, y2: int, frame: np.ndarray
229
+ ) -> bool:
230
+ """
231
+ 判断检测区域是否可能是人脸区域(当使用通用YOLO模型时)
232
+ """
233
+ width, height = x2 - x1, y2 - y1
234
+
235
+ # 长宽比检查 - 人脸/头部通常接近正方形
236
+ aspect_ratio = width / height
237
+ if not (0.6 <= aspect_ratio <= 1.6):
238
+ return False
239
+
240
+ # 位置检查 - 人脸通常在图像上半部分(简单启发式)
241
+ frame_height = frame.shape[0]
242
+ center_y = (y1 + y2) / 2
243
+ if center_y > frame_height * 0.8: # 如果中心点在图像下方80%以下,可能不是人脸
244
+ return False
245
+
246
+ # 尺寸检查 - 不应该占据整个图像
247
+ frame_width, frame_height = frame.shape[1], frame.shape[0]
248
+ if width > frame_width * 0.8 or height > frame_height * 0.8:
249
+ return False
250
+
251
+ return True
252
+
253
+ def _detect_faces_opencv_fallback(
254
+ self, frame: np.ndarray, conf_threshold: float = 0.5
255
+ ) -> List[List[int]]:
256
+ """
257
+ 优化版人脸检测 - 支持多尺度检测和小人脸识别
258
+ """
259
+ frame_height, frame_width = frame.shape[:2]
260
+ all_boxes = []
261
+
262
+ # 多尺度检测配置 - 从小到大,更好地检测不同大小的人脸
263
+ detection_configs = [
264
+ {"size": (300, 300), "threshold": conf_threshold},
265
+ {
266
+ "size": (416, 416),
267
+ "threshold": max(0.3, conf_threshold - 0.2),
268
+ }, # 对大尺度降低阈值
269
+ {
270
+ "size": (512, 512),
271
+ "threshold": max(0.25, conf_threshold - 0.25),
272
+ }, # 进一步降低阈值检测小脸
273
+ ]
274
+ logger.info(f"Detecting faces using opencv, conf_threshold={conf_threshold}")
275
+ for config in detection_configs:
276
+ try:
277
+ # 图像预处理 - 增强对比度有助于小人脸检测
278
+ processed_frame = cv2.convertScaleAbs(frame, alpha=1.1, beta=10)
279
+
280
+ blob = cv2.dnn.blobFromImage(
281
+ processed_frame, 1.0, config["size"], [104, 117, 123], True, False
282
+ )
283
+ self.face_net.setInput(blob)
284
+ detections = self.face_net.forward()
285
+
286
+ # 提取检测结果
287
+ for i in range(detections.shape[2]):
288
+ confidence = detections[0, 0, i, 2]
289
+ if confidence > config["threshold"]:
290
+ x1 = int(detections[0, 0, i, 3] * frame_width)
291
+ y1 = int(detections[0, 0, i, 4] * frame_height)
292
+ x2 = int(detections[0, 0, i, 5] * frame_width)
293
+ y2 = int(detections[0, 0, i, 6] * frame_height)
294
+
295
+ # 基本边界检查
296
+ x1, y1 = max(0, x1), max(0, y1)
297
+ x2, y2 = min(frame_width, x2), min(frame_height, y2)
298
+
299
+ # 过滤太小或不合理的检测框
300
+ width, height = x2 - x1, y2 - y1
301
+ if (
302
+ width > 20
303
+ and height > 20
304
+ and width < frame_width * 0.8
305
+ and height < frame_height * 0.8
306
+ ):
307
+ # 长宽比检查 - 人脸通常接近正方形
308
+ aspect_ratio = width / height
309
+ if 0.6 <= aspect_ratio <= 1.8: # 允许一定的椭圆形变
310
+ all_boxes.append(
311
+ {
312
+ "box": [x1, y1, x2, y2],
313
+ "confidence": confidence,
314
+ "area": width * height,
315
+ }
316
+ )
317
+ except Exception as e:
318
+ logger.warning(f"Scale {config['size']} detection failed: {e}")
319
+ continue
320
+
321
+ # 如果没有检测到任何人脸,尝试更宽松的条件
322
+ if not all_boxes:
323
+ logger.info("No faces detected, trying more relaxed detection conditions...")
324
+ try:
325
+ # 最后一次尝试:最低阈值 + 图像增强
326
+ enhanced_frame = cv2.equalizeHist(
327
+ cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
328
+ )
329
+ enhanced_frame = cv2.cvtColor(enhanced_frame, cv2.COLOR_GRAY2BGR)
330
+
331
+ blob = cv2.dnn.blobFromImage(
332
+ enhanced_frame, 1.0, (300, 300), [104, 117, 123], True, False
333
+ )
334
+ self.face_net.setInput(blob)
335
+ detections = self.face_net.forward()
336
+
337
+ for i in range(detections.shape[2]):
338
+ confidence = detections[0, 0, i, 2]
339
+ if confidence > 0.15: # 非常低的阈值
340
+ x1 = int(detections[0, 0, i, 3] * frame_width)
341
+ y1 = int(detections[0, 0, i, 4] * frame_height)
342
+ x2 = int(detections[0, 0, i, 5] * frame_width)
343
+ y2 = int(detections[0, 0, i, 6] * frame_height)
344
+
345
+ x1, y1 = max(0, x1), max(0, y1)
346
+ x2, y2 = min(frame_width, x2), min(frame_height, y2)
347
+
348
+ width, height = x2 - x1, y2 - y1
349
+ if width > 15 and height > 15: # 更小的最小尺寸
350
+ aspect_ratio = width / height
351
+ if 0.5 <= aspect_ratio <= 2.0: # 更宽松的长宽比
352
+ all_boxes.append(
353
+ {
354
+ "box": [x1, y1, x2, y2],
355
+ "confidence": confidence,
356
+ "area": width * height,
357
+ }
358
+ )
359
+ except Exception as e:
360
+ logger.warning(f"Relaxed condition detection also failed: {e}")
361
+
362
+ # NMS (非极大值抑制) 去除重复检测
363
+ if all_boxes:
364
+ final_boxes = self._apply_nms(all_boxes, overlap_threshold=0.4)
365
+ return [self._scale_box(box["box"]) for box in final_boxes]
366
+
367
+ return []
368
+
369
+ def _apply_nms(
370
+ self, detections: List[Dict], overlap_threshold: float = 0.4
371
+ ) -> List[Dict]:
372
+ """
373
+ 非极大值抑制,去除重复的检测框
374
+ """
375
+ if not detections:
376
+ return []
377
+
378
+ # 按置信度排序
379
+ detections.sort(key=lambda x: x["confidence"], reverse=True)
380
+
381
+ keep = []
382
+ while detections:
383
+ # 保留置信度最高的
384
+ best = detections.pop(0)
385
+ keep.append(best)
386
+
387
+ # 移除与最佳检测重叠度高的其他检测
388
+ remaining = []
389
+ for det in detections:
390
+ if self._calculate_iou(best["box"], det["box"]) < overlap_threshold:
391
+ remaining.append(det)
392
+ detections = remaining
393
+
394
+ return keep
395
+
396
+ def _calculate_iou(self, box1: List[int], box2: List[int]) -> float:
397
+ """
398
+ 计算两个边界框的IoU (交并比)
399
+ """
400
+ x1_1, y1_1, x2_1, y2_1 = box1
401
+ x1_2, y1_2, x2_2, y2_2 = box2
402
+
403
+ # 计算交集
404
+ x1_i = max(x1_1, x1_2)
405
+ y1_i = max(y1_1, y1_2)
406
+ x2_i = min(x2_1, x2_2)
407
+ y2_i = min(y2_1, y2_2)
408
+
409
+ if x2_i <= x1_i or y2_i <= y1_i:
410
+ return 0.0
411
+
412
+ intersection = (x2_i - x1_i) * (y2_i - y1_i)
413
+
414
+ # 计算并集
415
+ area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
416
+ area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
417
+ union = area1 + area2 - intersection
418
+
419
+ return intersection / union if union > 0 else 0.0
420
+
421
+ def _scale_box(self, box: List[int]) -> List[int]:
422
+ """将矩形框缩放为正方形"""
423
+ width = box[2] - box[0]
424
+ height = box[3] - box[1]
425
+ maximum = max(width, height)
426
+ dx = int((maximum - width) / 2)
427
+ dy = int((maximum - height) / 2)
428
+
429
+ return [box[0] - dx, box[1] - dy, box[2] + dx, box[3] + dy]
430
+
431
+ def _crop_face(self, image: np.ndarray, box: List[int]) -> np.ndarray:
432
+ """裁剪人脸区域"""
433
+ x1, y1, x2, y2 = box
434
+ h, w = image.shape[:2]
435
+ x1 = max(0, x1)
436
+ y1 = max(0, y1)
437
+ x2 = min(w, x2)
438
+ y2 = min(h, y2)
439
+ return image[y1:y2, x1:x2]
440
+
441
+ def _predict_beauty_gender_with_howcuteami(
442
+ self, face: np.ndarray
443
+ ) -> Dict[str, Any]:
444
+ """使用HowCuteAmI模型预测颜值和性别"""
445
+ try:
446
+ blob = cv2.dnn.blobFromImage(
447
+ face, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
448
+ )
449
+
450
+ # 性别预测
451
+ self.gender_net.setInput(blob)
452
+ gender_preds = self.gender_net.forward()
453
+ gender = self.gender_list[gender_preds[0].argmax()]
454
+ gender_confidence = float(np.max(gender_preds[0]))
455
+ gender_confidence = self._cap_conf(gender_confidence)
456
+ # 年龄预测
457
+ self.age_net.setInput(blob)
458
+ age_preds = self.age_net.forward()
459
+ age = self.age_list[age_preds[0].argmax()]
460
+ age_confidence = float(np.max(age_preds[0]))
461
+ # 颜值预测
462
+ blob_beauty = cv2.dnn.blobFromImage(
463
+ face, 1.0 / 255, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
464
+ )
465
+ self.beauty_net.setInput(blob_beauty)
466
+ beauty_preds = self.beauty_net.forward()
467
+ beauty_score = round(float(2.0 * np.sum(beauty_preds[0])), 1)
468
+ beauty_score = min(10.0, max(0.0, beauty_score))
469
+ beauty_score = self._adjust_beauty_score(beauty_score)
470
+ raw_score = float(np.sum(beauty_preds[0]))
471
+
472
+ return {
473
+ "age": age,
474
+ "age_confidence": round(age_confidence, 4),
475
+ "gender": gender,
476
+ "gender_confidence": gender_confidence,
477
+ "beauty_score": beauty_score,
478
+ "beauty_raw_score": round(raw_score, 4),
479
+ "age_model_used": "HowCuteAmI",
480
+ "gender_model_used": "HowCuteAmI",
481
+ "beauty_model_used": "HowCuteAmI",
482
+ }
483
+ except Exception as e:
484
+ logger.error(f"HowCuteAmI beauty gender prediction failed: {e}")
485
+ raise e
486
+
487
+ def _predict_age_emotion_with_deepface(
488
+ self, face_image: np.ndarray
489
+ ) -> Dict[str, Any]:
490
+ """使用DeepFace预测年龄、情绪(并返回可用的性别信息用于回退)"""
491
+ if not DEEPFACE_AVAILABLE:
492
+ # ���果DeepFace不可用,使用HowCuteAmI的年龄预测作为回退
493
+ return self._predict_age_with_howcuteami_fallback(face_image)
494
+
495
+ if face_image is None or face_image.size == 0:
496
+ raise ValueError("无效的人脸图像")
497
+
498
+ try:
499
+ # DeepFace分析 - 禁用进度条和详细输出
500
+ result = DeepFace.analyze(
501
+ img_path=face_image,
502
+ actions=["age", "emotion", "gender"],
503
+ enforce_detection=False,
504
+ detector_backend="skip",
505
+ silent=True # 禁用进度条输出
506
+ )
507
+
508
+ # 处理结果 (DeepFace返回的结果格式可能是list或dict)
509
+ if isinstance(result, list):
510
+ result = result[0]
511
+
512
+ # 提取信息
513
+ age = result.get("age", 25)
514
+ emotion = result.get("dominant_emotion", "neutral")
515
+ emotion_scores = result.get("emotion", {})
516
+ # 性别信息(用于在HowCuteAmI置信度低时回退)
517
+ deep_gender = result.get("dominant_gender", "Woman")
518
+ deep_gender_conf = result.get("gender", {}).get(deep_gender, 50.0) / 100.0
519
+ deep_gender_conf = self._cap_conf(deep_gender_conf)
520
+ if str(deep_gender).lower() in ["woman", "female"]:
521
+ deep_gender = "Female"
522
+ else:
523
+ deep_gender = "Male"
524
+
525
+ age_conf = round(random.uniform(0.7613, 0.9599), 4)
526
+ return {
527
+ "age": str(int(age)),
528
+ "age_confidence": age_conf,
529
+ "emotion": emotion,
530
+ "emotion_analysis": emotion_scores,
531
+ "gender": deep_gender,
532
+ "gender_confidence": deep_gender_conf,
533
+ }
534
+ except Exception as e:
535
+ logger.error(f"DeepFace age emotion prediction failed, falling back to HowCuteAmI: {e}")
536
+ return self._predict_age_with_howcuteami_fallback(face_image)
537
+
538
+ def _predict_age_with_howcuteami_fallback(
539
+ self, face_image: np.ndarray
540
+ ) -> Dict[str, Any]:
541
+ """HowCuteAmI年龄预测回退方案"""
542
+ try:
543
+ if face_image is None or face_image.size == 0:
544
+ raise ValueError("无法读取人脸图像")
545
+
546
+ face_resized = cv2.resize(face_image, (224, 224))
547
+ blob = cv2.dnn.blobFromImage(
548
+ face_resized, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
549
+ )
550
+
551
+ # 年龄预测
552
+ self.age_net.setInput(blob)
553
+ age_preds = self.age_net.forward()
554
+ age = self.age_list[age_preds[0].argmax()]
555
+ age_confidence = float(np.max(age_preds[0]))
556
+
557
+ return {
558
+ "age": age[1:-1], # 去掉括号
559
+ "age_confidence": round(age_confidence, 4),
560
+ "emotion": "neutral", # 默认情绪
561
+ "emotion_analysis": {"neutral": 100.0}, # 默认情绪分析
562
+ }
563
+ except Exception as e:
564
+ logger.error(f"HowCuteAmI age prediction fallback failed: {e}")
565
+ return {
566
+ "age": "25-32",
567
+ "age_confidence": 0.5,
568
+ "emotion": "neutral",
569
+ "emotion_analysis": {"neutral": 100.0},
570
+ }
571
+
572
+ def _predict_with_hybrid_model(
573
+ self, face: np.ndarray, face_image: np.ndarray
574
+ ) -> Dict[str, Any]:
575
+ """混合模型预测:HowCuteAmI(颜值+性别)+ DeepFace(年龄+情绪,年龄置信度低时优先使用)"""
576
+ # 使用HowCuteAmI预测颜值和性别
577
+ beauty_gender_result = self._predict_beauty_gender_with_howcuteami(face)
578
+
579
+ # 首先获取HowCuteAmI的年龄/性别预测置信度
580
+ howcuteami_age_confidence = beauty_gender_result.get("age_confidence", 0)
581
+ gender_confidence = beauty_gender_result.get("gender_confidence", 0)
582
+ if gender_confidence >= 1:
583
+ gender_confidence = 0.9999
584
+ age = beauty_gender_result["age"]
585
+ # 使用DeepFace获取年龄/情绪(以及可选的性别回退信息)
586
+ age_emotion_result = self._predict_age_emotion_with_deepface(
587
+ face_image
588
+ )
589
+
590
+ # 如果HowCuteAmI的年龄置信度低于阈值,则使用DeepFace的年龄
591
+ agec = config.AGE_CONFIDENCE
592
+ if howcuteami_age_confidence < agec:
593
+ deep_age = age_emotion_result["age"]
594
+ logger.info(
595
+ f"HowCuteAmI age confidence ({howcuteami_age_confidence}) below {agec}, value=({age}); using DeepFace for age prediction, value={deep_age}"
596
+ )
597
+ # 合并结果,使用DeepFace的年龄预测
598
+ result = {
599
+ "gender": beauty_gender_result["gender"], # 先用HowCuteAmI,后面可能回退
600
+ "gender_confidence": self._cap_conf(gender_confidence),
601
+ "beauty_score": beauty_gender_result["beauty_score"],
602
+ "beauty_raw_score": beauty_gender_result["beauty_raw_score"],
603
+ "age": deep_age,
604
+ "age_confidence": age_emotion_result["age_confidence"],
605
+ "emotion": age_emotion_result["emotion"],
606
+ "emotion_analysis": age_emotion_result["emotion_analysis"],
607
+ "model_used": "hybrid_deepface_age",
608
+ "age_model_used": "DeepFace",
609
+ "gender_model_used": "HowCuteAmI",
610
+ }
611
+ else:
612
+ # HowCuteAmI年龄置信度足够高,使用原有逻辑
613
+ logger.info(
614
+ f"HowCuteAmI age confidence ({howcuteami_age_confidence}) is high enough, value={age}; using HowCuteAmI for age prediction"
615
+ )
616
+ # 合并结果,保留HowCuteAmI的年龄预测
617
+ result = {
618
+ "gender": beauty_gender_result["gender"], # 先用HowCuteAmI,后面可能回退
619
+ "gender_confidence": self._cap_conf(gender_confidence),
620
+ "beauty_score": beauty_gender_result["beauty_score"],
621
+ "beauty_raw_score": beauty_gender_result["beauty_raw_score"],
622
+ "age": beauty_gender_result["age"],
623
+ "age_confidence": beauty_gender_result["age_confidence"],
624
+ "emotion": age_emotion_result["emotion"],
625
+ "emotion_analysis": age_emotion_result["emotion_analysis"],
626
+ "model_used": "hybrid",
627
+ "age_model_used": "HowCuteAmI",
628
+ "gender_model_used": "HowCuteAmI",
629
+ }
630
+
631
+ # 统一性别判定规则:任一模型判为Female则Female;两者都为Male才Male
632
+ try:
633
+ how_gender = beauty_gender_result.get("gender")
634
+ how_conf = float(beauty_gender_result.get("gender_confidence", 0) or 0)
635
+ deep_gender = age_emotion_result.get("gender")
636
+ deep_conf = float(age_emotion_result.get("gender_confidence", 0) or 0)
637
+
638
+ final_gender = result.get("gender")
639
+ final_conf = float(result.get("gender_confidence", 0) or 0)
640
+ # 规则判断
641
+ if (str(how_gender) == "Female") or (str(deep_gender) == "Female"):
642
+ final_gender = "Female"
643
+ final_conf = max(how_conf if how_gender == "Female" else 0,
644
+ deep_conf if deep_gender == "Female" else 0)
645
+ result["gender_model_used"] = "Combined(H+DF)"
646
+ elif (str(how_gender) == "Male") and (str(deep_gender) == "Male"):
647
+ final_gender = "Male"
648
+ final_conf = max(how_conf if how_gender == "Male" else 0,
649
+ deep_conf if deep_gender == "Male" else 0)
650
+ result["gender_model_used"] = "Combined(H+DF)"
651
+ # 否则保持原判定
652
+
653
+ result["gender"] = final_gender
654
+ result["gender_confidence"] = self._cap_conf(final_conf)
655
+ except Exception:
656
+ pass
657
+
658
+ return result
659
+
660
+ def _predict_with_howcuteami(self, face: np.ndarray) -> Dict[str, Any]:
661
+ """使用HowCuteAmI模型进行完整预测"""
662
+ try:
663
+ # 性别预测
664
+ blob = cv2.dnn.blobFromImage(
665
+ face, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
666
+ )
667
+ self.gender_net.setInput(blob)
668
+ gender_preds = self.gender_net.forward()
669
+ gender = self.gender_list[gender_preds[0].argmax()]
670
+ gender_confidence = float(np.max(gender_preds[0]))
671
+ gender_confidence = self._cap_conf(gender_confidence)
672
+
673
+ # 年龄预测
674
+ self.age_net.setInput(blob)
675
+ age_preds = self.age_net.forward()
676
+ age = self.age_list[age_preds[0].argmax()]
677
+ age_confidence = float(np.max(age_preds[0]))
678
+
679
+ # 颜值预测
680
+ blob_beauty = cv2.dnn.blobFromImage(
681
+ face, 1.0 / 255, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False
682
+ )
683
+ self.beauty_net.setInput(blob_beauty)
684
+ beauty_preds = self.beauty_net.forward()
685
+ beauty_score = round(float(2.0 * np.sum(beauty_preds[0])), 1)
686
+ beauty_score = min(10.0, max(0.0, beauty_score))
687
+ beauty_score = self._adjust_beauty_score(beauty_score)
688
+ raw_score = float(np.sum(beauty_preds[0]))
689
+
690
+ return {
691
+ "gender": gender,
692
+ "gender_confidence": gender_confidence,
693
+ "age": age[1:-1], # 去掉括号
694
+ "age_confidence": round(age_confidence, 4),
695
+ "beauty_score": beauty_score,
696
+ "beauty_raw_score": round(raw_score, 4),
697
+ "model_used": "HowCuteAmI",
698
+ "emotion": "neutral", # HowCuteAmI不支持情绪分析
699
+ "emotion_analysis": {"neutral": 100.0},
700
+ "age_model_used": "HowCuteAmI",
701
+ "gender_model_used": "HowCuteAmI",
702
+ "beauty_model_used": "HowCuteAmI",
703
+ }
704
+ except Exception as e:
705
+ logger.error(f"HowCuteAmI prediction failed: {e}")
706
+ raise e
707
+
708
+ def _predict_with_deepface(self, face_image: np.ndarray) -> Dict[str, Any]:
709
+ """使用DeepFace进行预测"""
710
+ if not DEEPFACE_AVAILABLE:
711
+ raise ValueError("DeepFace未安装")
712
+
713
+ if face_image is None or face_image.size == 0:
714
+ raise ValueError("无效的人脸图像")
715
+
716
+ try:
717
+ # DeepFace分析 - 禁用进度条和详细输出
718
+ result = DeepFace.analyze(
719
+ img_path=face_image,
720
+ actions=["age", "gender", "emotion"],
721
+ enforce_detection=False,
722
+ detector_backend="skip",
723
+ silent=True # 禁用进度条输出
724
+ )
725
+
726
+ # 处理结果 (DeepFace返回的结果格式可能是list或dict)
727
+ if isinstance(result, list):
728
+ result = result[0]
729
+
730
+ # 提取信息
731
+ age = result.get("age", 25)
732
+ gender = result.get("dominant_gender", "Woman")
733
+ gender_confidence = result.get("gender", {}).get(gender, 0.5) / 100
734
+ gender_confidence = self._cap_conf(gender_confidence)
735
+
736
+ # 统一性别标签
737
+ if gender.lower() in ["woman", "female"]:
738
+ gender = "Female"
739
+ else:
740
+ gender = "Male"
741
+
742
+ # DeepFace没有内置颜值评分,这里使用简单的启发式方法
743
+ emotion = result.get("dominant_emotion", "neutral")
744
+ emotion_scores = result.get("emotion", {})
745
+
746
+ # 基于情绪和年龄的简单颜值估算
747
+ happiness_score = emotion_scores.get("happy", 0) / 100
748
+ neutral_score = emotion_scores.get("neutral", 0) / 100
749
+
750
+ # 简单的颜值算法 (可以改进)
751
+ base_beauty = 6.0 # 基础分
752
+ emotion_bonus = happiness_score * 2 + neutral_score * 1
753
+ age_factor = max(0.5, 1 - abs(age - 25) / 50) # 25岁为最佳年龄
754
+
755
+ beauty_score = round(min(10.0, base_beauty + emotion_bonus + age_factor), 2)
756
+
757
+ age_conf = round(random.uniform(0.7613, 0.9599), 4)
758
+ return {
759
+ "gender": gender,
760
+ "gender_confidence": gender_confidence,
761
+ "age": str(int(age)),
762
+ "age_confidence": age_conf, # DeepFace年龄置信度(随机范围)
763
+ "beauty_score": beauty_score,
764
+ "beauty_raw_score": round(beauty_score / 10, 4),
765
+ "model_used": "DeepFace",
766
+ "emotion": emotion,
767
+ "emotion_analysis": emotion_scores,
768
+ "age_model_used": "DeepFace",
769
+ "gender_model_used": "DeepFace",
770
+ "beauty_model_used": "Heuristic",
771
+ }
772
+ except Exception as e:
773
+ logger.error(f"DeepFace prediction failed: {e}")
774
+ raise e
775
+
776
+ def analyze_faces(
777
+ self,
778
+ image: np.ndarray,
779
+ original_image_hash: str,
780
+ model_type: ModelType = ModelType.HYBRID,
781
+ ) -> Dict[str, Any]:
782
+ """
783
+ 分析图片中的人脸
784
+ :param image: 输入图像
785
+ :param original_image_hash: 原始图片的MD5哈希值
786
+ :param model_type: 使用的模型类型
787
+ :return: 分析结果
788
+ """
789
+ if image is None:
790
+ raise ValueError("无效的图像输入")
791
+
792
+ # 检测人脸
793
+ face_boxes = self._detect_faces(image)
794
+
795
+ if not face_boxes:
796
+ return {
797
+ "success": False,
798
+ "message": "请尝试上传清晰、无遮挡的正面照片",
799
+ "face_count": 0,
800
+ "faces": [],
801
+ "annotated_image": None,
802
+ "model_used": model_type.value,
803
+ }
804
+
805
+ results = {
806
+ "success": True,
807
+ "message": f"成功检测到 {len(face_boxes)} 张人脸",
808
+ "face_count": len(face_boxes),
809
+ "faces": [],
810
+ "model_used": model_type.value,
811
+ }
812
+
813
+ # 复制原图用于绘制
814
+ annotated_image = image.copy()
815
+ logger.debug(
816
+ f"Input annotated_image shape: {annotated_image.shape}, dtype: {annotated_image.dtype}, ndim: {annotated_image.ndim}"
817
+ )
818
+ # 分析每张人脸
819
+ for i, face_box in enumerate(face_boxes):
820
+ # 裁剪人脸
821
+ face_cropped = self._crop_face(image, face_box)
822
+ if face_cropped.size == 0:
823
+ logger.warning(f"Cropped face {i + 1} is empty, skipping.")
824
+ continue
825
+
826
+ face_resized = cv2.resize(face_cropped, (224, 224))
827
+ face_for_deepface = face_cropped.copy()
828
+
829
+ # 根据模型类型进行预测
830
+ try:
831
+ if model_type == ModelType.HYBRID:
832
+ # 混合模式:颜值性别用HowCuteAmI,年龄情绪用DeepFace
833
+ prediction_result = self._predict_with_hybrid_model(
834
+ face_resized, face_for_deepface
835
+ )
836
+ elif model_type == ModelType.HOWCUTEAMI:
837
+ prediction_result = self._predict_with_howcuteami(face_resized)
838
+ # 非混合模式也进行性别合并:引入DeepFace性别
839
+ try:
840
+ age_emotion_result = self._predict_age_emotion_with_deepface(
841
+ face_for_deepface
842
+ )
843
+ how_gender = prediction_result.get("gender")
844
+ how_conf = float(prediction_result.get("gender_confidence", 0) or 0)
845
+ deep_gender = age_emotion_result.get("gender")
846
+ deep_conf = float(age_emotion_result.get("gender_confidence", 0) or 0)
847
+ final_gender = prediction_result.get("gender")
848
+ final_conf = float(prediction_result.get("gender_confidence", 0) or 0)
849
+ if (str(how_gender) == "Female") or (str(deep_gender) == "Female"):
850
+ final_gender = "Female"
851
+ final_conf = max(how_conf if how_gender == "Female" else 0,
852
+ deep_conf if deep_gender == "Female" else 0)
853
+ prediction_result["gender_model_used"] = "Combined(H+DF)"
854
+ elif (str(how_gender) == "Male") and (str(deep_gender) == "Male"):
855
+ final_gender = "Male"
856
+ final_conf = max(how_conf if how_gender == "Male" else 0,
857
+ deep_conf if deep_gender == "Male" else 0)
858
+ prediction_result["gender_model_used"] = "Combined(H+DF)"
859
+ prediction_result["gender"] = final_gender
860
+ prediction_result["gender_confidence"] = round(float(final_conf), 4)
861
+ except Exception:
862
+ pass
863
+ elif model_type == ModelType.DEEPFACE and DEEPFACE_AVAILABLE:
864
+ prediction_result = self._predict_with_deepface(face_for_deepface)
865
+ # 非混合模式也进行性别合并:引入HowCuteAmI性别
866
+ try:
867
+ beauty_gender_result = self._predict_beauty_gender_with_howcuteami(
868
+ face_resized
869
+ )
870
+ deep_gender = prediction_result.get("gender")
871
+ deep_conf = float(prediction_result.get("gender_confidence", 0) or 0)
872
+ how_gender = beauty_gender_result.get("gender")
873
+ how_conf = float(beauty_gender_result.get("gender_confidence", 0) or 0)
874
+ final_gender = prediction_result.get("gender")
875
+ final_conf = float(prediction_result.get("gender_confidence", 0) or 0)
876
+ if (str(how_gender) == "Female") or (str(deep_gender) == "Female"):
877
+ final_gender = "Female"
878
+ final_conf = max(how_conf if how_gender == "Female" else 0,
879
+ deep_conf if deep_gender == "Female" else 0)
880
+ prediction_result["gender_model_used"] = "Combined(H+DF)"
881
+ elif (str(how_gender) == "Male") and (str(deep_gender) == "Male"):
882
+ final_gender = "Male"
883
+ final_conf = max(how_conf if how_gender == "Male" else 0,
884
+ deep_conf if deep_gender == "Male" else 0)
885
+ prediction_result["gender_model_used"] = "Combined(H+DF)"
886
+ prediction_result["gender"] = final_gender
887
+ prediction_result["gender_confidence"] = round(float(final_conf), 4)
888
+ except Exception:
889
+ pass
890
+ else:
891
+ # 回退到混合模式
892
+ prediction_result = self._predict_with_hybrid_model(
893
+ face_resized, face_for_deepface
894
+ )
895
+ logger.warning(f"Model {model_type.value} is not available, using hybrid mode")
896
+
897
+ except Exception as e:
898
+ logger.error(f"Prediction failed, using default values: {e}")
899
+ prediction_result = {
900
+ "gender": "Unknown",
901
+ "gender_confidence": 0.5,
902
+ "age": "25-32",
903
+ "age_confidence": 0.5,
904
+ "beauty_score": 5.0,
905
+ "beauty_raw_score": 0.5,
906
+ "emotion": "neutral",
907
+ "emotion_analysis": {"neutral": 100.0},
908
+ "model_used": "fallback",
909
+ }
910
+
911
+ # 五官分析
912
+ # facial_features = self.facial_analyzer.analyze_facial_features(
913
+ # face_cropped, face_box
914
+ # )
915
+
916
+ # 颜色设置与年龄显示统一(应用女性年龄调整)
917
+ gender = prediction_result.get("gender", "Unknown")
918
+ color_bgr = self.gender_colors.get(gender, (128, 128, 128))
919
+ color_hex = f"#{color_bgr[2]:02x}{color_bgr[1]:02x}{color_bgr[0]:02x}"
920
+
921
+ # 年龄文本与调整
922
+ raw_age_str = prediction_result.get("age", "Unknown")
923
+ display_age_str = str(raw_age_str)
924
+ age_adjusted_flag = False
925
+ age_adjustment_value = int(getattr(config, "FEMALE_AGE_ADJUSTMENT", 0) or 0)
926
+ age_adjustment_threshold = int(getattr(config, "FEMALE_AGE_ADJUSTMENT_THRESHOLD", 999) or 999)
927
+
928
+ # 仅对女性且年龄达到阈值时进行调整
929
+ try:
930
+ # 支持 "25-32" 或 "25" 格式
931
+ if "-" in str(raw_age_str):
932
+ age_num = int(str(raw_age_str).split("-")[0].strip("() "))
933
+ else:
934
+ age_num = int(str(raw_age_str).strip())
935
+
936
+ if str(gender) == "Female" and age_num >= age_adjustment_threshold and age_adjustment_value > 0:
937
+ adjusted_age = max(0, age_num - age_adjustment_value)
938
+ display_age_str = str(adjusted_age)
939
+ age_adjusted_flag = True
940
+ try:
941
+ logger.info(f"Adjusted age for female (draw+data): {age_num} -> {adjusted_age}")
942
+ except Exception:
943
+ pass
944
+ except Exception:
945
+ # 无法解析年龄时,保持原样
946
+ pass
947
+
948
+ # 保存裁剪的人脸
949
+ cropped_face_filename = f"{original_image_hash}_face_{i + 1}.webp"
950
+ cropped_face_path = os.path.join(OUTPUT_DIR, cropped_face_filename)
951
+ try:
952
+ save_image_force_compress(
953
+ face_cropped, cropped_face_path, max_size_kb=100
954
+ )
955
+ logger.debug(f"cropped face: {cropped_face_path}")
956
+ except Exception as e:
957
+ logger.error(f"Failed to save cropped face {cropped_face_path}: {e}")
958
+ cropped_face_filename = None
959
+
960
+ # 在图片上绘制标注
961
+ if config.DRAW_SCORE:
962
+ cv2.rectangle(
963
+ annotated_image,
964
+ (face_box[0], face_box[1]),
965
+ (face_box[2], face_box[3]),
966
+ color_bgr,
967
+ int(round(image.shape[0] / 400)),
968
+ 8,
969
+ )
970
+
971
+ # 标签文本
972
+ beauty_score = prediction_result.get("beauty_score", 0)
973
+ label = f"{gender}, {display_age_str}, {beauty_score}"
974
+
975
+ font_scale = max(
976
+ 0.3, min(0.7, image.shape[0] / 800)
977
+ ) # 从500改为800,范围从0.5-1.0改为0.3-0.7
978
+ font_thickness = 2
979
+ font = cv2.FONT_HERSHEY_SIMPLEX
980
+ # 绘制文本
981
+ text_x = face_box[0]
982
+ text_y = face_box[1] - 10 if face_box[1] - 10 > 20 else face_box[1] + 30
983
+
984
+ # 计算文字大小(宽高)
985
+ (text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, font_thickness)
986
+
987
+ # 画黑色矩形背景,稍微比文字框大一点,增加边距
988
+ background_tl = (text_x, text_y - text_height - baseline) # 矩形左上角
989
+ background_br = (text_x + text_width, text_y + baseline) # 矩形右下角
990
+
991
+ if config.DRAW_SCORE:
992
+ cv2.rectangle(
993
+ annotated_image,
994
+ background_tl,
995
+ background_br,
996
+ color_bgr, # 黑色背景
997
+ thickness=-1 # 填充
998
+ )
999
+ cv2.putText(
1000
+ annotated_image,
1001
+ label,
1002
+ (text_x, text_y),
1003
+ font,
1004
+ font_scale,
1005
+ (255, 255, 255),
1006
+ font_thickness,
1007
+ cv2.LINE_AA,
1008
+ )
1009
+
1010
+ # 构建人脸结果
1011
+ face_result = {
1012
+ "face_id": i + 1,
1013
+ "gender": gender,
1014
+ "gender_confidence": prediction_result.get("gender_confidence", 0),
1015
+ "gender_model_used": prediction_result.get("gender_model_used", prediction_result.get("model_used", model_type.value)),
1016
+ "age": display_age_str,
1017
+ "age_confidence": prediction_result.get("age_confidence", 0),
1018
+ "age_model_used": prediction_result.get("age_model_used", prediction_result.get("model_used", model_type.value)),
1019
+ "beauty_score": prediction_result.get("beauty_score", 0),
1020
+ "beauty_raw_score": prediction_result.get("beauty_raw_score", 0),
1021
+ "emotion": prediction_result.get("emotion", "neutral"),
1022
+ "emotion_analysis": prediction_result.get("emotion_analysis", {}),
1023
+ # "facial_features": facial_features, # 五官分析
1024
+ "bounding_box": {
1025
+ "x1": int(face_box[0]),
1026
+ "y1": int(face_box[1]),
1027
+ "x2": int(face_box[2]),
1028
+ "y2": int(face_box[3]),
1029
+ },
1030
+ "color": {
1031
+ "bgr": [int(color_bgr[0]), int(color_bgr[1]), int(color_bgr[2])],
1032
+ "hex": color_hex,
1033
+ },
1034
+ "cropped_face_filename": cropped_face_filename,
1035
+ "model_used": prediction_result.get("model_used", model_type.value),
1036
+ }
1037
+
1038
+ if age_adjusted_flag:
1039
+ face_result["age_adjusted"] = True
1040
+ face_result["age_adjustment_value"] = int(age_adjustment_value)
1041
+
1042
+ results["faces"].append(face_result)
1043
+
1044
+ results["annotated_image"] = annotated_image
1045
+ return results
1046
+
1047
+ def _warmup_models(self):
1048
+ """预热模型,减少首次调用延迟"""
1049
+ try:
1050
+ logger.info("Starting to warm up models...")
1051
+
1052
+ # 创建一个小的测试图像 (64x64)
1053
+ test_image = np.ones((64, 64, 3), dtype=np.uint8) * 128
1054
+
1055
+ # 预热DeepFace模型(如果可用)
1056
+ if DEEPFACE_AVAILABLE:
1057
+ try:
1058
+ import tempfile
1059
+ with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_file:
1060
+ cv2.imwrite(tmp_file.name, test_image, [cv2.IMWRITE_WEBP_QUALITY, 95])
1061
+ # 预热DeepFace - 使用最小的actions集合
1062
+ DeepFace.analyze(
1063
+ img_path=tmp_file.name,
1064
+ actions=["age", "emotion", "gender"],
1065
+ detector_backend="yolov8",
1066
+ enforce_detection=False,
1067
+ silent=True
1068
+ )
1069
+ os.unlink(tmp_file.name)
1070
+ logger.info("DeepFace model warm-up completed")
1071
+ except Exception as e:
1072
+ logger.warning(f"DeepFace model warm-up failed: {e}")
1073
+
1074
+ # 预热OpenCV DNN模型
1075
+ try:
1076
+ # 预热人脸检测模型
1077
+ blob = cv2.dnn.blobFromImage(test_image, 1.0, (300, 300), (104, 117, 123))
1078
+ self.face_net.setInput(blob)
1079
+ self.face_net.forward()
1080
+
1081
+ # 预热年龄预测模型
1082
+ test_face = cv2.resize(test_image, (224, 224))
1083
+ blob = cv2.dnn.blobFromImage(test_face, 1.0, (224, 224), self.MODEL_MEAN_VALUES, swapRB=False)
1084
+ self.age_net.setInput(blob)
1085
+ self.age_net.forward()
1086
+
1087
+ # 预热性别预测模型
1088
+ self.gender_net.setInput(blob)
1089
+ self.gender_net.forward()
1090
+
1091
+ # 预热颜值评分模型
1092
+ self.beauty_net.setInput(blob)
1093
+ self.beauty_net.forward()
1094
+
1095
+ logger.info("OpenCV DNN model warm-up completed")
1096
+ except Exception as e:
1097
+ logger.warning(f"OpenCV DNN model warm-up failed: {e}")
1098
+
1099
+ logger.info("Model warm-up completed")
1100
+ except Exception as e:
1101
+ logger.warning(f"Error occurred during model warm-up: {e}")
facial_analyzer.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from typing import List, Dict, Any
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ import config
8
+ from config import logger, DLIB_AVAILABLE
9
+
10
+ if DLIB_AVAILABLE:
11
+ import mediapipe as mp
12
+
13
+
14
+ class FacialFeatureAnalyzer:
15
+ """五官分析器"""
16
+
17
+ def __init__(self):
18
+ self.face_mesh = None
19
+ if DLIB_AVAILABLE:
20
+ try:
21
+ # 初始化MediaPipe Face Mesh
22
+ mp_face_mesh = mp.solutions.face_mesh
23
+ self.face_mesh = mp_face_mesh.FaceMesh(
24
+ static_image_mode=True,
25
+ max_num_faces=1,
26
+ refine_landmarks=True,
27
+ min_detection_confidence=0.5,
28
+ min_tracking_confidence=0.5
29
+ )
30
+ logger.info("MediaPipe face landmark detector loaded successfully")
31
+ except Exception as e:
32
+ logger.error(f"Failed to load MediaPipe model: {e}")
33
+
34
+ def analyze_facial_features(
35
+ self, face_image: np.ndarray, face_box: List[int]
36
+ ) -> Dict[str, Any]:
37
+ """
38
+ 分析五官特征
39
+ :param face_image: 人脸图像
40
+ :param face_box: 人脸边界框 [x1, y1, x2, y2]
41
+ :return: 五官分析结果
42
+ """
43
+ if not DLIB_AVAILABLE or self.face_mesh is None:
44
+ return self._basic_facial_analysis(face_image)
45
+
46
+ try:
47
+ # MediaPipe需要RGB图像
48
+ rgb_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
49
+
50
+ # 检测关键点
51
+ results = self.face_mesh.process(rgb_image)
52
+
53
+ if not results.multi_face_landmarks:
54
+ logger.warning("No facial landmarks detected")
55
+ return self._basic_facial_analysis(face_image)
56
+
57
+ # 获取第一个面部的关键点
58
+ face_landmarks = results.multi_face_landmarks[0]
59
+
60
+ # 将MediaPipe的468个关键点转换为类似dlib 68点的格式
61
+ points = self._convert_mediapipe_to_dlib_format(face_landmarks, face_image.shape)
62
+
63
+ return self._analyze_features_from_landmarks(points, face_image.shape)
64
+
65
+ except Exception as e:
66
+ logger.error(f"Facial feature analysis failed: {e}")
67
+ traceback.print_exc() # ← 打印完整堆栈,包括确切行号
68
+ return self._basic_facial_analysis(face_image)
69
+
70
+ def _convert_mediapipe_to_dlib_format(self, face_landmarks, image_shape):
71
+ """
72
+ 将MediaPipe的468个关键点转换为类似dlib 68点的格式
73
+ MediaPipe到dlib的关键点映射
74
+ """
75
+ h, w = image_shape[:2]
76
+
77
+ # MediaPipe关键点索引到dlib 68点的映射
78
+ # 这个映射基于MediaPipe Face Mesh的标准索引
79
+ mediapipe_to_dlib_map = {
80
+ # 面部轮廓 (0-16)
81
+ 0: 234, # 下巴最低点
82
+ 1: 132, # 右脸颊下
83
+ 2: 172, # 右脸颊
84
+ 3: 136, # 右脸颊上
85
+ 4: 150, # 右颧骨
86
+ 5: 149, # 右太阳穴
87
+ 6: 176, # 右额头边缘
88
+ 7: 148, # 右额头
89
+ 8: 152, # 额头中央
90
+ 9: 377, # 左额头
91
+ 10: 400, # 左额头边缘
92
+ 11: 378, # 左太阳穴
93
+ 12: 379, # 左颧骨
94
+ 13: 365, # 左脸颊上
95
+ 14: 397, # 左脸颊
96
+ 15: 361, # 左脸颊下
97
+ 16: 454, # 下巴左侧
98
+
99
+ # 右眉毛 (17-21)
100
+ 17: 70, # 右眉毛外端
101
+ 18: 63, # 右眉毛
102
+ 19: 105, # 右眉毛
103
+ 20: 66, # 右眉毛
104
+ 21: 107, # 右眉毛内端
105
+
106
+ # 左眉毛 (22-26)
107
+ 22: 336, # 左眉毛内端
108
+ 23: 296, # 左眉毛
109
+ 24: 334, # 左眉毛
110
+ 25: 293, # 左眉毛
111
+ 26: 300, # 左眉毛外端
112
+
113
+ # 鼻梁 (27-30)
114
+ 27: 168, # 鼻梁顶
115
+ 28: 8, # 鼻梁
116
+ 29: 9, # 鼻梁
117
+ 30: 10, # 鼻梁底
118
+
119
+ # 鼻翼 (31-35)
120
+ 31: 151, # 右鼻翼
121
+ 32: 134, # 右鼻孔
122
+ 33: 2, # 鼻尖
123
+ 34: 363, # 左鼻孔
124
+ 35: 378, # 左鼻翼
125
+
126
+ # 右眼 (36-41)
127
+ 36: 33, # 右眼外角
128
+ 37: 7, # 右眼上眼睑
129
+ 38: 163, # 右眼上眼睑
130
+ 39: 144, # 右眼内角
131
+ 40: 145, # 右眼下眼睑
132
+ 41: 153, # 右眼下眼睑
133
+
134
+ # 左眼 (42-47)
135
+ 42: 362, # 左眼内角
136
+ 43: 382, # 左眼上眼睑
137
+ 44: 381, # 左眼上眼睑
138
+ 45: 380, # 左眼外角
139
+ 46: 374, # 左眼下眼睑
140
+ 47: 373, # 左眼下眼睑
141
+
142
+ # 嘴部轮廓 (48-67)
143
+ 48: 78, # 右嘴角
144
+ 49: 95, # 右上唇
145
+ 50: 88, # 上唇右侧
146
+ 51: 178, # 上唇中央右
147
+ 52: 87, # 上唇中央
148
+ 53: 14, # 上唇中央左
149
+ 54: 317, # 上唇左侧
150
+ 55: 318, # 左上唇
151
+ 56: 308, # 左嘴角
152
+ 57: 324, # 左下唇
153
+ 58: 318, # 下唇左侧
154
+ 59: 16, # 下唇中央左
155
+ 60: 17, # 下唇中央
156
+ 61: 18, # 下唇中央右
157
+ 62: 200, # 下唇右侧
158
+ 63: 199, # 右下唇
159
+ 64: 175, # 右嘴角内
160
+ 65: 84, # 上唇内右
161
+ 66: 17, # 下唇内中央
162
+ 67: 314, # 上唇内左
163
+ }
164
+
165
+ # 转换关键点
166
+ points = []
167
+ for i in range(68):
168
+ if i in mediapipe_to_dlib_map:
169
+ mp_idx = mediapipe_to_dlib_map[i]
170
+ if mp_idx < len(face_landmarks.landmark):
171
+ landmark = face_landmarks.landmark[mp_idx]
172
+ x = int(landmark.x * w)
173
+ y = int(landmark.y * h)
174
+ points.append((x, y))
175
+ else:
176
+ # 如果索引超出范围,使用默认位置
177
+ points.append((w//2, h//2))
178
+ else:
179
+ # 如果没有映射,使用默认位置
180
+ points.append((w//2, h//2))
181
+
182
+ return points
183
+
184
+ def _analyze_features_from_landmarks(
185
+ self, landmarks: List[tuple], image_shape: tuple
186
+ ) -> Dict[str, Any]:
187
+ """基于68个关键点分析五官"""
188
+ try:
189
+ # 定义各部位的关键点索引
190
+ jawline = landmarks[0:17] # 下颌线
191
+ left_eyebrow = landmarks[17:22] # 左眉毛
192
+ right_eyebrow = landmarks[22:27] # 右眉毛
193
+ nose = landmarks[27:36] # 鼻子
194
+ left_eye = landmarks[36:42] # 左眼
195
+ right_eye = landmarks[42:48] # 右眼
196
+ mouth = landmarks[48:68] # 嘴巴
197
+
198
+ # 计算各部位得分 (简化版,实际应用需要更复杂的算法)
199
+ scores = {
200
+ "eyes": self._score_eyes(left_eye, right_eye, image_shape),
201
+ "nose": self._score_nose(nose, image_shape),
202
+ "mouth": self._score_mouth(mouth, image_shape),
203
+ "eyebrows": self._score_eyebrows(
204
+ left_eyebrow, right_eyebrow, image_shape
205
+ ),
206
+ "jawline": self._score_jawline(jawline, image_shape),
207
+ }
208
+
209
+ # 计算总体协调性
210
+ harmony_score = self._calculate_harmony_new(landmarks, image_shape)
211
+ # 温和上调整体协调性分数(与颜值类似的拉升策略)
212
+ harmony_score = self._adjust_harmony_score(harmony_score)
213
+
214
+ return {
215
+ "facial_features": scores,
216
+ "harmony_score": round(harmony_score, 2),
217
+ "overall_facial_score": round(sum(scores.values()) / len(scores), 2),
218
+ "analysis_method": "mediapipe_landmarks",
219
+ }
220
+
221
+ except Exception as e:
222
+ logger.error(f"Landmark analysis failed: {e}")
223
+ return self._basic_facial_analysis(None)
224
+
225
+ def _adjust_harmony_score(self, score: float) -> float:
226
+ """整体协调性分值温和拉升:当低于阈值时往阈值靠拢一点。"""
227
+ try:
228
+ if not getattr(config, "HARMONY_ADJUST_ENABLED", False):
229
+ return round(float(score), 2)
230
+ thr = float(getattr(config, "HARMONY_ADJUST_THRESHOLD", 8.0))
231
+ gamma = float(getattr(config, "HARMONY_ADJUST_GAMMA", 0.5))
232
+ gamma = max(0.0001, min(1.0, gamma))
233
+ s = float(score)
234
+ if s < thr:
235
+ s = thr - gamma * (thr - s)
236
+ return round(min(10.0, max(0.0, s)), 2)
237
+ except Exception:
238
+ try:
239
+ return round(float(score), 2)
240
+ except Exception:
241
+ return 6.21
242
+
243
+ def _score_eyes(
244
+ self, left_eye: List[tuple], right_eye: List[tuple], image_shape: tuple
245
+ ) -> float:
246
+ """眼部评分"""
247
+ try:
248
+ # 计算眼部对称性和大小
249
+ left_width = abs(left_eye[3][0] - left_eye[0][0])
250
+ right_width = abs(right_eye[3][0] - right_eye[0][0])
251
+
252
+ # 计算眼部高度
253
+ left_height = abs(left_eye[1][1] - left_eye[5][1])
254
+ right_height = abs(right_eye[1][1] - right_eye[5][1])
255
+
256
+ # 对称性评分 - 宽度对称性
257
+ width_symmetry = 1 - min(
258
+ abs(left_width - right_width) / max(left_width, right_width), 0.5
259
+ )
260
+
261
+ # 高度对称性
262
+ height_symmetry = 1 - min(
263
+ abs(left_height - right_height) / max(left_height, right_height), 0.5
264
+ )
265
+
266
+ # 大小适中性评分 (相对于脸部宽度) - 调整理想比例
267
+ avg_eye_width = (left_width + right_width) / 2
268
+ face_width = image_shape[1]
269
+ ideal_ratio = 0.08 # 调整理想比例,原来0.15太大
270
+ size_score = max(
271
+ 0, 1 - abs(avg_eye_width / face_width - ideal_ratio) / ideal_ratio
272
+ )
273
+
274
+ # 眼部长宽比评分
275
+ avg_eye_height = (left_height + right_height) / 2
276
+ aspect_ratio = avg_eye_width / max(avg_eye_height, 1) # 避免除零
277
+ ideal_aspect = 3.0 # 理想长宽比
278
+ aspect_score = max(0, 1 - abs(aspect_ratio - ideal_aspect) / ideal_aspect)
279
+
280
+ final_score = (
281
+ width_symmetry * 0.3
282
+ + height_symmetry * 0.3
283
+ + size_score * 0.25
284
+ + aspect_score * 0.15
285
+ ) * 10
286
+ return round(max(0, min(10, final_score)), 2)
287
+ except:
288
+ return 6.21
289
+
290
+ def _score_nose(self, nose: List[tuple], image_shape: tuple) -> float:
291
+ """鼻部评分"""
292
+ try:
293
+ # 鼻子关键点
294
+ nose_tip = nose[3] # 鼻尖
295
+ nose_bridge_top = nose[0] # 鼻梁顶部
296
+ left_nostril = nose[1]
297
+ right_nostril = nose[5]
298
+
299
+ # 计算鼻子的直线度 (鼻梁是否挺直)
300
+ straightness = 1 - min(
301
+ abs(nose_tip[0] - nose_bridge_top[0]) / (image_shape[1] * 0.1), 1.0
302
+ )
303
+
304
+ # 鼻宽评分 - 使用鼻翼宽度
305
+ nose_width = abs(right_nostril[0] - left_nostril[0])
306
+ face_width = image_shape[1]
307
+ ideal_nose_ratio = 0.06 # 调整理想比例
308
+ width_score = max(
309
+ 0,
310
+ 1 - abs(nose_width / face_width - ideal_nose_ratio) / ideal_nose_ratio,
311
+ )
312
+
313
+ # 鼻子长度评分
314
+ nose_length = abs(nose_tip[1] - nose_bridge_top[1])
315
+ face_height = image_shape[0]
316
+ ideal_length_ratio = 0.08
317
+ length_score = max(
318
+ 0,
319
+ 1
320
+ - abs(nose_length / face_height - ideal_length_ratio)
321
+ / ideal_length_ratio,
322
+ )
323
+
324
+ final_score = (
325
+ straightness * 0.4 + width_score * 0.35 + length_score * 0.25
326
+ ) * 10
327
+ return round(max(0, min(10, final_score)), 2)
328
+ except:
329
+ return 6.21
330
+
331
+ def _score_mouth(self, mouth: List[tuple], image_shape: tuple) -> float:
332
+ """嘴部评分 - 大幅优化,更宽松的评分标准"""
333
+ try:
334
+ # 嘴角点
335
+ left_corner = mouth[0] # 左嘴角
336
+ right_corner = mouth[6] # 右嘴角
337
+
338
+ # 上唇和下唇中心点
339
+ upper_lip_center = mouth[3] # 上唇中心
340
+ lower_lip_center = mouth[9] # 下唇中心
341
+
342
+ # 基础分数,避免过低
343
+ base_score = 6.0
344
+
345
+ # 1. 嘴宽评分 - 更宽松的标准
346
+ mouth_width = abs(right_corner[0] - left_corner[0])
347
+ face_width = image_shape[1]
348
+ mouth_ratio = mouth_width / face_width
349
+
350
+ # 设置更宽的合理范围 (0.04-0.15)
351
+ if 0.04 <= mouth_ratio <= 0.15:
352
+ width_score = 1.0 # 在合理范围内就给满分
353
+ elif mouth_ratio < 0.04:
354
+ width_score = max(0.3, mouth_ratio / 0.04) # 太小时渐减
355
+ else:
356
+ width_score = max(0.3, 0.15 / mouth_ratio) # 太大时渐减
357
+
358
+ # 2. 唇厚度评分 - 简化并放宽标准
359
+ lip_thickness = abs(lower_lip_center[1] - upper_lip_center[1])
360
+ # 只要厚度不是极端值就给高分
361
+ if lip_thickness > 3: # 像素值,有一定厚度
362
+ thickness_score = min(1.0, lip_thickness / 25) # 25像素为满分
363
+ else:
364
+ thickness_score = 0.5 # 太薄给中等分数
365
+
366
+ # 3. 嘴部对称性评分 - 更宽松
367
+ mouth_center_x = (left_corner[0] + right_corner[0]) / 2
368
+ face_center_x = image_shape[1] / 2
369
+ center_deviation = abs(mouth_center_x - face_center_x) / face_width
370
+
371
+ if center_deviation < 0.02: # 偏差小于2%
372
+ symmetry_score = 1.0
373
+ elif center_deviation < 0.05: # 偏差小于5%
374
+ symmetry_score = 0.8
375
+ else:
376
+ symmetry_score = max(0.5, 1 - center_deviation * 10) # 最低0.5分
377
+
378
+ # 4. 嘴唇形状评分 - 简化
379
+ # 检查嘴角是否在合理位置
380
+ corner_height_diff = abs(left_corner[1] - right_corner[1])
381
+ if corner_height_diff < face_width * 0.02: # 嘴角高度差异小
382
+ shape_score = 1.0
383
+ else:
384
+ shape_score = max(0.6, 1 - corner_height_diff / (face_width * 0.02))
385
+
386
+ # 5. 综合评分 - 调整权重,给基础分更大权重
387
+ feature_score = (
388
+ width_score * 0.3
389
+ + thickness_score * 0.25
390
+ + symmetry_score * 0.25
391
+ + shape_score * 0.2
392
+ )
393
+
394
+ # 最终分数 = 基础分 + 特征分奖励
395
+ final_score = base_score + feature_score * 4 # 最高10分
396
+
397
+ return round(max(4.0, min(10, final_score)), 2) # 最低4分,最高10分
398
+ except Exception as e:
399
+ return 6.21
400
+
401
+ def _score_eyebrows(
402
+ self, left_brow: List[tuple], right_brow: List[tuple], image_shape: tuple
403
+ ) -> float:
404
+ """眉毛评分 - 改进算法"""
405
+ try:
406
+ # 计算眉毛长度
407
+ left_length = abs(left_brow[-1][0] - left_brow[0][0])
408
+ right_length = abs(right_brow[-1][0] - right_brow[0][0])
409
+
410
+ # 长度对称性
411
+ length_symmetry = 1 - min(
412
+ abs(left_length - right_length) / max(left_length, right_length), 0.5
413
+ )
414
+
415
+ # 计算眉毛拱形 - 改进方法
416
+ left_peak_y = min([p[1] for p in left_brow]) # 眉峰(y坐标最小)
417
+ left_ends_y = (left_brow[0][1] + left_brow[-1][1]) / 2 # 眉毛两端平均高度
418
+ left_arch = max(0, left_ends_y - left_peak_y) # 拱形高度
419
+
420
+ right_peak_y = min([p[1] for p in right_brow])
421
+ right_ends_y = (right_brow[0][1] + right_brow[-1][1]) / 2
422
+ right_arch = max(0, right_ends_y - right_peak_y)
423
+
424
+ # 拱形对称性
425
+ arch_symmetry = 1 - min(
426
+ abs(left_arch - right_arch) / max(left_arch, right_arch, 1), 0.5
427
+ )
428
+
429
+ # 眉形适中性评分
430
+ avg_arch = (left_arch + right_arch) / 2
431
+ face_height = image_shape[0]
432
+ ideal_arch_ratio = 0.015 # 理想拱形比例
433
+ arch_ratio = avg_arch / face_height
434
+ arch_score = max(
435
+ 0, 1 - abs(arch_ratio - ideal_arch_ratio) / ideal_arch_ratio
436
+ )
437
+
438
+ # 眉毛浓密度(通过点的密集程度估算)
439
+ density_score = min(1.0, (len(left_brow) + len(right_brow)) / 10)
440
+
441
+ final_score = (
442
+ length_symmetry * 0.3
443
+ + arch_symmetry * 0.3
444
+ + arch_score * 0.25
445
+ + density_score * 0.15
446
+ ) * 10
447
+ return round(max(0, min(10, final_score)), 2)
448
+ except:
449
+ return 6.21
450
+
451
+ def _score_jawline(self, jawline: List[tuple], image_shape: tuple) -> float:
452
+ """下颌线评分 - 改进算法"""
453
+ try:
454
+ jaw_points = [(p[0], p[1]) for p in jawline]
455
+
456
+ # 关键点
457
+ left_jaw = jaw_points[2] # 左下颌角
458
+ jaw_tip = jaw_points[8] # 下巴尖
459
+ right_jaw = jaw_points[14] # 右下颌角
460
+
461
+ # 对称性评分 - 改进计算
462
+ left_dist = (
463
+ (left_jaw[0] - jaw_tip[0]) ** 2 + (left_jaw[1] - jaw_tip[1]) ** 2
464
+ ) ** 0.5
465
+ right_dist = (
466
+ (right_jaw[0] - jaw_tip[0]) ** 2 + (right_jaw[1] - jaw_tip[1]) ** 2
467
+ ) ** 0.5
468
+ symmetry = 1 - min(
469
+ abs(left_dist - right_dist) / max(left_dist, right_dist), 0.5
470
+ )
471
+
472
+ # 下颌角度评分
473
+ left_angle_y = abs(left_jaw[1] - jaw_tip[1])
474
+ right_angle_y = abs(right_jaw[1] - jaw_tip[1])
475
+ avg_angle = (left_angle_y + right_angle_y) / 2
476
+
477
+ # 理想的下颌角度
478
+ face_height = image_shape[0]
479
+ ideal_angle_ratio = 0.08
480
+ angle_ratio = avg_angle / face_height
481
+ angle_score = max(
482
+ 0, 1 - abs(angle_ratio - ideal_angle_ratio) / ideal_angle_ratio
483
+ )
484
+
485
+ # 下颌线清晰度(通过点间距离变化评估)
486
+ smoothness_score = 0.8 # 简化处理,可以根据实际需要改进
487
+
488
+ final_score = (
489
+ symmetry * 0.4 + angle_score * 0.35 + smoothness_score * 0.25
490
+ ) * 10
491
+ return round(max(0, min(10, final_score)), 2)
492
+ except:
493
+ return 6.21
494
+
495
+ def _calculate_harmony(self, landmarks: List[tuple], image_shape: tuple) -> float:
496
+ """计算五官协调性"""
497
+ try:
498
+ # 黄金比例检测 (简化版)
499
+ face_height = max([p[1] for p in landmarks]) - min(
500
+ [p[1] for p in landmarks]
501
+ )
502
+ face_width = max([p[0] for p in landmarks]) - min([p[0] for p in landmarks])
503
+
504
+ # 理想比例约为1.618
505
+ ratio = face_height / face_width if face_width > 0 else 1
506
+ golden_ratio = 1.618
507
+ harmony = 1 - abs(ratio - golden_ratio) / golden_ratio
508
+
509
+ return max(0, min(10, harmony * 10))
510
+ except:
511
+ return 6.21
512
+
513
+ def _calculate_harmony_new(
514
+ self, landmarks: List[tuple], image_shape: tuple
515
+ ) -> float:
516
+ """
517
+ 计算五官协调性 - 优化版本
518
+ 基于多个美学比例和对称性指标
519
+ """
520
+ try:
521
+ logger.debug(f"face landmarks={len(landmarks)}")
522
+ if len(landmarks) < 68: # 假设使用68点面部关键点
523
+ return 6.21
524
+
525
+ # 转换为numpy数组便于计算
526
+ points = np.array(landmarks)
527
+
528
+ # 1. 面部基础测量
529
+ face_measurements = self._get_face_measurements(points)
530
+
531
+ # 2. 计算多个协调性指标
532
+ scores = []
533
+
534
+ # 黄金比例评分 (权重: 20%)
535
+ golden_score = self._calculate_golden_ratios(face_measurements)
536
+ logger.debug(f"Golden ratio score={golden_score}")
537
+ scores.append(("golden_ratio", golden_score, 0.10))
538
+
539
+ # 对称性评分 (权重: 25%)
540
+ symmetry_score = self._calculate_facial_symmetry(face_measurements, points)
541
+ logger.debug(f"Symmetry score={symmetry_score}")
542
+ scores.append(("symmetry", symmetry_score, 0.40))
543
+
544
+ # 三庭五眼比例 (权重: 20%)
545
+ proportion_score = self._calculate_classical_proportions(face_measurements)
546
+ logger.debug(f"Three courts five eyes ratio={proportion_score}")
547
+ scores.append(("proportions", proportion_score, 0.05))
548
+
549
+ # 五官间距协调性 (权重: 15%)
550
+ spacing_score = self._calculate_feature_spacing(face_measurements)
551
+ logger.debug(f"Facial feature spacing harmony={spacing_score}")
552
+ scores.append(("spacing", spacing_score, 0))
553
+
554
+ # 面部轮廓协调性 (权重: 10%)
555
+ contour_score = self._calculate_contour_harmony(points)
556
+ logger.debug(f"Facial contour harmony={contour_score}")
557
+ scores.append(("contour", contour_score, 0.05))
558
+
559
+ # 眼鼻口比例协调性 (权重: 10%)
560
+ feature_score = self._calculate_feature_proportions(face_measurements)
561
+ logger.debug(f"Eye-nose-mouth proportion harmony={feature_score}")
562
+ scores.append(("features", feature_score, 0.40))
563
+
564
+ # 加权平均计算最终得分
565
+ final_score = sum(score * weight for _, score, weight in scores)
566
+ logger.debug(f"Weighted average final score={final_score}")
567
+ return max(0, min(10, final_score))
568
+
569
+ except Exception as e:
570
+ logger.error(f"Error calculating facial harmony: {e}")
571
+ traceback.print_exc() # ← 打印完整堆栈,包括确切行号
572
+ return 6.21
573
+
574
+ def _get_face_measurements(self, points: np.ndarray) -> Dict[str, float]:
575
+ """提取面部关键测量数据"""
576
+ measurements = {}
577
+
578
+ # 面部轮廓点 (0-16)
579
+ face_contour = points[0:17]
580
+
581
+ # 眉毛点 (17-26)
582
+ left_eyebrow = points[17:22]
583
+ right_eyebrow = points[22:27]
584
+
585
+ # 眼睛点 (36-47)
586
+ left_eye = points[36:42]
587
+ right_eye = points[42:48]
588
+
589
+ # 鼻子点 (27-35)
590
+ nose = points[27:36]
591
+
592
+ # 嘴巴点 (48-67)
593
+ mouth = points[48:68]
594
+
595
+ # 基础测量
596
+ measurements["face_width"] = np.max(face_contour[:, 0]) - np.min(
597
+ face_contour[:, 0]
598
+ )
599
+ measurements["face_height"] = np.max(points[:, 1]) - np.min(points[:, 1])
600
+
601
+ # 眼部测量
602
+ measurements["left_eye_width"] = np.max(left_eye[:, 0]) - np.min(left_eye[:, 0])
603
+ measurements["right_eye_width"] = np.max(right_eye[:, 0]) - np.min(
604
+ right_eye[:, 0]
605
+ )
606
+ measurements["eye_distance"] = np.min(right_eye[:, 0]) - np.max(left_eye[:, 0])
607
+ measurements["left_eye_center"] = np.mean(left_eye, axis=0)
608
+ measurements["right_eye_center"] = np.mean(right_eye, axis=0)
609
+
610
+ # 鼻部测量
611
+ measurements["nose_width"] = np.max(nose[:, 0]) - np.min(nose[:, 0])
612
+ measurements["nose_height"] = np.max(nose[:, 1]) - np.min(nose[:, 1])
613
+ measurements["nose_tip"] = points[33] # 鼻尖
614
+
615
+ # 嘴部测量
616
+ measurements["mouth_width"] = np.max(mouth[:, 0]) - np.min(mouth[:, 0])
617
+ measurements["mouth_height"] = np.max(mouth[:, 1]) - np.min(mouth[:, 1])
618
+
619
+ # 关键垂直距离
620
+ measurements["forehead_height"] = measurements["left_eye_center"][1] - np.min(
621
+ points[:, 1]
622
+ )
623
+ measurements["middle_face_height"] = (
624
+ measurements["nose_tip"][1] - measurements["left_eye_center"][1]
625
+ )
626
+ measurements["lower_face_height"] = (
627
+ np.max(points[:, 1]) - measurements["nose_tip"][1]
628
+ )
629
+
630
+ return measurements
631
+
632
+ def _calculate_golden_ratios(self, measurements: Dict[str, float]) -> float:
633
+ """计算黄金比例相关得分"""
634
+ golden_ratio = 1.618
635
+ scores = []
636
+
637
+ # 面部长宽比
638
+ if measurements["face_width"] > 0:
639
+ face_ratio = measurements["face_height"] / measurements["face_width"]
640
+ score = 1 - abs(face_ratio - golden_ratio) / golden_ratio
641
+ scores.append(max(0, score))
642
+
643
+ # 上中下三庭比例
644
+ total_height = (
645
+ measurements["forehead_height"]
646
+ + measurements["middle_face_height"]
647
+ + measurements["lower_face_height"]
648
+ )
649
+
650
+ if total_height > 0:
651
+ upper_ratio = measurements["forehead_height"] / total_height
652
+ middle_ratio = measurements["middle_face_height"] / total_height
653
+ lower_ratio = measurements["lower_face_height"] / total_height
654
+
655
+ # 理想比例约为 1:1:1
656
+ ideal_ratio = 1 / 3
657
+ upper_score = 1 - abs(upper_ratio - ideal_ratio) / ideal_ratio
658
+ middle_score = 1 - abs(middle_ratio - ideal_ratio) / ideal_ratio
659
+ lower_score = 1 - abs(lower_ratio - ideal_ratio) / ideal_ratio
660
+
661
+ scores.extend(
662
+ [max(0, upper_score), max(0, middle_score), max(0, lower_score)]
663
+ )
664
+
665
+ return np.mean(scores) * 10 if scores else 7.0
666
+
667
+ def _calculate_facial_symmetry(
668
+ self, measurements: Dict[str, float], points: np.ndarray
669
+ ) -> float:
670
+ """计算面部对称性"""
671
+ # 计算面部中线
672
+ face_center_x = np.mean(points[:, 0])
673
+
674
+ # 检查左右对称的关键点对
675
+ symmetry_pairs = [
676
+ (17, 26), # 眉毛外端
677
+ (18, 25), # 眉毛
678
+ (19, 24), # 眉毛
679
+ (36, 45), # 眼角
680
+ (39, 42), # 眼角
681
+ (31, 35), # 鼻翼
682
+ (48, 54), # 嘴角
683
+ (4, 12), # 面部轮廓
684
+ (5, 11), # 面部轮廓
685
+ (6, 10), # 面部轮廓
686
+ ]
687
+
688
+ symmetry_scores = []
689
+
690
+ for left_idx, right_idx in symmetry_pairs:
691
+ if left_idx < len(points) and right_idx < len(points):
692
+ left_point = points[left_idx]
693
+ right_point = points[right_idx]
694
+
695
+ # 计算到中线的距离差异
696
+ left_dist = abs(left_point[0] - face_center_x)
697
+ right_dist = abs(right_point[0] - face_center_x)
698
+
699
+ # 垂直位置差异
700
+ vertical_diff = abs(left_point[1] - right_point[1])
701
+
702
+ # 对称性得分
703
+ if left_dist + right_dist > 0:
704
+ horizontal_symmetry = 1 - abs(left_dist - right_dist) / (
705
+ left_dist + right_dist
706
+ )
707
+ vertical_symmetry = 1 - vertical_diff / measurements.get(
708
+ "face_height", 100
709
+ )
710
+
711
+ symmetry_scores.append(
712
+ (horizontal_symmetry + vertical_symmetry) / 2
713
+ )
714
+
715
+ return np.mean(symmetry_scores) * 10 if symmetry_scores else 7.0
716
+
717
+ def _calculate_classical_proportions(self, measurements: Dict[str, float]) -> float:
718
+ """计算经典美学比例 (三庭五眼等)"""
719
+ scores = []
720
+
721
+ # 五眼比例检测
722
+ if measurements["face_width"] > 0:
723
+ eye_width_avg = (
724
+ measurements["left_eye_width"] + measurements["right_eye_width"]
725
+ ) / 2
726
+ ideal_eye_count = 5 # 理想情况下面宽应该等于5个眼宽
727
+ actual_eye_count = (
728
+ measurements["face_width"] / eye_width_avg if eye_width_avg > 0 else 5
729
+ )
730
+
731
+ eye_proportion_score = (
732
+ 1 - abs(actual_eye_count - ideal_eye_count) / ideal_eye_count
733
+ )
734
+ scores.append(max(0, eye_proportion_score))
735
+
736
+ # 眼间距比例
737
+ if measurements.get("left_eye_width", 0) > 0:
738
+ eye_spacing_ratio = (
739
+ measurements["eye_distance"] / measurements["left_eye_width"]
740
+ )
741
+ ideal_spacing_ratio = 1.0 # 理想情况下眼间距约等于一个眼宽
742
+
743
+ spacing_score = (
744
+ 1 - abs(eye_spacing_ratio - ideal_spacing_ratio) / ideal_spacing_ratio
745
+ )
746
+ scores.append(max(0, spacing_score))
747
+
748
+ # 鼻宽与眼宽比例
749
+ if (
750
+ measurements.get("left_eye_width", 0) > 0
751
+ and measurements.get("nose_width", 0) > 0
752
+ ):
753
+ nose_eye_ratio = measurements["nose_width"] / measurements["left_eye_width"]
754
+ ideal_nose_eye_ratio = 0.8 # 理想鼻宽约为眼宽的80%
755
+
756
+ nose_score = (
757
+ 1 - abs(nose_eye_ratio - ideal_nose_eye_ratio) / ideal_nose_eye_ratio
758
+ )
759
+ scores.append(max(0, nose_score))
760
+
761
+ return np.mean(scores) * 10 if scores else 7.0
762
+
763
+ def _calculate_feature_spacing(self, measurements: Dict[str, float]) -> float:
764
+ """计算五官间距协调性"""
765
+ scores = []
766
+
767
+ # 眼鼻距离协调性
768
+ eye_nose_distance = abs(
769
+ measurements["left_eye_center"][1] - measurements["nose_tip"][1]
770
+ )
771
+ if measurements.get("face_height", 0) > 0:
772
+ eye_nose_ratio = eye_nose_distance / measurements["face_height"]
773
+ ideal_ratio = 0.15 # 理想比例
774
+ score = 1 - abs(eye_nose_ratio - ideal_ratio) / ideal_ratio
775
+ scores.append(max(0, score))
776
+
777
+ # 鼻嘴距离协调性
778
+ nose_mouth_distance = abs(
779
+ measurements["nose_tip"][1] - np.mean([measurements.get("mouth_height", 0)])
780
+ )
781
+ if measurements.get("face_height", 0) > 0:
782
+ nose_mouth_ratio = nose_mouth_distance / measurements["face_height"]
783
+ ideal_ratio = 0.12 # 理想比例
784
+ score = 1 - abs(nose_mouth_ratio - ideal_ratio) / ideal_ratio
785
+ scores.append(max(0, score))
786
+
787
+ return np.mean(scores) * 10 if scores else 7.0
788
+
789
+ def _calculate_contour_harmony(self, points: np.ndarray) -> float:
790
+ """计算面部轮廓协调性"""
791
+ try:
792
+ face_contour = points[0:17] # 面部轮廓点
793
+
794
+ # 计算轮廓的平滑度
795
+ smoothness_scores = []
796
+
797
+ for i in range(1, len(face_contour) - 1):
798
+ # 计算相邻三点形成的角度
799
+ p1, p2, p3 = face_contour[i - 1], face_contour[i], face_contour[i + 1]
800
+
801
+ v1 = p1 - p2
802
+ v2 = p3 - p2
803
+
804
+ # 计算角度
805
+ cos_angle = np.dot(v1, v2) / (
806
+ np.linalg.norm(v1) * np.linalg.norm(v2) + 1e-8
807
+ )
808
+ angle = np.arccos(np.clip(cos_angle, -1, 1))
809
+
810
+ # 角度越接近平滑曲线越好 (避免过于尖锐的角度)
811
+ smoothness = 1 - abs(angle - np.pi / 2) / (np.pi / 2)
812
+ smoothness_scores.append(max(0, smoothness))
813
+
814
+ return np.mean(smoothness_scores) * 10 if smoothness_scores else 7.0
815
+
816
+ except:
817
+ return 6.21
818
+
819
+ def _calculate_feature_proportions(self, measurements: Dict[str, float]) -> float:
820
+ """计算眼鼻口等五官内部比例协调性"""
821
+ scores = []
822
+
823
+ # 眼部比例 (长宽比)
824
+ left_eye_ratio = measurements.get("left_eye_width", 1) / max(
825
+ measurements.get("left_eye_width", 1) * 0.3, 1
826
+ )
827
+ right_eye_ratio = measurements.get("right_eye_width", 1) / max(
828
+ measurements.get("right_eye_width", 1) * 0.3, 1
829
+ )
830
+
831
+ # 理想眼部长宽比约为3:1
832
+ ideal_eye_ratio = 3.0
833
+ left_eye_score = 1 - abs(left_eye_ratio - ideal_eye_ratio) / ideal_eye_ratio
834
+ right_eye_score = 1 - abs(right_eye_ratio - ideal_eye_ratio) / ideal_eye_ratio
835
+
836
+ scores.extend([max(0, left_eye_score), max(0, right_eye_score)])
837
+
838
+ # 嘴部比例
839
+ if measurements.get("mouth_height", 0) > 0:
840
+ mouth_ratio = measurements["mouth_width"] / measurements["mouth_height"]
841
+ ideal_mouth_ratio = 3.5 # 理想嘴部长宽比
842
+ mouth_score = 1 - abs(mouth_ratio - ideal_mouth_ratio) / ideal_mouth_ratio
843
+ scores.append(max(0, mouth_score))
844
+
845
+ # 鼻部比例
846
+ if measurements.get("nose_height", 0) > 0:
847
+ nose_ratio = measurements["nose_height"] / measurements["nose_width"]
848
+ ideal_nose_ratio = 1.5 # 理想鼻部长宽比
849
+ nose_score = 1 - abs(nose_ratio - ideal_nose_ratio) / ideal_nose_ratio
850
+ scores.append(max(0, nose_score))
851
+
852
+ return np.mean(scores) * 10 if scores else 7.0
853
+
854
+ def _basic_facial_analysis(self, face_image) -> Dict[str, Any]:
855
+ """基础五官分析 (当dlib不可用时)"""
856
+ return {
857
+ "facial_features": {
858
+ "eyes": 7.0,
859
+ "nose": 7.0,
860
+ "mouth": 7.0,
861
+ "eyebrows": 7.0,
862
+ "jawline": 7.0,
863
+ },
864
+ "harmony_score": 7.0,
865
+ "overall_facial_score": 7.0,
866
+ "analysis_method": "basic_estimation",
867
+ }
868
+
869
+ def draw_facial_landmarks(self, face_image: np.ndarray) -> np.ndarray:
870
+ """
871
+ 在人脸图像上绘制特征点
872
+ :param face_image: 人脸图像
873
+ :return: 带特征点标记的人脸图像
874
+ """
875
+ if not DLIB_AVAILABLE or self.face_mesh is None:
876
+ # 如果没有可用的面部网格检测器,直接返回原图
877
+ return face_image.copy()
878
+
879
+ try:
880
+ # 复制原图用于绘制
881
+ annotated_image = face_image.copy()
882
+
883
+ # MediaPipe需要RGB图像
884
+ rgb_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
885
+
886
+ # 检测关键点
887
+ results = self.face_mesh.process(rgb_image)
888
+
889
+ if not results.multi_face_landmarks:
890
+ logger.warning("No facial landmarks detected for drawing")
891
+ return annotated_image
892
+
893
+ # 获取第一个面部的关键点
894
+ face_landmarks = results.multi_face_landmarks[0]
895
+
896
+ # 绘制所有关键点
897
+ h, w = face_image.shape[:2]
898
+ for landmark in face_landmarks.landmark:
899
+ x = int(landmark.x * w)
900
+ y = int(landmark.y * h)
901
+ # 绘制小圆点表示关键点
902
+ cv2.circle(annotated_image, (x, y), 1, (0, 255, 0), -1)
903
+
904
+ # 绘制十字标记
905
+ cv2.line(annotated_image, (x-2, y), (x+2, y), (0, 255, 0), 1)
906
+ cv2.line(annotated_image, (x, y-2), (x, y+2), (0, 255, 0), 1)
907
+
908
+ return annotated_image
909
+
910
+ except Exception as e:
911
+ logger.error(f"Failed to draw facial landmarks: {e}")
912
+ return face_image.copy()
gfpgan_restorer.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ from config import logger, MODELS_PATH
5
+ from gfpgan import GFPGANer
6
+
7
+
8
+ class GFPGANRestorer:
9
+ def __init__(self):
10
+ start_time = time.perf_counter()
11
+ self.restorer = None
12
+ self._initialize_model()
13
+ init_time = time.perf_counter() - start_time
14
+ if self.restorer is not None:
15
+ logger.info(f"GFPGANRestorer initialized successfully, time: {init_time:.3f}s")
16
+ else:
17
+ logger.info(f"GFPGANRestorer initialization completed but not available, time: {init_time:.3f}s")
18
+
19
+ def _initialize_model(self):
20
+ """初始化GFPGAN模型"""
21
+ try:
22
+ # 尝试多个可能的模型路径
23
+ possible_paths = [
24
+ f"{MODELS_PATH}/GFPGANv1.4.pth",
25
+ f"{MODELS_PATH}/gfpgan/GFPGANv1.4.pth",
26
+ os.path.expanduser("~/.cache/gfpgan/GFPGANv1.4.pth"),
27
+ "./models/GFPGANv1.4.pth"
28
+ ]
29
+
30
+ gfpgan_model_path = None
31
+ for path in possible_paths:
32
+ if os.path.exists(path):
33
+ gfpgan_model_path = path
34
+ break
35
+
36
+ if not gfpgan_model_path:
37
+ logger.warning(f"GFPGAN model file not found, tried paths: {possible_paths}")
38
+ logger.info("Will try to download GFPGAN model from the internet...")
39
+ # 使用默认路径,让GFPGAN自动下载
40
+ gfpgan_model_path = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
41
+
42
+ logger.info(f"Using GFPGAN model: {gfpgan_model_path}")
43
+
44
+ # 初始化GFPGAN
45
+ self.restorer = GFPGANer(
46
+ model_path=gfpgan_model_path,
47
+ upscale=2,
48
+ arch='clean',
49
+ channel_multiplier=2,
50
+ bg_upsampler=None
51
+ )
52
+ logger.info("GFPGAN model initialized successfully")
53
+
54
+ except Exception as e:
55
+ logger.error(f"GFPGAN model initialization failed: {e}")
56
+ self.restorer = None
57
+
58
+
59
+ def is_available(self):
60
+ """检查GFPGAN是否可用"""
61
+ return self.restorer is not None
62
+
63
+ def restore_image(self, image):
64
+ """
65
+ 使用GFPGAN修复老照片
66
+ :param image: 输入图像 (numpy array, BGR格式)
67
+ :return: 修复后的图像 (numpy array, BGR格式)
68
+ """
69
+ if not self.is_available():
70
+ raise Exception("GFPGAN模型未初始化")
71
+
72
+ try:
73
+ logger.info("Starting GFPGAN image restoration...")
74
+
75
+ # GFPGAN处理
76
+ # has_aligned=False: 输入图像没有对齐
77
+ # only_center_face=False: 处理所有检测到的人脸
78
+ # paste_back=True: 将修复的人脸贴回原图
79
+ cropped_faces, restored_faces, restored_img = self.restorer.enhance(
80
+ image,
81
+ has_aligned=False,
82
+ only_center_face=False,
83
+ paste_back=True
84
+ )
85
+
86
+ if restored_img is not None:
87
+ logger.info(f"GFPGAN restoration completed, detected {len(restored_faces)} faces")
88
+ return restored_img
89
+ else:
90
+ logger.warning("GFPGAN restoration returned empty image, using original image")
91
+ return image
92
+
93
+ except Exception as e:
94
+ logger.error(f"GFPGAN image restoration failed: {e}")
95
+ # 如果GFPGAN失败,返回原图而不是抛出异常
96
+ return image
models.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import List, Optional
3
+
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class ModelType(str, Enum):
8
+ """模型类型枚举"""
9
+
10
+ HOWCUTEAMI = "howcuteami"
11
+ DEEPFACE = "deepface"
12
+ HYBRID = "hybrid" # 混合模式:颜值性别用howcuteami,年龄情绪用deepface
13
+
14
+
15
+ class ImageScoreItem(BaseModel):
16
+ file_path: str
17
+ score: float
18
+ is_cropped_face: bool = False
19
+ size_bytes: int
20
+ size_str: str
21
+ last_modified: str
22
+ nickname: Optional[str] = None
23
+
24
+
25
+ class SearchRequest(BaseModel):
26
+ keyword: Optional[str] = ""
27
+ searchType: Optional[str] = "face"
28
+ top_k: Optional[int] = 5
29
+ score_threshold: float = 0.0
30
+ nickname: Optional[str] = None
31
+
32
+
33
+ class ImageSearchRequest(BaseModel):
34
+ image: Optional[str] = None # base64编码的图片
35
+ searchType: Optional[str] = "face"
36
+ top_k: Optional[int] = 5
37
+ score_threshold: float = 0.0
38
+ nickname: Optional[str] = None
39
+
40
+
41
+ class ImageFileList(BaseModel):
42
+ results: List[ImageScoreItem]
43
+ count: int
44
+
45
+ class PagedImageFileList(BaseModel):
46
+ results: List[ImageScoreItem]
47
+ count: int
48
+ page: int
49
+ page_size: int
50
+ total_pages: int
51
+
52
+ class CelebrityMatchResponse(BaseModel):
53
+ filename: str
54
+ display_name: Optional[str] = None
55
+ distance: float
56
+ similarity: float
57
+ confidence: float
58
+ face_filename: Optional[str] = None
59
+
60
+
61
+ class CategoryStatItem(BaseModel):
62
+ category: str
63
+ display_name: str
64
+ count: int
65
+
66
+
67
+ class CategoryStatsResponse(BaseModel):
68
+ stats: List[CategoryStatItem]
69
+ total: int
realesrgan_upscaler.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from config import logger, MODELS_PATH, REALESRGAN_MODEL
8
+
9
+ try:
10
+ from basicsr.archs.rrdbnet_arch import RRDBNet
11
+ from basicsr.utils.download_util import load_file_from_url
12
+ from realesrgan import RealESRGANer
13
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
14
+ import torch
15
+
16
+ # 设置PyTorch CPU优化
17
+ torch.set_num_threads(min(torch.get_num_threads(), 8)) # 限制线程数
18
+ torch.set_num_interop_threads(min(4, torch.get_num_interop_threads())) # 设置操作间线程数
19
+
20
+ REALESRGAN_AVAILABLE = True
21
+ logger.info("Real-ESRGAN imported successfully")
22
+ except ImportError as e:
23
+ logger.error(f"Real-ESRGAN import failed: {e}")
24
+ REALESRGAN_AVAILABLE = False
25
+
26
+
27
+ class RealESRGANUpscaler:
28
+ """Real-ESRGAN超清放大处理器"""
29
+
30
+ def __init__(self):
31
+ start_time = time.perf_counter()
32
+ self.upsampler = None
33
+ self.model_name = None
34
+ self.scale = 4
35
+ self.denoise_strength = 0.5
36
+ self._initialize()
37
+ init_time = time.perf_counter() - start_time
38
+ if self.upsampler is not None:
39
+ logger.info(f"RealESRGANUpscaler initialized successfully, time: {init_time:.3f}s")
40
+ else:
41
+ logger.info(f"RealESRGANUpscaler initialization completed but not available, time: {init_time:.3f}s")
42
+
43
+ def _initialize(self):
44
+ """初始化Real-ESRGAN模型"""
45
+ if not REALESRGAN_AVAILABLE:
46
+ logger.error("Real-ESRGAN is not available, cannot initialize super resolution processor")
47
+ return
48
+
49
+ try:
50
+ # 模型配置 - 从环境变量读取模型名称
51
+ model_name = REALESRGAN_MODEL
52
+ self.model_name = model_name
53
+
54
+ # 根据模型名称设置默认放大倍数
55
+ if 'x2' in model_name:
56
+ self.scale = 2
57
+ elif 'x4' in model_name:
58
+ self.scale = 4
59
+ else:
60
+ self.scale = 4 # 默认4倍
61
+
62
+ # 模型文件路径
63
+ model_path = None
64
+ if model_name == 'RealESRGAN_x4plus':
65
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
66
+ netscale = 4
67
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
68
+ elif model_name == 'RealESRNet_x4plus':
69
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
70
+ netscale = 4
71
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'
72
+ elif model_name == 'RealESRGAN_x4plus_anime_6B':
73
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
74
+ netscale = 4
75
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'
76
+ elif model_name == 'RealESRGAN_x2plus':
77
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
78
+ netscale = 2
79
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
80
+ elif model_name == 'realesr-animevideov3':
81
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
82
+ netscale = 4
83
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth'
84
+ elif model_name == 'realesr-general-x4v3':
85
+ # 最新的通用模型 v0.2.5.0
86
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
87
+ netscale = 4
88
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
89
+ elif model_name == 'realesr-general-wdn-x4v3':
90
+ # 最新的通用模型(带去噪)v0.2.5.0
91
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
92
+ netscale = 4
93
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
94
+
95
+ # 确保模型目录存在
96
+ model_dir = os.path.join(MODELS_PATH, 'realesrgan')
97
+ os.makedirs(model_dir, exist_ok=True)
98
+
99
+ # 检查本地是否已有模型文件
100
+ local_model_path = None
101
+ model_filename = f"{model_name}.pth"
102
+ local_pth = os.path.join(MODELS_PATH, model_filename)
103
+
104
+ if os.path.exists(local_pth):
105
+ local_model_path = local_pth
106
+ logger.info(f"Using local model file: {local_model_path}")
107
+
108
+ # 如果本地有模型文件,使用本地文件,��则下载
109
+ if local_model_path:
110
+ model_path = local_model_path
111
+ else:
112
+ # 下载模型
113
+ logger.info(f"Downloading model {model_name} from {file_url}")
114
+ model_path = load_file_from_url(
115
+ url=file_url, model_dir=model_dir, progress=True, file_name=model_filename)
116
+
117
+ # 创建upsampler
118
+ self.upsampler = RealESRGANer(
119
+ scale=netscale,
120
+ model_path=model_path,
121
+ model=model,
122
+ tile=512, # 启用分块处理,减少内存使用并提高CPU效率
123
+ tile_pad=10,
124
+ pre_pad=0,
125
+ half=False, # 使用fp32精度
126
+ gpu_id=None # 使用CPU
127
+ )
128
+
129
+ logger.info(f"Real-ESRGAN super resolution processor initialized successfully, model: {model_name}")
130
+
131
+ except Exception as e:
132
+ logger.error(f"Failed to initialize Real-ESRGAN: {e}")
133
+ self.upsampler = None
134
+
135
+ def is_available(self):
136
+ """检查处理器是否可用"""
137
+ return REALESRGAN_AVAILABLE and self.upsampler is not None
138
+
139
+ def _optimize_input_image(self, image):
140
+ """
141
+ 优化输入图像以提高CPU处理速度
142
+ :param image: 输入图像
143
+ :return: 优化后的图像
144
+ """
145
+ # 确保图像数据类型为uint8(减少计算开销)
146
+ if image.dtype != np.uint8:
147
+ if image.dtype == np.float32 or image.dtype == np.float64:
148
+ image = (image * 255).astype(np.uint8)
149
+ else:
150
+ image = image.astype(np.uint8)
151
+
152
+ # 确保图像是3通道BGR格式
153
+ if len(image.shape) == 2: # 灰度图
154
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
155
+ elif image.shape[2] == 4: # RGBA
156
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
157
+ elif image.shape[2] == 3 and image.shape[2] != 3: # RGB转BGR
158
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
159
+
160
+ return image
161
+
162
+ def upscale_image(self, image, scale=None, denoise_strength=None):
163
+ """
164
+ 对图像进行超清放大
165
+ :param image: 输入图像 (numpy array)
166
+ :param scale: 放大倍数,默认使用模型的放大倍数
167
+ :param denoise_strength: 去噪强度 (0-1),仅对realesr-general-x4v3模型有效
168
+ :return: 超清后的图像
169
+ """
170
+ if not self.is_available():
171
+ raise RuntimeError("Real-ESRGAN超清处理器不可用")
172
+
173
+ try:
174
+ start_time = time.perf_counter()
175
+
176
+ # 预处理优化图像
177
+ image = self._optimize_input_image(image)
178
+
179
+ # 设置去噪强度(仅对特定模型有效)
180
+ if denoise_strength is not None and self.model_name == 'realesr-general-x4v3':
181
+ self.denoise_strength = denoise_strength
182
+
183
+ # 根据图像大小动态调整tile大小以优化CPU性能
184
+ h, w = image.shape[:2]
185
+ pixel_count = h * w
186
+
187
+ # 根据图像大小调整tile大小
188
+ if pixel_count > 2000000: # 大于2MP
189
+ tile_size = 256
190
+ elif pixel_count > 1000000: # 大于1MP
191
+ tile_size = 384
192
+ else:
193
+ tile_size = 512
194
+
195
+ # 动态更新tile大小
196
+ if hasattr(self.upsampler, 'tile'):
197
+ self.upsampler.tile = tile_size
198
+ logger.info(f"Adjusting tile size to: {tile_size} based on image size ({w}x{h})")
199
+
200
+ # 执行超清处理
201
+ logger.info(f"Starting Real-ESRGAN super resolution processing, model: {self.model_name}")
202
+ output, _ = self.upsampler.enhance(image, outscale=scale or self.scale)
203
+
204
+ processing_time = time.perf_counter() - start_time
205
+ logger.info(f"Real-ESRGAN super resolution processing completed, time: {processing_time:.3f}s")
206
+
207
+ return output
208
+
209
+ except Exception as e:
210
+ logger.error(f"Real-ESRGAN super resolution processing failed: {e}")
211
+ raise RuntimeError(f"超清处理失败: {str(e)}")
212
+
213
+ def get_model_info(self):
214
+ """获取模型信息"""
215
+ return {
216
+ "model_name": self.model_name,
217
+ "scale": self.scale,
218
+ "available": self.is_available()
219
+ }
220
+
221
+
222
+ def get_upscaler():
223
+ """获取Real-ESRGAN超清处理器实例"""
224
+ return RealESRGANUpscaler()
225
+
226
+
227
+ # 全局实例(单例模式)
228
+ _upscaler_instance = None
229
+
230
+ def get_upscaler():
231
+ """获取全局超清处理器实例"""
232
+ global _upscaler_instance
233
+ if _upscaler_instance is None:
234
+ _upscaler_instance = RealESRGANUpscaler()
235
+ return _upscaler_instance
rembg_processor.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Optional, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from config import logger, REMBG_AVAILABLE
8
+
9
+ if REMBG_AVAILABLE:
10
+ import rembg
11
+ from rembg import new_session
12
+ from PIL import Image
13
+
14
+
15
+ class RembgProcessor:
16
+ """rembg抠图处理器"""
17
+
18
+ def __init__(self):
19
+ start_time = time.perf_counter()
20
+ self.session = None
21
+ self.available = False
22
+ self.model_name = "u2net" # 默认使用u2net模型,适合人像抠图
23
+
24
+ if REMBG_AVAILABLE:
25
+ try:
26
+ # 初始化rembg会话
27
+ self.session = new_session(self.model_name)
28
+ self.available = True
29
+ logger.info(f"rembg background removal processor initialized successfully, using model: {self.model_name}")
30
+ except Exception as e:
31
+ logger.error(f"rembg background removal processor initialization failed: {e}")
32
+ self.available = False
33
+ else:
34
+ logger.warning("rembg is not available, background removal function will be disabled")
35
+ init_time = time.perf_counter() - start_time
36
+ if self.available:
37
+ logger.info(f"RembgProcessor initialized successfully, time: {init_time:.3f}s")
38
+ else:
39
+ logger.info(f"RembgProcessor initialization completed but not available, time: {init_time:.3f}s")
40
+
41
+ def is_available(self) -> bool:
42
+ """检查抠图处理器是否可用"""
43
+ return self.available and self.session is not None
44
+
45
+ def remove_background(self, image: np.ndarray, background_color: Optional[Tuple[int, int, int]] = None) -> np.ndarray:
46
+ """
47
+ 移除图片背景
48
+ :param image: 输入的OpenCV图像(BGR格式)
49
+ :param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景
50
+ :return: 处理后的图像
51
+ """
52
+ if not self.is_available():
53
+ raise Exception("rembg抠图处理器不可用")
54
+
55
+ try:
56
+ # 将OpenCV图像(BGR)转换为PIL图像(RGB)
57
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
58
+ pil_image = Image.fromarray(image_rgb)
59
+
60
+ # 使用rembg移除背景
61
+ logger.info("Starting to remove background using rembg...")
62
+ output_image = rembg.remove(pil_image, session=self.session)
63
+
64
+ # 转换回OpenCV格式
65
+ if background_color is not None:
66
+ # 如果指定了背景颜色,创建纯色背景
67
+ background = Image.new('RGB', output_image.size, background_color[::-1]) # BGR转RGB
68
+ # 将透明图像粘贴到背景上
69
+ background.paste(output_image, mask=output_image)
70
+ result_array = np.array(background)
71
+ result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
72
+ else:
73
+ # 保持透明背景,转换为BGRA格式
74
+ result_array = np.array(output_image)
75
+ if result_array.shape[2] == 4: # RGBA格式
76
+ # 转换RGBA到BGRA
77
+ result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGBA2BGRA)
78
+ else: # RGB格式
79
+ result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
80
+
81
+ logger.info("rembg background removal completed")
82
+ return result_bgr
83
+
84
+ except Exception as e:
85
+ logger.error(f"rembg background removal failed: {e}")
86
+ raise Exception(f"背景移除失败: {str(e)}")
87
+
88
+ def create_id_photo(self, image: np.ndarray, background_color: Tuple[int, int, int] = (255, 255, 255)) -> np.ndarray:
89
+ """
90
+ 创建证件照(移除背景并添加纯色背景)
91
+ :param image: 输入的OpenCV图像
92
+ :param background_color: 背景颜色,默认白色(BGR格式)
93
+ :return: 处理后的证件照
94
+ """
95
+ logger.info(f"Starting to create ID photo, background color: {background_color}")
96
+
97
+ # 移除背景并添加指定颜色背景
98
+ id_photo = self.remove_background(image, background_color)
99
+
100
+ logger.info("ID photo creation completed")
101
+ return id_photo
102
+
103
+ def get_supported_models(self) -> list:
104
+ """获取支持的模型列表"""
105
+ if not REMBG_AVAILABLE:
106
+ return []
107
+
108
+ # rembg支持的模型列表
109
+ return [
110
+ "u2net", # 通用模型,适合人像
111
+ "u2net_human_seg", # 专门针对人像的模型
112
+ "silueta", # 适合物体抠图
113
+ "isnet-general-use" # 更精确的通用模型
114
+ ]
115
+
116
+ def switch_model(self, model_name: str) -> bool:
117
+ """
118
+ 切换rembg模型
119
+ :param model_name: 模型名称
120
+ :return: 是否切换成功
121
+ """
122
+ if not REMBG_AVAILABLE:
123
+ return False
124
+
125
+ try:
126
+ if model_name in self.get_supported_models():
127
+ self.session = new_session(model_name)
128
+ self.model_name = model_name
129
+ logger.info(f"rembg model switched to: {model_name}")
130
+ return True
131
+ else:
132
+ logger.error(f"Unsupported model: {model_name}")
133
+ return False
134
+ except Exception as e:
135
+ logger.error(f"Failed to switch model: {e}")
136
+ return False
requirements.txt ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 固定NumPy版本避免兼容性问题 - 必须最先安装
2
+ numpy>=1.24.0,<2.0.0
3
+
4
+ # 基础依赖
5
+ fastapi>=0.104.0
6
+ uvicorn[standard]>=0.24.0
7
+ python-multipart>=0.0.6
8
+ aiofiles>=23.2.1
9
+
10
+ # 图像处理
11
+ opencv-python>=4.8.0
12
+ Pillow>=10.0.0
13
+
14
+ # PyTorch 相关包 - 升级到2.x版本解决依赖冲突
15
+ torch>=2.0.0,<2.9.0
16
+ torchvision>=0.15.0
17
+
18
+ # 机器学习和CV相关
19
+ tf-keras
20
+ aiohttp
21
+ ultralytics
22
+ deepface>=0.0.79
23
+ mediapipe>=0.10.0
24
+ # ModelScope相关包 - 让pip自动解决版本依赖
25
+ modelscope==1.28.2
26
+ datasets==2.21.0
27
+ transformers==4.40.0
28
+ # ModelScope DDColor的额外依赖
29
+ timm==1.0.19
30
+ sortedcontainers==2.4.0
31
+ fsspec==2024.6.1
32
+ multiprocess==0.70.16
33
+ xxhash==3.5.0
34
+ dill==0.3.8
35
+ huggingface-hub==0.34.3
36
+ # 修复pyarrow兼容性问题 - 使用稳定版本
37
+ pyarrow==20.0.0
38
+
39
+ # API相关
40
+ pydantic>=2.4.0
41
+ starlette>=0.27.0
42
+ simplejson==3.20.1
43
+ # 科学计算和工具
44
+ scipy>=1.7.0,<1.13.0
45
+ tqdm
46
+ lmdb
47
+ pyyaml
48
+
49
+ # 定时任务
50
+ apscheduler>=3.10.0
51
+
52
+ # 数据库
53
+ aiomysql>=0.2.0
54
+
55
+ # 对象存储
56
+ boto3>=1.34.0
57
+
58
+ # GFPGAN 和相关包 - 修复依赖兼容性
59
+ basicsr>=1.3.3
60
+ facexlib>=0.2.5
61
+ gfpgan>=1.3.0
62
+ realesrgan>=0.3.0
63
+
64
+ # CLIP 相关依赖
65
+ cn_clip
66
+ faiss-cpu
67
+ onnxruntime
68
+ diffusers
69
+ accelerate
70
+ # rembg 抠图处理
71
+ rembg>=2.0.50
72
+ easydict
rvm_processor.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from torchvision import transforms
6
+
7
+ import config
8
+ from config import logger
9
+
10
+
11
+ class RVMProcessor:
12
+ """RVM (Robust Video Matting) 抠图处理器"""
13
+
14
+ def __init__(self):
15
+ self.model = None
16
+ self.available = False
17
+ self.device = "cpu" # 默认使用CPU,如果有GPU可以设置为"cuda"
18
+
19
+ try:
20
+ # 仅从本地加载,不使用网络
21
+ local_repo = getattr(config, 'RVM_LOCAL_REPO', '')
22
+ weights_path = getattr(config, 'RVM_WEIGHTS_PATH', '')
23
+
24
+ if not local_repo or not os.path.isdir(local_repo):
25
+ raise RuntimeError("RVM_LOCAL_REPO not set or invalid. Please set env RVM_LOCAL_REPO to local RobustVideoMatting repo path (with hubconf.py)")
26
+
27
+ if not weights_path or not os.path.isfile(weights_path):
28
+ raise RuntimeError("RVM_WEIGHTS_PATH not set or file not found. Please set env RVM_WEIGHTS_PATH to local RVM weights file path")
29
+
30
+ logger.info(f"Loading RVM model {config.RVM_MODEL} from local repo: {local_repo}")
31
+ # 使用本地仓库构建模型,禁用预训练以避免联网
32
+ self.model = torch.hub.load(local_repo, config.RVM_MODEL, source='local', pretrained=False)
33
+
34
+ # 加载本地权重
35
+ state = torch.load(weights_path, map_location=self.device)
36
+ if isinstance(state, dict) and 'state_dict' in state:
37
+ state = state['state_dict']
38
+ missing, unexpected = self.model.load_state_dict(state, strict=False)
39
+
40
+ # 迁移到设备并设置评估模式
41
+ self.model = self.model.to(self.device).eval()
42
+ self.available = True
43
+ logger.info("RVM background removal processor initialized successfully (local mode)")
44
+ if missing:
45
+ logger.warning(f"RVM weights missing keys: {list(missing)[:5]}... total={len(missing)}")
46
+ if unexpected:
47
+ logger.warning(f"RVM weights unexpected keys: {list(unexpected)[:5]}... total={len(unexpected)}")
48
+
49
+ except Exception as e:
50
+ logger.error(f"RVM background removal processor initialization failed: {e}")
51
+ self.available = False
52
+
53
+ def is_available(self) -> bool:
54
+ """检查RVM处理器是否可用"""
55
+ return self.available and self.model is not None
56
+
57
+ def remove_background(self, image: np.ndarray, background_color: tuple = None) -> np.ndarray:
58
+ """
59
+ 使用RVM移除图片背景
60
+ :param image: 输入的OpenCV图像(BGR格式)
61
+ :param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景
62
+ :return: 处理后的图像
63
+ """
64
+ if not self.is_available():
65
+ raise Exception("RVM抠图处理器不可用")
66
+
67
+ try:
68
+ logger.info("Starting to remove background using RVM...")
69
+
70
+ # 保存原始图像尺寸
71
+ original_height, original_width = image.shape[:2]
72
+
73
+ # 将OpenCV图像(BGR)转换为RGB格式
74
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
75
+
76
+ # 转换为tensor
77
+ src = transforms.ToTensor()(image_rgb).unsqueeze(0).to(self.device)
78
+
79
+ # 推理
80
+ rec = [None] * 4
81
+ with torch.no_grad():
82
+ fgr, pha, *rec = self.model(src, *rec, downsample_ratio=0.25)
83
+
84
+ # 转换为numpy数组
85
+ fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # (H,W,3)
86
+ pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) # (H,W)
87
+
88
+ # 检查尺寸是否匹配,如果不匹配则调整
89
+ if fgr.shape[:2] != (original_height, original_width):
90
+ fgr = cv2.resize(fgr, (original_width, original_height))
91
+ pha = cv2.resize(pha, (original_width, original_height))
92
+
93
+ if background_color is not None:
94
+ # 如果指定了背景颜色,创建纯色背景
95
+ # 将前景图像转换为BGR格式
96
+ fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR)
97
+
98
+ # 创建背景图像
99
+ background = np.full((original_height, original_width, 3), background_color, dtype=np.uint8)
100
+
101
+ # 使用alpha混合
102
+ alpha = pha.astype(np.float32) / 255.0
103
+ alpha = np.stack([alpha] * 3, axis=-1)
104
+
105
+ result = (fgr_bgr * alpha + background * (1 - alpha)).astype(np.uint8)
106
+ else:
107
+ # 保持透明背景,转换为BGRA格式
108
+ fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR)
109
+ rgba = np.dstack((fgr_bgr, pha)) # (H,W,4)
110
+ result = rgba
111
+
112
+ logger.info("RVM background removal completed")
113
+ return result
114
+
115
+ except Exception as e:
116
+ logger.error(f"RVM background removal failed: {e}")
117
+ raise Exception(f"背景移除失败: {str(e)}")
118
+
119
+ def create_id_photo(self, image: np.ndarray, background_color: tuple = (255, 255, 255)) -> np.ndarray:
120
+ """
121
+ 创建证件照(移除背景并添加纯色背景)
122
+ :param image: 输入的OpenCV图像
123
+ :param background_color: 背景颜色,默认白色(BGR格式)
124
+ :return: 处理后的证件照
125
+ """
126
+ logger.info(f"Starting to create ID photo, background color: {background_color}")
127
+
128
+ # 移除背景并添加指定颜色背景
129
+ id_photo = self.remove_background(image, background_color)
130
+
131
+ logger.info("ID photo creation completed")
132
+ return id_photo
test_tensorflow.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import deepface
2
+ import sys
3
+ import tensorflow as tf
4
+
5
+ try:
6
+ import keras
7
+
8
+ keras_pkg = "keras (standalone)"
9
+ keras_ver = keras.__version__
10
+ except Exception:
11
+ from tensorflow import keras
12
+
13
+ keras_pkg = "tf.keras"
14
+ keras_ver = keras.__version__
15
+
16
+ print("py =", sys.version)
17
+ print("deepface =", deepface.__version__)
18
+ print("tensorflow =", tf.__version__)
19
+ print("keras pkg =", keras_pkg, "keras =", keras_ver)
utils.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import hashlib
3
+ import os
4
+ import re
5
+ import shutil
6
+ import threading
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ try:
13
+ import boto3
14
+ from botocore.exceptions import BotoCoreError, ClientError
15
+ except ImportError:
16
+ boto3 = None
17
+ BotoCoreError = ClientError = Exception
18
+
19
+ from config import (
20
+ IMAGES_DIR,
21
+ IMG_QUALITY,
22
+ logger,
23
+ SAVE_QUALITY,
24
+ BOS_ACCESS_KEY,
25
+ BOS_SECRET_KEY,
26
+ BOS_ENDPOINT,
27
+ BOS_BUCKET_NAME,
28
+ BOS_IMAGE_DIR,
29
+ BOS_UPLOAD_ENABLED,
30
+ )
31
+
32
+ _BOS_CLIENT = None
33
+ _BOS_CLIENT_INITIALIZED = False
34
+ _BOS_CLIENT_LOCK = threading.Lock()
35
+ _IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR))
36
+
37
+
38
+ def _decode_bos_credential(raw_value: str) -> str:
39
+ """将Base64编码的凭证解码为明文,若解码失败则返回原值"""
40
+ if not raw_value:
41
+ return ""
42
+
43
+ value = raw_value.strip()
44
+ if not value:
45
+ return ""
46
+
47
+ try:
48
+ padding = len(value) % 4
49
+ if padding:
50
+ value += "=" * (4 - padding)
51
+ decoded = base64.b64decode(value).decode("utf-8").strip()
52
+ if decoded:
53
+ return decoded
54
+ except Exception:
55
+ pass
56
+ return value
57
+
58
+
59
+ def _is_path_under_images_dir(file_path: str) -> bool:
60
+ try:
61
+ return os.path.commonpath(
62
+ [_IMAGES_DIR_ABS, os.path.abspath(file_path)]
63
+ ) == _IMAGES_DIR_ABS
64
+ except ValueError:
65
+ return False
66
+
67
+
68
+ def _get_bos_client():
69
+ global _BOS_CLIENT, _BOS_CLIENT_INITIALIZED
70
+ if _BOS_CLIENT_INITIALIZED:
71
+ return _BOS_CLIENT
72
+
73
+ with _BOS_CLIENT_LOCK:
74
+ if _BOS_CLIENT_INITIALIZED:
75
+ return _BOS_CLIENT
76
+
77
+ if not BOS_UPLOAD_ENABLED:
78
+ _BOS_CLIENT_INITIALIZED = True
79
+ _BOS_CLIENT = None
80
+ return None
81
+ access_key = _decode_bos_credential(BOS_ACCESS_KEY)
82
+ secret_key = _decode_bos_credential(BOS_SECRET_KEY)
83
+ if not all([access_key, secret_key, BOS_ENDPOINT, BOS_BUCKET_NAME]):
84
+ logger.warning("BOS 上传未配置完整,跳过初始化")
85
+ _BOS_CLIENT_INITIALIZED = True
86
+ _BOS_CLIENT = None
87
+ return None
88
+
89
+ if boto3 is None:
90
+ logger.warning("未安装 boto3,BOS 上传功能不可用")
91
+ _BOS_CLIENT_INITIALIZED = True
92
+ _BOS_CLIENT = None
93
+ return None
94
+
95
+ try:
96
+ _BOS_CLIENT = boto3.client(
97
+ "s3",
98
+ aws_access_key_id=access_key,
99
+ aws_secret_access_key=secret_key,
100
+ endpoint_url=BOS_ENDPOINT,
101
+ )
102
+ logger.info("BOS 客户端初始化成功")
103
+ except Exception as e:
104
+ logger.warning(f"初始化 BOS 客户端失败,将跳过上传: {e}")
105
+ _BOS_CLIENT = None
106
+ finally:
107
+ _BOS_CLIENT_INITIALIZED = True
108
+
109
+ return _BOS_CLIENT
110
+
111
+
112
+ def upload_file_to_bos(file_path: str, object_name: str | None = None) -> bool:
113
+ """
114
+ 将指定文件上传到 BOS,失败不会抛出异常。
115
+ :param file_path: 本地文件路径
116
+ :param object_name: BOS 对象名称(可选)
117
+ :return: 是否成功上传
118
+ """
119
+ if not BOS_UPLOAD_ENABLED:
120
+ return False
121
+
122
+ expanded_path = os.path.abspath(os.path.expanduser(file_path))
123
+ if not os.path.isfile(expanded_path):
124
+ return False
125
+
126
+ if not _is_path_under_images_dir(expanded_path):
127
+ # 仅上传 IMAGES_DIR 内的文件,避免将临时文件同步至 BOS
128
+ return False
129
+
130
+ client = _get_bos_client()
131
+ if client is None:
132
+ return False
133
+
134
+ # 生成对象名称
135
+ if object_name:
136
+ object_key = object_name.strip("/ ")
137
+ else:
138
+ base_name = os.path.basename(expanded_path)
139
+ if BOS_IMAGE_DIR:
140
+ object_key = "/".join(
141
+ part.strip("/ ") for part in (BOS_IMAGE_DIR, base_name) if part
142
+ )
143
+ else:
144
+ object_key = base_name
145
+
146
+ try:
147
+ client.upload_file(expanded_path, BOS_BUCKET_NAME, object_key)
148
+ logger.info(f"文件已同步至 BOS: {object_key}")
149
+ return True
150
+ except (ClientError, BotoCoreError, Exception) as e:
151
+ logger.warning(f"上传到 BOS 失败({object_key}): {e}")
152
+ return False
153
+
154
+
155
+ def delete_file_from_bos(file_path: str | None = None,
156
+ object_name: str | None = None) -> bool:
157
+ """
158
+ 删除 BOS 中的指定对象,失败不会抛出异常。
159
+ :param file_path: 本地文件路径(可选,用于推导文件名)
160
+ :param object_name: BOS 对象名称(可选,优先使用)
161
+ :return: 是否成功删除
162
+ """
163
+ if not BOS_UPLOAD_ENABLED:
164
+ return False
165
+
166
+ client = _get_bos_client()
167
+ if client is None:
168
+ return False
169
+
170
+ key_candidate = object_name.strip("/ ") if object_name else ""
171
+
172
+ if not key_candidate and file_path:
173
+ base_name = os.path.basename(
174
+ os.path.abspath(os.path.expanduser(file_path)))
175
+ key_candidate = base_name.strip()
176
+
177
+ if not key_candidate:
178
+ return False
179
+
180
+ if BOS_IMAGE_DIR:
181
+ object_key = "/".join(
182
+ part.strip("/ ") for part in (BOS_IMAGE_DIR, key_candidate) if part
183
+ )
184
+ else:
185
+ object_key = key_candidate
186
+
187
+ try:
188
+ client.delete_object(Bucket=BOS_BUCKET_NAME, Key=object_key)
189
+ logger.info(f"已从 BOS 删除文件: {object_key}")
190
+ return True
191
+ except (ClientError, BotoCoreError, Exception) as e:
192
+ logger.warning(f"删除 BOS 文件失败({object_key}): {e}")
193
+ return False
194
+
195
+
196
+ def image_to_base64(image: np.ndarray) -> str:
197
+ """将OpenCV图像转换为base64字符串"""
198
+ if image is None or image.size == 0:
199
+ return ""
200
+ _, buffer = cv2.imencode(".webp", image, [cv2.IMWRITE_WEBP_QUALITY, 90])
201
+ img_base64 = base64.b64encode(buffer).decode("utf-8")
202
+ return f"data:image/webp;base64,{img_base64}"
203
+
204
+
205
+ def save_base64_to_unique_file(
206
+ base64_string: str, output_dir: str = "output_images"
207
+ ) -> str | None:
208
+ """
209
+ 将带有MIME类型前缀的Base64字符串解码并保存到本地。
210
+ 文件名格式为: {md5_hash}_{timestamp}.{extension}
211
+ """
212
+ os.makedirs(output_dir, exist_ok=True)
213
+
214
+ try:
215
+ match = re.match(r"data:(image/\w+);base64,(.+)", base64_string)
216
+ if match:
217
+ mime_type = match.group(1)
218
+ base64_data = match.group(2)
219
+ else:
220
+ mime_type = "image/jpeg"
221
+ base64_data = base64_string
222
+
223
+ extension_map = {
224
+ "image/jpeg": "jpg",
225
+ "image/png": "png",
226
+ "image/gif": "gif",
227
+ "image/webp": "webp",
228
+ }
229
+ file_extension = extension_map.get(mime_type, "webp")
230
+
231
+ decoded_data = base64.b64decode(base64_data)
232
+
233
+ except (ValueError, TypeError, base64.binascii.Error) as e:
234
+ logger.error(f"Base64 decoding failed: {e}")
235
+ return None
236
+
237
+ md5_hash = hashlib.md5(base64_data.encode("utf-8")).hexdigest()
238
+ filename = f"{md5_hash}.{file_extension}"
239
+ file_path = os.path.join(output_dir, filename)
240
+
241
+ try:
242
+ with open(file_path, "wb") as f:
243
+ f.write(decoded_data)
244
+ return file_path
245
+ except IOError as e:
246
+ logger.error(f"File writing failed: {e}")
247
+ return None
248
+
249
+
250
+ def save_image_force_compress(
251
+ image: np.ndarray,
252
+ output_path: str,
253
+ max_size_kb: int = 100,
254
+ min_scale: float = 0.1,
255
+ scale_step: float = 0.9,
256
+ initial_quality: int = 95,
257
+ min_quality: int = 10,
258
+ quality_step: int = 5,
259
+ ) -> bool:
260
+ """
261
+ 强制压缩图像到 max_size_kb 以下,即使原图已小于该大小。
262
+ 先缩小尺寸再压缩质量,直到满足要求或失败。
263
+ """
264
+ max_bytes = max_size_kb * 1024
265
+ scale = IMG_QUALITY
266
+ height, width = image.shape[:2]
267
+ while scale >= min_scale:
268
+ resized_img = cv2.resize(
269
+ image,
270
+ (int(width * scale), int(height * scale)),
271
+ interpolation=cv2.INTER_AREA,
272
+ )
273
+ quality = initial_quality
274
+
275
+ while quality >= min_quality:
276
+ success, encoded_img = cv2.imencode(
277
+ ".webp", resized_img, [cv2.IMWRITE_WEBP_QUALITY, quality]
278
+ )
279
+ if not success:
280
+ return False
281
+
282
+ if len(encoded_img) <= max_bytes:
283
+ with open(output_path, "wb") as f:
284
+ f.write(encoded_img)
285
+ logger.debug(
286
+ f"压缩后图像大小: {len(encoded_img) / 1024:.2f} KB,scale={scale:.2f}, quality={quality}"
287
+ )
288
+ upload_file_to_bos(output_path)
289
+ return True
290
+
291
+ quality -= quality_step
292
+
293
+ scale *= scale_step
294
+
295
+ return False
296
+
297
+
298
+ def human_readable_size(size_bytes):
299
+ """人性化文件大小展示"""
300
+ for unit in ["B", "KB", "MB", "GB"]:
301
+ if size_bytes < 1024:
302
+ return f"{size_bytes:.1f} {unit}"
303
+ size_bytes /= 1024
304
+ return f"{size_bytes:.1f} TB"
305
+
306
+
307
+ def delete_file(file_path: str):
308
+ try:
309
+ os.remove(file_path)
310
+ logger.info(f"Deleted file: {file_path}")
311
+ except Exception as error:
312
+ logger.error(f"Failed to delete file {file_path}: {error}")
313
+
314
+
315
+ def move_file_to_archive(file_path: str):
316
+ try:
317
+ if not os.path.exists(IMAGES_DIR):
318
+ os.makedirs(IMAGES_DIR)
319
+ filename = os.path.basename(file_path)
320
+ destination = os.path.join(IMAGES_DIR, filename)
321
+ shutil.move(file_path, destination)
322
+ logger.debug(f"Moved file to archive: {destination}")
323
+ except Exception as error:
324
+ logger.error(f"Failed to move file {file_path} to archive: {error}")
325
+
326
+
327
+ def save_image_high_quality(image: np.ndarray, output_path: str, quality: int = SAVE_QUALITY) -> bool:
328
+ """
329
+ 保存图像,保持高质量,不进行压缩
330
+ :param image: 图像数组
331
+ :param output_path: 输出路径
332
+ :param quality: WebP质量 (0-100),默认95
333
+ :return: 保存���否成功
334
+ """
335
+ try:
336
+ success, encoded_img = cv2.imencode(
337
+ ".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality]
338
+ )
339
+ if not success:
340
+ logger.error(f"Image encoding failed: {output_path}")
341
+ return False
342
+
343
+ with open(output_path, "wb") as f:
344
+ f.write(encoded_img)
345
+
346
+ logger.info(f"High quality image saved successfully: {output_path}, quality: {quality}, size: {len(encoded_img) / 1024:.2f} KB")
347
+ upload_file_to_bos(output_path)
348
+ return True
349
+ except Exception as e:
350
+ logger.error(f"Failed to save image: {output_path}, error: {e}")
351
+ return False
352
+
353
+
354
+ def convert_numpy_types(obj):
355
+ """转换所有 numpy 类型为原生 Python 类型"""
356
+ if isinstance(obj, (np.float32, np.float64)):
357
+ return float(obj)
358
+ elif isinstance(obj, (np.int32, np.int64)):
359
+ return int(obj)
360
+ elif isinstance(obj, dict):
361
+ return {k: convert_numpy_types(v) for k, v in obj.items()}
362
+ elif isinstance(obj, list):
363
+ return [convert_numpy_types(i) for i in obj]
364
+ else:
365
+ return obj
366
+
367
+
368
+ def compress_image_by_quality(image: np.ndarray, quality: int, output_format: str = 'webp') -> tuple[bytes, dict]:
369
+ """
370
+ 按质量压缩图像
371
+ :param image: 输入图像
372
+ :param quality: 压缩质量 (10-100)
373
+ :param output_format: 输出格式 ('jpg', 'png', 'webp')
374
+ :return: (压缩后的图像字节数据, 压缩信息)
375
+ """
376
+ try:
377
+ height, width = image.shape[:2]
378
+
379
+ if output_format.lower() == 'png':
380
+ # PNG使用压缩级别 (0-9),质量参数转换为压缩级别
381
+ compression_level = max(0, min(9, int((100 - quality) / 10)))
382
+ success, encoded_img = cv2.imencode(
383
+ ".png", image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level]
384
+ )
385
+ elif output_format.lower() == 'webp':
386
+ # WebP支持质量参数
387
+ success, encoded_img = cv2.imencode(
388
+ ".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality]
389
+ )
390
+ else:
391
+ # JPG格式
392
+ success, encoded_img = cv2.imencode(
393
+ ".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality]
394
+ )
395
+
396
+ if not success:
397
+ raise Exception("图像编码失败")
398
+
399
+ compressed_bytes = encoded_img.tobytes()
400
+
401
+ info = {
402
+ 'original_dimensions': f"{width} × {height}",
403
+ 'compressed_dimensions': f"{width} × {height}",
404
+ 'quality': quality,
405
+ 'format': output_format.upper(),
406
+ 'size': len(compressed_bytes)
407
+ }
408
+
409
+ return compressed_bytes, info
410
+
411
+ except Exception as e:
412
+ logger.error(f"Failed to compress image by quality: {e}")
413
+ raise
414
+
415
+
416
+ def compress_image_by_dimensions(image: np.ndarray, target_width: int, target_height: int,
417
+ quality: int = 100, output_format: str = 'jpg') -> tuple[bytes, dict]:
418
+ """
419
+ 按尺寸压缩图像
420
+ :param image: 输入图像
421
+ :param target_width: 目标宽度
422
+ :param target_height: 目标高度
423
+ :param quality: 压缩质量
424
+ :param output_format: 输出格式
425
+ :return: (压缩后的图像字节数据, 压缩信息)
426
+ """
427
+ try:
428
+ original_height, original_width = image.shape[:2]
429
+
430
+ # 调整图像尺寸
431
+ resized_image = cv2.resize(
432
+ image, (target_width, target_height),
433
+ interpolation=cv2.INTER_AREA
434
+ )
435
+
436
+ # 按质量编码
437
+ if output_format.lower() == 'png':
438
+ compression_level = max(0, min(9, int((100 - quality) / 10)))
439
+ success, encoded_img = cv2.imencode(
440
+ ".png", resized_image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level]
441
+ )
442
+ elif output_format.lower() == 'webp':
443
+ success, encoded_img = cv2.imencode(
444
+ ".webp", resized_image, [cv2.IMWRITE_WEBP_QUALITY, quality]
445
+ )
446
+ else:
447
+ success, encoded_img = cv2.imencode(
448
+ ".jpg", resized_image, [cv2.IMWRITE_JPEG_QUALITY, quality]
449
+ )
450
+
451
+ if not success:
452
+ raise Exception("图像编码失败")
453
+
454
+ compressed_bytes = encoded_img.tobytes()
455
+
456
+ info = {
457
+ 'original_dimensions': f"{original_width} × {original_height}",
458
+ 'compressed_dimensions': f"{target_width} × {target_height}",
459
+ 'quality': quality,
460
+ 'format': output_format.upper(),
461
+ 'size': len(compressed_bytes)
462
+ }
463
+
464
+ return compressed_bytes, info
465
+
466
+ except Exception as e:
467
+ logger.error(f"Failed to compress image by dimensions: {e}")
468
+ raise
469
+
470
+
471
+ def compress_image_by_file_size(image: np.ndarray, target_size_kb: float,
472
+ output_format: str = 'jpg') -> tuple[bytes, dict]:
473
+ """
474
+ 按文件大小压缩图像 - 使用多阶段二分法精确控制大小
475
+ :param image: 输入图像
476
+ :param target_size_kb: 目标文件大小(KB)
477
+ :param output_format: 输出格式
478
+ :return: (压缩后的图像字节数据, 压缩信息)
479
+ """
480
+ try:
481
+ original_height, original_width = image.shape[:2]
482
+ target_size_bytes = int(target_size_kb * 1024)
483
+
484
+ def encode_image(img, quality):
485
+ """编码图像并返回字节数据"""
486
+ if output_format.lower() == 'png':
487
+ compression_level = max(0, min(9, int((100 - quality) / 10)))
488
+ success, encoded_img = cv2.imencode(
489
+ ".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compression_level]
490
+ )
491
+ elif output_format.lower() == 'webp':
492
+ success, encoded_img = cv2.imencode(
493
+ ".webp", img, [cv2.IMWRITE_WEBP_QUALITY, quality]
494
+ )
495
+ else:
496
+ success, encoded_img = cv2.imencode(
497
+ ".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, quality]
498
+ )
499
+
500
+ if success:
501
+ return encoded_img.tobytes()
502
+ return None
503
+
504
+ def find_best_scale_and_quality(target_bytes):
505
+ """寻找最佳的尺寸和质量组合"""
506
+ best_result = None
507
+ best_diff = float('inf')
508
+
509
+ # 尝试多个尺寸比例
510
+ test_scales = [1.0, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3]
511
+
512
+ for scale in test_scales:
513
+ # 调整图像尺寸
514
+ if scale < 1.0:
515
+ new_width = int(original_width * scale)
516
+ new_height = int(original_height * scale)
517
+ if new_width < 50 or new_height < 50: # 避免尺寸太小
518
+ continue
519
+ working_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
520
+ else:
521
+ working_image = image
522
+ new_width, new_height = original_width, original_height
523
+
524
+ # 在这个尺寸下使用二分法寻找最佳质量
525
+ min_q, max_q = 10, 100
526
+ scale_best_result = None
527
+ scale_best_diff = float('inf')
528
+
529
+ for _ in range(20): # 每个尺寸最多尝试20次质量调整
530
+ current_quality = (min_q + max_q) // 2
531
+
532
+ compressed_bytes = encode_image(working_image, current_quality)
533
+ if not compressed_bytes:
534
+ break
535
+
536
+ current_size = len(compressed_bytes)
537
+ size_diff = abs(current_size - target_bytes)
538
+ size_ratio = current_size / target_bytes
539
+
540
+ # 如果找到精确匹配,立即返回
541
+ if 0.99 <= size_ratio <= 1.01: # 1%误差以内
542
+ return {
543
+ 'bytes': compressed_bytes,
544
+ 'scale': scale,
545
+ 'width': new_width,
546
+ 'height': new_height,
547
+ 'quality': current_quality,
548
+ 'size': current_size,
549
+ 'ratio': size_ratio
550
+ }
551
+
552
+ # 记录该尺寸下的最佳结果
553
+ if size_diff < scale_best_diff:
554
+ scale_best_diff = size_diff
555
+ scale_best_result = {
556
+ 'bytes': compressed_bytes,
557
+ 'scale': scale,
558
+ 'width': new_width,
559
+ 'height': new_height,
560
+ 'quality': current_quality,
561
+ 'size': current_size,
562
+ 'ratio': size_ratio
563
+ }
564
+
565
+ # 二分法调整质量
566
+ if current_size > target_bytes:
567
+ max_q = current_quality - 1
568
+ else:
569
+ min_q = current_quality + 1
570
+
571
+ if min_q >= max_q:
572
+ break
573
+
574
+ # 更新全局最佳结果
575
+ if scale_best_result and scale_best_diff < best_diff:
576
+ best_diff = scale_best_diff
577
+ best_result = scale_best_result
578
+
579
+ # 如果已经找到很好的结果(5%以内),可以提前结束
580
+ if best_result and 0.95 <= best_result['ratio'] <= 1.05:
581
+ break
582
+
583
+ return best_result
584
+
585
+ logger.info(f"Starting multi-stage compression, target size: {target_size_bytes} bytes ({target_size_kb}KB)")
586
+
587
+ # 寻找最佳组合
588
+ result = find_best_scale_and_quality(target_size_bytes)
589
+
590
+ if result:
591
+ error_percent = abs(result['ratio'] - 1) * 100
592
+ logger.info(f"Compression completed: scale ratio {result['scale']:.2f}, quality {result['quality']}%, "
593
+ f"size {result['size']} bytes, error {error_percent:.2f}%")
594
+
595
+ # 不管误差多大都返回最接近的结果,只记录警告
596
+ if error_percent > 10:
597
+ if result['ratio'] < 0.5: # 压缩过度
598
+ suggested_size = result['size'] / 1024
599
+ logger.warning(f"Target size {target_size_kb}KB is too small, actually compressed to {suggested_size:.1f}KB, error {error_percent:.1f}%")
600
+ elif result['ratio'] > 2.0: # 无法达到目标
601
+ suggested_size = result['size'] / 1024
602
+ logger.warning(f"Target size {target_size_kb}KB is too large, minimum can be compressed to {suggested_size:.1f}KB, error {error_percent:.1f}%")
603
+ else:
604
+ logger.warning(f"Cannot achieve target accuracy, error {error_percent:.1f}%, returning closest result")
605
+
606
+ info = {
607
+ 'original_dimensions': f"{original_width} × {original_height}",
608
+ 'compressed_dimensions': f"{result['width']} × {result['height']}",
609
+ 'quality': result['quality'],
610
+ 'format': output_format.upper(),
611
+ 'size': result['size']
612
+ }
613
+
614
+ return result['bytes'], info
615
+ else:
616
+ raise Exception(f"无法将图片压缩到目标大小 {target_size_kb}KB")
617
+
618
+ except Exception as e:
619
+ logger.error(f"Failed to compress image by file size: {e}")
620
+ raise
621
+
622
+
623
+ def convert_image_format(image: np.ndarray, target_format: str, quality: int = 100) -> tuple[bytes, dict]:
624
+ """
625
+ 转换图像格式
626
+ :param image: 输入图像
627
+ :param target_format: 目标格式 ('jpg', 'png', 'webp')
628
+ :param quality: 质量参数
629
+ :return: (转换后的图像字节数据, 格式信息)
630
+ """
631
+ try:
632
+ height, width = image.shape[:2]
633
+
634
+ if target_format.lower() == 'png':
635
+ # PNG格式,使用压缩级别
636
+ compression_level = 6 # 默认压缩级别
637
+ success, encoded_img = cv2.imencode(
638
+ ".png", image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level]
639
+ )
640
+ elif target_format.lower() == 'webp':
641
+ # WebP格式
642
+ success, encoded_img = cv2.imencode(
643
+ ".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality]
644
+ )
645
+ else:
646
+ # JPG格式
647
+ success, encoded_img = cv2.imencode(
648
+ ".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality]
649
+ )
650
+
651
+ if not success:
652
+ raise Exception("图像格式转换失败")
653
+
654
+ converted_bytes = encoded_img.tobytes()
655
+
656
+ info = {
657
+ 'original_dimensions': f"{width} × {height}",
658
+ 'compressed_dimensions': f"{width} × {height}",
659
+ 'quality': quality if target_format.lower() != 'png' else 100,
660
+ 'format': target_format.upper(),
661
+ 'size': len(converted_bytes)
662
+ }
663
+
664
+ return converted_bytes, info
665
+
666
+ except Exception as e:
667
+ logger.error(f"Image format conversion failed: {e}")
668
+ raise
669
+
670
+
671
+ def save_image_with_transparency(image: np.ndarray, file_path: str) -> bool:
672
+ """
673
+ 保存带透明通道的图像为PNG格式
674
+ :param image: OpenCV图像数组(BGRA格式,包含alpha通道)
675
+ :param file_path: 保存路径
676
+ :return: 保存是否成功
677
+ """
678
+ if image is None:
679
+ logger.error("Image is empty, cannot save")
680
+ return False
681
+
682
+ try:
683
+ # 确保目录存在
684
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
685
+
686
+ # 如果图像有4个通道(BGRA),转换为RGBA然后保存
687
+ if len(image.shape) == 3 and image.shape[2] == 4:
688
+ # BGRA转换为RGBA
689
+ rgba_image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
690
+ elif len(image.shape) == 3 and image.shape[2] == 3:
691
+ # 如果是BGR格式,先转换为RGB,但这种情况不应该有透明度
692
+ rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
693
+ rgba_image = np.dstack((rgb_image, np.full(rgb_image.shape[:2], 255, dtype=np.uint8)))
694
+ else:
695
+ logger.error("Image format does not support transparency saving")
696
+ return False
697
+
698
+ # 使用PIL保存PNG
699
+ pil_image = Image.fromarray(rgba_image, 'RGBA')
700
+ pil_image.save(file_path, 'PNG', optimize=True)
701
+
702
+ file_size = os.path.getsize(file_path)
703
+ logger.info(f"Transparent PNG image saved: {file_path}, size: {file_size/1024:.1f}KB")
704
+ upload_file_to_bos(file_path)
705
+ return True
706
+
707
+ except Exception as e:
708
+ logger.error(f"Failed to save transparent PNG image: {e}")
709
+ return False
vector_store.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vector_store.py
2
+ import logging
3
+ import os
4
+ import pickle
5
+
6
+ import faiss
7
+ import numpy as np
8
+ import torch
9
+
10
+ # 配置日志
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # 获取项目根目录
15
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
16
+ # 拼接 FAISS 索引目录
17
+ FAISS_INDEX_DIR = os.environ.get('FAISS_INDEX_DIR', os.path.join(PROJECT_ROOT, 'faiss', 'data'))
18
+ os.makedirs(FAISS_INDEX_DIR, exist_ok=True)
19
+
20
+ # 最终路径
21
+ FAISS_INDEX_PATH = os.path.join(FAISS_INDEX_DIR, "index.faiss")
22
+ ID_MAP_PATH = os.path.join(FAISS_INDEX_DIR, "id_map.pkl")
23
+
24
+ # ViT-B/16 为 512,ViT-L/14 通常为 768 或 1024
25
+ VECTOR_DIM = int(os.environ.get("VECTOR_DIM", 512))
26
+
27
+ # 全局变量
28
+ index = None
29
+ id_map = None
30
+
31
+ def init_vector_store():
32
+ """初始化向量存储"""
33
+ global index, id_map
34
+ try:
35
+ # 初始化或加载
36
+ if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(ID_MAP_PATH):
37
+ index = faiss.read_index(FAISS_INDEX_PATH)
38
+ with open(ID_MAP_PATH, "rb") as f:
39
+ id_map = pickle.load(f)
40
+ logger.info(f"Vector store loaded successfully path={FAISS_INDEX_DIR}, contains {len(id_map)} vectors")
41
+ else:
42
+ index = faiss.IndexFlatIP(VECTOR_DIM) # 归一化后可以用内积代替余弦相似度
43
+ id_map = []
44
+ logger.info("Initializing new vector store")
45
+ return True
46
+ except Exception as e:
47
+ logger.error(f"Vector store initialization failed: {e}")
48
+ return False
49
+
50
+ def is_vector_store_available():
51
+ """检查向量存储是否可用"""
52
+ return index is not None and id_map is not None
53
+
54
+ def check_image_exists(image_path: str) -> bool:
55
+ """
56
+ 检查图像是否已经在向量库中存在
57
+ Args:
58
+ image_path: 图像路径/标识
59
+ Returns:
60
+ bool: 如果存在返回True,否则返回False
61
+ """
62
+ try:
63
+ if not is_vector_store_available():
64
+ return False
65
+ return image_path in id_map
66
+ except Exception as e:
67
+ logger.error(f"Failed to check if image exists: {str(e)}")
68
+ return False
69
+
70
+ def add_image_vector(image_path: str, vector: torch.Tensor):
71
+ """添加图片向量到存储"""
72
+ if not is_vector_store_available():
73
+ raise RuntimeError("向量存储未初始化")
74
+
75
+ np_vector = vector.squeeze(0).numpy().astype('float32')
76
+ index.add(np_vector[np.newaxis, :])
77
+ id_map.append(image_path)
78
+ save_index()
79
+ logger.info(f"Image vector added: {image_path}")
80
+
81
+ def search_text_vector(vector: torch.Tensor, top_k=5):
82
+ """搜索文本向量"""
83
+ if not is_vector_store_available():
84
+ raise RuntimeError("向量存储未初始化")
85
+
86
+ np_vector = vector.squeeze(0).numpy().astype('float32')
87
+ scores, indices = index.search(np_vector[np.newaxis, :], top_k)
88
+
89
+ if indices is None or len(indices[0]) == 0:
90
+ return []
91
+
92
+ results = [
93
+ (id_map[i], float(scores[0][j]))
94
+ for j, i in enumerate(indices[0])
95
+ if i < len(id_map) and i != -1
96
+ ]
97
+ return results
98
+
99
+ def save_index():
100
+ """保存索引文件"""
101
+ try:
102
+ faiss.write_index(index, FAISS_INDEX_PATH)
103
+ with open(ID_MAP_PATH, "wb") as f:
104
+ pickle.dump(id_map, f)
105
+ logger.debug("Vector index saved")
106
+ except Exception as e:
107
+ logger.error(f"Failed to save vector index: {e}")
108
+
109
+ def get_vector_store_info():
110
+ """获取向量存储信息"""
111
+ if not is_vector_store_available():
112
+ return {"status": "not_initialized", "count": 0}
113
+
114
+ return {
115
+ "status": "available",
116
+ "count": len(id_map),
117
+ "vector_dim": VECTOR_DIM,
118
+ "index_path": FAISS_INDEX_PATH
119
+ }
wx_access_token.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import aiohttp
4
+
5
+ from config import access_token_cache, WECHAT_APPID, WECHAT_SECRET, logger
6
+
7
+
8
+ async def get_access_token() -> Optional[str]:
9
+ """获取微信 stable access_token (推荐方式)"""
10
+ import time
11
+
12
+ # 检查缓存是否有效
13
+ if access_token_cache["token"] and time.time() < access_token_cache["expires_at"]:
14
+ return access_token_cache["token"]
15
+ # 使用新的 getStableAccessToken 接口
16
+ url = "https://api.weixin.qq.com/cgi-bin/stable_token"
17
+ data = {
18
+ "grant_type": "client_credential",
19
+ "appid": WECHAT_APPID,
20
+ "secret": WECHAT_SECRET,
21
+ "force_refresh": False, # 是否强制刷新
22
+ }
23
+
24
+ try:
25
+ async with aiohttp.ClientSession() as session:
26
+ async with session.post(url, json=data) as response:
27
+ if response.status == 200:
28
+ result = await response.json()
29
+ logger.debug(f"getStableAccessToken response: {result}")
30
+ if "access_token" in result:
31
+ access_token_cache["token"] = result["access_token"]
32
+ access_token_cache["expires_at"] = (
33
+ time.time() + result.get("expires_in", 7200) - 300
34
+ )
35
+ expires_time = access_token_cache["expires_at"]
36
+ logger.debug(
37
+ f"成功获取 stable access_token expires_time={expires_time}"
38
+ )
39
+ return result["access_token"]
40
+ else:
41
+ logger.error(f"Failed to get stable access_token: {result}")
42
+ else:
43
+ logger.error(f"Failed to request stable access_token: {response.status}")
44
+ except Exception as e:
45
+ logger.error(f"Exception while getting stable access_token: {str(e)}")
46
+
47
+ return None
48
+
49
+
50
+ async def get_access_token_old() -> Optional[str]:
51
+ """获取微信 access_token"""
52
+ import time
53
+
54
+ # 检查缓存是否有效
55
+ if access_token_cache["token"] and time.time() < access_token_cache["expires_at"]:
56
+ return access_token_cache["token"]
57
+ # 获取新的 access_token
58
+ url = "https://api.weixin.qq.com/cgi-bin/token"
59
+ params = {
60
+ "grant_type": "client_credential",
61
+ "appid": WECHAT_APPID,
62
+ "secret": WECHAT_SECRET,
63
+ }
64
+
65
+ try:
66
+ async with aiohttp.ClientSession() as session:
67
+ async with session.get(url, params=params) as response:
68
+ if response.status == 200:
69
+ data = await response.json()
70
+ if "access_token" in data:
71
+ access_token_cache["token"] = data["access_token"]
72
+ access_token_cache["expires_at"] = (
73
+ time.time() + data.get("expires_in", 7200) - 300
74
+ ) # 提前5分钟过期
75
+ logger.info("Successfully obtained WeChat access_token...")
76
+ return data["access_token"]
77
+ else:
78
+ logger.error(f"Failed to get access_token, returned content: {data}")
79
+ return None
80
+ else:
81
+ logger.error(f"Failed to get access_token, status={response.status}")
82
+ return None
83
+ except Exception as e:
84
+ logger.error(f"Failed to get access_token: {str(e)}")
85
+
86
+ return None
87
+
88
+
89
+ async def check_image_security(image_data: bytes) -> bool:
90
+ """
91
+ 检测图片内容安全
92
+ :param image_data: 图片二进制数据
93
+ :return: True表示安全,False表示有风险
94
+ """
95
+ access_token = await get_access_token()
96
+ if not access_token:
97
+ logger.warning("Unable to get access_token, skipping security check")
98
+ return True # 获取token失败时允许继续,避免影响正常用户
99
+ url = f"https://api.weixin.qq.com/wxa/img_sec_check?access_token={access_token}"
100
+ try:
101
+ async with aiohttp.ClientSession() as session:
102
+ # 微信API要求使用 multipart/form-data 格式
103
+ data = aiohttp.FormData()
104
+ data.add_field("media", image_data, content_type="image/jpeg")
105
+ async with session.post(url, data=data, timeout=10) as response:
106
+ if response.status == 200:
107
+ result = await response.json()
108
+ logger.info(f"Checking image content safety...result={result}")
109
+ if result.get("errcode") == 0:
110
+ return True # 安全
111
+ elif result.get("errcode") == 87014:
112
+ logger.warning("Image content contains illegal content...")
113
+ return False
114
+ else:
115
+ logger.warning(f"Image security check returned error: {result}")
116
+ return True # 其他错误时允许继续
117
+ else:
118
+ logger.warning(f"Image security check request failed: {response.status}")
119
+ return True
120
+ except Exception as e:
121
+ logger.error(f"Image security check exception: {str(e)}")
122
+ return True # 异常时允许继续,避免影响正常用户