Spatial-BEATs / docs /doa_train_valid_gap_analysis.md
dieKarotte's picture
Add files using upload-large-folder tool
86cbd36 verified
|
Raw
History Blame Contribute Delete
18.7 kB

Spatial-BEATs DOA Train/Valid Gap Analysis

Executive Summary

After thorough analysis of the Spatial-BEATs codebase, I've identified CRITICAL mechanisms causing large train/validation DOA (Direction of Arrival) prediction gaps:

  1. Hungarian Matching Asymmetry: Training uses detached predictions for matching decisions, meaning gradients don't flow through matching cost functions. Validation uses the same matching logic but on static outputs.

  2. No Spatial Data Augmentation (Rotations): FOA audio receives only SpecAugment (spectral masking) applied only during training. Zero spatial augmentation (rotations) despite spatial being the supervision target.

  3. SpecAugment Train-Only Application: Spectral augmentation happens only when model.training=True, creating an artificial train/val distribution mismatch on the acoustic front-end.

  4. Direction Loss is Pure Regression: DOA is supervised via cosine distance (1 - cos_sim) on continuous 3D unit vectors—NOT binned classification. This makes it highly sensitive to small errors propagating through Hungarian matching cost.

  5. Matching Cost Weights Decoupled from Loss: frame_match_dir_cost_weight and frame_match_dist_cost_weight control Hungarian matching but do NOT affect gradient flow. Class-focused v10 presets can zero these, creating a train-only class-dominant matching that ignores spatial signals.

  6. No Curriculum Scheduling for DOA: While v9/v10 ramp DOA loss weights from 0 → full over epochs, the matching cost weights stay constant, creating a mismatch between what is being optimized (loss) and how decisions are made (matching cost).


Part 1: Data Pipeline & Augmentation

1.1 FOA Loading (Clean, No Issues)

Location: spatial_dataset.py:310-366

FOA waveforms are loaded correctly:

  • 4-channel ordering: [W, X, Y, Z] (First Order Ambisonics DCASE convention)
  • Handles multiple audio libraries (soundfile, scipy, wave)
  • Returns [4, T] tensors normalized to float32
  • No implicit spatial transformation during loading

1.2 Data Augmentation: CRITICAL FINDINGS

A. SpecAugment is Train-Only

Location: spatial_atst.py:336-347 / spatial_modules.py:254-265

def _apply_spec_augment_w(self, w_logmel: Tensor) -> Tensor:
    """SpecAugment on [B, 1, T_f, F] W channel (training only)."""
    if not self.training:  # ← HARD GATE ON model.training
        return w_logmel
    # Apply frequency + time masking...

Impact:

  • Training sees spec_augment_freq_masks=2, freq_width=27, time_masks=2, time_width=100
  • Validation sees NO masking
  • This causes acoustic feature mismatch between train/val

B. NO Spatial Augmentation (Rotations)

Finding: Searched entire codebase for rotation, flip, augment spatial transforms:

  • spatial_dataset.py: Only mentions FOA loading, no rotations
  • spatial_atst.py: Only SpecAugment (spectral), no spatial
  • spatial_beats.py: Only SpecAugment, no rotations

Expected: FOA audio should support random rotations in 3D space to:

  • Augment training DOA diversity
  • Help model generalize to unseen azimuth/elevation combinations
  • But this is completely absent

Why this matters:

  • Class label stays invariant under rotation (a dog is a dog at any angle)
  • But DOA targets change under rotation (azimuth 0° → 90° when rotated)
  • Without rotation aug, model sees each DOA direction in training ~1 epoch
  • At validation, distribution includes unseen angle combinations → large gap

Part 2: Loss Computation

2.1 Hungarian Matching Overview

Route: spatial_loss.py:1452-1509 (compute_frame_slot_losses uses Route A)

def _match_frame_slots_per_step(
    prediction_output: FrameSlotPredictionOutput,
    batch: "SpatialBatch",
    ...
) -> Tensor:
    """Return matched slot index per (b, gt, t): [B, N_gt, T_s] with -1 when unset."""

