aha2023 commited on
Commit
dd57c3c
·
verified ·
1 Parent(s): 8771da7

Upload 2 files

Browse files
Files changed (2) hide show
  1. null_embedding.pt +3 -0
  2. prune_model.py +94 -0
null_embedding.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:598bbf18a1568585795933f063f6d2bb71ae6b92715b69425690698f7282494e
3
+ size 316607
prune_model.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from collections import OrderedDict
4
+
5
+ def prune_and_convert_model():
6
+ """
7
+ 交互式脚本,用于加载PyTorch检查点,移除trainer状态和text_encoder权重,
8
+ 并根据用户选择的目标精度(fp32, fp16, bf16)保存。
9
+ """
10
+ # --- 交互式输入 ---
11
+ source_path = input("请输入源模型文件的完整路径(例如 D:\\models\\SDMatte_plus.pth):\n> ")
12
+
13
+ if not os.path.exists(source_path):
14
+ print(f"\n[错误] 文件未找到: {source_path}")
15
+ return
16
+
17
+ while True:
18
+ target_precision = input("\n请选择目标精度 (输入 fp32, fp16, 或 bf16):\n> ").lower()
19
+ if target_precision in ['fp32', 'fp16', 'bf16']:
20
+ break
21
+ else:
22
+ print("[错误] 无效输入,请输入 'fp32', 'fp16', 或 'bf16' 中的一个。")
23
+
24
+ # --- 核心逻辑 ---
25
+ # 自动生成输出文件名
26
+ base, ext = os.path.splitext(source_path)
27
+ output_path = f"{base}_pruned_{target_precision}{ext}"
28
+
29
+ print(f"\n--- 开始处理检查点: {os.path.basename(source_path)} ---")
30
+ print(f"目标精度: {target_precision.upper()}")
31
+ print(f"输出文件将保存为: {os.path.basename(output_path)}")
32
+
33
+ try:
34
+ print("\n正在加载原始权重文件...")
35
+ full_checkpoint = torch.load(source_path, map_location="cpu", weights_only=False)
36
+
37
+ state_dict = None
38
+ if isinstance(full_checkpoint, dict):
39
+ if 'model' in full_checkpoint:
40
+ state_dict = full_checkpoint['model']
41
+ print("检测到训练检查点,已从 'model' 键中提取权重。")
42
+ elif 'state_dict' in full_checkpoint:
43
+ state_dict = full_checkpoint['state_dict']
44
+ print("检测到训练检查点,已从 'state_dict' 键中提取权重。")
45
+ else:
46
+ state_dict = full_checkpoint
47
+ print("文件似乎是一个纯权重字典,将直接处理。")
48
+ else:
49
+ print("[错误] 文件格式无法识别。")
50
+ return
51
+
52
+ pruned_state_dict = OrderedDict()
53
+ removed_count = 0
54
+ kept_count = 0
55
+
56
+ print("正在移除 'text_encoder.' 相关权重...")
57
+ for key, value in state_dict.items():
58
+ if not key.startswith("text_encoder."):
59
+ pruned_state_dict[key] = value
60
+ kept_count += 1
61
+ else:
62
+ removed_count += 1
63
+
64
+ print(f"移除 {removed_count} 个 text_encoder 权重,保留 {kept_count} 个。")
65
+
66
+ # 根据用户选择,进行精度转换
67
+ if target_precision != 'fp32':
68
+ print(f"开始将保留的权重转换为 {target_precision} ...")
69
+ target_dtype = torch.float16 if target_precision == 'fp16' else torch.bfloat16
70
+
71
+ for key, value in pruned_state_dict.items():
72
+ if isinstance(value, torch.Tensor) and value.is_floating_point():
73
+ pruned_state_dict[key] = value.to(target_dtype)
74
+ else:
75
+ print("保留原始 FP32 精度。")
76
+
77
+ print(f"正在保存处理后的模型到: {output_path}")
78
+ torch.save(pruned_state_dict, output_path)
79
+
80
+ source_size = os.path.getsize(source_path) / (1024**2)
81
+ pruned_size = os.path.getsize(output_path) / (1024**2)
82
+
83
+ print("\n--- 成功 ---")
84
+ print(f"原始文件大小: {source_size:.2f} MB")
85
+ print(f"裁剪后文件大小 ({target_precision.upper()}): {pruned_size:.2f} MB")
86
+ print(f"总共节省空间: {source_size - pruned_size:.2f} MB")
87
+
88
+ except Exception as e:
89
+ print(f"\n[错误] 处理过程中发生错误: {e}")
90
+ import traceback
91
+ traceback.print_exc()
92
+
93
+ if __name__ == "__main__":
94
+ prune_and_convert_model()