#!/usr/bin/env python3 """对比两个 checkpoint 的 safetensor 文件,检查权重是否真的在变化""" import sys from safetensors.torch import load_file import torch import os # ============================== # 在这里手动填写 checkpoint 路径 # ============================== CHECKPOINT_1 = r"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors" CHECKPOINT_2 = r"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors" # ============================== def compare_safetensors(path1, path2): print(f"\n{'='*80}") print(f"对比两个 checkpoint:") print(f" Checkpoint 1: {path1}") print(f" Checkpoint 2: {path2}") print(f"{'='*80}\n") try: state_dict1 = load_file(path1) state_dict2 = load_file(path2) # 检查键是否一致 keys1 = set(state_dict1.keys()) keys2 = set(state_dict2.keys()) if keys1 != keys2: print("⚠️ 警告: 两个 checkpoint 的键不一致!") print(f" 只在 checkpoint1 中: {keys1 - keys2}") print(f" 只在 checkpoint2 中: {keys2 - keys1}") return print(f"✅ 两个 checkpoint 都有 {len(keys1)} 个参数张量\n") # 统计差异 identical_count = 0 different_count = 0 max_diff_info = None max_diff = 0 layer_diffs = {} for key in sorted(keys1): tensor1 = state_dict1[key] tensor2 = state_dict2[key] diff = (tensor2 - tensor1).float() abs_diff = diff.abs() max_abs_diff = abs_diff.max().item() mean_abs_diff = abs_diff.mean().item() if max_abs_diff == 0: identical_count += 1 else: different_count += 1 if max_abs_diff > max_diff: max_diff = max_abs_diff max_diff_info = { 'key': key, 'max_diff': max_abs_diff, 'mean_diff': mean_abs_diff, 'tensor1_max': tensor1.float().abs().max().item(), 'tensor2_max': tensor2.float().abs().max().item(), } if '.lora_B.' in key: layer_name = key.split('.lora_B.')[0] if layer_name not in layer_diffs: layer_diffs[layer_name] = { 'max_diff': max_abs_diff, 'mean_diff': mean_abs_diff, 'key': key } print(f"差异统计:") print(f" 完全相同的参数: {identical_count} / {len(keys1)} ({identical_count/len(keys1)*100:.2f}%)") print(f" 有变化的参数: {different_count} / {len(keys1)} ({different_count/len(keys1)*100:.2f}%)") print() if max_diff_info: print(f"最大权重变化:") print(f" 层: {max_diff_info['key']}") print(f" 最大绝对差异: {max_diff_info['max_diff']:.6e}") print(f" 平均绝对差异: {max_diff_info['mean_diff']:.6e}") print(f" Checkpoint1 最大值: {max_diff_info['tensor1_max']:.6e}") print(f" Checkpoint2 最大值: {max_diff_info['tensor2_max']:.6e}") print() key_layers = ['x_embedder', 'transformer_blocks.0', 'transformer_blocks.9', 'single_transformer_blocks.30', 'proj_out'] print("关键层的 lora_B 权重变化:") print("-" * 80) for layer_prefix in key_layers: matching = [k for k in layer_diffs.keys() if layer_prefix in k] if matching: for layer_name in matching[:2]: info = layer_diffs[layer_name] print(f"\n层: {layer_name}") print(f" 最大差异: {info['max_diff']:.6e}") print(f" 平均差异: {info['mean_diff']:.6e}") key = info['key'] t1 = state_dict1[key].float() t2 = state_dict2[key].float() print(f" Checkpoint1: mean={t1.mean():.6e}, max={t1.abs().max():.6e}") print(f" Checkpoint2: mean={t2.mean():.6e}, max={t2.abs().max():.6e}") print("\n" + "="*80) print("lora_B 权重变化最大的前 10 个层:") print("-" * 80) sorted_layers = sorted(layer_diffs.items(), key=lambda x: x[1]['max_diff'], reverse=True) for i, (layer_name, info) in enumerate(sorted_layers[:10], 1): key = info['key'] t1 = state_dict1[key].float() t2 = state_dict2[key].float() print(f"\n{i}. {layer_name}") print(f" 最大差异: {info['max_diff']:.6e}, 平均差异: {info['mean_diff']:.6e}") print(f" Ckpt1: mean={t1.mean():.6e}, max={t1.abs().max():.6e}") print(f" Ckpt2: mean={t2.mean():.6e}, max={t2.abs().max():.6e}") print("\n" + "="*80) if different_count == 0: print("❌ 严重问题: 两个 checkpoint 完全相同,模型没有学习!") elif different_count < len(keys1) * 0.1: print(f"⚠️ 警告: 只有 {different_count/len(keys1)*100:.2f}% 的参数在变化,可能存在梯度阻塞") else: print(f"✅ 正常: {different_count/len(keys1)*100:.2f}% 的参数在变化") if max_diff < 1e-6: print(f"⚠️ 但是: 最大变化只有 {max_diff:.6e},变化幅度可能太小") print("="*80) except Exception as e: print(f"❌ 错误: {e}") import traceback traceback.print_exc() if __name__ == "__main__": compare_safetensors(CHECKPOINT_1, CHECKPOINT_2)