Upload 2 files
Browse files- null_embedding.pt +3 -0
- prune_model.py +94 -0
null_embedding.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:598bbf18a1568585795933f063f6d2bb71ae6b92715b69425690698f7282494e
|
| 3 |
+
size 316607
|
prune_model.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
def prune_and_convert_model():
|
| 6 |
+
"""
|
| 7 |
+
交互式脚本,用于加载PyTorch检查点,移除trainer状态和text_encoder权重,
|
| 8 |
+
并根据用户选择的目标精度(fp32, fp16, bf16)保存。
|
| 9 |
+
"""
|
| 10 |
+
# --- 交互式输入 ---
|
| 11 |
+
source_path = input("请输入源模型文件的完整路径(例如 D:\\models\\SDMatte_plus.pth):\n> ")
|
| 12 |
+
|
| 13 |
+
if not os.path.exists(source_path):
|
| 14 |
+
print(f"\n[错误] 文件未找到: {source_path}")
|
| 15 |
+
return
|
| 16 |
+
|
| 17 |
+
while True:
|
| 18 |
+
target_precision = input("\n请选择目标精度 (输入 fp32, fp16, 或 bf16):\n> ").lower()
|
| 19 |
+
if target_precision in ['fp32', 'fp16', 'bf16']:
|
| 20 |
+
break
|
| 21 |
+
else:
|
| 22 |
+
print("[错误] 无效输入,请输入 'fp32', 'fp16', 或 'bf16' 中的一个。")
|
| 23 |
+
|
| 24 |
+
# --- 核心逻辑 ---
|
| 25 |
+
# 自动生成输出文件名
|
| 26 |
+
base, ext = os.path.splitext(source_path)
|
| 27 |
+
output_path = f"{base}_pruned_{target_precision}{ext}"
|
| 28 |
+
|
| 29 |
+
print(f"\n--- 开始处理检查点: {os.path.basename(source_path)} ---")
|
| 30 |
+
print(f"目标精度: {target_precision.upper()}")
|
| 31 |
+
print(f"输出文件将保存为: {os.path.basename(output_path)}")
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
print("\n正在加载原始权重文件...")
|
| 35 |
+
full_checkpoint = torch.load(source_path, map_location="cpu", weights_only=False)
|
| 36 |
+
|
| 37 |
+
state_dict = None
|
| 38 |
+
if isinstance(full_checkpoint, dict):
|
| 39 |
+
if 'model' in full_checkpoint:
|
| 40 |
+
state_dict = full_checkpoint['model']
|
| 41 |
+
print("检测到训练检查点,已从 'model' 键中提取权重。")
|
| 42 |
+
elif 'state_dict' in full_checkpoint:
|
| 43 |
+
state_dict = full_checkpoint['state_dict']
|
| 44 |
+
print("检测到训练检查点,已从 'state_dict' 键中提取权重。")
|
| 45 |
+
else:
|
| 46 |
+
state_dict = full_checkpoint
|
| 47 |
+
print("文件似乎是一个纯权重字典,将直接处理。")
|
| 48 |
+
else:
|
| 49 |
+
print("[错误] 文件格式无法识别。")
|
| 50 |
+
return
|
| 51 |
+
|
| 52 |
+
pruned_state_dict = OrderedDict()
|
| 53 |
+
removed_count = 0
|
| 54 |
+
kept_count = 0
|
| 55 |
+
|
| 56 |
+
print("正在移除 'text_encoder.' 相关权重...")
|
| 57 |
+
for key, value in state_dict.items():
|
| 58 |
+
if not key.startswith("text_encoder."):
|
| 59 |
+
pruned_state_dict[key] = value
|
| 60 |
+
kept_count += 1
|
| 61 |
+
else:
|
| 62 |
+
removed_count += 1
|
| 63 |
+
|
| 64 |
+
print(f"移除 {removed_count} 个 text_encoder 权重,保留 {kept_count} 个。")
|
| 65 |
+
|
| 66 |
+
# 根据用户选择,进行精度转换
|
| 67 |
+
if target_precision != 'fp32':
|
| 68 |
+
print(f"开始将保留的权重转换为 {target_precision} ...")
|
| 69 |
+
target_dtype = torch.float16 if target_precision == 'fp16' else torch.bfloat16
|
| 70 |
+
|
| 71 |
+
for key, value in pruned_state_dict.items():
|
| 72 |
+
if isinstance(value, torch.Tensor) and value.is_floating_point():
|
| 73 |
+
pruned_state_dict[key] = value.to(target_dtype)
|
| 74 |
+
else:
|
| 75 |
+
print("保留原始 FP32 精度。")
|
| 76 |
+
|
| 77 |
+
print(f"正在保存处理后的模型到: {output_path}")
|
| 78 |
+
torch.save(pruned_state_dict, output_path)
|
| 79 |
+
|
| 80 |
+
source_size = os.path.getsize(source_path) / (1024**2)
|
| 81 |
+
pruned_size = os.path.getsize(output_path) / (1024**2)
|
| 82 |
+
|
| 83 |
+
print("\n--- 成功 ---")
|
| 84 |
+
print(f"原始文件大小: {source_size:.2f} MB")
|
| 85 |
+
print(f"裁剪后文件大小 ({target_precision.upper()}): {pruned_size:.2f} MB")
|
| 86 |
+
print(f"总共节省空间: {source_size - pruned_size:.2f} MB")
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"\n[错误] 处理过程中发生错误: {e}")
|
| 90 |
+
import traceback
|
| 91 |
+
traceback.print_exc()
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
prune_and_convert_model()
|