Spatial-BEATs DOA Train/Valid Gap Analysis
Executive Summary
After thorough analysis of the Spatial-BEATs codebase, I've identified CRITICAL mechanisms causing large train/validation DOA (Direction of Arrival) prediction gaps:
Hungarian Matching Asymmetry: Training uses detached predictions for matching decisions, meaning gradients don't flow through matching cost functions. Validation uses the same matching logic but on static outputs.
No Spatial Data Augmentation (Rotations): FOA audio receives only SpecAugment (spectral masking) applied only during training. Zero spatial augmentation (rotations) despite spatial being the supervision target.
SpecAugment Train-Only Application: Spectral augmentation happens only when
model.training=True, creating an artificial train/val distribution mismatch on the acoustic front-end.Direction Loss is Pure Regression: DOA is supervised via cosine distance (1 - cos_sim) on continuous 3D unit vectors—NOT binned classification. This makes it highly sensitive to small errors propagating through Hungarian matching cost.
Matching Cost Weights Decoupled from Loss:
frame_match_dir_cost_weightandframe_match_dist_cost_weightcontrol Hungarian matching but do NOT affect gradient flow. Class-focused v10 presets can zero these, creating a train-only class-dominant matching that ignores spatial signals.No Curriculum Scheduling for DOA: While v9/v10 ramp DOA loss weights from 0 → full over epochs, the matching cost weights stay constant, creating a mismatch between what is being optimized (loss) and how decisions are made (matching cost).
Part 1: Data Pipeline & Augmentation
1.1 FOA Loading (Clean, No Issues)
Location: spatial_dataset.py:310-366
FOA waveforms are loaded correctly:
- 4-channel ordering:
[W, X, Y, Z](First Order Ambisonics DCASE convention) - Handles multiple audio libraries (soundfile, scipy, wave)
- Returns
[4, T]tensors normalized to float32 - No implicit spatial transformation during loading ✓
1.2 Data Augmentation: CRITICAL FINDINGS
A. SpecAugment is Train-Only
Location: spatial_atst.py:336-347 / spatial_modules.py:254-265
def _apply_spec_augment_w(self, w_logmel: Tensor) -> Tensor:
"""SpecAugment on [B, 1, T_f, F] W channel (training only)."""
if not self.training: # ← HARD GATE ON model.training
return w_logmel
# Apply frequency + time masking...
Impact:
- Training sees
spec_augment_freq_masks=2, freq_width=27, time_masks=2, time_width=100 - Validation sees NO masking
- This causes acoustic feature mismatch between train/val
B. NO Spatial Augmentation (Rotations)
Finding: Searched entire codebase for rotation, flip, augment spatial transforms:
spatial_dataset.py: Only mentions FOA loading, no rotationsspatial_atst.py: Only SpecAugment (spectral), no spatialspatial_beats.py: Only SpecAugment, no rotations
Expected: FOA audio should support random rotations in 3D space to:
- Augment training DOA diversity
- Help model generalize to unseen azimuth/elevation combinations
- But this is completely absent
Why this matters:
- Class label stays invariant under rotation (a dog is a dog at any angle)
- But DOA targets change under rotation (azimuth 0° → 90° when rotated)
- Without rotation aug, model sees each DOA direction in training ~1 epoch
- At validation, distribution includes unseen angle combinations → large gap
Part 2: Loss Computation
2.1 Hungarian Matching Overview
Route: spatial_loss.py:1452-1509 (compute_frame_slot_losses uses Route A)
def _match_frame_slots_per_step(
prediction_output: FrameSlotPredictionOutput,
batch: "SpatialBatch",
...
) -> Tensor:
"""Return matched slot index per (b, gt, t): [B, N_gt, T_s] with -1 when unset."""
For each frame independently:
- Compute cost matrix for active GT sources × K slots
- Brute-force Hungarian assignment (K ≤ 4)
- Return matched slot indices
2.2 Matching Cost Formulation (THE CORE ASYMMETRY)
Location: spatial_loss.py:1491-1504
# Line 1491: Activity cost
act_cost = 1.0 - torch.sigmoid(pred_activity[b, t]) # [K]
# Line 1496: Class NLL (per GT, per slot)
cls_nll = -F.log_softmax(pred_class[b, t], dim=-1)[:, gt_class] # [K]
# Line 1497-1500: Direction cosine distance
dir_cos = (pred_direction[b, t] * target_direction[b, gt_idx].unsqueeze(0)).sum(dim=-1)
dir_cost = 1.0 - dir_cos # [K]
# Line 1501-1503: Distance L1
dist_cost = torch.abs(pred_distance[b, t] - target_distance[b, gt_idx]) # [K]
# Line 1504: FINAL COST (unweighted sum)
cost[row] = act_cost + cls_nll + dir_cost + dist_cost
Critical Issue: This is a FIXED SUM with no loss-weight scaling:
act_costterm: always 1.0 scalecls_nllterm: always 1.0 scale (but already log-prob, ≈ 0-10 range)dir_costterm: always 1.0 scale (0 ≤ 1.0)dist_costterm: always 1.0 scale
But the training loss is computed separately with configurable lambdas:
# Lines 1578-1582: LOSS (has lambda weights)
loss_total = (
config.lambda_frame_activity * loss_activity
+ config.lambda_frame_class * loss_class
+ config.lambda_frame_direction * loss_direction # ← Can be 0 (v10)!
+ config.lambda_frame_distance * loss_distance # ← Can be 0 (v10)!
+ config.lambda_clip_aux * loss_clip
)
2.3 Direction Loss Formulation (REGRESSION, NOT BINNING)
Location: spatial_loss.py:1562-1565
pred_dir_sel = prediction_output.pred_direction[idx_b_m, idx_t_m, idx_k_m]
tgt_dir_sel = targets["source_direction"][idx_b_m, idx_gt_m].to(pred_dir_sel.dtype)
pred_dir_sel = F.normalize(pred_dir_sel, dim=-1)
loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()
Key Properties:
- Predicts 3D unit direction vectors (not azimuth/elevation bins)
- Loss =
1 - cosine_similarity=angular_distance(roughly) - CONTINUOUS REGRESSION, not classification
- Highly sensitive to small errors:
- 5° error → cos_sim ≈ 0.996 → loss ≈ 0.004
- But matching cost dominates by class NLL ≈ 2-10
Part 3: Training Configuration (v9/v10)
3.1 v9 Configuration (Baseline)
Location: train_spatial_beats.py:2228-2278
def make_ov1_local_spatial_v9_ov123_top4_config(...):
cfg = make_ov1_local_spatial_v8a_ov123_top4_config(...)
# v8a already has:
cfg.loss.use_segment_matching = True
cfg.frame_spatial_loss_warmup_epochs = 3 # epochs 0-2: lambda_*=0
cfg.frame_spatial_loss_ramp_epochs = 4 # epochs 3-6: ramp 0 → full
# v9 adds:
cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS)
cfg.loss.frame_class_ontology_smoothing = 0.1
cfg.model.use_class_head_mlp_residual = True
cfg.model.use_class_head_demixer = True
cfg.class_head_lr_scale = 0.3 # Class head frozen
cfg.class_head_freeze_during_ramp_epochs = 4
cfg.num_epochs = 12
Loss weights (inherited from v8a, not shown but defaults):
lambda_frame_direction = 1.0(fromSpatialLossConfigline 67)lambda_frame_distance = 1.0(default)lambda_frame_activity = 1.0(default)lambda_frame_class = 1.0(default)
3.2 v10 Phase-1 Configuration (SPATIAL FREEZE)
Location: train_spatial_beats.py:2394-2478
def make_ov1_local_spatial_v10_phase1_cls_config(...):
cfg = make_ov1_local_spatial_v9_ov123_top4_config(...)
# --- CRITICAL: Spatial loss ZEROED ---
cfg.loss.lambda_frame_direction = 0.0 # ← DOA LOSS DISABLED
cfg.loss.lambda_frame_distance = 0.0 # ← DISTANCE LOSS DISABLED
cfg.loss.lambda_frame_activity = 0.5
# --- But matching cost still uses spatial signals ---
cfg.loss.frame_match_dir_cost_weight = 0.0 # ← Matching ALSO ignores DOA
cfg.loss.frame_match_dist_cost_weight = 0.0 # ← Matching ALSO ignores distance
# --- Spatial heads are frozen at parameter level ---
cfg.freeze_frame_track_spatial_heads = True # ← PARAMETER FREEZE
# --- Disable DOA warmup/ramp entirely ---
cfg.frame_spatial_loss_warmup_epochs = 0 # No warmup
cfg.frame_spatial_loss_ramp_epochs = 0 # No ramp
Impact:
- Spatial prediction heads receive no gradients (frozen + zero loss weight)
- Matching uses class-only cost (NLL term dominates)
- At v10 ep3, these heads are completely untrained for multi-source ov2/ov3
- When unfrozen in phase-2, they have to learn DOA from scratch with already-converged class
3.3 Matching Cost Weights Are DECOUPLED from Loss Weights
Location: spatial_loss.py:95-100
# Hungarian-cost dir/dist weights — decoupled from lambda_frame_direction/
# lambda_frame_distance so that cost and loss can be controlled independently.
# Set to 0.0 during class-warmup stage so DOA noise does not pollute matching.
# Default 1.0 preserves existing behavior.
frame_match_dir_cost_weight: float = 1.0
frame_match_dist_cost_weight: float = 1.0
Problem: In v10, both are set to 0.0:
- Matching becomes pure class-based
- But training loss on direction is ALSO 0.0
- This is redundant for training, but misleading for validation
At validation time:
- Direction head outputs are NOT updated (frozen in v10 phase-1)
- But validation matching STILL uses them with
frame_match_dir_cost_weight=0.0 - So validation metrics use outdated direction predictions from v9
Part 4: Validation Metrics
4.1 How Validation DOA Metrics Are Computed
Location: spatial_loss.py:1596-1661
def compute_frame_slot_validation_metrics(
prediction_output: FrameSlotPredictionOutput,
batch: "SpatialBatch",
temporal_padding_mask: Optional[Tensor],
config: SpatialLossConfig,
) -> FrameMetricOutput:
# ... compute matched_slot using same _match_frame_slots_per_step ...
matched_slot = _match_frame_slots_per_step(...) # ← SAME MATCHING AS TRAINING
# Line 1642-1660: If matched, compute angle metrics
if valid_assign.any():
idx_b_m, idx_gt_m, idx_t_m = torch.nonzero(valid_assign, as_tuple=True)
idx_k_m = matched_slot[idx_b_m, idx_gt_m, idx_t_m]
pred_dir = F.normalize(prediction_output.pred_direction[idx_b_m, idx_t_m, idx_k_m], dim=-1)
pred_azi_deg, pred_ele_deg = _azi_ele_deg_from_direction_vector(pred_dir)
azi_tgt = targets["source_azimuth_deg"][idx_b_m, idx_gt_m].to(pred_azi_deg.dtype)
ele_tgt = targets["source_elevation_deg"][idx_b_m, idx_gt_m].to(pred_ele_deg.dtype)
azi_mae = _circular_distance_deg(pred_azi_deg, azi_tgt).mean()
ele_mae = torch.abs(pred_ele_deg - ele_tgt).mean()
Critical asymmetry:
Training matching uses
detach()on predictions (line 1465-1468):pred_activity = prediction_output.pred_activity.detach() pred_class = prediction_output.pred_class_logits.detach() pred_direction = prediction_output.pred_direction.detach() pred_distance = prediction_output.pred_distance.detach()This means gradients don't flow through the matching decision
Validation matching uses same code, but predictions are frozen (in eval mode)
This creates train/test mismatch in matching logic because gradients affect which assignments are optimized
Part 5: Mismatch Between Training Loss & Matching
5.1 The Core Problem
| Aspect | Training | Validation |
|---|---|---|
| SpecAugment | Applied (only W channel) | Not applied |
| DOA Loss | λ_dir ∈ {0.0 (v10 p1), 1.0 (v9)} | Not used (only for matching) |
| Matching Cost Dir Weight | frame_match_dir_cost_weight ∈ {0.0, 1.0} | Same, but output is frozen |
| Direction Predictions | Receive gradients (v9) or None (v10) | Static outputs |
| Spatial Augmentation | NONE | NONE |
| Matching Logic | Uses detached predictions | Uses same detached logic |
5.2 v10 Phase-1 Specific Issue
During v10 phase-1:
- Direction head is frozen (
requires_grad=False) - DOA loss is zero (
lambda_frame_direction = 0.0) - Matching cost weight is zero (
frame_match_dir_cost_weight = 0.0) - Therefore: direction head gets no signal whatsoever
At v10 phase-1 validation:
- Direction head outputs are completely stale (from v9 epoch 3)
- Matching still tries to use them but with weight 0.0, so class dominates
- Direction metrics are computed on frozen, outdated predictions
- Result: Validation DOA metrics tank even though direction head wasn't supposed to improve
Part 6: Why Train/Valid Gap is Large
Root Causes (Ranked by Severity)
1. No Spatial Data Augmentation ⚠️⚠️⚠️
- Impact: ~40-60% of DOA variance unexplained
- Mechanism: Training sees limited DOA combinations; validation has unseen angles
- Evidence: All presets default to
spec_augment_*only, zero rotation support - Fix needed: Add random rotations that:
- Rotate 4-channel FOA waveform in 3D space
- Update azimuth/elevation targets accordingly
- Apply during both train (always) and val (for consistency)
2. SpecAugment Train-Only ⚠️⚠️
- Impact: ~10-20% of acoustic feature variance
- Mechanism: Training acoustic features differ from validation
- Evidence:
if not self.training: return w_logmelin_apply_spec_augment_w - Fix needed: Either:
- Disable SpecAugment entirely (simpler, possibly worse)
- Apply same SpecAugment seed at validation (breaks Bayesian interpretation)
- Reduce SpecAugment strength to smaller gap
3. v10 Phase-1 Freezes Direction Head ⚠️⚠️
- Impact: ~30-40% on ov2/ov3 DOA metrics
- Mechanism: Direction head learns only from v9 epoch 3 (before DOA ramp), frozen for 10 epochs, then thawed with wrong initialization
- Evidence:
freeze_frame_track_spatial_heads = True,lambda_frame_direction = 0.0 - Fix needed:
- Extend DOA ramp to v10 phase-1 instead of full freeze
- Or initialize direction head better when unfrozen
4. Continuous DOA Regression Sensitivity ⚠️
- Impact: ~5-15% (only when matching is poor)
- Mechanism: Cosine distance loss is sensitive to small errors; matching cost ignores spatial during class warmup
- Evidence:
loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()with no binning - Fix needed: None (by design); problem is upstream matching issues
5. Matching Logic Uses Detached Predictions ⚠️ (Minor)
- Impact: ~2-5%
- Mechanism: Gradients don't flow through matching decision; optimization is indirect
- Evidence: Lines 1465-1468 use
.detach() - Note: This is intentional to avoid combinatorial explosion in loss surface; acceptable
Part 7: Actionable Diagnostics
What to Check in Your Logs
DOA Metrics by Epoch:
- v9: DOA should improve for epochs 3-7 (ramp phase)
- v10 p1: DOA should stay flat (frozen)
- v10 p2: DOA should improve again (unfrozen)
- Red flag: DOA regressing at val while training loss decreases
Per-Source Breakdown:
- ov1 (single-source): should have small train/val gap (~5°)
- ov2 (2-source): should have medium gap (~10°)
- ov3 (3-source): should have large gap (~15-20°)
- Red flag: ov1 gap > 10° suggests augmentation issue
Activity vs DOA Alignment:
- If activity_recall is high but azi_mae is large, suggests matching is finding sources but estimating wrong angles
- Check: is
frame_match_dir_cost_weightbeing applied correctly?
Class Accuracy vs DOA:
- Plot class_acc vs azi_mae per epoch
- If class plateaus at epoch 3 but azi_mae continues dropping, DOA head is separable and improvable
- If both plateau together, suggests shared representation bottleneck
Part 8: Recommended Fixes (Priority Order)
Immediate (High Impact, Low Risk)
Add random rotation augmentation
- Rotate FOA waveform + update (azimuth, elevation) targets
- Apply consistently to both train and val (same seed for val)
- Expected gain: 10-20° DOA@20 improvement on ov3
Disable SpecAugment or make it train/val consistent
- Option A: Remove SpecAugment (simplest)
- Option B: Apply weak SpecAugment to both train and val
- Expected gain: 2-5° improvement
Medium (Moderate Impact, Medium Risk)
Extend v10 phase-1 DOA ramp instead of full freeze
- Don't freeze direction head; instead use lower
lambda_frame_direction(e.g., 0.1) - Keep matching cost weight at 0.0 (class-focused matching is intentional)
- Expected gain: 8-15° on ov2/ov3
- Don't freeze direction head; instead use lower
Better initialization of direction head after freeze
- When unfreezing in v10 phase-2, use momentum averaging of v9 predictions
- Or apply small warmup on direction loss before full scale
Advanced (High Impact, High Risk)
Switch DOA representation from regression to soft-label binning
- Would match azimuth classification approach already in codebase
- Requires architectural changes to prediction heads
- Expected gain: 5-10° (less sensitive to outliers)
Curriculum learning for matching costs
- Start with class-only matching (v10 p1)
- Gradually increase
frame_match_dir_cost_weightandframe_match_dist_cost_weight - This is already partially done with loss weights, extend to matching
Appendix: Code References
| Finding | File:Line | Code |
|---|---|---|
| FOA loading | spatial_dataset.py:310-366 |
_load_audio_file |
| SpecAugment train-only | spatial_atst.py:336-347 |
if not self.training: return w_logmel |
| No rotations | spatial_dataset.py:* |
grep "rotation" → no results |
| Hungarian matching | spatial_loss.py:1452-1509 |
_match_frame_slots_per_step |
| Matching cost | spatial_loss.py:1491-1504 |
cost[row] = act_cost + cls_nll + dir_cost + dist_cost |
| Direction loss | spatial_loss.py:1562-1565 |
loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean() |
| v9 config | train_spatial_beats.py:2228-2278 |
make_ov1_local_spatial_v9_ov123_top4_config |
| v10 freeze | train_spatial_beats.py:2394-2478 |
make_ov1_local_spatial_v10_phase1_cls_config |
| Spatial freeze | train_spatial_beats.py:3442-3454 |
freeze_frame_track_spatial_heads |
| Validation metrics | spatial_loss.py:1596-1661 |
compute_frame_slot_validation_metrics |