File size: 4,712 Bytes
02e364a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
一次性验证 v6dc 初始化是否正确:
1. direct_cls_head.weight 是否从 foa_cls ckpt 加载
2. 加载后在随机一批 fake 数据上的 cls 准确率是否合理(随机期望 1/63≈1.6%,初始化好应>20%)
3. pre_readout_tokens 是否真的传到了 direct_cls_head

运行方式:
    python debug_v6dc_init.py
"""
import sys
import torch

sys.path.insert(0, "/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats")

from train_spatial_beats import make_ov1_local_spatial_v6dc_classwarmup_config
from spatial_beats import SpatialBEATs

FOA_CLS_CKPT = "checkpoints/beats_ov1_foa_cls_v1/03_full/best.pt"

print("=" * 60)
print("Step 1: build model config")
cfg_wrapper = make_ov1_local_spatial_v6dc_classwarmup_config()
model_cfg = cfg_wrapper.model
print(f"  use_direct_cls={model_cfg.use_direct_cls}")
print(f"  readout_layers={model_cfg.readout_layers}")
print(f"  bypass_spatial_delta={model_cfg.bypass_spatial_delta}")
print(f"  class_finetuned_ckpt={cfg_wrapper.class_finetuned_ckpt}")

print("\nStep 2: build model")
model = SpatialBEATs(model_cfg)

print("\nStep 3: check direct_cls_head exists")
heads = model.local_spatial_prediction_heads
if heads is None:
    print("  ERROR: local_spatial_prediction_heads is None!")
    sys.exit(1)
if not hasattr(heads, "direct_cls_head"):
    print("  ERROR: direct_cls_head does not exist in heads!")
    sys.exit(1)
print(f"  direct_cls_head: {heads.direct_cls_head}")
w_before = heads.direct_cls_head.weight.data.clone()
print(f"  weight norm before load: {w_before.norm().item():.4f}")

print("\nStep 4: load foa_cls checkpoint")
# check ckpt exists
import os
if not os.path.exists(FOA_CLS_CKPT):
    print(f"  ERROR: ckpt not found: {FOA_CLS_CKPT}")
    sys.exit(1)
model.load_event_classifier_checkpoint(FOA_CLS_CKPT)
w_after = heads.direct_cls_head.weight.data.clone()
print(f"  weight norm after load: {w_after.norm().item():.4f}")
weight_changed = not torch.allclose(w_before, w_after)
print(f"  weight changed: {weight_changed}")
if not weight_changed:
    print("  *** PROBLEM: direct_cls_head.weight was NOT loaded! ***")
else:
    print("  OK: direct_cls_head.weight was loaded from foa_cls ckpt")

# Also check what foa_cls classifier weight norm is
print("\nStep 5: compare with foa_cls classifier weight directly")
ckpt = torch.load(FOA_CLS_CKPT, map_location="cpu", weights_only=False)
sd = ckpt.get("model", ckpt)
if "classifier.weight" in sd:
    cls_w = sd["classifier.weight"]
    print(f"  foa_cls classifier.weight shape: {cls_w.shape}, norm: {cls_w.norm().item():.4f}")
    print(f"  direct_cls_head.weight shape: {w_after.shape}, norm: {w_after.norm().item():.4f}")
    match = torch.allclose(cls_w, w_after, atol=1e-5)
    print(f"  weights match exactly: {match}")
else:
    print("  WARNING: 'classifier.weight' key not found in ckpt!")
    print(f"  Keys in ckpt: {list(sd.keys())[:20]}")

print("\nStep 6: forward pass with fake data to check cls logits")
model.eval()
with torch.no_grad():
    B = 4
    T = 16000 * 5  # 5 seconds
    fake_waveform = torch.randn(B, 4, T)
    out = model(fake_waveform)
    mp = out.mono_prediction_output
    if mp is None:
        print("  ERROR: mono_prediction_output is None!")
        sys.exit(1)
    logits = mp.pred_class_logits  # [B, 63]
    probs = logits.softmax(dim=-1)
    max_prob = probs.max(dim=-1).values
    pred_cls = logits.argmax(dim=-1)
    print(f"  logits shape: {logits.shape}")
    print(f"  logit range: [{logits.min().item():.2f}, {logits.max().item():.2f}]")
    print(f"  max prob per sample: {max_prob.tolist()}")
    print(f"  pred classes: {pred_cls.tolist()}")
    # If properly initialized, logits should NOT be near-uniform
    # (uniform would be all ~0 logits since random init)
    logit_std = logits.std().item()
    print(f"  logit std: {logit_std:.4f}  (random init ≈ small, good init >> 0.1)")
    if logit_std < 0.05:
        print("  *** WARNING: logit_std very small — direct_cls_head may be outputting near-uniform ***")

print("\nStep 7: verify pre_readout_tokens path")
# Check readout_layers=0 means ShallowTemporalReadout is identity (only LN)
from spatial_modules import ShallowTemporalReadout
tr = model.temporal_readout
print(f"  temporal_readout encoder: {tr.encoder}")
if tr.encoder is None:
    print("  OK: readout_layers=0, ShallowTemporalReadout is pure LayerNorm")
else:
    print("  *** WARNING: readout_layers>0, there is still a Transformer in ShallowTemporalReadout ***")

print("\n" + "=" * 60)
print("SUMMARY:")
print(f"  weight_changed (ckpt loaded): {weight_changed}")
print(f"  readout_layers=0 (no Transformer): {tr.encoder is None}")
print(f"  logit_std: {logit_std:.4f}")
print("=" * 60)