| |
| """ |
| 脚本:从整体的Stable Diffusion XL权重文件中提取VAE和UNet组件并单独保存 |
| |
| 功能: |
| 1. 使用StableDiffusionXLPipeline.from_single_file加载完整模型 |
| 2. 提取VAE和UNet组件 |
| 3. 将它们保存为独立的safetensors文件 |
| |
| 作者:Assistant |
| 日期:2025年7月16日 |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from safetensors.torch import save_file |
|
|
| |
| script_dir = Path(__file__).parent.absolute() |
| diffusers_src = script_dir.parent / "diffusers" / "src" |
| sys.path.insert(0, str(diffusers_src)) |
|
|
| from diffusers import StableDiffusionXLPipeline |
|
|
|
|
| def extract_and_save_components(model_path: str, output_dir: str = None): |
| """ |
| 从完整的SDXL模型文件中提取VAE和UNet组件并保存 |
| |
| Args: |
| model_path (str): 输入的safetensors模型文件路径 |
| output_dir (str): 输出目录,默认为模型文件所在目录 |
| """ |
| model_path = Path(model_path) |
| if not model_path.exists(): |
| raise FileNotFoundError(f"模型文件不存在: {model_path}") |
| |
| if output_dir is None: |
| output_dir = model_path.parent |
| else: |
| output_dir = Path(output_dir) |
| |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| print(f"正在加载模型: {model_path}") |
| print("这可能需要一些时间...") |
| |
| |
| |
| try: |
| pipeline = StableDiffusionXLPipeline.from_single_file( |
| str(model_path), |
| torch_dtype=torch.float16, |
| use_safetensors=True, |
| ) |
| print("✓ 模型加载成功!") |
| except Exception as e: |
| print(f"✗ 模型加载失败: {e}") |
| return False |
| |
| |
| print("\n正在提取VAE组件...") |
| try: |
| vae = pipeline.vae |
| vae_state_dict = vae.state_dict() |
| |
| |
| vae_state_dict_cpu = {k: v.cpu() for k, v in vae_state_dict.items()} |
| |
| vae_output_path = output_dir / f"{model_path.stem}_vae.safetensors" |
| save_file(vae_state_dict_cpu, str(vae_output_path)) |
| print(f"✓ VAE已保存到: {vae_output_path}") |
| |
| |
| vae_config_path = output_dir / f"{model_path.stem}_vae_config.json" |
| import json |
| with open(vae_config_path, 'w', encoding='utf-8') as f: |
| json.dump(vae.config, f, indent=2, ensure_ascii=False) |
| print(f"✓ VAE配置已保存到: {vae_config_path}") |
| |
| except Exception as e: |
| print(f"✗ VAE提取失败: {e}") |
| return False |
| |
| |
| print("\n正在提取UNet组件...") |
| try: |
| unet = pipeline.unet |
| unet_state_dict = unet.state_dict() |
| |
| |
| unet_state_dict_cpu = {k: v.cpu() for k, v in unet_state_dict.items()} |
| |
| unet_output_path = output_dir / f"{model_path.stem}_unet.safetensors" |
| save_file(unet_state_dict_cpu, str(unet_output_path)) |
| print(f"✓ UNet已保存到: {unet_output_path}") |
| |
| |
| unet_config_path = output_dir / f"{model_path.stem}_unet_config.json" |
| import json |
| with open(unet_config_path, 'w', encoding='utf-8') as f: |
| json.dump(unet.config, f, indent=2, ensure_ascii=False) |
| print(f"✓ UNet配置已保存到: {unet_config_path}") |
| |
| except Exception as e: |
| print(f"✗ UNet提取失败: {e}") |
| return False |
| |
| |
| del pipeline |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
| print(f"\n🎉 提取完成! 文件已保存到: {output_dir}") |
| print("\n生成的文件:") |
| print(f" - {model_path.stem}_vae.safetensors") |
| print(f" - {model_path.stem}_vae_config.json") |
| print(f" - {model_path.stem}_unet.safetensors") |
| print(f" - {model_path.stem}_unet_config.json") |
| |
| return True |
|
|
|
|
| def print_model_info(model_path: str): |
| """ |
| 打印模型文件的基本信息 |
| """ |
| model_path = Path(model_path) |
| if not model_path.exists(): |
| print(f"模型文件不存在: {model_path}") |
| return |
| |
| file_size = model_path.stat().st_size |
| size_gb = file_size / (1024**3) |
| |
| print(f"模型文件信息:") |
| print(f" 路径: {model_path}") |
| print(f" 大小: {size_gb:.2f} GB") |
| print(f" 格式: {model_path.suffix}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="从Stable Diffusion XL模型文件中提取VAE和UNet组件", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| 示例用法: |
| python extract_vae_unet.py ../models/waiNSFWIllustrious_v140.safetensors |
| python extract_vae_unet.py ../models/waiNSFWIllustrious_v140.safetensors --output-dir ./extracted_components |
| python extract_vae_unet.py ../models/waiNSFWIllustrious_v140.safetensors --info-only |
| """ |
| ) |
| |
| parser.add_argument( |
| "model_path", |
| help="输入的safetensors模型文件路径" |
| ) |
| |
| parser.add_argument( |
| "--output-dir", "-o", |
| help="输出目录 (默认为模型文件所在目录)" |
| ) |
| |
| parser.add_argument( |
| "--info-only", |
| action="store_true", |
| help="仅显示模型信息,不进行提取" |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| print_model_info(args.model_path) |
| |
| if args.info_only: |
| return |
| |
| print("\n" + "="*60) |
| print("开始提取VAE和UNet组件...") |
| print("="*60) |
| |
| |
| success = extract_and_save_components(args.model_path, args.output_dir) |
| |
| if success: |
| print("\n✅ 所有组件提取成功!") |
| else: |
| print("\n❌ 提取过程中出现错误!") |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|