SDMatte_plus-fp16-and-bf16 / prune_model.py
aha2023's picture
Upload 2 files
dd57c3c verified
import torch
import os
from collections import OrderedDict
def prune_and_convert_model():
"""
交互式脚本,用于加载PyTorch检查点,移除trainer状态和text_encoder权重,
并根据用户选择的目标精度(fp32, fp16, bf16)保存。
"""
# --- 交互式输入 ---
source_path = input("请输入源模型文件的完整路径(例如 D:\\models\\SDMatte_plus.pth):\n> ")
if not os.path.exists(source_path):
print(f"\n[错误] 文件未找到: {source_path}")
return
while True:
target_precision = input("\n请选择目标精度 (输入 fp32, fp16, 或 bf16):\n> ").lower()
if target_precision in ['fp32', 'fp16', 'bf16']:
break
else:
print("[错误] 无效输入,请输入 'fp32', 'fp16', 或 'bf16' 中的一个。")
# --- 核心逻辑 ---
# 自动生成输出文件名
base, ext = os.path.splitext(source_path)
output_path = f"{base}_pruned_{target_precision}{ext}"
print(f"\n--- 开始处理检查点: {os.path.basename(source_path)} ---")
print(f"目标精度: {target_precision.upper()}")
print(f"输出文件将保存为: {os.path.basename(output_path)}")
try:
print("\n正在加载原始权重文件...")
full_checkpoint = torch.load(source_path, map_location="cpu", weights_only=False)
state_dict = None
if isinstance(full_checkpoint, dict):
if 'model' in full_checkpoint:
state_dict = full_checkpoint['model']
print("检测到训练检查点,已从 'model' 键中提取权重。")
elif 'state_dict' in full_checkpoint:
state_dict = full_checkpoint['state_dict']
print("检测到训练检查点,已从 'state_dict' 键中提取权重。")
else:
state_dict = full_checkpoint
print("文件似乎是一个纯权重字典,将直接处理。")
else:
print("[错误] 文件格式无法识别。")
return
pruned_state_dict = OrderedDict()
removed_count = 0
kept_count = 0
print("正在移除 'text_encoder.' 相关权重...")
for key, value in state_dict.items():
if not key.startswith("text_encoder."):
pruned_state_dict[key] = value
kept_count += 1
else:
removed_count += 1
print(f"移除 {removed_count} 个 text_encoder 权重,保留 {kept_count} 个。")
# 根据用户选择,进行精度转换
if target_precision != 'fp32':
print(f"开始将保留的权重转换为 {target_precision} ...")
target_dtype = torch.float16 if target_precision == 'fp16' else torch.bfloat16
for key, value in pruned_state_dict.items():
if isinstance(value, torch.Tensor) and value.is_floating_point():
pruned_state_dict[key] = value.to(target_dtype)
else:
print("保留原始 FP32 精度。")
print(f"正在保存处理后的模型到: {output_path}")
torch.save(pruned_state_dict, output_path)
source_size = os.path.getsize(source_path) / (1024**2)
pruned_size = os.path.getsize(output_path) / (1024**2)
print("\n--- 成功 ---")
print(f"原始文件大小: {source_size:.2f} MB")
print(f"裁剪后文件大小 ({target_precision.upper()}): {pruned_size:.2f} MB")
print(f"总共节省空间: {source_size - pruned_size:.2f} MB")
except Exception as e:
print(f"\n[错误] 处理过程中发生错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
prune_and_convert_model()