File size: 2,794 Bytes
90cf301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4a3c28
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import os

# --- 配置 ---
# 1. 设置目标精度: 'fp32', 'fp16', 或 'bf16'
#    - 'fp32': 移除训练数据,保留32位全精度。
#    - 'fp16': 转换为16位半精度,体积最小。
#    - 'bf16': 转换为bfloat16精度,推荐RTX 30系及以上GPU。
TARGET_PRECISION = 'fp32' 

# 2. 设置原始32位训练检查点文件的路径,例如"E:\comfyui\ComfyUI-aki-v1.3\models\SDMatte\1\SDMatte_plus.pth"
fp32_checkpoint_path = r"E:\comfyui\ComfyUI-aki-v1.3\models\SDMatte\1\SDMatte_plus.pth"

# --------------------------------------------------------------------

# 自动生成输出文件名
output_filename = fp32_checkpoint_path.replace('.pth', f'_{TARGET_PRECISION}_inference.pth')

if not os.path.exists(fp32_checkpoint_path):
    print(f"[错误] 文件不存在: {fp32_checkpoint_path}")
else:
    try:
        print(f"--- 开始处理训练检查点: {fp32_checkpoint_path} ---")
        full_checkpoint = torch.load(fp32_checkpoint_path, map_location="cpu", weights_only=False)

        if 'model' in full_checkpoint:
            state_dict = full_checkpoint['model']
            print("成功提取到 'model' 键中的权重字典。")
        else:
            print("[警告] 未在顶层找到 'model' 键,将尝试转换整个文件。")
            state_dict = full_checkpoint
        
        print(f"目标输出精度: {TARGET_PRECISION}")

        # --- MODIFICATION START ---
        # 仅当目标精度不是 'fp32' 时,才执行类型转换
        if TARGET_PRECISION != 'fp32':
            print(f"开始将权重转换为 {TARGET_PRECISION} ...")
            target_dtype = torch.float16 if TARGET_PRECISION == 'fp16' else torch.bfloat16

            for key in state_dict:
                if isinstance(state_dict[key], torch.Tensor) and state_dict[key].is_floating_point():
                    state_dict[key] = state_dict[key].to(target_dtype)
        else:
            print("保留原始 FP32 精度,仅剥离训练数据。")
        # --- MODIFICATION END ---

        print(f"正在保存纯推理模型到: {output_filename} ...")
        torch.save(state_dict, output_filename)
        
        original_size_gb = os.path.getsize(fp32_checkpoint_path) / (1024**3)
        final_size_gb = os.path.getsize(output_filename) / (1024**3)

        print("\n--- 转换成功 ---")
        print(f"原始训练检查点大小: {original_size_gb:.2f} GB")
        print(f"生成的纯推理模型大小 ({TARGET_PRECISION.upper()}): {final_size_gb:.2f} GB")
        print("说明: 新文件只包含用于推理的核心模型权重,已移除训练相关的优化器状态。")

    except Exception as e:
        print(f"\n[错误] 处理过程中发生错误: {e}")