kuai / diffusion-dpo-test /compare_checkpoints.py
Larer's picture
Add files using upload-large-folder tool
5c19a88
#!/usr/bin/env python3
"""对比两个 checkpoint 的 safetensor 文件,检查权重是否真的在变化"""
import sys
from safetensors.torch import load_file
import torch
import os
# ==============================
# 在这里手动填写 checkpoint 路径
# ==============================
CHECKPOINT_1 = r"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors"
CHECKPOINT_2 = r"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors"
# ==============================
def compare_safetensors(path1, path2):
print(f"\n{'='*80}")
print(f"对比两个 checkpoint:")
print(f" Checkpoint 1: {path1}")
print(f" Checkpoint 2: {path2}")
print(f"{'='*80}\n")
try:
state_dict1 = load_file(path1)
state_dict2 = load_file(path2)
# 检查键是否一致
keys1 = set(state_dict1.keys())
keys2 = set(state_dict2.keys())
if keys1 != keys2:
print("⚠️ 警告: 两个 checkpoint 的键不一致!")
print(f" 只在 checkpoint1 中: {keys1 - keys2}")
print(f" 只在 checkpoint2 中: {keys2 - keys1}")
return
print(f"✅ 两个 checkpoint 都有 {len(keys1)} 个参数张量\n")
# 统计差异
identical_count = 0
different_count = 0
max_diff_info = None
max_diff = 0
layer_diffs = {}
for key in sorted(keys1):
tensor1 = state_dict1[key]
tensor2 = state_dict2[key]
diff = (tensor2 - tensor1).float()
abs_diff = diff.abs()
max_abs_diff = abs_diff.max().item()
mean_abs_diff = abs_diff.mean().item()
if max_abs_diff == 0:
identical_count += 1
else:
different_count += 1
if max_abs_diff > max_diff:
max_diff = max_abs_diff
max_diff_info = {
'key': key,
'max_diff': max_abs_diff,
'mean_diff': mean_abs_diff,
'tensor1_max': tensor1.float().abs().max().item(),
'tensor2_max': tensor2.float().abs().max().item(),
}
if '.lora_B.' in key:
layer_name = key.split('.lora_B.')[0]
if layer_name not in layer_diffs:
layer_diffs[layer_name] = {
'max_diff': max_abs_diff,
'mean_diff': mean_abs_diff,
'key': key
}
print(f"差异统计:")
print(f" 完全相同的参数: {identical_count} / {len(keys1)} ({identical_count/len(keys1)*100:.2f}%)")
print(f" 有变化的参数: {different_count} / {len(keys1)} ({different_count/len(keys1)*100:.2f}%)")
print()
if max_diff_info:
print(f"最大权重变化:")
print(f" 层: {max_diff_info['key']}")
print(f" 最大绝对差异: {max_diff_info['max_diff']:.6e}")
print(f" 平均绝对差异: {max_diff_info['mean_diff']:.6e}")
print(f" Checkpoint1 最大值: {max_diff_info['tensor1_max']:.6e}")
print(f" Checkpoint2 最大值: {max_diff_info['tensor2_max']:.6e}")
print()
key_layers = ['x_embedder', 'transformer_blocks.0', 'transformer_blocks.9',
'single_transformer_blocks.30', 'proj_out']
print("关键层的 lora_B 权重变化:")
print("-" * 80)
for layer_prefix in key_layers:
matching = [k for k in layer_diffs.keys() if layer_prefix in k]
if matching:
for layer_name in matching[:2]:
info = layer_diffs[layer_name]
print(f"\n层: {layer_name}")
print(f" 最大差异: {info['max_diff']:.6e}")
print(f" 平均差异: {info['mean_diff']:.6e}")
key = info['key']
t1 = state_dict1[key].float()
t2 = state_dict2[key].float()
print(f" Checkpoint1: mean={t1.mean():.6e}, max={t1.abs().max():.6e}")
print(f" Checkpoint2: mean={t2.mean():.6e}, max={t2.abs().max():.6e}")
print("\n" + "="*80)
print("lora_B 权重变化最大的前 10 个层:")
print("-" * 80)
sorted_layers = sorted(layer_diffs.items(), key=lambda x: x[1]['max_diff'], reverse=True)
for i, (layer_name, info) in enumerate(sorted_layers[:10], 1):
key = info['key']
t1 = state_dict1[key].float()
t2 = state_dict2[key].float()
print(f"\n{i}. {layer_name}")
print(f" 最大差异: {info['max_diff']:.6e}, 平均差异: {info['mean_diff']:.6e}")
print(f" Ckpt1: mean={t1.mean():.6e}, max={t1.abs().max():.6e}")
print(f" Ckpt2: mean={t2.mean():.6e}, max={t2.abs().max():.6e}")
print("\n" + "="*80)
if different_count == 0:
print("❌ 严重问题: 两个 checkpoint 完全相同,模型没有学习!")
elif different_count < len(keys1) * 0.1:
print(f"⚠️ 警告: 只有 {different_count/len(keys1)*100:.2f}% 的参数在变化,可能存在梯度阻塞")
else:
print(f"✅ 正常: {different_count/len(keys1)*100:.2f}% 的参数在变化")
if max_diff < 1e-6:
print(f"⚠️ 但是: 最大变化只有 {max_diff:.6e},变化幅度可能太小")
print("="*80)
except Exception as e:
print(f"❌ 错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
compare_safetensors(CHECKPOINT_1, CHECKPOINT_2)