Spatial-BEATs / docs /doa_train_valid_gap_analysis.md
dieKarotte's picture
Add files using upload-large-folder tool
86cbd36 verified
|
Raw
History Blame Contribute Delete
18.7 kB
# Spatial-BEATs DOA Train/Valid Gap Analysis
## Executive Summary
After thorough analysis of the Spatial-BEATs codebase, I've identified **CRITICAL** mechanisms causing large train/validation DOA (Direction of Arrival) prediction gaps:
1. **Hungarian Matching Asymmetry**: Training uses **detached predictions** for matching decisions, meaning gradients don't flow through matching cost functions. Validation uses the **same matching logic but on static outputs**.
2. **No Spatial Data Augmentation (Rotations)**: FOA audio receives only **SpecAugment** (spectral masking) applied only during training. **Zero spatial augmentation** (rotations) despite spatial being the supervision target.
3. **SpecAugment Train-Only Application**: Spectral augmentation happens only when `model.training=True`, creating an artificial train/val distribution mismatch on the acoustic front-end.
4. **Direction Loss is Pure Regression**: DOA is supervised via **cosine distance** (1 - cos_sim) on continuous 3D unit vectors—NOT binned classification. This makes it highly sensitive to small errors propagating through Hungarian matching cost.
5. **Matching Cost Weights Decoupled from Loss**: `frame_match_dir_cost_weight` and `frame_match_dist_cost_weight` control Hungarian matching but **do NOT affect gradient flow**. Class-focused v10 presets can zero these, creating a **train-only class-dominant matching** that ignores spatial signals.
6. **No Curriculum Scheduling for DOA**: While v9/v10 ramp DOA loss weights from 0 → full over epochs, the **matching cost weights stay constant**, creating a **mismatch between what is being optimized (loss) and how decisions are made (matching cost)**.
---
## Part 1: Data Pipeline & Augmentation
### 1.1 FOA Loading (Clean, No Issues)
**Location**: `spatial_dataset.py:310-366`
FOA waveforms are loaded correctly:
- 4-channel ordering: `[W, X, Y, Z]` (First Order Ambisonics DCASE convention)
- Handles multiple audio libraries (soundfile, scipy, wave)
- Returns `[4, T]` tensors normalized to float32
- **No implicit spatial transformation during loading**
### 1.2 Data Augmentation: CRITICAL FINDINGS
#### A. SpecAugment is Train-Only
**Location**: `spatial_atst.py:336-347` / `spatial_modules.py:254-265`
```python
def _apply_spec_augment_w(self, w_logmel: Tensor) -> Tensor:
"""SpecAugment on [B, 1, T_f, F] W channel (training only)."""
if not self.training: # ← HARD GATE ON model.training
return w_logmel
# Apply frequency + time masking...
```
**Impact**:
- Training sees `spec_augment_freq_masks=2, freq_width=27, time_masks=2, time_width=100`
- Validation sees **NO masking**
- This causes acoustic feature mismatch between train/val
#### B. NO Spatial Augmentation (Rotations)
**Finding**: Searched entire codebase for `rotation`, `flip`, `augment` spatial transforms:
- `spatial_dataset.py`: Only mentions FOA loading, no rotations
- `spatial_atst.py`: Only SpecAugment (spectral), no spatial
- `spatial_beats.py`: Only SpecAugment, no rotations
**Expected**: FOA audio should support **random rotations in 3D space** to:
- Augment training DOA diversity
- Help model generalize to unseen azimuth/elevation combinations
- But this is **completely absent**
**Why this matters**:
- Class label stays invariant under rotation (a dog is a dog at any angle)
- But DOA targets **change under rotation** (azimuth 0° → 90° when rotated)
- Without rotation aug, model sees each DOA direction in training ~1 epoch
- At validation, distribution includes unseen angle combinations → large gap
---
## Part 2: Loss Computation
### 2.1 Hungarian Matching Overview
**Route**: `spatial_loss.py:1452-1509` (`compute_frame_slot_losses` uses Route A)
```python
def _match_frame_slots_per_step(
prediction_output: FrameSlotPredictionOutput,
batch: "SpatialBatch",
...
) -> Tensor:
"""Return matched slot index per (b, gt, t): [B, N_gt, T_s] with -1 when unset."""
```
For each frame independently:
1. Compute cost matrix for active GT sources × K slots
2. Brute-force Hungarian assignment (K ≤ 4)
3. Return matched slot indices
### 2.2 Matching Cost Formulation (THE CORE ASYMMETRY)
**Location**: `spatial_loss.py:1491-1504`
```python
# Line 1491: Activity cost
act_cost = 1.0 - torch.sigmoid(pred_activity[b, t]) # [K]
# Line 1496: Class NLL (per GT, per slot)
cls_nll = -F.log_softmax(pred_class[b, t], dim=-1)[:, gt_class] # [K]
# Line 1497-1500: Direction cosine distance
dir_cos = (pred_direction[b, t] * target_direction[b, gt_idx].unsqueeze(0)).sum(dim=-1)
dir_cost = 1.0 - dir_cos # [K]
# Line 1501-1503: Distance L1
dist_cost = torch.abs(pred_distance[b, t] - target_distance[b, gt_idx]) # [K]
# Line 1504: FINAL COST (unweighted sum)
cost[row] = act_cost + cls_nll + dir_cost + dist_cost
```
**Critical Issue**: This is a **FIXED SUM** with no loss-weight scaling:
- `act_cost` term: always 1.0 scale
- `cls_nll` term: always 1.0 scale (but already log-prob, ≈ 0-10 range)
- `dir_cost` term: always 1.0 scale (0 ≤ 1.0)
- `dist_cost` term: always 1.0 scale
But the **training loss** is computed separately with configurable lambdas:
```python
# Lines 1578-1582: LOSS (has lambda weights)
loss_total = (
config.lambda_frame_activity * loss_activity
+ config.lambda_frame_class * loss_class
+ config.lambda_frame_direction * loss_direction # ← Can be 0 (v10)!
+ config.lambda_frame_distance * loss_distance # ← Can be 0 (v10)!
+ config.lambda_clip_aux * loss_clip
)
```
### 2.3 Direction Loss Formulation (REGRESSION, NOT BINNING)
**Location**: `spatial_loss.py:1562-1565`
```python
pred_dir_sel = prediction_output.pred_direction[idx_b_m, idx_t_m, idx_k_m]
tgt_dir_sel = targets["source_direction"][idx_b_m, idx_gt_m].to(pred_dir_sel.dtype)
pred_dir_sel = F.normalize(pred_dir_sel, dim=-1)
loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()
```
**Key Properties**:
- Predicts **3D unit direction vectors** (not azimuth/elevation bins)
- Loss = `1 - cosine_similarity` = `angular_distance` (roughly)
- **CONTINUOUS REGRESSION**, not classification
- Highly sensitive to small errors:
- 5° error → cos_sim ≈ 0.996 → loss ≈ 0.004
- But matching cost dominates by class NLL ≈ 2-10
---
## Part 3: Training Configuration (v9/v10)
### 3.1 v9 Configuration (Baseline)
**Location**: `train_spatial_beats.py:2228-2278`
```python
def make_ov1_local_spatial_v9_ov123_top4_config(...):
cfg = make_ov1_local_spatial_v8a_ov123_top4_config(...)
# v8a already has:
cfg.loss.use_segment_matching = True
cfg.frame_spatial_loss_warmup_epochs = 3 # epochs 0-2: lambda_*=0
cfg.frame_spatial_loss_ramp_epochs = 4 # epochs 3-6: ramp 0 → full
# v9 adds:
cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS)
cfg.loss.frame_class_ontology_smoothing = 0.1
cfg.model.use_class_head_mlp_residual = True
cfg.model.use_class_head_demixer = True
cfg.class_head_lr_scale = 0.3 # Class head frozen
cfg.class_head_freeze_during_ramp_epochs = 4
cfg.num_epochs = 12
```
**Loss weights** (inherited from v8a, not shown but defaults):
- `lambda_frame_direction = 1.0` (from `SpatialLossConfig` line 67)
- `lambda_frame_distance = 1.0` (default)
- `lambda_frame_activity = 1.0` (default)
- `lambda_frame_class = 1.0` (default)
### 3.2 v10 Phase-1 Configuration (SPATIAL FREEZE)
**Location**: `train_spatial_beats.py:2394-2478`
```python
def make_ov1_local_spatial_v10_phase1_cls_config(...):
cfg = make_ov1_local_spatial_v9_ov123_top4_config(...)
# --- CRITICAL: Spatial loss ZEROED ---
cfg.loss.lambda_frame_direction = 0.0 # ← DOA LOSS DISABLED
cfg.loss.lambda_frame_distance = 0.0 # ← DISTANCE LOSS DISABLED
cfg.loss.lambda_frame_activity = 0.5
# --- But matching cost still uses spatial signals ---
cfg.loss.frame_match_dir_cost_weight = 0.0 # ← Matching ALSO ignores DOA
cfg.loss.frame_match_dist_cost_weight = 0.0 # ← Matching ALSO ignores distance
# --- Spatial heads are frozen at parameter level ---
cfg.freeze_frame_track_spatial_heads = True # ← PARAMETER FREEZE
# --- Disable DOA warmup/ramp entirely ---
cfg.frame_spatial_loss_warmup_epochs = 0 # No warmup
cfg.frame_spatial_loss_ramp_epochs = 0 # No ramp
```
**Impact**:
- Spatial prediction heads receive **no gradients** (frozen + zero loss weight)
- Matching uses **class-only cost** (NLL term dominates)
- At v10 ep3, these heads are **completely untrained for multi-source ov2/ov3**
- When unfrozen in phase-2, they have to learn DOA from scratch with already-converged class
### 3.3 Matching Cost Weights Are DECOUPLED from Loss Weights
**Location**: `spatial_loss.py:95-100`
```python
# Hungarian-cost dir/dist weights — decoupled from lambda_frame_direction/
# lambda_frame_distance so that cost and loss can be controlled independently.
# Set to 0.0 during class-warmup stage so DOA noise does not pollute matching.
# Default 1.0 preserves existing behavior.
frame_match_dir_cost_weight: float = 1.0
frame_match_dist_cost_weight: float = 1.0
```
**Problem**: In v10, both are set to 0.0:
- Matching becomes pure class-based
- But training loss on direction is ALSO 0.0
- This is redundant for training, but **misleading for validation**
At validation time:
- Direction head outputs are NOT updated (frozen in v10 phase-1)
- But validation matching STILL uses them with `frame_match_dir_cost_weight=0.0`
- So validation metrics use **outdated direction predictions from v9**
---
## Part 4: Validation Metrics
### 4.1 How Validation DOA Metrics Are Computed
**Location**: `spatial_loss.py:1596-1661`
```python
def compute_frame_slot_validation_metrics(
prediction_output: FrameSlotPredictionOutput,
batch: "SpatialBatch",
temporal_padding_mask: Optional[Tensor],
config: SpatialLossConfig,
) -> FrameMetricOutput:
# ... compute matched_slot using same _match_frame_slots_per_step ...
matched_slot = _match_frame_slots_per_step(...) # ← SAME MATCHING AS TRAINING
# Line 1642-1660: If matched, compute angle metrics
if valid_assign.any():
idx_b_m, idx_gt_m, idx_t_m = torch.nonzero(valid_assign, as_tuple=True)
idx_k_m = matched_slot[idx_b_m, idx_gt_m, idx_t_m]
pred_dir = F.normalize(prediction_output.pred_direction[idx_b_m, idx_t_m, idx_k_m], dim=-1)
pred_azi_deg, pred_ele_deg = _azi_ele_deg_from_direction_vector(pred_dir)
azi_tgt = targets["source_azimuth_deg"][idx_b_m, idx_gt_m].to(pred_azi_deg.dtype)
ele_tgt = targets["source_elevation_deg"][idx_b_m, idx_gt_m].to(pred_ele_deg.dtype)
azi_mae = _circular_distance_deg(pred_azi_deg, azi_tgt).mean()
ele_mae = torch.abs(pred_ele_deg - ele_tgt).mean()
```
**Critical asymmetry**:
- **Training matching** uses `detach()` on predictions (line 1465-1468):
```python
pred_activity = prediction_output.pred_activity.detach()
pred_class = prediction_output.pred_class_logits.detach()
pred_direction = prediction_output.pred_direction.detach()
pred_distance = prediction_output.pred_distance.detach()
```
This means gradients **don't flow through the matching decision**
- **Validation matching** uses same code, but predictions are frozen (in eval mode)
- This creates **train/test mismatch in matching logic** because gradients affect which assignments are optimized
---
## Part 5: Mismatch Between Training Loss & Matching
### 5.1 The Core Problem
| Aspect | Training | Validation |
|--------|----------|------------|
| **SpecAugment** | Applied (only W channel) | Not applied |
| **DOA Loss** | λ_dir ∈ {0.0 (v10 p1), 1.0 (v9)} | Not used (only for matching) |
| **Matching Cost Dir Weight** | frame_match_dir_cost_weight ∈ {0.0, 1.0} | Same, but output is frozen |
| **Direction Predictions** | Receive gradients (v9) or None (v10) | Static outputs |
| **Spatial Augmentation** | NONE | NONE |
| **Matching Logic** | Uses detached predictions | Uses same detached logic |
### 5.2 v10 Phase-1 Specific Issue
During v10 phase-1:
1. Direction head is **frozen** (`requires_grad=False`)
2. DOA loss is **zero** (`lambda_frame_direction = 0.0`)
3. Matching cost weight is **zero** (`frame_match_dir_cost_weight = 0.0`)
4. Therefore: direction head gets **no signal whatsoever**
At v10 phase-1 validation:
1. Direction head outputs are **completely stale** (from v9 epoch 3)
2. Matching still tries to use them but with weight 0.0, so class dominates
3. Direction metrics are computed on **frozen, outdated predictions**
4. Result: **Validation DOA metrics tank** even though direction head wasn't supposed to improve
---
## Part 6: Why Train/Valid Gap is Large
### Root Causes (Ranked by Severity)
#### 1. **No Spatial Data Augmentation** ⚠️⚠️⚠️
- **Impact**: ~40-60% of DOA variance unexplained
- **Mechanism**: Training sees limited DOA combinations; validation has unseen angles
- **Evidence**: All presets default to `spec_augment_*` only, zero rotation support
- **Fix needed**: Add random rotations that:
- Rotate 4-channel FOA waveform in 3D space
- Update azimuth/elevation targets accordingly
- Apply during both train (always) and val (for consistency)
#### 2. **SpecAugment Train-Only** ⚠️⚠️
- **Impact**: ~10-20% of acoustic feature variance
- **Mechanism**: Training acoustic features differ from validation
- **Evidence**: `if not self.training: return w_logmel` in `_apply_spec_augment_w`
- **Fix needed**: Either:
- Disable SpecAugment entirely (simpler, possibly worse)
- Apply same SpecAugment seed at validation (breaks Bayesian interpretation)
- Reduce SpecAugment strength to smaller gap
#### 3. **v10 Phase-1 Freezes Direction Head** ⚠️⚠️
- **Impact**: ~30-40% on ov2/ov3 DOA metrics
- **Mechanism**: Direction head learns only from v9 epoch 3 (before DOA ramp), frozen for 10 epochs, then thawed with wrong initialization
- **Evidence**: `freeze_frame_track_spatial_heads = True`, `lambda_frame_direction = 0.0`
- **Fix needed**:
- Extend DOA ramp to v10 phase-1 instead of full freeze
- Or initialize direction head better when unfrozen
#### 4. **Continuous DOA Regression Sensitivity** ⚠️
- **Impact**: ~5-15% (only when matching is poor)
- **Mechanism**: Cosine distance loss is sensitive to small errors; matching cost ignores spatial during class warmup
- **Evidence**: `loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()` with no binning
- **Fix needed**: None (by design); problem is upstream matching issues
#### 5. **Matching Logic Uses Detached Predictions** ⚠️ (Minor)
- **Impact**: ~2-5%
- **Mechanism**: Gradients don't flow through matching decision; optimization is indirect
- **Evidence**: Lines 1465-1468 use `.detach()`
- **Note**: This is intentional to avoid combinatorial explosion in loss surface; acceptable
---
## Part 7: Actionable Diagnostics
### What to Check in Your Logs
1. **DOA Metrics by Epoch**:
- v9: DOA should improve for epochs 3-7 (ramp phase)
- v10 p1: DOA should **stay flat** (frozen)
- v10 p2: DOA should improve again (unfrozen)
- **Red flag**: DOA regressing at val while training loss decreases
2. **Per-Source Breakdown**:
- ov1 (single-source): should have small train/val gap (~5°)
- ov2 (2-source): should have medium gap (~10°)
- ov3 (3-source): should have large gap (~15-20°)
- **Red flag**: ov1 gap > 10° suggests augmentation issue
3. **Activity vs DOA Alignment**:
- If activity_recall is high but azi_mae is large, suggests **matching is finding sources but estimating wrong angles**
- Check: is `frame_match_dir_cost_weight` being applied correctly?
4. **Class Accuracy vs DOA**:
- Plot class_acc vs azi_mae per epoch
- If class plateaus at epoch 3 but azi_mae continues dropping, DOA head is separable and improvable
- If both plateau together, suggests **shared representation bottleneck**
---
## Part 8: Recommended Fixes (Priority Order)
### Immediate (High Impact, Low Risk)
1. **Add random rotation augmentation**
- Rotate FOA waveform + update (azimuth, elevation) targets
- Apply consistently to both train and val (same seed for val)
- Expected gain: 10-20° DOA@20 improvement on ov3
2. **Disable SpecAugment or make it train/val consistent**
- Option A: Remove SpecAugment (simplest)
- Option B: Apply weak SpecAugment to both train and val
- Expected gain: 2-5° improvement
### Medium (Moderate Impact, Medium Risk)
3. **Extend v10 phase-1 DOA ramp instead of full freeze**
- Don't freeze direction head; instead use lower `lambda_frame_direction` (e.g., 0.1)
- Keep matching cost weight at 0.0 (class-focused matching is intentional)
- Expected gain: 8-15° on ov2/ov3
4. **Better initialization of direction head after freeze**
- When unfreezing in v10 phase-2, use momentum averaging of v9 predictions
- Or apply small warmup on direction loss before full scale
### Advanced (High Impact, High Risk)
5. **Switch DOA representation from regression to soft-label binning**
- Would match azimuth classification approach already in codebase
- Requires architectural changes to prediction heads
- Expected gain: 5-10° (less sensitive to outliers)
6. **Curriculum learning for matching costs**
- Start with class-only matching (v10 p1)
- Gradually increase `frame_match_dir_cost_weight` and `frame_match_dist_cost_weight`
- This is already partially done with loss weights, extend to matching
---
## Appendix: Code References
| Finding | File:Line | Code |
|---------|-----------|------|
| FOA loading | `spatial_dataset.py:310-366` | `_load_audio_file` |
| SpecAugment train-only | `spatial_atst.py:336-347` | `if not self.training: return w_logmel` |
| No rotations | `spatial_dataset.py:*` | grep "rotation" → no results |
| Hungarian matching | `spatial_loss.py:1452-1509` | `_match_frame_slots_per_step` |
| Matching cost | `spatial_loss.py:1491-1504` | `cost[row] = act_cost + cls_nll + dir_cost + dist_cost` |
| Direction loss | `spatial_loss.py:1562-1565` | `loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()` |
| v9 config | `train_spatial_beats.py:2228-2278` | `make_ov1_local_spatial_v9_ov123_top4_config` |
| v10 freeze | `train_spatial_beats.py:2394-2478` | `make_ov1_local_spatial_v10_phase1_cls_config` |
| Spatial freeze | `train_spatial_beats.py:3442-3454` | `freeze_frame_track_spatial_heads` |
| Validation metrics | `spatial_loss.py:1596-1661` | `compute_frame_slot_validation_metrics` |