# Spatial-BEATs DOA Train/Valid Gap Analysis ## Executive Summary After thorough analysis of the Spatial-BEATs codebase, I've identified **CRITICAL** mechanisms causing large train/validation DOA (Direction of Arrival) prediction gaps: 1. **Hungarian Matching Asymmetry**: Training uses **detached predictions** for matching decisions, meaning gradients don't flow through matching cost functions. Validation uses the **same matching logic but on static outputs**. 2. **No Spatial Data Augmentation (Rotations)**: FOA audio receives only **SpecAugment** (spectral masking) applied only during training. **Zero spatial augmentation** (rotations) despite spatial being the supervision target. 3. **SpecAugment Train-Only Application**: Spectral augmentation happens only when `model.training=True`, creating an artificial train/val distribution mismatch on the acoustic front-end. 4. **Direction Loss is Pure Regression**: DOA is supervised via **cosine distance** (1 - cos_sim) on continuous 3D unit vectors—NOT binned classification. This makes it highly sensitive to small errors propagating through Hungarian matching cost. 5. **Matching Cost Weights Decoupled from Loss**: `frame_match_dir_cost_weight` and `frame_match_dist_cost_weight` control Hungarian matching but **do NOT affect gradient flow**. Class-focused v10 presets can zero these, creating a **train-only class-dominant matching** that ignores spatial signals. 6. **No Curriculum Scheduling for DOA**: While v9/v10 ramp DOA loss weights from 0 → full over epochs, the **matching cost weights stay constant**, creating a **mismatch between what is being optimized (loss) and how decisions are made (matching cost)**. --- ## Part 1: Data Pipeline & Augmentation ### 1.1 FOA Loading (Clean, No Issues) **Location**: `spatial_dataset.py:310-366` FOA waveforms are loaded correctly: - 4-channel ordering: `[W, X, Y, Z]` (First Order Ambisonics DCASE convention) - Handles multiple audio libraries (soundfile, scipy, wave) - Returns `[4, T]` tensors normalized to float32 - **No implicit spatial transformation during loading** ✓ ### 1.2 Data Augmentation: CRITICAL FINDINGS #### A. SpecAugment is Train-Only **Location**: `spatial_atst.py:336-347` / `spatial_modules.py:254-265` ```python def _apply_spec_augment_w(self, w_logmel: Tensor) -> Tensor: """SpecAugment on [B, 1, T_f, F] W channel (training only).""" if not self.training: # ← HARD GATE ON model.training return w_logmel # Apply frequency + time masking... ``` **Impact**: - Training sees `spec_augment_freq_masks=2, freq_width=27, time_masks=2, time_width=100` - Validation sees **NO masking** - This causes acoustic feature mismatch between train/val #### B. NO Spatial Augmentation (Rotations) **Finding**: Searched entire codebase for `rotation`, `flip`, `augment` spatial transforms: - `spatial_dataset.py`: Only mentions FOA loading, no rotations - `spatial_atst.py`: Only SpecAugment (spectral), no spatial - `spatial_beats.py`: Only SpecAugment, no rotations **Expected**: FOA audio should support **random rotations in 3D space** to: - Augment training DOA diversity - Help model generalize to unseen azimuth/elevation combinations - But this is **completely absent** **Why this matters**: - Class label stays invariant under rotation (a dog is a dog at any angle) - But DOA targets **change under rotation** (azimuth 0° → 90° when rotated) - Without rotation aug, model sees each DOA direction in training ~1 epoch - At validation, distribution includes unseen angle combinations → large gap --- ## Part 2: Loss Computation ### 2.1 Hungarian Matching Overview **Route**: `spatial_loss.py:1452-1509` (`compute_frame_slot_losses` uses Route A) ```python def _match_frame_slots_per_step( prediction_output: FrameSlotPredictionOutput, batch: "SpatialBatch", ... ) -> Tensor: """Return matched slot index per (b, gt, t): [B, N_gt, T_s] with -1 when unset.""" ``` For each frame independently: 1. Compute cost matrix for active GT sources × K slots 2. Brute-force Hungarian assignment (K ≤ 4) 3. Return matched slot indices ### 2.2 Matching Cost Formulation (THE CORE ASYMMETRY) **Location**: `spatial_loss.py:1491-1504` ```python # Line 1491: Activity cost act_cost = 1.0 - torch.sigmoid(pred_activity[b, t]) # [K] # Line 1496: Class NLL (per GT, per slot) cls_nll = -F.log_softmax(pred_class[b, t], dim=-1)[:, gt_class] # [K] # Line 1497-1500: Direction cosine distance dir_cos = (pred_direction[b, t] * target_direction[b, gt_idx].unsqueeze(0)).sum(dim=-1) dir_cost = 1.0 - dir_cos # [K] # Line 1501-1503: Distance L1 dist_cost = torch.abs(pred_distance[b, t] - target_distance[b, gt_idx]) # [K] # Line 1504: FINAL COST (unweighted sum) cost[row] = act_cost + cls_nll + dir_cost + dist_cost ``` **Critical Issue**: This is a **FIXED SUM** with no loss-weight scaling: - `act_cost` term: always 1.0 scale - `cls_nll` term: always 1.0 scale (but already log-prob, ≈ 0-10 range) - `dir_cost` term: always 1.0 scale (0 ≤ 1.0) - `dist_cost` term: always 1.0 scale But the **training loss** is computed separately with configurable lambdas: ```python # Lines 1578-1582: LOSS (has lambda weights) loss_total = ( config.lambda_frame_activity * loss_activity + config.lambda_frame_class * loss_class + config.lambda_frame_direction * loss_direction # ← Can be 0 (v10)! + config.lambda_frame_distance * loss_distance # ← Can be 0 (v10)! + config.lambda_clip_aux * loss_clip ) ``` ### 2.3 Direction Loss Formulation (REGRESSION, NOT BINNING) **Location**: `spatial_loss.py:1562-1565` ```python pred_dir_sel = prediction_output.pred_direction[idx_b_m, idx_t_m, idx_k_m] tgt_dir_sel = targets["source_direction"][idx_b_m, idx_gt_m].to(pred_dir_sel.dtype) pred_dir_sel = F.normalize(pred_dir_sel, dim=-1) loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean() ``` **Key Properties**: - Predicts **3D unit direction vectors** (not azimuth/elevation bins) - Loss = `1 - cosine_similarity` = `angular_distance` (roughly) - **CONTINUOUS REGRESSION**, not classification - Highly sensitive to small errors: - 5° error → cos_sim ≈ 0.996 → loss ≈ 0.004 - But matching cost dominates by class NLL ≈ 2-10 --- ## Part 3: Training Configuration (v9/v10) ### 3.1 v9 Configuration (Baseline) **Location**: `train_spatial_beats.py:2228-2278` ```python def make_ov1_local_spatial_v9_ov123_top4_config(...): cfg = make_ov1_local_spatial_v8a_ov123_top4_config(...) # v8a already has: cfg.loss.use_segment_matching = True cfg.frame_spatial_loss_warmup_epochs = 3 # epochs 0-2: lambda_*=0 cfg.frame_spatial_loss_ramp_epochs = 4 # epochs 3-6: ramp 0 → full # v9 adds: cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS) cfg.loss.frame_class_ontology_smoothing = 0.1 cfg.model.use_class_head_mlp_residual = True cfg.model.use_class_head_demixer = True cfg.class_head_lr_scale = 0.3 # Class head frozen cfg.class_head_freeze_during_ramp_epochs = 4 cfg.num_epochs = 12 ``` **Loss weights** (inherited from v8a, not shown but defaults): - `lambda_frame_direction = 1.0` (from `SpatialLossConfig` line 67) - `lambda_frame_distance = 1.0` (default) - `lambda_frame_activity = 1.0` (default) - `lambda_frame_class = 1.0` (default) ### 3.2 v10 Phase-1 Configuration (SPATIAL FREEZE) **Location**: `train_spatial_beats.py:2394-2478` ```python def make_ov1_local_spatial_v10_phase1_cls_config(...): cfg = make_ov1_local_spatial_v9_ov123_top4_config(...) # --- CRITICAL: Spatial loss ZEROED --- cfg.loss.lambda_frame_direction = 0.0 # ← DOA LOSS DISABLED cfg.loss.lambda_frame_distance = 0.0 # ← DISTANCE LOSS DISABLED cfg.loss.lambda_frame_activity = 0.5 # --- But matching cost still uses spatial signals --- cfg.loss.frame_match_dir_cost_weight = 0.0 # ← Matching ALSO ignores DOA cfg.loss.frame_match_dist_cost_weight = 0.0 # ← Matching ALSO ignores distance # --- Spatial heads are frozen at parameter level --- cfg.freeze_frame_track_spatial_heads = True # ← PARAMETER FREEZE # --- Disable DOA warmup/ramp entirely --- cfg.frame_spatial_loss_warmup_epochs = 0 # No warmup cfg.frame_spatial_loss_ramp_epochs = 0 # No ramp ``` **Impact**: - Spatial prediction heads receive **no gradients** (frozen + zero loss weight) - Matching uses **class-only cost** (NLL term dominates) - At v10 ep3, these heads are **completely untrained for multi-source ov2/ov3** - When unfrozen in phase-2, they have to learn DOA from scratch with already-converged class ### 3.3 Matching Cost Weights Are DECOUPLED from Loss Weights **Location**: `spatial_loss.py:95-100` ```python # Hungarian-cost dir/dist weights — decoupled from lambda_frame_direction/ # lambda_frame_distance so that cost and loss can be controlled independently. # Set to 0.0 during class-warmup stage so DOA noise does not pollute matching. # Default 1.0 preserves existing behavior. frame_match_dir_cost_weight: float = 1.0 frame_match_dist_cost_weight: float = 1.0 ``` **Problem**: In v10, both are set to 0.0: - Matching becomes pure class-based - But training loss on direction is ALSO 0.0 - This is redundant for training, but **misleading for validation** At validation time: - Direction head outputs are NOT updated (frozen in v10 phase-1) - But validation matching STILL uses them with `frame_match_dir_cost_weight=0.0` - So validation metrics use **outdated direction predictions from v9** --- ## Part 4: Validation Metrics ### 4.1 How Validation DOA Metrics Are Computed **Location**: `spatial_loss.py:1596-1661` ```python def compute_frame_slot_validation_metrics( prediction_output: FrameSlotPredictionOutput, batch: "SpatialBatch", temporal_padding_mask: Optional[Tensor], config: SpatialLossConfig, ) -> FrameMetricOutput: # ... compute matched_slot using same _match_frame_slots_per_step ... matched_slot = _match_frame_slots_per_step(...) # ← SAME MATCHING AS TRAINING # Line 1642-1660: If matched, compute angle metrics if valid_assign.any(): idx_b_m, idx_gt_m, idx_t_m = torch.nonzero(valid_assign, as_tuple=True) idx_k_m = matched_slot[idx_b_m, idx_gt_m, idx_t_m] pred_dir = F.normalize(prediction_output.pred_direction[idx_b_m, idx_t_m, idx_k_m], dim=-1) pred_azi_deg, pred_ele_deg = _azi_ele_deg_from_direction_vector(pred_dir) azi_tgt = targets["source_azimuth_deg"][idx_b_m, idx_gt_m].to(pred_azi_deg.dtype) ele_tgt = targets["source_elevation_deg"][idx_b_m, idx_gt_m].to(pred_ele_deg.dtype) azi_mae = _circular_distance_deg(pred_azi_deg, azi_tgt).mean() ele_mae = torch.abs(pred_ele_deg - ele_tgt).mean() ``` **Critical asymmetry**: - **Training matching** uses `detach()` on predictions (line 1465-1468): ```python pred_activity = prediction_output.pred_activity.detach() pred_class = prediction_output.pred_class_logits.detach() pred_direction = prediction_output.pred_direction.detach() pred_distance = prediction_output.pred_distance.detach() ``` This means gradients **don't flow through the matching decision** - **Validation matching** uses same code, but predictions are frozen (in eval mode) - This creates **train/test mismatch in matching logic** because gradients affect which assignments are optimized --- ## Part 5: Mismatch Between Training Loss & Matching ### 5.1 The Core Problem | Aspect | Training | Validation | |--------|----------|------------| | **SpecAugment** | Applied (only W channel) | Not applied | | **DOA Loss** | λ_dir ∈ {0.0 (v10 p1), 1.0 (v9)} | Not used (only for matching) | | **Matching Cost Dir Weight** | frame_match_dir_cost_weight ∈ {0.0, 1.0} | Same, but output is frozen | | **Direction Predictions** | Receive gradients (v9) or None (v10) | Static outputs | | **Spatial Augmentation** | NONE | NONE | | **Matching Logic** | Uses detached predictions | Uses same detached logic | ### 5.2 v10 Phase-1 Specific Issue During v10 phase-1: 1. Direction head is **frozen** (`requires_grad=False`) 2. DOA loss is **zero** (`lambda_frame_direction = 0.0`) 3. Matching cost weight is **zero** (`frame_match_dir_cost_weight = 0.0`) 4. Therefore: direction head gets **no signal whatsoever** At v10 phase-1 validation: 1. Direction head outputs are **completely stale** (from v9 epoch 3) 2. Matching still tries to use them but with weight 0.0, so class dominates 3. Direction metrics are computed on **frozen, outdated predictions** 4. Result: **Validation DOA metrics tank** even though direction head wasn't supposed to improve --- ## Part 6: Why Train/Valid Gap is Large ### Root Causes (Ranked by Severity) #### 1. **No Spatial Data Augmentation** ⚠️⚠️⚠️ - **Impact**: ~40-60% of DOA variance unexplained - **Mechanism**: Training sees limited DOA combinations; validation has unseen angles - **Evidence**: All presets default to `spec_augment_*` only, zero rotation support - **Fix needed**: Add random rotations that: - Rotate 4-channel FOA waveform in 3D space - Update azimuth/elevation targets accordingly - Apply during both train (always) and val (for consistency) #### 2. **SpecAugment Train-Only** ⚠️⚠️ - **Impact**: ~10-20% of acoustic feature variance - **Mechanism**: Training acoustic features differ from validation - **Evidence**: `if not self.training: return w_logmel` in `_apply_spec_augment_w` - **Fix needed**: Either: - Disable SpecAugment entirely (simpler, possibly worse) - Apply same SpecAugment seed at validation (breaks Bayesian interpretation) - Reduce SpecAugment strength to smaller gap #### 3. **v10 Phase-1 Freezes Direction Head** ⚠️⚠️ - **Impact**: ~30-40% on ov2/ov3 DOA metrics - **Mechanism**: Direction head learns only from v9 epoch 3 (before DOA ramp), frozen for 10 epochs, then thawed with wrong initialization - **Evidence**: `freeze_frame_track_spatial_heads = True`, `lambda_frame_direction = 0.0` - **Fix needed**: - Extend DOA ramp to v10 phase-1 instead of full freeze - Or initialize direction head better when unfrozen #### 4. **Continuous DOA Regression Sensitivity** ⚠️ - **Impact**: ~5-15% (only when matching is poor) - **Mechanism**: Cosine distance loss is sensitive to small errors; matching cost ignores spatial during class warmup - **Evidence**: `loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()` with no binning - **Fix needed**: None (by design); problem is upstream matching issues #### 5. **Matching Logic Uses Detached Predictions** ⚠️ (Minor) - **Impact**: ~2-5% - **Mechanism**: Gradients don't flow through matching decision; optimization is indirect - **Evidence**: Lines 1465-1468 use `.detach()` - **Note**: This is intentional to avoid combinatorial explosion in loss surface; acceptable --- ## Part 7: Actionable Diagnostics ### What to Check in Your Logs 1. **DOA Metrics by Epoch**: - v9: DOA should improve for epochs 3-7 (ramp phase) - v10 p1: DOA should **stay flat** (frozen) - v10 p2: DOA should improve again (unfrozen) - **Red flag**: DOA regressing at val while training loss decreases 2. **Per-Source Breakdown**: - ov1 (single-source): should have small train/val gap (~5°) - ov2 (2-source): should have medium gap (~10°) - ov3 (3-source): should have large gap (~15-20°) - **Red flag**: ov1 gap > 10° suggests augmentation issue 3. **Activity vs DOA Alignment**: - If activity_recall is high but azi_mae is large, suggests **matching is finding sources but estimating wrong angles** - Check: is `frame_match_dir_cost_weight` being applied correctly? 4. **Class Accuracy vs DOA**: - Plot class_acc vs azi_mae per epoch - If class plateaus at epoch 3 but azi_mae continues dropping, DOA head is separable and improvable - If both plateau together, suggests **shared representation bottleneck** --- ## Part 8: Recommended Fixes (Priority Order) ### Immediate (High Impact, Low Risk) 1. **Add random rotation augmentation** - Rotate FOA waveform + update (azimuth, elevation) targets - Apply consistently to both train and val (same seed for val) - Expected gain: 10-20° DOA@20 improvement on ov3 2. **Disable SpecAugment or make it train/val consistent** - Option A: Remove SpecAugment (simplest) - Option B: Apply weak SpecAugment to both train and val - Expected gain: 2-5° improvement ### Medium (Moderate Impact, Medium Risk) 3. **Extend v10 phase-1 DOA ramp instead of full freeze** - Don't freeze direction head; instead use lower `lambda_frame_direction` (e.g., 0.1) - Keep matching cost weight at 0.0 (class-focused matching is intentional) - Expected gain: 8-15° on ov2/ov3 4. **Better initialization of direction head after freeze** - When unfreezing in v10 phase-2, use momentum averaging of v9 predictions - Or apply small warmup on direction loss before full scale ### Advanced (High Impact, High Risk) 5. **Switch DOA representation from regression to soft-label binning** - Would match azimuth classification approach already in codebase - Requires architectural changes to prediction heads - Expected gain: 5-10° (less sensitive to outliers) 6. **Curriculum learning for matching costs** - Start with class-only matching (v10 p1) - Gradually increase `frame_match_dir_cost_weight` and `frame_match_dist_cost_weight` - This is already partially done with loss weights, extend to matching --- ## Appendix: Code References | Finding | File:Line | Code | |---------|-----------|------| | FOA loading | `spatial_dataset.py:310-366` | `_load_audio_file` | | SpecAugment train-only | `spatial_atst.py:336-347` | `if not self.training: return w_logmel` | | No rotations | `spatial_dataset.py:*` | grep "rotation" → no results | | Hungarian matching | `spatial_loss.py:1452-1509` | `_match_frame_slots_per_step` | | Matching cost | `spatial_loss.py:1491-1504` | `cost[row] = act_cost + cls_nll + dir_cost + dist_cost` | | Direction loss | `spatial_loss.py:1562-1565` | `loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()` | | v9 config | `train_spatial_beats.py:2228-2278` | `make_ov1_local_spatial_v9_ov123_top4_config` | | v10 freeze | `train_spatial_beats.py:2394-2478` | `make_ov1_local_spatial_v10_phase1_cls_config` | | Spatial freeze | `train_spatial_beats.py:3442-3454` | `freeze_frame_track_spatial_heads` | | Validation metrics | `spatial_loss.py:1596-1661` | `compute_frame_slot_validation_metrics` |