Add files using upload-large-folder tool
Browse files- .gitattributes +3 -0
- Evaluation_Results/Comparing_Different_Pre-Training_Targets.png +3 -0
- Evaluation_Results/Comparing_with_the_SOTA_Ensemble_Models.png +3 -0
- Evaluation_Results/Comparing_with_the_SOTA_Single_Models.png +3 -0
- __pycache__/BEATs.cpython-310.pyc +0 -0
- __pycache__/BEATs.cpython-312.pyc +0 -0
- __pycache__/modules.cpython-311.pyc +0 -0
- __pycache__/modules.cpython-312.pyc +0 -0
- __pycache__/spatial_beats.cpython-310.pyc +0 -0
- __pycache__/spatial_dataset.cpython-311.pyc +0 -0
- __pycache__/test_vectorized_matching.cpython-311.pyc +0 -0
- docs/00_START_HERE.md +228 -0
- docs/0427_v11_series.md +184 -0
- docs/0429_v11a_with_dynamic.md +475 -0
- docs/V11_QUICK_START.md +345 -0
- docs/gemini.md +63 -0
- docs/spatial_beats_implementation_spec.md +706 -0
- docs/spatial_beats_training_overview.md +608 -0
- docs/v13_honest_postmortem.md +170 -0
- docs/v13_spatial_beats_design.md +528 -0
- docs/v13d_spatial_beats_design.md +333 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Evaluation_Results/Comparing_with_the_SOTA_Single_Models.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Evaluation_Results/Comparing_Different_Pre-Training_Targets.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Evaluation_Results/Comparing_with_the_SOTA_Ensemble_Models.png filter=lfs diff=lfs merge=lfs -text
|
Evaluation_Results/Comparing_Different_Pre-Training_Targets.png
ADDED
|
Git LFS Details
|
Evaluation_Results/Comparing_with_the_SOTA_Ensemble_Models.png
ADDED
|
Git LFS Details
|
Evaluation_Results/Comparing_with_the_SOTA_Single_Models.png
ADDED
|
Git LFS Details
|
__pycache__/BEATs.cpython-310.pyc
ADDED
|
Binary file (4.13 kB). View file
|
|
|
__pycache__/BEATs.cpython-312.pyc
ADDED
|
Binary file (7.34 kB). View file
|
|
|
__pycache__/modules.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
__pycache__/modules.cpython-312.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
__pycache__/spatial_beats.cpython-310.pyc
ADDED
|
Binary file (49.1 kB). View file
|
|
|
__pycache__/spatial_dataset.cpython-311.pyc
ADDED
|
Binary file (76 kB). View file
|
|
|
__pycache__/test_vectorized_matching.cpython-311.pyc
ADDED
|
Binary file (19 kB). View file
|
|
|
docs/00_START_HERE.md
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 START HERE — Spatial-BEATs Documentation Guide
|
| 2 |
+
|
| 3 |
+
**Welcome!** You have access to comprehensive analysis of the Spatial-BEATs codebase. This guide will direct you to exactly what you need.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## ⚡ Quick Pick Your Task
|
| 8 |
+
|
| 9 |
+
### "I have 5 minutes"
|
| 10 |
+
→ Read: [`SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md`](SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md)
|
| 11 |
+
|
| 12 |
+
### "I have 15 minutes"
|
| 13 |
+
→ Read: [`README_DOCUMENTATION_INDEX.md`](README_DOCUMENTATION_INDEX.md) then [`ANALYSIS_COMPLETION_SUMMARY.md`](ANALYSIS_COMPLETION_SUMMARY.md)
|
| 14 |
+
|
| 15 |
+
### "I have 30 minutes"
|
| 16 |
+
→ Choose one:
|
| 17 |
+
- **New to codebase?** Read: Part 1 of [`SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md`](SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md)
|
| 18 |
+
- **Debugging DOA gap?** Read: [`doa_train_valid_gap_analysis.md`](doa_train_valid_gap_analysis.md) Executive Summary
|
| 19 |
+
- **Planning experiments?** Read: [`0427_v11_series.md`](0427_v11_series.md) Section 1-2
|
| 20 |
+
|
| 21 |
+
### "I have 1-2 hours"
|
| 22 |
+
→ Full reading path for your role:
|
| 23 |
+
- **Researcher**: QUICK_REF → 0427_v11_series.md → ANALYSIS Part 2-3
|
| 24 |
+
- **Contributor**: QUICK_REF → ANALYSIS Part 1-2 → Pick component → read code
|
| 25 |
+
- **Investigator**: DOA_GAP Executive → Part 6 → Part 8 → Appendix
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## 📚 The Five Documents
|
| 30 |
+
|
| 31 |
+
### 1. **README_DOCUMENTATION_INDEX.md**
|
| 32 |
+
🏠 **Navigation hub** — Where to find what
|
| 33 |
+
- Use case lookup (choose your problem)
|
| 34 |
+
- Code component quick reference
|
| 35 |
+
- Reading order for different roles
|
| 36 |
+
- Cross-reference guide
|
| 37 |
+
|
| 38 |
+
**👉 Read this first if**: You're not sure where to start
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
### 2. **SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md**
|
| 43 |
+
⚡ **Practitioner's card** — Fast lookup
|
| 44 |
+
- Framework table (4 frameworks, 1 page)
|
| 45 |
+
- Route A/B/C comparison
|
| 46 |
+
- Version series highlights
|
| 47 |
+
- Code locations by component
|
| 48 |
+
- Loss weight patterns
|
| 49 |
+
- When to use each configuration
|
| 50 |
+
|
| 51 |
+
**👉 Read this for**: Quick answers, practitioner reference
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
|
| 55 |
+
### 3. **SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md**
|
| 56 |
+
📖 **Deep technical reference** — Architecture bible
|
| 57 |
+
- Part 1: Four spatial frameworks (Spatial-AST, DCASE SELD, EINV2, DETR)
|
| 58 |
+
- Part 2: Routes A/B/C with full specifications
|
| 59 |
+
- Part 3: Version evolution (v7→v11)
|
| 60 |
+
- Part 4-10: Implementation, configs, metrics, future work
|
| 61 |
+
- Appendix: Code reference table
|
| 62 |
+
|
| 63 |
+
**👉 Read this for**: Deep understanding, architecture details, code paths
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
### 4. **doa_train_valid_gap_analysis.md**
|
| 68 |
+
🔍 **Diagnostic & fix guide** — Root cause analysis
|
| 69 |
+
- Executive Summary: 6 critical mechanisms
|
| 70 |
+
- Part 1: Data pipeline analysis
|
| 71 |
+
- Part 2: Loss computation asymmetry
|
| 72 |
+
- Part 3: Training configuration (v9/v10)
|
| 73 |
+
- Part 4: Validation metrics
|
| 74 |
+
- **Part 6: Root causes ranked by severity**
|
| 75 |
+
- **Part 7: Diagnostic checklist**
|
| 76 |
+
- **Part 8: Recommended fixes (prioritized)**
|
| 77 |
+
- Appendix: Code reference with exact line numbers
|
| 78 |
+
|
| 79 |
+
**👉 Read this for**: Debugging train/val gaps, understanding root causes
|
| 80 |
+
|
| 81 |
+
---
|
| 82 |
+
|
| 83 |
+
### 5. **ANALYSIS_COMPLETION_SUMMARY.md**
|
| 84 |
+
📋 **Executive overview** — What was found
|
| 85 |
+
- Deliverables summary (5 docs, 1,883 lines)
|
| 86 |
+
- Key findings (frameworks, routes, v11 series)
|
| 87 |
+
- Next steps (immediate vs experimental)
|
| 88 |
+
- How to use documents
|
| 89 |
+
- Verification checklist
|
| 90 |
+
|
| 91 |
+
**👉 Read this for**: Overview, decision-making, what comes next
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## 🎯 Choose Your Path
|
| 96 |
+
|
| 97 |
+
### Path 1: "I want to understand the architecture (60 min)"
|
| 98 |
+
1. SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md (5 min)
|
| 99 |
+
2. SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md Part 1-2 (25 min)
|
| 100 |
+
3. Pick a component from Part 6 Appendix, find in code (20 min)
|
| 101 |
+
4. spatial_beats_ov123_frame_routes.md if curious (10 min)
|
| 102 |
+
|
| 103 |
+
**Outcome**: Can navigate codebase, understand paradigms, modify code confidently
|
| 104 |
+
|
| 105 |
+
---
|
| 106 |
+
|
| 107 |
+
### Path 2: "I need to debug a train/val gap (30 min)"
|
| 108 |
+
1. doa_train_valid_gap_analysis.md Executive Summary (2 min)
|
| 109 |
+
2. Part 6: Check which mechanisms apply to your situation (5 min)
|
| 110 |
+
3. Part 7: Diagnostics—check your logs (10 min)
|
| 111 |
+
4. Part 8: Pick a fix priority (5 min)
|
| 112 |
+
5. Appendix: Get code locations (2 min)
|
| 113 |
+
6. SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md if you need to modify (optional)
|
| 114 |
+
|
| 115 |
+
**Outcome**: Root cause identified, fix strategy chosen, code locations ready
|
| 116 |
+
|
| 117 |
+
---
|
| 118 |
+
|
| 119 |
+
### Path 3: "I want to run an experiment (v11 series) (45 min)"
|
| 120 |
+
1. SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md (5 min)
|
| 121 |
+
2. 0427_v11_series.md Section 1-2 (15 min)
|
| 122 |
+
3. SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md Part 3 (15 min)
|
| 123 |
+
4. 0427_v11_series.md Part 4 (verification method) (5 min)
|
| 124 |
+
5. Copy shell script from QUICK_REF (5 min)
|
| 125 |
+
|
| 126 |
+
**Outcome**: Experiment ready to launch, understanding of success metrics
|
| 127 |
+
|
| 128 |
+
---
|
| 129 |
+
|
| 130 |
+
### Path 4: "I'm new, I want to understand everything (2 hours)"
|
| 131 |
+
1. SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md (10 min)
|
| 132 |
+
2. SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md Part 1-3 (40 min)
|
| 133 |
+
3. spatial_beats_ov123_frame_routes.md (25 min)
|
| 134 |
+
4. spatial_beats_training_overview.md (20 min)
|
| 135 |
+
5. Pick component, trace through code with Part 6 references (20 min)
|
| 136 |
+
6. doa_train_valid_gap_analysis.md Part 6 for context (5 min)
|
| 137 |
+
|
| 138 |
+
**Outcome**: Comprehensive understanding of system, ready to contribute
|
| 139 |
+
|
| 140 |
+
---
|
| 141 |
+
|
| 142 |
+
## 🔑 Key Findings at a Glance
|
| 143 |
+
|
| 144 |
+
### 4 Spatial Frameworks in Codebase
|
| 145 |
+
- **Spatial-AST**: Task tokens (pre-trunk)
|
| 146 |
+
- **DCASE SELD**: Per-class activity+DOA
|
| 147 |
+
- **EINV2**: Learnable track queries
|
| 148 |
+
- **DETR**: Per-frame K-slot allocation
|
| 149 |
+
|
| 150 |
+
### 3 Parallel Routes (A/B/C)
|
| 151 |
+
- **Route A**: Per-frame K-slot, per-step Hungarian
|
| 152 |
+
- **Route B**: Learnable queries, clip-level Hungarian (PRODUCTION v9)
|
| 153 |
+
- **Route C**: Per-class vectors (PROTOTYPE, v11c test)
|
| 154 |
+
|
| 155 |
+
### DOA Train/Val Gap Root Causes
|
| 156 |
+
1. ⚠️⚠️⚠️ **ZERO spatial augmentation (rotations)** — 40-60% of variance
|
| 157 |
+
2. ⚠️⚠️ **SpecAugment train-only** — 10-20% variance
|
| 158 |
+
3. ⚠️⚠️ **v10 freezes direction head** — 30-40% on multi-source
|
| 159 |
+
4. ⚠️ **Regression sensitivity** — 5-15% variance
|
| 160 |
+
5. ⚠️ **Detached prediction asymmetry** — 2-5% variance
|
| 161 |
+
|
| 162 |
+
### v11 Experiments (Parallel Runs)
|
| 163 |
+
- **v11a**: DOA demixer → ov2 angles ↓ 5pp+
|
| 164 |
+
- **v11b**: LocalSpatial pre-pool → test IV necessity
|
| 165 |
+
- **v11c**: ACCDOA paradigm → ov3 binding ↓ 5pp+
|
| 166 |
+
- **v11d**: Post-hoc calibration → ov1 ranking ↑ 5pp+
|
| 167 |
+
|
| 168 |
+
---
|
| 169 |
+
|
| 170 |
+
## 📞 FAQ
|
| 171 |
+
|
| 172 |
+
**Q: Where do I find the direction head loss?**
|
| 173 |
+
A: `SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md` Appendix → search "direction loss" → `spatial_loss.py:1562-1565`
|
| 174 |
+
|
| 175 |
+
**Q: What's the difference between routes?**
|
| 176 |
+
A: Compare table in `SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md` or detailed Part 2 of `SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md`
|
| 177 |
+
|
| 178 |
+
**Q: Should I implement fix #1, #2, or #3?**
|
| 179 |
+
A: Read `doa_train_valid_gap_analysis.md` Part 6, pick based on your gap size and risk tolerance.
|
| 180 |
+
|
| 181 |
+
**Q: How do I run v11a?**
|
| 182 |
+
A: Shell script in `SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md` v11 section + spec in `0427_v11_series.md` Section 2.2
|
| 183 |
+
|
| 184 |
+
**Q: I'm stuck on a component. Where's the code?**
|
| 185 |
+
A: `SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md` Part 6 has complete reference table with file:line for every component.
|
| 186 |
+
|
| 187 |
+
---
|
| 188 |
+
|
| 189 |
+
## 🎁 You Now Have
|
| 190 |
+
|
| 191 |
+
✅ **Navigation guide** for all documents
|
| 192 |
+
✅ **Quick reference card** with all the essentials
|
| 193 |
+
✅ **Architecture bible** with code paths
|
| 194 |
+
✅ **Diagnostic guide** for train/val gaps
|
| 195 |
+
✅ **Experimental specifications** for v11 series
|
| 196 |
+
✅ **Comprehensive metadata** (1,883 lines, 77KB)
|
| 197 |
+
✅ **All findings tied to exact code locations**
|
| 198 |
+
|
| 199 |
+
---
|
| 200 |
+
|
| 201 |
+
## 🚀 Next Steps
|
| 202 |
+
|
| 203 |
+
1. **Choose your path above** based on how much time you have
|
| 204 |
+
2. **Follow the reading order** in that path
|
| 205 |
+
3. **Use cross-references** when you need more detail
|
| 206 |
+
4. **Check Appendices** for exact code locations
|
| 207 |
+
5. **Reference Part 6/Part 8** when implementing
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## 📊 Document Overview
|
| 212 |
+
|
| 213 |
+
| Document | Size | Time | Purpose |
|
| 214 |
+
|----------|------|------|---------|
|
| 215 |
+
| README_DOCUMENTATION_INDEX | 12KB | 5-10m | Navigation hub |
|
| 216 |
+
| SPATIAL_FRAMEWORKS_QUICK_REFERENCE | 7KB | 5-10m | Quick lookup |
|
| 217 |
+
| SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS | 28KB | 30-45m | Deep reference |
|
| 218 |
+
| doa_train_valid_gap_analysis | 19KB | 20-30m | Diagnostics |
|
| 219 |
+
| ANALYSIS_COMPLETION_SUMMARY | 11KB | 10m | Executive summary |
|
| 220 |
+
| **TOTAL** | **77KB** | **2-4 hours** | **Complete set** |
|
| 221 |
+
|
| 222 |
+
---
|
| 223 |
+
|
| 224 |
+
**Status**: ✅ Complete and ready for use
|
| 225 |
+
**Created**: 2026-04-27
|
| 226 |
+
**Next update**: After v11 experiments
|
| 227 |
+
|
| 228 |
+
👉 **Pick your path above and start reading!**
|
docs/0427_v11_series.md
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 2026-04-27 — v11 系列实验:DOA demixer / ACCDOA 范式 / 校准对照
|
| 2 |
+
|
| 3 |
+
本文档对应「v9 之后做什么」的对话定型。把 docs/0424.md 的诊断结论拆成四个独立、可单独评估、可并行跑的实验:v11a / v11b / v11c / v11d。
|
| 4 |
+
|
| 5 |
+
## 1. 上一轮诊断回顾
|
| 6 |
+
|
| 7 |
+
参考 docs/0423.md(v9 design)+ docs/0424.md(v9 real dump 拆解)。三个真实 split 的失败模式互不重合:
|
| 8 |
+
|
| 9 |
+
| split | 主要症状 | 责任层 |
|
| 10 |
+
| -------- | --------------------------------------------- | ------------------------------- |
|
| 11 |
+
| real_ov1 | raw 4-track 100% 同类候选,但 act>=0.5 后 37% GT 丢同类预测 | 排序 / activity 校准(**非架构**) |
|
| 12 |
+
| real_ov2 | 73.9% 的预测「同类有,但角度 >20°」 | direction head 本体(**架构**) |
|
| 13 |
+
| real_ov3 | 24.5% raw 层无同类候选 + avg_pred 1.82 < avg_gt 2.88 | binding(query→source)+ 少亮轨(**架构**) |
|
| 14 |
+
|
| 15 |
+
v9 已经给 class head 加过 `ClassHeadSpectralDemixer`(Fix C:track latent → BEATs trunk pre-pool grid 的频率轴 cross-attn,零门控残差)。但 **direction / distance head 没有对应通路**——它们只看 `track_time_features` 这一个 D 维向量,输入再往前是 `fused_spatial_embeddings`,已经被 `FrequencyPool(mean)` 平均掉。多源情况下,频率维被平均后单 D 向量无法同时表达两个方向,这是 real_ov2 的物理根因。
|
| 16 |
+
|
| 17 |
+
## 2. 实验定义
|
| 18 |
+
|
| 19 |
+
### v11a — DOA / 距离对称 demixer(最小改动)
|
| 20 |
+
|
| 21 |
+
**目的**:直接验证「v9 Fix C 没补到 DOA」是否就是 real_ov2 angle 错的根因。
|
| 22 |
+
|
| 23 |
+
**做法**:
|
| 24 |
+
|
| 25 |
+
* 在 `FrameTrackPredictionHeads` 里再加一个 `ClassHeadSpectralDemixer` 实例(参数独立,结构完全相同),命名 `spatial_head_demixer`,作用于 `direction_head` 和 `distance_head` 的输入。
|
| 26 |
+
* KV 与 class demixer 共用:BEATs trunk 的 pre-pool tokens `[B, T_p*F_p, D]` + grid_size `(T_p, F_p)` + pre-pool 时间 mask。
|
| 27 |
+
* 零门控初始化:`out_proj.weight = out_proj.bias = 0`,`gate = 1e-2`,所以 epoch-0 forward 与 v9 bit-equivalent。
|
| 28 |
+
|
| 29 |
+
**热启动**:`RESUME_CKPT` = v9 best.pt;`strict=False`;新增 13 个参数走默认零门控初值。`--no-resume-optimizer --reset-epoch-on-resume --reset-best-on-resume`。
|
| 30 |
+
|
| 31 |
+
**预期**:
|
| 32 |
+
|
| 33 |
+
* real_ov2:`class_right_angle_wrong` 从 73.9% 显著下降;`mean_best_angle_when_same_class_exists` 缩小。
|
| 34 |
+
* sim split(`valid__hm3d__`)的 `F20 / LE_CD / ocls` 不退化(零门控保证 epoch-0 安全,训练只能修不能炸)。
|
| 35 |
+
* real_ov1 / real_ov3 不一定改善——它们卡的不是 DOA 本身。
|
| 36 |
+
|
| 37 |
+
**证明的事**:v9 在 class 上加 demixer 是有效的,但 DOA 也需要同样的通路才能拿到等价收益;后 frequency_pool 单向量是 DOA 多源 demix 的硬瓶颈。
|
| 38 |
+
|
| 39 |
+
**入口**:`run_ov1_v11a_ov123_top4.sh` → preset `ov1_local_spatial_v11a_ov123_top4`。
|
| 40 |
+
|
| 41 |
+
### v11b — DOA demixer 的 KV 换成 LocalSpatial pre-pool
|
| 42 |
+
|
| 43 |
+
**目的**:进一步追问——v11a 的 KV 是 BEATs 的 mono fbank pre-pool,本身没有方向信息(IV 是后面 `local_spatial_fuser` 才混进来的);如果让 DOA demixer 直接 attend 到 7 通道 FOA + IV 的 CNN pre-pool,会不会比 v11a 更好?
|
| 44 |
+
|
| 45 |
+
**做法**:
|
| 46 |
+
|
| 47 |
+
* 让 `LocalSpatialEncoder.forward(foa_feat, return_pre_pool=True)` 额外返回 4D CNN 特征 `[B, D_s, T_f, F_cnn]`(在 `mean(dim=-1)` 频率塌缩之前)。
|
| 48 |
+
* `build_local_spatial_fusion(..., return_local_pre_pool=True)` 透传,并 reshape 到 `[B, T_f*F_cnn, D_s]` + grid `(T_f, F_cnn)`。
|
| 49 |
+
* 新增 `local_spatial_pre_pool_proj: Linear(D_s -> D=768)`,xavier 初始化乘以 `local_spatial_proj_scale_init`,bias=0。
|
| 50 |
+
* `FrameTrackPredictionHeads.forward` 接 `spatial_pre_pool_features / spatial_pre_pool_grid_size / spatial_pre_pool_time_mask`;当传入时,DOA demixer 用这条 KV,否则回落到 v11a 的 BEATs trunk pre-pool。
|
| 51 |
+
|
| 52 |
+
**热启动**:同样 v9 best.pt + strict=False。`spatial_head_demixer` 仍然零门控;`local_spatial_pre_pool_proj` 是新参数,但 demixer gate=1e-2,`out_proj=0` 保证 epoch-0 数值与 v9 一致。
|
| 53 |
+
|
| 54 |
+
**预期**:
|
| 55 |
+
|
| 56 |
+
* 如果 v11b 比 v11a 在 real_ov2 上明显更好 → 物理 IV 信号确实是 DOA 必要输入,BEATs trunk pre-pool 信息不足。
|
| 57 |
+
* 如果 v11b ≈ v11a → BEATs trunk pre-pool 已经携带足够 spatial 上下文(`local_spatial_fuser` 把 IV 混回去了),DOA 头的瓶颈只在 demixer 本身。
|
| 58 |
+
* 如果 v11b < v11a → 新 KV 引入太多噪声 / projection 没充分训。
|
| 59 |
+
|
| 60 |
+
**证明的事**:架构里「方向先验来自哪里」的问题——是 fuser 后已经够,还是必须 pre-fuser 直读 IV。
|
| 61 |
+
|
| 62 |
+
**入口**:`run_ov1_v11b_ov123_top4.sh` → preset `ov1_local_spatial_v11b_ov123_top4`。
|
| 63 |
+
|
| 64 |
+
### v11c — ACCDOA 范式对照
|
| 65 |
+
|
| 66 |
+
**目的**:质询「K-track DETR 范式」本身。v11a/v11b 都在补 head;但 real_ov3 24.5% 的 GT 在 raw 4-track 里就找不到同类候选——这意味着问题不在 head,而在 **��哪个 query 该负责哪个 source」** 的 binding 阶段。把范式整体换掉看是否绕得过去。
|
| 67 |
+
|
| 68 |
+
**做法**:
|
| 69 |
+
|
| 70 |
+
* `readout_scheme = local_spatial_accdoa`,per-class 3D 向量场 `v_c` :`||v_c||` = activity_c,`v_c/||v_c||` = DOA_c。无 query、无 Hungarian。
|
| 71 |
+
* ov2/ov3 的 same-class-in-same-frame 几乎为零,所以 per-class 输出是无歧义的。
|
| 72 |
+
* 接现有 `accdoa_heads` + `compute_frame_accdoa_losses`(仓库已实现)。
|
| 73 |
+
|
| 74 |
+
**冷启动**:拓扑与 v9 不兼容(无 `source_query_decoder`、无 `FrameTrackPredictionHeads`)。改用 `--init-from-spatial-ckpt` 从 ov1 local_spatial warmup ckpt 初始化(继承 BEATs trunk + LocalSpatialEncoder + fuser),strict=False。
|
| 75 |
+
|
| 76 |
+
**调度**:24 epochs,`lr=3e-5`(默认 1e-4 是 ov1 单源用的,没在多源上调过;3e-5 与 v9 同档,更安全)。
|
| 77 |
+
|
| 78 |
+
**预期**:
|
| 79 |
+
|
| 80 |
+
* real_ov3 `no_same_class_pred_but_other_preds_exist` 显著下降:query binding 不存在了,每个类天然有自己的 vector slot。
|
| 81 |
+
* real_ov2 也可能改善——因为 DOA 来自 per-class 向量,已经按类分离,避免多源 head 共享单 D 向量。
|
| 82 |
+
* sim ov1 上 ocls 可能略低于 v9(ACCDOA 把 activity/DOA 耦合,class 信号被向量幅值稀释);这个代价是已知的。
|
| 83 |
+
|
| 84 |
+
**证明的事**:query-binding 阶段是不是 real_ov3 的真正瓶颈。如果 v11c 就把 ov3 拉起来了,说明 head 修补(v11a/v11b)治不好 ov3;如果 v11c 也救不了 ov3,说明问题在更前——可能是 `LocalSpatialEncoder` 时空分辨率不够。
|
| 85 |
+
|
| 86 |
+
**入口**:`run_ov1_v11c_ov123_accdoa.sh` → preset `ov1_local_spatial_v11c_ov123_accdoa`。
|
| 87 |
+
|
| 88 |
+
### v11d — activity 校准 + Top-K̂ 解码(纯后处理)
|
| 89 |
+
|
| 90 |
+
**目的**:real_ov1 的失血点在「raw 100% 有同类候选 → act>=0.5 后 37% 丢失」,是阈值 / 排序问题,**不是模型**。所以不改架构、不重训,只改 decode。同步给 v9 / v11a / v11b 出个 pareto,避免后续把 ranking 收益错算到 head 改动上。
|
| 91 |
+
|
| 92 |
+
**做法**:
|
| 93 |
+
|
| 94 |
+
`scripts/calibrate_activity.py` 重读已 dump 的 `*__pred.csv`(这些 CSV 里 `eval_v7k_real_valid.py` 已经写入 `activity_prob` 和——v10 head 在的话——`num_active_pred`)。三种 decode 模式:
|
| 95 |
+
|
| 96 |
+
* `threshold`:固定阈值(扫 0.3 / 0.4 / 0.5 / 0.6)。
|
| 97 |
+
* `topk_hat`:每帧按 `activity_prob` 降序取前 K̂ 个,K̂ = `num_active_pred`(v10 的 num_active_head argmax)。
|
| 98 |
+
* `topk_hat_min`:上面两条的 AND(K̂ 之内还要过最低阈值)。
|
| 99 |
+
|
| 100 |
+
每个 (split, mode, thr) 计算与 `analyze_csv_dump.py` 同口径的指标:`hit_share`、`class_right_angle_wrong`、`matched_tp_precision/recall`、`mean_best_angle_when_same_class_exists`。再用 `pick_best` 给每个 split 找最优配置。
|
| 101 |
+
|
| 102 |
+
**输入**:任何已经包含 `__pred.csv / __gt.csv` 的目录(v9 best 的 dump、v11a/b/c 任一 epoch 的 dump 都行)。
|
| 103 |
+
|
| 104 |
+
**预期**:
|
| 105 |
+
|
| 106 |
+
* real_ov1:降阈值或换 `topk_hat` 后 `no_same_class_pred_but_other_preds_exist` 显著下降;存在「降阈值 → recall 大涨而 precision 小掉」的清晰拐点。
|
| 107 |
+
* real_ov2 / real_ov3:阈值调节收益有限——它们的问题不在 ranking。
|
| 108 |
+
* 出一个表,给每个 split 单独决定上线 decode 配置(线上不必三 split 共用同一阈值)。
|
| 109 |
+
|
| 110 |
+
**证明的事**:real_ov1 的 37% 丢失是 decode 过紧 / activity 分布漂移的可校正问题,**不需要拿训练侧资源去解**;同时给 v11a/v11b/v11c 报告时分离「ranking 收益」与「head 改动收益」,避免归因混淆。
|
| 111 |
+
|
| 112 |
+
**入口**:`scripts/calibrate_activity.py --dump-dir <csv_dir>`。
|
| 113 |
+
|
| 114 |
+
## 3. 改动清单
|
| 115 |
+
|
| 116 |
+
代码改动集中在四个文件:
|
| 117 |
+
|
| 118 |
+
* `spatial_modules.py`
|
| 119 |
+
* `LocalSpatialEncoder.forward` 新增 `return_pre_pool=False` 参数。
|
| 120 |
+
* `FrameTrackPredictionHeads.__init__` 新增 5 个 spatial demixer 配置项;`forward` 新增 3 个 `spatial_pre_pool_*` kwargs,复用 `ClassHeadSpectralDemixer` 类作 `spatial_head_demixer`。
|
| 121 |
+
* `spatial_beats.py`
|
| 122 |
+
* `SpatialBEATsConfig` 新增 5 个字段:`use_spatial_head_demixer / spatial_head_demixer_layers / heads / dropout / spatial_demixer_use_local_spatial_kv`。
|
| 123 |
+
* `__init__` 两处 `FrameTrackPredictionHeads` 构造点透传新 kwargs;同处按需创建 `local_spatial_pre_pool_proj`。
|
| 124 |
+
* `build_local_spatial_fusion` 新增 `return_local_pre_pool` 选项,6-tuple 返回;reshape 4D CNN feature 到 `[B, T_f*F_cnn, D_s]`。
|
| 125 |
+
* `forward` 两条 readout 分支(`local_spatial`、`local_spatial_track`)按配置切换 6-tuple,并把 `local_spatial_pre_pool_proj` 投影后的张量传给 frame_track head。
|
| 126 |
+
* `train_spatial_beats.py`
|
| 127 |
+
* 新增 3 个 preset factory:`make_ov1_local_spatial_v11a_ov123_top4_config`(继承 v9 + `use_spatial_head_demixer=True`),`v11b`(v11a + `spatial_demixer_use_local_spatial_kv=True`),`v11c`(包装 `make_ov123_local_spatial_accdoa_config` + 24 epochs + lr=3e-5)。
|
| 128 |
+
* 三处 dispatch 分支 + argparse `--preset` choices。
|
| 129 |
+
* `scripts/calibrate_activity.py`(新文件,纯 stdlib)
|
| 130 |
+
* 复用 `analyze_csv_dump.py` 的几何 / 计数逻辑,但读 `num_active_pred` 列做 Top-K̂ 解码;输出 per-split 最优配置。
|
| 131 |
+
|
| 132 |
+
新增 shell 入口:
|
| 133 |
+
|
| 134 |
+
* `run_ov1_v11a_ov123_top4.sh`(master_port 29561)
|
| 135 |
+
* `run_ov1_v11b_ov123_top4.sh`(master_port 29562)
|
| 136 |
+
* `run_ov1_v11c_ov123_accdoa.sh`(master_port 29563)
|
| 137 |
+
|
| 138 |
+
## 4. 验证方法
|
| 139 |
+
|
| 140 |
+
每个实验跑完后用同一套口径回验:
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
# 1. dump real valid CSV
|
| 144 |
+
python3 scripts/eval_v7k_real_valid.py \
|
| 145 |
+
--ckpt <ckpt_path> \
|
| 146 |
+
--dump-pred-dir <dump_dir> \
|
| 147 |
+
--dump-splits real_ov1,real_ov2,real_ov3 \
|
| 148 |
+
--activity-threshold 0.5
|
| 149 |
+
|
| 150 |
+
# 2. 阈值口径下的指标
|
| 151 |
+
python3 scripts/analyze_csv_dump.py \
|
| 152 |
+
--dump-dir <dump_dir> \
|
| 153 |
+
--threshold 0.5 \
|
| 154 |
+
--threshold-sweep 0.3 0.4 0.6
|
| 155 |
+
|
| 156 |
+
# 3. (v11d 用) 全 decode 模式扫描
|
| 157 |
+
python3 scripts/calibrate_activity.py \
|
| 158 |
+
--dump-dir <dump_dir> \
|
| 159 |
+
--thresholds 0.3 0.4 0.5 0.6 \
|
| 160 |
+
--modes threshold topk_hat topk_hat_min \
|
| 161 |
+
--json-out calibration.json
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
判 pass 的硬指标:
|
| 165 |
+
|
| 166 |
+
| 实验 | 主要观测 | 副作用看护 |
|
| 167 |
+
| ---- | ---------------------------------------------------------- | -------------------------------- |
|
| 168 |
+
| v11a | real_ov2 `class_right_angle_wrong` 下降 ≥ 5 pp | sim ov1 ocls / F20 不退化 |
|
| 169 |
+
| v11b | real_ov2 / real_ov3 vs v11a 是否更优 | 同 v11a + projection 不爆显存 |
|
| 170 |
+
| v11c | real_ov3 `no_same_class_pred_but_other_preds_exist` 下降 ≥ 5 pp | sim ov1 ocls 退化 < 3 pp |
|
| 171 |
+
| v11d | real_ov1 `hit_share` 在最佳 decode 下提升 ≥ 5 pp | precision 不崩(>= 0.6×baseline) |
|
| 172 |
+
|
| 173 |
+
## 5. 跑的顺序建议
|
| 174 |
+
|
| 175 |
+
1. **v11a**(最便宜、最可能直接命中 real_ov2,hot-start 完整保留)。
|
| 176 |
+
2. **v11d** 并行:用 v9 best 的现成 dump 出 ranking pareto,拿 real_ov1 的「decode 收益基线」。这样 v11a 的 dump 出来后能立刻分离「DOA 收益」vs「decode 收益」。
|
| 177 |
+
3. **v11b**:仅在 v11a 收益不达标、或想验「IV 直读 vs fuser 后」时启动。
|
| 178 |
+
4. **v11c**:作为范式对照单独跑,主要看 real_ov3。和 v11a/v11b 不可比同一基础(拓扑不同),但 sim ov1 应保持在可接受退化内。
|
| 179 |
+
|
| 180 |
+
## 6. 不在本轮范围内的事
|
| 181 |
+
|
| 182 |
+
* K=4 → 6/8 + query 正交正则(计划里的 P3)。代价高且需要重启动;放在 v11c 结果出来之后再决定要不要做。
|
| 183 |
+
* `SourceQueryDecoder` memory 升级为 `[B, T_s*F_p, D]` pre-freq-pool(计划里的 P2a)。改动面太大,先用 v11a/v11b 的 head-side demixer 取等效收益。
|
| 184 |
+
* 真实数据 finetune 调度(`run_ov1_v9_real_balanced_*.sh` 已经在跑),与 v11 系列正交。
|
docs/0429_v11a_with_dynamic.md
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 0429 · v11a_with_dynamic_10hz —— 动态 DOA 监督 + 真实/QA 数据扩展
|
| 2 |
+
|
| 3 |
+
> 记录当前 `ov1_local_spatial_v11a_with_dynamic_10hz` preset 的完整链路:
|
| 4 |
+
> 数据 → loader → 模型 → loss;并标明相对于 `v11a_real_balanced_10hz` 的每一处差异。
|
| 5 |
+
> 对应代码入口:
|
| 6 |
+
> - Preset: `train_spatial_beats.py::make_ov1_local_spatial_v11a_with_dynamic_10hz_config`
|
| 7 |
+
> - Run script: `run_ov1_v11a_with_dynamic_10hz.sh`
|
| 8 |
+
> - DCASE 转换器: `tools/dcase_starss_to_jsonl.py`
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## 1. 一句话定位
|
| 13 |
+
|
| 14 |
+
**v11a_with_dynamic_10hz = v11a_real_balanced_10hz 的训练数据扩展 + loader/loss 升级到逐帧 target。**
|
| 15 |
+
|
| 16 |
+
- **模型结构不变**:仍然是 `local_spatial_track` + `SourceQueryDecoder (K=4)` + spatial_head_demixer
|
| 17 |
+
(v11a 的新组件),10 Hz token rate,ov123 top4 目录。
|
| 18 |
+
- **预测侧不变**:`FrameTrackPredictionOutput` 仍然是 `[B, K, T_s, ...]` 的逐帧四元组
|
| 19 |
+
`(activity, class, direction, distance)`。
|
| 20 |
+
- **监督侧升级**:target 张量从 `[B, N_gt]` 扩展为 `[B, N_gt, T_s]`——静态源沿 T_s 轴广播
|
| 21 |
+
(和旧行为一致),动态源按每帧轨迹线性插值到 10 Hz 栅格。
|
| 22 |
+
- **数据扩展**:新增 5 个训练 manifest(qa_moving / qa_counting / qa_lr_pair /
|
| 23 |
+
qa_same_doa / dcase_starss_foa.train)和 1 个验证 manifest(dcase_starss_foa.valid)。
|
| 24 |
+
- **Hot-start**:默认从 `v11a_real_balanced_10hz/03_ov123_top4/best.pt` 继续训练,`strict=False`
|
| 25 |
+
且不继承 optimizer/epoch/best。上游 ov123 静态 clip 的 epoch 0 loss 应与 v11a 吻合
|
| 26 |
+
(逐帧 target 对静态源退化为标量广播)。
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
## 2. 数据:新增的 manifest 与样本格式
|
| 31 |
+
|
| 32 |
+
### 2.1 训练集组成
|
| 33 |
+
|
| 34 |
+
| manifest | 记录数 | 类型 | DOA 来源 | distance 有效 | 复制次数 |
|
| 35 |
+
| --- | --- | --- | --- | --- | --- |
|
| 36 |
+
| `ov1_foa.jsonl` (sim) | — | ov1 静态 | scalar | ✓ | 1 |
|
| 37 |
+
| `ov2_foa.jsonl` (sim) | — | ov2 静态 | scalar | ✓ | 3 |
|
| 38 |
+
| `ov3_foa.jsonl` (sim) | — | ov3 静态 | scalar | ✓ | 3 |
|
| 39 |
+
| `ov1_real_static_foa_mapped.jsonl` | — | ov1 real 静态 | scalar | ✗ (null) | 4 |
|
| 40 |
+
| `ov2_real_static_foa_mapped.jsonl` | — | ov2 real 静态 | scalar | ✗ | 8 |
|
| 41 |
+
| `ov3_real_static_foa_mapped.jsonl` | — | ov3 real 静态 | scalar | ✗ | 8 |
|
| 42 |
+
| **`qa_moving.jsonl`** | 19 597 | QA sim 动态(单源平滑轨迹) | `frames[]` per-frame | ✓ | **2** |
|
| 43 |
+
| **`qa_counting.jsonl`** | 2 428 | QA sim 静态(多源 2-5 个) | scalar | ✓ | **1** |
|
| 44 |
+
| **`qa_lr_pair.jsonl`** | 6 631 | QA sim 静态(左右成对) | scalar | ✓ | **1** |
|
| 45 |
+
| **`qa_same_doa.jsonl`** | 7 896 | QA sim 静态(同一方向多源) | scalar | ✓ | **1** |
|
| 46 |
+
| **`dcase_starss_foa.train.jsonl`** | 12 805 | DCASE 真录 20s 多源动态 | `frames[]` per-frame | ✗ (-1) | **2** |
|
| 47 |
+
|
| 48 |
+
对应的 `train_manifest_replication = (1, 3, 3, 4, 8, 8, 2, 1, 1, 1, 2)`。
|
| 49 |
+
- 总训练 clip 数(未复制)= ov123sim + ov123real + 5 × 新 manifest ≈ ov123 基础 + 49 357。
|
| 50 |
+
- 验证集相比 v11a 增加 `dcase_starss_foa.valid.jsonl` (4 560 clips),作为真实录音的统一评估入口。
|
| 51 |
+
|
| 52 |
+
### 2.2 manifest schema(以 qa_moving / DCASE 为例)
|
| 53 |
+
|
| 54 |
+
**qa_moving.jsonl**(来自 `build_qa_foa_moving.py` 合成管道)
|
| 55 |
+
|
| 56 |
+
```jsonc
|
| 57 |
+
{
|
| 58 |
+
"scene_id": "...",
|
| 59 |
+
"output_foa_path": "/abs/.../foa.wav",
|
| 60 |
+
"output_duration_seconds": 10.0,
|
| 61 |
+
"sample_rate": 16000,
|
| 62 |
+
"frame_rate": 10.0,
|
| 63 |
+
"num_frames": 100,
|
| 64 |
+
"sources": [
|
| 65 |
+
{
|
| 66 |
+
"source_index": 0,
|
| 67 |
+
"is_moving": true,
|
| 68 |
+
"mono_target_label": "speech", // FSD50K 63-class 名字
|
| 69 |
+
"active_time": [0.0, 10.0],
|
| 70 |
+
"doa": null, // ← 动态源 scalar doa 为空
|
| 71 |
+
"distance_cm": 150.0, // clip 级 fallback 距离
|
| 72 |
+
"trajectory": "...", // sweep_arc / lshape / ...
|
| 73 |
+
"frames": [
|
| 74 |
+
{"frame_idx": 0, "doa": {"azimuth_deg": 175.6, "elevation_deg": -3.2},
|
| 75 |
+
"distance_cm": 150.2},
|
| 76 |
+
{"frame_idx": 1, "doa": {"azimuth_deg": 177.9, "elevation_deg": -3.1},
|
| 77 |
+
"distance_cm": 150.1},
|
| 78 |
+
...
|
| 79 |
+
]
|
| 80 |
+
}
|
| 81 |
+
]
|
| 82 |
+
}
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
**dcase_starss_foa.{train,valid,test}.jsonl**(由 `tools/dcase_starss_to_jsonl.py` 生成)
|
| 86 |
+
|
| 87 |
+
```jsonc
|
| 88 |
+
{
|
| 89 |
+
"scene_id": "fold1_starss22__fold4_room10_mix001_0",
|
| 90 |
+
"dataset_source": "dcase_starss",
|
| 91 |
+
"split": "train",
|
| 92 |
+
"output_foa_path": "/abs/.../foa.wav",
|
| 93 |
+
"output_duration_seconds": 20.0,
|
| 94 |
+
"sample_rate": 16000,
|
| 95 |
+
"frame_rate": 10.0,
|
| 96 |
+
"sources": [
|
| 97 |
+
{
|
| 98 |
+
"source_index": 0,
|
| 99 |
+
"is_moving": true,
|
| 100 |
+
"dcase_class_idx": 5,
|
| 101 |
+
"dcase_source_idx": 3,
|
| 102 |
+
"mono_target_label": "speech", // 经 DCASE_TO_FSD50K 重映射
|
| 103 |
+
"mono_primary_label": "male_speech",
|
| 104 |
+
"active_time": [1.3, 4.8],
|
| 105 |
+
"full_time": [0.0, 20.0],
|
| 106 |
+
"doa": null,
|
| 107 |
+
"distance_cm": -1, // DCASE 没有距离
|
| 108 |
+
"distance_valid": false,
|
| 109 |
+
"frames": [
|
| 110 |
+
{"frame_idx": 13, "time_s": 1.3,
|
| 111 |
+
"doa": {"azimuth_deg": -45.0, "elevation_deg": 10.0},
|
| 112 |
+
"distance_cm": -1},
|
| 113 |
+
...
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
... // 可能 4+ 个 track,但逐帧同时 active 的 ≤ 4(K=4 由 matcher 保证)
|
| 117 |
+
]
|
| 118 |
+
}
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
### 2.3 生成 DCASE manifest 的一次性步骤
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
python tools/dcase_starss_to_jsonl.py \
|
| 125 |
+
--dcase-root /apdcephfs_cq10/.../DCASE2024_seld_baseline/prepared_datasets/starss23_foa_plus_29cls_20s \
|
| 126 |
+
--output /apdcephfs_cq10/.../data/metadata/dcase_starss_foa.jsonl \
|
| 127 |
+
--per-split-output
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
- 扫描 `metadata_dev/<dataset>/<stem>.csv`,每行 `frame_idx, class_idx, source_idx, az_deg, el_deg, dist_cm`。
|
| 131 |
+
- 按 `(class_idx, source_idx)` 分 track,若相邻 labelled frame 间隔 > `gap_split_frames` (默认 50 帧 = 5s),
|
| 132 |
+
就把 track 切成多段 `SourceEvent`,避免在静默区间乱插值。
|
| 133 |
+
- **类别空间压缩**:DCASE 29 类 → FSD50K 63 类的语义最近邻映射
|
| 134 |
+
(`DCASE_TO_FSD50K` dict,见文件头部);碰不上 FSD50K 词表的 DCASE 类(如 `unknown_*`)整 track 丢弃。
|
| 135 |
+
- 输出统计:18 061 CSV → train 12 805 / valid 4 560 / test 505。
|
| 136 |
+
|
| 137 |
+
### 2.4 FSD50K 63 类别名
|
| 138 |
+
|
| 139 |
+
qa_*/DCASE manifest 里可能出现细粒度标签(`male_singing` / `female_singing`),
|
| 140 |
+
但 v11a 的 vocab 只有压缩后的 63 类(含 `singing` 不含性别变体)。在 `spatial_dataset.py` 的
|
| 141 |
+
`_resolve_class_index` / `_resolve_class_label` 之前先跑一次 `_LABEL_ALIASES.get(raw, raw)` 归一化:
|
| 142 |
+
|
| 143 |
+
```python
|
| 144 |
+
_LABEL_ALIASES = {
|
| 145 |
+
"male_singing": "singing",
|
| 146 |
+
"female_singing": "singing",
|
| 147 |
+
}
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
这样不用重新生成 jsonl,就能把 508 条 `male_singing` + 527 条 `female_singing` 折到 `singing` 上。
|
| 151 |
+
|
| 152 |
+
---
|
| 153 |
+
|
| 154 |
+
## 3. Loader:`spatial_dataset.py` 的逐帧化改造
|
| 155 |
+
|
| 156 |
+
### 3.1 `SourceEvent` 新增 5 个可选字段
|
| 157 |
+
|
| 158 |
+
```python
|
| 159 |
+
@dataclass
|
| 160 |
+
class SourceEvent:
|
| 161 |
+
class_index: int
|
| 162 |
+
class_label: str
|
| 163 |
+
azimuth_deg: float # 静态 scalar;动态时是 frames[0] 的 fallback
|
| 164 |
+
elevation_deg: float
|
| 165 |
+
distance: float # 动态时是第一个 valid frame 的 fallback
|
| 166 |
+
distance_valid: bool
|
| 167 |
+
start_time_seconds: float
|
| 168 |
+
end_time_seconds: float
|
| 169 |
+
# ---- 动态轨迹(仅动态源设置)----
|
| 170 |
+
frame_times_s: Optional[Tensor] = None # [N_f] 秒,相对 clip 起点
|
| 171 |
+
frame_azi_deg: Optional[Tensor] = None # [N_f] 度,未 unwrap
|
| 172 |
+
frame_ele_deg: Optional[Tensor] = None # [N_f] 度
|
| 173 |
+
frame_distance_m: Optional[Tensor] = None # [N_f] 米
|
| 174 |
+
frame_distance_valid: Optional[Tensor] = None # [N_f] bool
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
### 3.2 `_parse_frame_trajectory`
|
| 178 |
+
|
| 179 |
+
从 manifest 的 `frames[]` 里抽出 5 个 1D tensor,同时支持两种 layout:
|
| 180 |
+
|
| 181 |
+
1. **qa_moving**:每帧带 `frame_idx`(不带 `time_s`),clip 级 `frame_rate` 用于换算 `time_s = frame_idx / frame_rate`。
|
| 182 |
+
2. **DCASE 转换器输出**:每帧直接给 `time_s`,跳过 `frame_idx / frame_rate` 换算。
|
| 183 |
+
|
| 184 |
+
距离单位处理:优先读 `distance_cm`(除以 100 得米),`-1` 或缺失标记为 `distance_valid=False`;
|
| 185 |
+
其次读 `distance_m`。
|
| 186 |
+
|
| 187 |
+
### 3.3 `_build_source_event_from_nested_entry` 的 fallback
|
| 188 |
+
|
| 189 |
+
- 动态源 top-level `doa` 通常是 `null`,所以把 `_get_float` 换成 `_maybe_get_float`,
|
| 190 |
+
再用 `frames[0]` 的 DOA 补 `azimuth_deg` / `elevation_deg`(scalar fallback,只有在
|
| 191 |
+
loss 层碰到静态路径时才会用到)。
|
| 192 |
+
- 距离同理:若 source-level `distance_valid=False`,但 `frames[]` 里有至少一个 `distance_cm >= 0`,
|
| 193 |
+
就用第一个 valid frame 的距离作为 scalar fallback;否则保留 `distance_valid=False`。
|
| 194 |
+
|
| 195 |
+
### 3.4 `_maybe_crop_sample` 的轨迹裁剪
|
| 196 |
+
|
| 197 |
+
随机/中心裁剪时,除了裁 waveform 和更新 `start/end_time_seconds`,还要:
|
| 198 |
+
|
| 199 |
+
- 用 `new_start/new_end` 窗口过滤 `frame_times_s`,把留下来的帧时间重置到新 clip 起点(`- crop_start_seconds`)。
|
| 200 |
+
- `frame_azi_deg` / `frame_ele_deg` / `frame_distance_m` / `frame_distance_valid` 一起按索引截断。
|
| 201 |
+
|
| 202 |
+
保证裁剪后的 `SourceEvent` 时间轴仍和 waveform 同源。
|
| 203 |
+
|
| 204 |
+
### 3.5 Collate:`[B, N_gt, T_s]` 逐帧 target
|
| 205 |
+
|
| 206 |
+
`collate_spatial_batch` 相对旧实现的关键变化:
|
| 207 |
+
|
| 208 |
+
```python
|
| 209 |
+
t_s_max = int(target_num_steps.max()) # batch 内最大 token 数
|
| 210 |
+
source_azimuth_deg = zeros(B, N_gt_max, t_s_max) # 原来是 (B, N_gt_max)
|
| 211 |
+
source_elevation_deg = zeros(B, N_gt_max, t_s_max)
|
| 212 |
+
source_distance = zeros(B, N_gt_max, t_s_max)
|
| 213 |
+
source_distance_valid = ones (B, N_gt_max, t_s_max, dtype=bool) # 默认 True
|
| 214 |
+
|
| 215 |
+
for b, sample in enumerate(samples):
|
| 216 |
+
t_axis = arange(t_s_i) / target_token_rate # 该 sample 的有效时间轴
|
| 217 |
+
for s, source in enumerate(sample.sources):
|
| 218 |
+
azi_row, ele_row, dist_row, dist_valid_row = _build_per_frame_targets(
|
| 219 |
+
source=source, t_axis=t_axis, t_s_max=t_s_max,
|
| 220 |
+
)
|
| 221 |
+
source_azimuth_deg[b, s] = azi_row
|
| 222 |
+
source_elevation_deg[b, s] = ele_row
|
| 223 |
+
source_distance[b, s] = dist_row
|
| 224 |
+
source_distance_valid[b, s] = dist_valid_row
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
`_build_per_frame_targets` 的两条路径:
|
| 228 |
+
|
| 229 |
+
- **静态源**(`frame_times_s is None`���:在 `[0:t_s_i)` 填入 scalar;`[t_s_i:t_s_max)` 填零(padding)。
|
| 230 |
+
行为等价于旧版广播。
|
| 231 |
+
- **动态源**:对 `t_axis` 做线性插值。方位角先用 `_unwrap_deg` 去掉 ±180° 的跳变
|
| 232 |
+
(qa_moving / DCASE 都可能有 170° → -170° 这样跨接的情况),插值后再 wrap 回 `[-180, 180]`;
|
| 233 |
+
elevation / distance 直接线性插值;`distance_valid` 用**两端都 valid 才 valid**的逻辑
|
| 234 |
+
(`_linear_interp_valid_mask`),避免在未知距离段里猜出假的 valid。
|
| 235 |
+
|
| 236 |
+
### 3.6 `SpatialBatch` 的契约变化
|
| 237 |
+
|
| 238 |
+
```python
|
| 239 |
+
@dataclass
|
| 240 |
+
class SpatialBatch:
|
| 241 |
+
...
|
| 242 |
+
source_azimuth_deg: Tensor # [B, N_gt_max, T_s_max] ← 原 [B, N_gt_max]
|
| 243 |
+
source_elevation_deg: Tensor # [B, N_gt_max, T_s_max]
|
| 244 |
+
source_distance: Tensor # [B, N_gt_max, T_s_max]
|
| 245 |
+
source_distance_valid: Tensor # [B, N_gt_max, T_s_max] 新字段
|
| 246 |
+
source_class_indices: Tensor # [B, N_gt_max] (class 仍是 clip 级)
|
| 247 |
+
source_start_time_seconds: Tensor # [B, N_gt_max]
|
| 248 |
+
source_end_time_seconds: Tensor # [B, N_gt_max]
|
| 249 |
+
source_valid_mask: Tensor # [B, N_gt_max]
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
`source_class_indices` 保持 clip 级:v11a 没有「同一 source 换类」的需求,且对应 track 内 class 恒定。
|
| 253 |
+
|
| 254 |
+
---
|
| 255 |
+
|
| 256 |
+
## 4. 模型:和 v11a 完全一致
|
| 257 |
+
|
| 258 |
+
**没改**,为了让 hot-start 生效。这里简要记录一下 v11a 已有的配置,便于对照:
|
| 259 |
+
|
| 260 |
+
```
|
| 261 |
+
FOA 4-ch waveform @ 16 kHz
|
| 262 |
+
│
|
| 263 |
+
▼
|
| 264 |
+
SpatialBEATsPreprocessor (mel, iv feat)
|
| 265 |
+
│
|
| 266 |
+
┌───────────────────┴──────────────────┐
|
| 267 |
+
▼ ▼
|
| 268 |
+
SpatialPatchEmbedding SpatialDeltaPatchAdapter
|
| 269 |
+
(mel → 768-d patch tokens) (+IV residual contribution)
|
| 270 |
+
└───────────────┬──────────────────────┘
|
| 271 |
+
▼
|
| 272 |
+
BEATs TransformerEncoder (12 层,冻结)
|
| 273 |
+
│
|
| 274 |
+
▼
|
| 275 |
+
LocalSpatialEncoder (IV-aware conv over (T_p, F_p))
|
| 276 |
+
│
|
| 277 |
+
▼
|
| 278 |
+
TemporalResampler → fused_spatial_embeddings [B, T_s, 768]
|
| 279 |
+
(T_s @ 10 Hz ,cfg.target_token_rate=10)
|
| 280 |
+
│
|
| 281 |
+
▼
|
| 282 |
+
SourceQueryDecoder (K=4 queries × T_s 次 decode,两段式)
|
| 283 |
+
• track_latents: [B, K, D]
|
| 284 |
+
• track_time_feat:[B, K, T_s, D]
|
| 285 |
+
│
|
| 286 |
+
▼
|
| 287 |
+
FrameTrackHeads (+ SpatialHeadDemixer 1 层 attn refine, heads=8)
|
| 288 |
+
• pred_activity: [B, K, T_s]
|
| 289 |
+
• pred_class_logits: [B, K, T_s, 63]
|
| 290 |
+
• pred_direction: [B, K, T_s, 3] L2-normed
|
| 291 |
+
• pred_distance: [B, K, T_s] softplus 米
|
| 292 |
+
• pred_num_active_logits: [B, T_s, K+1] (v10 num_active head)
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
v11a 相对 v9 的关键新组件(均保留):
|
| 296 |
+
- `use_spatial_head_demixer=True`(1 层 self-attn,8 heads,dropout 0.1)—— 在 FrameTrack head 输出后做一次
|
| 297 |
+
track 维解相关。
|
| 298 |
+
- `local_spatial_lr_scale=1.0` —— LocalSpatialEncoder 和 head 用相同 LR(v9 默认 0.3 偏低)。
|
| 299 |
+
|
| 300 |
+
---
|
| 301 |
+
|
| 302 |
+
## 5. Loss:逐帧 target + distance valid mask
|
| 303 |
+
|
| 304 |
+
入口 `compute_frame_track_losses(prediction_output, batch, temporal_padding_mask, config)`。
|
| 305 |
+
|
| 306 |
+
### 5.1 target 抽取
|
| 307 |
+
|
| 308 |
+
```python
|
| 309 |
+
targets = _frame_source_target_tensors(batch, t_s_max, device)
|
| 310 |
+
# 返回:
|
| 311 |
+
# window_mask: [B, N_gt, T_s] (active_time 内为 True)
|
| 312 |
+
# source_valid: [B, N_gt]
|
| 313 |
+
# source_class: [B, N_gt]
|
| 314 |
+
# source_direction: [B, N_gt, T_s, 3] ← 逐帧 unit vector
|
| 315 |
+
# source_distance: [B, N_gt, T_s] ← 逐帧米
|
| 316 |
+
# source_distance_valid: [B, N_gt, T_s] ← 逐帧 bool
|
| 317 |
+
```
|
| 318 |
+
|
| 319 |
+
对于来自 loader 的 `source_azimuth_deg` / `source_elevation_deg`,`_align_t` 处理长度不匹配:
|
| 320 |
+
batch 内 `t_s_max` 可能与 loader 构造时的大小不同(不同 DataLoader 的 collate 边界),
|
| 321 |
+
短则 pad 末帧的值,长则截断。
|
| 322 |
+
|
| 323 |
+
### 5.2 Hungarian 匹配(代价按每帧)
|
| 324 |
+
|
| 325 |
+
`_match_frame_tracks`(`per_frame` 或 `segment` 两种策略,preset 里走 `segment`)
|
| 326 |
+
在 `[B, N, K, T]` 的代价张量上做匹配:
|
| 327 |
+
|
| 328 |
+
```
|
| 329 |
+
cost[b, n, k, t] = class_cost_w * NLL(pred_class, target_class[b, n])
|
| 330 |
+
+ dir_cost_w * (1 - pred_direction[b,k,t] · target_direction[b,n,t])
|
| 331 |
+
+ dist_cost_w * |pred_distance[b,k,t] - target_distance[b,n,t]|
|
| 332 |
+
+ (1 - σ(pred_activity[b,k,t])) # include_activity_cost
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
**关键点**:`target_direction` / `target_distance` 从 `[B, N, T, *]` 广播到 `[B, N, 1, T, *]`��之前是 clip 级标量),
|
| 336 |
+
代价按每帧独立累加,所以动态源在不同帧的 best-match track 可以不同。segment matching
|
| 337 |
+
额外加了一个 −2.0 的 continuity bonus,让同一 GT 在连续的相同 active-set segment 里尽量停留在同一 track。
|
| 338 |
+
|
| 339 |
+
### 5.3 监督张量的构建
|
| 340 |
+
|
| 341 |
+
```python
|
| 342 |
+
matched_track: [B, N_gt, T_s] (匹配结果 k∈[0,K) 或 -1)
|
| 343 |
+
valid_assign = matched_track >= 0
|
| 344 |
+
idx_b, idx_gt, idx_t = valid_assign.nonzero(as_tuple=True)
|
| 345 |
+
idx_k = matched_track[idx_b, idx_gt, idx_t]
|
| 346 |
+
|
| 347 |
+
activity_target[idx_b, idx_k, idx_t] = 1.0
|
| 348 |
+
class_target [idx_b, idx_k, idx_t] = targets["source_class"][idx_b, idx_gt]
|
| 349 |
+
direction_target[idx_b, idx_k, idx_t] = targets["source_direction"][idx_b, idx_gt, idx_t] # ← 3D 索引
|
| 350 |
+
distance_target [idx_b, idx_k, idx_t] = targets["source_distance" ][idx_b, idx_gt, idx_t] # ← 3D 索引
|
| 351 |
+
dist_supervise_mask[idx_b, idx_k, idx_t] = targets["source_distance_valid"][idx_b, idx_gt, idx_t]
|
| 352 |
+
```
|
| 353 |
+
|
| 354 |
+
相对旧版(`targets["source_direction"][idx_b, idx_gt]` 是 `[M, 3]`)的差别是把第三个 axis 替换为具体
|
| 355 |
+
的 `idx_t`,真正拿到逐帧 GT。`supervise_mask` 是 activity-winning 的 mask,`dist_supervise_mask`
|
| 356 |
+
在其基础上再 AND 一个**逐帧 distance validity**:STARSS/DCASE 整源为 False 时,
|
| 357 |
+
distance loss 就不会回传任何梯度。
|
| 358 |
+
|
| 359 |
+
### 5.4 各项损失
|
| 360 |
+
|
| 361 |
+
| 项 | 公式 | mask |
|
| 362 |
+
| --- | --- | --- |
|
| 363 |
+
| activity | `BCE_with_logits(pred_activity, activity_target, pos_weight=dyn)` | `valid_time` 扩到 [B, K, T_s] |
|
| 364 |
+
| num_active (v10) | `CE(pred_num_active_logits, active_count)` | `valid_time` |
|
| 365 |
+
| class | `CE(pred_class_logits, class_target)` + 可选 ontology smoothing | `supervise_mask` |
|
| 366 |
+
| direction | `mean(1 - pred · target)` | `supervise_mask` |
|
| 367 |
+
| distance | `smooth_l1(pred, target)` | `dist_supervise_mask` ← **逐帧 validity** |
|
| 368 |
+
|
| 369 |
+
**ADPIT duplicate** & **nonwinner soft activity**(v9/v10 的两个辅助)同样全部走逐帧 `source_direction[..., t]`
|
| 370 |
+
和 `source_distance_valid[..., t]` 索引;旧版的 `batch.source_azimuth_deg[:, 0]` 之类的 2D 访问被替换成
|
| 371 |
+
`[:, 0, 0]`(共 44 处)以避免 shape 冲突。
|
| 372 |
+
|
| 373 |
+
最终汇总:
|
| 374 |
+
```
|
| 375 |
+
loss_total = λ_act · loss_activity
|
| 376 |
+
+ λ_cls · loss_class
|
| 377 |
+
+ λ_dir · loss_direction
|
| 378 |
+
+ λ_dist · loss_distance
|
| 379 |
+
+ λ_na · loss_num_active
|
| 380 |
+
```
|
| 381 |
+
λ 与 v11a 完全相同(在 `SpatialLossConfig` 里走 v10 phase-2 的基线数值)。
|
| 382 |
+
|
| 383 |
+
### 5.5 静态源的退化等价性
|
| 384 |
+
|
| 385 |
+
因为 loader 把静态源沿 T_s 轴广播,`target_direction[b, gt, 0:T_s_i]` 每一帧都一致,
|
| 386 |
+
Hungarian 代价 `(1 - pred·target)` 和 per-clip 版本逐项相等;distance 同理。因此
|
| 387 |
+
**ov123 sim/real clip 的 epoch 0 loss 与 v11a 数值吻合**,这也是为什么可以直接从 v11a best.pt 热启动。
|
| 388 |
+
|
| 389 |
+
---
|
| 390 |
+
|
| 391 |
+
## 6. 训练配置(preset diff)
|
| 392 |
+
|
| 393 |
+
```python
|
| 394 |
+
def make_ov1_local_spatial_v11a_with_dynamic_10hz_config(...):
|
| 395 |
+
cfg = make_ov1_local_spatial_v11a_real_balanced_10hz_config(...)
|
| 396 |
+
|
| 397 |
+
# —— 只改了数据和轮次 ——
|
| 398 |
+
cfg.train_manifest_paths = (
|
| 399 |
+
ov1_sim, ov2_sim, ov3_sim,
|
| 400 |
+
ov1_real, ov2_real, ov3_real,
|
| 401 |
+
qa_moving, qa_counting, qa_lr_pair, qa_same_doa,
|
| 402 |
+
dcase_starss_train,
|
| 403 |
+
)
|
| 404 |
+
cfg.train_manifest_replication = (1, 3, 3, 4, 8, 8, 2, 1, 1, 1, 2)
|
| 405 |
+
|
| 406 |
+
cfg.val_manifest_paths = (
|
| 407 |
+
ov1_sim, ov2_sim, ov3_sim,
|
| 408 |
+
ov1_real, ov2_real, ov3_real,
|
| 409 |
+
dcase_starss_valid,
|
| 410 |
+
)
|
| 411 |
+
cfg.test_manifest_paths = cfg.val_manifest_paths
|
| 412 |
+
|
| 413 |
+
cfg.num_epochs = 15
|
| 414 |
+
cfg.output_dir = "checkpoints/spatial_beats_ov1_local_spatial_v11a_with_dynamic_10hz_exp/03_ov123_top4"
|
| 415 |
+
return cfg
|
| 416 |
+
```
|
| 417 |
+
|
| 418 |
+
运行入口(`run_ov1_v11a_with_dynamic_10hz.sh`)默认:
|
| 419 |
+
```
|
| 420 |
+
GPUS=8 BATCH_SIZE=4 NUM_WORKERS=8
|
| 421 |
+
SPATIAL_EPOCHS=15 SPATIAL_LR=1.5e-5 AMP=fp32
|
| 422 |
+
RESUME_CKPT=checkpoints/spatial_beats_ov1_local_spatial_v11a_real_balanced_10hz_exp/03_ov123_top4/best.pt
|
| 423 |
+
--no-resume-optimizer --reset-epoch-on-resume --reset-best-on-resume
|
| 424 |
+
```
|
| 425 |
+
|
| 426 |
+
---
|
| 427 |
+
|
| 428 |
+
## 7. 与 v11a_real_balanced_10hz 的差异一览
|
| 429 |
+
|
| 430 |
+
| 维度 | v11a_real_balanced_10hz | **v11a_with_dynamic_10hz** |
|
| 431 |
+
| --- | --- | --- |
|
| 432 |
+
| 训练 manifest 数 | 6 (ov123 sim + ov123 real) | **11**(+5 个 QA/DCASE) |
|
| 433 |
+
| 训练 clip 量(未复制) | ~O(20k) | +49 357 |
|
| 434 |
+
| 真实录音监督 | ov123 real 静态 scalar | **+DCASE STARSS 20s 逐帧** |
|
| 435 |
+
| 动态源监督 | 无 | **qa_moving + DCASE 动态 track** |
|
| 436 |
+
| `SourceEvent` 字段 | 无 `frame_*` | +5 个 `frame_times_s/azi/ele/dist/distance_valid` |
|
| 437 |
+
| Loader target shape | `source_*: [B, N_gt]` | **`source_*: [B, N_gt, T_s]`**(静态广播,动态插值) |
|
| 438 |
+
| `source_distance_valid` | `[B, N_gt]` | **`[B, N_gt, T_s]`**(逐帧) |
|
| 439 |
+
| Hungarian 代价 | dir/dist 代价用 clip 级标量 target | **用逐帧 `[B, N, T, *]` target 广播** |
|
| 440 |
+
| distance 监督 mask | 按源 `distance_valid` | **按 (源, 帧) `dist_supervise_mask`**,STARSS/DCASE 不回传梯度 |
|
| 441 |
+
| Class 词表别名 | 无 | **`_LABEL_ALIASES` 处理 `male_singing`/`female_singing → singing`** |
|
| 442 |
+
| 验证集 | ov123 sim+real | **+ DCASE valid (4 560 clips)** |
|
| 443 |
+
| `num_epochs` | 15 (继承 v9) | 15(不变) |
|
| 444 |
+
| 模型结构 | local_spatial_track + K=4 + demixer | **完全一致** |
|
| 445 |
+
| 热启动 | v9_real_balanced_10hz best.pt | **v11a_real_balanced_10hz best.pt**, strict=False |
|
| 446 |
+
| 输出目录 | `.../v11a_real_balanced_10hz_exp/03_ov123_top4` | `.../v11a_with_dynamic_10hz_exp/03_ov123_top4` |
|
| 447 |
+
|
| 448 |
+
---
|
| 449 |
+
|
| 450 |
+
## 8. 已知坑 & 兼容性备忘
|
| 451 |
+
|
| 452 |
+
1. **`pred_*` 的 t_s 与 loader 的 `T_s_max` 有可能不一致**。loss 侧 `_align_t` 会 pad/truncate GT 的最后一维;
|
| 453 |
+
loader 侧已保证 `target_num_steps = round(duration × target_token_rate)` 与模型 temporal resampler 对齐,
|
| 454 |
+
正常只有在 batch 内最长样本决定 `t_s_max` 而 frame-track head 输出更短时需要截断。
|
| 455 |
+
2. **方位角跨 ±180°**:qa_moving 里观测到 `175.6° → -177.3°` 的自然轨迹,
|
| 456 |
+
`_unwrap_deg` 会把它解包成 `175.6° → 182.7°` 插值,最后 wrap 回 `[-180, 180]`,不会出现 -340° 的错误跨越。
|
| 457 |
+
3. **DCASE 一个 clip 可以出现 6+ 个 track**,但任意帧同时 active 的 ≤ 4(DCASE 规范约束);
|
| 458 |
+
`_match_frame_tracks_per_frame` 里 `active_count.clamp(max=K)` 是一层防御。
|
| 459 |
+
4. **distance=-1 的 clip** 在 loss 内走 `dist_supervise_mask` 全 False 的分支:
|
| 460 |
+
`loss_distance = pred_distance.sum() * 0.0`,梯度为 0,这一 clip 只贡献 activity/class/direction loss。
|
| 461 |
+
5. **`male_singing/female_singing`** 是 loader 侧别名,若以后又有新细粒度标签冒出来,直接在
|
| 462 |
+
`spatial_dataset.py` 的 `_LABEL_ALIASES` 里加映射即可,不需要重跑数据。
|
| 463 |
+
6. **ov123 静态 clip 的 epoch 0 loss 必须与 v11a 一致**——这是验证 loader/loss 升级没有引入
|
| 464 |
+
回归的烟雾测试;曾经跑过 smoke script 确认 variance=0 for 静态 target、variance>0 for qa_moving target。
|
| 465 |
+
|
| 466 |
+
---
|
| 467 |
+
|
| 468 |
+
## 9. 相关脚本索引
|
| 469 |
+
|
| 470 |
+
- `tools/dcase_starss_to_jsonl.py` —— DCASE CSV → jsonl + FSD50K 类别映射
|
| 471 |
+
- `run_ov1_v11a_with_dynamic_10hz.sh` —— 训练入口(含 6 个动态 manifest 存在性 warning)
|
| 472 |
+
- `spatial_dataset.py::_parse_frame_trajectory` / `_build_per_frame_targets` —— 动态 target 构建
|
| 473 |
+
- `spatial_loss.py::_frame_source_target_tensors` / `compute_frame_track_losses` —— 逐帧 loss
|
| 474 |
+
- Preset: `train_spatial_beats.py::make_ov1_local_spatial_v11a_with_dynamic_10hz_config`
|
| 475 |
+
(CLI 名: `--preset ov1_local_spatial_v11a_with_dynamic_10hz`)
|
docs/V11_QUICK_START.md
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# v11 Architecture - Quick Start Guide
|
| 2 |
+
|
| 3 |
+
## What is v11?
|
| 4 |
+
|
| 5 |
+
The v11 series represents a major architectural enhancement addressing the classification accuracy plateau at ~51% in the Spatial-BEATs model. It introduces:
|
| 6 |
+
|
| 7 |
+
1. **SpatialDeltaPatchAdapterV2**: Enhanced front-end spatial encoder (17.4M params)
|
| 8 |
+
2. **SpatialAdapterLayer**: In-trunk spatial conditioning (1.2M params total)
|
| 9 |
+
3. **Multiple routing options**: Route A/B (track-based) or Route C (ACCDOA class-based)
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## Four v11 Variants
|
| 14 |
+
|
| 15 |
+
### 1. v11_phase1_cls - Phase 1 Classification Refinement
|
| 16 |
+
|
| 17 |
+
**Use this first** to diagnose if the new adapter improves classification accuracy.
|
| 18 |
+
|
| 19 |
+
**What it does:**
|
| 20 |
+
- Enables SpatialDeltaPatchAdapterV2 only
|
| 21 |
+
- Freezes direction/distance heads
|
| 22 |
+
- Trains classification + num_active heads
|
| 23 |
+
- Hot-starts from v10 phase-1 best checkpoint
|
| 24 |
+
|
| 25 |
+
**Command:**
|
| 26 |
+
```bash
|
| 27 |
+
./run_ov1_v11_phase1_cls.sh
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
**Environment variables:**
|
| 31 |
+
```bash
|
| 32 |
+
SPATIAL_EPOCHS=10 # Default: 10 epochs
|
| 33 |
+
SPATIAL_LR=7.5e-6 # Default: 7.5e-6
|
| 34 |
+
BATCH_SIZE=8 # Default: 8
|
| 35 |
+
GPUS=8 # Default: 8
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
**Expected output:**
|
| 39 |
+
```
|
| 40 |
+
Epoch 1: cls_acc=0.720 (should be at least v10 level)
|
| 41 |
+
Epoch 5: cls_acc=0.755 (expect improvement trend)
|
| 42 |
+
Epoch 10: cls_acc=0.78+ (best of phase-1)
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
**What to look for:**
|
| 46 |
+
- Does cls_acc improve beyond v10 phase-1 peak (0.78)?
|
| 47 |
+
- How quickly does it converge?
|
| 48 |
+
- Does val_loss plateau or continue improving?
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
### 2. v11a_ov123_top4 - Route B + Spatial Demixer (Full Architecture)
|
| 53 |
+
|
| 54 |
+
**Use after v11_phase1_cls confirms improvement.**
|
| 55 |
+
|
| 56 |
+
**What it does:**
|
| 57 |
+
- Enables SpatialDeltaPatchAdapterV2 + trunk adapters + spatial_head_demixer
|
| 58 |
+
- Trains all heads (activity, class, direction, distance)
|
| 59 |
+
- Uses demixer for both class AND spatial heads
|
| 60 |
+
- Hot-starts from v9 ov123_top4 best checkpoint
|
| 61 |
+
|
| 62 |
+
**Command:**
|
| 63 |
+
```bash
|
| 64 |
+
./run_ov1_v11a_ov123_top4.sh
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
**Key metrics:**
|
| 68 |
+
- `azi_mae_deg`: Azimuth mean absolute error (primary DOA metric)
|
| 69 |
+
- `class_acc`: Matched-source class accuracy
|
| 70 |
+
- `activity_f1`: Source presence F1-score
|
| 71 |
+
|
| 72 |
+
**Expected improvement:**
|
| 73 |
+
```
|
| 74 |
+
Metric v9 Baseline v11a Target Expected Delta
|
| 75 |
+
────────────────────────────────────────────────────────────────────
|
| 76 |
+
azi_mae_deg (train) 10° 8-9° -1 to -2°
|
| 77 |
+
azi_mae_deg (val) 30° 24-26° -4 to -6°
|
| 78 |
+
class_acc (val) 73% 75%+ +2%
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
**What to look for:**
|
| 82 |
+
- Validation azimuth error should be significantly lower
|
| 83 |
+
- Train/val gap should narrow (from ~20° toward ~15°)
|
| 84 |
+
- No collapse in accuracy metrics
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
### 3. v11b_ov123_top4 - Route B + LocalSpatial Demixer KV
|
| 89 |
+
|
| 90 |
+
**Use for comparison with v11a.**
|
| 91 |
+
|
| 92 |
+
**What it does:**
|
| 93 |
+
- Same as v11a, BUT
|
| 94 |
+
- Demixer attends to LocalSpatial's 7-channel pre-pool (FOA + IV)
|
| 95 |
+
- Instead of BEATs mono mel-filterbank features
|
| 96 |
+
- Hypothesis: Spatial features better for DOA decomposition
|
| 97 |
+
|
| 98 |
+
**Command:**
|
| 99 |
+
```bash
|
| 100 |
+
./run_ov1_v11b_ov123_top4.sh
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
**Comparison with v11a:**
|
| 104 |
+
```
|
| 105 |
+
Aspect v11a (BEATs KV) v11b (LocalSpatial KV)
|
| 106 |
+
─────────────────────────────────────────────────────────────────
|
| 107 |
+
Demixer KV source BEATs trunk LocalSpatial pre-pool
|
| 108 |
+
Channels 1 (mono fbank) 7 (4 FOA + 3 IV)
|
| 109 |
+
Prior knowledge Semantic Spatial physics
|
| 110 |
+
Expected advantage Better for class Better for direction
|
| 111 |
+
Computational cost Lower Higher
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
**When to pick v11b over v11a:**
|
| 115 |
+
- If DOA error (azi_mae_deg) is more important than class accuracy
|
| 116 |
+
- If you have GPU budget for extra feature processing
|
| 117 |
+
- For acoustic scenes where spatial features matter more
|
| 118 |
+
|
| 119 |
+
---
|
| 120 |
+
|
| 121 |
+
### 4. v11c_ov123_accdoa - Paradigm Shift to ACCDOA (Route C)
|
| 122 |
+
|
| 123 |
+
**Use as a "simplicity first" baseline.**
|
| 124 |
+
|
| 125 |
+
**What it does:**
|
| 126 |
+
- Enables SpatialDeltaPatchAdapterV2 + trunk adapters
|
| 127 |
+
- Replaces query decoder + Hungarian matching with per-class ACCDOA heads
|
| 128 |
+
- Each class gets its own spatial slot (no matching needed)
|
| 129 |
+
- Activity encoded in vector magnitude, direction in unit vector
|
| 130 |
+
|
| 131 |
+
**Command:**
|
| 132 |
+
```bash
|
| 133 |
+
./run_ov1_v11c_ov123_accdoa.sh
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
**Key differences from v11a:**
|
| 137 |
+
```
|
| 138 |
+
Aspect v11a (Route B, Track) v11c (Route C, ACCDOA)
|
| 139 |
+
─────────────────────────────────────────────────────────────────────
|
| 140 |
+
Paradigm K learnable tracks Per-class slots
|
| 141 |
+
Matching Hungarian (clip-level) None (inherent per-class)
|
| 142 |
+
Activity loss Binary cross-entropy MSE on magnitude
|
| 143 |
+
Direction repr. L2 normalized vector Unit vector (normalized)
|
| 144 |
+
Scalability O(K×T_s) per-frame O(num_classes×T_s)
|
| 145 |
+
ov2/ov3 fit Good (overlap ambiguity) Better (same-class=0)
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
**When to pick v11c:**
|
| 149 |
+
- For DCASE evaluation (uses official SELD metrics)
|
| 150 |
+
- If Hungarian matching is a bottleneck
|
| 151 |
+
- For datasets with no overlapping same-class sources (ov2/ov3 constraints)
|
| 152 |
+
- For interpretability (each class = one direction)
|
| 153 |
+
|
| 154 |
+
---
|
| 155 |
+
|
| 156 |
+
## Decision Tree: Which v11 to Run?
|
| 157 |
+
|
| 158 |
+
```
|
| 159 |
+
START
|
| 160 |
+
│
|
| 161 |
+
├─→ "Do I want to diagnose if new adapters help classification?"
|
| 162 |
+
│ └─→ YES: Run v11_phase1_cls
|
| 163 |
+
│ ↓ (wait for results)
|
| 164 |
+
│ Does cls_acc improve?
|
| 165 |
+
│ ├─→ YES ✓
|
| 166 |
+
│ │ └─→ Proceed to multi-head experiments
|
| 167 |
+
│ └─→ NO ✗
|
| 168 |
+
│ └─→ Back to drawing board (architecture issue)
|
| 169 |
+
│
|
| 170 |
+
├─→ "Is direction-of-arrival (DOA) error my primary concern?"
|
| 171 |
+
│ ├─→ YES: Need DOA focus
|
| 172 |
+
│ │ ├─→ "Do I have GPU budget for LocalSpatial features?"
|
| 173 |
+
│ │ │ ├─→ YES: Run v11b_ov123_top4
|
| 174 |
+
│ │ │ └─→ NO: Run v11a_ov123_top4
|
| 175 |
+
│ └─→ NO: Skip v11a/v11b
|
| 176 |
+
│
|
| 177 |
+
└─→ "Am I targeting DCASE evaluation / ov2/ov3 constraints?"
|
| 178 |
+
├─→ YES: Run v11c_ov123_accdoa
|
| 179 |
+
└─→ NO: Run v11a_ov123_top4 (default full-featured)
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
---
|
| 183 |
+
|
| 184 |
+
## Monitoring Experiments
|
| 185 |
+
|
| 186 |
+
### Key Metrics to Track
|
| 187 |
+
|
| 188 |
+
**Classification**:
|
| 189 |
+
- `class_acc`: Top-1 accuracy on matched sources
|
| 190 |
+
- `class_precision`: Per-class precision
|
| 191 |
+
- `class_recall`: Per-class recall
|
| 192 |
+
|
| 193 |
+
**Direction (DOA)**:
|
| 194 |
+
- `azi_mae_deg`: **Primary metric** - azimuth mean absolute error
|
| 195 |
+
- `ele_mae_deg`: Elevation mean absolute error
|
| 196 |
+
- `azi_std_deg`: Azimuth error standard deviation
|
| 197 |
+
|
| 198 |
+
**Distance**:
|
| 199 |
+
- `dist_mae_m`: Distance mean absolute error
|
| 200 |
+
|
| 201 |
+
**Activity**:
|
| 202 |
+
- `activity_f1`: Source presence F1-score
|
| 203 |
+
- `num_active_mae`: Mean absolute error in source count
|
| 204 |
+
|
| 205 |
+
**Gap Analysis**:
|
| 206 |
+
- `train_azi_mae_deg`: Training set azimuth error
|
| 207 |
+
- `val_azi_mae_deg`: Validation set azimuth error
|
| 208 |
+
- `gap = val - train`: **Gap should decrease with v11**
|
| 209 |
+
|
| 210 |
+
### TensorBoard Visualization
|
| 211 |
+
|
| 212 |
+
```bash
|
| 213 |
+
tensorboard --logdir=checkpoints/spatial_beats_v11_phase1_cls_exp/ov123_top4 --port=6006
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
**Plots to monitor:**
|
| 217 |
+
- `metrics/val_azi_mae_deg`: Should decrease smoothly
|
| 218 |
+
- `metrics/train_azi_mae_deg`: Should decrease with training
|
| 219 |
+
- `loss/total`: Should follow training dynamics (may oscillate)
|
| 220 |
+
- `loss/frame_direction`: DOA-specific loss component
|
| 221 |
+
|
| 222 |
+
---
|
| 223 |
+
|
| 224 |
+
## Checkpoint Management
|
| 225 |
+
|
| 226 |
+
### Hot-Start Strategy
|
| 227 |
+
|
| 228 |
+
Each v11 variant is designed to hot-start from a previous checkpoint:
|
| 229 |
+
|
| 230 |
+
**v11_phase1_cls**:
|
| 231 |
+
```
|
| 232 |
+
Loads from: v10_phase1_cls best.pt
|
| 233 |
+
Missing params: V2 adapter + trunk adapters
|
| 234 |
+
Initialize with: Zero-init adapters (identity at epoch-0)
|
| 235 |
+
Benefit: Inherits v10's frozen classification features
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
**v11a_ov123_top4**:
|
| 239 |
+
```
|
| 240 |
+
Loads from: v9_ov123_top4 best.pt
|
| 241 |
+
Missing params: V2 + trunk adapters + spatial_demixer (added to heads)
|
| 242 |
+
Initialize with: Zero-init everything (identity at epoch-0)
|
| 243 |
+
Benefit: Inherits v9's proven multi-head balance
|
| 244 |
+
```
|
| 245 |
+
|
| 246 |
+
**v11b_ov123_top4**:
|
| 247 |
+
```
|
| 248 |
+
Same as v11a, but adds LocalSpatial pre-pool processing
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
**v11c_ov123_accdoa**:
|
| 252 |
+
```
|
| 253 |
+
Loads from: ov1_local_spatial baseline (v9 incompatible)
|
| 254 |
+
Missing params: ACCDOAHeads (entire head replacement)
|
| 255 |
+
Initialize with: Zero-init (no class/spatial heads to inherit)
|
| 256 |
+
Benefit: Simpler routing = faster convergence
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
### How to Load a Checkpoint Manually
|
| 260 |
+
|
| 261 |
+
```python
|
| 262 |
+
import torch
|
| 263 |
+
from train_spatial_beats import make_ov1_local_spatial_v11a_ov123_top4_config
|
| 264 |
+
from spatial_beats import SpatialBEATs
|
| 265 |
+
|
| 266 |
+
# Create model with v11a config
|
| 267 |
+
cfg = make_ov1_local_spatial_v11a_ov123_top4_config()
|
| 268 |
+
model = SpatialBEATs(cfg)
|
| 269 |
+
|
| 270 |
+
# Load v9 checkpoint (strict=False ignores new params)
|
| 271 |
+
ckpt = torch.load('checkpoints/.../v9_best.pt')
|
| 272 |
+
model.load_state_dict(ckpt['model'], strict=False)
|
| 273 |
+
|
| 274 |
+
# New params are zero-initialized (identity behavior)
|
| 275 |
+
# Ready to train!
|
| 276 |
+
model.train()
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
---
|
| 280 |
+
|
| 281 |
+
## Troubleshooting
|
| 282 |
+
|
| 283 |
+
### Issue: "CUDA out of memory"
|
| 284 |
+
**Solution**: Reduce batch size or sequence length
|
| 285 |
+
```bash
|
| 286 |
+
BATCH_SIZE=4 ./run_ov1_v11a_ov123_top4.sh
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
### Issue: "ClassHeadSpectralDemixer not initialized"
|
| 290 |
+
**Solution**: Ensure config enables it:
|
| 291 |
+
```python
|
| 292 |
+
cfg.use_class_head_demixer = True # For v11a
|
| 293 |
+
cfg.use_spatial_head_demixer = True # For v11a (added in v11)
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
### Issue: "Large train/val gap not shrinking"
|
| 297 |
+
**Diagnosis steps**:
|
| 298 |
+
1. Check if Dropout is OFF during evaluation
|
| 299 |
+
2. Verify SpecAugment is applied only during training
|
| 300 |
+
3. Run diagnostic: evaluate same checkpoint in train/eval modes
|
| 301 |
+
```bash
|
| 302 |
+
python -c "
|
| 303 |
+
model.eval()
|
| 304 |
+
val_error_no_dropout = evaluate(model, val_loader)
|
| 305 |
+
model.train()
|
| 306 |
+
val_error_with_dropout = evaluate(model, val_loader)
|
| 307 |
+
print(f'Dropout effect: {val_error_with_dropout - val_error_no_dropout:.1f}°')
|
| 308 |
+
"
|
| 309 |
+
```
|
| 310 |
+
|
| 311 |
+
### Issue: "Trunk adapters not being applied"
|
| 312 |
+
**Check**: Verify config flag is True
|
| 313 |
+
```python
|
| 314 |
+
if not cfg.use_trunk_spatial_adapters:
|
| 315 |
+
print("WARNING: Trunk adapters disabled!")
|
| 316 |
+
cfg.use_trunk_spatial_adapters = True
|
| 317 |
+
```
|
| 318 |
+
|
| 319 |
+
---
|
| 320 |
+
|
| 321 |
+
## Next Steps After v11 Experiments
|
| 322 |
+
|
| 323 |
+
1. **Analyze results** (docs/V11_IMPLEMENTATION_SUMMARY.md contains diagnostic templates)
|
| 324 |
+
2. **Pick best variant** based on your primary metric
|
| 325 |
+
3. **Fine-tune hyperparameters** (learning rate, dropout rate if you modify later)
|
| 326 |
+
4. **Run official evaluation** on test set using DCASE metrics
|
| 327 |
+
5. **Consider multi-stage training**:
|
| 328 |
+
- Stage 1: Classification only (v11_phase1_cls)
|
| 329 |
+
- Stage 2: Full pipeline (v11a/b/c)
|
| 330 |
+
- Stage 3: Fine-tuning (reduce LR, increase epochs)
|
| 331 |
+
|
| 332 |
+
---
|
| 333 |
+
|
| 334 |
+
## Citation & References
|
| 335 |
+
|
| 336 |
+
This architecture is built on:
|
| 337 |
+
- **BEATs** (Microsoft): Base semantic encoder (https://arxiv.org/abs/2212.09058)
|
| 338 |
+
- **DCASE SELD**: Official evaluation metrics (https://github.com/sharathadavanne/seld-dcase2023)
|
| 339 |
+
- **EINV2 paradigm**: Track-based source modeling
|
| 340 |
+
- **Spatial audio physics**: FOA (First-Order Ambisonics) + Intensity Vectors
|
| 341 |
+
|
| 342 |
+
For detailed technical justification, see:
|
| 343 |
+
- docs/V11_IMPLEMENTATION_SUMMARY.md
|
| 344 |
+
- docs/doa_train_valid_gap_analysis.md
|
| 345 |
+
- SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS_COMPREHENSIVE.md
|
docs/gemini.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Spatial-BEATs 最终实施指南 (Reference Implementation Guide)
|
| 2 |
+
|
| 3 |
+
本文档定义了 `Spatial-BEATs` 的模型架构、特征工程与训练流程的最终技术细节,作为代码实现的唯一参照。
|
| 4 |
+
|
| 5 |
+
## 1. 模型架构细节 (Architecture Specification)
|
| 6 |
+
|
| 7 |
+
### 1.1 输入前端 (Stem)
|
| 8 |
+
- **输入特征图**: $7 \times 128 \times 1024$ (Channels $\times$ Mel-bins $\times$ Time-frames)。
|
| 9 |
+
- **通道定义**:
|
| 10 |
+
- `[0:4]`: W, X, Y, Z 的 Log-mel。
|
| 11 |
+
- `[4:7]`: IVx, IVy, IVz (Intensity Vector),按时间/频率对齐。
|
| 12 |
+
- **Patch Embedding**:
|
| 13 |
+
- 结构: `nn.Conv2d(7, embed_dim, kernel_size=16, stride=16)`。
|
| 14 |
+
- 初始化: 通道 0 (W) 复用 BEATs 预训练权重,通道 1-6 随机初始化。
|
| 15 |
+
|
| 16 |
+
### 1.2 空间 Token 提取 (Source Queries)
|
| 17 |
+
- **Token 数量 ($K$)**: 4 个。
|
| 18 |
+
- **实现方式**:
|
| 19 |
+
- 定义 `nn.Parameter(torch.randn(1, 4, embed_dim))` 作为 Source Queries。
|
| 20 |
+
- 使用 2 层 Transformer Decoder 层。
|
| 21 |
+
- **Query**: Source Queries。
|
| 22 |
+
- **Key/Value**: BEATs Trunk 的输出序列 (Dense Patch Tokens)。
|
| 23 |
+
- **输出**: 4 个维度为 `embed_dim` 的 `Spatial Tokens`。
|
| 24 |
+
|
| 25 |
+
### 1.3 预测头 (Prediction Heads)
|
| 26 |
+
每个 Spatial Token 独立连接以下 MLP 层:
|
| 27 |
+
- **Objectness**: `Linear -> Sigmoid` (1 unit)。
|
| 28 |
+
- **Azimuth**: `Linear -> tanh` (2 units: $\sin, \cos$)。计算角度使用 `atan2`。
|
| 29 |
+
- **Elevation**: `Linear -> tanh` (2 units: $\sin, \cos$)。
|
| 30 |
+
- **Distance**: `Linear` (1 unit, 单位:**Centimeters**)。
|
| 31 |
+
- **Class**: `Linear -> Softmax` (N units, 对应 FSD50k 类别)。
|
| 32 |
+
|
| 33 |
+
## 2. 坐标系与物理特征 (Spatial Physics)
|
| 34 |
+
|
| 35 |
+
### 2.1 坐标系 (DCASE Standard)
|
| 36 |
+
- **轴向**: +x 前, +y 左, +z 上。
|
| 37 |
+
- **方位角 (Azimuth)**: $[-180, 180]$,逆时针增加。+90 度为左,-90 度为右。
|
| 38 |
+
- **仰角 (Elevation)**: $[-90, 90]$,向上增加。
|
| 39 |
+
- **距离 (Distance)**: 以 **厘米 (cm)** 为单位进行回归。
|
| 40 |
+
|
| 41 |
+
### 2.2 IV 计算 (Intensity Vector)
|
| 42 |
+
在特征提取阶段,按以下逻辑计算 IV:
|
| 43 |
+
- $I_x = \text{Re}\{W^* \cdot X\}$
|
| 44 |
+
- $I_y = \text{Re}\{W^* \cdot Y\}$
|
| 45 |
+
- $I_z = \text{Re}\{W^* \cdot Z\}$
|
| 46 |
+
- 所有的 $I$ 均通过 Mel 滤波器组进行映射,以匹配 Log-mel 的分辨率。
|
| 47 |
+
|
| 48 |
+
## 3. 训练策略 (Training Recipe)
|
| 49 |
+
|
| 50 |
+
### 3.1 损失函数 (Hungarian Loss)
|
| 51 |
+
- **匹配算法**: 使用 `scipy.optimize.linear_sum_assignment` (Hungarian Matching) 匹配 4 个预测 Token 与 $N$ 个 GT 声源 ($N \le 4$)。
|
| 52 |
+
- **匹配代价 (Matching Cost)**: 综合位置误差 (Az/El/Dist)、类别误差和 Objectness 分数。
|
| 53 |
+
- **总损失**:
|
| 54 |
+
- 对匹配成功的 Token:计算 $L_{MSE}(pos) + L_{BCE}(obj) + L_{CrossEntropy}(cls)$。
|
| 55 |
+
- 对未匹配的 Token:计算 $L_{BCE}(obj, 0)$。
|
| 56 |
+
|
| 57 |
+
### 3.2 训练阶段
|
| 58 |
+
1. **Stage 1 (Stem & Head Warmup)**: 冻结 BEATs Trunk (Transformer 层),仅训练新 Patch Embedding 和 Spatial Decoder/Heads。
|
| 59 |
+
2. **Stage 2 (Joint Fine-tuning)**: 以 $1 \times 10^{-5}$ 的低学习率解冻整个 Trunk 进行微调。
|
| 60 |
+
|
| 61 |
+
## 4. LLM 接入接口 (LLM Interface)
|
| 62 |
+
- 提取后的 4 个 `Spatial Tokens` 将通过一个 `Linear` 投影层对齐到 LLM 的隐藏层空间。
|
| 63 |
+
- 在 Prompt 中,这 4 个 tokens 将按 object-wise 顺序排列,代表音频中的空间实体。
|
docs/spatial_beats_implementation_spec.md
ADDED
|
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Spatial-BEATs 实现规格
|
| 2 |
+
|
| 3 |
+
## 1. 目标
|
| 4 |
+
|
| 5 |
+
本规格文档用于将前期讨论收敛为一个可以直接实施的 `Spatial-BEATs` 方案。
|
| 6 |
+
|
| 7 |
+
目标是构建一个独立的 `Spatial Encoder`:
|
| 8 |
+
|
| 9 |
+
- 输入为完整 `FOA` 音频及其派生空间特征
|
| 10 |
+
- 完整的 `FOA` 特征经过 `BEATs backbone`
|
| 11 |
+
- 最大化复用 `BEATs` 预训练权重
|
| 12 |
+
- 输出一组 `source-level spatial tokens`
|
| 13 |
+
- 这些 token 作为独立模态输入给 LLM
|
| 14 |
+
- 原有语义 audio encoder 保持不动
|
| 15 |
+
|
| 16 |
+
这里的关键原则是:
|
| 17 |
+
|
| 18 |
+
> 不是让 `W-only` 走主干,再外挂一个小空间 adapter;而是让完整 FOA 空间特征真正进入 BEATs 主干,并在主干之后产出结构化空间 token。
|
| 19 |
+
|
| 20 |
+
## 2. 最终任务定义
|
| 21 |
+
|
| 22 |
+
### 2.1 核心任务
|
| 23 |
+
|
| 24 |
+
`Spatial-BEATs` 的主任务定义为:
|
| 25 |
+
|
| 26 |
+
- 给定一个多源 `FOA` 音频片段
|
| 27 |
+
- 预测其中最多 `K` 个潜在声源的空间表示
|
| 28 |
+
- 每个表示对应一个 `source token`
|
| 29 |
+
|
| 30 |
+
每个 source token 至少承载:
|
| 31 |
+
|
| 32 |
+
- `objectness`
|
| 33 |
+
- `azimuth`
|
| 34 |
+
- `elevation`
|
| 35 |
+
- `distance`
|
| 36 |
+
|
| 37 |
+
可选承载:
|
| 38 |
+
|
| 39 |
+
- `source class auxiliary logits`
|
| 40 |
+
- `source embedding`
|
| 41 |
+
|
| 42 |
+
### 2.2 推荐监督形式
|
| 43 |
+
|
| 44 |
+
如果训练数据中每个源都有标注,则推荐采用:
|
| 45 |
+
|
| 46 |
+
- `set prediction`
|
| 47 |
+
- `K` 个预测 token 对 `N` 个 GT sources
|
| 48 |
+
- 用 `Hungarian matching` 做一一匹配
|
| 49 |
+
|
| 50 |
+
不建议采用:
|
| 51 |
+
|
| 52 |
+
- 单一 scene-level spatial token
|
| 53 |
+
- 仅回归整段音频的全局空间摘要
|
| 54 |
+
|
| 55 |
+
原因是这会损失多源结构,不利于后续 LLM 做关系推理。
|
| 56 |
+
|
| 57 |
+
## 3. 最终架构
|
| 58 |
+
|
| 59 |
+
推荐最终架构:
|
| 60 |
+
|
| 61 |
+
```text
|
| 62 |
+
FOA waveform
|
| 63 |
+
-> SpatialBEATsPreprocessor
|
| 64 |
+
-> FOA feature map [B, C_foa, T, F]
|
| 65 |
+
-> FOA patch embedding
|
| 66 |
+
-> BEATs trunk
|
| 67 |
+
-> Spatial query decoder
|
| 68 |
+
-> K source tokens
|
| 69 |
+
-> Spatial prediction heads
|
| 70 |
+
-> LLM projector
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
为了最大化复用 BEATs 主干,本方案尽量不改 trunk 内部的 Transformer 结构。
|
| 74 |
+
|
| 75 |
+
## 4. 输入特征定义
|
| 76 |
+
|
| 77 |
+
### 4.1 默认推荐特征
|
| 78 |
+
|
| 79 |
+
第一版推荐输入通道:
|
| 80 |
+
|
| 81 |
+
- `W_logmel`
|
| 82 |
+
- `X_logmel`
|
| 83 |
+
- `Y_logmel`
|
| 84 |
+
- `Z_logmel`
|
| 85 |
+
- `IVx`
|
| 86 |
+
- `IVy`
|
| 87 |
+
- `IVz`
|
| 88 |
+
|
| 89 |
+
即:
|
| 90 |
+
|
| 91 |
+
- `C_foa = 7`
|
| 92 |
+
|
| 93 |
+
这是默认推荐方案。
|
| 94 |
+
|
| 95 |
+
### 4.2 备选输入特征
|
| 96 |
+
|
| 97 |
+
若希望先降低复杂度,可以使用:
|
| 98 |
+
|
| 99 |
+
- `WXYZ logmel`
|
| 100 |
+
|
| 101 |
+
即:
|
| 102 |
+
|
| 103 |
+
- `C_foa = 4`
|
| 104 |
+
|
| 105 |
+
但这只适合最小原型。
|
| 106 |
+
如果目标是稳定学习空间方向与结构,优先使用 `WXYZ + IV`。
|
| 107 |
+
|
| 108 |
+
### 4.3 前端参数建议
|
| 109 |
+
|
| 110 |
+
为了最大化复用 BEATs 主干,推荐保持与 BEATs 接近的时频分辨率:
|
| 111 |
+
|
| 112 |
+
- sample rate:优先 `16k`
|
| 113 |
+
- mel bins:`128`
|
| 114 |
+
- frame length:`25 ms`
|
| 115 |
+
- frame shift:`10 ms`
|
| 116 |
+
|
| 117 |
+
原因:
|
| 118 |
+
|
| 119 |
+
- 这能让 trunk 看到与原始 BEATs 更接近的 patch 几何结构
|
| 120 |
+
- patch embedding 和后续序列长度更容易保持一致
|
| 121 |
+
- 预训练权重复用更稳定
|
| 122 |
+
|
| 123 |
+
### 4.4 为什么不沿用 Spatial-AST 的 binaural 前端
|
| 124 |
+
|
| 125 |
+
Spatial-AST 采用的是:
|
| 126 |
+
|
| 127 |
+
- 双耳 log-mel
|
| 128 |
+
- IPD
|
| 129 |
+
|
| 130 |
+
这适合 binaural,不适合直接迁移到 FOA。
|
| 131 |
+
|
| 132 |
+
FOA 下应优先利用:
|
| 133 |
+
|
| 134 |
+
- ambisonic 通道本身
|
| 135 |
+
- intensity vector
|
| 136 |
+
- 其他 FOA 物理特征
|
| 137 |
+
|
| 138 |
+
## 5. 对 BEATs 具体修改哪些模块
|
| 139 |
+
|
| 140 |
+
下面按模块说明修改方案。
|
| 141 |
+
|
| 142 |
+
### 5.1 保留不动的模块
|
| 143 |
+
|
| 144 |
+
建议尽量保留:
|
| 145 |
+
|
| 146 |
+
- `TransformerEncoder`
|
| 147 |
+
- `TransformerSentenceEncoderLayer`
|
| 148 |
+
- `MultiheadAttention`
|
| 149 |
+
- `conv_pos`
|
| 150 |
+
- `LayerNorm`
|
| 151 |
+
- `FFN`
|
| 152 |
+
- `post_extract_proj`
|
| 153 |
+
|
| 154 |
+
也就是 `backbone.py` 内的主干结构和 `BEATs.py` 中的 trunk 逻辑尽量不动。
|
| 155 |
+
|
| 156 |
+
### 5.2 必须修改的模块
|
| 157 |
+
|
| 158 |
+
必须重做:
|
| 159 |
+
|
| 160 |
+
1. `preprocess`
|
| 161 |
+
2. `patch_embedding`
|
| 162 |
+
3. `extract_features` 输出头部逻辑
|
| 163 |
+
4. 下游 `predictor`
|
| 164 |
+
|
| 165 |
+
### 5.3 推荐新增的模块
|
| 166 |
+
|
| 167 |
+
建议新增:
|
| 168 |
+
|
| 169 |
+
1. `SpatialBEATsPreprocessor`
|
| 170 |
+
2. `SpatialPatchEmbedding`
|
| 171 |
+
3. `SpatialQueryDecoder`
|
| 172 |
+
4. `SpatialPredictionHead`
|
| 173 |
+
5. `SpatialTokenProjector`
|
| 174 |
+
6. `HungarianMatcher`
|
| 175 |
+
7. `SpatialSetCriterion`
|
| 176 |
+
|
| 177 |
+
## 6. 代码级映射建议
|
| 178 |
+
|
| 179 |
+
### 6.1 现有文件建议
|
| 180 |
+
|
| 181 |
+
建议保留和复用:
|
| 182 |
+
|
| 183 |
+
- [BEATs.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/BEATs.py)
|
| 184 |
+
- [backbone.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/backbone.py)
|
| 185 |
+
|
| 186 |
+
建议新增:
|
| 187 |
+
|
| 188 |
+
- `spatial_beats.py`
|
| 189 |
+
- `spatial_modules.py`
|
| 190 |
+
- `spatial_loss.py`
|
| 191 |
+
- `spatial_dataset.py`
|
| 192 |
+
- `train_spatial_beats.py`
|
| 193 |
+
|
| 194 |
+
### 6.2 `spatial_beats.py` 建议包含
|
| 195 |
+
|
| 196 |
+
建议实现:
|
| 197 |
+
|
| 198 |
+
- `SpatialBEATsConfig`
|
| 199 |
+
- `SpatialBEATs`
|
| 200 |
+
- `SpatialBEATs.extract_spatial_tokens()`
|
| 201 |
+
- `SpatialBEATs.forward()`
|
| 202 |
+
|
| 203 |
+
### 6.3 `spatial_modules.py` 建议包含
|
| 204 |
+
|
| 205 |
+
建议实现:
|
| 206 |
+
|
| 207 |
+
- `SpatialBEATsPreprocessor`
|
| 208 |
+
- `SpatialPatchEmbedding`
|
| 209 |
+
- `SpatialQueryDecoder`
|
| 210 |
+
- `SpatialPredictionHead`
|
| 211 |
+
- `SpatialTokenProjector`
|
| 212 |
+
|
| 213 |
+
### 6.4 `spatial_loss.py` 建议包含
|
| 214 |
+
|
| 215 |
+
建议实现:
|
| 216 |
+
|
| 217 |
+
- `HungarianMatcher`
|
| 218 |
+
- `SpatialSetCriterion`
|
| 219 |
+
|
| 220 |
+
## 7. 预训练权重如何复用
|
| 221 |
+
|
| 222 |
+
## 7.1 默认推荐权重
|
| 223 |
+
|
| 224 |
+
默认推荐:
|
| 225 |
+
|
| 226 |
+
- `BEATs_iter3+ (AS2M) pre-trained`
|
| 227 |
+
|
| 228 |
+
而不是:
|
| 229 |
+
|
| 230 |
+
- fine-tuned checkpoints
|
| 231 |
+
|
| 232 |
+
原因:
|
| 233 |
+
|
| 234 |
+
- `pre-trained` 更适合作为 trunk 初始化
|
| 235 |
+
- `fine-tuned` 更偏向 AudioSet 分类判别
|
| 236 |
+
- 你这里的 spatial encoder 应与原语义 encoder 职责分离
|
| 237 |
+
|
| 238 |
+
### 7.2 必须直接加载的层
|
| 239 |
+
|
| 240 |
+
这些层建议直接加载原 BEATs checkpoint:
|
| 241 |
+
|
| 242 |
+
- `post_extract_proj`
|
| 243 |
+
- `encoder.pos_conv`
|
| 244 |
+
- `encoder.layers.*`
|
| 245 |
+
- `encoder.layer_norm`
|
| 246 |
+
- `layer_norm`
|
| 247 |
+
|
| 248 |
+
即除了输入 stem 和输出头,主干参数都尽量继承。
|
| 249 |
+
|
| 250 |
+
### 7.3 需要特殊初始化的层
|
| 251 |
+
|
| 252 |
+
以下层因为 shape 不同,不能直接 strict load:
|
| 253 |
+
|
| 254 |
+
- `patch_embedding`
|
| 255 |
+
- 新增的 `query decoder`
|
| 256 |
+
- 新增的 `spatial heads`
|
| 257 |
+
- 新增的 `LLM projector`
|
| 258 |
+
|
| 259 |
+
### 7.4 新 patch embedding 的初始化策略
|
| 260 |
+
|
| 261 |
+
原 BEATs stem 是:
|
| 262 |
+
|
| 263 |
+
- `Conv2d(1, embed_dim, kernel_size=patch, stride=patch)`
|
| 264 |
+
|
| 265 |
+
新 stem 建议是:
|
| 266 |
+
|
| 267 |
+
- `Conv2d(C_foa, embed_dim, kernel_size=patch, stride=patch)`
|
| 268 |
+
|
| 269 |
+
推荐初始化策略:
|
| 270 |
+
|
| 271 |
+
#### 方案 A:保守初始化,默认推荐
|
| 272 |
+
|
| 273 |
+
- `W_logmel` 通道继承原 stem 权重
|
| 274 |
+
- 其他空间通道初始化为 `0` 或较小随机值
|
| 275 |
+
|
| 276 |
+
优点:
|
| 277 |
+
|
| 278 |
+
- 最大程度保留原 BEATs 初始分布
|
| 279 |
+
- trunk 适配更稳
|
| 280 |
+
|
| 281 |
+
缺点:
|
| 282 |
+
|
| 283 |
+
- 训练初期空间通道利用较慢
|
| 284 |
+
|
| 285 |
+
#### 方案 B:通道 inflation
|
| 286 |
+
|
| 287 |
+
- 把原 stem 权重复制到全部输入通道
|
| 288 |
+
- 再按通道数做归一化
|
| 289 |
+
|
| 290 |
+
优点:
|
| 291 |
+
|
| 292 |
+
- 所有通道一开始都能进入主干
|
| 293 |
+
|
| 294 |
+
缺点:
|
| 295 |
+
|
| 296 |
+
- 初始统计更可能偏离原 BEATs
|
| 297 |
+
|
| 298 |
+
最终推荐:
|
| 299 |
+
|
| 300 |
+
- 第一版用 `方案 A`
|
| 301 |
+
- 后续做 ablation 再比较 `方案 B`
|
| 302 |
+
|
| 303 |
+
## 8. Spatial token 模块的最终设计
|
| 304 |
+
|
| 305 |
+
### 8.1 为什么不用全局池化
|
| 306 |
+
|
| 307 |
+
原始 BEATs 的输出方式更接近:
|
| 308 |
+
|
| 309 |
+
- patch sequence
|
| 310 |
+
- mean pooling
|
| 311 |
+
- clip-level prediction
|
| 312 |
+
|
| 313 |
+
这不适合多源空间任务。
|
| 314 |
+
|
| 315 |
+
### 8.2 最终推荐:Query Decoder
|
| 316 |
+
|
| 317 |
+
在 trunk 输出后新增:
|
| 318 |
+
|
| 319 |
+
- `K` 个 learnable source queries
|
| 320 |
+
- 一个轻量 `cross-attention decoder`
|
| 321 |
+
|
| 322 |
+
输入:
|
| 323 |
+
|
| 324 |
+
- encoder memory:`H in R^{B x T x D}`
|
| 325 |
+
- source queries:`Q in R^{B x K x D}`
|
| 326 |
+
|
| 327 |
+
输出:
|
| 328 |
+
|
| 329 |
+
- `Z in R^{B x K x D}`
|
| 330 |
+
|
| 331 |
+
这里的 `Z[:, i, :]` 即第 `i` 个 `source token`
|
| 332 |
+
|
| 333 |
+
### 8.3 为什么 query decoder 是当前最优解
|
| 334 |
+
|
| 335 |
+
它的优点:
|
| 336 |
+
|
| 337 |
+
- 不改 trunk 内部结构
|
| 338 |
+
- 仍然让完整 FOA 特征经过 backbone
|
| 339 |
+
- 适合多源 set prediction
|
| 340 |
+
- 最利于最大化复用 trunk 权重
|
| 341 |
+
|
| 342 |
+
## 9. 输出头设计
|
| 343 |
+
|
| 344 |
+
对每个 source token `z_i`,预测:
|
| 345 |
+
|
| 346 |
+
- `objectness`
|
| 347 |
+
- `azimuth`
|
| 348 |
+
- `elevation`
|
| 349 |
+
- `distance`
|
| 350 |
+
- 可选 `class_aux`
|
| 351 |
+
|
| 352 |
+
### 9.1 离散还是连续
|
| 353 |
+
|
| 354 |
+
第一版推荐全部使用离散分类头:
|
| 355 |
+
|
| 356 |
+
- `azimuth`: 360 bins
|
| 357 |
+
- `elevation`: 180 bins
|
| 358 |
+
- `distance`: 按数据分桶,例如 `0.5m` 一档
|
| 359 |
+
|
| 360 |
+
原因:
|
| 361 |
+
|
| 362 |
+
- 与已有 Spatial-AST/BAT 经验一致
|
| 363 |
+
- 分类头更稳
|
| 364 |
+
- 更便于构造离散坐标 embedding
|
| 365 |
+
|
| 366 |
+
### 9.2 objectness 头
|
| 367 |
+
|
| 368 |
+
推荐增加:
|
| 369 |
+
|
| 370 |
+
- `objectness_head: D -> 1`
|
| 371 |
+
|
| 372 |
+
用于:
|
| 373 |
+
|
| 374 |
+
- 判断当前 token 是否对应真实声源
|
| 375 |
+
- 作为 Hungarian matching 的一部分
|
| 376 |
+
- 推理时做 token 保留/裁剪
|
| 377 |
+
|
| 378 |
+
### 9.3 类别头
|
| 379 |
+
|
| 380 |
+
类别头建议作为:
|
| 381 |
+
|
| 382 |
+
- `auxiliary head`
|
| 383 |
+
|
| 384 |
+
而不是最终 LLM 的主要输入内容。
|
| 385 |
+
|
| 386 |
+
这样做的作用:
|
| 387 |
+
|
| 388 |
+
- 让 query token 更容易学会 source slot 对齐
|
| 389 |
+
- 但不把 Spatial-BEATs 变成第二个强语义 encoder
|
| 390 |
+
|
| 391 |
+
## 10. Loss 设计
|
| 392 |
+
|
| 393 |
+
推荐总损失:
|
| 394 |
+
|
| 395 |
+
```text
|
| 396 |
+
L_total =
|
| 397 |
+
lambda_obj * L_obj
|
| 398 |
+
+ lambda_azi * L_azi
|
| 399 |
+
+ lambda_ele * L_ele
|
| 400 |
+
+ lambda_dist * L_dist
|
| 401 |
+
+ lambda_cls * L_cls_aux
|
| 402 |
+
```
|
| 403 |
+
|
| 404 |
+
### 10.1 匹配方式
|
| 405 |
+
|
| 406 |
+
使用 `Hungarian matching`:
|
| 407 |
+
|
| 408 |
+
- 预测:`K` 个 token
|
| 409 |
+
- GT:`N` 个 sources
|
| 410 |
+
- 成本由以下项构成:
|
| 411 |
+
- objectness cost
|
| 412 |
+
- azimuth cost
|
| 413 |
+
- elevation cost
|
| 414 |
+
- distance cost
|
| 415 |
+
- optional class cost
|
| 416 |
+
|
| 417 |
+
### 10.2 损失项定义
|
| 418 |
+
|
| 419 |
+
推荐:
|
| 420 |
+
|
| 421 |
+
- `L_obj`: BCE 或 focal loss
|
| 422 |
+
- `L_azi`: cross entropy
|
| 423 |
+
- `L_ele`: cross entropy
|
| 424 |
+
- `L_dist`: cross entropy
|
| 425 |
+
- `L_cls_aux`: cross entropy 或 BCE
|
| 426 |
+
|
| 427 |
+
### 10.3 初始 loss 权重建议
|
| 428 |
+
|
| 429 |
+
第一版建议从以下权重起步:
|
| 430 |
+
|
| 431 |
+
```text
|
| 432 |
+
lambda_obj = 1.0
|
| 433 |
+
lambda_azi = 2.0
|
| 434 |
+
lambda_ele = 2.0
|
| 435 |
+
lambda_dist = 1.0
|
| 436 |
+
lambda_cls = 0.25
|
| 437 |
+
```
|
| 438 |
+
|
| 439 |
+
解释:
|
| 440 |
+
|
| 441 |
+
- 方向任务通常更关键
|
| 442 |
+
- 距离次之
|
| 443 |
+
- objectness 必须稳定
|
| 444 |
+
- 类别监督只作为辅助
|
| 445 |
+
|
| 446 |
+
### 10.4 不建议的做法
|
| 447 |
+
|
| 448 |
+
第一版不建议:
|
| 449 |
+
|
| 450 |
+
- 重分类损失压倒空间损失
|
| 451 |
+
- 直接照搬 Spatial-AST 的 `1250 * cls`
|
| 452 |
+
|
| 453 |
+
原因:
|
| 454 |
+
|
| 455 |
+
- Spatial-AST 的目标之一是保住 sound event detection
|
| 456 |
+
- 这里 `Spatial-BEATs` 的主要目标是空间 token
|
| 457 |
+
- 原项目已有独立语义 encoder
|
| 458 |
+
|
| 459 |
+
## 11. 训练策略
|
| 460 |
+
|
| 461 |
+
### 11.1 第一阶段是否需要 SSL
|
| 462 |
+
|
| 463 |
+
当前最终结论:
|
| 464 |
+
|
| 465 |
+
- 第一版 **不需要** 重新做 BEATs 式 SSL
|
| 466 |
+
|
| 467 |
+
因为当前已经有:
|
| 468 |
+
|
| 469 |
+
- 多源监督
|
| 470 |
+
- 每个源的空间标注
|
| 471 |
+
- 可复用的 BEATs 主干预训练
|
| 472 |
+
|
| 473 |
+
所以第一阶段应优先做:
|
| 474 |
+
|
| 475 |
+
- `supervised multi-source spatial training`
|
| 476 |
+
|
| 477 |
+
### 11.2 分阶段训练建议
|
| 478 |
+
|
| 479 |
+
#### Stage A:Warmup
|
| 480 |
+
|
| 481 |
+
冻结:
|
| 482 |
+
|
| 483 |
+
- 大部分 trunk
|
| 484 |
+
|
| 485 |
+
只训练:
|
| 486 |
+
|
| 487 |
+
- FOA preprocessor
|
| 488 |
+
- patch embedding
|
| 489 |
+
- query decoder
|
| 490 |
+
- spatial heads
|
| 491 |
+
- LLM projector
|
| 492 |
+
|
| 493 |
+
目的:
|
| 494 |
+
|
| 495 |
+
- 让新输入 stem 和新输出头稳定接入 trunk
|
| 496 |
+
|
| 497 |
+
#### Stage B:Upper-trunk finetune
|
| 498 |
+
|
| 499 |
+
解冻:
|
| 500 |
+
|
| 501 |
+
- trunk 上层若干层
|
| 502 |
+
|
| 503 |
+
目的:
|
| 504 |
+
|
| 505 |
+
- 让主干逐步适应 FOA 空间任务
|
| 506 |
+
|
| 507 |
+
#### Stage C:Near-full finetune
|
| 508 |
+
|
| 509 |
+
进一步解冻:
|
| 510 |
+
|
| 511 |
+
- 更多 encoder layers
|
| 512 |
+
|
| 513 |
+
目的:
|
| 514 |
+
|
| 515 |
+
- 提升空间表示上限
|
| 516 |
+
|
| 517 |
+
### 11.3 学习率建议
|
| 518 |
+
|
| 519 |
+
推荐:
|
| 520 |
+
|
| 521 |
+
- trunk:较小 lr
|
| 522 |
+
- 新模块:较大学习率
|
| 523 |
+
|
| 524 |
+
例如:
|
| 525 |
+
|
| 526 |
+
```text
|
| 527 |
+
lr_trunk = 1e-5 ~ 5e-5
|
| 528 |
+
lr_new = 1e-4 ~ 5e-4
|
| 529 |
+
```
|
| 530 |
+
|
| 531 |
+
并配合:
|
| 532 |
+
|
| 533 |
+
- layer-wise lr decay
|
| 534 |
+
|
| 535 |
+
## 12. 最终输出给 LLM 的 spatial token 形式
|
| 536 |
+
|
| 537 |
+
这是本项目最关键的接口定义之一。
|
| 538 |
+
|
| 539 |
+
### 12.1 内部 token 形式
|
| 540 |
+
|
| 541 |
+
`Spatial-BEATs` 内部输出:
|
| 542 |
+
|
| 543 |
+
- `Z in R^{B x K x D}`
|
| 544 |
+
|
| 545 |
+
其中:
|
| 546 |
+
|
| 547 |
+
- `B`: batch size
|
| 548 |
+
- `K`: source token 数
|
| 549 |
+
- `D`: Spatial-BEATs hidden dim,建议与 BEATs trunk 一致
|
| 550 |
+
|
| 551 |
+
### 12.2 不建议直接把 raw logits 喂给 LLM
|
| 552 |
+
|
| 553 |
+
不建议直接给 LLM:
|
| 554 |
+
|
| 555 |
+
- azimuth logits
|
| 556 |
+
- elevation logits
|
| 557 |
+
- distance logits
|
| 558 |
+
- objectness logits
|
| 559 |
+
|
| 560 |
+
这些是监督头,不是最终模态表示。
|
| 561 |
+
|
| 562 |
+
### 12.3 最终推荐的 LLM spatial token 形式
|
| 563 |
+
|
| 564 |
+
最终推荐送给 LLM 的每个 token 形式为:
|
| 565 |
+
|
| 566 |
+
```text
|
| 567 |
+
s_i = Proj([z_i ; e_azi(i) ; e_ele(i) ; e_dist(i) ; e_obj(i)])
|
| 568 |
+
```
|
| 569 |
+
|
| 570 |
+
其中:
|
| 571 |
+
|
| 572 |
+
- `z_i`: query decoder 输出的 latent token
|
| 573 |
+
- `e_azi(i)`: 由预测 azimuth bin 查表得到的 embedding
|
| 574 |
+
- `e_ele(i)`: 由预测 elevation bin 查表得到的 embedding
|
| 575 |
+
- `e_dist(i)`: 由预测 distance bin 查表得到的 embedding
|
| 576 |
+
- `e_obj(i)`: 由 objectness/confidence 产生的 embedding
|
| 577 |
+
- `Proj`: 投影到 LLM hidden size 的 MLP/Linear
|
| 578 |
+
|
| 579 |
+
最终:
|
| 580 |
+
|
| 581 |
+
- `s_i in R^{d_llm}`
|
| 582 |
+
|
| 583 |
+
### 12.4 为什么采用“latent + structured embedding”的混合形式
|
| 584 |
+
|
| 585 |
+
原因:
|
| 586 |
+
|
| 587 |
+
1. `z_i` 保留丰富的隐式空间结构信息
|
| 588 |
+
2. `坐标 embedding` 给 LLM 显式离散空间线索
|
| 589 |
+
3. `confidence` 有助于 LLM 区分可靠/不可靠 token
|
| 590 |
+
|
| 591 |
+
这比单纯只传:
|
| 592 |
+
|
| 593 |
+
- raw latent token
|
| 594 |
+
|
| 595 |
+
或者只传:
|
| 596 |
+
|
| 597 |
+
- 显式坐标 one-hot / scalar
|
| 598 |
+
|
| 599 |
+
都更合适。
|
| 600 |
+
|
| 601 |
+
### 12.5 最终序列形式
|
| 602 |
+
|
| 603 |
+
送入 LLM 时推荐:
|
| 604 |
+
|
| 605 |
+
```text
|
| 606 |
+
<SPATIAL_START>, s_1, s_2, ..., s_K, <SPATIAL_END>
|
| 607 |
+
```
|
| 608 |
+
|
| 609 |
+
并且:
|
| 610 |
+
|
| 611 |
+
- 按 `objectness` 从高到低排序
|
| 612 |
+
- 对低置信 token 可直接截断或 mask
|
| 613 |
+
|
| 614 |
+
### 12.6 是否保留全部 K 个 token
|
| 615 |
+
|
| 616 |
+
默认推荐:
|
| 617 |
+
|
| 618 |
+
- 训练时保留全部 `K`
|
| 619 |
+
- 推理时按 `objectness` 过滤
|
| 620 |
+
|
| 621 |
+
例如:
|
| 622 |
+
|
| 623 |
+
- 保留前 `K_keep`
|
| 624 |
+
- 或保留 `obj > threshold` 的 token
|
| 625 |
+
|
| 626 |
+
## 13. 与原语义 audio encoder 的关系
|
| 627 |
+
|
| 628 |
+
为了避免“两个 encoder 在做同样的事”,推荐如下职责划分:
|
| 629 |
+
|
| 630 |
+
- 原语义 audio encoder:负责 `what`
|
| 631 |
+
- Spatial-BEATs:负责 `where / spatial structure / relations`
|
| 632 |
+
|
| 633 |
+
### 13.1 是否允许 Spatial-BEATs 学类别
|
| 634 |
+
|
| 635 |
+
允许,但只作为辅助。
|
| 636 |
+
|
| 637 |
+
建议:
|
| 638 |
+
|
| 639 |
+
- 类别头只用于训练
|
| 640 |
+
- 最终输入给 LLM 的空间 token 不直接暴露完整类别 logits
|
| 641 |
+
|
| 642 |
+
### 13.2 是否需要和语义 encoder 做对齐
|
| 643 |
+
|
| 644 |
+
第一版不是必须。
|
| 645 |
+
|
| 646 |
+
若后续希望更强的 source grounding,可进一步加入:
|
| 647 |
+
|
| 648 |
+
- semantic distillation
|
| 649 |
+
- cross-encoder alignment
|
| 650 |
+
- source-wise contrastive loss
|
| 651 |
+
|
| 652 |
+
但这些应放到第二阶段。
|
| 653 |
+
|
| 654 |
+
## 14. 第一版推荐配置
|
| 655 |
+
|
| 656 |
+
第一版默认建议:
|
| 657 |
+
|
| 658 |
+
- 输入特征:`WXYZ + IVxyz`
|
| 659 |
+
- `C_foa = 7`
|
| 660 |
+
- 采样率:`16k`
|
| 661 |
+
- mel bins:`128`
|
| 662 |
+
- patch 配置:与 BEATs 保持一致
|
| 663 |
+
- 预训练权重:`BEATs_iter3+ AS2M pre-trained`
|
| 664 |
+
- trunk:最大化加载
|
| 665 |
+
- patch stem:`W` 继承,其余通道小初始化
|
| 666 |
+
- 输出:`K` 个 source tokens
|
| 667 |
+
- token 解码:轻量 query decoder
|
| 668 |
+
- 监督:Hungarian matching + 多头空间分类
|
| 669 |
+
- LLM 输入:`latent + structured coordinate embedding` 的混合 token
|
| 670 |
+
|
| 671 |
+
## 15. 实现优先级
|
| 672 |
+
|
| 673 |
+
推荐按如下优先级推进:
|
| 674 |
+
|
| 675 |
+
1. 实现 `FOA preprocessor`
|
| 676 |
+
2. 实现多通道 `patch embedding`
|
| 677 |
+
3. 完成 trunk ckpt 加载
|
| 678 |
+
4. 实现 `query decoder`
|
| 679 |
+
5. 实现 `objectness / azi / ele / dist` heads
|
| 680 |
+
6. 实现 `Hungarian matcher + criterion`
|
| 681 |
+
7. 实现 `LLM projector`
|
| 682 |
+
8. 完成训练脚本
|
| 683 |
+
|
| 684 |
+
## 16. 当前仍需用户确认的问题
|
| 685 |
+
|
| 686 |
+
以下问题会直接影响第一版实现细节:
|
| 687 |
+
|
| 688 |
+
1. `FOA` 数据当前主要采样率是多少?是 `16k`、`24k`、`32k` 还是 `48k`?
|
| 689 |
+
2. 每个样本中 `最大同时源数` 大概是多少?这会影响 `K` 的默认设定。
|
| 690 |
+
3. 每个源是否都有 `source-level class label`?如果有,类别头和匹配会更稳。
|
| 691 |
+
4. 你希望 `distance` 是离散分类还是连续回归?当前默认推荐离散分类。
|
| 692 |
+
5. 下游 LLM 的 hidden size 是多少?是否已有固定的 audio token projector?
|
| 693 |
+
6. 你是否希望 Spatial-BEATs 在第一版就具备一定的 source semantic 辅助能力,还是严格只做空间?
|
| 694 |
+
|
| 695 |
+
## 17. 结论
|
| 696 |
+
|
| 697 |
+
当前最终方案已经明确:
|
| 698 |
+
|
| 699 |
+
- **完整 FOA 特征进入 BEATs 主干**
|
| 700 |
+
- **最大化复用 trunk 预训练**
|
| 701 |
+
- **重做输入 stem**
|
| 702 |
+
- **重做输出为多源 spatial tokens**
|
| 703 |
+
- **第一版采用监督式 set prediction**
|
| 704 |
+
- **最终给 LLM 的不是 raw logits,而是融合 latent 与坐标 embedding 的 spatial tokens**
|
| 705 |
+
|
| 706 |
+
这是当前最符合项目目标、也最稳妥的 `Spatial-BEATs` 方案。
|
docs/spatial_beats_training_overview.md
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Spatial-BEATs Training And Architecture Overview
|
| 2 |
+
|
| 3 |
+
This document summarizes the current `Spatial-BEATs` implementation in this repository:
|
| 4 |
+
|
| 5 |
+
- model architecture
|
| 6 |
+
- tensor shape flow
|
| 7 |
+
- dataset contract
|
| 8 |
+
- variable-length batching
|
| 9 |
+
- supervision and losses
|
| 10 |
+
- stage-1 training setup
|
| 11 |
+
- current `ov1/ov2/ov3` presets
|
| 12 |
+
|
| 13 |
+
The implementation described here corresponds to:
|
| 14 |
+
|
| 15 |
+
- [spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats.py)
|
| 16 |
+
- [spatial_modules.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_modules.py)
|
| 17 |
+
- [spatial_dataset.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_dataset.py)
|
| 18 |
+
- [spatial_loss.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_loss.py)
|
| 19 |
+
- [train_spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/train_spatial_beats.py)
|
| 20 |
+
- [spatial_beats_ov123_stage1_config.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats_ov123_stage1_config.py)
|
| 21 |
+
|
| 22 |
+
## 1. Goal
|
| 23 |
+
|
| 24 |
+
`Spatial-BEATs` is a separate spatial encoder for FOA audio.
|
| 25 |
+
|
| 26 |
+
It is designed to:
|
| 27 |
+
|
| 28 |
+
- reuse the BEATs backbone and pretrained weights
|
| 29 |
+
- take full FOA input instead of only the `W` channel
|
| 30 |
+
- learn spatial structure through explicit supervision
|
| 31 |
+
- output fixed-rate spatial tokens for an LLM
|
| 32 |
+
- stay separate from the original audio encoder used for semantic audio understanding
|
| 33 |
+
|
| 34 |
+
The current implementation follows the simplified design:
|
| 35 |
+
|
| 36 |
+
- the main objective is to train the FOA front-end and BEATs trunk to produce spatially informative embeddings
|
| 37 |
+
- the supervision heads are lightweight readout heads
|
| 38 |
+
- the final LLM tokens are taken from the encoder-side spatial embeddings, not from the final logits
|
| 39 |
+
|
| 40 |
+
## 2. High-Level Architecture
|
| 41 |
+
|
| 42 |
+
The end-to-end model path is:
|
| 43 |
+
|
| 44 |
+
```text
|
| 45 |
+
FOA waveform
|
| 46 |
+
-> FOA spatial preprocessor
|
| 47 |
+
-> multi-channel patch embedding
|
| 48 |
+
-> BEATs trunk
|
| 49 |
+
-> frequency pooling
|
| 50 |
+
-> temporal resampling to 2.5 Hz
|
| 51 |
+
-> shallow temporal readout
|
| 52 |
+
-> spatial embeddings
|
| 53 |
+
-> fixed-slot supervision heads
|
| 54 |
+
-> projector
|
| 55 |
+
-> LLM spatial tokens
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
More concretely:
|
| 59 |
+
|
| 60 |
+
```text
|
| 61 |
+
[B, 4, T]
|
| 62 |
+
-> [B, 7, T_f, 128]
|
| 63 |
+
-> [B, N_p, 512]
|
| 64 |
+
-> [B, N_p, 768]
|
| 65 |
+
-> [B, T_p, 768]
|
| 66 |
+
-> [B, T_s_max, 768]
|
| 67 |
+
-> [B, T_s_max, 768]
|
| 68 |
+
-> [B, T_s_max, 4, 768]
|
| 69 |
+
-> [B, T_s_max, d_llm]
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
Where:
|
| 73 |
+
|
| 74 |
+
- `B`: batch size
|
| 75 |
+
- `T`: waveform length in samples
|
| 76 |
+
- `T_f`: acoustic frame count before patching
|
| 77 |
+
- `N_p`: number of BEATs patches
|
| 78 |
+
- `T_p`: time-axis patch count after frequency pooling
|
| 79 |
+
- `T_s_max`: padded token count in the batch after resampling to `2.5 Hz`
|
| 80 |
+
- `d_llm`: spatial token width sent to the LLM
|
| 81 |
+
|
| 82 |
+
## 3. Input And Front-End
|
| 83 |
+
|
| 84 |
+
### 3.1 Input audio
|
| 85 |
+
|
| 86 |
+
The model expects:
|
| 87 |
+
|
| 88 |
+
- FOA waveform
|
| 89 |
+
- shape `[B, 4, T]`
|
| 90 |
+
- channel order: `W, X, Y, Z`
|
| 91 |
+
- sample rate: `16 kHz`
|
| 92 |
+
|
| 93 |
+
### 3.2 Qwen-like low-level mel setup
|
| 94 |
+
|
| 95 |
+
The current front-end is aligned to the Qwen-2.5-Omni audio tower style low-level parameters:
|
| 96 |
+
|
| 97 |
+
- `sample_rate = 16000`
|
| 98 |
+
- `num_mel_bins = 128`
|
| 99 |
+
- `n_fft = 400`
|
| 100 |
+
- `win_length = 400`
|
| 101 |
+
- `hop_length = 160`
|
| 102 |
+
- `dither = 0.0`
|
| 103 |
+
|
| 104 |
+
These parameters are shared between:
|
| 105 |
+
|
| 106 |
+
- `SpatialBEATsConfig`
|
| 107 |
+
- `SpatialDatasetConfig`
|
| 108 |
+
|
| 109 |
+
This keeps the data pipeline and the model front-end consistent.
|
| 110 |
+
|
| 111 |
+
### 3.3 FOA feature construction
|
| 112 |
+
|
| 113 |
+
The preprocessor converts FOA waveform into a 7-channel feature map:
|
| 114 |
+
|
| 115 |
+
- `W_logmel`
|
| 116 |
+
- `X_logmel`
|
| 117 |
+
- `Y_logmel`
|
| 118 |
+
- `Z_logmel`
|
| 119 |
+
- `IVx`
|
| 120 |
+
- `IVy`
|
| 121 |
+
- `IVz`
|
| 122 |
+
|
| 123 |
+
Output shape:
|
| 124 |
+
|
| 125 |
+
- `foa_feat: [B, 7, T_f, 128]`
|
| 126 |
+
|
| 127 |
+
This allows the whole FOA structure to enter the backbone instead of relying on only `W`.
|
| 128 |
+
|
| 129 |
+
## 4. Backbone And Spatial Embedding Path
|
| 130 |
+
|
| 131 |
+
### 4.1 Spatial patch embedding
|
| 132 |
+
|
| 133 |
+
The model replaces the original single-channel patch stem with a 7-channel patch embedding:
|
| 134 |
+
|
| 135 |
+
- input: `foa_feat [B, 7, T_f, 128]`
|
| 136 |
+
- output: `patch_tokens [B, N_p, 512]`
|
| 137 |
+
- also returns `grid_size = (T_p, F_p)`
|
| 138 |
+
|
| 139 |
+
This is the first modified entry point for reusing BEATs on FOA input.
|
| 140 |
+
|
| 141 |
+
### 4.2 Reused BEATs trunk
|
| 142 |
+
|
| 143 |
+
The trunk reuses BEATs pretrained components:
|
| 144 |
+
|
| 145 |
+
- `layer_norm`
|
| 146 |
+
- `post_extract_proj`
|
| 147 |
+
- `encoder.pos_conv`
|
| 148 |
+
- all transformer layers
|
| 149 |
+
- `encoder.layer_norm`
|
| 150 |
+
|
| 151 |
+
Flow:
|
| 152 |
+
|
| 153 |
+
- input: `patch_tokens [B, N_p, 512]`
|
| 154 |
+
- output: `encoder_memory [B, N_p, 768]`
|
| 155 |
+
|
| 156 |
+
### 4.3 Frequency pooling
|
| 157 |
+
|
| 158 |
+
The patch sequence is reshaped back into a patch grid and pooled over the frequency axis:
|
| 159 |
+
|
| 160 |
+
- input: `encoder_memory [B, N_p, 768]` with `grid_size=(T_p, F_p)`
|
| 161 |
+
- reshaped internally to `[B, T_p, F_p, 768]`
|
| 162 |
+
- pooled output: `temporal_patch_tokens [B, T_p, 768]`
|
| 163 |
+
|
| 164 |
+
This produces a time-aligned sequence before the final token-rate conversion.
|
| 165 |
+
|
| 166 |
+
### 4.4 Temporal resampling
|
| 167 |
+
|
| 168 |
+
The temporal resampler converts the patch-rate sequence into the final spatial token rate:
|
| 169 |
+
|
| 170 |
+
- target token rate: `2.5 Hz`
|
| 171 |
+
- per-sample target length:
|
| 172 |
+
|
| 173 |
+
```text
|
| 174 |
+
T_s_i = round(duration_i * 2.5)
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
Batch handling:
|
| 178 |
+
|
| 179 |
+
- each sample is resampled independently
|
| 180 |
+
- the batch is padded to `T_s_max = max_i(T_s_i)`
|
| 181 |
+
- a temporal mask is produced
|
| 182 |
+
|
| 183 |
+
Outputs:
|
| 184 |
+
|
| 185 |
+
- `temporal_tokens: [B, T_s_max, 768]`
|
| 186 |
+
- `temporal_padding_mask: [B, T_s_max]`
|
| 187 |
+
|
| 188 |
+
Mask convention:
|
| 189 |
+
|
| 190 |
+
- `False`: valid time step
|
| 191 |
+
- `True`: padded time step
|
| 192 |
+
|
| 193 |
+
### 4.5 Shallow temporal readout
|
| 194 |
+
|
| 195 |
+
The shallow temporal readout refines the resampled sequence with a lightweight transformer encoder:
|
| 196 |
+
|
| 197 |
+
- input: `temporal_tokens [B, T_s_max, 768]`
|
| 198 |
+
- output: `spatial_embeddings [B, T_s_max, 768]`
|
| 199 |
+
|
| 200 |
+
This is the main representation used for both:
|
| 201 |
+
|
| 202 |
+
- spatial supervision
|
| 203 |
+
- final projection to LLM tokens
|
| 204 |
+
|
| 205 |
+
## 5. Supervision Heads
|
| 206 |
+
|
| 207 |
+
The current stage-1 design does not use a heavy decoder.
|
| 208 |
+
|
| 209 |
+
Instead, it uses a fixed-slot readout for supervision only.
|
| 210 |
+
|
| 211 |
+
### 5.1 Fixed-slot readout
|
| 212 |
+
|
| 213 |
+
The readout expands each time step into a small number of internal supervision slots:
|
| 214 |
+
|
| 215 |
+
- max slots per step: `K = 4`
|
| 216 |
+
- input: `spatial_embeddings [B, T_s_max, 768]`
|
| 217 |
+
- output: `slot_latents [B, T_s_max, 4, 768]`
|
| 218 |
+
|
| 219 |
+
Important:
|
| 220 |
+
|
| 221 |
+
- `K=4` is only a supervision capacity
|
| 222 |
+
- it does not change the final LLM token count
|
| 223 |
+
- the final LLM-visible token rate is still `2.5 Hz`
|
| 224 |
+
|
| 225 |
+
### 5.2 Prediction heads
|
| 226 |
+
|
| 227 |
+
Each supervision slot predicts:
|
| 228 |
+
|
| 229 |
+
- `pred_activity: [B, T_s_max, 4]`
|
| 230 |
+
- `pred_azi_logits: [B, T_s_max, 4, 360]`
|
| 231 |
+
- `pred_ele_logits: [B, T_s_max, 4, 180]`
|
| 232 |
+
- `pred_dist: [B, T_s_max, 4, 1]`
|
| 233 |
+
- `pred_class_logits: [B, T_s_max, 4, C]`
|
| 234 |
+
|
| 235 |
+
Where:
|
| 236 |
+
|
| 237 |
+
- `C = 65`
|
| 238 |
+
- the class vocabulary comes from:
|
| 239 |
+
- `/apdcephfs_cq12/share_302080740/user/schmittzhu/data/fsd50k/FSD50K.ground_truth/final_vocabulary.csv`
|
| 240 |
+
|
| 241 |
+
These heads are used to supply explicit training loss and push the front-end plus BEATs trunk to learn spatial structure.
|
| 242 |
+
|
| 243 |
+
## 6. LLM Spatial Tokens
|
| 244 |
+
|
| 245 |
+
The final LLM tokens are not taken from slot logits.
|
| 246 |
+
|
| 247 |
+
They are projected from the encoder-side spatial embeddings:
|
| 248 |
+
|
| 249 |
+
- input: `spatial_embeddings [B, T_s_max, 768]`
|
| 250 |
+
- output: `llm_spatial_tokens [B, T_s_max, d_llm]`
|
| 251 |
+
|
| 252 |
+
Therefore:
|
| 253 |
+
|
| 254 |
+
- `2.5 Hz` means final LLM-visible tokens arrive at `2.5 tokens/second`
|
| 255 |
+
- a `20 s` clip produces about `50` spatial tokens
|
| 256 |
+
- a `10 s` clip produces about `25` spatial tokens
|
| 257 |
+
|
| 258 |
+
This is the externally visible spatial token interface.
|
| 259 |
+
|
| 260 |
+
## 7. Pretrained Weight Reuse
|
| 261 |
+
|
| 262 |
+
The model initializes from `BEATs_iter3+ AS2M`.
|
| 263 |
+
|
| 264 |
+
Current pretrained loading logic:
|
| 265 |
+
|
| 266 |
+
- selectively load BEATs trunk modules
|
| 267 |
+
- skip task-specific components that do not match
|
| 268 |
+
- inflate the old single-channel patch embedding into the new 7-channel stem
|
| 269 |
+
|
| 270 |
+
Patch stem initialization rule:
|
| 271 |
+
|
| 272 |
+
- original BEATs patch weight is copied into channel `0` of the new 7-channel stem
|
| 273 |
+
- remaining channels start from zero
|
| 274 |
+
|
| 275 |
+
This is a conservative initialization intended to preserve BEATs trunk stability while enabling FOA adaptation.
|
| 276 |
+
|
| 277 |
+
## 8. Dataset Contract
|
| 278 |
+
|
| 279 |
+
### 8.1 Supported manifests
|
| 280 |
+
|
| 281 |
+
The dataset loader currently supports:
|
| 282 |
+
|
| 283 |
+
- `ov1_foa.jsonl`
|
| 284 |
+
- `ov2_foa.jsonl`
|
| 285 |
+
- `ov3_foa.jsonl`
|
| 286 |
+
|
| 287 |
+
It handles:
|
| 288 |
+
|
| 289 |
+
- single-source top-level manifest style
|
| 290 |
+
- nested multi-source manifest style with `sources`
|
| 291 |
+
|
| 292 |
+
### 8.2 Required scene-level data
|
| 293 |
+
|
| 294 |
+
At scene level the dataset expects one FOA path, typically:
|
| 295 |
+
|
| 296 |
+
- `output_foa_path`
|
| 297 |
+
|
| 298 |
+
or compatible fallback names already handled in the parser.
|
| 299 |
+
|
| 300 |
+
### 8.3 Required source-level data
|
| 301 |
+
|
| 302 |
+
For each source, the loader extracts:
|
| 303 |
+
|
| 304 |
+
- source class
|
| 305 |
+
- azimuth
|
| 306 |
+
- elevation
|
| 307 |
+
- distance
|
| 308 |
+
- weak time window
|
| 309 |
+
|
| 310 |
+
Internally each source is converted into a `SourceEvent` containing:
|
| 311 |
+
|
| 312 |
+
- `class_index`
|
| 313 |
+
- `class_label`
|
| 314 |
+
- `azimuth_deg`
|
| 315 |
+
- `elevation_deg`
|
| 316 |
+
- `distance_m`
|
| 317 |
+
- `start_time_seconds`
|
| 318 |
+
- `end_time_seconds`
|
| 319 |
+
|
| 320 |
+
### 8.4 Vocabulary mapping
|
| 321 |
+
|
| 322 |
+
Source labels are mapped to `final_vocabulary.csv`.
|
| 323 |
+
|
| 324 |
+
The loader supports several field aliases, including:
|
| 325 |
+
|
| 326 |
+
- `mono_target_label`
|
| 327 |
+
- `mono_primary_label`
|
| 328 |
+
- `final_label`
|
| 329 |
+
- `source_label`
|
| 330 |
+
- `label`
|
| 331 |
+
|
| 332 |
+
and several id-style aliases if an integer class index is already present.
|
| 333 |
+
|
| 334 |
+
## 9. Variable-Length Batching
|
| 335 |
+
|
| 336 |
+
Handling mixed-length FOA clips is a core part of the current implementation.
|
| 337 |
+
|
| 338 |
+
### 9.1 Waveform padding
|
| 339 |
+
|
| 340 |
+
At batch time:
|
| 341 |
+
|
| 342 |
+
- each waveform is padded to the batch maximum waveform length
|
| 343 |
+
- the padded tensor has shape `[B, 4, T_max]`
|
| 344 |
+
- a waveform padding mask is created:
|
| 345 |
+
|
| 346 |
+
```text
|
| 347 |
+
waveform_padding_mask: [B, T_max]
|
| 348 |
+
```
|
| 349 |
+
|
| 350 |
+
Mask convention:
|
| 351 |
+
|
| 352 |
+
- `False`: valid waveform sample
|
| 353 |
+
- `True`: padded sample
|
| 354 |
+
|
| 355 |
+
dui
|
| 356 |
+
### 9.2 Temporal token padding
|
| 357 |
+
|
| 358 |
+
After temporal resampling:
|
| 359 |
+
|
| 360 |
+
- each sample has its own `T_s_i = round(duration_i * 2.5)`
|
| 361 |
+
- the batch is padded to `T_s_max`
|
| 362 |
+
- the model returns:
|
| 363 |
+
|
| 364 |
+
```text
|
| 365 |
+
temporal_padding_mask: [B, T_s_max]
|
| 366 |
+
target_num_steps: [B]
|
| 367 |
+
```
|
| 368 |
+
|
| 369 |
+
All temporal supervision, matching, and loss computation respect these lengths.
|
| 370 |
+
|
| 371 |
+
### 9.3 Long clip truncation
|
| 372 |
+
|
| 373 |
+
The current training presets cap clip duration at:
|
| 374 |
+
|
| 375 |
+
- `20.0 seconds`
|
| 376 |
+
|
| 377 |
+
The dataset applies cropping before batching.
|
| 378 |
+
|
| 379 |
+
Preset crop policy:
|
| 380 |
+
|
| 381 |
+
- `crop_mode = "start"`
|
| 382 |
+
|
| 383 |
+
This means:
|
| 384 |
+
|
| 385 |
+
- clips longer than 20 seconds are truncated from the beginning
|
| 386 |
+
- training and validation follow the same deterministic sequence policy
|
| 387 |
+
|
| 388 |
+
If needed later, the dataset also supports:
|
| 389 |
+
|
| 390 |
+
- `random`
|
| 391 |
+
- `center`
|
| 392 |
+
- `none`
|
| 393 |
+
|
| 394 |
+
## 10. Matching And Losses
|
| 395 |
+
|
| 396 |
+
### 10.1 Weak temporal supervision
|
| 397 |
+
|
| 398 |
+
The model uses weak source windows:
|
| 399 |
+
|
| 400 |
+
- each source provides `start_time_seconds` and `end_time_seconds`
|
| 401 |
+
- these define a valid supervision window, not guaranteed frame-level activity
|
| 402 |
+
|
| 403 |
+
The loss code first converts source windows into a time-window mask:
|
| 404 |
+
|
| 405 |
+
```text
|
| 406 |
+
window_mask: [B, N_gt, T_s_max]
|
| 407 |
+
```
|
| 408 |
+
|
| 409 |
+
### 10.2 Per-step fixed-slot matching
|
| 410 |
+
|
| 411 |
+
Matching is performed per time step:
|
| 412 |
+
|
| 413 |
+
- only on valid temporal positions
|
| 414 |
+
- only within each source's weak time window
|
| 415 |
+
- between active GT sources and the `K=4` slot predictions
|
| 416 |
+
|
| 417 |
+
The current matcher uses a detached cost built from:
|
| 418 |
+
|
| 419 |
+
- activity
|
| 420 |
+
- class
|
| 421 |
+
- azimuth
|
| 422 |
+
- elevation
|
| 423 |
+
- distance
|
| 424 |
+
|
| 425 |
+
The output contains the assigned GT target for each valid slot-time pair.
|
| 426 |
+
|
| 427 |
+
### 10.3 Multi-task loss terms
|
| 428 |
+
|
| 429 |
+
Current loss terms are:
|
| 430 |
+
|
| 431 |
+
- `loss_activity`
|
| 432 |
+
- `loss_azi`
|
| 433 |
+
- `loss_ele`
|
| 434 |
+
- `loss_dist`
|
| 435 |
+
- `loss_cls_aux`
|
| 436 |
+
- `loss_temp`
|
| 437 |
+
|
| 438 |
+
Their roles:
|
| 439 |
+
|
| 440 |
+
- `loss_activity`
|
| 441 |
+
- `BCEWithLogits` on slot activity
|
| 442 |
+
- computed over valid time steps
|
| 443 |
+
- `loss_azi`
|
| 444 |
+
- cross-entropy over 360 azimuth bins
|
| 445 |
+
- `loss_ele`
|
| 446 |
+
- cross-entropy over 180 elevation bins
|
| 447 |
+
- `loss_dist`
|
| 448 |
+
- `SmoothL1Loss` on continuous distance regression
|
| 449 |
+
- `loss_cls_aux`
|
| 450 |
+
- auxiliary source class cross-entropy
|
| 451 |
+
- `loss_temp`
|
| 452 |
+
- temporal smoothness regularization over valid consecutive steps
|
| 453 |
+
|
| 454 |
+
The total loss is the weighted sum defined in `SpatialLossConfig`.
|
| 455 |
+
|
| 456 |
+
## 11. Stage-1 Training Flow
|
| 457 |
+
|
| 458 |
+
The current training entry is stage-1 encoder-focused training.
|
| 459 |
+
|
| 460 |
+
High-level flow per step:
|
| 461 |
+
|
| 462 |
+
```text
|
| 463 |
+
batch
|
| 464 |
+
-> SpatialBEATs.forward()
|
| 465 |
+
-> match_fixed_slots()
|
| 466 |
+
-> compute_spatial_losses()
|
| 467 |
+
-> backward()
|
| 468 |
+
-> optimizer.step()
|
| 469 |
+
```
|
| 470 |
+
|
| 471 |
+
### 11.1 Trainable modules in stage 1
|
| 472 |
+
|
| 473 |
+
By default, stage 1 trains:
|
| 474 |
+
|
| 475 |
+
- `preprocessor`
|
| 476 |
+
- `patch_embedding`
|
| 477 |
+
- `frequency_pool`
|
| 478 |
+
- `temporal_resampler`
|
| 479 |
+
- `temporal_readout`
|
| 480 |
+
- `slot_readout`
|
| 481 |
+
- `prediction_heads`
|
| 482 |
+
|
| 483 |
+
It can also unfreeze the BEATs trunk.
|
| 484 |
+
|
| 485 |
+
The projector is kept frozen by default in stage 1.
|
| 486 |
+
|
| 487 |
+
### 11.2 Optimizer
|
| 488 |
+
|
| 489 |
+
The current trainer uses:
|
| 490 |
+
|
| 491 |
+
- `AdamW`
|
| 492 |
+
|
| 493 |
+
Default preset values:
|
| 494 |
+
|
| 495 |
+
- `batch_size = 4`
|
| 496 |
+
- `num_epochs = 20`
|
| 497 |
+
- `learning_rate = 1e-4`
|
| 498 |
+
- `weight_decay = 0.05`
|
| 499 |
+
|
| 500 |
+
## 12. Current Presets
|
| 501 |
+
|
| 502 |
+
### 12.1 `OV123_STAGE1_CFG`
|
| 503 |
+
|
| 504 |
+
Defined in:
|
| 505 |
+
|
| 506 |
+
- [spatial_beats_ov123_stage1_config.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats_ov123_stage1_config.py)
|
| 507 |
+
|
| 508 |
+
This preset is intended to train on:
|
| 509 |
+
|
| 510 |
+
- `ov1_foa.jsonl`
|
| 511 |
+
- `ov2_foa.jsonl`
|
| 512 |
+
- `ov3_foa.jsonl`
|
| 513 |
+
|
| 514 |
+
with split filtering:
|
| 515 |
+
|
| 516 |
+
- train: `("train",)`
|
| 517 |
+
- val: `("valid",)`
|
| 518 |
+
- test: `("test",)`
|
| 519 |
+
|
| 520 |
+
and clip truncation:
|
| 521 |
+
|
| 522 |
+
- `max_clip_duration_seconds = 20.0`
|
| 523 |
+
- `crop_mode = "start"`
|
| 524 |
+
|
| 525 |
+
### 12.2 `OV23_STAGE1_CFG`
|
| 526 |
+
|
| 527 |
+
This is the safer baseline preset using only:
|
| 528 |
+
|
| 529 |
+
- `ov2_foa.jsonl`
|
| 530 |
+
- `ov3_foa.jsonl`
|
| 531 |
+
|
| 532 |
+
It uses the same split and truncation policy.
|
| 533 |
+
|
| 534 |
+
### 12.3 Important note on `ov1`
|
| 535 |
+
|
| 536 |
+
The trainer is already written to use `split` filtering for `ov1`.
|
| 537 |
+
|
| 538 |
+
If the active `ov1` manifest at the configured path does not yet contain `split`, then:
|
| 539 |
+
|
| 540 |
+
- the `OV123` preset will not automatically include those samples in train, valid, or test
|
| 541 |
+
- the fix is simply to point the preset at the updated `ov1` manifest path
|
| 542 |
+
|
| 543 |
+
The code path itself already supports split-aware loading.
|
| 544 |
+
|
| 545 |
+
## 13. Current Runtime Status
|
| 546 |
+
|
| 547 |
+
The current implementation has already been checked on:
|
| 548 |
+
|
| 549 |
+
- a real FOA waveform file
|
| 550 |
+
- mixed-length real manifest samples
|
| 551 |
+
- full forward pass
|
| 552 |
+
- fixed-slot matching
|
| 553 |
+
- multi-task loss computation
|
| 554 |
+
- BEATs pretrained weight loading
|
| 555 |
+
|
| 556 |
+
The following paths are already operational:
|
| 557 |
+
|
| 558 |
+
- dataset parsing
|
| 559 |
+
- waveform batching
|
| 560 |
+
- mixed-length temporal masking
|
| 561 |
+
- model forward
|
| 562 |
+
- matching
|
| 563 |
+
- loss computation
|
| 564 |
+
- stage-1 optimization loop
|
| 565 |
+
|
| 566 |
+
## 14. Recommended Launch Pattern
|
| 567 |
+
|
| 568 |
+
Example usage:
|
| 569 |
+
|
| 570 |
+
```python
|
| 571 |
+
from spatial_beats_ov123_stage1_config import OV123_STAGE1_CFG
|
| 572 |
+
from train_spatial_beats import main
|
| 573 |
+
|
| 574 |
+
main(OV123_STAGE1_CFG)
|
| 575 |
+
```
|
| 576 |
+
|
| 577 |
+
If `ov1` still needs a different manifest path, update only:
|
| 578 |
+
|
| 579 |
+
```python
|
| 580 |
+
OV123_STAGE1_CFG.train_manifest_paths
|
| 581 |
+
OV123_STAGE1_CFG.val_manifest_paths
|
| 582 |
+
OV123_STAGE1_CFG.test_manifest_paths
|
| 583 |
+
```
|
| 584 |
+
|
| 585 |
+
or rebuild the config through:
|
| 586 |
+
|
| 587 |
+
```python
|
| 588 |
+
from train_spatial_beats import make_ov123_stage1_config
|
| 589 |
+
```
|
| 590 |
+
|
| 591 |
+
## 15. Summary
|
| 592 |
+
|
| 593 |
+
The current `Spatial-BEATs` implementation is a FOA-first BEATs-based spatial encoder with:
|
| 594 |
+
|
| 595 |
+
- Qwen-like low-level mel settings
|
| 596 |
+
- a 7-channel FOA front-end
|
| 597 |
+
- reused BEATs trunk
|
| 598 |
+
- fixed-rate `2.5 Hz` spatial token output
|
| 599 |
+
- fixed-slot supervision heads
|
| 600 |
+
- variable-length batching
|
| 601 |
+
- split-aware `ov1/ov2/ov3` training presets
|
| 602 |
+
|
| 603 |
+
The central training idea is:
|
| 604 |
+
|
| 605 |
+
- use explicit spatial supervision to shape the front-end and BEATs trunk
|
| 606 |
+
- keep the supervision head lightweight
|
| 607 |
+
- use encoder-side spatial embeddings as the final source of LLM spatial tokens
|
| 608 |
+
|
docs/v13_honest_postmortem.md
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# v13 系列实验 Postmortem — 诚实记录我的错误判断
|
| 2 |
+
|
| 3 |
+
> 日期:2026-05-02
|
| 4 |
+
> 目的:保留错误判断的证据,避免后续重蹈覆辙
|
| 5 |
+
> 作者说明:多次对实验状态做出错误诊断,本文记录这些错误及其根因
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## 1. 核心错误:没注意 v13_D 的 cls warmup 是 8 ep 不是 3 ep
|
| 10 |
+
|
| 11 |
+
### 事实
|
| 12 |
+
|
| 13 |
+
v13_D 配置:
|
| 14 |
+
```python
|
| 15 |
+
cfg.frame_spatial_loss_warmup_epochs = 8 # v12 是 3
|
| 16 |
+
cfg.num_epochs = 25
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
这意味着 **ep0-7 全是 cls-only 训练**(spatial lambda 被 warmup 机制压为 warmup_scale),ep8 才是 spatial loss 真正放开的第一个 epoch。
|
| 20 |
+
|
| 21 |
+
### v13_D 真实训练曲线
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
ep F20 o_cls azi 阶段
|
| 25 |
+
0 0.311 0.650 28.6° cls warmup, spatial loss 几乎为 0
|
| 26 |
+
1 0.340 0.768 25.5° ...
|
| 27 |
+
3 0.308 0.788 26.5° ...
|
| 28 |
+
7 0.193 0.786 31.0° spatial 还没 ramp,azi 越走越远
|
| 29 |
+
8 0.397 0.876 18.5° ★ spatial loss 真正放开,F 跳涨
|
| 30 |
+
9 0.402 0.868 17.3°
|
| 31 |
+
10 0.402 0.864 17.2° (best so far)
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### 我的错误判断(两次对话前)
|
| 35 |
+
|
| 36 |
+
看到 ep0-7 的数据,我写了:
|
| 37 |
+
|
| 38 |
+
> v13_D:ep1 就到 best,之后一路发散
|
| 39 |
+
|
| 40 |
+
> 这是完全不同的问题,不是过拟合 — 是 top-k rank loss 本身没用对。
|
| 41 |
+
|
| 42 |
+
> Top-K rank loss + EMA + resume optimizer + cosine LR 这四个改动叠加起来打架了
|
| 43 |
+
|
| 44 |
+
**完全错误**。真相:
|
| 45 |
+
- ep1 的 "best" 只是 cls warmup 期间的 F20 虚高值,不代表 spatial 性能
|
| 46 |
+
- F 从 ep1 → ep7 下降,是因为 spatial loss 被压住但 trunk 在学 class(trunk 权重变化 → 没有 spatial 监督 → azi 漂移)
|
| 47 |
+
- ep8 spatial loss 放开后 F 立刻 **0.19 → 0.40**(+107%)
|
| 48 |
+
- **Top-K rank、EMA、cosine LR、resume optimizer 全部工作正常**
|
| 49 |
+
|
| 50 |
+
### 我为什么犯这个错
|
| 51 |
+
|
| 52 |
+
- 用 v12 的 3 ep warmup 经验直觉套 v13_D 的 8 ep warmup
|
| 53 |
+
- 没有先拉长查看 ep0-10 的完整曲线,只看前几个 epoch 就下诊断
|
| 54 |
+
- "F 在下降" 这个表面现象让我急于给出解释,没有对照 **ramp schedule**
|
| 55 |
+
|
| 56 |
+
## 2. 次要错误:误判了 v13_B 的状态
|
| 57 |
+
|
| 58 |
+
v13_B 实际跑到 ep4 停下:
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
ep F20 o_cls
|
| 62 |
+
0 0.255 0.647
|
| 63 |
+
1 0.292 0.771
|
| 64 |
+
2 0.298 0.779
|
| 65 |
+
3 0.357 0.776 ← spatial 放开第一个 epoch(warmup=3)
|
| 66 |
+
4 0.356 0.775
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
v13_B 的 warmup 是 3(与 v12 一致),所以 ep3 才是 spatial 放开第一个 epoch。ep3-4 刚放开 2 个 epoch,F 还没上涨空间。
|
| 70 |
+
|
| 71 |
+
### 我的错误判断
|
| 72 |
+
|
| 73 |
+
> v13_B 的 ASL + soft-F1 在这个数据规模下不足以扭转 precision/recall 权衡
|
| 74 |
+
|
| 75 |
+
**错**:ep4 就停了训练,根本没给 ASL + soft-F1 时间学习。如果和 v12 同样跑 15 epoch,F 可能到 0.38-0.40 也说不定。
|
| 76 |
+
|
| 77 |
+
### 更正
|
| 78 |
+
|
| 79 |
+
- v13_B 的结论应该是 **"跑不完整,不能下结论"**,而不是 "设计失败"
|
| 80 |
+
- 如果用户还想看 v13_B 的真实效果,应该重启并跑满 15+ epoch
|
| 81 |
+
|
| 82 |
+
## 3. v13_C 的判断基本正确,但也需校准
|
| 83 |
+
|
| 84 |
+
v13_C 的 spatial warmup 也是 3 ep(继承 v12)。跑满 15 epoch:
|
| 85 |
+
|
| 86 |
+
```
|
| 87 |
+
ep F20 val_loss
|
| 88 |
+
3 0.385 2.67 spatial 放开第一个 epoch,F 就跳到 0.385
|
| 89 |
+
7 0.387 3.40
|
| 90 |
+
10 0.385 3.38 F 平了
|
| 91 |
+
15 0.385 3.60 val_loss 持续上升
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
v13_C **确实** 从 ep3 就到 0.385,之后 12 个 epoch 没再涨,这是真的 overfitting(val_loss 从 2.67 → 3.60)。
|
| 95 |
+
|
| 96 |
+
但我的归因可能也有偏差:
|
| 97 |
+
- 我说 "real replication 6× 是失败配方"
|
| 98 |
+
- 但实际上 v13_C 是 C-1(real 6×)+ C-2(refinement)+ C-3(V3 adapter)+ C-4(log-dist) 四个改动同时叠加,不能武断归罪 C-1
|
| 99 |
+
- 正确的做法:做 ablation,只开 C-2,或只开 C-1,分别看
|
| 100 |
+
|
| 101 |
+
## 4. v13_E 的设计基于错误诊断
|
| 102 |
+
|
| 103 |
+
v13_E 的目标写的是 "F 0.40-0.43,基于 v13_D 崩溃的教训"。
|
| 104 |
+
但 v13_D 其实没崩,F 已到 0.402 @ ep10,还在涨。
|
| 105 |
+
|
| 106 |
+
**v13_E 的实际价值**:
|
| 107 |
+
- 它开启 num_active head 训练 + SELD evaluator 的 top-K̂ gate —— v13_D 没做的
|
| 108 |
+
- 作为 v13_D 之后的 **扩展实验**(v13_F?),不是替代
|
| 109 |
+
|
| 110 |
+
## 5. 学到的教训
|
| 111 |
+
|
| 112 |
+
### 教训 1:看 full trajectory,不看 prefix
|
| 113 |
+
以后评估实验状态,必须等到 **spatial ramp 结束** 且 **至少 5 个 epoch 的 spatial 阶段数据**。看 ep0-7 的 cls warmup 数据就下判断是严重错误。
|
| 114 |
+
|
| 115 |
+
### 教训 2:warmup schedule 是不同实验的关键差异
|
| 116 |
+
v12: warmup=3, v13_B/C: warmup=3, v13_D: warmup=8。相同 epoch number 对应的训练阶段完全不同。画图时应标注 "spatial_enabled_epoch" 作为基准点对齐。
|
| 117 |
+
|
| 118 |
+
### 教训 3:诊断要基于 pipeline 理解
|
| 119 |
+
我多次说 "Top-K rank 和 Hungarian 对着干"、"ASL 和 gate 互相抵消",但这些都是**基于少量 ep 数据的事后解释**,不是真正的机制推导。下次先问:"这组数据在 pipeline 的哪个阶段?"
|
| 120 |
+
|
| 121 |
+
### 教训 4:不要急于写新实验替代旧的
|
| 122 |
+
v13_E 本来不需要。v13_D 只要让它跑完 25 epoch 就够了。写 v13_E 的动机是 "v13_D 崩了所以要抢救",这个前提本身就错。
|
| 123 |
+
|
| 124 |
+
## 6. 接下来该怎么办
|
| 125 |
+
|
| 126 |
+
### 立即行动
|
| 127 |
+
- **让 v13_D 继续跑到 25 epoch**,不要停
|
| 128 |
+
- **v13_B 不要急着下结论**。如果有算力,重启跑满 15 epoch;没有就标记为 "incomplete"
|
| 129 |
+
- **v13_C 的结论保留**(确实 overfit),但不能归罪单一改动
|
| 130 |
+
|
| 131 |
+
### v13_D 最终预期(基于 v12 曲线外推)
|
| 132 |
+
|
| 133 |
+
```
|
| 134 |
+
v12 从 spatial 放开后曲线:
|
| 135 |
+
ep3: 0.353 (刚放开)
|
| 136 |
+
ep12: 0.378 (best, +0.025 / 9 ep)
|
| 137 |
+
|
| 138 |
+
v13_D 类比:
|
| 139 |
+
ep10: 0.402 (刚放开 ramp 结束)
|
| 140 |
+
ep19 (+9 ep): ~0.425 (保守估计)
|
| 141 |
+
ep22 (+12 ep): ~0.43~0.46 (乐观)
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
EMA + cosine LR 可能再给 +0.005~0.01。最终 v13_D 预期 **0.43 ~ 0.46**。
|
| 145 |
+
|
| 146 |
+
### v13_E 的定位调整
|
| 147 |
+
|
| 148 |
+
从"替代 v13_D" 改为 "v13_D 之后的扩展实验":
|
| 149 |
+
- 先等 v13_D 跑完
|
| 150 |
+
- 用 v13_D best.pt 作为 hot-start
|
| 151 |
+
- 在其上启用 num_active head + top-K̂ gate 看是否再涨
|
| 152 |
+
- 如果涨了 → v13_F 路径
|
| 153 |
+
- 如果不涨 → num_active 在这个任务上意义不大
|
| 154 |
+
|
| 155 |
+
代码和 run 脚本都已落地,随时可用,不影响 v13_D 的实验。
|
| 156 |
+
|
| 157 |
+
## 7. 对用户的道歉
|
| 158 |
+
|
| 159 |
+
我多次给出过于自信的错误判断:
|
| 160 |
+
- "v13_C F 卡在 0.38 不再上升" —— 其实那是 overfitting 没错,但原因归咎 C-1 过于武断
|
| 161 |
+
- "v13_D 不收敛" —— 完全错误
|
| 162 |
+
- "Top-K rank loss 本身没用对" —— 没有证据
|
| 163 |
+
- "改 activity loss 的改动都失败了" —— 基于错误数据的推断
|
| 164 |
+
|
| 165 |
+
应该做但没做的事:
|
| 166 |
+
- 应该先看 full trajectory 再诊断
|
| 167 |
+
- 应该注意 warmup schedule 差异
|
| 168 |
+
- 应该说 "ep0-7 是 cls warmup 期间的数据,不能代表最终性能"
|
| 169 |
+
|
| 170 |
+
**今后对策**:评估实验前先读该实验的 preset 代码,看 schedule 是什么,再去解读数据。不基于前几个 epoch 就给"崩了/不收敛"的结论。
|
docs/v13_spatial_beats_design.md
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# v13_B + v13_C 设计文档
|
| 2 |
+
|
| 3 |
+
> 日期:2026-05-01
|
| 4 |
+
> 作者:Claude + user
|
| 5 |
+
> 目标:在 v12(F20=0.378)基础上,分两个正交实验把 F20 推到 0.45+,合并实验 v14 目标 0.55~0.62
|
| 6 |
+
> 设计原则:**所有改动增量式**,通过 cfg flag 控制,默认全关 → 不破坏任何现有实验(v7/v11/v12 全可正常复现)
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## 0. 背景与瓶颈诊断(摘自 v12 per-subset 分析)
|
| 11 |
+
|
| 12 |
+
v12 best.pt (F20=0.3779 聚合) 按子集拆分:
|
| 13 |
+
|
| 14 |
+
```
|
| 15 |
+
子集 N F20 ER20 LE_CD LR_CD o_cls o_azi aP aR
|
| 16 |
+
ov1_sim 4800 0.386 0.686 26.8° 0.640 0.796 25.5° 0.950 0.046
|
| 17 |
+
ov2_sim 1718 0.299 0.926 29.6° 0.546 0.653 30.0° 0.916 0.151
|
| 18 |
+
ov3_sim 1612 0.270 0.917 30.5° 0.499 0.599 31.8° 0.888 0.165
|
| 19 |
+
ov1_real 3374 0.140 0.924 117.7° 0.232 0.809 27.4° 0.767 0.144
|
| 20 |
+
ov2_real 2230 0.098 0.866 121.1° 0.198 0.747 38.3° 0.655 0.227
|
| 21 |
+
ov3_real 740 0.052 0.766 146.8° 0.125 0.624 51.5° 0.612 0.232
|
| 22 |
+
dcase_starss 4560 0.071 1.171 130.3° 0.185 0.698 36.0° 0.625 0.176
|
| 23 |
+
unified 35021 (~0.40, 推算)
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
**三大诊断**:
|
| 27 |
+
|
| 28 |
+
1. **瓶颈是 activity_recall (0.13),不是 cls 也不是 spatial。** oracle_class_acc 在 real 上仍有 0.80,oracle_azi_mae_deg 27-51°,说明表征是健康的,被 activity gate 挡住了。
|
| 29 |
+
2. **Real/Sim gap 极大**:同 ov 级别下 real F20 比 sim 低 64~81%。原因是 train 里 dcase_real 只占 6%。
|
| 30 |
+
3. **Overlap 惩罚显著**:sim ov3 F20 比 ov1 低 30%(0.27 vs 0.39),K=4 track head 在 overlap 下 slot 分配混乱。
|
| 31 |
+
|
| 32 |
+
---
|
| 33 |
+
|
| 34 |
+
## 1. 实验切分:B 和 C 是正交维度
|
| 35 |
+
|
| 36 |
+
| 维度 | v13_B | v13_C |
|
| 37 |
+
|---|---|---|
|
| 38 |
+
| 改 loss / head 输出接口 | ✅ | ❌(沿用 v12 loss) |
|
| 39 |
+
| 改数据比例 / augment | 仅 augment | ✅ replication |
|
| 40 |
+
| 改主干架构 / 容量 | ❌ | ✅ refinement + adapter V3 |
|
| 41 |
+
| 热启动 | v12 best.pt (strict=False) | v12 best.pt (strict=False) |
|
| 42 |
+
| 能解的子集 | sim + real 的 activity 瓶颈 | ov2/ov3 overlap + real gap |
|
| 43 |
+
|
| 44 |
+
理由:让 B 的收益和 C 的收益互不污染,便于 ablation。v14 = B + C 合并。
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## 2. 共同约束(关键:不破坏现有实验)
|
| 49 |
+
|
| 50 |
+
所有 B/C 改动都要遵守:
|
| 51 |
+
|
| 52 |
+
### 2.1 Cfg flag 默认全关
|
| 53 |
+
|
| 54 |
+
每条新改动对应一个 `SpatialBEATsConfig` / `SpatialLossConfig` / `TrainSpatialBEATsConfig` 字段,**默认值就是"这条改动不启用"**。
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
# spatial_modules 侧
|
| 58 |
+
self.use_class_activity_bias: bool = False # [B-1]
|
| 59 |
+
self.use_class_conditional_gate: bool = False # [B-3]
|
| 60 |
+
self.use_track_refinement: bool = False # [C-2]
|
| 61 |
+
self.track_refinement_layers: int = 2
|
| 62 |
+
self.patch_adapter_version: str = "v1" # "v1"/"v2"/"v3" [C-3] 加 v3
|
| 63 |
+
self.use_log_distance_head: bool = False # [C-4]
|
| 64 |
+
|
| 65 |
+
# spatial_loss 侧
|
| 66 |
+
self.activity_loss_type: str = "bce" # "bce"/"asymmetric" [B-2]
|
| 67 |
+
self.asl_gamma_neg: float = 4.0
|
| 68 |
+
self.asl_gamma_pos: float = 0.0
|
| 69 |
+
self.asl_probability_margin: float = 0.05
|
| 70 |
+
self.soft_f1_weight: float = 0.0 # 0 = disabled [B-4]
|
| 71 |
+
self.distance_loss_type: str = "l1" # "l1"/"laplace_nll" [C-4]
|
| 72 |
+
|
| 73 |
+
# dataset 侧
|
| 74 |
+
self.use_spec_augment: bool = False # [B-5]
|
| 75 |
+
self.spec_augment_time_mask_ratio: float = 0.0
|
| 76 |
+
self.spec_augment_freq_mask_ratio: float = 0.0
|
| 77 |
+
self.random_gain_db: float = 0.0
|
| 78 |
+
self.channel_dropout_prob: float = 0.0
|
| 79 |
+
self.lowpass_sim_real_prob: float = 0.0
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
`ov1_unified_v12` preset 和之前所有 preset 都**不动**,因为它们没设置这些 flag → 新代码分支不走。
|
| 83 |
+
|
| 84 |
+
### 2.2 热启动安全
|
| 85 |
+
|
| 86 |
+
新加的 `nn.Module` / `nn.Parameter` 全部**零初始化**或**identity 等价初始化**:
|
| 87 |
+
|
| 88 |
+
- `class_activity_bias`:`torch.zeros(num_classes)` → logit 不变
|
| 89 |
+
- `GatingMLP` 最后一层 bias/weight zero-init → gate_logit = 0
|
| 90 |
+
- `TrackRefinementDecoder`:`layer_scale = zeros(num_layers)` → 残差 x + 0 = x
|
| 91 |
+
- `log_distance_head`:bias 初始化为 `log(mean_distance_v12) ≈ log(1.5)`
|
| 92 |
+
|
| 93 |
+
从 v12 best.pt `strict=False` 加载时:
|
| 94 |
+
- 不存在的 key → 走零初始化
|
| 95 |
+
- 存在的 key(所有 v12 组件)→ 正常加载
|
| 96 |
+
- ep0 forward 输出应与 v12 best.pt 完全一致(或数值上差 < 1e-5)
|
| 97 |
+
|
| 98 |
+
### 2.3 完全向后兼容
|
| 99 |
+
|
| 100 |
+
任何旧脚本跑起来(比如 `run_ov1_unified_v12.sh`)**不需要改一行**,因为所有新字段都有默认值且默认 disabled。
|
| 101 |
+
|
| 102 |
+
### 2.4 Preset 命名
|
| 103 |
+
|
| 104 |
+
- `ov1_unified_v13b`:启用所有 B 相关 flag
|
| 105 |
+
- `ov1_unified_v13c`:启用所有 C 相关 flag
|
| 106 |
+
- `ov1_unified_v14`:B + C 全开,热启动 max(v13b.best, v13c.best)
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## 3. v13_B 详细设计:Loss + Decision 全面重写
|
| 111 |
+
|
| 112 |
+
### [B-1] Per-class learnable logit bias
|
| 113 |
+
|
| 114 |
+
**为什么**:全局阈值 0.5 对所有类别一视同仁不合理。稀有类(jackhammer)和常见类(singing)的 activity 先验差异巨大,应当让模型自己学各类的 logit bias(等价于 per-class threshold)。
|
| 115 |
+
|
| 116 |
+
**实现点**:
|
| 117 |
+
- 文件:`spatial_modules.py`
|
| 118 |
+
- 类:`FrameTrackPredictionHeads`(定位:`class_logits` 和 `activity_logits` 的出口处)
|
| 119 |
+
- 新增 parameter:`self.class_activity_bias = nn.Parameter(torch.zeros(num_classes))`
|
| 120 |
+
- 新增 buffer:`self.use_class_activity_bias: bool`
|
| 121 |
+
- Forward 改动:
|
| 122 |
+
|
| 123 |
+
```python
|
| 124 |
+
# 原:
|
| 125 |
+
activity_logit = self.activity_head(token) # [B, K, T, 1]
|
| 126 |
+
|
| 127 |
+
# 新:
|
| 128 |
+
activity_logit_raw = self.activity_head(token) # [B, K, T, 1]
|
| 129 |
+
if self.use_class_activity_bias:
|
| 130 |
+
class_probs = F.softmax(class_logits, dim=-1) # [B, K, T, C]
|
| 131 |
+
# 用 class_probs 作为加权软分配(避免 argmax 阻断梯度)
|
| 132 |
+
expected_bias = torch.einsum('bktc,c->bkt', class_probs, self.class_activity_bias)
|
| 133 |
+
activity_logit = activity_logit_raw + expected_bias.unsqueeze(-1)
|
| 134 |
+
else:
|
| 135 |
+
activity_logit = activity_logit_raw
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
**训练/推理一致性**:bias 在训练的 BCE loss 和推理的 sigmoid 里都是**同一个量**,因此不需要 threshold sweep。推理时 threshold 始终 = 0.0(logit 空间)或 0.5(prob 空间),完全等价。
|
| 139 |
+
|
| 140 |
+
**参数量**:63 个标量,忽略不计。
|
| 141 |
+
|
| 142 |
+
### [B-2] Asymmetric Loss 替换 BCE
|
| 143 |
+
|
| 144 |
+
**为什么**:BCE 把 FN 和 FP 等权重。当前 activity_recall=0.13 说明 FN 惩罚严重不够。ASL 对 easy negatives 用 `(1-p)^γ-` 下压,对 positives 用弱 `γ+=0`,正负不均衡下表现显著好于 BCE。
|
| 145 |
+
|
| 146 |
+
**实现点**:
|
| 147 |
+
- 文件:`spatial_loss.py`
|
| 148 |
+
- 新增函数:`asymmetric_loss_with_logits(logits, targets, gamma_neg=4, gamma_pos=0, margin=0.05)`
|
| 149 |
+
- 在 `compute_frame_track_losses`(或同名函数)里根据 `config.activity_loss_type` 分支:
|
| 150 |
+
|
| 151 |
+
```python
|
| 152 |
+
if config.activity_loss_type == "asymmetric":
|
| 153 |
+
loss_act = asymmetric_loss_with_logits(
|
| 154 |
+
activity_logit, target_active,
|
| 155 |
+
gamma_neg=config.asl_gamma_neg,
|
| 156 |
+
gamma_pos=config.asl_gamma_pos,
|
| 157 |
+
margin=config.asl_probability_margin,
|
| 158 |
+
)
|
| 159 |
+
else: # "bce"
|
| 160 |
+
loss_act = F.binary_cross_entropy_with_logits(activity_logit, target_active, ...)
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
**数学**:
|
| 164 |
+
```
|
| 165 |
+
p = sigmoid(logit)
|
| 166 |
+
positive: -( (1-p)**γ+ ) * log(p)
|
| 167 |
+
negative: p_shifted = max(p - m, 0)
|
| 168 |
+
-( p_shifted**γ- ) * log(1 - p_shifted)
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
**参数**:`γ+ = 0`, `γ- = 4`, `margin = 0.05`(ASL paper 推荐起点)
|
| 172 |
+
|
| 173 |
+
### [B-3] Class-conditional gating MLP
|
| 174 |
+
|
| 175 |
+
**为什么**:activity 当前只看 token embedding。应该让 activity 也依赖 class/DOA 的确信度 —— class softmax 尖锐、DOA 稳定时更大胆判 active。
|
| 176 |
+
|
| 177 |
+
**实现点**:
|
| 178 |
+
- 文件:`spatial_modules.py`
|
| 179 |
+
- 新增类:`ClassConditionalGate(embed_dim, num_classes, hidden_dim=128)`
|
| 180 |
+
- 输入:`fused_token [B, K, T, D]`, `class_logits [B, K, T, C]`, `pred_dir [B, K, T, 3]`
|
| 181 |
+
- 融合:`gate_input = concat(token, class_emb_avg, dir_vec)` → MLP → `gate_logit [B, K, T, 1]`
|
| 182 |
+
- class_emb 用 `class_logits.softmax()` 加权的 class embedding(新增 `nn.Embedding(C, 32)`)
|
| 183 |
+
- 在 FrameTrackPredictionHeads 里:
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
if self.use_class_conditional_gate:
|
| 187 |
+
gate_logit = self.class_conditional_gate(token, class_logits, pred_dir)
|
| 188 |
+
activity_logit = activity_logit + self.gate_scale * gate_logit
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
**初始化**:MLP 最后一层 `weight=zero, bias=zero` → gate_logit = 0 → ep0 等价 v12。
|
| 192 |
+
|
| 193 |
+
**参数量**:~80K。
|
| 194 |
+
|
| 195 |
+
### [B-4] Soft-F1 auxiliary loss
|
| 196 |
+
|
| 197 |
+
**为什么**:BCE/ASL 仍是 per-sample 损失,优化目标和 macro-F20 评测有 gap。Soft-F1 直接按类聚合,和 DCASE 评估同构。
|
| 198 |
+
|
| 199 |
+
**实现点**:
|
| 200 |
+
- 文件:`spatial_loss.py`
|
| 201 |
+
- 新增函数:`soft_macro_f1_loss(activity_logits, class_logits, target_active, target_class)`
|
| 202 |
+
- 对每个类 `c`:
|
| 203 |
+
- `p_c = sigmoid(act_logit) * softmax(class)[c]` (class-c 的软 activity)
|
| 204 |
+
- `y_c = (target_active and target_class==c)`
|
| 205 |
+
- `tp_c = sum(p_c * y_c)`, `fp_c = sum(p_c * (1-y_c))`, `fn_c = sum((1-p_c) * y_c)`
|
| 206 |
+
- `f1_c = 2 tp_c / (2 tp_c + fp_c + fn_c + eps)`
|
| 207 |
+
- `loss = 1 - mean(f1_c)`
|
| 208 |
+
- 在总 loss 里:
|
| 209 |
+
|
| 210 |
+
```python
|
| 211 |
+
if config.soft_f1_weight > 0:
|
| 212 |
+
total_loss = total_loss + config.soft_f1_weight * soft_macro_f1_loss(...)
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
**warmup**(已确认采用):前 3 ep `soft_f1_weight=0.1`,第 3 ep 起硬切到 `0.3`。
|
| 216 |
+
|
| 217 |
+
实现方式:在 `train_spatial_beats.py` 的 epoch 循环里根据 `epoch >= soft_f1_warmup_epochs` 动态设置 `train_cfg.loss.soft_f1_weight`,新增 config 字段:
|
| 218 |
+
```python
|
| 219 |
+
cfg.loss.soft_f1_weight_warmup: float = 0.1 # ep < warmup_epochs 时使用
|
| 220 |
+
cfg.loss.soft_f1_weight: float = 0.3 # ep >= warmup_epochs
|
| 221 |
+
cfg.loss.soft_f1_warmup_epochs: int = 3
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
### [B-5] Real-distribution augment
|
| 225 |
+
|
| 226 |
+
**为什么**:sim_static 混响干净,模型学到的 activity 判据在低 SNR 下崩溃。augment 让模型见到各种"污染"的 spec,对 real 数据更鲁棒。
|
| 227 |
+
|
| 228 |
+
**实现点**:
|
| 229 |
+
- 文件:`spatial_dataset.py`
|
| 230 |
+
- 在 `SpatialDataset.__getitem__` 或 collate 里加 augment pipeline
|
| 231 |
+
- 只在训练集(`split='train'`)启用,valid/test 不启用
|
| 232 |
+
- 新增 config flag:
|
| 233 |
+
- `use_spec_augment`(默认 False)
|
| 234 |
+
- `spec_augment_time_mask_ratio`(0.2 = 20% time 长度)
|
| 235 |
+
- `spec_augment_freq_mask_ratio`(0.15)
|
| 236 |
+
- `random_gain_db`(±8)
|
| 237 |
+
- `channel_dropout_prob`(0.1)
|
| 238 |
+
- `lowpass_sim_real_prob`(0.1,cutoff ∈ U[4000, 8000] Hz)
|
| 239 |
+
|
| 240 |
+
**顺序**:waveform-level augment(gain, lowpass, channel_dropout)→ feature-level augment(SpecAugment)。
|
| 241 |
+
|
| 242 |
+
**重要**:augment 只作用在 FOA 4 通道 waveform / delta feature 上,**target labels 不变**。
|
| 243 |
+
|
| 244 |
+
### B 实验 preset: `make_ov1_unified_v13b_config`
|
| 245 |
+
|
| 246 |
+
热启动:`v12_best.pt` (strict=False)
|
| 247 |
+
开关:
|
| 248 |
+
```python
|
| 249 |
+
cfg.model.use_class_activity_bias = True # [B-1]
|
| 250 |
+
cfg.model.use_class_conditional_gate = True # [B-3]
|
| 251 |
+
cfg.loss.activity_loss_type = "asymmetric" # [B-2]
|
| 252 |
+
cfg.loss.asl_gamma_neg = 4.0
|
| 253 |
+
cfg.loss.asl_probability_margin = 0.05
|
| 254 |
+
cfg.loss.soft_f1_weight = 0.3 # [B-4]
|
| 255 |
+
cfg.dataset.use_spec_augment = True # [B-5]
|
| 256 |
+
cfg.dataset.spec_augment_time_mask_ratio = 0.2
|
| 257 |
+
cfg.dataset.spec_augment_freq_mask_ratio = 0.15
|
| 258 |
+
cfg.dataset.random_gain_db = 8.0
|
| 259 |
+
cfg.dataset.channel_dropout_prob = 0.1
|
| 260 |
+
cfg.dataset.lowpass_sim_real_prob = 0.1
|
| 261 |
+
cfg.learning_rate = 1e-5
|
| 262 |
+
cfg.num_epochs = 15
|
| 263 |
+
cfg.output_dir = "checkpoints/spatial_beats_ov1_unified_v13b_exp/03_ov123_top4"
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
**数据 manifest 完全复用 v12**:unified train/valid + old ov1/2/3 sim/real + dcase_starss 作为 val。
|
| 267 |
+
|
| 268 |
+
---
|
| 269 |
+
|
| 270 |
+
## 4. v13_C 详细设计:Data + Architecture 全面重写
|
| 271 |
+
|
| 272 |
+
### [C-1] Real data upsampling (replication)
|
| 273 |
+
|
| 274 |
+
**为什么**:real (dcase_real) 在 train 里占 6%,梯度感受不到。DCASE 社区标准做法是 20-30% real。
|
| 275 |
+
|
| 276 |
+
**实现点**:
|
| 277 |
+
- 预处理脚本:`scripts/split_unified_train_by_source.py`
|
| 278 |
+
- 读 `unified_spatial_foa_fsd63_all/train.jsonl`
|
| 279 |
+
- 按 `data_source` 字段拆成三份:
|
| 280 |
+
- `train_sim_static.jsonl`
|
| 281 |
+
- `train_qa_sim.jsonl`
|
| 282 |
+
- `train_dcase_real.jsonl`
|
| 283 |
+
- 写到 `unified_spatial_foa_fsd63_all/` 同目录下
|
| 284 |
+
- Preset:
|
| 285 |
+
|
| 286 |
+
```python
|
| 287 |
+
cfg.train_manifest_paths = (
|
| 288 |
+
unified_root / "train_sim_static.jsonl",
|
| 289 |
+
unified_root / "train_qa_sim.jsonl",
|
| 290 |
+
unified_root / "train_dcase_real.jsonl",
|
| 291 |
+
)
|
| 292 |
+
cfg.train_manifest_replication = (1, 1, 6)
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
**影响估算**(基于 v12 已知分布):
|
| 296 |
+
- sim_static 304K × 1 = 304K
|
| 297 |
+
- qa_sim ~? × 1 = ~?
|
| 298 |
+
- dcase_real 20K × 6 = 120K
|
| 299 |
+
- real 占比从 6% → ~25% (取决于 qa_sim 规模)
|
| 300 |
+
|
| 301 |
+
**兼容性**:`train_manifest_replication` 机制在 `train_spatial_beats.py` 已经存在(v7j 用过),不需要新加框架代码。只改 preset。
|
| 302 |
+
|
| 303 |
+
### [C-2] Track-wise Refinement Transformer(2 layers)
|
| 304 |
+
|
| 305 |
+
**为什么**:K=4 track slots 之间互相不知道对方在干嘛,overlap 时同一源被多个 slot 抢,或同一 slot 被多个源抢。引入 self-attention 让 slot 互相"排斥"。
|
| 306 |
+
|
| 307 |
+
**实现点**:
|
| 308 |
+
- 文件:`spatial_modules.py`
|
| 309 |
+
- 新增类:
|
| 310 |
+
|
| 311 |
+
```python
|
| 312 |
+
class TrackRefinementDecoder(nn.Module):
|
| 313 |
+
def __init__(self, num_tracks=4, embed_dim=768, num_layers=2,
|
| 314 |
+
num_heads=8, dim_feedforward=2048, dropout=0.0):
|
| 315 |
+
self.track_queries = nn.Parameter(torch.randn(num_tracks, embed_dim) * 0.02)
|
| 316 |
+
self.layers = nn.ModuleList([
|
| 317 |
+
nn.TransformerDecoderLayer(
|
| 318 |
+
d_model=embed_dim, nhead=num_heads,
|
| 319 |
+
dim_feedforward=dim_feedforward, dropout=dropout,
|
| 320 |
+
activation='gelu', norm_first=True, batch_first=True,
|
| 321 |
+
) for _ in range(num_layers)
|
| 322 |
+
])
|
| 323 |
+
# Zero-init layer scale: ep0 refinement = identity
|
| 324 |
+
self.layer_scale = nn.Parameter(torch.zeros(num_layers))
|
| 325 |
+
|
| 326 |
+
def forward(self, memory):
|
| 327 |
+
# memory: fused_spatial_embeddings [B, T, D]
|
| 328 |
+
# 输出:refined track tokens [B, K, T, D]
|
| 329 |
+
B, T, D = memory.shape
|
| 330 |
+
K = self.track_queries.size(0)
|
| 331 |
+
# 复制 K queries 到时间维度:[B, K, T, D]
|
| 332 |
+
q = self.track_queries[None, :, None, :].expand(B, K, T, D)
|
| 333 |
+
# 每个时间步独立做 decoder
|
| 334 |
+
# 为简化,把 T 维 flatten 进 batch:[B*T, K, D] cross-attn with [B*T, 1, D]
|
| 335 |
+
q_flat = q.permute(0, 2, 1, 3).reshape(B * T, K, D)
|
| 336 |
+
mem_flat = memory.reshape(B * T, 1, D)
|
| 337 |
+
for i, layer in enumerate(self.layers):
|
| 338 |
+
out = layer(q_flat, mem_flat)
|
| 339 |
+
q_flat = q_flat + self.layer_scale[i] * (out - q_flat)
|
| 340 |
+
# reshape 回 [B, K, T, D]
|
| 341 |
+
refined = q_flat.reshape(B, T, K, D).permute(0, 2, 1, 3).contiguous()
|
| 342 |
+
return refined
|
| 343 |
+
```
|
| 344 |
+
|
| 345 |
+
- 在 `SpatialBEATs` 里:
|
| 346 |
+
|
| 347 |
+
```python
|
| 348 |
+
if cfg.use_track_refinement:
|
| 349 |
+
self.track_refinement = TrackRefinementDecoder(
|
| 350 |
+
num_tracks=cfg.num_tracks,
|
| 351 |
+
embed_dim=cfg.encoder_embed_dim,
|
| 352 |
+
num_layers=cfg.track_refinement_layers,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# encode_patches 之后、送入 head 之前:
|
| 356 |
+
if self.track_refinement is not None:
|
| 357 |
+
track_tokens = self.track_refinement(encoder_memory) # [B, K, T, D]
|
| 358 |
+
# 传给 FrameTrackPredictionHeads 的输入从 [B, T, D] 改成 [B, K, T, D]
|
| 359 |
+
else:
|
| 360 |
+
track_tokens = None # head 沿用旧 expand 逻辑
|
| 361 |
+
```
|
| 362 |
+
|
| 363 |
+
- `FrameTrackPredictionHeads` 的 forward ��个 `track_tokens: Optional[Tensor]` 参数:
|
| 364 |
+
- 传入 None → 沿用现有的"[B,T,D] 复制到 K slots"
|
| 365 |
+
- 传入 `[B,K,T,D]` → 用 refined tokens 走 head
|
| 366 |
+
|
| 367 |
+
**参数量**:2 layer × (self_attn + cross_attn + FFN) ≈ 2 × 2M = 4M。
|
| 368 |
+
|
| 369 |
+
**Zero-init 校验**:`layer_scale = zeros(2)` + 残差公式 `q + scale * (out - q)` → ep0 输出 = `track_queries`(静态,和 memory 无关)。但这会丢掉时间信息 —— **修正**:改用 `q + scale * layer_out`,并且把 track_queries 初始化成 `memory` 投影:
|
| 370 |
+
|
| 371 |
+
实际更安全的等价初始化:
|
| 372 |
+
|
| 373 |
+
```python
|
| 374 |
+
# Zero-init 方案:layer 不改 query,query 本身先吸收 memory 信息
|
| 375 |
+
# 思路:在 refine 前先做一次 "identity fallback":如果 scale=0,输出 = memory 广播到 K
|
| 376 |
+
def forward(self, memory):
|
| 377 |
+
B, T, D = memory.shape
|
| 378 |
+
K = ...
|
| 379 |
+
# 初始 track_tokens = memory 广播到 K(+ 一个很小的 query 偏移)
|
| 380 |
+
track_tokens = memory[:, None, :, :].expand(B, K, T, D).contiguous()
|
| 381 |
+
track_tokens = track_tokens + 0.02 * self.track_queries[None, :, None, :]
|
| 382 |
+
# refine
|
| 383 |
+
for i, layer in enumerate(self.layers):
|
| 384 |
+
...
|
| 385 |
+
track_tokens = track_tokens + self.layer_scale[i] * delta
|
| 386 |
+
return track_tokens
|
| 387 |
+
```
|
| 388 |
+
|
| 389 |
+
这样 `layer_scale=0` 时 refinement 输出 ≈ `memory` 广播到 K,和 v12 "把 [B,T,D] 复制到 [B,K,T,D]" 等价。热启动安全。
|
| 390 |
+
|
| 391 |
+
### [C-3] Multi-scale Patch Adapter V3
|
| 392 |
+
|
| 393 |
+
**为什么**:v12 用的 V2 adapter 只看 3 个时间 bin(30 ms),抓不到房间冲激响应的 early reflection (50-150ms)。V3 加多尺度 + dilated conv。
|
| 394 |
+
|
| 395 |
+
**实现点**:
|
| 396 |
+
- 文件:`spatial_modules.py`
|
| 397 |
+
- 新增类:`SpatialDeltaPatchAdapterV3`
|
| 398 |
+
- 三路 branch:
|
| 399 |
+
- branch_3x3: `Conv2d(C, H, kernel=3, padding=1)` (同 V2)
|
| 400 |
+
- branch_5x5: `Conv2d(C, H, kernel=5, padding=2)` (中尺度)
|
| 401 |
+
- branch_dilated: `Conv2d(C, H, kernel=3, padding=2, dilation=2)` (长时)
|
| 402 |
+
- fuse: `torch.cat` along channel → `Conv2d(3H, H, kernel=1)`
|
| 403 |
+
- 接现有 V2 的 SE block + residual + patchify
|
| 404 |
+
- cfg:`patch_adapter_version: str = "v1"` 增加选项 `"v3"`
|
| 405 |
+
|
| 406 |
+
**参数量**:比 V2 多 ~1M。
|
| 407 |
+
|
| 408 |
+
### [C-4] Log-distance head + Laplace NLL loss
|
| 409 |
+
|
| 410 |
+
**为什么**:dist_mae=0.57 很差。距离分布长尾,log 后近似高斯。加 uncertainty 头允许模型对不确信的距离给大 variance,减少高 bias 样本的损失。
|
| 411 |
+
|
| 412 |
+
**实现点**:
|
| 413 |
+
- 文件:`spatial_modules.py`, 类 `FrameTrackPredictionHeads`
|
| 414 |
+
- 把现有 `distance_head: Linear(D, 1)` 升级为 `distance_head: Linear(D, 2)` 输出 `[log_dist, log_var]`
|
| 415 |
+
- 初始化:`bias[0] = log(1.5)`(v12 平均距离附近),`bias[1] = log(0.2^2)`(var=0.04 起点)
|
| 416 |
+
- cfg:`use_log_distance_head: bool = False`, `distance_loss_type: str = "l1" / "laplace_nll"`
|
| 417 |
+
|
| 418 |
+
- 文件:`spatial_loss.py`, Laplace NLL:
|
| 419 |
+
|
| 420 |
+
```python
|
| 421 |
+
def laplace_nll_loss(pred_log_dist, pred_log_var, target_dist, mask):
|
| 422 |
+
# target_dist > 0 的位置才算 loss
|
| 423 |
+
pred_dist = pred_log_dist.exp()
|
| 424 |
+
pred_b = (0.5 * pred_log_var).exp() # Laplace scale
|
| 425 |
+
nll = (target_dist - pred_dist).abs() / pred_b + pred_log_var * 0.5
|
| 426 |
+
return (nll * mask).sum() / mask.sum().clamp(min=1)
|
| 427 |
+
```
|
| 428 |
+
|
| 429 |
+
推理时 `pred_distance = exp(pred_log_dist)`。
|
| 430 |
+
|
| 431 |
+
**初期稳定性**(已确认):v13c 从第 0 ep 就切 Laplace NLL(不做 L1 warmup)。如果训练前期 loss_dist 不稳,再回来调。
|
| 432 |
+
|
| 433 |
+
### C 实验 preset: `make_ov1_unified_v13c_config`
|
| 434 |
+
|
| 435 |
+
热启动:`v12_best.pt` (strict=False)
|
| 436 |
+
开关:
|
| 437 |
+
```python
|
| 438 |
+
cfg.model.use_track_refinement = True # [C-2]
|
| 439 |
+
cfg.model.track_refinement_layers = 2
|
| 440 |
+
cfg.model.patch_adapter_version = "v3" # [C-3]
|
| 441 |
+
cfg.model.use_log_distance_head = True # [C-4]
|
| 442 |
+
cfg.loss.distance_loss_type = "laplace_nll" # [C-4]
|
| 443 |
+
|
| 444 |
+
cfg.train_manifest_paths = (sim_static, qa_sim, dcase_real) # [C-1]
|
| 445 |
+
cfg.train_manifest_replication = (1, 1, 6)
|
| 446 |
+
|
| 447 |
+
cfg.learning_rate = 1e-5
|
| 448 |
+
cfg.num_epochs = 20
|
| 449 |
+
cfg.output_dir = "checkpoints/spatial_beats_ov1_unified_v13c_exp/03_ov123_top4"
|
| 450 |
+
```
|
| 451 |
+
|
| 452 |
+
**loss 完全沿用 v12 默认**(BCE + L1 → Laplace),`soft_f1_weight=0`, `activity_loss_type="bce"`。
|
| 453 |
+
|
| 454 |
+
---
|
| 455 |
+
|
| 456 |
+
## 5. v14 合并实验(后续)
|
| 457 |
+
|
| 458 |
+
在 B 和 C 都验证有效(F20 > 0.42)后启动:
|
| 459 |
+
|
| 460 |
+
- 热启动:`max(v13b.best, v13c.best).pt`
|
| 461 |
+
- 所有 B 和 C 的 flag 全开
|
| 462 |
+
- LR = 5e-6(更保守,防止双改动发散)
|
| 463 |
+
- epochs = 20
|
| 464 |
+
|
| 465 |
+
预期 F20 = 0.55~0.62。
|
| 466 |
+
|
| 467 |
+
---
|
| 468 |
+
|
| 469 |
+
## 6. 预期结果 & 风险矩阵
|
| 470 |
+
|
| 471 |
+
### B 预期
|
| 472 |
+
- 聚合 F20: 0.378 → **0.45~0.52**
|
| 473 |
+
- sim ov1: 0.386 → 0.50~0.55
|
| 474 |
+
- real ov1: 0.140 → 0.22~0.28
|
| 475 |
+
- dcase: 0.071 → 0.15~0.20
|
| 476 |
+
- activity_recall: 0.13 → 0.40~0.55
|
| 477 |
+
|
| 478 |
+
### C 预期
|
| 479 |
+
- 聚合 F20: 0.378 → **0.44~0.50**
|
| 480 |
+
- ov3_sim: 0.270 → 0.38~0.42
|
| 481 |
+
- real ov1: 0.140 → 0.20~0.26
|
| 482 |
+
- dcase: 0.071 → 0.14~0.19
|
| 483 |
+
- dist_mae: 0.566 → 0.38~0.42
|
| 484 |
+
|
| 485 |
+
### 风险
|
| 486 |
+
|
| 487 |
+
| 风险 | 发生概率 | 兜底 |
|
| 488 |
+
|---|---|---|
|
| 489 |
+
| B-2 ASL γ- 太大发散 | 中 | 先 γ-=2 跑 1 ep 验证,再拉到 4 |
|
| 490 |
+
| B-3 gate 挡掉好样本 | 低 | gate_scale 从 0.5 改 0.2 重跑 |
|
| 491 |
+
| B-4 soft-F1 和 ASL 冲突 | 中 | soft_f1_weight 从 0.3 降到 0.1 |
|
| 492 |
+
| B-5 augment 太强 sim 掉点 | 中 | 减半 augment 比例重跑 |
|
| 493 |
+
| C-1 real 6× 导致 sim 掉点 | 中 | 降到 4× |
|
| 494 |
+
| C-2 refinement 不学 | 中 | 手动设 layer_scale warmup |
|
| 495 |
+
| C-3 多尺度显存爆 | 低 | 去掉 dilated branch |
|
| 496 |
+
| C-4 log-dist 初期不稳 | 中 | 前 3 ep 用 L1 再切 |
|
| 497 |
+
| v14 合并发散 | 高 | 降 LR 到 3e-6,freeze trunk 前半段 |
|
| 498 |
+
|
| 499 |
+
---
|
| 500 |
+
|
| 501 |
+
## 7. 落地文件清单
|
| 502 |
+
|
| 503 |
+
| 文件 | 改动类型 | B/C |
|
| 504 |
+
|---|---|---|
|
| 505 |
+
| `spatial_modules.py` | 新增 3 个类 + 现有类加 forward 分支 | B+C |
|
| 506 |
+
| `spatial_loss.py` | 新增 2 个 loss 函数 + config 分支 | B+C |
|
| 507 |
+
| `spatial_dataset.py` | 新增 augment 逻辑 + config 字段 | B |
|
| 508 |
+
| `spatial_beats.py` | config 字段 + 可选模块构造 + forward 分支 | B+C |
|
| 509 |
+
| `train_spatial_beats.py` | 新增 2 个 preset 工厂 + CLI dispatch + choices | B+C |
|
| 510 |
+
| `scripts/split_unified_train_by_source.py` | 新文件,预处理脚本 | C |
|
| 511 |
+
| `run_ov1_unified_v13b.sh` | 新文件 | B |
|
| 512 |
+
| `run_ov1_unified_v13c.sh` | 新文件 | C |
|
| 513 |
+
| `docs/v13_spatial_beats_design.md` | 本文档 | — |
|
| 514 |
+
|
| 515 |
+
所有现有文件的改动都是**新增分支**,不删除/修改任何现有逻辑。
|
| 516 |
+
|
| 517 |
+
---
|
| 518 |
+
|
| 519 |
+
## 8. 验证步骤
|
| 520 |
+
|
| 521 |
+
每完成一步,按顺序验证:
|
| 522 |
+
|
| 523 |
+
1. **语法检查**:`python -c "import ast; ast.parse(open(path).read())"`
|
| 524 |
+
2. **旧 preset 回归**:`python train_spatial_beats.py --preset ov1_unified_v12 --dry-run`(或者 ep=1 跑到第一个 batch),确认 F20 和 v12 一致
|
| 525 |
+
3. **新模块零初始化等价**:在 `v13b_config` / `v13c_config` 下跑 ep=0 valid,确认和 v12 best.pt 的 valid 指标差异 < 1%
|
| 526 |
+
4. **B/C 训练**:完整跑 15/20 ep,观察 F20 曲线
|
| 527 |
+
5. **per-subset eval**:用 `eval_v12_per_subset.py --preset ov1_unified_v13b --checkpoint ...` 看每个子集涨点
|
| 528 |
+
6. **test eval**:用同脚本加 `--split test` 跑最终测试
|
docs/v13d_spatial_beats_design.md
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# v13_D 设计文档 — Loss 机制 + 训练细节双重改进
|
| 2 |
+
|
| 3 |
+
> 日期:2026-05-02
|
| 4 |
+
> 起源:v13_B / v13_C 实验 ep3 结果证明 "loss/arch 改动在 cls warmup 结束时贡献微弱",需要重新从**训练机制**角度切入
|
| 5 |
+
> 目标:F20 从 v12 的 0.378 推到 **0.43+**
|
| 6 |
+
> 设计原则:延续 v13_B/C 的"增量式、不破坏现有实验",所有新 flag 默认 False
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## 0. 为什么 v13_B / v13_C 都不 work
|
| 11 |
+
|
| 12 |
+
ep3(cls warmup 结束)对比:
|
| 13 |
+
|
| 14 |
+
| | v12 ep3 | v13_B ep3 | v13_C ep3 |
|
| 15 |
+
|---|---|---|---|
|
| 16 |
+
| F20 | 0.3529 | 0.3565 (+0.004) | 0.3849 (+0.032) |
|
| 17 |
+
| o_cls | 0.7793 | 0.7757 | 0.7779 |
|
| 18 |
+
| aR | 0.120 | 0.194 | 0.090 |
|
| 19 |
+
| aP | 0.866 | 0.860 | 0.859 |
|
| 20 |
+
|
| 21 |
+
**三个实验的 oracle_class_acc 基本一致** → 三组实验的表征学的是同一个东西,新改动都没触达表征。loss/head decision 层面的改动(ASL / gate / soft-F1)在 cls warmup 结束时**基本还没起作用**,因为:
|
| 22 |
+
|
| 23 |
+
1. **zero-init 的新模块需要 3-5 ep 才能让 layer_scale 爬起来**,而 cls warmup 只有 3 ep
|
| 24 |
+
2. **ASL γ-=4 让 loss_activity 绝对值从 0.29 降到 0.07**,实际上是把 activity head 的梯度信号弄弱了
|
| 25 |
+
3. **augment 反而降低 activity_precision**(0.95 → 0.86)
|
| 26 |
+
4. **v12 自身从 ep3→ep12 只涨了 +0.025**,外推 v13_B/C 的 best 也就 0.38~0.41
|
| 27 |
+
|
| 28 |
+
## 1. v13_D 的切入点:**训练机制本身**
|
| 29 |
+
|
| 30 |
+
不碰架构、不碰表征,只改三件事:
|
| 31 |
+
|
| 32 |
+
- **D-1**: 扩大 cls warmup(3 ep → 8 ep)+ cosine LR + 总 25 ep,让表征学得更稳
|
| 33 |
+
- **D-2**: 用 Top-K rank activity loss 替换 BCE,直接针对 activity_recall=0.13 的根本瓶颈
|
| 34 |
+
- **D-5**: resume optimizer,从 v12 最后的 Adam momentum 继续,避免前 2-3 ep 梯度方向混乱
|
| 35 |
+
- **D-6**: EMA 权重,validate 时用 EMA 模型,ep 间震荡降低 1-2 个点
|
| 36 |
+
|
| 37 |
+
D-2 是核心,D-1 / D-5 / D-6 是辅助。
|
| 38 |
+
|
| 39 |
+
## 2. 具体设计
|
| 40 |
+
|
| 41 |
+
### D-1: cls warmup 拉长 + cosine LR + 25 epochs
|
| 42 |
+
|
| 43 |
+
#### 诊断
|
| 44 |
+
v12 `frame_spatial_loss_warmup_epochs = 3`,ep0-2 只训练 cls + activity,ep3 起放开 spatial loss。但:
|
| 45 |
+
- v12 曲线:cls_acc ep0=0.65 → ep1=0.81 → ep2=0.85 → ep3=0.84 → ... → ep12=0.83
|
| 46 |
+
- **ep2 已 0.85 但 ep3 反而微降 0.84**,说明 **ep3 放开 spatial loss 的瞬间干扰了 cls**
|
| 47 |
+
- cls 没有再涨的机会 —— 之后一直在 0.83 附近波动
|
| 48 |
+
- FSD63 的 oracle_cls 卡在 0.78 是 class head 学的,不是 trunk 表征的上限
|
| 49 |
+
|
| 50 |
+
#### 改动
|
| 51 |
+
```python
|
| 52 |
+
cfg.frame_spatial_loss_warmup_epochs = 8 # 3 → 8
|
| 53 |
+
cfg.frame_spatial_loss_ramp_epochs = 2 # 1 → 2(ramp 更平滑)
|
| 54 |
+
cfg.num_epochs = 25 # 15 → 25
|
| 55 |
+
|
| 56 |
+
# LR: cosine schedule,峰值在 ep 5-8 之间(warmup 结束附近)
|
| 57 |
+
cfg.use_cosine_lr = True # 新 flag
|
| 58 |
+
cfg.cosine_lr_warmup_epochs = 3 # 前 3 ep linear warmup 到 peak
|
| 59 |
+
cfg.cosine_lr_min_ratio = 0.05 # 最后降到 peak * 0.05
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
**训练循环改动**:
|
| 63 |
+
```python
|
| 64 |
+
# ep 0-2: linear warmup LR 0 → peak_lr
|
| 65 |
+
# ep 3-24: cosine decay peak_lr → peak_lr * min_ratio
|
| 66 |
+
# ep 0-7: spatial loss weight = warmup_scale (0.0 or 0.1)
|
| 67 |
+
# ep 8-9: linear ramp to 1.0
|
| 68 |
+
# ep 10-24: full spatial loss
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### D-2: Top-K rank activity loss(核心)
|
| 72 |
+
|
| 73 |
+
#### 诊断
|
| 74 |
+
当前 BCE(或 ASL)对每个 `(b, k, t)` 位置独立判断"该 slot 是否 active"。问题:
|
| 75 |
+
- K=4 slots 在 ov1 数据上永远 3/4 inactive,类别极不平衡
|
| 76 |
+
- BCE 的最优解是 `sigmoid(act_logit) ≈ 0.25`(平均 prior),不会敢预测 active
|
| 77 |
+
- SELD 评估实际用 **sorted by activity logit, take top-K̂** 的方式决策
|
| 78 |
+
- 训练目标和评估目标不对齐
|
| 79 |
+
|
| 80 |
+
#### 设计
|
| 81 |
+
|
| 82 |
+
**Top-K rank loss**:在每一帧 `(b, t)`,强制 top-`N_active_gt` 个 slot 的 activity logit **必须比其他 slot 至少高 margin**,而不是独立回归 0/1。
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
def topk_rank_activity_loss(
|
| 86 |
+
activity_logit: Tensor, # [B, K, T]
|
| 87 |
+
target_active: Tensor, # [B, K, T], 0/1
|
| 88 |
+
valid_time: Tensor, # [B, T]
|
| 89 |
+
margin: float = 2.0,
|
| 90 |
+
) -> Tensor:
|
| 91 |
+
"""
|
| 92 |
+
Per-frame marginal ranking loss:
|
| 93 |
+
For each active slot i (target=1) and each inactive slot j (target=0),
|
| 94 |
+
enforce logit[i] > logit[j] + margin.
|
| 95 |
+
|
| 96 |
+
Equivalent to:
|
| 97 |
+
loss = Σ_{i in A, j in I} max(0, margin + logit[j] - logit[i])
|
| 98 |
+
This gives direct gradient that "ranks" active slots above inactive ones,
|
| 99 |
+
which aligns with the DCASE eval pipeline (take top-K̂ per frame).
|
| 100 |
+
|
| 101 |
+
Plus a weak binary regularizer to anchor logit magnitude:
|
| 102 |
+
+ 0.1 * BCE(activity_logit, target_active)
|
| 103 |
+
"""
|
| 104 |
+
# [B, T] active_count per frame
|
| 105 |
+
n_active = target_active.sum(dim=1) # [B, T]
|
| 106 |
+
# Loop-free formulation using broadcasting:
|
| 107 |
+
# logit_i: [B, K, T] (active side)
|
| 108 |
+
# logit_j: [B, K, T] (inactive side)
|
| 109 |
+
# pairwise diff: [B, K_i, K_j, T]
|
| 110 |
+
# mask: target_active[i] * (1 - target_active[j]) [B, K_i, K_j, T]
|
| 111 |
+
act = target_active.unsqueeze(2) # [B, K, 1, T]
|
| 112 |
+
ina = (1.0 - target_active).unsqueeze(1) # [B, 1, K, T]
|
| 113 |
+
pair_mask = act * ina # [B, K_i, K_j, T]
|
| 114 |
+
logit_i = activity_logit.unsqueeze(2) # [B, K, 1, T]
|
| 115 |
+
logit_j = activity_logit.unsqueeze(1) # [B, 1, K, T]
|
| 116 |
+
diff = logit_j - logit_i + margin # want this <= 0
|
| 117 |
+
# hinge loss, masked
|
| 118 |
+
hinge = F.relu(diff) * pair_mask # [B, K, K, T]
|
| 119 |
+
# normalize by valid pairs count
|
| 120 |
+
pair_valid = pair_mask.sum(dim=(1, 2)) # [B, T]
|
| 121 |
+
time_mask = valid_time.float() * (pair_valid > 0).float()
|
| 122 |
+
loss_rank = (hinge.sum(dim=(1, 2)) * time_mask).sum() / time_mask.sum().clamp(min=1.0)
|
| 123 |
+
|
| 124 |
+
# Anchor term: prevents logits from drifting to ±inf
|
| 125 |
+
loss_bce = F.binary_cross_entropy_with_logits(
|
| 126 |
+
activity_logit, target_active, reduction='none'
|
| 127 |
+
)
|
| 128 |
+
loss_bce = (loss_bce * valid_time.unsqueeze(1)).mean()
|
| 129 |
+
|
| 130 |
+
return loss_rank + 0.1 * loss_bce
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
#### 为什么比 ASL 好
|
| 134 |
+
|
| 135 |
+
| 特性 | BCE | ASL | **Top-K rank** |
|
| 136 |
+
|---|---|---|---|
|
| 137 |
+
| 优化目标 | per-element logprob | per-element logprob (γ-weighted) | **pairwise ranking** |
|
| 138 |
+
| 受 K 不平衡影响 | 严重 | 缓解 | 无(只看 rank) |
|
| 139 |
+
| 与 DCASE 评估对齐 | ❌ | ❌ | **✓** (top-K̂) |
|
| 140 |
+
| 训练稳定性 | 好 | 中 (γ 过大会崩) | **好**(hinge + 小 BCE anchor) |
|
| 141 |
+
| 已知效果 | v12 aR=0.13 | v13_B aR=0.19,aP 降 | **未验证**,但机制对 |
|
| 142 |
+
|
| 143 |
+
#### Config flag
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
# spatial_loss.py
|
| 147 |
+
frame_activity_loss_type: str = "bce" # + "topk_rank"
|
| 148 |
+
topk_rank_margin: float = 2.0
|
| 149 |
+
topk_rank_bce_weight: float = 0.1 # anchor 的 BCE 权重
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
### D-5: Resume optimizer(低成本改进)
|
| 153 |
+
|
| 154 |
+
#### 诊断
|
| 155 |
+
v13_B/C 都用 `--no-resume-optimizer`,Adam 的 m/v moment buffer 从零重建。前 2-3 个 epoch 梯度方向不稳,尤其在"换 loss 函数"后更明显。
|
| 156 |
+
|
| 157 |
+
#### 改动
|
| 158 |
+
```bash
|
| 159 |
+
# run_ov1_unified_v13d.sh 不加 --no-resume-optimizer
|
| 160 |
+
# 但 LR 设为 v12 最后 LR 的 1/3(避免 resume 后太激进)
|
| 161 |
+
SPATIAL_LR="${SPATIAL_LR:-7e-6}" # v12 是 2e-5,这里是 2e-5/3 ≈ 7e-6
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
**注意**:Optimizer state 包含 LR scheduler 状态,如果我们切 cosine schedule 需要 reset schedule 但保留 Adam moments。实现时:
|
| 165 |
+
```python
|
| 166 |
+
# 加载 optimizer_state_dict
|
| 167 |
+
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
| 168 |
+
# 但把所有 param_group 的 LR 重设为新 LR(cosine scheduler 会从这开始)
|
| 169 |
+
for pg in optimizer.param_groups:
|
| 170 |
+
pg['lr'] = new_peak_lr
|
| 171 |
+
# 删掉 step count(avoid schedule confusion)
|
| 172 |
+
# scheduler 从 epoch=0 重新开始
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
### D-6: EMA 权重
|
| 176 |
+
|
| 177 |
+
#### 诊断
|
| 178 |
+
v12 ep10-14 F20 在 0.367-0.378 震荡,SGD 困在鞍点。EMA = 取最近 N 个权重的平滑平均,能稳定在鞍点中间而非某一端。
|
| 179 |
+
|
| 180 |
+
#### 改动
|
| 181 |
+
|
| 182 |
+
**新加 `EMAModel` helper**:
|
| 183 |
+
```python
|
| 184 |
+
class EMAModel:
|
| 185 |
+
def __init__(self, model: nn.Module, decay: float = 0.9995):
|
| 186 |
+
self.decay = decay
|
| 187 |
+
self.shadow: Dict[str, Tensor] = {
|
| 188 |
+
name: p.detach().clone()
|
| 189 |
+
for name, p in model.named_parameters()
|
| 190 |
+
if p.requires_grad
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
@torch.no_grad()
|
| 194 |
+
def update(self, model: nn.Module):
|
| 195 |
+
for name, p in model.named_parameters():
|
| 196 |
+
if not p.requires_grad: continue
|
| 197 |
+
self.shadow[name].mul_(self.decay).add_(p.detach(), alpha=1 - self.decay)
|
| 198 |
+
|
| 199 |
+
def apply_to(self, model: nn.Module) -> Dict[str, Tensor]:
|
| 200 |
+
"""Swap model params with EMA shadow, returns backup for restoration."""
|
| 201 |
+
backup = {}
|
| 202 |
+
for name, p in model.named_parameters():
|
| 203 |
+
if name in self.shadow:
|
| 204 |
+
backup[name] = p.data.clone()
|
| 205 |
+
p.data.copy_(self.shadow[name])
|
| 206 |
+
return backup
|
| 207 |
+
|
| 208 |
+
def restore(self, model: nn.Module, backup: Dict[str, Tensor]):
|
| 209 |
+
for name, p in model.named_parameters():
|
| 210 |
+
if name in backup:
|
| 211 |
+
p.data.copy_(backup[name])
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
**训练循环**:
|
| 215 |
+
```python
|
| 216 |
+
# 每 step 后:
|
| 217 |
+
if ema_model is not None:
|
| 218 |
+
ema_model.update(model)
|
| 219 |
+
|
| 220 |
+
# validate 前:
|
| 221 |
+
if ema_model is not None:
|
| 222 |
+
backup = ema_model.apply_to(model)
|
| 223 |
+
val_metrics = evaluate_one_epoch(model, ...)
|
| 224 |
+
if ema_model is not None:
|
| 225 |
+
ema_model.restore(model, backup)
|
| 226 |
+
|
| 227 |
+
# save best.pt 时,保存 EMA 权重而非原始权重
|
| 228 |
+
if ema_model is not None:
|
| 229 |
+
backup = ema_model.apply_to(model)
|
| 230 |
+
torch.save({'model_state_dict': model.state_dict(), ...})
|
| 231 |
+
ema_model.restore(model, backup)
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
#### Config flag
|
| 235 |
+
|
| 236 |
+
```python
|
| 237 |
+
# TrainSpatialBEATsConfig
|
| 238 |
+
use_ema: bool = False
|
| 239 |
+
ema_decay: float = 0.9995
|
| 240 |
+
ema_start_epoch: int = 3 # 前 3 ep 不 EMA(避免 warmup 噪声)
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
## 3. v13_D preset
|
| 244 |
+
|
| 245 |
+
```python
|
| 246 |
+
def make_ov1_unified_v13d_config(...):
|
| 247 |
+
cfg = make_ov1_unified_v12_config(...) # v12 为基础
|
| 248 |
+
|
| 249 |
+
# D-1: 扩大 cls warmup,cosine schedule
|
| 250 |
+
cfg.frame_spatial_loss_warmup_epochs = 8
|
| 251 |
+
cfg.frame_spatial_loss_ramp_epochs = 2
|
| 252 |
+
cfg.num_epochs = 25
|
| 253 |
+
cfg.use_cosine_lr = True
|
| 254 |
+
cfg.cosine_lr_warmup_epochs = 3
|
| 255 |
+
cfg.cosine_lr_min_ratio = 0.05
|
| 256 |
+
cfg.learning_rate = 1.5e-5 # peak LR
|
| 257 |
+
|
| 258 |
+
# D-2: Top-K rank activity loss
|
| 259 |
+
cfg.loss.frame_activity_loss_type = "topk_rank"
|
| 260 |
+
cfg.loss.topk_rank_margin = 2.0
|
| 261 |
+
cfg.loss.topk_rank_bce_weight = 0.1
|
| 262 |
+
|
| 263 |
+
# D-5: resume optimizer (在 run script 里,不写 --no-resume-optimizer)
|
| 264 |
+
|
| 265 |
+
# D-6: EMA
|
| 266 |
+
cfg.use_ema = True
|
| 267 |
+
cfg.ema_decay = 0.9995
|
| 268 |
+
cfg.ema_start_epoch = 3
|
| 269 |
+
|
| 270 |
+
cfg.output_dir = "checkpoints/spatial_beats_ov1_unified_v13d_exp/03_ov123_top4"
|
| 271 |
+
return cfg
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
## 4. 实现步骤
|
| 275 |
+
|
| 276 |
+
| 文件 | 改动 | 对应 D-* |
|
| 277 |
+
|---|---|---|
|
| 278 |
+
| `spatial_loss.py` | 加 `_topk_rank_activity_loss` + config 字段 + 分支 | D-2 |
|
| 279 |
+
| `train_spatial_beats.py` | 加 `EMAModel` class + cosine LR scheduler + 训练循环 hook | D-1, D-6 |
|
| 280 |
+
| `train_spatial_beats.py` | 加 `make_ov1_unified_v13d_config` + CLI + choices | - |
|
| 281 |
+
| `run_ov1_unified_v13d.sh` | 新脚本,不带 `--no-resume-optimizer` | D-5 |
|
| 282 |
+
| `docs/v13d_spatial_beats_design.md` | 本文档 | - |
|
| 283 |
+
|
| 284 |
+
所有改动都通过 cfg flag 控制,默认 False → v12/v13_B/v13_C 不受影响。
|
| 285 |
+
|
| 286 |
+
## 5. 预期结果
|
| 287 |
+
|
| 288 |
+
| 指标 | v12 best | v13_D 预期 | 机制 |
|
| 289 |
+
|---|---|---|---|
|
| 290 |
+
| F20 | 0.378 | **0.42 ~ 0.46** | Top-K rank + EMA + 长 warmup |
|
| 291 |
+
| aR | 0.126 | **0.25 ~ 0.40** | Top-K 强制拉高活跃 slot |
|
| 292 |
+
| aP | 0.855 | **0.80 ~ 0.85** | 可能小降(recall ↑ 的代价),但 hinge 保留 rank 信号 |
|
| 293 |
+
| class_acc | 0.834 | **0.86 ~ 0.88** | 长 warmup 让 cls 真的学完 |
|
| 294 |
+
| azi_mae | 19.7° | **18~20°** | 不变,不是目标 |
|
| 295 |
+
|
| 296 |
+
## 6. 风险和兜底
|
| 297 |
+
|
| 298 |
+
| 风险 | 概率 | 兜底 |
|
| 299 |
+
|---|---|---|
|
| 300 |
+
| Top-K rank 的 hinge 梯度饱和(margin 太大) | 中 | 降 margin 到 1.0 |
|
| 301 |
+
| margin=2.0 导致 logit 分布爆炸(两端拉开) | 低 | anchor BCE 权重从 0.1 升到 0.3 |
|
| 302 |
+
| EMA 反而降 F(热启动 EMA 初始化问题) | 低 | ema_start_epoch 提到 5 |
|
| 303 |
+
| Cosine LR 峰值太高毁掉 v12 表征 | 中 | peak_lr 降到 1e-5(v12 也是这个) |
|
| 304 |
+
| resume optimizer 把 v12 的 Adam moment 固化在错方向 | 低 | 如果 ep0-2 loss 爆炸,退回 --no-resume-optimizer |
|
| 305 |
+
| 总体不涨(所有改动都没用) | 中 | 写 ablation,跑 v13_D_noema / v13_D_nocos 诊断 |
|
| 306 |
+
|
| 307 |
+
## 7. 和 v13_B/C 的关系
|
| 308 |
+
|
| 309 |
+
- **v13_D 不依赖 v13_B/C 的改动**,直接从 v12 best.pt 热启动
|
| 310 |
+
- v13_B/C 可以看作"改模块结构"的尝试,v13_D 是"改训练机制"的尝试
|
| 311 |
+
- 如果 v13_D 成功(F ≥ 0.42),**可以在它之上加回 v13_C 的 refinement 2-layer**,那才是真正的 v14
|
| 312 |
+
|
| 313 |
+
## 8. 验证步骤
|
| 314 |
+
|
| 315 |
+
1. `python -c "import ast; ast.parse(open('spatial_loss.py').read())"` 语法
|
| 316 |
+
2. Top-K rank loss 单测:给已知 activity_logit + target 手算验证
|
| 317 |
+
3. 模型构造 + hot-start v12 best.pt:确认 missing=0, unexpected=0
|
| 318 |
+
4. 单 batch 前向 + backward:确认 loss 是 scalar、梯度非 NaN
|
| 319 |
+
5. EMA 单测:update 后 shadow 权重正确
|
| 320 |
+
6. 1 epoch dry-run:看 cosine LR 曲线 + EMA shadow 随 step 变化
|
| 321 |
+
|
| 322 |
+
---
|
| 323 |
+
|
| 324 |
+
## 附:为什么不加更多改动(D-7 per-class expert 等)
|
| 325 |
+
|
| 326 |
+
诊断:v13_B/C 失败的主因不是"改动不够多",而是"改动不对靶"。v13_D 只碰 loss 机制 + 训练 schedule,属于**最小必要改动**:
|
| 327 |
+
|
| 328 |
+
- D-2 Top-K 直接针对 activity_recall 瓶颈
|
| 329 |
+
- D-1 扩大 warmup 给表征更多学习时间
|
| 330 |
+
- D-6 EMA 降低末期震荡
|
| 331 |
+
- D-5 resume optimizer 让 LR 轨迹连续
|
| 332 |
+
|
| 333 |
+
加更多改动(per-class expert、class-conditional gate v2)会重复 v13_B 的错误——改了 head 但没触达瓶颈,而且同时改太多东西无法 ablation。
|