chawin.chen commited on
Commit
fae1594
·
1 Parent(s): cd5aabe
Files changed (5) hide show
  1. Dockerfile +9 -9
  2. app.py +12 -0
  3. config.py +37 -5
  4. start_local.sh +5 -0
  5. utils.py +175 -0
Dockerfile CHANGED
@@ -1,12 +1,12 @@
1
  FROM python:3.10-slim
2
 
3
  ENV TZ=Asia/Shanghai \
4
- OUTPUT_DIR=/opt/output \
5
- IMAGES_DIR=/opt/images \
6
- MODELS_PATH=/opt/models \
7
- DEEPFACE_HOME=/opt/models \
8
- FAISS_INDEX_DIR=/opt/faiss \
9
- CELEBRITY_SOURCE_DIR=/opt/chinese_celeb_dataset \
10
  GENDER_CONFIDENCE=1 \
11
  UPSCALE_SIZE=2 \
12
  AGE_CONFIDENCE=0.1 \
@@ -27,8 +27,8 @@ ENV TZ=Asia/Shanghai \
27
  ENABLE_ANIME_PRELOAD=false \
28
  ENABLE_LOGGING=true \
29
  BEAUTY_ADJUST_ENABLED=true \
30
- RVM_LOCAL_REPO=/app/RobustVideoMatting \
31
- RVM_WEIGHTS_PATH=/opt/models/torch/hub/checkpoints/rvm_resnet50.pth \
32
  RVM_MODEL=resnet50 \
33
  AUTO_INIT_GFPGAN=false \
34
  AUTO_INIT_DDCOLOR=false \
@@ -43,7 +43,7 @@ ENV TZ=Asia/Shanghai \
43
  FEMALE_AGE_ADJUSTMENT=4 \
44
  HOSTNAME=HG
45
 
46
- RUN mkdir -p /opt/chinese_celeb_dataset /opt/faiss /opt/models /opt/images /opt/output
47
  WORKDIR /app
48
  COPY requirements.txt .
49
  COPY *.py /app/
 
1
  FROM python:3.10-slim
2
 
3
  ENV TZ=Asia/Shanghai \
4
+ OUTPUT_DIR=/opt/data/output \
5
+ IMAGES_DIR=/opt/data/images \
6
+ MODELS_PATH=/opt/data/models \
7
+ DEEPFACE_HOME=/opt/data/models \
8
+ FAISS_INDEX_DIR=/opt/data/faiss \
9
+ CELEBRITY_SOURCE_DIR=/opt/data/chinese_celeb_dataset \
10
  GENDER_CONFIDENCE=1 \
11
  UPSCALE_SIZE=2 \
12
  AGE_CONFIDENCE=0.1 \
 
27
  ENABLE_ANIME_PRELOAD=false \
28
  ENABLE_LOGGING=true \
29
  BEAUTY_ADJUST_ENABLED=true \
30
+ RVM_LOCAL_REPO=/opt/data/RobustVideoMatting \
31
+ RVM_WEIGHTS_PATH=/opt/data/models/torch/hub/checkpoints/rvm_resnet50.pth \
32
  RVM_MODEL=resnet50 \
33
  AUTO_INIT_GFPGAN=false \
34
  AUTO_INIT_DDCOLOR=false \
 
43
  FEMALE_AGE_ADJUSTMENT=4 \
44
  HOSTNAME=HG
45
 
46
+ RUN mkdir -p /opt/data/chinese_celeb_dataset /opt/data/faiss /opt/data/models /opt/data/images /opt/data/output
47
  WORKDIR /app
48
  COPY requirements.txt .
49
  COPY *.py /app/
app.py CHANGED
@@ -17,8 +17,20 @@ from config import (
17
  ENABLE_LOGGING,
18
  )
19
  from database import close_mysql_pool, init_mysql_pool
 
20
 
21
  logger.info("Starting to import api_routes module...")
 
 
 
 
 
 
 
 
 
 
 
22
  try:
23
  t_start = time.perf_counter()
24
  from api_routes import api_router
 
17
  ENABLE_LOGGING,
18
  )
19
  from database import close_mysql_pool, init_mysql_pool
20
+ from utils import ensure_bos_resources
21
 
22
  logger.info("Starting to import api_routes module...")
