aha2023 commited on
Commit
90cf301
·
verified ·
1 Parent(s): b4a3c28

Upload convert_precision.py

Browse files
Files changed (1) hide show
  1. convert_precision.py +59 -55
convert_precision.py CHANGED
@@ -1,56 +1,60 @@
1
- import torch
2
- import os
3
-
4
- # --- 配置 ---
5
- # 1. 设置目标精度: 'fp16' 或 'bf16'
6
- TARGET_PRECISION = 'bf16'
7
-
8
- # 2. 设置原始32位训练检查点文件的路径
9
- fp32_checkpoint_path = r"E:\comfyui\ComfyUI-aki-v1.3\models\SDMatte\SDMatte_plus.pth"
10
-
11
- # --------------------------------------------------------------------
12
-
13
- # 自动生成输出文件名
14
- output_filename = fp32_checkpoint_path.replace('.pth', f'_{TARGET_PRECISION}_inference.pth')
15
-
16
- if not os.path.exists(fp32_checkpoint_path):
17
- print(f"[错误] 文件不存在: {fp32_checkpoint_path}")
18
- else:
19
- try:
20
- print(f"--- 开始处理训练检查点: {fp32_checkpoint_path} ---")
21
- full_checkpoint = torch.load(fp32_checkpoint_path, map_location="cpu", weights_only=False)
22
-
23
- # 检查 'model' 键是否存在,这是包含我们所需权重的部分
24
- if 'model' in full_checkpoint:
25
- # 明确提取出模型的 state_dict
26
- state_dict = full_checkpoint['model']
27
- print("成功提取到 'model' 键中的权重字典。")
28
- else:
29
- # 如果没有 'model' 键,则假定整个文件就是 state_dict
30
- print("[警告] 未在顶层找到 'model' 键,将尝试转换整个文件。")
31
- state_dict = full_checkpoint
32
-
33
- print(f"开始将权重转换为 {TARGET_PRECISION} ...")
34
-
35
- target_dtype = torch.float16 if TARGET_PRECISION == 'fp16' else torch.bfloat16
36
-
37
- # 遍历权重字典中的每一项并进行转换
38
- for key in state_dict:
39
- if isinstance(state_dict[key], torch.Tensor) and state_dict[key].is_floating_point():
40
- state_dict[key] = state_dict[key].to(target_dtype)
41
-
42
- print(f"正在保存纯推理模型到: {output_filename} ...")
43
- # 直接保存处理后的 state_dict,不包含任何训练相关的附加信息
44
- torch.save(state_dict, output_filename)
45
-
46
- # 打印大小对比
47
- original_size_gb = os.path.getsize(fp32_checkpoint_path) / (1024**3)
48
- final_size_gb = os.path.getsize(output_filename) / (1024**3)
49
-
50
- print("\n--- 转换成功 ---")
51
- print(f"原始训练检查点大小: {original_size_gb:.2f} GB")
52
- print(f"生成的纯推理模型大小 ({TARGET_PRECISION.upper()}): {final_size_gb:.2f} GB")
53
- print("说明: 新文件只包含用于推理的核心模型权重,已移除训练相关的优化器状态。")
54
-
55
- except Exception as e:
 
 
 
 
56
  print(f"\n[错误] 处理过程中发生错误: {e}")
 
1
+ import torch
2
+ import os
3
+
4
+ # --- 配置 ---
5
+ # 1. 设置目标精度: 'fp32', 'fp16', 或 'bf16'
6
+ # - 'fp32': 移除训练数据,保留32位全精度。
7
+ # - 'fp16': 转换为16位半精度,体积最小。
8
+ # - 'bf16': 转换为bfloat16精度,推荐RTX 30系及以上GPU。
9
+ TARGET_PRECISION = 'fp32'
10
+
11
+ # 2. 设置原始32位训练检查点文件的路径,例如"E:\comfyui\ComfyUI-aki-v1.3\models\SDMatte\1\SDMatte_plus.pth"
12
+ fp32_checkpoint_path = r"E:\comfyui\ComfyUI-aki-v1.3\models\SDMatte\1\SDMatte_plus.pth"
13
+
14
+ # --------------------------------------------------------------------
15
+
16
+ # 自动生成输出文件名
17
+ output_filename = fp32_checkpoint_path.replace('.pth', f'_{TARGET_PRECISION}_inference.pth')
18
+
19
+ if not os.path.exists(fp32_checkpoint_path):
20
+ print(f"[错误] 文件不存在: {fp32_checkpoint_path}")
21
+ else:
22
+ try:
23
+ print(f"--- 开始处理训练检查点: {fp32_checkpoint_path} ---")
24
+ full_checkpoint = torch.load(fp32_checkpoint_path, map_location="cpu", weights_only=False)
25
+
26
+ if 'model' in full_checkpoint:
27
+ state_dict = full_checkpoint['model']
28
+ print("成功提取到 'model' 键中的权重字典。")
29
+ else:
30
+ print("[警告] 未在顶层找到 'model' 键,将尝试转换整个文件。")
31
+ state_dict = full_checkpoint
32
+
33
+ print(f"目标输出精度: {TARGET_PRECISION}")
34
+
35
+ # --- MODIFICATION START ---
36
+ # 仅当目标精度不是 'fp32' 时,才执行类型转换
37
+ if TARGET_PRECISION != 'fp32':
38
+ print(f"开始将权重转换为 {TARGET_PRECISION} ...")
39
+ target_dtype = torch.float16 if TARGET_PRECISION == 'fp16' else torch.bfloat16
40
+
41
+ for key in state_dict:
42
+ if isinstance(state_dict[key], torch.Tensor) and state_dict[key].is_floating_point():
43
+ state_dict[key] = state_dict[key].to(target_dtype)
44
+ else:
45
+ print("保留原始 FP32 精度,仅剥离训练数据。")
46
+ # --- MODIFICATION END ---
47
+
48
+ print(f"正在保存纯推理模型到: {output_filename} ...")
49
+ torch.save(state_dict, output_filename)
50
+
51
+ original_size_gb = os.path.getsize(fp32_checkpoint_path) / (1024**3)
52
+ final_size_gb = os.path.getsize(output_filename) / (1024**3)
53
+
54
+ print("\n--- 转换成功 ---")
55
+ print(f"原始训练检查点大小: {original_size_gb:.2f} GB")
56
+ print(f"生成的纯推理模型大小 ({TARGET_PRECISION.upper()}): {final_size_gb:.2f} GB")
57
+ print("说明: 新文件只包含用于推理的核心模型权重,已移除训练相关的优化器状态。")
58
+
59
+ except Exception as e:
60
  print(f"\n[错误] 处理过程中发生错误: {e}")