Spatial-BEATs / debug_v6dc_init.py
dieKarotte's picture
Add files using upload-large-folder tool
02e364a verified
Raw
History Blame Contribute Delete
4.71 kB
"""
一次性验证 v6dc 初始化是否正确:
1. direct_cls_head.weight 是否从 foa_cls ckpt 加载
2. 加载后在随机一批 fake 数据上的 cls 准确率是否合理(随机期望 1/63≈1.6%,初始化好应>20%)
3. pre_readout_tokens 是否真的传到了 direct_cls_head
运行方式:
python debug_v6dc_init.py
"""
import sys
import torch
sys.path.insert(0, "/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats")
from train_spatial_beats import make_ov1_local_spatial_v6dc_classwarmup_config
from spatial_beats import SpatialBEATs
FOA_CLS_CKPT = "checkpoints/beats_ov1_foa_cls_v1/03_full/best.pt"
print("=" * 60)
print("Step 1: build model config")
cfg_wrapper = make_ov1_local_spatial_v6dc_classwarmup_config()
model_cfg = cfg_wrapper.model
print(f" use_direct_cls={model_cfg.use_direct_cls}")
print(f" readout_layers={model_cfg.readout_layers}")
print(f" bypass_spatial_delta={model_cfg.bypass_spatial_delta}")
print(f" class_finetuned_ckpt={cfg_wrapper.class_finetuned_ckpt}")
print("\nStep 2: build model")
model = SpatialBEATs(model_cfg)
print("\nStep 3: check direct_cls_head exists")
heads = model.local_spatial_prediction_heads
if heads is None:
print(" ERROR: local_spatial_prediction_heads is None!")
sys.exit(1)
if not hasattr(heads, "direct_cls_head"):
print(" ERROR: direct_cls_head does not exist in heads!")
sys.exit(1)
print(f" direct_cls_head: {heads.direct_cls_head}")
w_before = heads.direct_cls_head.weight.data.clone()
print(f" weight norm before load: {w_before.norm().item():.4f}")
print("\nStep 4: load foa_cls checkpoint")
# check ckpt exists
import os
if not os.path.exists(FOA_CLS_CKPT):
print(f" ERROR: ckpt not found: {FOA_CLS_CKPT}")
sys.exit(1)
model.load_event_classifier_checkpoint(FOA_CLS_CKPT)
w_after = heads.direct_cls_head.weight.data.clone()
print(f" weight norm after load: {w_after.norm().item():.4f}")
weight_changed = not torch.allclose(w_before, w_after)
print(f" weight changed: {weight_changed}")
if not weight_changed:
print(" *** PROBLEM: direct_cls_head.weight was NOT loaded! ***")
else:
print(" OK: direct_cls_head.weight was loaded from foa_cls ckpt")
# Also check what foa_cls classifier weight norm is
print("\nStep 5: compare with foa_cls classifier weight directly")
ckpt = torch.load(FOA_CLS_CKPT, map_location="cpu", weights_only=False)
sd = ckpt.get("model", ckpt)
if "classifier.weight" in sd:
cls_w = sd["classifier.weight"]
print(f" foa_cls classifier.weight shape: {cls_w.shape}, norm: {cls_w.norm().item():.4f}")
print(f" direct_cls_head.weight shape: {w_after.shape}, norm: {w_after.norm().item():.4f}")
match = torch.allclose(cls_w, w_after, atol=1e-5)
print(f" weights match exactly: {match}")
else:
print(" WARNING: 'classifier.weight' key not found in ckpt!")
print(f" Keys in ckpt: {list(sd.keys())[:20]}")
print("\nStep 6: forward pass with fake data to check cls logits")
model.eval()
with torch.no_grad():
B = 4
T = 16000 * 5 # 5 seconds
fake_waveform = torch.randn(B, 4, T)
out = model(fake_waveform)
mp = out.mono_prediction_output
if mp is None:
print(" ERROR: mono_prediction_output is None!")
sys.exit(1)
logits = mp.pred_class_logits # [B, 63]
probs = logits.softmax(dim=-1)
max_prob = probs.max(dim=-1).values
pred_cls = logits.argmax(dim=-1)
print(f" logits shape: {logits.shape}")
print(f" logit range: [{logits.min().item():.2f}, {logits.max().item():.2f}]")
print(f" max prob per sample: {max_prob.tolist()}")
print(f" pred classes: {pred_cls.tolist()}")
# If properly initialized, logits should NOT be near-uniform
# (uniform would be all ~0 logits since random init)
logit_std = logits.std().item()
print(f" logit std: {logit_std:.4f} (random init ≈ small, good init >> 0.1)")
if logit_std < 0.05:
print(" *** WARNING: logit_std very small — direct_cls_head may be outputting near-uniform ***")
print("\nStep 7: verify pre_readout_tokens path")
# Check readout_layers=0 means ShallowTemporalReadout is identity (only LN)
from spatial_modules import ShallowTemporalReadout
tr = model.temporal_readout
print(f" temporal_readout encoder: {tr.encoder}")
if tr.encoder is None:
print(" OK: readout_layers=0, ShallowTemporalReadout is pure LayerNorm")
else:
print(" *** WARNING: readout_layers>0, there is still a Transformer in ShallowTemporalReadout ***")
print("\n" + "=" * 60)
print("SUMMARY:")
print(f" weight_changed (ckpt loaded): {weight_changed}")
print(f" readout_layers=0 (no Transformer): {tr.encoder is None}")
print(f" logit_std: {logit_std:.4f}")
print("=" * 60)