23
+
24
+ try:
25
+ t_bos_start = time.perf_counter()
26
+ if not ensure_bos_resources():
27
+ raise RuntimeError("无法从 BOS 同步模型与数据,请检查凭证与网络")
28
+ bos_time = time.perf_counter() - t_bos_start
29
+ logger.info(f"BOS resources synchronized successfully, time: {bos_time:.3f}s")
30
+ except Exception as exc:
31
+ logger.error(f"BOS resource preparation failed: {exc}")
32
+ raise
33
+
34
  try:
35
  t_start = time.perf_counter()
36
  from api_routes import api_router
config.py CHANGED
@@ -176,15 +176,24 @@ try:
176
  except (ImportError, AttributeError) as e:
177
  print(f"Warning: PyTorch/PyArrow compatibility patch failed: {e}")
178
  pass
179
- IMAGES_DIR = os.environ.get("IMAGES_DIR", "~/app/data/images")
180
  OUTPUT_DIR = IMAGES_DIR
181
 
182
  # 明星图库目录配置
183
  CELEBRITY_SOURCE_DIR = os.environ.get(
184
- "CELEBRITY_SOURCE_DIR", "~/apps/chinese_celeb_imgs"
185
  ).strip()
186
  if CELEBRITY_SOURCE_DIR:
187
- CELEBRITY_SOURCE_DIR = os.path.expanduser(CELEBRITY_SOURCE_DIR)
 
 
 
 
 
 
 
 
 
188
 
189
  CELEBRITY_FIND_THRESHOLD = float(
190
  os.environ.get("CELEBRITY_FIND_THRESHOLD", 0.88)
@@ -202,6 +211,10 @@ BOS_ENDPOINT = os.environ.get(
202
  ).strip()
203
  BOS_BUCKET_NAME = os.environ.get("BOS_BUCKET_NAME", "hbgs-travel").strip()
204
  BOS_IMAGE_DIR = os.environ.get("BOS_IMAGE_DIR", "20220808").strip()
 
 
 
 
205
  _bos_enabled_env = os.environ.get("BOS_UPLOAD_ENABLED")
206
  if _bos_enabled_env is not None:
207
  BOS_UPLOAD_ENABLED = _bos_enabled_env.lower() in ("1", "true", "on")
@@ -216,12 +229,17 @@ else:
216
  )
217
  APP_SECRET_TOKEN = os.environ.get("APP_SECRET_TOKEN", "Abdc@q1")
218
  HOSTNAME = os.environ.get("HOSTNAME", "default-hostname")
219
- MODELS_PATH = os.environ.get("MODELS_PATH", "~/apps/ai/models")
 
 
 
 
 
220
  DEEPFACE_HOME = os.environ.get("DEEPFACE_HOME", "~/apps/ai")
221
  os.environ["DEEPFACE_HOME"] = DEEPFACE_HOME
222
 
223
  # 设置GFPGAN相关模型下载路径
224
- GFPGAN_MODEL_DIR = "~/apps/ai/models"
225
  os.makedirs(GFPGAN_MODEL_DIR, exist_ok=True)
226
 
227
  # 设置各种模型库的下载目录环境变量
