picpocket / config.py
chawin.chen
fix
56daa9f
import logging
import os
# 解决OpenMP库冲突问题
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# 设置CPU线程数为CPU核心数,提高CPU利用率
import multiprocessing
cpu_cores = multiprocessing.cpu_count()
os.environ["OMP_NUM_THREADS"] = str(min(cpu_cores, 8)) # 最多使用8个线程
os.environ["MKL_NUM_THREADS"] = str(min(cpu_cores, 8))
os.environ["NUMEXPR_NUM_THREADS"] = str(min(cpu_cores, 8))
# 修复torchvision兼容性问题
try:
import torchvision.transforms.functional_tensor
except ImportError:
# 为缺失的functional_tensor模块创建兼容性补丁
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
import sys
from types import ModuleType
# 创建functional_tensor模块
functional_tensor = ModuleType('torchvision.transforms.functional_tensor')
# 添加常用的函数映射
if hasattr(F, 'rgb_to_grayscale'):
functional_tensor.rgb_to_grayscale = F.rgb_to_grayscale
if hasattr(F, 'adjust_brightness'):
functional_tensor.adjust_brightness = F.adjust_brightness
if hasattr(F, 'adjust_contrast'):
functional_tensor.adjust_contrast = F.adjust_contrast
if hasattr(F, 'adjust_saturation'):
functional_tensor.adjust_saturation = F.adjust_saturation
if hasattr(F, 'normalize'):
functional_tensor.normalize = F.normalize
if hasattr(F, 'resize'):
functional_tensor.resize = F.resize
if hasattr(F, 'crop'):
functional_tensor.crop = F.crop
if hasattr(F, 'pad'):
functional_tensor.pad = F.pad
# 将模块添加到sys.modules
sys.modules['torchvision.transforms.functional_tensor'] = functional_tensor
transforms.functional_tensor = functional_tensor
# 环境变量配置 - 禁用TensorFlow优化和GPU
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # 强制使用CPU
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false"
# 修复PyTorch兼容性问题
try:
import torch
import torch.onnx
# 修复GFPGAN的ONNX兼容性
if not hasattr(torch.onnx._internal.exporter, 'ExportOptions'):
from types import SimpleNamespace
torch.onnx._internal.exporter.ExportOptions = SimpleNamespace
# 修复ModelScope的PyTree兼容性 - 更完整的实现
import torch.utils
if not hasattr(torch.utils, '_pytree'):
# 如果_pytree模块不存在,创建一个
from types import ModuleType
torch.utils._pytree = ModuleType('_pytree')
pytree = torch.utils._pytree
if not hasattr(pytree, 'register_pytree_node'):
def register_pytree_node(typ, flatten_fn, unflatten_fn, *, flatten_with_keys_fn=None, **kwargs):
"""兼容性实现:注册PyTree节点类型"""
pass # 简单实现,不做实际操作
pytree.register_pytree_node = register_pytree_node
if not hasattr(pytree, 'tree_flatten'):
def tree_flatten(tree, is_leaf=None):
"""兼容性实现:展平树结构"""
if isinstance(tree, (list, tuple)):
flat = []
spec = []
for i, item in enumerate(tree):
if isinstance(item, (list, tuple, dict)):
sub_flat, sub_spec = tree_flatten(item, is_leaf)
flat.extend(sub_flat)
spec.append((i, sub_spec))
else:
flat.append(item)
spec.append((i, None))
return flat, (type(tree), spec)
elif isinstance(tree, dict):
flat = []
spec = []
for key, value in sorted(tree.items()):
if isinstance(value, (list, tuple, dict)):
sub_flat, sub_spec = tree_flatten(value, is_leaf)
flat.extend(sub_flat)
spec.append((key, sub_spec))
else:
flat.append(value)
spec.append((key, None))
return flat, (dict, spec)
else:
return [tree], None
pytree.tree_flatten = tree_flatten
if not hasattr(pytree, 'tree_unflatten'):
def tree_unflatten(values, spec):
"""兼容性实现:重构树结构"""
if spec is None:
return values[0] if values else None
tree_type, tree_spec = spec
if tree_type in (list, tuple):
result = []
value_idx = 0
for pos, sub_spec in tree_spec:
if sub_spec is None:
result.append(values[value_idx])
value_idx += 1
else:
# 计算子树需要的值数量
sub_count = _count_tree_values(sub_spec)
sub_values = values[value_idx:value_idx + sub_count]
result.append(tree_unflatten(sub_values, sub_spec))
value_idx += sub_count
return tree_type(result)
elif tree_type == dict:
result = {}
value_idx = 0
for key, sub_spec in tree_spec:
if sub_spec is None:
result[key] = values[value_idx]
value_idx += 1
else:
sub_count = _count_tree_values(sub_spec)
sub_values = values[value_idx:value_idx + sub_count]
result[key] = tree_unflatten(sub_values, sub_spec)
value_idx += sub_count
return result
return values[0] if values else None
pytree.tree_unflatten = tree_unflatten
if not hasattr(pytree, 'tree_map'):
def tree_map(fn, tree, *other_trees, is_leaf=None):
"""兼容性实现:树映射"""
flat, spec = tree_flatten(tree, is_leaf)
if other_trees:
other_flats = [tree_flatten(t, is_leaf)[0] for t in other_trees]
mapped = [fn(x, *others) for x, *others in zip(flat, *other_flats)]
else:
mapped = [fn(x) for x in flat]
return tree_unflatten(mapped, spec)
pytree.tree_map = tree_map
# 辅助函数
def _count_tree_values(spec):
"""计算树规格中的值数量"""
if spec is None:
return 1
tree_type, tree_spec = spec
return sum(_count_tree_values(sub_spec) if sub_spec else 1 for _, sub_spec in tree_spec)
# 修复pyarrow兼容性问题
try:
import pyarrow
if not hasattr(pyarrow, 'PyExtensionType'):
# 为旧版本pyarrow添加PyExtensionType兼容性
pyarrow.PyExtensionType = type('PyExtensionType', (), {})
except ImportError:
pass
except (ImportError, AttributeError) as e:
print(f"Warning: PyTorch/PyArrow compatibility patch failed: {e}")
pass
IMAGES_DIR = os.environ.get("IMAGES_DIR", "/opt/data/images")
OUTPUT_DIR = IMAGES_DIR
# 明星图库目录配置
CELEBRITY_SOURCE_DIR = os.environ.get(
"CELEBRITY_SOURCE_DIR", "/opt/data/chinese_celeb_dataset"
).strip()
if CELEBRITY_SOURCE_DIR:
CELEBRITY_SOURCE_DIR = os.path.abspath(os.path.expanduser(CELEBRITY_SOURCE_DIR))
CELEBRITY_DATASET_DIR = os.path.abspath(
os.path.expanduser(
os.environ.get(
"CELEBRITY_DATASET_DIR",
CELEBRITY_SOURCE_DIR or "/opt/data/chinese_celeb_dataset",
)
)
)
CELEBRITY_FIND_THRESHOLD = float(
os.environ.get("CELEBRITY_FIND_THRESHOLD", 0.88)
)
# ---- start ----
# 微信小程序配置(默认值仅用于本地开发)
WECHAT_APPID = os.environ.get("WECHAT_APPID", "******").strip()
WECHAT_SECRET = os.environ.get("WCT_SECRET", "******").strip()
APP_SECRET_TOKEN = os.environ.get("APP_SECRET_TOKEN", "******")
# MySQL 数据库配置
MYSQL_HOST = os.environ.get("MYSQL_HOST", "******")
MYSQL_PORT = int(os.environ.get("MYSQL_PORT", "3306"))
MYSQL_DB = os.environ.get("MYSQL_DB", "******")
MYSQL_USER = os.environ.get("MYSQL_USER", "******")
MYSQL_PASSWORD = os.environ.get("MYSQL_PASSWORD", "******")
# BOS 对象存储配置(默认存储为Base64编码字符串)
BOS_ACCESS_KEY = os.environ.get("BOS_ACCESS_KEY", "******").strip()
BOS_SECRET_KEY = os.environ.get("BOS_SECRET_KEY", "******").strip()
BOS_ENDPOINT = os.environ.get("BOS_ENDPOINT", "******").strip()
BOS_BUCKET_NAME = os.environ.get("BOS_BUCKET_NAME", "******").strip()
BOS_IMAGE_DIR = os.environ.get("BOS_IMAGE_DIR", "******").strip()
BOS_MODELS_PREFIX = os.environ.get("BOS_MODELS_PREFIX", "******").strip()
BOS_CELEBRITY_PREFIX = os.environ.get("BOS_CELEBRITY_PREFIX", "******").strip()
# ---- end ---
_bos_enabled_env = os.environ.get("BOS_UPLOAD_ENABLED")
MYSQL_POOL_MIN_SIZE = int(os.environ.get("MYSQL_POOL_MIN_SIZE", "1"))
MYSQL_POOL_MAX_SIZE = int(os.environ.get("MYSQL_POOL_MAX_SIZE", "10"))
if _bos_enabled_env is not None:
BOS_UPLOAD_ENABLED = _bos_enabled_env.lower() in ("1", "true", "on")
else:
BOS_UPLOAD_ENABLED = all(
[
BOS_ACCESS_KEY.strip(),
BOS_SECRET_KEY.strip(),
BOS_ENDPOINT,
BOS_BUCKET_NAME,
]
)
HOSTNAME = os.environ.get("HOSTNAME", "default-hostname")
MODELS_PATH = os.path.abspath(
os.path.expanduser(os.environ.get("MODELS_PATH", "/opt/data/models"))
)
MODELS_DOWNLOAD_DIR = os.path.abspath(
os.path.expanduser(os.environ.get("MODELS_DOWNLOAD_DIR", MODELS_PATH))
)
# HuggingFace 仓库配置
HUGGINGFACE_SYNC_ENABLED = os.environ.get(
"HUGGINGFACE_SYNC_ENABLED", "true"
).lower() in ("1", "true", "on")
HUGGINGFACE_REPO_ID = os.environ.get(
"HUGGINGFACE_REPO_ID", "ethonmax/facescore"
).strip()
HUGGINGFACE_REVISION = os.environ.get(
"HUGGINGFACE_REVISION", "main"
).strip()
_hf_allow_env = os.environ.get("HUGGINGFACE_ALLOW_PATTERNS", "").strip()
HUGGINGFACE_ALLOW_PATTERNS = [
pattern.strip() for pattern in _hf_allow_env.split(",") if pattern.strip()
]
_hf_ignore_env = os.environ.get("HUGGINGFACE_IGNORE_PATTERNS", "").strip()
HUGGINGFACE_IGNORE_PATTERNS = [
pattern.strip() for pattern in _hf_ignore_env.split(",") if pattern.strip()
]
_MODELSCOPE_CACHE_ENV = os.environ.get("MODELSCOPE_CACHE", "").strip()
if _MODELSCOPE_CACHE_ENV:
MODELSCOPE_CACHE_DIR = os.path.abspath(os.path.expanduser(_MODELSCOPE_CACHE_ENV))
else:
MODELSCOPE_CACHE_DIR = os.path.join(MODELS_PATH, "modelscope")
try:
os.makedirs(MODELSCOPE_CACHE_DIR, exist_ok=True)
except Exception as exc:
print(f"创建 ModelScope 缓存目录失败: %s (%s)", MODELSCOPE_CACHE_DIR, exc)
os.environ.setdefault("MODELSCOPE_CACHE", MODELSCOPE_CACHE_DIR)
os.environ.setdefault("MODELSCOPE_HOME", MODELSCOPE_CACHE_DIR)
os.environ.setdefault("MODELSCOPE_CACHE_HOME", MODELSCOPE_CACHE_DIR)
DEEPFACE_HOME = os.environ.get("DEEPFACE_HOME", "/opt/data/models")
os.environ["DEEPFACE_HOME"] = DEEPFACE_HOME
# 设置GFPGAN相关模型下载路径
GFPGAN_MODEL_DIR = MODELS_DOWNLOAD_DIR
os.makedirs(GFPGAN_MODEL_DIR, exist_ok=True)
# 设置各种模型库的下载目录环境变量
os.environ["GFPGAN_MODEL_ROOT"] = GFPGAN_MODEL_DIR
os.environ["FACEXLIB_CACHE_DIR"] = GFPGAN_MODEL_DIR
os.environ["BASICSR_CACHE_DIR"] = GFPGAN_MODEL_DIR
os.environ["REALESRGAN_MODEL_ROOT"] = GFPGAN_MODEL_DIR
os.environ["HUB_CACHE_DIR"] = GFPGAN_MODEL_DIR # PyTorch Hub缓存
# 设置rembg模型下载路径到统一的AI模型目录
REMBG_MODEL_DIR = os.path.expanduser(MODELS_PATH.replace("$HOME", "~"))
os.environ["U2NET_HOME"] = REMBG_MODEL_DIR # u2net模型缓存目录
os.environ["REMBG_HOME"] = REMBG_MODEL_DIR # rembg通用缓存目录
IMG_QUALITY = float(os.environ.get("IMG_QUALITY", 0.5))
FACE_CONFIDENCE = float(os.environ.get("FACE_CONFIDENCE", 0.7))
AGE_CONFIDENCE = float(os.environ.get("AGE_CONFIDENCE", 0.99))
GENDER_CONFIDENCE = float(os.environ.get("GENDER_CONFIDENCE", 1.1))
# 是否启用 DeepFace 的情绪识别(默认开启;关闭可减少推理耗时)
DEEPFACE_EMOTION_ENABLED = os.environ.get("DEEPFACE_EMOTION_ENABLED", "true").lower() in ("1", "true", "on")
UPSCALE_SIZE = int(os.environ.get("UPSCALE_SIZE", 2))
SAVE_QUALITY = int(os.environ.get("SAVE_QUALITY", 85))
REALESRGAN_MODEL = os.environ.get("REALESRGAN_MODEL", "realesr-general-x4v3")
# yolov11n-face.pt / yolov8n-face.pt
YOLO_MODEL = os.environ.get("YOLO_MODEL", "yolov11n-face.pt")
# mobilenetv3/resnet50
RVM_MODEL = os.environ.get("RVM_MODEL", "resnet50")
RVM_LOCAL_REPO = os.environ.get("RVM_LOCAL_REPO", "/opt/data/RobustVideoMatting").strip()
RVM_WEIGHTS_PATH = os.environ.get("RVM_WEIGHTS_PATH", "/opt/data/models/torch/hub/checkpoints/rvm_resnet50.pth").strip()
DRAW_SCORE = os.environ.get("DRAW_SCORE", "true").lower() in ("1", "true", "on")
# 颜值评分温和提升配置(默认开启;默认区间与力度:区间=[6.0, 8.0],gamma=0.3)
# - BEAUTY_ADJUST_ENABLED: 是否开启提分
# - BEAUTY_ADJUST_MIN: 提分下限(低于该值不提分)
# - BEAUTY_ADJUST_MAX: 提分上限(目标上限;仅在 [min, max) 区间内提分)
# - BEAUTY_ADJUST_THRESHOLD: 兼容旧配置,等价于 BEAUTY_ADJUST_MAX
# - BEAUTY_ADJUST_GAMMA: 提分力度,(0,1],越小提升越多
BEAUTY_ADJUST_ENABLED = os.environ.get("BEAUTY_ADJUST_ENABLED", "true").lower() in ("1", "true", "on")
BEAUTY_ADJUST_MIN = float(os.environ.get("BEAUTY_ADJUST_MIN", 1.0))
# 向后兼容:未提供 BEAUTY_ADJUST_MAX 时,使用旧的 BEAUTY_ADJUST_THRESHOLD 或 8.0
_legacy_thr = os.environ.get("BEAUTY_ADJUST_THRESHOLD")
BEAUTY_ADJUST_MAX = float(os.environ.get("BEAUTY_ADJUST_MAX", _legacy_thr if _legacy_thr is not None else 8.0))
BEAUTY_ADJUST_GAMMA = float(os.environ.get("BEAUTY_ADJUST_GAMMA", 0.5)) # 0<gamma<=1,越小提升越多
# 兼容旧引用,保留变量名(不再直接使用于逻辑内部)
BEAUTY_ADJUST_THRESHOLD = BEAUTY_ADJUST_MAX
# 整体协调性分数温和提升配置(默认开启;默认阈值与力度:T=8.0, gamma=0.5)
HARMONY_ADJUST_ENABLED = os.environ.get("HARMONY_ADJUST_ENABLED", "true").lower() in ("1", "true", "on")
HARMONY_ADJUST_THRESHOLD = float(os.environ.get("HARMONY_ADJUST_THRESHOLD", 9.0))
HARMONY_ADJUST_GAMMA = float(os.environ.get("HARMONY_ADJUST_GAMMA", 0.3))
# 启动优化:是否在启动时自动初始化/预热重型组件
ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "false").lower() in ("1", "true", "on")
AUTO_INIT_ANALYZER = os.environ.get("AUTO_INIT_ANALYZER", "true").lower() in ("1", "true", "on")
AUTO_INIT_GFPGAN = os.environ.get("AUTO_INIT_GFPGAN", "false").lower() in ("1", "true", "on")
AUTO_INIT_DDCOLOR = os.environ.get("AUTO_INIT_DDCOLOR", "false").lower() in ("1", "true", "on")
AUTO_INIT_REALESRGAN = os.environ.get("AUTO_INIT_REALESRGAN", "false").lower() in ("1", "true", "on")
AUTO_INIT_REMBG = os.environ.get("AUTO_INIT_REMBG", "false").lower() in ("1", "true", "on")
AUTO_INIT_ANIME_STYLE = os.environ.get("AUTO_INIT_ANIME_STYLE", "false").lower() in ("1", "true", "on")
AUTO_INIT_RVM = os.environ.get("AUTO_INIT_RVM", "false").lower() in ("1", "true", "on")
# 定时任务相关配置
CLEANUP_INTERVAL_HOURS = float(os.environ.get("CLEANUP_INTERVAL_HOURS", 1.0)) # 清理任务执行间隔(小时),默认1小时
CLEANUP_AGE_HOURS = float(os.environ.get("CLEANUP_AGE_HOURS", 1.0)) # 清理文件的年龄阈值(小时),默认1小时
# BOS 自动同步清单:定义 BOS 路径和本地目录的映射,启动时可迭代该结构完成批量下载
BOS_DOWNLOAD_TARGETS = [
# {
# "description": "明星图库数据集",
# "bos_prefix": BOS_CELEBRITY_PREFIX,
# "destination": CELEBRITY_DATASET_DIR,
# "background": True,
# },
# {
# "description": "AI 模型权重",
# "bos_prefix": BOS_MODELS_PREFIX,
# "destination": MODELS_DOWNLOAD_DIR,
# },
]
log_level_str = os.getenv("LOG_LEVEL", "INFO").upper()
log_level = getattr(logging, log_level_str, logging.INFO)
# 日志开关配置 - 控制是否启用所有日志输出
ENABLE_LOGGING = os.environ.get("ENABLE_LOGGING", "true").lower() in ("1", "true", "on")
# 功能开关配置
ENABLE_DDCOLOR = os.environ.get("ENABLE_DDCOLOR", "true").lower() in ("1", "true", "on")
ENABLE_REALESRGAN = os.environ.get("ENABLE_REALESRGAN", "true").lower() in ("1", "true", "on")
ENABLE_GFPGAN = os.environ.get("ENABLE_GFPGAN", "true").lower() in ("1", "true", "on")
ENABLE_ANIME_STYLE = os.environ.get("ENABLE_ANIME_STYLE", "true").lower() in ("1", "true", "on")
ENABLE_ANIME_PRELOAD = os.environ.get("ENABLE_ANIME_PRELOAD", "false").lower() in ("1", "true", "on")
ENABLE_RVM = os.environ.get("ENABLE_RVM", "true").lower() in ("1", "true", "on")
# 颜值评分模块配置
FACE_SCORE_MAX_IMAGES = int(os.environ.get("FACE_SCORE_MAX_IMAGES", 10)) # 颜值评分最大上传图片数量
# 女性年龄调整配置 - 对于20岁以上的女性,显示的年龄会减去指定岁数
FEMALE_AGE_ADJUSTMENT = int(os.environ.get("FEMALE_AGE_ADJUSTMENT", 3)) # 默认减3岁
FEMALE_AGE_ADJUSTMENT_THRESHOLD = int(os.environ.get("FEMALE_AGE_ADJUSTMENT_THRESHOLD", 20)) # 年龄阈值,默认20岁
# 配置日志
if ENABLE_LOGGING:
logging.basicConfig(
level=log_level,
format="[%(asctime)s] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
else:
# 禁用所有日志输出
logging.basicConfig(level=logging.CRITICAL + 10)
logger = logging.getLogger(__name__)
logger.disabled = True
# 全局变量存储 access_token
access_token_cache = {"token": None, "expires_at": 0}
# 尝试导入依赖
try:
from deepface import DeepFace
DEEPFACE_AVAILABLE = True
except ImportError:
print("Warning: DeepFace not installed. Install with: pip install deepface")
DEEPFACE_AVAILABLE = False
try:
import mediapipe as mp
MEDIAPIPE_AVAILABLE = True
except ImportError:
print("Warning: mediapipe not installed. Install with: pip install mediapipe")
MEDIAPIPE_AVAILABLE = False
# 为了保持向后兼容,保留 DLIB_AVAILABLE 变量名
DLIB_AVAILABLE = MEDIAPIPE_AVAILABLE
try:
from ultralytics import YOLO
YOLO_AVAILABLE = True
except ImportError:
print("Warning: ultralytics not installed. Install with: pip install ultralytics")
YOLO_AVAILABLE = False
# 检查GFPGAN是否启用和可用
if ENABLE_GFPGAN:
try:
required_files = [
os.path.join(os.path.dirname(__file__), "gfpgan_restorer.py"),
os.path.join(MODELS_PATH, "gfpgan/weights/detection_Resnet50_Final.pth"),
os.path.join(MODELS_PATH, "gfpgan/weights/parsing_parsenet.pth"),
]
missing_files = [path for path in required_files if not os.path.exists(path)]
if missing_files:
for file_path in missing_files:
logger.info("GFPGAN 所需文件暂未找到,将等待模型同步: %s", file_path)
from gfpgan_restorer import GFPGANRestorer # noqa: F401
GFPGAN_AVAILABLE = True
if missing_files:
logger.warning(
"GFPGAN 文件尚未全部就绪,将在 HuggingFace/BOS 同步完成后继续初始化: %s",
", ".join(missing_files),
)
else:
logger.info("GFPGAN photo restoration feature prerequisites detected")
except ImportError as e:
print(f"Warning: GFPGAN enabled but not available: {e}")
GFPGAN_AVAILABLE = False
logger.warning(f"GFPGAN photo restoration feature is enabled but import failed: {e}")
else:
GFPGAN_AVAILABLE = False
logger.info("GFPGAN photo restoration feature is disabled (via ENABLE_GFPGAN environment variable)")
# 检查DDColor是否启用和可用
if ENABLE_DDCOLOR:
try:
from ddcolor_colorizer import DDColorColorizer
DDCOLOR_AVAILABLE = True
logger.info("DDColor feature is enabled and available")
except ImportError as e:
print(f"Warning: DDColor enabled but not available: {e}")
DDCOLOR_AVAILABLE = False
logger.warning(f"DDColor feature is enabled but import failed: {e}")
else:
DDCOLOR_AVAILABLE = False
logger.info("DDColor feature is disabled (via ENABLE_DDCOLOR environment variable)")
# 只使用GFPGAN修复器
SIMPLE_RESTORER_AVAILABLE = False
# 检查Real-ESRGAN是否启用和可用
if ENABLE_REALESRGAN:
try:
from realesrgan_upscaler import RealESRGANUpscaler
REALESRGAN_AVAILABLE = True
logger.info("Real-ESRGAN super resolution feature is enabled and available")
except ImportError as e:
print(f"Warning: Real-ESRGAN enabled but not available: {e}")
REALESRGAN_AVAILABLE = False
logger.warning(f"Real-ESRGAN super resolution feature is enabled but import failed: {e}")
else:
REALESRGAN_AVAILABLE = False
logger.info("Real-ESRGAN super resolution feature is disabled (via ENABLE_REALESRGAN environment variable)")
# rembg功能开关配置
ENABLE_REMBG = os.environ.get("ENABLE_REMBG", "true").lower() in ("1", "true", "on")
# 检查rembg是否启用和可用
if ENABLE_REMBG:
try:
import rembg
from rembg import new_session
REMBG_AVAILABLE = True
logger.info("rembg background removal feature is enabled and available")
logger.info(f"rembg model storage path: {REMBG_MODEL_DIR}")
except ImportError as e:
print(f"Warning: rembg enabled but not available: {e}")
REMBG_AVAILABLE = False
logger.warning(f"rembg background removal feature is enabled but import failed: {e}")
else:
REMBG_AVAILABLE = False
logger.info("rembg background removal feature is disabled (via ENABLE_REMBG environment variable)")
CLIP_AVAILABLE = False
# 检查Anime Style是否启用和可用
if ENABLE_ANIME_STYLE:
try:
from anime_stylizer import AnimeStylizer
ANIME_STYLE_AVAILABLE = True
logger.info("Anime stylization feature is enabled and available")
except ImportError as e:
print(f"Warning: Anime Style enabled but not available: {e}")
ANIME_STYLE_AVAILABLE = False
logger.warning(f"Anime stylization feature is enabled but import failed: {e}")
else:
ANIME_STYLE_AVAILABLE = False
logger.info("Anime stylization feature is disabled (via ENABLE_ANIME_STYLE environment variable)")
# RVM功能开关配置
ENABLE_RVM = os.environ.get("ENABLE_RVM", "true").lower() in ("1", "true", "on")
# 检查RVM是否启用和可用
if ENABLE_RVM:
try:
import torch
# 检查是否可以加载RVM模型
RVM_AVAILABLE = True
logger.info("RVM background removal feature is enabled and available")
except ImportError as e:
print(f"Warning: RVM enabled but not available: {e}")
RVM_AVAILABLE = False
logger.warning(f"RVM background removal feature is enabled but import failed: {e}")
else:
RVM_AVAILABLE = False
logger.info("RVM background removal feature is disabled (via ENABLE_RVM environment variable)")