| # Spatial-BEATs DOA Train/Valid Gap Analysis |
|
|
| ## Executive Summary |
|
|
| After thorough analysis of the Spatial-BEATs codebase, I've identified **CRITICAL** mechanisms causing large train/validation DOA (Direction of Arrival) prediction gaps: |
|
|
| 1. **Hungarian Matching Asymmetry**: Training uses **detached predictions** for matching decisions, meaning gradients don't flow through matching cost functions. Validation uses the **same matching logic but on static outputs**. |
|
|
| 2. **No Spatial Data Augmentation (Rotations)**: FOA audio receives only **SpecAugment** (spectral masking) applied only during training. **Zero spatial augmentation** (rotations) despite spatial being the supervision target. |
|
|
| 3. **SpecAugment Train-Only Application**: Spectral augmentation happens only when `model.training=True`, creating an artificial train/val distribution mismatch on the acoustic front-end. |
|
|
| 4. **Direction Loss is Pure Regression**: DOA is supervised via **cosine distance** (1 - cos_sim) on continuous 3D unit vectors—NOT binned classification. This makes it highly sensitive to small errors propagating through Hungarian matching cost. |
| |
| 5. **Matching Cost Weights Decoupled from Loss**: `frame_match_dir_cost_weight` and `frame_match_dist_cost_weight` control Hungarian matching but **do NOT affect gradient flow**. Class-focused v10 presets can zero these, creating a **train-only class-dominant matching** that ignores spatial signals. |
| |
| 6. **No Curriculum Scheduling for DOA**: While v9/v10 ramp DOA loss weights from 0 → full over epochs, the **matching cost weights stay constant**, creating a **mismatch between what is being optimized (loss) and how decisions are made (matching cost)**. |
| |
| --- |
| |
| ## Part 1: Data Pipeline & Augmentation |
| |
| ### 1.1 FOA Loading (Clean, No Issues) |
| |
| **Location**: `spatial_dataset.py:310-366` |
|
|
| FOA waveforms are loaded correctly: |
| - 4-channel ordering: `[W, X, Y, Z]` (First Order Ambisonics DCASE convention) |
| - Handles multiple audio libraries (soundfile, scipy, wave) |
| - Returns `[4, T]` tensors normalized to float32 |
| - **No implicit spatial transformation during loading** ✓ |
|
|
| ### 1.2 Data Augmentation: CRITICAL FINDINGS |
|
|
| #### A. SpecAugment is Train-Only |
|
|
| **Location**: `spatial_atst.py:336-347` / `spatial_modules.py:254-265` |
|
|
| ```python |
| def _apply_spec_augment_w(self, w_logmel: Tensor) -> Tensor: |
| """SpecAugment on [B, 1, T_f, F] W channel (training only).""" |
| if not self.training: # ← HARD GATE ON model.training |
| return w_logmel |
| # Apply frequency + time masking... |
| ``` |
|
|
| **Impact**: |
| - Training sees `spec_augment_freq_masks=2, freq_width=27, time_masks=2, time_width=100` |
| - Validation sees **NO masking** |
| - This causes acoustic feature mismatch between train/val |
|
|
| #### B. NO Spatial Augmentation (Rotations) |
|
|
| **Finding**: Searched entire codebase for `rotation`, `flip`, `augment` spatial transforms: |
| - `spatial_dataset.py`: Only mentions FOA loading, no rotations |
| - `spatial_atst.py`: Only SpecAugment (spectral), no spatial |
| - `spatial_beats.py`: Only SpecAugment, no rotations |
|
|
| **Expected**: FOA audio should support **random rotations in 3D space** to: |
| - Augment training DOA diversity |
| - Help model generalize to unseen azimuth/elevation combinations |
| - But this is **completely absent** |
|
|
| **Why this matters**: |
| - Class label stays invariant under rotation (a dog is a dog at any angle) |
| - But DOA targets **change under rotation** (azimuth 0° → 90° when rotated) |
| - Without rotation aug, model sees each DOA direction in training ~1 epoch |
| - At validation, distribution includes unseen angle combinations → large gap |
|
|
| --- |
|
|
| ## Part 2: Loss Computation |
|
|
| ### 2.1 Hungarian Matching Overview |
|
|
| **Route**: `spatial_loss.py:1452-1509` (`compute_frame_slot_losses` uses Route A) |
|
|
| ```python |
| def _match_frame_slots_per_step( |
| prediction_output: FrameSlotPredictionOutput, |
| batch: "SpatialBatch", |
| ... |
| ) -> Tensor: |
| """Return matched slot index per (b, gt, t): [B, N_gt, T_s] with -1 when unset.""" |
| ``` |
|
|
| For each frame independently: |
| 1. Compute cost matrix for active GT sources × K slots |
| 2. Brute-force Hungarian assignment (K ≤ 4) |
| 3. Return matched slot indices |
|
|
| ### 2.2 Matching Cost Formulation (THE CORE ASYMMETRY) |
|
|
| **Location**: `spatial_loss.py:1491-1504` |
|
|
| ```python |
| # Line 1491: Activity cost |
| act_cost = 1.0 - torch.sigmoid(pred_activity[b, t]) # [K] |
| |
| # Line 1496: Class NLL (per GT, per slot) |
| cls_nll = -F.log_softmax(pred_class[b, t], dim=-1)[:, gt_class] # [K] |
| |
| # Line 1497-1500: Direction cosine distance |
| dir_cos = (pred_direction[b, t] * target_direction[b, gt_idx].unsqueeze(0)).sum(dim=-1) |
| dir_cost = 1.0 - dir_cos # [K] |
| |
| # Line 1501-1503: Distance L1 |
| dist_cost = torch.abs(pred_distance[b, t] - target_distance[b, gt_idx]) # [K] |
| |
| # Line 1504: FINAL COST (unweighted sum) |
| cost[row] = act_cost + cls_nll + dir_cost + dist_cost |
| ``` |
|
|
| **Critical Issue**: This is a **FIXED SUM** with no loss-weight scaling: |
| - `act_cost` term: always 1.0 scale |
| - `cls_nll` term: always 1.0 scale (but already log-prob, ≈ 0-10 range) |
| - `dir_cost` term: always 1.0 scale (0 ≤ 1.0) |
| - `dist_cost` term: always 1.0 scale |
|
|
| But the **training loss** is computed separately with configurable lambdas: |
|
|
| ```python |
| # Lines 1578-1582: LOSS (has lambda weights) |
| loss_total = ( |
| config.lambda_frame_activity * loss_activity |
| + config.lambda_frame_class * loss_class |
| + config.lambda_frame_direction * loss_direction # ← Can be 0 (v10)! |
| + config.lambda_frame_distance * loss_distance # ← Can be 0 (v10)! |
| + config.lambda_clip_aux * loss_clip |
| ) |
| ``` |
|
|
| ### 2.3 Direction Loss Formulation (REGRESSION, NOT BINNING) |
|
|
| **Location**: `spatial_loss.py:1562-1565` |
|
|
| ```python |
| pred_dir_sel = prediction_output.pred_direction[idx_b_m, idx_t_m, idx_k_m] |
| tgt_dir_sel = targets["source_direction"][idx_b_m, idx_gt_m].to(pred_dir_sel.dtype) |
| pred_dir_sel = F.normalize(pred_dir_sel, dim=-1) |
| loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean() |
| ``` |
|
|
| **Key Properties**: |
| - Predicts **3D unit direction vectors** (not azimuth/elevation bins) |
| - Loss = `1 - cosine_similarity` = `angular_distance` (roughly) |
| - **CONTINUOUS REGRESSION**, not classification |
| - Highly sensitive to small errors: |
| - 5° error → cos_sim ≈ 0.996 → loss ≈ 0.004 |
| - But matching cost dominates by class NLL ≈ 2-10 |
| |
| --- |
| |
| ## Part 3: Training Configuration (v9/v10) |
| |
| ### 3.1 v9 Configuration (Baseline) |
| |
| **Location**: `train_spatial_beats.py:2228-2278` |
| |
| ```python |
| def make_ov1_local_spatial_v9_ov123_top4_config(...): |
| cfg = make_ov1_local_spatial_v8a_ov123_top4_config(...) |
| |
| # v8a already has: |
| cfg.loss.use_segment_matching = True |
| cfg.frame_spatial_loss_warmup_epochs = 3 # epochs 0-2: lambda_*=0 |
| cfg.frame_spatial_loss_ramp_epochs = 4 # epochs 3-6: ramp 0 → full |
| |
| # v9 adds: |
| cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS) |
| cfg.loss.frame_class_ontology_smoothing = 0.1 |
| cfg.model.use_class_head_mlp_residual = True |
| cfg.model.use_class_head_demixer = True |
| cfg.class_head_lr_scale = 0.3 # Class head frozen |
| cfg.class_head_freeze_during_ramp_epochs = 4 |
| cfg.num_epochs = 12 |
| ``` |
| |
| **Loss weights** (inherited from v8a, not shown but defaults): |
| - `lambda_frame_direction = 1.0` (from `SpatialLossConfig` line 67) |
| - `lambda_frame_distance = 1.0` (default) |
| - `lambda_frame_activity = 1.0` (default) |
| - `lambda_frame_class = 1.0` (default) |
|
|
| ### 3.2 v10 Phase-1 Configuration (SPATIAL FREEZE) |
|
|
| **Location**: `train_spatial_beats.py:2394-2478` |
|
|
| ```python |
| def make_ov1_local_spatial_v10_phase1_cls_config(...): |
| cfg = make_ov1_local_spatial_v9_ov123_top4_config(...) |
| |
| # --- CRITICAL: Spatial loss ZEROED --- |
| cfg.loss.lambda_frame_direction = 0.0 # ← DOA LOSS DISABLED |
| cfg.loss.lambda_frame_distance = 0.0 # ← DISTANCE LOSS DISABLED |
| cfg.loss.lambda_frame_activity = 0.5 |
| |
| # --- But matching cost still uses spatial signals --- |
| cfg.loss.frame_match_dir_cost_weight = 0.0 # ← Matching ALSO ignores DOA |
| cfg.loss.frame_match_dist_cost_weight = 0.0 # ← Matching ALSO ignores distance |
| |
| # --- Spatial heads are frozen at parameter level --- |
| cfg.freeze_frame_track_spatial_heads = True # ← PARAMETER FREEZE |
| |
| # --- Disable DOA warmup/ramp entirely --- |
| cfg.frame_spatial_loss_warmup_epochs = 0 # No warmup |
| cfg.frame_spatial_loss_ramp_epochs = 0 # No ramp |
| ``` |
|
|
| **Impact**: |
| - Spatial prediction heads receive **no gradients** (frozen + zero loss weight) |
| - Matching uses **class-only cost** (NLL term dominates) |
| - At v10 ep3, these heads are **completely untrained for multi-source ov2/ov3** |
| - When unfrozen in phase-2, they have to learn DOA from scratch with already-converged class |
|
|
| ### 3.3 Matching Cost Weights Are DECOUPLED from Loss Weights |
|
|
| **Location**: `spatial_loss.py:95-100` |
|
|
| ```python |
| # Hungarian-cost dir/dist weights — decoupled from lambda_frame_direction/ |
| # lambda_frame_distance so that cost and loss can be controlled independently. |
| # Set to 0.0 during class-warmup stage so DOA noise does not pollute matching. |
| # Default 1.0 preserves existing behavior. |
| frame_match_dir_cost_weight: float = 1.0 |
| frame_match_dist_cost_weight: float = 1.0 |
| ``` |
|
|
| **Problem**: In v10, both are set to 0.0: |
| - Matching becomes pure class-based |
| - But training loss on direction is ALSO 0.0 |
| - This is redundant for training, but **misleading for validation** |
|
|
| At validation time: |
| - Direction head outputs are NOT updated (frozen in v10 phase-1) |
| - But validation matching STILL uses them with `frame_match_dir_cost_weight=0.0` |
| - So validation metrics use **outdated direction predictions from v9** |
|
|
| --- |
|
|
| ## Part 4: Validation Metrics |
|
|
| ### 4.1 How Validation DOA Metrics Are Computed |
|
|
| **Location**: `spatial_loss.py:1596-1661` |
|
|
| ```python |
| def compute_frame_slot_validation_metrics( |
| prediction_output: FrameSlotPredictionOutput, |
| batch: "SpatialBatch", |
| temporal_padding_mask: Optional[Tensor], |
| config: SpatialLossConfig, |
| ) -> FrameMetricOutput: |
| # ... compute matched_slot using same _match_frame_slots_per_step ... |
| matched_slot = _match_frame_slots_per_step(...) # ← SAME MATCHING AS TRAINING |
| |
| # Line 1642-1660: If matched, compute angle metrics |
| if valid_assign.any(): |
| idx_b_m, idx_gt_m, idx_t_m = torch.nonzero(valid_assign, as_tuple=True) |
| idx_k_m = matched_slot[idx_b_m, idx_gt_m, idx_t_m] |
| pred_dir = F.normalize(prediction_output.pred_direction[idx_b_m, idx_t_m, idx_k_m], dim=-1) |
| pred_azi_deg, pred_ele_deg = _azi_ele_deg_from_direction_vector(pred_dir) |
| azi_tgt = targets["source_azimuth_deg"][idx_b_m, idx_gt_m].to(pred_azi_deg.dtype) |
| ele_tgt = targets["source_elevation_deg"][idx_b_m, idx_gt_m].to(pred_ele_deg.dtype) |
| azi_mae = _circular_distance_deg(pred_azi_deg, azi_tgt).mean() |
| ele_mae = torch.abs(pred_ele_deg - ele_tgt).mean() |
| ``` |
|
|
| **Critical asymmetry**: |
| - **Training matching** uses `detach()` on predictions (line 1465-1468): |
| ```python |
| pred_activity = prediction_output.pred_activity.detach() |
| pred_class = prediction_output.pred_class_logits.detach() |
| pred_direction = prediction_output.pred_direction.detach() |
| pred_distance = prediction_output.pred_distance.detach() |
| ``` |
| This means gradients **don't flow through the matching decision** |
| |
| - **Validation matching** uses same code, but predictions are frozen (in eval mode) |
| - This creates **train/test mismatch in matching logic** because gradients affect which assignments are optimized |
|
|
| --- |
|
|
| ## Part 5: Mismatch Between Training Loss & Matching |
|
|
| ### 5.1 The Core Problem |
|
|
| | Aspect | Training | Validation | |
| |--------|----------|------------| |
| | **SpecAugment** | Applied (only W channel) | Not applied | |
| | **DOA Loss** | λ_dir ∈ {0.0 (v10 p1), 1.0 (v9)} | Not used (only for matching) | |
| | **Matching Cost Dir Weight** | frame_match_dir_cost_weight ∈ {0.0, 1.0} | Same, but output is frozen | |
| | **Direction Predictions** | Receive gradients (v9) or None (v10) | Static outputs | |
| | **Spatial Augmentation** | NONE | NONE | |
| | **Matching Logic** | Uses detached predictions | Uses same detached logic | |
| |
| ### 5.2 v10 Phase-1 Specific Issue |
| |
| During v10 phase-1: |
| 1. Direction head is **frozen** (`requires_grad=False`) |
| 2. DOA loss is **zero** (`lambda_frame_direction = 0.0`) |
| 3. Matching cost weight is **zero** (`frame_match_dir_cost_weight = 0.0`) |
| 4. Therefore: direction head gets **no signal whatsoever** |
|
|
| At v10 phase-1 validation: |
| 1. Direction head outputs are **completely stale** (from v9 epoch 3) |
| 2. Matching still tries to use them but with weight 0.0, so class dominates |
| 3. Direction metrics are computed on **frozen, outdated predictions** |
| 4. Result: **Validation DOA metrics tank** even though direction head wasn't supposed to improve |
|
|
| --- |
|
|
| ## Part 6: Why Train/Valid Gap is Large |
|
|
| ### Root Causes (Ranked by Severity) |
|
|
| #### 1. **No Spatial Data Augmentation** ⚠️⚠️⚠️ |
| - **Impact**: ~40-60% of DOA variance unexplained |
| - **Mechanism**: Training sees limited DOA combinations; validation has unseen angles |
| - **Evidence**: All presets default to `spec_augment_*` only, zero rotation support |
| - **Fix needed**: Add random rotations that: |
| - Rotate 4-channel FOA waveform in 3D space |
| - Update azimuth/elevation targets accordingly |
| - Apply during both train (always) and val (for consistency) |
|
|
| #### 2. **SpecAugment Train-Only** ⚠️⚠️ |
| - **Impact**: ~10-20% of acoustic feature variance |
| - **Mechanism**: Training acoustic features differ from validation |
| - **Evidence**: `if not self.training: return w_logmel` in `_apply_spec_augment_w` |
| - **Fix needed**: Either: |
| - Disable SpecAugment entirely (simpler, possibly worse) |
| - Apply same SpecAugment seed at validation (breaks Bayesian interpretation) |
| - Reduce SpecAugment strength to smaller gap |
|
|
| #### 3. **v10 Phase-1 Freezes Direction Head** ⚠️⚠️ |
| - **Impact**: ~30-40% on ov2/ov3 DOA metrics |
| - **Mechanism**: Direction head learns only from v9 epoch 3 (before DOA ramp), frozen for 10 epochs, then thawed with wrong initialization |
| - **Evidence**: `freeze_frame_track_spatial_heads = True`, `lambda_frame_direction = 0.0` |
| - **Fix needed**: |
| - Extend DOA ramp to v10 phase-1 instead of full freeze |
| - Or initialize direction head better when unfrozen |
|
|
| #### 4. **Continuous DOA Regression Sensitivity** ⚠️ |
| - **Impact**: ~5-15% (only when matching is poor) |
| - **Mechanism**: Cosine distance loss is sensitive to small errors; matching cost ignores spatial during class warmup |
| - **Evidence**: `loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()` with no binning |
| - **Fix needed**: None (by design); problem is upstream matching issues |
|
|
| #### 5. **Matching Logic Uses Detached Predictions** ⚠️ (Minor) |
| - **Impact**: ~2-5% |
| - **Mechanism**: Gradients don't flow through matching decision; optimization is indirect |
| - **Evidence**: Lines 1465-1468 use `.detach()` |
| - **Note**: This is intentional to avoid combinatorial explosion in loss surface; acceptable |
|
|
| --- |
|
|
| ## Part 7: Actionable Diagnostics |
|
|
| ### What to Check in Your Logs |
|
|
| 1. **DOA Metrics by Epoch**: |
| - v9: DOA should improve for epochs 3-7 (ramp phase) |
| - v10 p1: DOA should **stay flat** (frozen) |
| - v10 p2: DOA should improve again (unfrozen) |
| - **Red flag**: DOA regressing at val while training loss decreases |
|
|
| 2. **Per-Source Breakdown**: |
| - ov1 (single-source): should have small train/val gap (~5°) |
| - ov2 (2-source): should have medium gap (~10°) |
| - ov3 (3-source): should have large gap (~15-20°) |
| - **Red flag**: ov1 gap > 10° suggests augmentation issue |
|
|
| 3. **Activity vs DOA Alignment**: |
| - If activity_recall is high but azi_mae is large, suggests **matching is finding sources but estimating wrong angles** |
| - Check: is `frame_match_dir_cost_weight` being applied correctly? |
|
|
| 4. **Class Accuracy vs DOA**: |
| - Plot class_acc vs azi_mae per epoch |
| - If class plateaus at epoch 3 but azi_mae continues dropping, DOA head is separable and improvable |
| - If both plateau together, suggests **shared representation bottleneck** |
| |
| --- |
| |
| ## Part 8: Recommended Fixes (Priority Order) |
| |
| ### Immediate (High Impact, Low Risk) |
| 1. **Add random rotation augmentation** |
| - Rotate FOA waveform + update (azimuth, elevation) targets |
| - Apply consistently to both train and val (same seed for val) |
| - Expected gain: 10-20° DOA@20 improvement on ov3 |
| |
| 2. **Disable SpecAugment or make it train/val consistent** |
| - Option A: Remove SpecAugment (simplest) |
| - Option B: Apply weak SpecAugment to both train and val |
| - Expected gain: 2-5° improvement |
| |
| ### Medium (Moderate Impact, Medium Risk) |
| 3. **Extend v10 phase-1 DOA ramp instead of full freeze** |
| - Don't freeze direction head; instead use lower `lambda_frame_direction` (e.g., 0.1) |
| - Keep matching cost weight at 0.0 (class-focused matching is intentional) |
| - Expected gain: 8-15° on ov2/ov3 |
| |
| 4. **Better initialization of direction head after freeze** |
| - When unfreezing in v10 phase-2, use momentum averaging of v9 predictions |
| - Or apply small warmup on direction loss before full scale |
| |
| ### Advanced (High Impact, High Risk) |
| 5. **Switch DOA representation from regression to soft-label binning** |
| - Would match azimuth classification approach already in codebase |
| - Requires architectural changes to prediction heads |
| - Expected gain: 5-10° (less sensitive to outliers) |
| |
| 6. **Curriculum learning for matching costs** |
| - Start with class-only matching (v10 p1) |
| - Gradually increase `frame_match_dir_cost_weight` and `frame_match_dist_cost_weight` |
| - This is already partially done with loss weights, extend to matching |
| |
| --- |
| |
| ## Appendix: Code References |
| |
| | Finding | File:Line | Code | |
| |---------|-----------|------| |
| | FOA loading | `spatial_dataset.py:310-366` | `_load_audio_file` | |
| | SpecAugment train-only | `spatial_atst.py:336-347` | `if not self.training: return w_logmel` | |
| | No rotations | `spatial_dataset.py:*` | grep "rotation" → no results | |
| | Hungarian matching | `spatial_loss.py:1452-1509` | `_match_frame_slots_per_step` | |
| | Matching cost | `spatial_loss.py:1491-1504` | `cost[row] = act_cost + cls_nll + dir_cost + dist_cost` | |
| | Direction loss | `spatial_loss.py:1562-1565` | `loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()` | |
| | v9 config | `train_spatial_beats.py:2228-2278` | `make_ov1_local_spatial_v9_ov123_top4_config` | |
| | v10 freeze | `train_spatial_beats.py:2394-2478` | `make_ov1_local_spatial_v10_phase1_cls_config` | |
| | Spatial freeze | `train_spatial_beats.py:3442-3454` | `freeze_frame_track_spatial_heads` | |
| | Validation metrics | `spatial_loss.py:1596-1661` | `compute_frame_slot_validation_metrics` | |
|
|
|
|