Upload convert_precision.py
Browse files- convert_precision.py +59 -55
convert_precision.py
CHANGED
|
@@ -1,56 +1,60 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
# --- 配置 ---
|
| 5 |
-
# 1. 设置目标精度: 'fp16' 或 'bf16'
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
#
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
print("[警告] 未在顶层找到 'model' 键,将尝试转换整个文件。")
|
| 31 |
-
state_dict = full_checkpoint
|
| 32 |
-
|
| 33 |
-
print(f"
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
#
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
print(f"\n[错误] 处理过程中发生错误: {e}")
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# --- 配置 ---
|
| 5 |
+
# 1. 设置目标精度: 'fp32', 'fp16', 或 'bf16'
|
| 6 |
+
# - 'fp32': 移除训练数据,保留32位全精度。
|
| 7 |
+
# - 'fp16': 转换为16位半精度,体积最小。
|
| 8 |
+
# - 'bf16': 转换为bfloat16精度,推荐RTX 30系及以上GPU。
|
| 9 |
+
TARGET_PRECISION = 'fp32'
|
| 10 |
+
|
| 11 |
+
# 2. 设置原始32位训练检查点文件的路径,例如"E:\comfyui\ComfyUI-aki-v1.3\models\SDMatte\1\SDMatte_plus.pth"
|
| 12 |
+
fp32_checkpoint_path = r"E:\comfyui\ComfyUI-aki-v1.3\models\SDMatte\1\SDMatte_plus.pth"
|
| 13 |
+
|
| 14 |
+
# --------------------------------------------------------------------
|
| 15 |
+
|
| 16 |
+
# 自动生成输出文件名
|
| 17 |
+
output_filename = fp32_checkpoint_path.replace('.pth', f'_{TARGET_PRECISION}_inference.pth')
|
| 18 |
+
|
| 19 |
+
if not os.path.exists(fp32_checkpoint_path):
|
| 20 |
+
print(f"[错误] 文件不存在: {fp32_checkpoint_path}")
|
| 21 |
+
else:
|
| 22 |
+
try:
|
| 23 |
+
print(f"--- 开始处理训练检查点: {fp32_checkpoint_path} ---")
|
| 24 |
+
full_checkpoint = torch.load(fp32_checkpoint_path, map_location="cpu", weights_only=False)
|
| 25 |
+
|
| 26 |
+
if 'model' in full_checkpoint:
|
| 27 |
+
state_dict = full_checkpoint['model']
|
| 28 |
+
print("成功提取到 'model' 键中的权重字典。")
|
| 29 |
+
else:
|
| 30 |
+
print("[警告] 未在顶层找到 'model' 键,将尝试转换整个文件。")
|
| 31 |
+
state_dict = full_checkpoint
|
| 32 |
+
|
| 33 |
+
print(f"目标输出精度: {TARGET_PRECISION}")
|
| 34 |
+
|
| 35 |
+
# --- MODIFICATION START ---
|
| 36 |
+
# 仅当目标精度不是 'fp32' 时,才执行类型转换
|
| 37 |
+
if TARGET_PRECISION != 'fp32':
|
| 38 |
+
print(f"开始将权重转换为 {TARGET_PRECISION} ...")
|
| 39 |
+
target_dtype = torch.float16 if TARGET_PRECISION == 'fp16' else torch.bfloat16
|
| 40 |
+
|
| 41 |
+
for key in state_dict:
|
| 42 |
+
if isinstance(state_dict[key], torch.Tensor) and state_dict[key].is_floating_point():
|
| 43 |
+
state_dict[key] = state_dict[key].to(target_dtype)
|
| 44 |
+
else:
|
| 45 |
+
print("保留原始 FP32 精度,仅剥离训练数据。")
|
| 46 |
+
# --- MODIFICATION END ---
|
| 47 |
+
|
| 48 |
+
print(f"正在保存纯推理模型到: {output_filename} ...")
|
| 49 |
+
torch.save(state_dict, output_filename)
|
| 50 |
+
|
| 51 |
+
original_size_gb = os.path.getsize(fp32_checkpoint_path) / (1024**3)
|
| 52 |
+
final_size_gb = os.path.getsize(output_filename) / (1024**3)
|
| 53 |
+
|
| 54 |
+
print("\n--- 转换成功 ---")
|
| 55 |
+
print(f"原始训练检查点大小: {original_size_gb:.2f} GB")
|
| 56 |
+
print(f"生成的纯推理模型大小 ({TARGET_PRECISION.upper()}): {final_size_gb:.2f} GB")
|
| 57 |
+
print("说明: 新文件只包含用于推理的核心模型权重,已移除训练相关的优化器状态。")
|
| 58 |
+
|
| 59 |
+
except Exception as e:
|
| 60 |
print(f"\n[错误] 处理过程中发生错误: {e}")
|