| |
| """对比两个 checkpoint 的 safetensor 文件,检查权重是否真的在变化""" |
| import sys |
| from safetensors.torch import load_file |
| import torch |
| import os |
|
|
| |
| |
| |
| CHECKPOINT_1 = r"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors" |
| CHECKPOINT_2 = r"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors" |
| |
|
|
|
|
| def compare_safetensors(path1, path2): |
| print(f"\n{'='*80}") |
| print(f"对比两个 checkpoint:") |
| print(f" Checkpoint 1: {path1}") |
| print(f" Checkpoint 2: {path2}") |
| print(f"{'='*80}\n") |
| |
| try: |
| state_dict1 = load_file(path1) |
| state_dict2 = load_file(path2) |
| |
| |
| keys1 = set(state_dict1.keys()) |
| keys2 = set(state_dict2.keys()) |
| |
| if keys1 != keys2: |
| print("⚠️ 警告: 两个 checkpoint 的键不一致!") |
| print(f" 只在 checkpoint1 中: {keys1 - keys2}") |
| print(f" 只在 checkpoint2 中: {keys2 - keys1}") |
| return |
| |
| print(f"✅ 两个 checkpoint 都有 {len(keys1)} 个参数张量\n") |
| |
| |
| identical_count = 0 |
| different_count = 0 |
| max_diff_info = None |
| max_diff = 0 |
| |
| layer_diffs = {} |
| |
| for key in sorted(keys1): |
| tensor1 = state_dict1[key] |
| tensor2 = state_dict2[key] |
| |
| diff = (tensor2 - tensor1).float() |
| abs_diff = diff.abs() |
| |
| max_abs_diff = abs_diff.max().item() |
| mean_abs_diff = abs_diff.mean().item() |
| |
| if max_abs_diff == 0: |
| identical_count += 1 |
| else: |
| different_count += 1 |
| |
| if max_abs_diff > max_diff: |
| max_diff = max_abs_diff |
| max_diff_info = { |
| 'key': key, |
| 'max_diff': max_abs_diff, |
| 'mean_diff': mean_abs_diff, |
| 'tensor1_max': tensor1.float().abs().max().item(), |
| 'tensor2_max': tensor2.float().abs().max().item(), |
| } |
| |
| if '.lora_B.' in key: |
| layer_name = key.split('.lora_B.')[0] |
| if layer_name not in layer_diffs: |
| layer_diffs[layer_name] = { |
| 'max_diff': max_abs_diff, |
| 'mean_diff': mean_abs_diff, |
| 'key': key |
| } |
| |
| print(f"差异统计:") |
| print(f" 完全相同的参数: {identical_count} / {len(keys1)} ({identical_count/len(keys1)*100:.2f}%)") |
| print(f" 有变化的参数: {different_count} / {len(keys1)} ({different_count/len(keys1)*100:.2f}%)") |
| print() |
| |
| if max_diff_info: |
| print(f"最大权重变化:") |
| print(f" 层: {max_diff_info['key']}") |
| print(f" 最大绝对差异: {max_diff_info['max_diff']:.6e}") |
| print(f" 平均绝对差异: {max_diff_info['mean_diff']:.6e}") |
| print(f" Checkpoint1 最大值: {max_diff_info['tensor1_max']:.6e}") |
| print(f" Checkpoint2 最大值: {max_diff_info['tensor2_max']:.6e}") |
| print() |
| |
| key_layers = ['x_embedder', 'transformer_blocks.0', 'transformer_blocks.9', |
| 'single_transformer_blocks.30', 'proj_out'] |
| |
| print("关键层的 lora_B 权重变化:") |
| print("-" * 80) |
| for layer_prefix in key_layers: |
| matching = [k for k in layer_diffs.keys() if layer_prefix in k] |
| if matching: |
| for layer_name in matching[:2]: |
| info = layer_diffs[layer_name] |
| print(f"\n层: {layer_name}") |
| print(f" 最大差异: {info['max_diff']:.6e}") |
| print(f" 平均差异: {info['mean_diff']:.6e}") |
| |
| key = info['key'] |
| t1 = state_dict1[key].float() |
| t2 = state_dict2[key].float() |
| print(f" Checkpoint1: mean={t1.mean():.6e}, max={t1.abs().max():.6e}") |
| print(f" Checkpoint2: mean={t2.mean():.6e}, max={t2.abs().max():.6e}") |
| |
| print("\n" + "="*80) |
| print("lora_B 权重变化最大的前 10 个层:") |
| print("-" * 80) |
| sorted_layers = sorted(layer_diffs.items(), key=lambda x: x[1]['max_diff'], reverse=True) |
| |
| for i, (layer_name, info) in enumerate(sorted_layers[:10], 1): |
| key = info['key'] |
| t1 = state_dict1[key].float() |
| t2 = state_dict2[key].float() |
| print(f"\n{i}. {layer_name}") |
| print(f" 最大差异: {info['max_diff']:.6e}, 平均差异: {info['mean_diff']:.6e}") |
| print(f" Ckpt1: mean={t1.mean():.6e}, max={t1.abs().max():.6e}") |
| print(f" Ckpt2: mean={t2.mean():.6e}, max={t2.abs().max():.6e}") |
| |
| print("\n" + "="*80) |
| if different_count == 0: |
| print("❌ 严重问题: 两个 checkpoint 完全相同,模型没有学习!") |
| elif different_count < len(keys1) * 0.1: |
| print(f"⚠️ 警告: 只有 {different_count/len(keys1)*100:.2f}% 的参数在变化,可能存在梯度阻塞") |
| else: |
| print(f"✅ 正常: {different_count/len(keys1)*100:.2f}% 的参数在变化") |
| if max_diff < 1e-6: |
| print(f"⚠️ 但是: 最大变化只有 {max_diff:.6e},变化幅度可能太小") |
| print("="*80) |
| |
| except Exception as e: |
| print(f"❌ 错误: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
|
|
| if __name__ == "__main__": |
| compare_safetensors(CHECKPOINT_1, CHECKPOINT_2) |
|
|