@@ -286,6 +304,20 @@ AUTO_INIT_RVM = os.environ.get("AUTO_INIT_RVM", "false").lower() in ("1", "true"
286
  CLEANUP_INTERVAL_HOURS = float(os.environ.get("CLEANUP_INTERVAL_HOURS", 12.0)) # 清理任务执行间隔(小时),默认1小时
287
  CLEANUP_AGE_HOURS = float(os.environ.get("CLEANUP_AGE_HOURS", 12.0)) # 清理文件的年龄阈值(小时),默认1小时
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  log_level_str = os.getenv("LOG_LEVEL", "INFO").upper()
290
  log_level = getattr(logging, log_level_str, logging.INFO)
291
 
 
176
  except (ImportError, AttributeError) as e:
177
  print(f"Warning: PyTorch/PyArrow compatibility patch failed: {e}")
178
  pass
179
+ IMAGES_DIR = os.environ.get("IMAGES_DIR", "/opt/data/images")
180
  OUTPUT_DIR = IMAGES_DIR
181
 
182
  # 明星图库目录配置
183
  CELEBRITY_SOURCE_DIR = os.environ.get(
184
+ "CELEBRITY_SOURCE_DIR", "/opt/data/chinese_celeb_dataset"
185
  ).strip()
186
  if CELEBRITY_SOURCE_DIR:
187
+ CELEBRITY_SOURCE_DIR = os.path.abspath(os.path.expanduser(CELEBRITY_SOURCE_DIR))
188
+
189
+ CELEBRITY_DATASET_DIR = os.path.abspath(
190
+ os.path.expanduser(
191
+ os.environ.get(
192
+ "CELEBRITY_DATASET_DIR",
193
+ CELEBRITY_SOURCE_DIR or "/opt/data/chinese_celeb_dataset",
194
+ )
195
+ )
196
+ )
197
 
198
  CELEBRITY_FIND_THRESHOLD = float(
199
  os.environ.get("CELEBRITY_FIND_THRESHOLD", 0.88)
 
211
  ).strip()
212
  BOS_BUCKET_NAME = os.environ.get("BOS_BUCKET_NAME", "hbgs-travel").strip()
213
  BOS_IMAGE_DIR = os.environ.get("BOS_IMAGE_DIR", "20220808").strip()
214
+ BOS_MODELS_PREFIX = os.environ.get("BOS_MODELS_PREFIX", "20220620/models").strip()
215
+ BOS_CELEBRITY_PREFIX = os.environ.get(
216
+ "BOS_CELEBRITY_PREFIX", "20220620/chinese_celeb_dataset"
217
+ ).strip()
218
  _bos_enabled_env = os.environ.get("BOS_UPLOAD_ENABLED")
219
  if _bos_enabled_env is not None:
220
  BOS_UPLOAD_ENABLED = _bos_enabled_env.lower() in ("1", "true", "on")
 
229
  )
230
  APP_SECRET_TOKEN = os.environ.get("APP_SECRET_TOKEN", "Abdc@q1")
231
  HOSTNAME = os.environ.get("HOSTNAME", "default-hostname")
232
+ MODELS_PATH = os.path.abspath(
233
+ os.path.expanduser(os.environ.get("MODELS_PATH", "/opt/data/models"))
234
+ )
235
+ MODELS_DOWNLOAD_DIR = os.path.abspath(
236
+ os.path.expanduser(os.environ.get("MODELS_DOWNLOAD_DIR", MODELS_PATH))
237
+ )
238
  DEEPFACE_HOME = os.environ.get("DEEPFACE_HOME", "~/apps/ai")
239
  os.environ["DEEPFACE_HOME"] = DEEPFACE_HOME
240
 
241
  # 设置GFPGAN相关模型下载路径
242
+ GFPGAN_MODEL_DIR = MODELS_DOWNLOAD_DIR
243
  os.makedirs(GFPGAN_MODEL_DIR, exist_ok=True)
244
 
245
  # 设置各种模型库的下载目录环境变量
 
304
  CLEANUP_INTERVAL_HOURS = float(os.environ.get("CLEANUP_INTERVAL_HOURS", 12.0)) # 清理任务执行间隔(小时),默认1小时
305
  CLEANUP_AGE_HOURS = float(os.environ.get("CLEANUP_AGE_HOURS", 12.0)) # 清理文件的年龄阈值(小时),默认1小时
306
 
307
+ # BOS 自动同步清单:定义 BOS 路径和本地目录的映射,启动时可迭代该结构完成批量下载
308
+ BOS_DOWNLOAD_TARGETS = [
309
+ {
310
+ "description": "明星图库数据集",
311
+ "bos_prefix": BOS_CELEBRITY_PREFIX,
312
+ "destination": CELEBRITY_DATASET_DIR,
313
+ },
314
+ {
315
+ "description": "AI 模型权重",
316
+ "bos_prefix": BOS_MODELS_PREFIX,
317
+ "destination": MODELS_DOWNLOAD_DIR,
318
+ },
319
+ ]
320
+
321
  log_level_str = os.getenv("LOG_LEVEL", "INFO").upper()
322
  log_level = getattr(logging, log_level_str, logging.INFO)
323
 
start_local.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export TZ=Asia/Shanghai
3
+ export HOSTNAME=HG
4
+ uvicorn app:app --workers 1 --loop asyncio --http httptools --host 0.0.0.0 --port 7860 --timeout-keep-alive 600
5
+
utils.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import re
5
  import shutil
6
  import threading
 
7
 
8
  import cv2
9
  import numpy as np
