""" 一次性验证 v6dc 初始化是否正确: 1. direct_cls_head.weight 是否从 foa_cls ckpt 加载 2. 加载后在随机一批 fake 数据上的 cls 准确率是否合理(随机期望 1/63≈1.6%,初始化好应>20%) 3. pre_readout_tokens 是否真的传到了 direct_cls_head 运行方式: python debug_v6dc_init.py """ import sys import torch sys.path.insert(0, "/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats") from train_spatial_beats import make_ov1_local_spatial_v6dc_classwarmup_config from spatial_beats import SpatialBEATs FOA_CLS_CKPT = "checkpoints/beats_ov1_foa_cls_v1/03_full/best.pt" print("=" * 60) print("Step 1: build model config") cfg_wrapper = make_ov1_local_spatial_v6dc_classwarmup_config() model_cfg = cfg_wrapper.model print(f" use_direct_cls={model_cfg.use_direct_cls}") print(f" readout_layers={model_cfg.readout_layers}") print(f" bypass_spatial_delta={model_cfg.bypass_spatial_delta}") print(f" class_finetuned_ckpt={cfg_wrapper.class_finetuned_ckpt}") print("\nStep 2: build model") model = SpatialBEATs(model_cfg) print("\nStep 3: check direct_cls_head exists") heads = model.local_spatial_prediction_heads if heads is None: print(" ERROR: local_spatial_prediction_heads is None!") sys.exit(1) if not hasattr(heads, "direct_cls_head"): print(" ERROR: direct_cls_head does not exist in heads!") sys.exit(1) print(f" direct_cls_head: {heads.direct_cls_head}") w_before = heads.direct_cls_head.weight.data.clone() print(f" weight norm before load: {w_before.norm().item():.4f}") print("\nStep 4: load foa_cls checkpoint") # check ckpt exists import os if not os.path.exists(FOA_CLS_CKPT): print(f" ERROR: ckpt not found: {FOA_CLS_CKPT}") sys.exit(1) model.load_event_classifier_checkpoint(FOA_CLS_CKPT) w_after = heads.direct_cls_head.weight.data.clone() print(f" weight norm after load: {w_after.norm().item():.4f}") weight_changed = not torch.allclose(w_before, w_after) print(f" weight changed: {weight_changed}") if not weight_changed: print(" *** PROBLEM: direct_cls_head.weight was NOT loaded! ***") else: print(" OK: direct_cls_head.weight was loaded from foa_cls ckpt") # Also check what foa_cls classifier weight norm is print("\nStep 5: compare with foa_cls classifier weight directly") ckpt = torch.load(FOA_CLS_CKPT, map_location="cpu", weights_only=False) sd = ckpt.get("model", ckpt) if "classifier.weight" in sd: cls_w = sd["classifier.weight"] print(f" foa_cls classifier.weight shape: {cls_w.shape}, norm: {cls_w.norm().item():.4f}") print(f" direct_cls_head.weight shape: {w_after.shape}, norm: {w_after.norm().item():.4f}") match = torch.allclose(cls_w, w_after, atol=1e-5) print(f" weights match exactly: {match}") else: print(" WARNING: 'classifier.weight' key not found in ckpt!") print(f" Keys in ckpt: {list(sd.keys())[:20]}") print("\nStep 6: forward pass with fake data to check cls logits") model.eval() with torch.no_grad(): B = 4 T = 16000 * 5 # 5 seconds fake_waveform = torch.randn(B, 4, T) out = model(fake_waveform) mp = out.mono_prediction_output if mp is None: print(" ERROR: mono_prediction_output is None!") sys.exit(1) logits = mp.pred_class_logits # [B, 63] probs = logits.softmax(dim=-1) max_prob = probs.max(dim=-1).values pred_cls = logits.argmax(dim=-1) print(f" logits shape: {logits.shape}") print(f" logit range: [{logits.min().item():.2f}, {logits.max().item():.2f}]") print(f" max prob per sample: {max_prob.tolist()}") print(f" pred classes: {pred_cls.tolist()}") # If properly initialized, logits should NOT be near-uniform # (uniform would be all ~0 logits since random init) logit_std = logits.std().item() print(f" logit std: {logit_std:.4f} (random init ≈ small, good init >> 0.1)") if logit_std < 0.05: print(" *** WARNING: logit_std very small — direct_cls_head may be outputting near-uniform ***") print("\nStep 7: verify pre_readout_tokens path") # Check readout_layers=0 means ShallowTemporalReadout is identity (only LN) from spatial_modules import ShallowTemporalReadout tr = model.temporal_readout print(f" temporal_readout encoder: {tr.encoder}") if tr.encoder is None: print(" OK: readout_layers=0, ShallowTemporalReadout is pure LayerNorm") else: print(" *** WARNING: readout_layers>0, there is still a Transformer in ShallowTemporalReadout ***") print("\n" + "=" * 60) print("SUMMARY:") print(f" weight_changed (ckpt loaded): {weight_changed}") print(f" readout_layers=0 (no Transformer): {tr.encoder is None}") print(f" logit_std: {logit_std:.4f}") print("=" * 60)