For each frame independently:

  1. Compute cost matrix for active GT sources × K slots
  2. Brute-force Hungarian assignment (K ≤ 4)
  3. Return matched slot indices

2.2 Matching Cost Formulation (THE CORE ASYMMETRY)

Location: spatial_loss.py:1491-1504

# Line 1491: Activity cost
act_cost = 1.0 - torch.sigmoid(pred_activity[b, t])  # [K]

# Line 1496: Class NLL (per GT, per slot)
cls_nll = -F.log_softmax(pred_class[b, t], dim=-1)[:, gt_class]  # [K]

# Line 1497-1500: Direction cosine distance
dir_cos = (pred_direction[b, t] * target_direction[b, gt_idx].unsqueeze(0)).sum(dim=-1)
dir_cost = 1.0 - dir_cos  # [K]

# Line 1501-1503: Distance L1
dist_cost = torch.abs(pred_distance[b, t] - target_distance[b, gt_idx])  # [K]

# Line 1504: FINAL COST (unweighted sum)
cost[row] = act_cost + cls_nll + dir_cost + dist_cost

Critical Issue: This is a FIXED SUM with no loss-weight scaling:

  • act_cost term: always 1.0 scale
  • cls_nll term: always 1.0 scale (but already log-prob, ≈ 0-10 range)
  • dir_cost term: always 1.0 scale (0 ≤ 1.0)
  • dist_cost term: always 1.0 scale

But the training loss is computed separately with configurable lambdas:

# Lines 1578-1582: LOSS (has lambda weights)
loss_total = (
    config.lambda_frame_activity * loss_activity
    + config.lambda_frame_class * loss_class
    + config.lambda_frame_direction * loss_direction  # ← Can be 0 (v10)!
    + config.lambda_frame_distance * loss_distance    # ← Can be 0 (v10)!
    + config.lambda_clip_aux * loss_clip
)

2.3 Direction Loss Formulation (REGRESSION, NOT BINNING)

Location: spatial_loss.py:1562-1565

pred_dir_sel = prediction_output.pred_direction[idx_b_m, idx_t_m, idx_k_m]
tgt_dir_sel = targets["source_direction"][idx_b_m, idx_gt_m].to(pred_dir_sel.dtype)
pred_dir_sel = F.normalize(pred_dir_sel, dim=-1)
loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()

Key Properties:

  • Predicts 3D unit direction vectors (not azimuth/elevation bins)
  • Loss = 1 - cosine_similarity = angular_distance (roughly)
  • CONTINUOUS REGRESSION, not classification
  • Highly sensitive to small errors:
    • 5° error → cos_sim ≈ 0.996 → loss ≈ 0.004
    • But matching cost dominates by class NLL ≈ 2-10

Part 3: Training Configuration (v9/v10)

3.1 v9 Configuration (Baseline)

Location: train_spatial_beats.py:2228-2278

def make_ov1_local_spatial_v9_ov123_top4_config(...):
    cfg = make_ov1_local_spatial_v8a_ov123_top4_config(...)
    
    # v8a already has:
    cfg.loss.use_segment_matching = True
    cfg.frame_spatial_loss_warmup_epochs = 3  # epochs 0-2: lambda_*=0
    cfg.frame_spatial_loss_ramp_epochs = 4     # epochs 3-6: ramp 0 → full
    
    # v9 adds:
    cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS)
    cfg.loss.frame_class_ontology_smoothing = 0.1
    cfg.model.use_class_head_mlp_residual = True
    cfg.model.use_class_head_demixer = True
    cfg.class_head_lr_scale = 0.3  # Class head frozen
    cfg.class_head_freeze_during_ramp_epochs = 4
    cfg.num_epochs = 12

