| """ |
| 一次性验证 v6dc 初始化是否正确: |
| 1. direct_cls_head.weight 是否从 foa_cls ckpt 加载 |
| 2. 加载后在随机一批 fake 数据上的 cls 准确率是否合理(随机期望 1/63≈1.6%,初始化好应>20%) |
| 3. pre_readout_tokens 是否真的传到了 direct_cls_head |
| |
| 运行方式: |
| python debug_v6dc_init.py |
| """ |
| import sys |
| import torch |
|
|
| sys.path.insert(0, "/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats") |
|
|
| from train_spatial_beats import make_ov1_local_spatial_v6dc_classwarmup_config |
| from spatial_beats import SpatialBEATs |
|
|
| FOA_CLS_CKPT = "checkpoints/beats_ov1_foa_cls_v1/03_full/best.pt" |
|
|
| print("=" * 60) |
| print("Step 1: build model config") |
| cfg_wrapper = make_ov1_local_spatial_v6dc_classwarmup_config() |
| model_cfg = cfg_wrapper.model |
| print(f" use_direct_cls={model_cfg.use_direct_cls}") |
| print(f" readout_layers={model_cfg.readout_layers}") |
| print(f" bypass_spatial_delta={model_cfg.bypass_spatial_delta}") |
| print(f" class_finetuned_ckpt={cfg_wrapper.class_finetuned_ckpt}") |
|
|
| print("\nStep 2: build model") |
| model = SpatialBEATs(model_cfg) |
|
|
| print("\nStep 3: check direct_cls_head exists") |
| heads = model.local_spatial_prediction_heads |
| if heads is None: |
| print(" ERROR: local_spatial_prediction_heads is None!") |
| sys.exit(1) |
| if not hasattr(heads, "direct_cls_head"): |
| print(" ERROR: direct_cls_head does not exist in heads!") |
| sys.exit(1) |
| print(f" direct_cls_head: {heads.direct_cls_head}") |
| w_before = heads.direct_cls_head.weight.data.clone() |
| print(f" weight norm before load: {w_before.norm().item():.4f}") |
|
|
| print("\nStep 4: load foa_cls checkpoint") |
| |
| import os |
| if not os.path.exists(FOA_CLS_CKPT): |
| print(f" ERROR: ckpt not found: {FOA_CLS_CKPT}") |
| sys.exit(1) |
| model.load_event_classifier_checkpoint(FOA_CLS_CKPT) |
| w_after = heads.direct_cls_head.weight.data.clone() |
| print(f" weight norm after load: {w_after.norm().item():.4f}") |
| weight_changed = not torch.allclose(w_before, w_after) |
| print(f" weight changed: {weight_changed}") |
| if not weight_changed: |
| print(" *** PROBLEM: direct_cls_head.weight was NOT loaded! ***") |
| else: |
| print(" OK: direct_cls_head.weight was loaded from foa_cls ckpt") |
|
|
| |
| print("\nStep 5: compare with foa_cls classifier weight directly") |
| ckpt = torch.load(FOA_CLS_CKPT, map_location="cpu", weights_only=False) |
| sd = ckpt.get("model", ckpt) |
| if "classifier.weight" in sd: |
| cls_w = sd["classifier.weight"] |
| print(f" foa_cls classifier.weight shape: {cls_w.shape}, norm: {cls_w.norm().item():.4f}") |
| print(f" direct_cls_head.weight shape: {w_after.shape}, norm: {w_after.norm().item():.4f}") |
| match = torch.allclose(cls_w, w_after, atol=1e-5) |
| print(f" weights match exactly: {match}") |
| else: |
| print(" WARNING: 'classifier.weight' key not found in ckpt!") |
| print(f" Keys in ckpt: {list(sd.keys())[:20]}") |
|
|
| print("\nStep 6: forward pass with fake data to check cls logits") |
| model.eval() |
| with torch.no_grad(): |
| B = 4 |
| T = 16000 * 5 |
| fake_waveform = torch.randn(B, 4, T) |
| out = model(fake_waveform) |
| mp = out.mono_prediction_output |
| if mp is None: |
| print(" ERROR: mono_prediction_output is None!") |
| sys.exit(1) |
| logits = mp.pred_class_logits |
| probs = logits.softmax(dim=-1) |
| max_prob = probs.max(dim=-1).values |
| pred_cls = logits.argmax(dim=-1) |
| print(f" logits shape: {logits.shape}") |
| print(f" logit range: [{logits.min().item():.2f}, {logits.max().item():.2f}]") |
| print(f" max prob per sample: {max_prob.tolist()}") |
| print(f" pred classes: {pred_cls.tolist()}") |
| |
| |
| logit_std = logits.std().item() |
| print(f" logit std: {logit_std:.4f} (random init ≈ small, good init >> 0.1)") |
| if logit_std < 0.05: |
| print(" *** WARNING: logit_std very small — direct_cls_head may be outputting near-uniform ***") |
|
|
| print("\nStep 7: verify pre_readout_tokens path") |
| |
| from spatial_modules import ShallowTemporalReadout |
| tr = model.temporal_readout |
| print(f" temporal_readout encoder: {tr.encoder}") |
| if tr.encoder is None: |
| print(" OK: readout_layers=0, ShallowTemporalReadout is pure LayerNorm") |
| else: |
| print(" *** WARNING: readout_layers>0, there is still a Transformer in ShallowTemporalReadout ***") |
|
|
| print("\n" + "=" * 60) |
| print("SUMMARY:") |
| print(f" weight_changed (ckpt loaded): {weight_changed}") |
| print(f" readout_layers=0 (no Transformer): {tr.encoder is None}") |
| print(f" logit_std: {logit_std:.4f}") |
| print("=" * 60) |
|
|