picpocket / anime_stylizer.py
chawin.chen
init
cd5aabe
raw
history blame
18.4 kB
import os
import tempfile
import time
import cv2
from config import logger
class AnimeStylizer:
def __init__(self):
start_time = time.perf_counter()
self.stylizers = {} # 存储不同风格的模型
self.current_style = None
self.current_stylizer = None
# 检查是否启用Anime Style功能
from config import ENABLE_ANIME_STYLE
if ENABLE_ANIME_STYLE:
self._initialize_models()
else:
logger.info("Anime Style feature is disabled, skipping model initialization")
init_time = time.perf_counter() - start_time
if hasattr(self, 'model_configs') and len(self.model_configs) > 0:
logger.info(f"AnimeStylizer initialized successfully, time: {init_time:.3f}s")
else:
logger.info(f"AnimeStylizer initialization completed but not available, time: {init_time:.3f}s")
def _initialize_models(self):
"""初始化所有Anime Style模型(使用ModelScope)"""
try:
logger.info("Initializing multiple Anime Style models (using ModelScope)...")
# 添加torch类型兼容性补丁
import torch
if not hasattr(torch, 'uint64'):
logger.info("Adding torch.uint64 compatibility patch...")
torch.uint64 = torch.int64 # 使用int64作为uint64的替代
if not hasattr(torch, 'uint32'):
logger.info("Adding torch.uint32 compatibility patch...")
torch.uint32 = torch.int32 # 使用int32作为uint32的替代
if not hasattr(torch, 'uint16'):
logger.info("Adding torch.uint16 compatibility patch...")
torch.uint16 = torch.int16 # 使用int16作为uint16的替代
# 导入ModelScope相关模块
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
self.OutputKeys = OutputKeys
# 定义所有可用的模型和风格
self.model_configs = {
"handdrawn": {
"model_id": "iic/cv_unet_person-image-cartoon-handdrawn_compound-models",
"name": "手绘风格",
"description": "手绘动漫风格 - 传统手绘感觉,线条清晰"
},
"disney": {
"model_id": "iic/cv_unet_person-image-cartoon-3d_compound-models",
"name": "迪士尼风格",
"description": "迪士尼风格 - 立体感强,色彩鲜艳"
},
"illustration": {
"model_id": "iic/cv_unet_person-image-cartoon-sd-design_compound-models",
"name": "插画风格",
"description": "插画风格 - 现代插画设计感"
},
"artstyle": {
"model_id": "iic/cv_unet_person-image-cartoon-artstyle_compound-models",
"name": "艺术风格",
"description": "艺术风格 - 独特的艺术表现力"
},
"anime": {
"model_id": "iic/cv_unet_person-image-cartoon_compound-models",
"name": "二次元风格",
"description": "二次元风格 - 经典动漫角色风格"
},
"sketch": {
"model_id": "iic/cv_unet_person-image-cartoon-sketch_compound-models",
"name": "素描风格",
"description": "素描风格 - 黑白素描画效果"
}
}
logger.info(f"Defined {len(self.model_configs)} anime style model configurations")
logger.info("Models will be loaded on-demand when first used to save memory")
# 检查是否启用预加载
try:
from config import ENABLE_ANIME_PRELOAD
if ENABLE_ANIME_PRELOAD:
logger.info("Enabling anime style model preloading...")
self.preload_models()
else:
logger.info("Anime style model preloading is disabled, will be loaded on-demand when first used")
except ImportError:
logger.info("Anime style model preloading configuration not found, will be loaded on-demand when first used")
except ImportError as e:
logger.error(f"ModelScope module import failed: {e}")
self.model_configs = {}
except Exception as e:
logger.error(f"Anime Style model initialization failed: {e}")
self.model_configs = {}
def _load_model(self, style_type):
"""按需加载指定风格的模型"""
if style_type not in self.model_configs:
logger.error(f"Unsupported style type: {style_type}")
return False
if style_type in self.stylizers:
logger.info(f"Model {style_type} already loaded, using directly")
return True
try:
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
config = self.model_configs[style_type]
logger.info(f"Loading {config['name']} model: {config['model_id']}")
# 根据模型类型选择合适的任务类型
if "stable_diffusion" in config["model_id"]:
# Stable Diffusion 系列模型使用文生图任务类型
task_type = Tasks.text_to_image_synthesis
logger.info(f"Using text_to_image_synthesis task type to load Stable Diffusion model")
else:
# UNet 系列模型使用人像风格化任务
task_type = Tasks.image_portrait_stylization
logger.info(f"Using image_portrait_stylization task type to load UNet model")
stylizer = pipeline(task_type, model=config["model_id"])
self.stylizers[style_type] = stylizer
logger.info(f"{config['name']} model loaded successfully")
return True
except Exception as e:
logger.error(f"Failed to load {style_type} model: {e}")
return False
def preload_models(self, style_types=None):
"""
预加载指定的动漫风格模型
:param style_types: 要预加载的风格类型列表,如果为None则预加载所有模型
"""
if not self.is_available():
logger.warning("Anime Style module is not available, cannot preload models")
return
if style_types is None:
style_types = list(self.model_configs.keys())
elif isinstance(style_types, str):
style_types = [style_types]
logger.info(f"Starting to preload anime style models: {style_types}")
successful_loads = []
failed_loads = []
for style_type in style_types:
if style_type not in self.model_configs:
logger.warning(f"Unknown style type: {style_type}, skipping preload")
failed_loads.append(style_type)
continue
try:
logger.info(f"Preloading model: {self.model_configs[style_type]['name']} ({style_type})")
if self._load_model(style_type):
successful_loads.append(style_type)
logger.info(f"✓ Successfully preloaded: {self.model_configs[style_type]['name']}")
else:
failed_loads.append(style_type)
logger.error(f"✗ Preload failed: {self.model_configs[style_type]['name']}")
except Exception as e:
logger.error(f"✗ Exception occurred while preloading model {style_type}: {e}")
failed_loads.append(style_type)
if successful_loads:
logger.info(f"Successfully preloaded models ({len(successful_loads)}): {successful_loads}")
if failed_loads:
logger.warning(f"Failed to preload models ({len(failed_loads)}): {failed_loads}")
logger.info(f"Anime style model preloading completed, success: {len(successful_loads)}/{len(style_types)}")
def get_loaded_models(self):
"""
获取已加载的模型列表
:return: 已加载的模型风格类型列表
"""
return list(self.stylizers.keys())
def is_model_loaded(self, style_type):
"""
检查指定风格的模型是否已加载
:param style_type: 风格类型
:return: 是否已加载
"""
return style_type in self.stylizers
def get_preload_status(self):
"""
获取模型预加载状态
:return: 包含预加载状态的字典
"""
total_models = len(self.model_configs)
loaded_models = len(self.stylizers)
status = {
"total_models": total_models,
"loaded_models": loaded_models,
"preload_ratio": f"{loaded_models}/{total_models}",
"preload_percentage": round((loaded_models / total_models * 100) if total_models > 0 else 0, 1),
"available_styles": list(self.model_configs.keys()),
"loaded_styles": list(self.stylizers.keys()),
"unloaded_styles": [style for style in self.model_configs.keys() if style not in self.stylizers]
}
return status
def is_available(self):
"""检查Anime Stylizer是否可用"""
return hasattr(self, 'model_configs') and len(self.model_configs) > 0
def stylize_image(self, image, style_type="disney"):
"""
对图像进行动漫风格化
:param image: 输入图像 (numpy array, BGR格式)
:param style_type: 动漫风格类型,支持的类型:
"handdrawn" - 手绘风格
"disney" - 迪士尼风格 (默认)
"illustration" - 插画风格
"flat" - 扁平风格
"clipart" - 剪贴画风格
"watercolor" - 水彩风格
"artstyle" - 艺术风格
"anime" - 二次元风格
"sketch" - 素描风格
:return: 动漫风格化后的图像 (numpy array, BGR格式)
"""
if not self.is_available():
logger.error("Anime Style model not initialized")
return image
# 加载指定风格的模型
if not self._load_model(style_type):
logger.error(f"Failed to load {style_type} model")
return image
return self._stylize_image_via_file(image, style_type)
def _stylize_image_via_file(self, image, style_type="disney"):
"""
通过临时文件进行动漫风格化
:param image: 输入图像 (numpy array, BGR格式)
:param style_type: 动漫风格类型
:return: 动漫风格化后的图像 (numpy array, BGR格式)
"""
try:
config = self.model_configs.get(style_type, {})
style_name = config.get('name', style_type)
logger.info(f"Using anime stylization processing, style type: {style_name} ({style_type})")
# 验证风格类型
if style_type not in self.model_configs:
logger.warning(f"Invalid style type: {style_type}, using default style disney")
style_type = "disney"
# 使用最高质量设置保存临时图像
with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as tmp_input:
# 使用WebP格式,最高质量设置
cv2.imwrite(tmp_input.name, image, [cv2.IMWRITE_WEBP_QUALITY, 100])
tmp_input_path = tmp_input.name
try:
logger.info(f"Temporary file saved to: {tmp_input_path}")
# 使用ModelScope进行动漫风格化
stylizer = self.stylizers[style_type]
# 根据模型类型使用不同的调用方式
if "stable_diffusion" in config["model_id"]:
# Stable Diffusion模型需要特殊处理
logger.info("Using Stable Diffusion model, text parameter is required")
# 对于Stable Diffusion,必须使用'sks style'格式的提示词
style_prompts = {}
prompt = style_prompts.get(style_type, "sks style, cartoon style artwork")
logger.info(f"Using prompt: {prompt}")
result = stylizer({"text": prompt})
else:
# UNet模型直接处理
result = stylizer(tmp_input_path)
# 获取风格化后的图像
# 不同模型的输出键名可能不同,需要适配
if "stable_diffusion" in config["model_id"]:
# Stable Diffusion模型通常使用不同的输出键名
logger.info(f"Stable Diffusion model output keys: {list(result.keys())}")
if 'output_imgs' in result:
stylized_image = result['output_imgs'][0]
elif 'output_img' in result:
stylized_image = result['output_img']
elif self.OutputKeys.OUTPUT_IMG in result:
stylized_image = result[self.OutputKeys.OUTPUT_IMG]
else:
# 尝试获取第一个图像输出
for key in result.keys():
if isinstance(result[key], (list, tuple)) and len(result[key]) > 0:
stylized_image = result[key][0]
logger.info(f"Using output key: {key}")
break
elif hasattr(result[key], 'shape'):
stylized_image = result[key]
logger.info(f"Using output key: {key}")
break
else:
raise KeyError(f"未找到有效的图像输出键,可用键: {list(result.keys())}")
else:
# UNet模型使用标准输出键
stylized_image = result[self.OutputKeys.OUTPUT_IMG]
logger.info(f"Anime stylization output: size={stylized_image.shape}, type={stylized_image.dtype}")
# ModelScope输出的图像已经是BGR格式,不需要转换
logger.info("Anime stylization processing completed")
return stylized_image
finally:
# 清理临时文件
try:
os.unlink(tmp_input_path)
except:
pass
except Exception as e:
logger.error(f"Anime stylization processing failed: {e}")
logger.info("Returning original image")
return image
def get_available_styles(self):
"""
获取支持的动漫风格类型
:return: 字典,包含风格代码和描述
"""
if not hasattr(self, 'model_configs'):
return {}
return {
style_type: f"{config['name']} - {config['description'].split(' - ')[1]}"
for style_type, config in self.model_configs.items()
}
def save_debug_image(self, image, filename_prefix):
"""保存调试用的图像"""
try:
debug_path = f"{filename_prefix}_debug.webp"
cv2.imwrite(debug_path, image, [cv2.IMWRITE_WEBP_QUALITY, 95])
logger.info(f"Debug image saved: {debug_path}")
return debug_path
except Exception as e:
logger.error(f"Failed to save debug image: {e}")
return None
def test_stylization(self, test_url=None):
"""
测试动漫风格化功能
:param test_url: 测试图像URL,默认使用官方示例
:return: 测试结果
"""
if not self.is_available():
return False, "Anime Style模型未初始化"
try:
test_url = test_url or 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/portrait.jpg'
logger.info(f"Testing anime stylization feature, using image: {test_url}")
# 测试默认风格
result = self.stylizer(test_url)
stylized_img = result[self.OutputKeys.OUTPUT_IMG]
# 保存测试结果
test_output_path = 'anime_style_test_result.webp'
cv2.imwrite(test_output_path, stylized_img, [cv2.IMWRITE_WEBP_QUALITY, 95])
logger.info(f"Anime stylization test successful, result saved to: {test_output_path}")
return True, f"测试成功,结果保存到: {test_output_path}"
except Exception as e:
logger.error(f"Anime stylization test failed: {e}")
return False, f"测试失败: {e}"
def test_local_image(self, image_path, style_type="disney"):
"""
测试本地图像动漫风格化
:param image_path: 本地图像路径
:param style_type: 动漫风格类型
:return: 测试结果
"""
if not self.is_available():
return False, "Anime Style模型未初始化"
try:
logger.info(f"Testing local image anime stylization: {image_path}, style: {style_type}")
# 读取本地图像
image = cv2.imread(image_path)
if image is None:
return False, f"Unable to read image: {image_path}"
# 保存原图用于对比
self.save_debug_image(image, "original")
# 动漫风格化处理
stylized_image = self.stylize_image(image, style_type)
# 保存风格化结果
result_path = self.save_debug_image(stylized_image, f"anime_style_{style_type}")
logger.info(f"Local image anime stylization successful, result saved to: {result_path}")
return True, f"本地图像动漫风格化成功,结果保存到: {result_path}"
except Exception as e:
logger.error(f"Local image anime stylization failed: {e}")
return False, f"本地图像动漫风格化失败: {e}"