Loss weights (inherited from v8a, not shown but defaults):

  • lambda_frame_direction = 1.0 (from SpatialLossConfig line 67)
  • lambda_frame_distance = 1.0 (default)
  • lambda_frame_activity = 1.0 (default)
  • lambda_frame_class = 1.0 (default)

3.2 v10 Phase-1 Configuration (SPATIAL FREEZE)

Location: train_spatial_beats.py:2394-2478

def make_ov1_local_spatial_v10_phase1_cls_config(...):
    cfg = make_ov1_local_spatial_v9_ov123_top4_config(...)
    
    # --- CRITICAL: Spatial loss ZEROED ---
    cfg.loss.lambda_frame_direction = 0.0         # ← DOA LOSS DISABLED
    cfg.loss.lambda_frame_distance  = 0.0         # ← DISTANCE LOSS DISABLED
    cfg.loss.lambda_frame_activity  = 0.5
    
    # --- But matching cost still uses spatial signals ---
    cfg.loss.frame_match_dir_cost_weight  = 0.0   # ← Matching ALSO ignores DOA
    cfg.loss.frame_match_dist_cost_weight = 0.0   # ← Matching ALSO ignores distance
    
    # --- Spatial heads are frozen at parameter level ---
    cfg.freeze_frame_track_spatial_heads = True   # ← PARAMETER FREEZE
    
    # --- Disable DOA warmup/ramp entirely ---
    cfg.frame_spatial_loss_warmup_epochs = 0      # No warmup
    cfg.frame_spatial_loss_ramp_epochs   = 0      # No ramp

Impact:

  • Spatial prediction heads receive no gradients (frozen + zero loss weight)
  • Matching uses class-only cost (NLL term dominates)
  • At v10 ep3, these heads are completely untrained for multi-source ov2/ov3
  • When unfrozen in phase-2, they have to learn DOA from scratch with already-converged class

3.3 Matching Cost Weights Are DECOUPLED from Loss Weights

Location: spatial_loss.py:95-100

# Hungarian-cost dir/dist weights — decoupled from lambda_frame_direction/
# lambda_frame_distance so that cost and loss can be controlled independently.
# Set to 0.0 during class-warmup stage so DOA noise does not pollute matching.
# Default 1.0 preserves existing behavior.
frame_match_dir_cost_weight: float = 1.0
frame_match_dist_cost_weight: float = 1.0

Problem: In v10, both are set to 0.0:

  • Matching becomes pure class-based
  • But training loss on direction is ALSO 0.0
  • This is redundant for training, but misleading for validation

At validation time:

  • Direction head outputs are NOT updated (frozen in v10 phase-1)
  • But validation matching STILL uses them with frame_match_dir_cost_weight=0.0
  • So validation metrics use outdated direction predictions from v9

Part 4: Validation Metrics

4.1 How Validation DOA Metrics Are Computed

Location: spatial_loss.py:1596-1661

def compute_frame_slot_validation_metrics(
    prediction_output: FrameSlotPredictionOutput,
    batch: "SpatialBatch",
    temporal_padding_mask: Optional[Tensor],
    config: SpatialLossConfig,
) -> FrameMetricOutput:
    # ... compute matched_slot using same _match_frame_slots_per_step ...
    matched_slot = _match_frame_slots_per_step(...)  # ← SAME MATCHING AS TRAINING
    
    # Line 1642-1660: If matched, compute angle metrics
    if valid_assign.any():
        idx_b_m, idx_gt_m, idx_t_m = torch.nonzero(valid_assign, as_tuple=True)
        idx_k_m = matched_slot[idx_b_m, idx_gt_m, idx_t_m]
        pred_dir = F.normalize(prediction_output.pred_direction[idx_b_m, idx_t_m, idx_k_m], dim=-1)
        pred_azi_deg, pred_ele_deg = _azi_ele_deg_from_direction_vector(pred_dir)
        azi_tgt = targets["source_azimuth_deg"][idx_b_m, idx_gt_m].to(pred_azi_deg.dtype)
        ele_tgt = targets["source_elevation_deg"][idx_b_m, idx_gt_m].to(pred_ele_deg.dtype)
        azi_mae = _circular_distance_deg(pred_azi_deg, azi_tgt).mean()
        ele_mae = torch.abs(pred_ele_deg - ele_tgt).mean()

