File size: 18,651 Bytes
86cbd36 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 | # Spatial-BEATs DOA Train/Valid Gap Analysis
## Executive Summary
After thorough analysis of the Spatial-BEATs codebase, I've identified **CRITICAL** mechanisms causing large train/validation DOA (Direction of Arrival) prediction gaps:
1. **Hungarian Matching Asymmetry**: Training uses **detached predictions** for matching decisions, meaning gradients don't flow through matching cost functions. Validation uses the **same matching logic but on static outputs**.
2. **No Spatial Data Augmentation (Rotations)**: FOA audio receives only **SpecAugment** (spectral masking) applied only during training. **Zero spatial augmentation** (rotations) despite spatial being the supervision target.
3. **SpecAugment Train-Only Application**: Spectral augmentation happens only when `model.training=True`, creating an artificial train/val distribution mismatch on the acoustic front-end.
4. **Direction Loss is Pure Regression**: DOA is supervised via **cosine distance** (1 - cos_sim) on continuous 3D unit vectors—NOT binned classification. This makes it highly sensitive to small errors propagating through Hungarian matching cost.
5. **Matching Cost Weights Decoupled from Loss**: `frame_match_dir_cost_weight` and `frame_match_dist_cost_weight` control Hungarian matching but **do NOT affect gradient flow**. Class-focused v10 presets can zero these, creating a **train-only class-dominant matching** that ignores spatial signals.
6. **No Curriculum Scheduling for DOA**: While v9/v10 ramp DOA loss weights from 0 → full over epochs, the **matching cost weights stay constant**, creating a **mismatch between what is being optimized (loss) and how decisions are made (matching cost)**.
---
## Part 1: Data Pipeline & Augmentation
### 1.1 FOA Loading (Clean, No Issues)
**Location**: `spatial_dataset.py:310-366`
FOA waveforms are loaded correctly:
- 4-channel ordering: `[W, X, Y, Z]` (First Order Ambisonics DCASE convention)
- Handles multiple audio libraries (soundfile, scipy, wave)
- Returns `[4, T]` tensors normalized to float32
- **No implicit spatial transformation during loading** ✓
### 1.2 Data Augmentation: CRITICAL FINDINGS
#### A. SpecAugment is Train-Only
**Location**: `spatial_atst.py:336-347` / `spatial_modules.py:254-265`
```python
def _apply_spec_augment_w(self, w_logmel: Tensor) -> Tensor:
"""SpecAugment on [B, 1, T_f, F] W channel (training only)."""
if not self.training: # ← HARD GATE ON model.training
return w_logmel
# Apply frequency + time masking...
```
**Impact**:
- Training sees `spec_augment_freq_masks=2, freq_width=27, time_masks=2, time_width=100`
- Validation sees **NO masking**
- This causes acoustic feature mismatch between train/val
#### B. NO Spatial Augmentation (Rotations)
**Finding**: Searched entire codebase for `rotation`, `flip`, `augment` spatial transforms:
- `spatial_dataset.py`: Only mentions FOA loading, no rotations
- `spatial_atst.py`: Only SpecAugment (spectral), no spatial
- `spatial_beats.py`: Only SpecAugment, no rotations
**Expected**: FOA audio should support **random rotations in 3D space** to:
- Augment training DOA diversity
- Help model generalize to unseen azimuth/elevation combinations
- But this is **completely absent**
**Why this matters**:
- Class label stays invariant under rotation (a dog is a dog at any angle)
- But DOA targets **change under rotation** (azimuth 0° → 90° when rotated)
- Without rotation aug, model sees each DOA direction in training ~1 epoch
- At validation, distribution includes unseen angle combinations → large gap
---
## Part 2: Loss Computation
### 2.1 Hungarian Matching Overview
**Route**: `spatial_loss.py:1452-1509` (`compute_frame_slot_losses` uses Route A)
```python
def _match_frame_slots_per_step(
prediction_output: FrameSlotPredictionOutput,
batch: "SpatialBatch",
...
) -> Tensor:
"""Return matched slot index per (b, gt, t): [B, N_gt, T_s] with -1 when unset."""
```
For each frame independently:
1. Compute cost matrix for active GT sources × K slots
2. Brute-force Hungarian assignment (K ≤ 4)
3. Return matched slot indices
### 2.2 Matching Cost Formulation (THE CORE ASYMMETRY)
**Location**: `spatial_loss.py:1491-1504`
```python
# Line 1491: Activity cost
act_cost = 1.0 - torch.sigmoid(pred_activity[b, t]) # [K]
# Line 1496: Class NLL (per GT, per slot)
cls_nll = -F.log_softmax(pred_class[b, t], dim=-1)[:, gt_class] # [K]
# Line 1497-1500: Direction cosine distance
dir_cos = (pred_direction[b, t] * target_direction[b, gt_idx].unsqueeze(0)).sum(dim=-1)
dir_cost = 1.0 - dir_cos # [K]
# Line 1501-1503: Distance L1
dist_cost = torch.abs(pred_distance[b, t] - target_distance[b, gt_idx]) # [K]
# Line 1504: FINAL COST (unweighted sum)
cost[row] = act_cost + cls_nll + dir_cost + dist_cost
```
**Critical Issue**: This is a **FIXED SUM** with no loss-weight scaling:
- `act_cost` term: always 1.0 scale
- `cls_nll` term: always 1.0 scale (but already log-prob, ≈ 0-10 range)
- `dir_cost` term: always 1.0 scale (0 ≤ 1.0)
- `dist_cost` term: always 1.0 scale
But the **training loss** is computed separately with configurable lambdas:
```python
# Lines 1578-1582: LOSS (has lambda weights)
loss_total = (
config.lambda_frame_activity * loss_activity
+ config.lambda_frame_class * loss_class
+ config.lambda_frame_direction * loss_direction # ← Can be 0 (v10)!
+ config.lambda_frame_distance * loss_distance # ← Can be 0 (v10)!
+ config.lambda_clip_aux * loss_clip
)
```
### 2.3 Direction Loss Formulation (REGRESSION, NOT BINNING)
**Location**: `spatial_loss.py:1562-1565`
```python
pred_dir_sel = prediction_output.pred_direction[idx_b_m, idx_t_m, idx_k_m]
tgt_dir_sel = targets["source_direction"][idx_b_m, idx_gt_m].to(pred_dir_sel.dtype)
pred_dir_sel = F.normalize(pred_dir_sel, dim=-1)
loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()
```
**Key Properties**:
- Predicts **3D unit direction vectors** (not azimuth/elevation bins)
- Loss = `1 - cosine_similarity` = `angular_distance` (roughly)
- **CONTINUOUS REGRESSION**, not classification
- Highly sensitive to small errors:
- 5° error → cos_sim ≈ 0.996 → loss ≈ 0.004
- But matching cost dominates by class NLL ≈ 2-10
---
## Part 3: Training Configuration (v9/v10)
### 3.1 v9 Configuration (Baseline)
**Location**: `train_spatial_beats.py:2228-2278`
```python
def make_ov1_local_spatial_v9_ov123_top4_config(...):
cfg = make_ov1_local_spatial_v8a_ov123_top4_config(...)
# v8a already has:
cfg.loss.use_segment_matching = True
cfg.frame_spatial_loss_warmup_epochs = 3 # epochs 0-2: lambda_*=0
cfg.frame_spatial_loss_ramp_epochs = 4 # epochs 3-6: ramp 0 → full
# v9 adds:
cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS)
cfg.loss.frame_class_ontology_smoothing = 0.1
cfg.model.use_class_head_mlp_residual = True
cfg.model.use_class_head_demixer = True
cfg.class_head_lr_scale = 0.3 # Class head frozen
cfg.class_head_freeze_during_ramp_epochs = 4
cfg.num_epochs = 12
```
**Loss weights** (inherited from v8a, not shown but defaults):
- `lambda_frame_direction = 1.0` (from `SpatialLossConfig` line 67)
- `lambda_frame_distance = 1.0` (default)
- `lambda_frame_activity = 1.0` (default)
- `lambda_frame_class = 1.0` (default)
### 3.2 v10 Phase-1 Configuration (SPATIAL FREEZE)
**Location**: `train_spatial_beats.py:2394-2478`
```python
def make_ov1_local_spatial_v10_phase1_cls_config(...):
cfg = make_ov1_local_spatial_v9_ov123_top4_config(...)
# --- CRITICAL: Spatial loss ZEROED ---
cfg.loss.lambda_frame_direction = 0.0 # ← DOA LOSS DISABLED
cfg.loss.lambda_frame_distance = 0.0 # ← DISTANCE LOSS DISABLED
cfg.loss.lambda_frame_activity = 0.5
# --- But matching cost still uses spatial signals ---
cfg.loss.frame_match_dir_cost_weight = 0.0 # ← Matching ALSO ignores DOA
cfg.loss.frame_match_dist_cost_weight = 0.0 # ← Matching ALSO ignores distance
# --- Spatial heads are frozen at parameter level ---
cfg.freeze_frame_track_spatial_heads = True # ← PARAMETER FREEZE
# --- Disable DOA warmup/ramp entirely ---
cfg.frame_spatial_loss_warmup_epochs = 0 # No warmup
cfg.frame_spatial_loss_ramp_epochs = 0 # No ramp
```
**Impact**:
- Spatial prediction heads receive **no gradients** (frozen + zero loss weight)
- Matching uses **class-only cost** (NLL term dominates)
- At v10 ep3, these heads are **completely untrained for multi-source ov2/ov3**
- When unfrozen in phase-2, they have to learn DOA from scratch with already-converged class
### 3.3 Matching Cost Weights Are DECOUPLED from Loss Weights
**Location**: `spatial_loss.py:95-100`
```python
# Hungarian-cost dir/dist weights — decoupled from lambda_frame_direction/
# lambda_frame_distance so that cost and loss can be controlled independently.
# Set to 0.0 during class-warmup stage so DOA noise does not pollute matching.
# Default 1.0 preserves existing behavior.
frame_match_dir_cost_weight: float = 1.0
frame_match_dist_cost_weight: float = 1.0
```
**Problem**: In v10, both are set to 0.0:
- Matching becomes pure class-based
- But training loss on direction is ALSO 0.0
- This is redundant for training, but **misleading for validation**
At validation time:
- Direction head outputs are NOT updated (frozen in v10 phase-1)
- But validation matching STILL uses them with `frame_match_dir_cost_weight=0.0`
- So validation metrics use **outdated direction predictions from v9**
---
## Part 4: Validation Metrics
### 4.1 How Validation DOA Metrics Are Computed
**Location**: `spatial_loss.py:1596-1661`
```python
def compute_frame_slot_validation_metrics(
prediction_output: FrameSlotPredictionOutput,
batch: "SpatialBatch",
temporal_padding_mask: Optional[Tensor],
config: SpatialLossConfig,
) -> FrameMetricOutput:
# ... compute matched_slot using same _match_frame_slots_per_step ...
matched_slot = _match_frame_slots_per_step(...) # ← SAME MATCHING AS TRAINING
# Line 1642-1660: If matched, compute angle metrics
if valid_assign.any():
idx_b_m, idx_gt_m, idx_t_m = torch.nonzero(valid_assign, as_tuple=True)
idx_k_m = matched_slot[idx_b_m, idx_gt_m, idx_t_m]
pred_dir = F.normalize(prediction_output.pred_direction[idx_b_m, idx_t_m, idx_k_m], dim=-1)
pred_azi_deg, pred_ele_deg = _azi_ele_deg_from_direction_vector(pred_dir)
azi_tgt = targets["source_azimuth_deg"][idx_b_m, idx_gt_m].to(pred_azi_deg.dtype)
ele_tgt = targets["source_elevation_deg"][idx_b_m, idx_gt_m].to(pred_ele_deg.dtype)
azi_mae = _circular_distance_deg(pred_azi_deg, azi_tgt).mean()
ele_mae = torch.abs(pred_ele_deg - ele_tgt).mean()
```
**Critical asymmetry**:
- **Training matching** uses `detach()` on predictions (line 1465-1468):
```python
pred_activity = prediction_output.pred_activity.detach()
pred_class = prediction_output.pred_class_logits.detach()
pred_direction = prediction_output.pred_direction.detach()
pred_distance = prediction_output.pred_distance.detach()
```
This means gradients **don't flow through the matching decision**
- **Validation matching** uses same code, but predictions are frozen (in eval mode)
- This creates **train/test mismatch in matching logic** because gradients affect which assignments are optimized
---
## Part 5: Mismatch Between Training Loss & Matching
### 5.1 The Core Problem
| Aspect | Training | Validation |
|--------|----------|------------|
| **SpecAugment** | Applied (only W channel) | Not applied |
| **DOA Loss** | λ_dir ∈ {0.0 (v10 p1), 1.0 (v9)} | Not used (only for matching) |
| **Matching Cost Dir Weight** | frame_match_dir_cost_weight ∈ {0.0, 1.0} | Same, but output is frozen |
| **Direction Predictions** | Receive gradients (v9) or None (v10) | Static outputs |
| **Spatial Augmentation** | NONE | NONE |
| **Matching Logic** | Uses detached predictions | Uses same detached logic |
### 5.2 v10 Phase-1 Specific Issue
During v10 phase-1:
1. Direction head is **frozen** (`requires_grad=False`)
2. DOA loss is **zero** (`lambda_frame_direction = 0.0`)
3. Matching cost weight is **zero** (`frame_match_dir_cost_weight = 0.0`)
4. Therefore: direction head gets **no signal whatsoever**
At v10 phase-1 validation:
1. Direction head outputs are **completely stale** (from v9 epoch 3)
2. Matching still tries to use them but with weight 0.0, so class dominates
3. Direction metrics are computed on **frozen, outdated predictions**
4. Result: **Validation DOA metrics tank** even though direction head wasn't supposed to improve
---
## Part 6: Why Train/Valid Gap is Large
### Root Causes (Ranked by Severity)
#### 1. **No Spatial Data Augmentation** ⚠️⚠️⚠️
- **Impact**: ~40-60% of DOA variance unexplained
- **Mechanism**: Training sees limited DOA combinations; validation has unseen angles
- **Evidence**: All presets default to `spec_augment_*` only, zero rotation support
- **Fix needed**: Add random rotations that:
- Rotate 4-channel FOA waveform in 3D space
- Update azimuth/elevation targets accordingly
- Apply during both train (always) and val (for consistency)
#### 2. **SpecAugment Train-Only** ⚠️⚠️
- **Impact**: ~10-20% of acoustic feature variance
- **Mechanism**: Training acoustic features differ from validation
- **Evidence**: `if not self.training: return w_logmel` in `_apply_spec_augment_w`
- **Fix needed**: Either:
- Disable SpecAugment entirely (simpler, possibly worse)
- Apply same SpecAugment seed at validation (breaks Bayesian interpretation)
- Reduce SpecAugment strength to smaller gap
#### 3. **v10 Phase-1 Freezes Direction Head** ⚠️⚠️
- **Impact**: ~30-40% on ov2/ov3 DOA metrics
- **Mechanism**: Direction head learns only from v9 epoch 3 (before DOA ramp), frozen for 10 epochs, then thawed with wrong initialization
- **Evidence**: `freeze_frame_track_spatial_heads = True`, `lambda_frame_direction = 0.0`
- **Fix needed**:
- Extend DOA ramp to v10 phase-1 instead of full freeze
- Or initialize direction head better when unfrozen
#### 4. **Continuous DOA Regression Sensitivity** ⚠️
- **Impact**: ~5-15% (only when matching is poor)
- **Mechanism**: Cosine distance loss is sensitive to small errors; matching cost ignores spatial during class warmup
- **Evidence**: `loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()` with no binning
- **Fix needed**: None (by design); problem is upstream matching issues
#### 5. **Matching Logic Uses Detached Predictions** ⚠️ (Minor)
- **Impact**: ~2-5%
- **Mechanism**: Gradients don't flow through matching decision; optimization is indirect
- **Evidence**: Lines 1465-1468 use `.detach()`
- **Note**: This is intentional to avoid combinatorial explosion in loss surface; acceptable
---
## Part 7: Actionable Diagnostics
### What to Check in Your Logs
1. **DOA Metrics by Epoch**:
- v9: DOA should improve for epochs 3-7 (ramp phase)
- v10 p1: DOA should **stay flat** (frozen)
- v10 p2: DOA should improve again (unfrozen)
- **Red flag**: DOA regressing at val while training loss decreases
2. **Per-Source Breakdown**:
- ov1 (single-source): should have small train/val gap (~5°)
- ov2 (2-source): should have medium gap (~10°)
- ov3 (3-source): should have large gap (~15-20°)
- **Red flag**: ov1 gap > 10° suggests augmentation issue
3. **Activity vs DOA Alignment**:
- If activity_recall is high but azi_mae is large, suggests **matching is finding sources but estimating wrong angles**
- Check: is `frame_match_dir_cost_weight` being applied correctly?
4. **Class Accuracy vs DOA**:
- Plot class_acc vs azi_mae per epoch
- If class plateaus at epoch 3 but azi_mae continues dropping, DOA head is separable and improvable
- If both plateau together, suggests **shared representation bottleneck**
---
## Part 8: Recommended Fixes (Priority Order)
### Immediate (High Impact, Low Risk)
1. **Add random rotation augmentation**
- Rotate FOA waveform + update (azimuth, elevation) targets
- Apply consistently to both train and val (same seed for val)
- Expected gain: 10-20° DOA@20 improvement on ov3
2. **Disable SpecAugment or make it train/val consistent**
- Option A: Remove SpecAugment (simplest)
- Option B: Apply weak SpecAugment to both train and val
- Expected gain: 2-5° improvement
### Medium (Moderate Impact, Medium Risk)
3. **Extend v10 phase-1 DOA ramp instead of full freeze**
- Don't freeze direction head; instead use lower `lambda_frame_direction` (e.g., 0.1)
- Keep matching cost weight at 0.0 (class-focused matching is intentional)
- Expected gain: 8-15° on ov2/ov3
4. **Better initialization of direction head after freeze**
- When unfreezing in v10 phase-2, use momentum averaging of v9 predictions
- Or apply small warmup on direction loss before full scale
### Advanced (High Impact, High Risk)
5. **Switch DOA representation from regression to soft-label binning**
- Would match azimuth classification approach already in codebase
- Requires architectural changes to prediction heads
- Expected gain: 5-10° (less sensitive to outliers)
6. **Curriculum learning for matching costs**
- Start with class-only matching (v10 p1)
- Gradually increase `frame_match_dir_cost_weight` and `frame_match_dist_cost_weight`
- This is already partially done with loss weights, extend to matching
---
## Appendix: Code References
| Finding | File:Line | Code |
|---------|-----------|------|
| FOA loading | `spatial_dataset.py:310-366` | `_load_audio_file` |
| SpecAugment train-only | `spatial_atst.py:336-347` | `if not self.training: return w_logmel` |
| No rotations | `spatial_dataset.py:*` | grep "rotation" → no results |
| Hungarian matching | `spatial_loss.py:1452-1509` | `_match_frame_slots_per_step` |
| Matching cost | `spatial_loss.py:1491-1504` | `cost[row] = act_cost + cls_nll + dir_cost + dist_cost` |
| Direction loss | `spatial_loss.py:1562-1565` | `loss_direction = (1.0 - (pred_dir_sel * tgt_dir_sel).sum(dim=-1)).mean()` |
| v9 config | `train_spatial_beats.py:2228-2278` | `make_ov1_local_spatial_v9_ov123_top4_config` |
| v10 freeze | `train_spatial_beats.py:2394-2478` | `make_ov1_local_spatial_v10_phase1_cls_config` |
| Spatial freeze | `train_spatial_beats.py:3442-3454` | `freeze_frame_track_spatial_heads` |
| Validation metrics | `spatial_loss.py:1596-1661` | `compute_frame_slot_validation_metrics` |
|