""" 将单个safetensors文件转换为HuggingFace Diffusers格式。 Usage: python convert_single.py --ckpt epoch-4.safetensors --model_type Wan-T2V-14B --output_path ./output """ import argparse import torch from safetensors.torch import load_file from accelerate import init_empty_weights # 从原脚本导入(或直接复制相关字典和函数) from convert_wan import ( get_transformer_config, update_state_dict_, DTYPE_MAPPING, ) from diffusers import WanTransformer3DModel, WanVACETransformer3DModel, WanAnimateTransformer3DModel def convert_single_checkpoint(ckpt_path: str, model_type: str, dtype: str = "bf16"): """ 转换单个checkpoint文件为Diffusers格式Transformer。 Args: ckpt_path: safetensors文件路径 model_type: 模型类型,如 "Wan-T2V-14B", "Wan-I2V-14B-720p" 等 dtype: 输出精度 Returns: 转换后的transformer模型 """ # 1. 获取配置和重命名规则 config, rename_dict, special_keys_remap = get_transformer_config(model_type) diffusers_config = config["diffusers_config"] # 2. 加载原始权重 state_dict = load_file(ckpt_path) # 3. 重命名keys for key in list(state_dict.keys()): new_key = key for old, new in rename_dict.items(): new_key = new_key.replace(old, new) update_state_dict_(state_dict, key, new_key) # 4. 处理特殊keys for key in list(state_dict.keys()): for special_key, handler_fn in special_keys_remap.items(): if special_key in key: handler_fn(key, state_dict) # 5. 创建模型并加载权重 with init_empty_weights(): if "Animate" in model_type: transformer = WanAnimateTransformer3DModel.from_config(diffusers_config) elif "VACE" in model_type: transformer = WanVACETransformer3DModel.from_config(diffusers_config) else: transformer = WanTransformer3DModel.from_config(diffusers_config) transformer.load_state_dict(state_dict, strict=True, assign=True) if dtype != "none": transformer = transformer.to(DTYPE_MAPPING[dtype]) return transformer if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--ckpt", type=str, required=True, help="safetensors文件路径") parser.add_argument("--model_type", type=str, required=True, help="模型类型") parser.add_argument("--output_path", type=str, required=True, help="输出目录") parser.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16", "none"]) args = parser.parse_args() transformer = convert_single_checkpoint(args.ckpt, args.model_type, args.dtype) transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") print(f"Saved to {args.output_path}")