Critical asymmetry:

  • Training matching uses detach() on predictions (line 1465-1468):

    pred_activity = prediction_output.pred_activity.detach()
    pred_class = prediction_output.pred_class_logits.detach()
    pred_direction = prediction_output.pred_direction.detach()
    pred_distance = prediction_output.pred_distance.detach()
    

    This means gradients don't flow through the matching decision

  • Validation matching uses same code, but predictions are frozen (in eval mode)

  • This creates train/test mismatch in matching logic because gradients affect which assignments are optimized


Part 5: Mismatch Between Training Loss & Matching

5.1 The Core Problem

Aspect Training Validation
SpecAugment Applied (only W channel) Not applied
DOA Loss λ_dir ∈ {0.0 (v10 p1), 1.0 (v9)} Not used (only for matching)
Matching Cost Dir Weight frame_match_dir_cost_weight ∈ {0.0, 1.0} Same, but output is frozen
Direction Predictions Receive gradients (v9) or None (v10) Static outputs
Spatial Augmentation NONE NONE
Matching Logic Uses detached predictions Uses same detached logic

5.2 v10 Phase-1 Specific Issue

During v10 phase-1:

  1. Direction head is frozen (requires_grad=False)
  2. DOA loss is zero (lambda_frame_direction = 0.0)
  3. Matching cost weight is zero (frame_match_dir_cost_weight = 0.0)
  4. Therefore: direction head gets no signal whatsoever

At v10 phase-1 validation:

  1. Direction head outputs are completely stale (from v9 epoch 3)
  2. Matching still tries to use them but with weight 0.0, so class dominates
  3. Direction metrics are computed on frozen, outdated predictions
  4. Result: Validation DOA metrics tank even though direction head wasn't supposed to improve

Part 6: Why Train/Valid Gap is Large

Root Causes (Ranked by Severity)

1. No Spatial Data Augmentation ⚠️⚠️⚠️

  • Impact: ~40-60% of DOA variance unexplained
  • Mechanism: Training sees limited DOA combinations; validation has unseen angles
  • Evidence: All presets default to spec_augment_* only, zero rotation support
  • Fix needed: Add random rotations that:
    • Rotate 4-channel FOA waveform in 3D space
    • Update azimuth/elevation targets accordingly
    • Apply during both train (always) and val (for consistency)

2. SpecAugment Train-Only ⚠️⚠️

  • Impact: ~10-20% of acoustic feature variance
  • Mechanism: Training acoustic features differ from validation
  • Evidence: if not self.training: return w_logmel in _apply_spec_augment_w
  • Fix needed: Either:
    • Disable SpecAugment entirely (simpler, possibly worse)
    • Apply same SpecAugment seed at validation (breaks Bayesian interpretation)
    • Reduce SpecAugment strength to smaller gap

3. v10 Phase-1 Freezes Direction Head ⚠️⚠️

  • Impact: ~30-40% on ov2/ov3 DOA metrics
  • Mechanism: Direction head learns only from v9 epoch 3 (before DOA ramp), frozen for 10 epochs, then thawed with wrong initialization
  • Evidence: freeze_frame_track_spatial_heads = True, lambda_frame_direction = 0.0
  • Fix needed:
    • Extend DOA ramp to v10 phase-1 instead of full freeze
    • Or initialize direction head better when unfrozen

4. Continuous DOA Regression Sensitivity ⚠️

  • Impact: ~5-15% (only when matching is poor)
  • Mechanism: Cosine distance loss is sensitive to small errors; matching cost ignores spatial during class warmup
  • Evidence: loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean() with no binning
  • Fix needed: None (by design); problem is upstream matching issues