@@ -27,11 +28,14 @@ from config import (
27
  BOS_BUCKET_NAME,
28
  BOS_IMAGE_DIR,
29
  BOS_UPLOAD_ENABLED,
 
30
  )
31
 
32
  _BOS_CLIENT = None
33
  _BOS_CLIENT_INITIALIZED = False
34
  _BOS_CLIENT_LOCK = threading.Lock()
 
 
35
  _IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR))
36
 
37
 
@@ -109,6 +113,177 @@ def _get_bos_client():
109
  return _BOS_CLIENT
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def upload_file_to_bos(file_path: str, object_name: str | None = None) -> bool:
113
  """
114
  将指定文件上传到 BOS,失败不会抛出异常。
 
4
  import re
5
  import shutil
6
  import threading
7
+ from typing import Optional
8
 
9
  import cv2
10
  import numpy as np
 
28
  BOS_BUCKET_NAME,
29
  BOS_IMAGE_DIR,
30
  BOS_UPLOAD_ENABLED,
31
+ BOS_DOWNLOAD_TARGETS,
32
  )
33
 
34
  _BOS_CLIENT = None
35
  _BOS_CLIENT_INITIALIZED = False
36
  _BOS_CLIENT_LOCK = threading.Lock()
37
+ _BOS_DOWNLOAD_LOCK = threading.Lock()
38
+ _BOS_DOWNLOAD_COMPLETED = False
39
  _IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR))
40
 
41
 
 
113
  return _BOS_CLIENT
114
 
115
 
116
+ def _normalize_bos_prefix(prefix: Optional[str]) -> str:
117
+ value = (prefix or "").strip()
118
+ if not value:
119
+ return ""
120
+ value = value.strip("/")
121
+ if not value:
122
+ return ""
123
+ return f"{value}/" if not value.endswith("/") else value
124
+
125
+
126
+ def _directory_has_files(path: str) -> bool:
127
+ try:
128
+ for _root, _dirs, files in os.walk(path):
129
+ if files:
130
+ return True
131
+ except Exception:
132
+ return False
133
+ return False
134
+
135
+
136
+ def download_bos_directory(prefix: str, destination_dir: str, *, force_download: bool = False) -> bool:
137
+ """
138
+ 将 BOS 上的指定前缀目录同步到本地。
139
+ :param prefix: BOS 对象前缀,例如 'models/' 或 '20220620/models'
140
+ :param destination_dir: 本地目标目录
141
+ :param force_download: 是否强制重新下载(忽略本地已存在的文件)
142
+ :return: 是否确保目录可用
143
+ """
144
+ client = _get_bos_client()
145
+ if client is None:
146
+ logger.warning("BOS 客户端不可用,无法下载资源(prefix=%s)", prefix)
147
+ return False
148
+
149
+ dest_dir = os.path.abspath(os.path.expanduser(destination_dir))
150
+ try:
151
+ os.makedirs(dest_dir, exist_ok=True)
152
+ except Exception as exc:
153
+ logger.error("创建本地目录失败: %s (%s)", dest_dir, exc)
154
+ return False
155
+
156
+ normalized_prefix = _normalize_bos_prefix(prefix)
157
+
158
+ # 未强制下载且目录已有文件时直接跳过,避免重复下载
159
+ if not force_download and _directory_has_files(dest_dir):
160
+ logger.info("本地目录已存在文件,跳过下载: %s -> %s", normalized_prefix or "<root>", dest_dir)
161
+ return True
162
+
163
+ paginate_kwargs = {"Bucket": BOS_BUCKET_NAME}
164
+ if normalized_prefix:
165
+ paginate_kwargs["Prefix"] = normalized_prefix if normalized_prefix.endswith("/") else f"{normalized_prefix}/"
166
+
167
+ found_any = False
168
+ downloaded = 0
169
+ skipped = 0
170
+
171
+ try:
172
+ paginator = client.get_paginator("list_objects_v2")
173
+ for page in paginator.paginate(**paginate_kwargs):
174
+ for obj in page.get("Contents", []):
175
+ key = obj.get("Key")
176
+ if not key:
177
+ continue
178
+ if normalized_prefix:
179
+ prefix_with_slash = normalized_prefix if normalized_prefix.endswith("/") else f"{normalized_prefix}/"
180
+ if not key.startswith(prefix_with_slash):
181
+ continue
182
+ relative_key = key[len(prefix_with_slash):]
183
+ else:
184
+ relative_key = key
185
+
186
+ if not relative_key or relative_key.endswith("/"):
187
+ continue
188
+ found_any = True
189
+
190
+ target_path = os.path.join(dest_dir, relative_key)
191
+ target_dir = os.path.dirname(target_path)
192
+ os.makedirs(target_dir, exist_ok=True)
193
+
194
+ expected_size = obj.get("Size")
195
+ if (
196
+ not force_download
197
+ and os.path.exists(target_path)
198
+ and expected_size is not None
199
+ and expected_size == os.path.getsize(target_path)
200
+ ):
201
+ skipped += 1
202
+ logger.info("文件已存在且大小一致,跳过下载: %s", relative_key)
203
+ continue
204
+
205
+ tmp_path = f"{target_path}.download"
206
+ try:
207
+ size_mb = (expected_size or 0) / (1024 * 1024)
208
+ logger.info("开始下载: %s (%.2f MB)", relative_key, size_mb)
209
+ client.download_file(Bucket=BOS_BUCKET_NAME, Key=key, Filename=tmp_path)
210
+ os.replace(tmp_path, target_path)
211
+ downloaded += 1
212
+ logger.info("下载完成: %s", relative_key)
213
+ except Exception as exc:
214
+ logger.warning("下载失败: %s (%s)", key, exc)
215
+ try:
216
+ if os.path.exists(tmp_path):
217
+ os.remove(tmp_path)
218
+ except Exception:
219
+ pass
220
+ except Exception as exc:
221
+ logger.warning("遍历 BOS 目录失败: %s", exc)
222
+ return False
223
+
224
+ if not found_any:
225
+ logger.warning("在 BOS 桶 %s 中未找到前缀 '%s' 的内容", BOS_BUCKET_NAME, normalized_prefix or "<root>")
226
+ return False
227
+
228
+ logger.info(
229
+ "BOS 同步完成 prefix=%s -> %s 下载=%d 跳过=%d",
230
+ normalized_prefix or "<root>",
231
+ dest_dir,
232
+ downloaded,
233
+ skipped,
234
+ )
235
+ return downloaded > 0 or skipped > 0
236
+
237
+
238
+ def ensure_bos_resources(force_download: bool = False) -> bool:
239
+ """
240
+ 根据配置的 BOS_DOWNLOAD_TARGETS 同步启动所需的模型与数据资源。
241
+ :param force_download: 是否强制重新同步所有资源
242
+ :return: 资源是否已准备就绪
243
+ """
244
+ global _BOS_DOWNLOAD_COMPLETED
245
+
246
+ with _BOS_DOWNLOAD_LOCK:
247
+ if _BOS_DOWNLOAD_COMPLETED and not force_download:
248
+ return True
249
+
250
+ targets = BOS_DOWNLOAD_TARGETS or []
251
+ if not targets:
252
+ logger.info("未配置 BOS 下载目标,跳过资源同步")
253
+ _BOS_DOWNLOAD_COMPLETED = True
254
+ return True
255
+
256
+ results = []
257
+ for target in targets:
258
+ if not isinstance(target, dict):
259
+ logger.warning("无效的 BOS 下载配置项: %r", target)
260
+ results.append(False)
261
+ continue
262
+
263
+ prefix = target.get("bos_prefix")
264
+ destination = target.get("destination")
265
+ description = target.get("description") or prefix or "<unnamed>"
266
+
267
+ if not prefix or not destination:
268
+ logger.warning("缺少必要字段,无法处理 BOS 下载配置: %r", target)
269
+ results.append(False)
270
+ continue
271
+
272
+ logger.info("准备同步 BOS 资源: %s (prefix=%s -> %s)", description, prefix, destination)
273
+ success = download_bos_directory(prefix, destination, force_download=force_download)
274
+ if success:
275
+ logger.info("BOS 资源已就绪: %s", description)
276
+ else:
277
+ logger.warning("BOS 资源同步失败: %s", description)
278
+ results.append(success)
279
+
280
+ all_ready = all(results) if results else True
281
+ if all_ready:
282
+ _BOS_DOWNLOAD_COMPLETED = True
283
+
284
+ return all_ready
285
+
286
+
287
  def upload_file_to_bos(file_path: str, object_name: str | None = None) -> bool:
288
  """
289
  将指定文件上传到 BOS,失败不会抛出异常。