lsmpp's picture
Add files using upload-large-folder tool
d477207 verified
#!/usr/bin/env python3
"""
脚本:从整体的Stable Diffusion XL权重文件中提取VAE和UNet组件并单独保存
功能:
1. 使用StableDiffusionXLPipeline.from_single_file加载完整模型
2. 提取VAE和UNet组件
3. 将它们保存为独立的safetensors文件
作者:Assistant
日期:2025年7月16日
"""
import os
import sys
import argparse
from pathlib import Path
import torch
from safetensors.torch import save_file
# 添加diffusers到Python路径
script_dir = Path(__file__).parent.absolute()
diffusers_src = script_dir.parent / "diffusers" / "src"
sys.path.insert(0, str(diffusers_src))
from diffusers import StableDiffusionXLPipeline
def extract_and_save_components(model_path: str, output_dir: str = None):
"""
从完整的SDXL模型文件中提取VAE和UNet组件并保存
Args:
model_path (str): 输入的safetensors模型文件路径
output_dir (str): 输出目录,默认为模型文件所在目录
"""
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"模型文件不存在: {model_path}")
if output_dir is None:
output_dir = model_path.parent
else:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"正在加载模型: {model_path}")
print("这可能需要一些时间...")
# 使用from_single_file加载完整的pipeline
# 设置torch_dtype为float16以节省内存
try:
pipeline = StableDiffusionXLPipeline.from_single_file(
str(model_path),
torch_dtype=torch.float16,
use_safetensors=True,
)
print("✓ 模型加载成功!")
except Exception as e:
print(f"✗ 模型加载失败: {e}")
return False
# 提取VAE组件
print("\n正在提取VAE组件...")
try:
vae = pipeline.vae
vae_state_dict = vae.state_dict()
# 转换为CPU并保持原始精度
vae_state_dict_cpu = {k: v.cpu() for k, v in vae_state_dict.items()}
vae_output_path = output_dir / f"{model_path.stem}_vae.safetensors"
save_file(vae_state_dict_cpu, str(vae_output_path))
print(f"✓ VAE已保存到: {vae_output_path}")
# 保存VAE配置
vae_config_path = output_dir / f"{model_path.stem}_vae_config.json"
import json
with open(vae_config_path, 'w', encoding='utf-8') as f:
json.dump(vae.config, f, indent=2, ensure_ascii=False)
print(f"✓ VAE配置已保存到: {vae_config_path}")
except Exception as e:
print(f"✗ VAE提取失败: {e}")
return False
# 提取UNet组件
print("\n正在提取UNet组件...")
try:
unet = pipeline.unet
unet_state_dict = unet.state_dict()
# 转换为CPU并保持原始精度
unet_state_dict_cpu = {k: v.cpu() for k, v in unet_state_dict.items()}
unet_output_path = output_dir / f"{model_path.stem}_unet.safetensors"
save_file(unet_state_dict_cpu, str(unet_output_path))
print(f"✓ UNet已保存到: {unet_output_path}")
# 保存UNet配置
unet_config_path = output_dir / f"{model_path.stem}_unet_config.json"
import json
with open(unet_config_path, 'w', encoding='utf-8') as f:
json.dump(unet.config, f, indent=2, ensure_ascii=False)
print(f"✓ UNet配置已保存到: {unet_config_path}")
except Exception as e:
print(f"✗ UNet提取失败: {e}")
return False
# 清理内存
del pipeline
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print(f"\n🎉 提取完成! 文件已保存到: {output_dir}")
print("\n生成的文件:")
print(f" - {model_path.stem}_vae.safetensors")
print(f" - {model_path.stem}_vae_config.json")
print(f" - {model_path.stem}_unet.safetensors")
print(f" - {model_path.stem}_unet_config.json")
return True
def print_model_info(model_path: str):
"""
打印模型文件的基本信息
"""
model_path = Path(model_path)
if not model_path.exists():
print(f"模型文件不存在: {model_path}")
return
file_size = model_path.stat().st_size
size_gb = file_size / (1024**3)
print(f"模型文件信息:")
print(f" 路径: {model_path}")
print(f" 大小: {size_gb:.2f} GB")
print(f" 格式: {model_path.suffix}")
def main():
parser = argparse.ArgumentParser(
description="从Stable Diffusion XL模型文件中提取VAE和UNet组件",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例用法:
python extract_vae_unet.py ../models/waiNSFWIllustrious_v140.safetensors
python extract_vae_unet.py ../models/waiNSFWIllustrious_v140.safetensors --output-dir ./extracted_components
python extract_vae_unet.py ../models/waiNSFWIllustrious_v140.safetensors --info-only
"""
)
parser.add_argument(
"model_path",
help="输入的safetensors模型文件路径"
)
parser.add_argument(
"--output-dir", "-o",
help="输出目录 (默认为模型文件所在目录)"
)
parser.add_argument(
"--info-only",
action="store_true",
help="仅显示模型信息,不进行提取"
)
args = parser.parse_args()
# 显示模型信息
print_model_info(args.model_path)
if args.info_only:
return
print("\n" + "="*60)
print("开始提取VAE和UNet组件...")
print("="*60)
# 执行提取
success = extract_and_save_components(args.model_path, args.output_dir)
if success:
print("\n✅ 所有组件提取成功!")
else:
print("\n❌ 提取过程中出现错误!")
sys.exit(1)
if __name__ == "__main__":
main()