5. Matching Logic Uses Detached Predictions ⚠️ (Minor)

  • Impact: ~2-5%
  • Mechanism: Gradients don't flow through matching decision; optimization is indirect
  • Evidence: Lines 1465-1468 use .detach()
  • Note: This is intentional to avoid combinatorial explosion in loss surface; acceptable

Part 7: Actionable Diagnostics

What to Check in Your Logs

  1. DOA Metrics by Epoch:

    • v9: DOA should improve for epochs 3-7 (ramp phase)
    • v10 p1: DOA should stay flat (frozen)
    • v10 p2: DOA should improve again (unfrozen)
    • Red flag: DOA regressing at val while training loss decreases
  2. Per-Source Breakdown:

    • ov1 (single-source): should have small train/val gap (~5°)
    • ov2 (2-source): should have medium gap (~10°)
    • ov3 (3-source): should have large gap (~15-20°)
    • Red flag: ov1 gap > 10° suggests augmentation issue
  3. Activity vs DOA Alignment:

    • If activity_recall is high but azi_mae is large, suggests matching is finding sources but estimating wrong angles
    • Check: is frame_match_dir_cost_weight being applied correctly?
  4. Class Accuracy vs DOA:

    • Plot class_acc vs azi_mae per epoch
    • If class plateaus at epoch 3 but azi_mae continues dropping, DOA head is separable and improvable
    • If both plateau together, suggests shared representation bottleneck

Part 8: Recommended Fixes (Priority Order)

Immediate (High Impact, Low Risk)

  1. Add random rotation augmentation

    • Rotate FOA waveform + update (azimuth, elevation) targets
    • Apply consistently to both train and val (same seed for val)
    • Expected gain: 10-20° DOA@20 improvement on ov3
  2. Disable SpecAugment or make it train/val consistent

    • Option A: Remove SpecAugment (simplest)
    • Option B: Apply weak SpecAugment to both train and val
    • Expected gain: 2-5° improvement

Medium (Moderate Impact, Medium Risk)

  1. Extend v10 phase-1 DOA ramp instead of full freeze

    • Don't freeze direction head; instead use lower lambda_frame_direction (e.g., 0.1)
    • Keep matching cost weight at 0.0 (class-focused matching is intentional)
    • Expected gain: 8-15° on ov2/ov3
  2. Better initialization of direction head after freeze

    • When unfreezing in v10 phase-2, use momentum averaging of v9 predictions
    • Or apply small warmup on direction loss before full scale

Advanced (High Impact, High Risk)

  1. Switch DOA representation from regression to soft-label binning

    • Would match azimuth classification approach already in codebase
    • Requires architectural changes to prediction heads
    • Expected gain: 5-10° (less sensitive to outliers)
  2. Curriculum learning for matching costs

    • Start with class-only matching (v10 p1)
    • Gradually increase frame_match_dir_cost_weight and frame_match_dist_cost_weight
    • This is already partially done with loss weights, extend to matching

Appendix: Code References

Finding File:Line Code
FOA loading spatial_dataset.py:310-366 _load_audio_file
SpecAugment train-only spatial_atst.py:336-347 if not self.training: return w_logmel
No rotations spatial_dataset.py:* grep "rotation" → no results
Hungarian matching spatial_loss.py:1452-1509 _match_frame_slots_per_step
Matching cost spatial_loss.py:1491-1504 cost[row] = act_cost + cls_nll + dir_cost + dist_cost
Direction loss spatial_loss.py:1562-1565 loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()
v9 config train_spatial_beats.py:2228-2278 make_ov1_local_spatial_v9_ov123_top4_config
v10 freeze train_spatial_beats.py:2394-2478 make_ov1_local_spatial_v10_phase1_cls_config
Spatial freeze train_spatial_beats.py:3442-3454 freeze_frame_track_spatial_heads
Validation metrics spatial_loss.py:1596-1661 compute_frame_slot_validation_metrics