Spaces:
Paused
Paused
| import logging | |
| import os | |
| # 解决OpenMP库冲突问题 | |
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | |
| # 设置CPU线程数为CPU核心数,提高CPU利用率 | |
| import multiprocessing | |
| cpu_cores = multiprocessing.cpu_count() | |
| os.environ["OMP_NUM_THREADS"] = str(min(cpu_cores, 8)) # 最多使用8个线程 | |
| os.environ["MKL_NUM_THREADS"] = str(min(cpu_cores, 8)) | |
| os.environ["NUMEXPR_NUM_THREADS"] = str(min(cpu_cores, 8)) | |
| # 修复torchvision兼容性问题 | |
| try: | |
| import torchvision.transforms.functional_tensor | |
| except ImportError: | |
| # 为缺失的functional_tensor模块创建兼容性补丁 | |
| import torchvision.transforms.functional as F | |
| import torchvision.transforms as transforms | |
| import sys | |
| from types import ModuleType | |
| # 创建functional_tensor模块 | |
| functional_tensor = ModuleType('torchvision.transforms.functional_tensor') | |
| # 添加常用的函数映射 | |
| if hasattr(F, 'rgb_to_grayscale'): | |
| functional_tensor.rgb_to_grayscale = F.rgb_to_grayscale | |
| if hasattr(F, 'adjust_brightness'): | |
| functional_tensor.adjust_brightness = F.adjust_brightness | |
| if hasattr(F, 'adjust_contrast'): | |
| functional_tensor.adjust_contrast = F.adjust_contrast | |
| if hasattr(F, 'adjust_saturation'): | |
| functional_tensor.adjust_saturation = F.adjust_saturation | |
| if hasattr(F, 'normalize'): | |
| functional_tensor.normalize = F.normalize | |
| if hasattr(F, 'resize'): | |
| functional_tensor.resize = F.resize | |
| if hasattr(F, 'crop'): | |
| functional_tensor.crop = F.crop | |
| if hasattr(F, 'pad'): | |
| functional_tensor.pad = F.pad | |
| # 将模块添加到sys.modules | |
| sys.modules['torchvision.transforms.functional_tensor'] = functional_tensor | |
| transforms.functional_tensor = functional_tensor | |
| # 环境变量配置 - 禁用TensorFlow优化和GPU | |
| os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # 强制使用CPU | |
| os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" | |
| # 修复PyTorch兼容性问题 | |
| try: | |
| import torch | |
| import torch.onnx | |
| # 修复GFPGAN的ONNX兼容性 | |
| if not hasattr(torch.onnx._internal.exporter, 'ExportOptions'): | |
| from types import SimpleNamespace | |
| torch.onnx._internal.exporter.ExportOptions = SimpleNamespace | |
| # 修复ModelScope的PyTree兼容性 - 更完整的实现 | |
| import torch.utils | |
| if not hasattr(torch.utils, '_pytree'): | |
| # 如果_pytree模块不存在,创建一个 | |
| from types import ModuleType | |
| torch.utils._pytree = ModuleType('_pytree') | |
| pytree = torch.utils._pytree | |
| if not hasattr(pytree, 'register_pytree_node'): | |
| def register_pytree_node(typ, flatten_fn, unflatten_fn, *, flatten_with_keys_fn=None, **kwargs): | |
| """兼容性实现:注册PyTree节点类型""" | |
| pass # 简单实现,不做实际操作 | |
| pytree.register_pytree_node = register_pytree_node | |
| if not hasattr(pytree, 'tree_flatten'): | |
| def tree_flatten(tree, is_leaf=None): | |
| """兼容性实现:展平树结构""" | |
| if isinstance(tree, (list, tuple)): | |
| flat = [] | |
| spec = [] | |
| for i, item in enumerate(tree): | |
| if isinstance(item, (list, tuple, dict)): | |
| sub_flat, sub_spec = tree_flatten(item, is_leaf) | |
| flat.extend(sub_flat) | |
| spec.append((i, sub_spec)) | |
| else: | |
| flat.append(item) | |
| spec.append((i, None)) | |
| return flat, (type(tree), spec) | |
| elif isinstance(tree, dict): | |
| flat = [] | |
| spec = [] | |
| for key, value in sorted(tree.items()): | |
| if isinstance(value, (list, tuple, dict)): | |
| sub_flat, sub_spec = tree_flatten(value, is_leaf) | |
| flat.extend(sub_flat) | |
| spec.append((key, sub_spec)) | |
| else: | |
| flat.append(value) | |
| spec.append((key, None)) | |
| return flat, (dict, spec) | |
| else: | |
| return [tree], None | |
| pytree.tree_flatten = tree_flatten | |
| if not hasattr(pytree, 'tree_unflatten'): | |
| def tree_unflatten(values, spec): | |
| """兼容性实现:重构树结构""" | |
| if spec is None: | |
| return values[0] if values else None | |
| tree_type, tree_spec = spec | |
| if tree_type in (list, tuple): | |
| result = [] | |
| value_idx = 0 | |
| for pos, sub_spec in tree_spec: | |
| if sub_spec is None: | |
| result.append(values[value_idx]) | |
| value_idx += 1 | |
| else: | |
| # 计算子树需要的值数量 | |
| sub_count = _count_tree_values(sub_spec) | |
| sub_values = values[value_idx:value_idx + sub_count] | |
| result.append(tree_unflatten(sub_values, sub_spec)) | |
| value_idx += sub_count | |
| return tree_type(result) | |
| elif tree_type == dict: | |
| result = {} | |
| value_idx = 0 | |
| for key, sub_spec in tree_spec: | |
| if sub_spec is None: | |
| result[key] = values[value_idx] | |
| value_idx += 1 | |
| else: | |
| sub_count = _count_tree_values(sub_spec) | |
| sub_values = values[value_idx:value_idx + sub_count] | |
| result[key] = tree_unflatten(sub_values, sub_spec) | |
| value_idx += sub_count | |
| return result | |
| return values[0] if values else None | |
| pytree.tree_unflatten = tree_unflatten | |
| if not hasattr(pytree, 'tree_map'): | |
| def tree_map(fn, tree, *other_trees, is_leaf=None): | |
| """兼容性实现:树映射""" | |
| flat, spec = tree_flatten(tree, is_leaf) | |
| if other_trees: | |
| other_flats = [tree_flatten(t, is_leaf)[0] for t in other_trees] | |
| mapped = [fn(x, *others) for x, *others in zip(flat, *other_flats)] | |
| else: | |
| mapped = [fn(x) for x in flat] | |
| return tree_unflatten(mapped, spec) | |
| pytree.tree_map = tree_map | |
| # 辅助函数 | |
| def _count_tree_values(spec): | |
| """计算树规格中的值数量""" | |
| if spec is None: | |
| return 1 | |
| tree_type, tree_spec = spec | |
| return sum(_count_tree_values(sub_spec) if sub_spec else 1 for _, sub_spec in tree_spec) | |
| # 修复pyarrow兼容性问题 | |
| try: | |
| import pyarrow | |
| if not hasattr(pyarrow, 'PyExtensionType'): | |
| # 为旧版本pyarrow添加PyExtensionType兼容性 | |
| pyarrow.PyExtensionType = type('PyExtensionType', (), {}) | |
| except ImportError: | |
| pass | |
| except (ImportError, AttributeError) as e: | |
| print(f"Warning: PyTorch/PyArrow compatibility patch failed: {e}") | |
| pass | |
| IMAGES_DIR = os.environ.get("IMAGES_DIR", "/opt/data/images") | |
| OUTPUT_DIR = IMAGES_DIR | |
| # 明星图库目录配置 | |
| CELEBRITY_SOURCE_DIR = os.environ.get( | |
| "CELEBRITY_SOURCE_DIR", "/opt/data/chinese_celeb_dataset" | |
| ).strip() | |
| if CELEBRITY_SOURCE_DIR: | |
| CELEBRITY_SOURCE_DIR = os.path.abspath(os.path.expanduser(CELEBRITY_SOURCE_DIR)) | |
| CELEBRITY_DATASET_DIR = os.path.abspath( | |
| os.path.expanduser( | |
| os.environ.get( | |
| "CELEBRITY_DATASET_DIR", | |
| CELEBRITY_SOURCE_DIR or "/opt/data/chinese_celeb_dataset", | |
| ) | |
| ) | |
| ) | |
| CELEBRITY_FIND_THRESHOLD = float( | |
| os.environ.get("CELEBRITY_FIND_THRESHOLD", 0.88) | |
| ) | |
| # ---- start ---- | |
| # 微信小程序配置(默认值仅用于本地开发) | |
| WECHAT_APPID = os.environ.get("WECHAT_APPID", "******").strip() | |
| WECHAT_SECRET = os.environ.get("WCT_SECRET", "******").strip() | |
| APP_SECRET_TOKEN = os.environ.get("APP_SECRET_TOKEN", "******") | |
| # MySQL 数据库配置 | |
| MYSQL_HOST = os.environ.get("MYSQL_HOST", "******") | |
| MYSQL_PORT = int(os.environ.get("MYSQL_PORT", "3306")) | |
| MYSQL_DB = os.environ.get("MYSQL_DB", "******") | |
| MYSQL_USER = os.environ.get("MYSQL_USER", "******") | |
| MYSQL_PASSWORD = os.environ.get("MYSQL_PASSWORD", "******") | |
| # BOS 对象存储配置(默认存储为Base64编码字符串) | |
| BOS_ACCESS_KEY = os.environ.get("BOS_ACCESS_KEY", "******").strip() | |
| BOS_SECRET_KEY = os.environ.get("BOS_SECRET_KEY", "******").strip() | |
| BOS_ENDPOINT = os.environ.get("BOS_ENDPOINT", "******").strip() | |
| BOS_BUCKET_NAME = os.environ.get("BOS_BUCKET_NAME", "******").strip() | |
| BOS_IMAGE_DIR = os.environ.get("BOS_IMAGE_DIR", "******").strip() | |
| BOS_MODELS_PREFIX = os.environ.get("BOS_MODELS_PREFIX", "******").strip() | |
| BOS_CELEBRITY_PREFIX = os.environ.get("BOS_CELEBRITY_PREFIX", "******").strip() | |
| # ---- end --- | |
| _bos_enabled_env = os.environ.get("BOS_UPLOAD_ENABLED") | |
| MYSQL_POOL_MIN_SIZE = int(os.environ.get("MYSQL_POOL_MIN_SIZE", "1")) | |
| MYSQL_POOL_MAX_SIZE = int(os.environ.get("MYSQL_POOL_MAX_SIZE", "10")) | |
| if _bos_enabled_env is not None: | |
| BOS_UPLOAD_ENABLED = _bos_enabled_env.lower() in ("1", "true", "on") | |
| else: | |
| BOS_UPLOAD_ENABLED = all( | |
| [ | |
| BOS_ACCESS_KEY.strip(), | |
| BOS_SECRET_KEY.strip(), | |
| BOS_ENDPOINT, | |
| BOS_BUCKET_NAME, | |
| ] | |
| ) | |
| HOSTNAME = os.environ.get("HOSTNAME", "default-hostname") | |
| MODELS_PATH = os.path.abspath( | |
| os.path.expanduser(os.environ.get("MODELS_PATH", "/opt/data/models")) | |
| ) | |
| MODELS_DOWNLOAD_DIR = os.path.abspath( | |
| os.path.expanduser(os.environ.get("MODELS_DOWNLOAD_DIR", MODELS_PATH)) | |
| ) | |
| # HuggingFace 仓库配置 | |
| HUGGINGFACE_SYNC_ENABLED = os.environ.get( | |
| "HUGGINGFACE_SYNC_ENABLED", "true" | |
| ).lower() in ("1", "true", "on") | |
| HUGGINGFACE_REPO_ID = os.environ.get( | |
| "HUGGINGFACE_REPO_ID", "ethonmax/facescore" | |
| ).strip() | |
| HUGGINGFACE_REVISION = os.environ.get( | |
| "HUGGINGFACE_REVISION", "main" | |
| ).strip() | |
| _hf_allow_env = os.environ.get("HUGGINGFACE_ALLOW_PATTERNS", "").strip() | |
| HUGGINGFACE_ALLOW_PATTERNS = [ | |
| pattern.strip() for pattern in _hf_allow_env.split(",") if pattern.strip() | |
| ] | |
| _hf_ignore_env = os.environ.get("HUGGINGFACE_IGNORE_PATTERNS", "").strip() | |
| HUGGINGFACE_IGNORE_PATTERNS = [ | |
| pattern.strip() for pattern in _hf_ignore_env.split(",") if pattern.strip() | |
| ] | |
| _MODELSCOPE_CACHE_ENV = os.environ.get("MODELSCOPE_CACHE", "").strip() | |
| if _MODELSCOPE_CACHE_ENV: | |
| MODELSCOPE_CACHE_DIR = os.path.abspath(os.path.expanduser(_MODELSCOPE_CACHE_ENV)) | |
| else: | |
| MODELSCOPE_CACHE_DIR = os.path.join(MODELS_PATH, "modelscope") | |
| try: | |
| os.makedirs(MODELSCOPE_CACHE_DIR, exist_ok=True) | |
| except Exception as exc: | |
| print(f"创建 ModelScope 缓存目录失败: %s (%s)", MODELSCOPE_CACHE_DIR, exc) | |
| os.environ.setdefault("MODELSCOPE_CACHE", MODELSCOPE_CACHE_DIR) | |
| os.environ.setdefault("MODELSCOPE_HOME", MODELSCOPE_CACHE_DIR) | |
| os.environ.setdefault("MODELSCOPE_CACHE_HOME", MODELSCOPE_CACHE_DIR) | |
| DEEPFACE_HOME = os.environ.get("DEEPFACE_HOME", "/opt/data/models") | |
| os.environ["DEEPFACE_HOME"] = DEEPFACE_HOME | |
| # 设置GFPGAN相关模型下载路径 | |
| GFPGAN_MODEL_DIR = MODELS_DOWNLOAD_DIR | |
| os.makedirs(GFPGAN_MODEL_DIR, exist_ok=True) | |
| # 设置各种模型库的下载目录环境变量 | |
| os.environ["GFPGAN_MODEL_ROOT"] = GFPGAN_MODEL_DIR | |
| os.environ["FACEXLIB_CACHE_DIR"] = GFPGAN_MODEL_DIR | |
| os.environ["BASICSR_CACHE_DIR"] = GFPGAN_MODEL_DIR | |
| os.environ["REALESRGAN_MODEL_ROOT"] = GFPGAN_MODEL_DIR | |
| os.environ["HUB_CACHE_DIR"] = GFPGAN_MODEL_DIR # PyTorch Hub缓存 | |
| # 设置rembg模型下载路径到统一的AI模型目录 | |
| REMBG_MODEL_DIR = os.path.expanduser(MODELS_PATH.replace("$HOME", "~")) | |
| os.environ["U2NET_HOME"] = REMBG_MODEL_DIR # u2net模型缓存目录 | |
| os.environ["REMBG_HOME"] = REMBG_MODEL_DIR # rembg通用缓存目录 | |
| IMG_QUALITY = float(os.environ.get("IMG_QUALITY", 0.5)) | |
| FACE_CONFIDENCE = float(os.environ.get("FACE_CONFIDENCE", 0.7)) | |
| AGE_CONFIDENCE = float(os.environ.get("AGE_CONFIDENCE", 0.99)) | |
| GENDER_CONFIDENCE = float(os.environ.get("GENDER_CONFIDENCE", 1.1)) | |
| # 是否启用 DeepFace 的情绪识别(默认开启;关闭可减少推理耗时) | |
| DEEPFACE_EMOTION_ENABLED = os.environ.get("DEEPFACE_EMOTION_ENABLED", "true").lower() in ("1", "true", "on") | |
| UPSCALE_SIZE = int(os.environ.get("UPSCALE_SIZE", 2)) | |
| SAVE_QUALITY = int(os.environ.get("SAVE_QUALITY", 85)) | |
| REALESRGAN_MODEL = os.environ.get("REALESRGAN_MODEL", "realesr-general-x4v3") | |
| # yolov11n-face.pt / yolov8n-face.pt | |
| YOLO_MODEL = os.environ.get("YOLO_MODEL", "yolov11n-face.pt") | |
| # mobilenetv3/resnet50 | |
| RVM_MODEL = os.environ.get("RVM_MODEL", "resnet50") | |
| RVM_LOCAL_REPO = os.environ.get("RVM_LOCAL_REPO", "/opt/data/RobustVideoMatting").strip() | |
| RVM_WEIGHTS_PATH = os.environ.get("RVM_WEIGHTS_PATH", "/opt/data/models/torch/hub/checkpoints/rvm_resnet50.pth").strip() | |
| DRAW_SCORE = os.environ.get("DRAW_SCORE", "true").lower() in ("1", "true", "on") | |
| # 颜值评分温和提升配置(默认开启;默认区间与力度:区间=[6.0, 8.0],gamma=0.3) | |
| # - BEAUTY_ADJUST_ENABLED: 是否开启提分 | |
| # - BEAUTY_ADJUST_MIN: 提分下限(低于该值不提分) | |
| # - BEAUTY_ADJUST_MAX: 提分上限(目标上限;仅在 [min, max) 区间内提分) | |
| # - BEAUTY_ADJUST_THRESHOLD: 兼容旧配置,等价于 BEAUTY_ADJUST_MAX | |
| # - BEAUTY_ADJUST_GAMMA: 提分力度,(0,1],越小提升越多 | |
| BEAUTY_ADJUST_ENABLED = os.environ.get("BEAUTY_ADJUST_ENABLED", "true").lower() in ("1", "true", "on") | |
| BEAUTY_ADJUST_MIN = float(os.environ.get("BEAUTY_ADJUST_MIN", 1.0)) | |
| # 向后兼容:未提供 BEAUTY_ADJUST_MAX 时,使用旧的 BEAUTY_ADJUST_THRESHOLD 或 8.0 | |
| _legacy_thr = os.environ.get("BEAUTY_ADJUST_THRESHOLD") | |
| BEAUTY_ADJUST_MAX = float(os.environ.get("BEAUTY_ADJUST_MAX", _legacy_thr if _legacy_thr is not None else 8.0)) | |
| BEAUTY_ADJUST_GAMMA = float(os.environ.get("BEAUTY_ADJUST_GAMMA", 0.5)) # 0<gamma<=1,越小提升越多 | |
| # 兼容旧引用,保留变量名(不再直接使用于逻辑内部) | |
| BEAUTY_ADJUST_THRESHOLD = BEAUTY_ADJUST_MAX | |
| # 整体协调性分数温和提升配置(默认开启;默认阈值与力度:T=8.0, gamma=0.5) | |
| HARMONY_ADJUST_ENABLED = os.environ.get("HARMONY_ADJUST_ENABLED", "true").lower() in ("1", "true", "on") | |
| HARMONY_ADJUST_THRESHOLD = float(os.environ.get("HARMONY_ADJUST_THRESHOLD", 9.0)) | |
| HARMONY_ADJUST_GAMMA = float(os.environ.get("HARMONY_ADJUST_GAMMA", 0.3)) | |
| # 启动优化:是否在启动时自动初始化/预热重型组件 | |
| ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "false").lower() in ("1", "true", "on") | |
| AUTO_INIT_ANALYZER = os.environ.get("AUTO_INIT_ANALYZER", "true").lower() in ("1", "true", "on") | |
| AUTO_INIT_GFPGAN = os.environ.get("AUTO_INIT_GFPGAN", "false").lower() in ("1", "true", "on") | |
| AUTO_INIT_DDCOLOR = os.environ.get("AUTO_INIT_DDCOLOR", "false").lower() in ("1", "true", "on") | |
| AUTO_INIT_REALESRGAN = os.environ.get("AUTO_INIT_REALESRGAN", "false").lower() in ("1", "true", "on") | |
| AUTO_INIT_REMBG = os.environ.get("AUTO_INIT_REMBG", "false").lower() in ("1", "true", "on") | |
| AUTO_INIT_ANIME_STYLE = os.environ.get("AUTO_INIT_ANIME_STYLE", "false").lower() in ("1", "true", "on") | |
| AUTO_INIT_RVM = os.environ.get("AUTO_INIT_RVM", "false").lower() in ("1", "true", "on") | |
| # 定时任务相关配置 | |
| CLEANUP_INTERVAL_HOURS = float(os.environ.get("CLEANUP_INTERVAL_HOURS", 1.0)) # 清理任务执行间隔(小时),默认1小时 | |
| CLEANUP_AGE_HOURS = float(os.environ.get("CLEANUP_AGE_HOURS", 1.0)) # 清理文件的年龄阈值(小时),默认1小时 | |
| # BOS 自动同步清单:定义 BOS 路径和本地目录的映射,启动时可迭代该结构完成批量下载 | |
| BOS_DOWNLOAD_TARGETS = [ | |
| # { | |
| # "description": "明星图库数据集", | |
| # "bos_prefix": BOS_CELEBRITY_PREFIX, | |
| # "destination": CELEBRITY_DATASET_DIR, | |
| # "background": True, | |
| # }, | |
| # { | |
| # "description": "AI 模型权重", | |
| # "bos_prefix": BOS_MODELS_PREFIX, | |
| # "destination": MODELS_DOWNLOAD_DIR, | |
| # }, | |
| ] | |
| log_level_str = os.getenv("LOG_LEVEL", "INFO").upper() | |
| log_level = getattr(logging, log_level_str, logging.INFO) | |
| # 日志开关配置 - 控制是否启用所有日志输出 | |
| ENABLE_LOGGING = os.environ.get("ENABLE_LOGGING", "true").lower() in ("1", "true", "on") | |
| # 功能开关配置 | |
| ENABLE_DDCOLOR = os.environ.get("ENABLE_DDCOLOR", "true").lower() in ("1", "true", "on") | |
| ENABLE_REALESRGAN = os.environ.get("ENABLE_REALESRGAN", "true").lower() in ("1", "true", "on") | |
| ENABLE_GFPGAN = os.environ.get("ENABLE_GFPGAN", "true").lower() in ("1", "true", "on") | |
| ENABLE_ANIME_STYLE = os.environ.get("ENABLE_ANIME_STYLE", "true").lower() in ("1", "true", "on") | |
| ENABLE_ANIME_PRELOAD = os.environ.get("ENABLE_ANIME_PRELOAD", "false").lower() in ("1", "true", "on") | |
| ENABLE_RVM = os.environ.get("ENABLE_RVM", "true").lower() in ("1", "true", "on") | |
| # 颜值评分模块配置 | |
| FACE_SCORE_MAX_IMAGES = int(os.environ.get("FACE_SCORE_MAX_IMAGES", 10)) # 颜值评分最大上传图片数量 | |
| # 女性年龄调整配置 - 对于20岁以上的女性,显示的年龄会减去指定岁数 | |
| FEMALE_AGE_ADJUSTMENT = int(os.environ.get("FEMALE_AGE_ADJUSTMENT", 3)) # 默认减3岁 | |
| FEMALE_AGE_ADJUSTMENT_THRESHOLD = int(os.environ.get("FEMALE_AGE_ADJUSTMENT_THRESHOLD", 20)) # 年龄阈值,默认20岁 | |
| # 配置日志 | |
| if ENABLE_LOGGING: | |
| logging.basicConfig( | |
| level=log_level, | |
| format="[%(asctime)s] [%(levelname)s] %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| else: | |
| # 禁用所有日志输出 | |
| logging.basicConfig(level=logging.CRITICAL + 10) | |
| logger = logging.getLogger(__name__) | |
| logger.disabled = True | |
| # 全局变量存储 access_token | |
| access_token_cache = {"token": None, "expires_at": 0} | |
| # 尝试导入依赖 | |
| try: | |
| from deepface import DeepFace | |
| DEEPFACE_AVAILABLE = True | |
| except ImportError: | |
| print("Warning: DeepFace not installed. Install with: pip install deepface") | |
| DEEPFACE_AVAILABLE = False | |
| try: | |
| import mediapipe as mp | |
| MEDIAPIPE_AVAILABLE = True | |
| except ImportError: | |
| print("Warning: mediapipe not installed. Install with: pip install mediapipe") | |
| MEDIAPIPE_AVAILABLE = False | |
| # 为了保持向后兼容,保留 DLIB_AVAILABLE 变量名 | |
| DLIB_AVAILABLE = MEDIAPIPE_AVAILABLE | |
| try: | |
| from ultralytics import YOLO | |
| YOLO_AVAILABLE = True | |
| except ImportError: | |
| print("Warning: ultralytics not installed. Install with: pip install ultralytics") | |
| YOLO_AVAILABLE = False | |
| # 检查GFPGAN是否启用和可用 | |
| if ENABLE_GFPGAN: | |
| try: | |
| required_files = [ | |
| os.path.join(os.path.dirname(__file__), "gfpgan_restorer.py"), | |
| os.path.join(MODELS_PATH, "gfpgan/weights/detection_Resnet50_Final.pth"), | |
| os.path.join(MODELS_PATH, "gfpgan/weights/parsing_parsenet.pth"), | |
| ] | |
| missing_files = [path for path in required_files if not os.path.exists(path)] | |
| if missing_files: | |
| for file_path in missing_files: | |
| logger.info("GFPGAN 所需文件暂未找到,将等待模型同步: %s", file_path) | |
| from gfpgan_restorer import GFPGANRestorer # noqa: F401 | |
| GFPGAN_AVAILABLE = True | |
| if missing_files: | |
| logger.warning( | |
| "GFPGAN 文件尚未全部就绪,将在 HuggingFace/BOS 同步完成后继续初始化: %s", | |
| ", ".join(missing_files), | |
| ) | |
| else: | |
| logger.info("GFPGAN photo restoration feature prerequisites detected") | |
| except ImportError as e: | |
| print(f"Warning: GFPGAN enabled but not available: {e}") | |
| GFPGAN_AVAILABLE = False | |
| logger.warning(f"GFPGAN photo restoration feature is enabled but import failed: {e}") | |
| else: | |
| GFPGAN_AVAILABLE = False | |
| logger.info("GFPGAN photo restoration feature is disabled (via ENABLE_GFPGAN environment variable)") | |
| # 检查DDColor是否启用和可用 | |
| if ENABLE_DDCOLOR: | |
| try: | |
| from ddcolor_colorizer import DDColorColorizer | |
| DDCOLOR_AVAILABLE = True | |
| logger.info("DDColor feature is enabled and available") | |
| except ImportError as e: | |
| print(f"Warning: DDColor enabled but not available: {e}") | |
| DDCOLOR_AVAILABLE = False | |
| logger.warning(f"DDColor feature is enabled but import failed: {e}") | |
| else: | |
| DDCOLOR_AVAILABLE = False | |
| logger.info("DDColor feature is disabled (via ENABLE_DDCOLOR environment variable)") | |
| # 只使用GFPGAN修复器 | |
| SIMPLE_RESTORER_AVAILABLE = False | |
| # 检查Real-ESRGAN是否启用和可用 | |
| if ENABLE_REALESRGAN: | |
| try: | |
| from realesrgan_upscaler import RealESRGANUpscaler | |
| REALESRGAN_AVAILABLE = True | |
| logger.info("Real-ESRGAN super resolution feature is enabled and available") | |
| except ImportError as e: | |
| print(f"Warning: Real-ESRGAN enabled but not available: {e}") | |
| REALESRGAN_AVAILABLE = False | |
| logger.warning(f"Real-ESRGAN super resolution feature is enabled but import failed: {e}") | |
| else: | |
| REALESRGAN_AVAILABLE = False | |
| logger.info("Real-ESRGAN super resolution feature is disabled (via ENABLE_REALESRGAN environment variable)") | |
| # rembg功能开关配置 | |
| ENABLE_REMBG = os.environ.get("ENABLE_REMBG", "true").lower() in ("1", "true", "on") | |
| # 检查rembg是否启用和可用 | |
| if ENABLE_REMBG: | |
| try: | |
| import rembg | |
| from rembg import new_session | |
| REMBG_AVAILABLE = True | |
| logger.info("rembg background removal feature is enabled and available") | |
| logger.info(f"rembg model storage path: {REMBG_MODEL_DIR}") | |
| except ImportError as e: | |
| print(f"Warning: rembg enabled but not available: {e}") | |
| REMBG_AVAILABLE = False | |
| logger.warning(f"rembg background removal feature is enabled but import failed: {e}") | |
| else: | |
| REMBG_AVAILABLE = False | |
| logger.info("rembg background removal feature is disabled (via ENABLE_REMBG environment variable)") | |
| CLIP_AVAILABLE = False | |
| # 检查Anime Style是否启用和可用 | |
| if ENABLE_ANIME_STYLE: | |
| try: | |
| from anime_stylizer import AnimeStylizer | |
| ANIME_STYLE_AVAILABLE = True | |
| logger.info("Anime stylization feature is enabled and available") | |
| except ImportError as e: | |
| print(f"Warning: Anime Style enabled but not available: {e}") | |
| ANIME_STYLE_AVAILABLE = False | |
| logger.warning(f"Anime stylization feature is enabled but import failed: {e}") | |
| else: | |
| ANIME_STYLE_AVAILABLE = False | |
| logger.info("Anime stylization feature is disabled (via ENABLE_ANIME_STYLE environment variable)") | |
| # RVM功能开关配置 | |
| ENABLE_RVM = os.environ.get("ENABLE_RVM", "true").lower() in ("1", "true", "on") | |
| # 检查RVM是否启用和可用 | |
| if ENABLE_RVM: | |
| try: | |
| import torch | |
| # 检查是否可以加载RVM模型 | |
| RVM_AVAILABLE = True | |
| logger.info("RVM background removal feature is enabled and available") | |
| except ImportError as e: | |
| print(f"Warning: RVM enabled but not available: {e}") | |
| RVM_AVAILABLE = False | |
| logger.warning(f"RVM background removal feature is enabled but import failed: {e}") | |
| else: | |
| RVM_AVAILABLE = False | |
| logger.info("RVM background removal feature is disabled (via ENABLE_RVM environment variable)") | |