import torch import os from collections import OrderedDict def prune_and_convert_model(): """ 交互式脚本,用于加载PyTorch检查点,移除trainer状态和text_encoder权重, 并根据用户选择的目标精度(fp32, fp16, bf16)保存。 """ # --- 交互式输入 --- source_path = input("请输入源模型文件的完整路径(例如 D:\\models\\SDMatte_plus.pth):\n> ") if not os.path.exists(source_path): print(f"\n[错误] 文件未找到: {source_path}") return while True: target_precision = input("\n请选择目标精度 (输入 fp32, fp16, 或 bf16):\n> ").lower() if target_precision in ['fp32', 'fp16', 'bf16']: break else: print("[错误] 无效输入,请输入 'fp32', 'fp16', 或 'bf16' 中的一个。") # --- 核心逻辑 --- # 自动生成输出文件名 base, ext = os.path.splitext(source_path) output_path = f"{base}_pruned_{target_precision}{ext}" print(f"\n--- 开始处理检查点: {os.path.basename(source_path)} ---") print(f"目标精度: {target_precision.upper()}") print(f"输出文件将保存为: {os.path.basename(output_path)}") try: print("\n正在加载原始权重文件...") full_checkpoint = torch.load(source_path, map_location="cpu", weights_only=False) state_dict = None if isinstance(full_checkpoint, dict): if 'model' in full_checkpoint: state_dict = full_checkpoint['model'] print("检测到训练检查点,已从 'model' 键中提取权重。") elif 'state_dict' in full_checkpoint: state_dict = full_checkpoint['state_dict'] print("检测到训练检查点,已从 'state_dict' 键中提取权重。") else: state_dict = full_checkpoint print("文件似乎是一个纯权重字典,将直接处理。") else: print("[错误] 文件格式无法识别。") return pruned_state_dict = OrderedDict() removed_count = 0 kept_count = 0 print("正在移除 'text_encoder.' 相关权重...") for key, value in state_dict.items(): if not key.startswith("text_encoder."): pruned_state_dict[key] = value kept_count += 1 else: removed_count += 1 print(f"移除 {removed_count} 个 text_encoder 权重,保留 {kept_count} 个。") # 根据用户选择,进行精度转换 if target_precision != 'fp32': print(f"开始将保留的权重转换为 {target_precision} ...") target_dtype = torch.float16 if target_precision == 'fp16' else torch.bfloat16 for key, value in pruned_state_dict.items(): if isinstance(value, torch.Tensor) and value.is_floating_point(): pruned_state_dict[key] = value.to(target_dtype) else: print("保留原始 FP32 精度。") print(f"正在保存处理后的模型到: {output_path}") torch.save(pruned_state_dict, output_path) source_size = os.path.getsize(source_path) / (1024**2) pruned_size = os.path.getsize(output_path) / (1024**2) print("\n--- 成功 ---") print(f"原始文件大小: {source_size:.2f} MB") print(f"裁剪后文件大小 ({target_precision.upper()}): {pruned_size:.2f} MB") print(f"总共节省空间: {source_size - pruned_size:.2f} MB") except Exception as e: print(f"\n[错误] 处理过程中发生错误: {e}") import traceback traceback.print_exc() if __name__ == "__main__": prune_and_convert_model()