File size: 3,936 Bytes
dd57c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import os
from collections import OrderedDict

def prune_and_convert_model():
    """

    交互式脚本,用于加载PyTorch检查点,移除trainer状态和text_encoder权重,

    并根据用户选择的目标精度(fp32, fp16, bf16)保存。

    """
    # --- 交互式输入 ---
    source_path = input("请输入源模型文件的完整路径(例如 D:\\models\\SDMatte_plus.pth):\n> ")
    
    if not os.path.exists(source_path):
        print(f"\n[错误] 文件未找到: {source_path}")
        return

    while True:
        target_precision = input("\n请选择目标精度 (输入 fp32, fp16, 或 bf16):\n> ").lower()
        if target_precision in ['fp32', 'fp16', 'bf16']:
            break
        else:
            print("[错误] 无效输入,请输入 'fp32', 'fp16', 或 'bf16' 中的一个。")

    # --- 核心逻辑 ---
    # 自动生成输出文件名
    base, ext = os.path.splitext(source_path)
    output_path = f"{base}_pruned_{target_precision}{ext}"
    
    print(f"\n--- 开始处理检查点: {os.path.basename(source_path)} ---")
    print(f"目标精度: {target_precision.upper()}")
    print(f"输出文件将保存为: {os.path.basename(output_path)}")

    try:
        print("\n正在加载原始权重文件...")
        full_checkpoint = torch.load(source_path, map_location="cpu", weights_only=False)
        
        state_dict = None
        if isinstance(full_checkpoint, dict):
            if 'model' in full_checkpoint:
                state_dict = full_checkpoint['model']
                print("检测到训练检查点,已从 'model' 键中提取权重。")
            elif 'state_dict' in full_checkpoint:
                 state_dict = full_checkpoint['state_dict']
                 print("检测到训练检查点,已从 'state_dict' 键中提取权重。")
            else:
                state_dict = full_checkpoint
                print("文件似乎是一个纯权重字典,将直接处理。")
        else:
             print("[错误] 文件格式无法识别。")
             return

        pruned_state_dict = OrderedDict()
        removed_count = 0
        kept_count = 0
        
        print("正在移除 'text_encoder.' 相关权重...")
        for key, value in state_dict.items():
            if not key.startswith("text_encoder."):
                pruned_state_dict[key] = value
                kept_count += 1
            else:
                removed_count += 1
        
        print(f"移除 {removed_count} 个 text_encoder 权重,保留 {kept_count} 个。")
        
        # 根据用户选择,进行精度转换
        if target_precision != 'fp32':
            print(f"开始将保留的权重转换为 {target_precision} ...")
            target_dtype = torch.float16 if target_precision == 'fp16' else torch.bfloat16
            
            for key, value in pruned_state_dict.items():
                if isinstance(value, torch.Tensor) and value.is_floating_point():
                    pruned_state_dict[key] = value.to(target_dtype)
        else:
            print("保留原始 FP32 精度。")

        print(f"正在保存处理后的模型到: {output_path}")
        torch.save(pruned_state_dict, output_path)

        source_size = os.path.getsize(source_path) / (1024**2)
        pruned_size = os.path.getsize(output_path) / (1024**2)

        print("\n--- 成功 ---")
        print(f"原始文件大小: {source_size:.2f} MB")
        print(f"裁剪后文件大小 ({target_precision.upper()}): {pruned_size:.2f} MB")
        print(f"总共节省空间: {source_size - pruned_size:.2f} MB")

    except Exception as e:
        print(f"\n[错误] 处理过程中发生错误: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    prune_and_convert_model()