vocal separation and bytedance integration
Browse files- README.md +173 -63
- backend/app_config.py +24 -0
- backend/audio_preprocessor.py +260 -0
- backend/bytedance_wrapper.py +224 -0
- backend/confidence_filter.py +237 -0
- backend/ensemble_transcriber.py +407 -0
- backend/evaluation/CLUSTER_SETUP.md +0 -262
- backend/evaluation/README.md +0 -251
- backend/evaluation/__init__.py +0 -1
- backend/evaluation/benchmark.py +0 -308
- backend/evaluation/generate_full_test_set.py +0 -66
- backend/evaluation/metrics.py +0 -253
- backend/evaluation/prepare_maestro.py +0 -259
- backend/evaluation/run_benchmark.py +0 -192
- backend/evaluation/slurm_benchmark.sh +0 -175
- backend/evaluation/test_videos.json +0 -66
- backend/key_filter.py +346 -0
- backend/main.py +32 -28
- backend/pipeline.py +454 -30
- backend/requirements-test.txt +0 -14
- backend/requirements.txt +13 -4
- backend/tests/conftest.py +1 -1
- backend/tests/test_api.py +9 -9
- backend/tests/test_pipeline.py +4 -4
- backend/tests/test_pipeline_fixes.py +0 -605
- backend/tests/test_pipeline_monophonic.py +0 -306
- backend/tests/test_tasks.py +30 -4
- backend/tests/test_yourmt3_integration.py +0 -296
- docs/README.md +2 -0
- docs/architecture/deployment.md +50 -19
- docs/backend/testing.md +365 -0
- docs/getting-started.md +27 -6
README.md
CHANGED
|
@@ -13,7 +13,10 @@ Rescored transcribes YouTube videos to professional-quality music notation:
|
|
| 13 |
**Tech Stack**:
|
| 14 |
- **Backend**: Python/FastAPI + Celery + Redis
|
| 15 |
- **Frontend**: React + VexFlow (notation) + Tone.js (playback)
|
| 16 |
-
- **ML**:
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
## Quick Start
|
| 19 |
|
|
@@ -32,6 +35,23 @@ Rescored transcribes YouTube videos to professional-quality music notation:
|
|
| 32 |
# Clone repository
|
| 33 |
git clone https://github.com/yourusername/rescored.git
|
| 34 |
cd rescored
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
```
|
| 36 |
|
| 37 |
### Setup Redis (macOS)
|
|
@@ -52,21 +72,66 @@ redis-cli ping # Should return PONG
|
|
| 52 |
```bash
|
| 53 |
cd backend
|
| 54 |
|
| 55 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
source .venv/bin/activate
|
| 57 |
|
| 58 |
-
#
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
# Copy environment file and configure
|
| 66 |
cp .env.example .env
|
| 67 |
# Edit .env - ensure YOURMT3_DEVICE=mps for Apple Silicon GPU acceleration
|
| 68 |
```
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
### Setup Frontend
|
| 71 |
|
| 72 |
```bash
|
|
@@ -102,54 +167,60 @@ YouTube requires authentication for video downloads (as of December 2024). You *
|
|
| 102 |
mv ~/Downloads/youtube.com_cookies.txt ./storage/youtube_cookies.txt
|
| 103 |
```
|
| 104 |
|
| 105 |
-
|
| 106 |
|
| 107 |
-
|
| 108 |
-
```bash
|
| 109 |
-
./start.sh
|
| 110 |
-
```
|
| 111 |
-
This starts all services in the background. Logs are written to `logs/` directory.
|
| 112 |
|
| 113 |
-
|
| 114 |
-
```bash
|
| 115 |
-
./stop.sh
|
| 116 |
-
# Or press Ctrl+C in the terminal running start.sh
|
| 117 |
-
```
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
tail -f logs/frontend.log # Frontend logs
|
| 124 |
-
```
|
| 125 |
|
| 126 |
-
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
# Use --pool=solo on macOS to avoid fork() crashes with ML libraries
|
| 140 |
-
celery -A tasks worker --loglevel=info --pool=solo
|
| 141 |
-
```
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
**Verification:**
|
| 155 |
```bash
|
|
@@ -171,7 +242,11 @@ You should see the file listed.
|
|
| 171 |
|
| 172 |
### YourMT3+ Setup
|
| 173 |
|
| 174 |
-
The backend uses **
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
**YourMT3+ model files and source code are already included in the repository.** The model checkpoint (~536MB) is stored via Git LFS in `backend/ymt/yourmt3_core/`.
|
| 177 |
|
|
@@ -237,20 +312,39 @@ You should see:
|
|
| 237 |
|
| 238 |
```
|
| 239 |
rescored/
|
| 240 |
-
├── backend/
|
| 241 |
-
│ ├── main.py
|
| 242 |
-
│ ├── tasks.py
|
| 243 |
-
│ ├── pipeline.py
|
| 244 |
-
│ ├──
|
| 245 |
-
│
|
| 246 |
-
├──
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
│ ├── src/
|
| 248 |
-
│ │ ├── components/
|
| 249 |
-
│ │ ├── store/
|
| 250 |
-
│ │ └── api/
|
| 251 |
-
│ └── package.json
|
| 252 |
-
├── docs/
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
```
|
| 255 |
|
| 256 |
## Documentation
|
|
@@ -286,7 +380,12 @@ Comprehensive documentation is available in the [`docs/`](docs/) directory:
|
|
| 286 |
|
| 287 |
## Accuracy Expectations
|
| 288 |
|
| 289 |
-
**With YourMT3+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
- Simple piano: **80-85% accurate**
|
| 291 |
- Complex pieces: **70-75% accurate**
|
| 292 |
|
|
@@ -296,20 +395,31 @@ Comprehensive documentation is available in the [`docs/`](docs/) directory:
|
|
| 296 |
|
| 297 |
The interactive editor is designed to make fixing errors easy regardless of which transcription model is used.
|
| 298 |
|
|
|
|
|
|
|
| 299 |
## Development
|
| 300 |
|
| 301 |
### Running Tests
|
| 302 |
|
| 303 |
```bash
|
| 304 |
-
# Backend tests
|
| 305 |
cd backend
|
|
|
|
| 306 |
pytest
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
# Frontend tests
|
| 309 |
cd frontend
|
| 310 |
npm test
|
| 311 |
```
|
| 312 |
|
|
|
|
|
|
|
| 313 |
### API Documentation
|
| 314 |
|
| 315 |
Once the backend is running, visit:
|
|
|
|
| 13 |
**Tech Stack**:
|
| 14 |
- **Backend**: Python/FastAPI + Celery + Redis
|
| 15 |
- **Frontend**: React + VexFlow (notation) + Tone.js (playback)
|
| 16 |
+
- **ML Pipeline**:
|
| 17 |
+
- BS-RoFormer (vocal removal) → Demucs (6-stem separation)
|
| 18 |
+
- YourMT3+ + ByteDance ensemble (90% accuracy on piano)
|
| 19 |
+
- Audio preprocessing + confidence/key filtering
|
| 20 |
|
| 21 |
## Quick Start
|
| 22 |
|
|
|
|
| 35 |
# Clone repository
|
| 36 |
git clone https://github.com/yourusername/rescored.git
|
| 37 |
cd rescored
|
| 38 |
+
|
| 39 |
+
# Pull large files with Git LFS (required for YourMT3+ model checkpoint)
|
| 40 |
+
git lfs pull
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
**Note:** This repository uses **Git LFS** (Large File Storage) to store the YourMT3+ model checkpoint (~536MB). If you don't have Git LFS installed:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
# macOS
|
| 47 |
+
brew install git-lfs
|
| 48 |
+
git lfs install
|
| 49 |
+
git lfs pull
|
| 50 |
+
|
| 51 |
+
# Linux (Debian/Ubuntu)
|
| 52 |
+
sudo apt-get install git-lfs
|
| 53 |
+
git lfs install
|
| 54 |
+
git lfs pull
|
| 55 |
```
|
| 56 |
|
| 57 |
### Setup Redis (macOS)
|
|
|
|
| 72 |
```bash
|
| 73 |
cd backend
|
| 74 |
|
| 75 |
+
# Ensure Python 3.10 is installed
|
| 76 |
+
python3.10 --version # Should show Python 3.10.x
|
| 77 |
+
|
| 78 |
+
# Create virtual environment
|
| 79 |
+
python3.10 -m venv .venv
|
| 80 |
+
|
| 81 |
+
# Activate virtual environment
|
| 82 |
source .venv/bin/activate
|
| 83 |
|
| 84 |
+
# Upgrade pip, setuptools, and wheel
|
| 85 |
+
pip install --upgrade pip setuptools wheel
|
| 86 |
+
|
| 87 |
+
# Install all dependencies (takes 10-15 minutes)
|
| 88 |
+
pip install -r requirements.txt
|
| 89 |
|
| 90 |
+
# Verify installation
|
| 91 |
+
python -c "import torch; print(f'PyTorch {torch.__version__} installed')"
|
| 92 |
+
python -c "import librosa; print(f'librosa installed')"
|
| 93 |
+
python -c "import music21; print(f'music21 installed')"
|
| 94 |
|
| 95 |
# Copy environment file and configure
|
| 96 |
cp .env.example .env
|
| 97 |
# Edit .env - ensure YOURMT3_DEVICE=mps for Apple Silicon GPU acceleration
|
| 98 |
```
|
| 99 |
|
| 100 |
+
**What gets installed:**
|
| 101 |
+
- Core ML frameworks: PyTorch 2.9+, torchaudio 2.9+
|
| 102 |
+
- Audio processing: librosa, soundfile, demucs, audio-separator
|
| 103 |
+
- Transcription: YourMT3+ dependencies (transformers, lightning, einops)
|
| 104 |
+
- Music notation: music21, mido, pretty_midi
|
| 105 |
+
- Web framework: FastAPI, uvicorn, celery, redis
|
| 106 |
+
- Testing: pytest, pytest-asyncio, pytest-cov, pytest-mock
|
| 107 |
+
- **Total: ~200 packages, ~3-4GB download**
|
| 108 |
+
|
| 109 |
+
**Troubleshooting Installation:**
|
| 110 |
+
|
| 111 |
+
If you encounter errors during `pip install -r requirements.txt`:
|
| 112 |
+
|
| 113 |
+
1. **scipy build errors**: Make sure you have the latest pip/setuptools:
|
| 114 |
+
```bash
|
| 115 |
+
pip install --upgrade pip setuptools wheel
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
2. **numpy version conflicts**: The requirements.txt is configured to use numpy 2.x which works with all packages. If you see conflicts, try:
|
| 119 |
+
```bash
|
| 120 |
+
pip install --no-deps -r requirements.txt
|
| 121 |
+
pip check # Verify no broken dependencies
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
3. **torch installation issues on macOS**: PyTorch should install pre-built wheels. If it tries to build from source:
|
| 125 |
+
```bash
|
| 126 |
+
pip install --only-binary :all: torch torchaudio
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
4. **madmom build errors**: madmom requires Cython. Install it first:
|
| 130 |
+
```bash
|
| 131 |
+
pip install Cython
|
| 132 |
+
pip install madmom
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
### Setup Frontend
|
| 136 |
|
| 137 |
```bash
|
|
|
|
| 167 |
mv ~/Downloads/youtube.com_cookies.txt ./storage/youtube_cookies.txt
|
| 168 |
```
|
| 169 |
|
| 170 |
+
## Running the Application
|
| 171 |
|
| 172 |
+
### Start All Services (Recommended)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
+
Use the provided shell scripts to start/stop all services at once:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
+
```bash
|
| 177 |
+
# Start all services (backend API, Celery worker, frontend)
|
| 178 |
+
./start.sh
|
| 179 |
+
```
|
|
|
|
|
|
|
| 180 |
|
| 181 |
+
This starts all services in the background with logs written to the `logs/` directory.
|
| 182 |
|
| 183 |
+
**View logs in real-time:**
|
| 184 |
+
```bash
|
| 185 |
+
tail -f logs/api.log # Backend API logs
|
| 186 |
+
tail -f logs/worker.log # Celery worker logs
|
| 187 |
+
tail -f logs/frontend.log # Frontend logs
|
| 188 |
+
```
|
| 189 |
|
| 190 |
+
**Stop all services:**
|
| 191 |
+
```bash
|
| 192 |
+
./stop.sh
|
| 193 |
+
```
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
+
**Services available at:**
|
| 196 |
+
- Frontend: http://localhost:5173
|
| 197 |
+
- Backend API: http://localhost:8000
|
| 198 |
+
- API Docs: http://localhost:8000/docs
|
| 199 |
+
|
| 200 |
+
### Manual Start (Alternative)
|
| 201 |
|
| 202 |
+
If you prefer to run services manually in separate terminals:
|
| 203 |
+
|
| 204 |
+
**Terminal 1 - Backend API:**
|
| 205 |
+
```bash
|
| 206 |
+
cd backend
|
| 207 |
+
source .venv/bin/activate
|
| 208 |
+
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
**Terminal 2 - Celery Worker:**
|
| 212 |
+
```bash
|
| 213 |
+
cd backend
|
| 214 |
+
source .venv/bin/activate
|
| 215 |
+
# Use --pool=solo on macOS to avoid fork() crashes with ML libraries
|
| 216 |
+
celery -A tasks worker --loglevel=info --pool=solo
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
**Terminal 3 - Frontend:**
|
| 220 |
+
```bash
|
| 221 |
+
cd frontend
|
| 222 |
+
npm run dev
|
| 223 |
+
```
|
| 224 |
|
| 225 |
**Verification:**
|
| 226 |
```bash
|
|
|
|
| 242 |
|
| 243 |
### YourMT3+ Setup
|
| 244 |
|
| 245 |
+
The backend uses a **multi-model ensemble** for transcription:
|
| 246 |
+
- **Primary**: YourMT3+ (multi-instrument, 80-85% base accuracy)
|
| 247 |
+
- **Specialist**: ByteDance Piano Transcription (piano-specific, ~90% accuracy)
|
| 248 |
+
- **Ensemble**: Weighted voting combines both models (90% accuracy on piano)
|
| 249 |
+
- **Fallback**: basic-pitch if ensemble unavailable (~70% accuracy)
|
| 250 |
|
| 251 |
**YourMT3+ model files and source code are already included in the repository.** The model checkpoint (~536MB) is stored via Git LFS in `backend/ymt/yourmt3_core/`.
|
| 252 |
|
|
|
|
| 312 |
|
| 313 |
```
|
| 314 |
rescored/
|
| 315 |
+
├── backend/ # Python/FastAPI backend
|
| 316 |
+
│ ├── main.py # REST API + WebSocket server
|
| 317 |
+
│ ├── tasks.py # Celery background workers
|
| 318 |
+
│ ├── pipeline.py # Audio processing pipeline
|
| 319 |
+
│ ├── app_config.py # Configuration settings
|
| 320 |
+
│ ├── app_utils.py # Utility functions
|
| 321 |
+
│ ├── audio_preprocessor.py # Audio enhancement pipeline
|
| 322 |
+
│ ├── ensemble_transcriber.py # Multi-model voting system
|
| 323 |
+
│ ├── confidence_filter.py # Post-processing filters
|
| 324 |
+
│ ├── key_filter.py # Music theory filters
|
| 325 |
+
│ ├── requirements.txt # Python dependencies (including tests)
|
| 326 |
+
│ ├── tests/ # Test suite (59 tests, 27% coverage)
|
| 327 |
+
│ │ ├── test_api.py # API endpoint tests
|
| 328 |
+
│ │ ├── test_pipeline.py # Pipeline component tests
|
| 329 |
+
│ │ ├── test_tasks.py # Celery task tests
|
| 330 |
+
│ │ └── test_utils.py # Utility function tests
|
| 331 |
+
│ └── ymt/ # YourMT3+ model and wrappers
|
| 332 |
+
├── frontend/ # React frontend
|
| 333 |
│ ├── src/
|
| 334 |
+
│ │ ├── components/ # UI components
|
| 335 |
+
│ │ ├── store/ # Zustand state management
|
| 336 |
+
│ │ └── api/ # API client
|
| 337 |
+
│ └── package.json # Node dependencies
|
| 338 |
+
├── docs/ # Comprehensive documentation
|
| 339 |
+
│ ├── backend/ # Backend implementation guides
|
| 340 |
+
│ ├── frontend/ # Frontend implementation guides
|
| 341 |
+
│ ├── architecture/ # System design documents
|
| 342 |
+
│ └── research/ # ML model comparisons
|
| 343 |
+
├── logs/ # Runtime logs (created by start.sh)
|
| 344 |
+
├── storage/ # YouTube cookies and temp files
|
| 345 |
+
├── start.sh # Start all services
|
| 346 |
+
├── stop.sh # Stop all services
|
| 347 |
+
└── docker-compose.yml # Docker setup (optional)
|
| 348 |
```
|
| 349 |
|
| 350 |
## Documentation
|
|
|
|
| 380 |
|
| 381 |
## Accuracy Expectations
|
| 382 |
|
| 383 |
+
**With Ensemble (YourMT3+ + ByteDance) - Recommended:**
|
| 384 |
+
- Simple piano: **~90% accurate** ✨
|
| 385 |
+
- Complex pieces: **80-85% accurate**
|
| 386 |
+
- Includes audio preprocessing, ensemble voting, and post-processing filters
|
| 387 |
+
|
| 388 |
+
**With YourMT3+ only:**
|
| 389 |
- Simple piano: **80-85% accurate**
|
| 390 |
- Complex pieces: **70-75% accurate**
|
| 391 |
|
|
|
|
| 395 |
|
| 396 |
The interactive editor is designed to make fixing errors easy regardless of which transcription model is used.
|
| 397 |
|
| 398 |
+
**Note**: Ensemble mode is enabled by default in `app_config.py`. ByteDance requires ~4GB VRAM and may fall back to YourMT3+ on systems with limited GPU memory.
|
| 399 |
+
|
| 400 |
## Development
|
| 401 |
|
| 402 |
### Running Tests
|
| 403 |
|
| 404 |
```bash
|
| 405 |
+
# Backend tests (59 tests, ~5-10 seconds)
|
| 406 |
cd backend
|
| 407 |
+
source .venv/bin/activate
|
| 408 |
pytest
|
| 409 |
|
| 410 |
+
# Run with coverage report
|
| 411 |
+
pytest --cov=. --cov-report=html
|
| 412 |
+
|
| 413 |
+
# Run specific test file
|
| 414 |
+
pytest tests/test_api.py -v
|
| 415 |
+
|
| 416 |
# Frontend tests
|
| 417 |
cd frontend
|
| 418 |
npm test
|
| 419 |
```
|
| 420 |
|
| 421 |
+
See [docs/backend/testing.md](docs/backend/testing.md) for detailed testing guide.
|
| 422 |
+
|
| 423 |
### API Documentation
|
| 424 |
|
| 425 |
Once the backend is running, visit:
|
backend/app_config.py
CHANGED
|
@@ -86,6 +86,30 @@ class Settings(BaseSettings):
|
|
| 86 |
vocal_instrument: int = 40 # MIDI program number for vocals (40=Violin, 73=Flute, 65=Alto Sax)
|
| 87 |
use_6stem_demucs: bool = True # Use 6-stem Demucs (piano, guitar, drums, bass, other) vs 4-stem
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
# Grand Staff Configuration
|
| 90 |
enable_grand_staff: bool = True # Split piano into treble + bass clefs
|
| 91 |
middle_c_split: int = 60 # MIDI note number for staff split (60 = Middle C)
|
|
|
|
| 86 |
vocal_instrument: int = 40 # MIDI program number for vocals (40=Violin, 73=Flute, 65=Alto Sax)
|
| 87 |
use_6stem_demucs: bool = True # Use 6-stem Demucs (piano, guitar, drums, bass, other) vs 4-stem
|
| 88 |
|
| 89 |
+
# Ensemble Transcription Configuration
|
| 90 |
+
use_ensemble_transcription: bool = True # Use ensemble of YourMT3+ and ByteDance for higher accuracy
|
| 91 |
+
use_yourmt3_ensemble: bool = True # Include YourMT3+ in ensemble
|
| 92 |
+
use_bytedance_ensemble: bool = True # Include ByteDance piano transcription in ensemble
|
| 93 |
+
ensemble_voting_strategy: str = "weighted" # Voting strategy: weighted, intersection, union, majority
|
| 94 |
+
ensemble_onset_tolerance_ms: int = 50 # Time window for matching notes (milliseconds)
|
| 95 |
+
ensemble_confidence_threshold: float = 0.6 # Minimum confidence for weighted voting
|
| 96 |
+
|
| 97 |
+
# Audio Preprocessing Configuration
|
| 98 |
+
enable_audio_preprocessing: bool = True # Preprocess audio before separation/transcription
|
| 99 |
+
enable_audio_denoising: bool = True # Remove background noise and artifacts
|
| 100 |
+
enable_audio_normalization: bool = True # Normalize volume to consistent level
|
| 101 |
+
enable_highpass_filter: bool = True # Remove low-frequency rumble (<30Hz)
|
| 102 |
+
|
| 103 |
+
# Post-Processing Filters (Phase 4)
|
| 104 |
+
enable_confidence_filtering: bool = False # Filter low-confidence notes (reduces false positives)
|
| 105 |
+
confidence_threshold: float = 0.3 # Minimum confidence to keep note (0-1)
|
| 106 |
+
velocity_threshold: int = 20 # Minimum velocity to keep note (0-127)
|
| 107 |
+
min_note_duration: float = 0.05 # Minimum note duration in seconds
|
| 108 |
+
|
| 109 |
+
enable_key_aware_filtering: bool = False # Filter isolated out-of-key notes (reduces false positives)
|
| 110 |
+
allow_chromatic_passing_tones: bool = True # Keep brief chromatic notes (jazz, classical)
|
| 111 |
+
isolation_threshold: float = 0.5 # Time threshold (seconds) to consider note isolated
|
| 112 |
+
|
| 113 |
# Grand Staff Configuration
|
| 114 |
enable_grand_staff: bool = True # Split piano into treble + bass clefs
|
| 115 |
middle_c_split: int = 60 # MIDI note number for staff split (60 = Middle C)
|
backend/audio_preprocessor.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio Preprocessing Module
|
| 3 |
+
|
| 4 |
+
Enhances audio quality before source separation and transcription.
|
| 5 |
+
|
| 6 |
+
Preprocessing Steps:
|
| 7 |
+
1. Spectral denoising - Remove background noise and artifacts
|
| 8 |
+
2. Peak normalization - Normalize volume to consistent level
|
| 9 |
+
3. High-pass filtering - Remove rumble and DC offset
|
| 10 |
+
4. Resampling - Ensure consistent sample rate
|
| 11 |
+
|
| 12 |
+
Target: +2-5% accuracy improvement on noisy/compressed YouTube audio
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Optional
|
| 17 |
+
import numpy as np
|
| 18 |
+
import librosa
|
| 19 |
+
import soundfile as sf
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AudioPreprocessor:
|
| 23 |
+
"""
|
| 24 |
+
Audio preprocessing for improving transcription accuracy.
|
| 25 |
+
|
| 26 |
+
Mitigates common issues with YouTube audio:
|
| 27 |
+
- Compression artifacts (lossy codecs)
|
| 28 |
+
- Background noise (ambient, microphone noise)
|
| 29 |
+
- Inconsistent levels (quiet vs loud recordings)
|
| 30 |
+
- Low-frequency rumble (not musical, degrades separation)
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
enable_denoising: bool = True,
|
| 36 |
+
enable_normalization: bool = True,
|
| 37 |
+
enable_highpass: bool = True,
|
| 38 |
+
target_sample_rate: int = 44100
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Initialize audio preprocessor.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
enable_denoising: Enable spectral denoising
|
| 45 |
+
enable_normalization: Enable peak normalization
|
| 46 |
+
enable_highpass: Enable high-pass filter (remove rumble)
|
| 47 |
+
target_sample_rate: Target sample rate (Hz)
|
| 48 |
+
"""
|
| 49 |
+
self.enable_denoising = enable_denoising
|
| 50 |
+
self.enable_normalization = enable_normalization
|
| 51 |
+
self.enable_highpass = enable_highpass
|
| 52 |
+
self.target_sample_rate = target_sample_rate
|
| 53 |
+
|
| 54 |
+
def preprocess(
|
| 55 |
+
self,
|
| 56 |
+
audio_path: Path,
|
| 57 |
+
output_dir: Optional[Path] = None
|
| 58 |
+
) -> Path:
|
| 59 |
+
"""
|
| 60 |
+
Preprocess audio file for improved transcription quality.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
audio_path: Input audio file
|
| 64 |
+
output_dir: Output directory (default: same as input)
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Path to preprocessed audio file
|
| 68 |
+
"""
|
| 69 |
+
if output_dir is None:
|
| 70 |
+
output_dir = audio_path.parent
|
| 71 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
output_path = output_dir / f"{audio_path.stem}_preprocessed.wav"
|
| 74 |
+
|
| 75 |
+
print(f" Preprocessing audio: {audio_path.name}")
|
| 76 |
+
|
| 77 |
+
# Load audio (preserve stereo if present)
|
| 78 |
+
y, sr = librosa.load(str(audio_path), sr=None, mono=False)
|
| 79 |
+
|
| 80 |
+
# Handle stereo vs mono
|
| 81 |
+
if y.ndim == 2:
|
| 82 |
+
print(f" Input: stereo, {sr}Hz")
|
| 83 |
+
is_stereo = True
|
| 84 |
+
else:
|
| 85 |
+
print(f" Input: mono, {sr}Hz")
|
| 86 |
+
is_stereo = False
|
| 87 |
+
y = np.expand_dims(y, axis=0) # Make it (1, samples) for uniform processing
|
| 88 |
+
|
| 89 |
+
# 1. Spectral denoising
|
| 90 |
+
if self.enable_denoising:
|
| 91 |
+
print(f" Applying spectral denoising...")
|
| 92 |
+
y = self._denoise(y, sr, is_stereo)
|
| 93 |
+
|
| 94 |
+
# 2. Peak normalization
|
| 95 |
+
if self.enable_normalization:
|
| 96 |
+
print(f" Normalizing volume...")
|
| 97 |
+
y = self._normalize(y)
|
| 98 |
+
|
| 99 |
+
# 3. High-pass filter (remove rumble <30Hz)
|
| 100 |
+
if self.enable_highpass:
|
| 101 |
+
print(f" Applying high-pass filter (30Hz cutoff)...")
|
| 102 |
+
y = self._highpass_filter(y, sr)
|
| 103 |
+
|
| 104 |
+
# 4. Resample to target sample rate
|
| 105 |
+
if sr != self.target_sample_rate:
|
| 106 |
+
print(f" Resampling: {sr}Hz → {self.target_sample_rate}Hz")
|
| 107 |
+
y = self._resample(y, sr, self.target_sample_rate)
|
| 108 |
+
sr = self.target_sample_rate
|
| 109 |
+
|
| 110 |
+
# Convert back to mono if input was mono
|
| 111 |
+
if not is_stereo:
|
| 112 |
+
y = y[0] # Remove channel dimension
|
| 113 |
+
|
| 114 |
+
# Save preprocessed audio
|
| 115 |
+
sf.write(output_path, y.T if is_stereo else y, sr)
|
| 116 |
+
print(f" ✓ Preprocessed audio saved: {output_path.name}")
|
| 117 |
+
|
| 118 |
+
return output_path
|
| 119 |
+
|
| 120 |
+
def _denoise(self, y: np.ndarray, sr: int, is_stereo: bool) -> np.ndarray:
|
| 121 |
+
"""
|
| 122 |
+
Apply spectral denoising using noisereduce library.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
y: Audio data (channels, samples)
|
| 126 |
+
sr: Sample rate
|
| 127 |
+
is_stereo: Whether audio is stereo
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Denoised audio
|
| 131 |
+
"""
|
| 132 |
+
try:
|
| 133 |
+
import noisereduce as nr
|
| 134 |
+
except ImportError:
|
| 135 |
+
print(f" ⚠ noisereduce not installed, skipping denoising")
|
| 136 |
+
return y
|
| 137 |
+
|
| 138 |
+
# Apply denoising per channel
|
| 139 |
+
y_denoised = np.zeros_like(y)
|
| 140 |
+
|
| 141 |
+
for ch in range(y.shape[0]):
|
| 142 |
+
y_denoised[ch] = nr.reduce_noise(
|
| 143 |
+
y=y[ch],
|
| 144 |
+
sr=sr,
|
| 145 |
+
stationary=True, # Assume noise is stationary (consistent background)
|
| 146 |
+
prop_decrease=0.8 # Aggressiveness (0-1, higher = more aggressive)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return y_denoised
|
| 150 |
+
|
| 151 |
+
def _normalize(self, y: np.ndarray, target_db: float = -1.0) -> np.ndarray:
|
| 152 |
+
"""
|
| 153 |
+
Normalize audio to target peak level.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
y: Audio data
|
| 157 |
+
target_db: Target peak level in dB (default: -1dB = almost full scale)
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Normalized audio
|
| 161 |
+
"""
|
| 162 |
+
# Find peak across all channels
|
| 163 |
+
peak = np.abs(y).max()
|
| 164 |
+
|
| 165 |
+
if peak == 0:
|
| 166 |
+
return y # Avoid division by zero
|
| 167 |
+
|
| 168 |
+
# Calculate gain to reach target peak
|
| 169 |
+
target_linear = 10 ** (target_db / 20.0)
|
| 170 |
+
gain = target_linear / peak
|
| 171 |
+
|
| 172 |
+
return y * gain
|
| 173 |
+
|
| 174 |
+
def _highpass_filter(
|
| 175 |
+
self,
|
| 176 |
+
y: np.ndarray,
|
| 177 |
+
sr: int,
|
| 178 |
+
cutoff_hz: float = 30.0
|
| 179 |
+
) -> np.ndarray:
|
| 180 |
+
"""
|
| 181 |
+
Apply high-pass filter to remove low-frequency rumble.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
y: Audio data (channels, samples)
|
| 185 |
+
sr: Sample rate
|
| 186 |
+
cutoff_hz: Cutoff frequency (Hz)
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Filtered audio
|
| 190 |
+
"""
|
| 191 |
+
from scipy.signal import butter, sosfilt
|
| 192 |
+
|
| 193 |
+
# Design 4th-order Butterworth high-pass filter
|
| 194 |
+
sos = butter(4, cutoff_hz, 'hp', fs=sr, output='sos')
|
| 195 |
+
|
| 196 |
+
# Apply per channel
|
| 197 |
+
y_filtered = np.zeros_like(y)
|
| 198 |
+
|
| 199 |
+
for ch in range(y.shape[0]):
|
| 200 |
+
y_filtered[ch] = sosfilt(sos, y[ch])
|
| 201 |
+
|
| 202 |
+
return y_filtered
|
| 203 |
+
|
| 204 |
+
def _resample(
|
| 205 |
+
self,
|
| 206 |
+
y: np.ndarray,
|
| 207 |
+
orig_sr: int,
|
| 208 |
+
target_sr: int
|
| 209 |
+
) -> np.ndarray:
|
| 210 |
+
"""
|
| 211 |
+
Resample audio to target sample rate.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
y: Audio data (channels, samples)
|
| 215 |
+
orig_sr: Original sample rate
|
| 216 |
+
target_sr: Target sample rate
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Resampled audio
|
| 220 |
+
"""
|
| 221 |
+
y_resampled = np.zeros((y.shape[0], int(y.shape[1] * target_sr / orig_sr)))
|
| 222 |
+
|
| 223 |
+
for ch in range(y.shape[0]):
|
| 224 |
+
y_resampled[ch] = librosa.resample(
|
| 225 |
+
y[ch],
|
| 226 |
+
orig_sr=orig_sr,
|
| 227 |
+
target_sr=target_sr
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
return y_resampled
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
# Test the preprocessor
|
| 235 |
+
import argparse
|
| 236 |
+
|
| 237 |
+
parser = argparse.ArgumentParser(description="Test Audio Preprocessor")
|
| 238 |
+
parser.add_argument("audio_file", type=str, help="Path to audio file")
|
| 239 |
+
parser.add_argument("--output", type=str, default="./output_audio",
|
| 240 |
+
help="Output directory for preprocessed audio")
|
| 241 |
+
parser.add_argument("--no-denoise", action="store_true",
|
| 242 |
+
help="Disable denoising")
|
| 243 |
+
parser.add_argument("--no-normalize", action="store_true",
|
| 244 |
+
help="Disable normalization")
|
| 245 |
+
parser.add_argument("--no-highpass", action="store_true",
|
| 246 |
+
help="Disable high-pass filter")
|
| 247 |
+
args = parser.parse_args()
|
| 248 |
+
|
| 249 |
+
preprocessor = AudioPreprocessor(
|
| 250 |
+
enable_denoising=not args.no_denoise,
|
| 251 |
+
enable_normalization=not args.no_normalize,
|
| 252 |
+
enable_highpass=not args.no_highpass
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
audio_path = Path(args.audio_file)
|
| 256 |
+
output_dir = Path(args.output)
|
| 257 |
+
|
| 258 |
+
# Preprocess
|
| 259 |
+
output_path = preprocessor.preprocess(audio_path, output_dir)
|
| 260 |
+
print(f"\n✓ Preprocessing complete: {output_path}")
|
backend/bytedance_wrapper.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ByteDance Piano Transcription Wrapper
|
| 3 |
+
|
| 4 |
+
Provides a clean interface to ByteDance's high-accuracy piano transcription model.
|
| 5 |
+
Trained on MAESTRO dataset with CNN + BiGRU architecture.
|
| 6 |
+
|
| 7 |
+
Model: https://github.com/bytedance/piano_transcription
|
| 8 |
+
Paper: "High-resolution Piano Transcription with Pedals by Regressing Onsets and Offsets Times"
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional, Dict
|
| 13 |
+
import torch
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ByteDanceTranscriber:
|
| 18 |
+
"""
|
| 19 |
+
Wrapper for ByteDance piano transcription model.
|
| 20 |
+
|
| 21 |
+
Characteristics:
|
| 22 |
+
- High accuracy on piano-only audio (~90% F1 score on MAESTRO)
|
| 23 |
+
- Outputs onset/offset probabilities with confidence scores
|
| 24 |
+
- Trained specifically for piano (not general-purpose)
|
| 25 |
+
- Includes pedal transcription (sustain, soft, sostenuto)
|
| 26 |
+
|
| 27 |
+
Performance:
|
| 28 |
+
- GPU: ~15-30s per 3-min song
|
| 29 |
+
- CPU: ~2-5 min per 3-min song
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
device: Optional[str] = None,
|
| 35 |
+
checkpoint: Optional[str] = None
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Initialize ByteDance transcription model.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
device: Torch device ('cuda', 'mps', 'cpu'). Auto-detected if None.
|
| 42 |
+
checkpoint: Model checkpoint path (optional - will auto-download if None)
|
| 43 |
+
"""
|
| 44 |
+
# Import here to avoid dependency issues if not installed
|
| 45 |
+
try:
|
| 46 |
+
from piano_transcription_inference import PianoTranscription, sample_rate
|
| 47 |
+
except ImportError as e:
|
| 48 |
+
raise ImportError(
|
| 49 |
+
"ByteDance piano transcription requires: pip install piano-transcription-inference"
|
| 50 |
+
) from e
|
| 51 |
+
|
| 52 |
+
# Auto-detect device
|
| 53 |
+
if device is None:
|
| 54 |
+
if torch.cuda.is_available():
|
| 55 |
+
device = 'cuda'
|
| 56 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 57 |
+
device = 'mps'
|
| 58 |
+
else:
|
| 59 |
+
device = 'cpu'
|
| 60 |
+
|
| 61 |
+
self.device = device
|
| 62 |
+
self.sample_rate = sample_rate
|
| 63 |
+
|
| 64 |
+
print(f" Initializing ByteDance piano transcription on {device}")
|
| 65 |
+
|
| 66 |
+
# Load model
|
| 67 |
+
# If checkpoint is None, PianoTranscription will auto-download the default model
|
| 68 |
+
self.model = PianoTranscription(
|
| 69 |
+
device=device,
|
| 70 |
+
checkpoint_path=checkpoint # None means auto-download
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
print(f" ✓ ByteDance model loaded")
|
| 74 |
+
|
| 75 |
+
def transcribe_audio(
|
| 76 |
+
self,
|
| 77 |
+
audio_path: Path,
|
| 78 |
+
output_dir: Optional[Path] = None
|
| 79 |
+
) -> Path:
|
| 80 |
+
"""
|
| 81 |
+
Transcribe audio to MIDI using ByteDance model.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
audio_path: Path to audio file (WAV, MP3, etc.)
|
| 85 |
+
output_dir: Directory for output MIDI file. Defaults to audio directory.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Path to generated MIDI file
|
| 89 |
+
"""
|
| 90 |
+
from piano_transcription_inference import load_audio
|
| 91 |
+
|
| 92 |
+
# Set output directory
|
| 93 |
+
if output_dir is None:
|
| 94 |
+
output_dir = audio_path.parent
|
| 95 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 96 |
+
|
| 97 |
+
# Output MIDI path
|
| 98 |
+
midi_path = output_dir / f"{audio_path.stem}_bytedance.mid"
|
| 99 |
+
|
| 100 |
+
print(f" Transcribing with ByteDance: {audio_path.name}")
|
| 101 |
+
|
| 102 |
+
# Load audio
|
| 103 |
+
(audio, _) = load_audio(
|
| 104 |
+
str(audio_path),
|
| 105 |
+
sr=self.sample_rate,
|
| 106 |
+
mono=True
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Transcribe
|
| 110 |
+
# ByteDance outputs:
|
| 111 |
+
# - MIDI file with notes and pedal events
|
| 112 |
+
# - onset_roll: Frame-level onset probabilities (can be used for confidence)
|
| 113 |
+
# - offset_roll: Frame-level offset probabilities
|
| 114 |
+
# - velocity_roll: Frame-level velocity predictions
|
| 115 |
+
# - pedal_roll: Frame-level pedal predictions (sustain, soft, sostenuto)
|
| 116 |
+
|
| 117 |
+
transcription_result = self.model.transcribe(
|
| 118 |
+
audio,
|
| 119 |
+
str(midi_path)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
print(f" ✓ ByteDance transcription complete: {midi_path.name}")
|
| 123 |
+
|
| 124 |
+
return midi_path
|
| 125 |
+
|
| 126 |
+
def transcribe_with_confidence(
|
| 127 |
+
self,
|
| 128 |
+
audio_path: Path,
|
| 129 |
+
output_dir: Optional[Path] = None
|
| 130 |
+
) -> Dict:
|
| 131 |
+
"""
|
| 132 |
+
Transcribe audio and return MIDI path + confidence scores.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
audio_path: Path to audio file
|
| 136 |
+
output_dir: Directory for output MIDI file
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Dict with keys:
|
| 140 |
+
- 'midi_path': Path to MIDI file
|
| 141 |
+
- 'onset_confidence': Frame-level onset probabilities
|
| 142 |
+
- 'offset_confidence': Frame-level offset probabilities
|
| 143 |
+
- 'velocity_confidence': Frame-level velocity predictions
|
| 144 |
+
"""
|
| 145 |
+
from piano_transcription_inference import load_audio
|
| 146 |
+
import pretty_midi
|
| 147 |
+
|
| 148 |
+
# Set output directory
|
| 149 |
+
if output_dir is None:
|
| 150 |
+
output_dir = audio_path.parent
|
| 151 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 152 |
+
|
| 153 |
+
midi_path = output_dir / f"{audio_path.stem}_bytedance.mid"
|
| 154 |
+
|
| 155 |
+
# Load audio
|
| 156 |
+
(audio, _) = load_audio(
|
| 157 |
+
str(audio_path),
|
| 158 |
+
sr=self.sample_rate,
|
| 159 |
+
mono=True
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Transcribe and get full output
|
| 163 |
+
print(f" Transcribing with ByteDance (with confidence): {audio_path.name}")
|
| 164 |
+
|
| 165 |
+
transcription_result = self.model.transcribe(
|
| 166 |
+
audio,
|
| 167 |
+
str(midi_path)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Extract note-level confidence scores from frame-level predictions
|
| 171 |
+
# transcription_result is a dict with:
|
| 172 |
+
# - onset_roll: (frames, 88) - probability of note onset at each frame
|
| 173 |
+
# - offset_roll: (frames, 88) - probability of note offset
|
| 174 |
+
# - velocity_roll: (frames, 88) - predicted velocity
|
| 175 |
+
|
| 176 |
+
# Load generated MIDI to map confidence to notes
|
| 177 |
+
pm = pretty_midi.PrettyMIDI(str(midi_path))
|
| 178 |
+
|
| 179 |
+
# Extract note-level confidence
|
| 180 |
+
note_confidences = []
|
| 181 |
+
for instrument in pm.instruments:
|
| 182 |
+
if instrument.is_drum:
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
for note in instrument.notes:
|
| 186 |
+
# Get average onset confidence during note's onset window
|
| 187 |
+
# (simplified - full implementation would use frame analysis)
|
| 188 |
+
note_confidences.append({
|
| 189 |
+
'pitch': note.pitch,
|
| 190 |
+
'onset': note.start,
|
| 191 |
+
'offset': note.end,
|
| 192 |
+
'velocity': note.velocity,
|
| 193 |
+
'confidence': 0.9 # Placeholder - TODO: extract from onset_roll
|
| 194 |
+
})
|
| 195 |
+
|
| 196 |
+
return {
|
| 197 |
+
'midi_path': midi_path,
|
| 198 |
+
'note_confidences': note_confidences,
|
| 199 |
+
'raw_onset_roll': transcription_result.get('onset_roll'),
|
| 200 |
+
'raw_offset_roll': transcription_result.get('offset_roll'),
|
| 201 |
+
'raw_velocity_roll': transcription_result.get('velocity_roll')
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
# Test the transcriber
|
| 207 |
+
import argparse
|
| 208 |
+
|
| 209 |
+
parser = argparse.ArgumentParser(description="Test ByteDance Piano Transcription")
|
| 210 |
+
parser.add_argument("audio_file", type=str, help="Path to audio file")
|
| 211 |
+
parser.add_argument("--output", type=str, default="./output_midi",
|
| 212 |
+
help="Output directory for MIDI")
|
| 213 |
+
parser.add_argument("--device", type=str, default=None,
|
| 214 |
+
choices=['cuda', 'mps', 'cpu'],
|
| 215 |
+
help="Device to use (auto-detected if not specified)")
|
| 216 |
+
args = parser.parse_args()
|
| 217 |
+
|
| 218 |
+
transcriber = ByteDanceTranscriber(device=args.device)
|
| 219 |
+
audio_path = Path(args.audio_file)
|
| 220 |
+
output_dir = Path(args.output)
|
| 221 |
+
|
| 222 |
+
# Transcribe
|
| 223 |
+
midi_path = transcriber.transcribe_audio(audio_path, output_dir)
|
| 224 |
+
print(f"\n✓ Transcription complete: {midi_path}")
|
backend/confidence_filter.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Confidence-Based MIDI Filtering
|
| 3 |
+
|
| 4 |
+
Filters out low-confidence notes from transcription output to reduce false positives.
|
| 5 |
+
|
| 6 |
+
Expected Impact: +1-3% precision improvement
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, Optional
|
| 11 |
+
import pretty_midi
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ConfidenceFilter:
|
| 15 |
+
"""
|
| 16 |
+
Filter MIDI notes based on confidence scores.
|
| 17 |
+
|
| 18 |
+
Removes low-confidence notes that are likely false positives from the
|
| 19 |
+
transcription model. Works with confidence scores if available, or uses
|
| 20 |
+
heuristics based on note characteristics.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
confidence_threshold: float = 0.3,
|
| 26 |
+
velocity_threshold: int = 20,
|
| 27 |
+
duration_threshold: float = 0.05
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Initialize confidence filter.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
confidence_threshold: Minimum confidence score to keep note (0-1)
|
| 34 |
+
velocity_threshold: Minimum velocity to keep note (0-127)
|
| 35 |
+
duration_threshold: Minimum duration in seconds to keep note
|
| 36 |
+
"""
|
| 37 |
+
self.confidence_threshold = confidence_threshold
|
| 38 |
+
self.velocity_threshold = velocity_threshold
|
| 39 |
+
self.duration_threshold = duration_threshold
|
| 40 |
+
|
| 41 |
+
def filter_midi_by_confidence(
|
| 42 |
+
self,
|
| 43 |
+
midi_path: Path,
|
| 44 |
+
confidence_scores: Optional[Dict] = None,
|
| 45 |
+
output_path: Optional[Path] = None
|
| 46 |
+
) -> Path:
|
| 47 |
+
"""
|
| 48 |
+
Filter MIDI notes based on confidence scores or heuristics.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
midi_path: Input MIDI file
|
| 52 |
+
confidence_scores: Optional dict mapping (onset_time, pitch) -> confidence
|
| 53 |
+
output_path: Output path (default: input_path with _filtered suffix)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Path to filtered MIDI file
|
| 57 |
+
"""
|
| 58 |
+
# Load MIDI
|
| 59 |
+
pm = pretty_midi.PrettyMIDI(str(midi_path))
|
| 60 |
+
|
| 61 |
+
# Create new MIDI with filtered notes
|
| 62 |
+
filtered_pm = pretty_midi.PrettyMIDI(initial_tempo=pm.estimate_tempo())
|
| 63 |
+
|
| 64 |
+
total_notes = 0
|
| 65 |
+
kept_notes = 0
|
| 66 |
+
|
| 67 |
+
for inst in pm.instruments:
|
| 68 |
+
if inst.is_drum:
|
| 69 |
+
# Keep drum tracks as-is
|
| 70 |
+
filtered_pm.instruments.append(inst)
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
# Create new instrument with filtered notes
|
| 74 |
+
filtered_inst = pretty_midi.Instrument(
|
| 75 |
+
program=inst.program,
|
| 76 |
+
is_drum=inst.is_drum,
|
| 77 |
+
name=inst.name
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
for note in inst.notes:
|
| 81 |
+
total_notes += 1
|
| 82 |
+
|
| 83 |
+
# Get confidence score if available
|
| 84 |
+
if confidence_scores is not None:
|
| 85 |
+
# Find closest note in confidence scores
|
| 86 |
+
confidence = self._get_note_confidence(
|
| 87 |
+
note,
|
| 88 |
+
confidence_scores
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
# Use heuristic confidence based on note characteristics
|
| 92 |
+
confidence = self._estimate_confidence(note)
|
| 93 |
+
|
| 94 |
+
# Filter based on confidence and other thresholds
|
| 95 |
+
if self._should_keep_note(note, confidence):
|
| 96 |
+
filtered_inst.notes.append(note)
|
| 97 |
+
kept_notes += 1
|
| 98 |
+
|
| 99 |
+
filtered_pm.instruments.append(filtered_inst)
|
| 100 |
+
|
| 101 |
+
# Set output path
|
| 102 |
+
if output_path is None:
|
| 103 |
+
output_path = midi_path.with_stem(f"{midi_path.stem}_filtered")
|
| 104 |
+
|
| 105 |
+
# Save filtered MIDI
|
| 106 |
+
filtered_pm.write(str(output_path))
|
| 107 |
+
|
| 108 |
+
removed = total_notes - kept_notes
|
| 109 |
+
print(f" Confidence filtering: kept {kept_notes}/{total_notes} notes (removed {removed})")
|
| 110 |
+
|
| 111 |
+
return output_path
|
| 112 |
+
|
| 113 |
+
def _get_note_confidence(
|
| 114 |
+
self,
|
| 115 |
+
note: pretty_midi.Note,
|
| 116 |
+
confidence_scores: Dict
|
| 117 |
+
) -> float:
|
| 118 |
+
"""
|
| 119 |
+
Get confidence score for a note from the confidence scores dict.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
note: Note to get confidence for
|
| 123 |
+
confidence_scores: Dict mapping (onset_time, pitch) -> confidence
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Confidence score (0-1), or 1.0 if not found
|
| 127 |
+
"""
|
| 128 |
+
# Try exact match
|
| 129 |
+
key = (note.start, note.pitch)
|
| 130 |
+
if key in confidence_scores:
|
| 131 |
+
return confidence_scores[key]
|
| 132 |
+
|
| 133 |
+
# Try approximate match (within 50ms)
|
| 134 |
+
tolerance = 0.05
|
| 135 |
+
for (onset, pitch), confidence in confidence_scores.items():
|
| 136 |
+
if pitch == note.pitch and abs(onset - note.start) < tolerance:
|
| 137 |
+
return confidence
|
| 138 |
+
|
| 139 |
+
# Default to high confidence if not found (don't filter)
|
| 140 |
+
return 1.0
|
| 141 |
+
|
| 142 |
+
def _estimate_confidence(self, note: pretty_midi.Note) -> float:
|
| 143 |
+
"""
|
| 144 |
+
Estimate confidence based on note characteristics (heuristic).
|
| 145 |
+
|
| 146 |
+
Heuristics:
|
| 147 |
+
- Very short notes (< 50ms) → likely false positives → low confidence
|
| 148 |
+
- Very quiet notes (velocity < 20) → likely noise → low confidence
|
| 149 |
+
- Normal duration + reasonable velocity → high confidence
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
note: Note to estimate confidence for
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Estimated confidence (0-1)
|
| 156 |
+
"""
|
| 157 |
+
confidence = 1.0
|
| 158 |
+
|
| 159 |
+
# Duration-based confidence
|
| 160 |
+
duration = note.end - note.start
|
| 161 |
+
if duration < 0.05: # < 50ms
|
| 162 |
+
confidence *= 0.3
|
| 163 |
+
elif duration < 0.1: # < 100ms
|
| 164 |
+
confidence *= 0.6
|
| 165 |
+
|
| 166 |
+
# Velocity-based confidence
|
| 167 |
+
if note.velocity < 20:
|
| 168 |
+
confidence *= 0.2
|
| 169 |
+
elif note.velocity < 40:
|
| 170 |
+
confidence *= 0.5
|
| 171 |
+
|
| 172 |
+
return confidence
|
| 173 |
+
|
| 174 |
+
def _should_keep_note(
|
| 175 |
+
self,
|
| 176 |
+
note: pretty_midi.Note,
|
| 177 |
+
confidence: float
|
| 178 |
+
) -> bool:
|
| 179 |
+
"""
|
| 180 |
+
Determine whether to keep a note based on confidence and thresholds.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
note: Note to evaluate
|
| 184 |
+
confidence: Confidence score for the note
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
True if note should be kept, False otherwise
|
| 188 |
+
"""
|
| 189 |
+
# Check confidence threshold
|
| 190 |
+
if confidence < self.confidence_threshold:
|
| 191 |
+
return False
|
| 192 |
+
|
| 193 |
+
# Check velocity threshold
|
| 194 |
+
if note.velocity < self.velocity_threshold:
|
| 195 |
+
return False
|
| 196 |
+
|
| 197 |
+
# Check duration threshold
|
| 198 |
+
duration = note.end - note.start
|
| 199 |
+
if duration < self.duration_threshold:
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
return True
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
# Test the confidence filter
|
| 207 |
+
import argparse
|
| 208 |
+
|
| 209 |
+
parser = argparse.ArgumentParser(description="Test Confidence Filter")
|
| 210 |
+
parser.add_argument("midi_file", type=str, help="Path to MIDI file")
|
| 211 |
+
parser.add_argument("--output", type=str, default=None,
|
| 212 |
+
help="Output MIDI file path")
|
| 213 |
+
parser.add_argument("--confidence-threshold", type=float, default=0.3,
|
| 214 |
+
help="Confidence threshold (0-1)")
|
| 215 |
+
parser.add_argument("--velocity-threshold", type=int, default=20,
|
| 216 |
+
help="Velocity threshold (0-127)")
|
| 217 |
+
parser.add_argument("--duration-threshold", type=float, default=0.05,
|
| 218 |
+
help="Duration threshold in seconds")
|
| 219 |
+
args = parser.parse_args()
|
| 220 |
+
|
| 221 |
+
filter = ConfidenceFilter(
|
| 222 |
+
confidence_threshold=args.confidence_threshold,
|
| 223 |
+
velocity_threshold=args.velocity_threshold,
|
| 224 |
+
duration_threshold=args.duration_threshold
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
midi_path = Path(args.midi_file)
|
| 228 |
+
output_path = Path(args.output) if args.output else None
|
| 229 |
+
|
| 230 |
+
# Filter MIDI
|
| 231 |
+
filtered_path = filter.filter_midi_by_confidence(
|
| 232 |
+
midi_path,
|
| 233 |
+
confidence_scores=None, # Use heuristics
|
| 234 |
+
output_path=output_path
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
print(f"\n✓ Filtered MIDI saved: {filtered_path}")
|
backend/ensemble_transcriber.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ensemble Transcription Module
|
| 3 |
+
|
| 4 |
+
Combines multiple transcription models via voting for improved accuracy.
|
| 5 |
+
|
| 6 |
+
Ensemble Strategy:
|
| 7 |
+
- YourMT3+: Multi-instrument generalist, excellent polyphony & expressive timing (80-85% F1)
|
| 8 |
+
- ByteDance: Piano specialist, high precision on piano-only audio (90-95% F1)
|
| 9 |
+
- Combined: Voting reduces false positives and false negatives (90-95% F1 expected)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import List, Dict, Optional, Literal
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
import numpy as np
|
| 16 |
+
from mido import MidiFile, MidiTrack, Message, MetaMessage
|
| 17 |
+
import pretty_midi
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class Note:
|
| 22 |
+
"""Musical note with timing and pitch information."""
|
| 23 |
+
pitch: int # MIDI pitch (0-127)
|
| 24 |
+
onset: float # Start time in seconds
|
| 25 |
+
offset: float # End time in seconds
|
| 26 |
+
velocity: int = 64 # Note velocity (0-127)
|
| 27 |
+
confidence: float = 1.0 # Confidence score (0-1)
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def duration(self) -> float:
|
| 31 |
+
"""Note duration in seconds."""
|
| 32 |
+
return self.offset - self.onset
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class EnsembleTranscriber:
|
| 36 |
+
"""
|
| 37 |
+
Ensemble transcription using multiple models with voting.
|
| 38 |
+
|
| 39 |
+
Voting Strategies:
|
| 40 |
+
1. 'weighted': Sum confidence scores, keep notes above threshold
|
| 41 |
+
2. 'intersection': Only keep notes agreed upon by all models (high precision)
|
| 42 |
+
3. 'union': Keep all notes from all models (high recall)
|
| 43 |
+
4. 'majority': Keep notes predicted by >=50% of models
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
yourmt3_transcriber,
|
| 49 |
+
bytedance_transcriber,
|
| 50 |
+
voting_strategy: Literal['weighted', 'intersection', 'union', 'majority'] = 'weighted',
|
| 51 |
+
onset_tolerance_ms: int = 50,
|
| 52 |
+
confidence_threshold: float = 0.6
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Initialize ensemble transcriber.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
yourmt3_transcriber: YourMT3Transcriber instance
|
| 59 |
+
bytedance_transcriber: ByteDanceTranscriber instance
|
| 60 |
+
voting_strategy: How to combine predictions
|
| 61 |
+
onset_tolerance_ms: Time window for matching notes (milliseconds)
|
| 62 |
+
confidence_threshold: Minimum confidence for 'weighted' strategy
|
| 63 |
+
"""
|
| 64 |
+
self.yourmt3 = yourmt3_transcriber
|
| 65 |
+
self.bytedance = bytedance_transcriber
|
| 66 |
+
self.voting_strategy = voting_strategy
|
| 67 |
+
self.onset_tolerance = onset_tolerance_ms / 1000.0 # Convert to seconds
|
| 68 |
+
self.confidence_threshold = confidence_threshold
|
| 69 |
+
|
| 70 |
+
def transcribe(
|
| 71 |
+
self,
|
| 72 |
+
audio_path: Path,
|
| 73 |
+
output_dir: Optional[Path] = None
|
| 74 |
+
) -> Path:
|
| 75 |
+
"""
|
| 76 |
+
Transcribe audio using ensemble of models.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
audio_path: Path to audio file (should be piano stem)
|
| 80 |
+
output_dir: Directory for output MIDI file
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Path to ensemble MIDI file
|
| 84 |
+
"""
|
| 85 |
+
if output_dir is None:
|
| 86 |
+
output_dir = audio_path.parent
|
| 87 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 88 |
+
|
| 89 |
+
print(f"\n ═══ Ensemble Transcription ═══")
|
| 90 |
+
print(f" Strategy: {self.voting_strategy}")
|
| 91 |
+
print(f" Onset tolerance: {self.onset_tolerance*1000:.0f}ms")
|
| 92 |
+
|
| 93 |
+
# Transcribe with YourMT3+
|
| 94 |
+
print(f"\n [1/2] Transcribing with YourMT3+...")
|
| 95 |
+
yourmt3_midi = self.yourmt3.transcribe_audio(audio_path, output_dir)
|
| 96 |
+
yourmt3_notes = self._extract_notes_from_midi(yourmt3_midi)
|
| 97 |
+
print(f" ✓ YourMT3+ found {len(yourmt3_notes)} notes")
|
| 98 |
+
|
| 99 |
+
# Transcribe with ByteDance
|
| 100 |
+
print(f"\n [2/2] Transcribing with ByteDance...")
|
| 101 |
+
bytedance_midi = self.bytedance.transcribe_audio(audio_path, output_dir)
|
| 102 |
+
bytedance_notes = self._extract_notes_from_midi(bytedance_midi)
|
| 103 |
+
print(f" ✓ ByteDance found {len(bytedance_notes)} notes")
|
| 104 |
+
|
| 105 |
+
# Vote and merge
|
| 106 |
+
print(f"\n Voting with '{self.voting_strategy}' strategy...")
|
| 107 |
+
ensemble_notes = self._vote_notes(
|
| 108 |
+
[yourmt3_notes, bytedance_notes],
|
| 109 |
+
model_names=['YourMT3+', 'ByteDance']
|
| 110 |
+
)
|
| 111 |
+
print(f" ✓ Ensemble result: {len(ensemble_notes)} notes")
|
| 112 |
+
|
| 113 |
+
# Convert to MIDI
|
| 114 |
+
ensemble_midi_path = output_dir / f"{audio_path.stem}_ensemble.mid"
|
| 115 |
+
self._notes_to_midi(ensemble_notes, ensemble_midi_path)
|
| 116 |
+
|
| 117 |
+
print(f" ✓ Ensemble MIDI saved: {ensemble_midi_path.name}")
|
| 118 |
+
print(f" ═══════════════════════════════\n")
|
| 119 |
+
|
| 120 |
+
return ensemble_midi_path
|
| 121 |
+
|
| 122 |
+
def _extract_notes_from_midi(self, midi_path: Path) -> List[Note]:
|
| 123 |
+
"""
|
| 124 |
+
Extract notes from MIDI file.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
midi_path: Path to MIDI file
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
List of Note objects
|
| 131 |
+
"""
|
| 132 |
+
pm = pretty_midi.PrettyMIDI(str(midi_path))
|
| 133 |
+
|
| 134 |
+
notes = []
|
| 135 |
+
for instrument in pm.instruments:
|
| 136 |
+
if instrument.is_drum:
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
for note in instrument.notes:
|
| 140 |
+
notes.append(Note(
|
| 141 |
+
pitch=note.pitch,
|
| 142 |
+
onset=note.start,
|
| 143 |
+
offset=note.end,
|
| 144 |
+
velocity=note.velocity,
|
| 145 |
+
confidence=1.0 # Default confidence (TODO: extract from model if available)
|
| 146 |
+
))
|
| 147 |
+
|
| 148 |
+
# Sort by onset time
|
| 149 |
+
notes.sort(key=lambda n: n.onset)
|
| 150 |
+
return notes
|
| 151 |
+
|
| 152 |
+
def _vote_notes(
|
| 153 |
+
self,
|
| 154 |
+
note_lists: List[List[Note]],
|
| 155 |
+
model_names: List[str]
|
| 156 |
+
) -> List[Note]:
|
| 157 |
+
"""
|
| 158 |
+
Vote on notes from multiple models.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
note_lists: List of note lists from different models
|
| 162 |
+
model_names: Names of models (for logging)
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Merged list of notes after voting
|
| 166 |
+
"""
|
| 167 |
+
if self.voting_strategy == 'weighted':
|
| 168 |
+
return self._vote_weighted(note_lists, model_names)
|
| 169 |
+
elif self.voting_strategy == 'intersection':
|
| 170 |
+
return self._vote_intersection(note_lists, model_names)
|
| 171 |
+
elif self.voting_strategy == 'union':
|
| 172 |
+
return self._vote_union(note_lists, model_names)
|
| 173 |
+
elif self.voting_strategy == 'majority':
|
| 174 |
+
return self._vote_majority(note_lists, model_names)
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError(f"Unknown voting strategy: {self.voting_strategy}")
|
| 177 |
+
|
| 178 |
+
def _vote_weighted(
|
| 179 |
+
self,
|
| 180 |
+
note_lists: List[List[Note]],
|
| 181 |
+
model_names: List[str]
|
| 182 |
+
) -> List[Note]:
|
| 183 |
+
"""
|
| 184 |
+
Weighted voting: Sum confidence scores, keep notes above threshold.
|
| 185 |
+
|
| 186 |
+
Gives higher weight to ByteDance (piano specialist).
|
| 187 |
+
"""
|
| 188 |
+
# Model weights (ByteDance is more accurate for piano)
|
| 189 |
+
weights = {'YourMT3+': 0.4, 'ByteDance': 0.6}
|
| 190 |
+
|
| 191 |
+
# Group notes by (onset_bucket, pitch)
|
| 192 |
+
note_groups = {}
|
| 193 |
+
|
| 194 |
+
for model_idx, notes in enumerate(note_lists):
|
| 195 |
+
model_name = model_names[model_idx]
|
| 196 |
+
weight = weights.get(model_name, 1.0 / len(note_lists))
|
| 197 |
+
|
| 198 |
+
for note in notes:
|
| 199 |
+
# Quantize onset to tolerance bucket
|
| 200 |
+
onset_bucket = round(note.onset / self.onset_tolerance)
|
| 201 |
+
key = (onset_bucket, note.pitch)
|
| 202 |
+
|
| 203 |
+
if key not in note_groups:
|
| 204 |
+
note_groups[key] = []
|
| 205 |
+
|
| 206 |
+
# Add note with weighted confidence
|
| 207 |
+
note.confidence *= weight
|
| 208 |
+
note_groups[key].append(note)
|
| 209 |
+
|
| 210 |
+
# Merge notes in each group
|
| 211 |
+
merged_notes = []
|
| 212 |
+
for (onset_bucket, pitch), group in note_groups.items():
|
| 213 |
+
# Sum confidence across models
|
| 214 |
+
total_confidence = sum(n.confidence for n in group)
|
| 215 |
+
|
| 216 |
+
if total_confidence >= self.confidence_threshold:
|
| 217 |
+
# Use average timing and velocity
|
| 218 |
+
avg_onset = np.mean([n.onset for n in group])
|
| 219 |
+
avg_offset = np.mean([n.offset for n in group])
|
| 220 |
+
avg_velocity = int(np.mean([n.velocity for n in group]))
|
| 221 |
+
|
| 222 |
+
merged_notes.append(Note(
|
| 223 |
+
pitch=pitch,
|
| 224 |
+
onset=avg_onset,
|
| 225 |
+
offset=avg_offset,
|
| 226 |
+
velocity=avg_velocity,
|
| 227 |
+
confidence=total_confidence
|
| 228 |
+
))
|
| 229 |
+
|
| 230 |
+
merged_notes.sort(key=lambda n: n.onset)
|
| 231 |
+
return merged_notes
|
| 232 |
+
|
| 233 |
+
def _vote_intersection(
|
| 234 |
+
self,
|
| 235 |
+
note_lists: List[List[Note]],
|
| 236 |
+
model_names: List[str]
|
| 237 |
+
) -> List[Note]:
|
| 238 |
+
"""
|
| 239 |
+
Intersection voting: Only keep notes agreed upon by ALL models.
|
| 240 |
+
High precision, potentially lower recall.
|
| 241 |
+
"""
|
| 242 |
+
if len(note_lists) == 0:
|
| 243 |
+
return []
|
| 244 |
+
|
| 245 |
+
# Start with first model's notes
|
| 246 |
+
base_notes = note_lists[0]
|
| 247 |
+
matched_notes = []
|
| 248 |
+
|
| 249 |
+
for base_note in base_notes:
|
| 250 |
+
# Check if this note appears in ALL other models
|
| 251 |
+
found_in_all = True
|
| 252 |
+
|
| 253 |
+
for other_notes in note_lists[1:]:
|
| 254 |
+
if not self._find_matching_note(base_note, other_notes):
|
| 255 |
+
found_in_all = False
|
| 256 |
+
break
|
| 257 |
+
|
| 258 |
+
if found_in_all:
|
| 259 |
+
matched_notes.append(base_note)
|
| 260 |
+
|
| 261 |
+
return matched_notes
|
| 262 |
+
|
| 263 |
+
def _vote_union(
|
| 264 |
+
self,
|
| 265 |
+
note_lists: List[List[Note]],
|
| 266 |
+
model_names: List[str]
|
| 267 |
+
) -> List[Note]:
|
| 268 |
+
"""
|
| 269 |
+
Union voting: Keep ALL notes from ALL models.
|
| 270 |
+
High recall, potentially more false positives.
|
| 271 |
+
"""
|
| 272 |
+
# Combine all notes
|
| 273 |
+
all_notes = []
|
| 274 |
+
for notes in note_lists:
|
| 275 |
+
all_notes.extend(notes)
|
| 276 |
+
|
| 277 |
+
# Deduplicate: group similar notes and average them
|
| 278 |
+
note_groups = {}
|
| 279 |
+
|
| 280 |
+
for note in all_notes:
|
| 281 |
+
onset_bucket = round(note.onset / self.onset_tolerance)
|
| 282 |
+
key = (onset_bucket, note.pitch)
|
| 283 |
+
|
| 284 |
+
if key not in note_groups:
|
| 285 |
+
note_groups[key] = []
|
| 286 |
+
note_groups[key].append(note)
|
| 287 |
+
|
| 288 |
+
# Average duplicates
|
| 289 |
+
merged_notes = []
|
| 290 |
+
for (onset_bucket, pitch), group in note_groups.items():
|
| 291 |
+
avg_onset = np.mean([n.onset for n in group])
|
| 292 |
+
avg_offset = np.mean([n.offset for n in group])
|
| 293 |
+
avg_velocity = int(np.mean([n.velocity for n in group]))
|
| 294 |
+
|
| 295 |
+
merged_notes.append(Note(
|
| 296 |
+
pitch=pitch,
|
| 297 |
+
onset=avg_onset,
|
| 298 |
+
offset=avg_offset,
|
| 299 |
+
velocity=avg_velocity,
|
| 300 |
+
confidence=len(group) / len(note_lists) # Confidence = agreement ratio
|
| 301 |
+
))
|
| 302 |
+
|
| 303 |
+
merged_notes.sort(key=lambda n: n.onset)
|
| 304 |
+
return merged_notes
|
| 305 |
+
|
| 306 |
+
def _vote_majority(
|
| 307 |
+
self,
|
| 308 |
+
note_lists: List[List[Note]],
|
| 309 |
+
model_names: List[str]
|
| 310 |
+
) -> List[Note]:
|
| 311 |
+
"""
|
| 312 |
+
Majority voting: Keep notes predicted by >=50% of models.
|
| 313 |
+
Balanced precision and recall.
|
| 314 |
+
"""
|
| 315 |
+
threshold = len(note_lists) / 2.0
|
| 316 |
+
|
| 317 |
+
# Group notes by (onset_bucket, pitch)
|
| 318 |
+
note_groups = {}
|
| 319 |
+
|
| 320 |
+
for notes in note_lists:
|
| 321 |
+
for note in notes:
|
| 322 |
+
onset_bucket = round(note.onset / self.onset_tolerance)
|
| 323 |
+
key = (onset_bucket, note.pitch)
|
| 324 |
+
|
| 325 |
+
if key not in note_groups:
|
| 326 |
+
note_groups[key] = []
|
| 327 |
+
note_groups[key].append(note)
|
| 328 |
+
|
| 329 |
+
# Keep notes with majority agreement
|
| 330 |
+
merged_notes = []
|
| 331 |
+
for (onset_bucket, pitch), group in note_groups.items():
|
| 332 |
+
if len(group) >= threshold:
|
| 333 |
+
avg_onset = np.mean([n.onset for n in group])
|
| 334 |
+
avg_offset = np.mean([n.offset for n in group])
|
| 335 |
+
avg_velocity = int(np.mean([n.velocity for n in group]))
|
| 336 |
+
|
| 337 |
+
merged_notes.append(Note(
|
| 338 |
+
pitch=pitch,
|
| 339 |
+
onset=avg_onset,
|
| 340 |
+
offset=avg_offset,
|
| 341 |
+
velocity=avg_velocity,
|
| 342 |
+
confidence=len(group) / len(note_lists)
|
| 343 |
+
))
|
| 344 |
+
|
| 345 |
+
merged_notes.sort(key=lambda n: n.onset)
|
| 346 |
+
return merged_notes
|
| 347 |
+
|
| 348 |
+
def _find_matching_note(self, target: Note, notes: List[Note]) -> Optional[Note]:
|
| 349 |
+
"""Find a note that matches the target note within tolerance."""
|
| 350 |
+
for note in notes:
|
| 351 |
+
if (note.pitch == target.pitch and
|
| 352 |
+
abs(note.onset - target.onset) <= self.onset_tolerance):
|
| 353 |
+
return note
|
| 354 |
+
return None
|
| 355 |
+
|
| 356 |
+
def _notes_to_midi(self, notes: List[Note], output_path: Path):
|
| 357 |
+
"""
|
| 358 |
+
Convert list of notes to MIDI file.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
notes: List of Note objects
|
| 362 |
+
output_path: Path for output MIDI file
|
| 363 |
+
"""
|
| 364 |
+
# Create MIDI file
|
| 365 |
+
mid = MidiFile()
|
| 366 |
+
track = MidiTrack()
|
| 367 |
+
mid.tracks.append(track)
|
| 368 |
+
|
| 369 |
+
# Add tempo (120 BPM default)
|
| 370 |
+
track.append(MetaMessage('set_tempo', tempo=500000, time=0))
|
| 371 |
+
|
| 372 |
+
# Convert notes to MIDI messages
|
| 373 |
+
# (simplified - assumes single instrument, no overlapping notes with same pitch)
|
| 374 |
+
|
| 375 |
+
# Use absolute timing, then convert to delta
|
| 376 |
+
events = []
|
| 377 |
+
|
| 378 |
+
for note in notes:
|
| 379 |
+
# Convert seconds to ticks (480 ticks per beat, 120 BPM)
|
| 380 |
+
ticks_per_second = 480 * 2 # 480 ticks/beat * 2 beats/second at 120 BPM
|
| 381 |
+
onset_ticks = int(note.onset * ticks_per_second)
|
| 382 |
+
offset_ticks = int(note.offset * ticks_per_second)
|
| 383 |
+
|
| 384 |
+
events.append((onset_ticks, 'note_on', note.pitch, note.velocity))
|
| 385 |
+
events.append((offset_ticks, 'note_off', note.pitch, 0))
|
| 386 |
+
|
| 387 |
+
# Sort by time
|
| 388 |
+
events.sort(key=lambda e: e[0])
|
| 389 |
+
|
| 390 |
+
# Convert to delta time and add to track
|
| 391 |
+
previous_time = 0
|
| 392 |
+
for abs_time, msg_type, pitch, velocity in events:
|
| 393 |
+
delta_time = abs_time - previous_time
|
| 394 |
+
previous_time = abs_time
|
| 395 |
+
|
| 396 |
+
track.append(Message(
|
| 397 |
+
msg_type,
|
| 398 |
+
note=pitch,
|
| 399 |
+
velocity=velocity,
|
| 400 |
+
time=delta_time
|
| 401 |
+
))
|
| 402 |
+
|
| 403 |
+
# Add end of track
|
| 404 |
+
track.append(MetaMessage('end_of_track', time=0))
|
| 405 |
+
|
| 406 |
+
# Save
|
| 407 |
+
mid.save(output_path)
|
backend/evaluation/CLUSTER_SETUP.md
DELETED
|
@@ -1,262 +0,0 @@
|
|
| 1 |
-
# Cluster Benchmark Setup
|
| 2 |
-
|
| 3 |
-
Run transcription benchmarks on a SLURM cluster with GPU acceleration.
|
| 4 |
-
|
| 5 |
-
## Directory Structure
|
| 6 |
-
|
| 7 |
-
```
|
| 8 |
-
/cluster/path/
|
| 9 |
-
├── data/
|
| 10 |
-
│ └── maestro-v3.0.0/ # MAESTRO dataset (120GB)
|
| 11 |
-
│ ├── 2004/
|
| 12 |
-
│ │ ├── *.wav # Audio files
|
| 13 |
-
│ │ └── *.midi # Ground truth MIDI
|
| 14 |
-
│ ├── 2006/
|
| 15 |
-
│ ├── 2008/
|
| 16 |
-
│ └── ...
|
| 17 |
-
└── rescored/ # Git repo
|
| 18 |
-
└── backend/
|
| 19 |
-
├── evaluation/
|
| 20 |
-
│ ├── slurm_benchmark.sh
|
| 21 |
-
│ └── ...
|
| 22 |
-
└── requirements.txt
|
| 23 |
-
```
|
| 24 |
-
|
| 25 |
-
## One-Time Setup
|
| 26 |
-
|
| 27 |
-
### 1. Download MAESTRO Dataset (on cluster)
|
| 28 |
-
|
| 29 |
-
```bash
|
| 30 |
-
# Navigate to data directory
|
| 31 |
-
cd /cluster/path/data/
|
| 32 |
-
|
| 33 |
-
# Download MAESTRO v3.0.0 (120GB - will take a while)
|
| 34 |
-
wget https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.zip
|
| 35 |
-
|
| 36 |
-
# Extract (creates maestro-v3.0.0/ directory)
|
| 37 |
-
unzip maestro-v3.0.0.zip
|
| 38 |
-
|
| 39 |
-
# Verify structure
|
| 40 |
-
ls maestro-v3.0.0/
|
| 41 |
-
# Should see: 2004/ 2006/ 2008/ 2009/ 2011/ 2013/ 2014/ 2015/ 2017/ 2018/
|
| 42 |
-
|
| 43 |
-
# Optional: Remove zip to save space
|
| 44 |
-
rm maestro-v3.0.0.zip
|
| 45 |
-
```
|
| 46 |
-
|
| 47 |
-
### 2. Clone Repository (on cluster)
|
| 48 |
-
|
| 49 |
-
```bash
|
| 50 |
-
cd /cluster/path/
|
| 51 |
-
|
| 52 |
-
# Clone your repo
|
| 53 |
-
git clone <your-repo-url> rescored
|
| 54 |
-
|
| 55 |
-
# Navigate to backend
|
| 56 |
-
cd rescored/backend/
|
| 57 |
-
|
| 58 |
-
# Make benchmark script executable
|
| 59 |
-
chmod +x evaluation/slurm_benchmark.sh
|
| 60 |
-
```
|
| 61 |
-
|
| 62 |
-
### 3. Create Virtual Environment (on cluster)
|
| 63 |
-
|
| 64 |
-
```bash
|
| 65 |
-
# In rescored/backend/
|
| 66 |
-
python3.10 -m venv .venv
|
| 67 |
-
source .venv/bin/activate
|
| 68 |
-
|
| 69 |
-
# Install dependencies
|
| 70 |
-
pip install --upgrade pip
|
| 71 |
-
|
| 72 |
-
# Install Cython first (required by madmom)
|
| 73 |
-
pip install Cython
|
| 74 |
-
|
| 75 |
-
# Install madmom separately to avoid build isolation issues
|
| 76 |
-
pip install --no-build-isolation madmom>=0.16.1
|
| 77 |
-
|
| 78 |
-
# Uninstall problematic packages if previously installed
|
| 79 |
-
pip uninstall -y torchcodec torchaudio
|
| 80 |
-
|
| 81 |
-
# Install remaining dependencies (includes torchaudio 2.1.0 which uses SoundFile backend)
|
| 82 |
-
pip install -r requirements.txt
|
| 83 |
-
```
|
| 84 |
-
|
| 85 |
-
**Note**: The SLURM script will automatically load FFmpeg module, which is required by Demucs for audio loading. If running manually, load it with `module load ffmpeg`.
|
| 86 |
-
|
| 87 |
-
## Running Benchmarks
|
| 88 |
-
|
| 89 |
-
### Baseline Benchmark (YourMT3+)
|
| 90 |
-
|
| 91 |
-
```bash
|
| 92 |
-
# In rescored/backend/
|
| 93 |
-
sbatch evaluation/slurm_benchmark.sh
|
| 94 |
-
|
| 95 |
-
# With custom paths:
|
| 96 |
-
sbatch evaluation/slurm_benchmark.sh \
|
| 97 |
-
../data/maestro-v3.0.0 \ # MAESTRO dataset path
|
| 98 |
-
yourmt3 \ # Model to benchmark
|
| 99 |
-
. # Repo directory (current)
|
| 100 |
-
```
|
| 101 |
-
|
| 102 |
-
### Check Job Status
|
| 103 |
-
|
| 104 |
-
```bash
|
| 105 |
-
# View job queue
|
| 106 |
-
squeue -u $USER
|
| 107 |
-
|
| 108 |
-
# View running job output (live)
|
| 109 |
-
tail -f logs/slurm/benchmark_<JOB_ID>.log
|
| 110 |
-
|
| 111 |
-
# View all output after completion
|
| 112 |
-
cat logs/slurm/benchmark_<JOB_ID>.log
|
| 113 |
-
```
|
| 114 |
-
|
| 115 |
-
### Download Results to Local Machine
|
| 116 |
-
|
| 117 |
-
After job completes:
|
| 118 |
-
|
| 119 |
-
```bash
|
| 120 |
-
# From your local machine
|
| 121 |
-
scp <cluster>:/cluster/path/rescored/backend/evaluation/results/yourmt3_results.* .
|
| 122 |
-
|
| 123 |
-
# Or download entire results directory
|
| 124 |
-
scp -r <cluster>:/cluster/path/rescored/backend/evaluation/results/ .
|
| 125 |
-
```
|
| 126 |
-
|
| 127 |
-
## Benchmark Workflow
|
| 128 |
-
|
| 129 |
-
The SLURM script automatically:
|
| 130 |
-
|
| 131 |
-
1. **Validates MAESTRO dataset** exists and has correct structure
|
| 132 |
-
2. **Prepares test cases** - Extracts 8 curated pieces (easy/medium/hard)
|
| 133 |
-
3. **Runs transcription** - Processes each test case through pipeline:
|
| 134 |
-
- YouTube audio download → Demucs separation → YourMT3+ transcription
|
| 135 |
-
4. **Calculates metrics** - F1, precision, recall, onset MAE
|
| 136 |
-
5. **Saves results** - JSON + CSV format
|
| 137 |
-
|
| 138 |
-
## Expected Timeline
|
| 139 |
-
|
| 140 |
-
With **L40 GPU**:
|
| 141 |
-
- MAESTRO download: ~30-60 min (one-time)
|
| 142 |
-
- Test case preparation: ~1 min
|
| 143 |
-
- Benchmark (8 test cases): ~8-12 hours
|
| 144 |
-
- Per test case: ~60-90 min (includes Demucs + YourMT3+)
|
| 145 |
-
|
| 146 |
-
With **CPU only** (no GPU):
|
| 147 |
-
- Benchmark would take ~24-48 hours (not recommended)
|
| 148 |
-
|
| 149 |
-
## Output Files
|
| 150 |
-
|
| 151 |
-
After successful run:
|
| 152 |
-
|
| 153 |
-
```
|
| 154 |
-
evaluation/
|
| 155 |
-
├── test_videos.json # Test case metadata (8 pieces)
|
| 156 |
-
├── results/
|
| 157 |
-
│ ├── yourmt3_results.json # Detailed results (F1, precision, recall)
|
| 158 |
-
│ └── yourmt3_results.csv # Same data in CSV format
|
| 159 |
-
└── logs/
|
| 160 |
-
└── slurm/
|
| 161 |
-
└── benchmark_<JOB_ID>.log # Full execution log
|
| 162 |
-
```
|
| 163 |
-
|
| 164 |
-
### Results Format (JSON)
|
| 165 |
-
|
| 166 |
-
```json
|
| 167 |
-
[
|
| 168 |
-
{
|
| 169 |
-
"test_case": "MAESTRO_2004_Track03",
|
| 170 |
-
"genre": "classical",
|
| 171 |
-
"difficulty": "easy",
|
| 172 |
-
"f1_score": 0.892,
|
| 173 |
-
"precision": 0.871,
|
| 174 |
-
"recall": 0.914,
|
| 175 |
-
"onset_mae": 0.0382,
|
| 176 |
-
"pitch_accuracy": 0.987,
|
| 177 |
-
"processing_time": 127.3,
|
| 178 |
-
"success": true
|
| 179 |
-
}
|
| 180 |
-
]
|
| 181 |
-
```
|
| 182 |
-
|
| 183 |
-
## Benchmarking Multiple Models
|
| 184 |
-
|
| 185 |
-
After implementing ByteDance (Phase 2) or ensemble (Phase 3):
|
| 186 |
-
|
| 187 |
-
```bash
|
| 188 |
-
# Benchmark ByteDance
|
| 189 |
-
sbatch evaluation/slurm_benchmark.sh ../data/maestro-v3.0.0 bytedance
|
| 190 |
-
|
| 191 |
-
# Benchmark Ensemble
|
| 192 |
-
sbatch evaluation/slurm_benchmark.sh ../data/maestro-v3.0.0 ensemble
|
| 193 |
-
```
|
| 194 |
-
|
| 195 |
-
## Troubleshooting
|
| 196 |
-
|
| 197 |
-
### Job Fails with "MAESTRO dataset not found"
|
| 198 |
-
|
| 199 |
-
```bash
|
| 200 |
-
# Check if dataset exists
|
| 201 |
-
ls /cluster/path/data/maestro-v3.0.0/
|
| 202 |
-
|
| 203 |
-
# If missing, download following "One-Time Setup" instructions
|
| 204 |
-
```
|
| 205 |
-
|
| 206 |
-
### Job Fails with "Module not found"
|
| 207 |
-
|
| 208 |
-
```bash
|
| 209 |
-
# Reinstall dependencies in venv
|
| 210 |
-
cd /cluster/path/rescored/backend/
|
| 211 |
-
source .venv/bin/activate
|
| 212 |
-
pip install -r requirements.txt
|
| 213 |
-
```
|
| 214 |
-
|
| 215 |
-
### GPU Out of Memory
|
| 216 |
-
|
| 217 |
-
SLURM script already uses conservative settings:
|
| 218 |
-
- 1 GPU (L40 with 48GB VRAM)
|
| 219 |
-
- 32GB system RAM
|
| 220 |
-
- Demucs uses mixed precision
|
| 221 |
-
|
| 222 |
-
If still failing, check:
|
| 223 |
-
```bash
|
| 224 |
-
# View GPU usage during job
|
| 225 |
-
ssh <node-from-squeue>
|
| 226 |
-
nvidia-smi
|
| 227 |
-
```
|
| 228 |
-
|
| 229 |
-
### Job Times Out
|
| 230 |
-
|
| 231 |
-
Increase time limit in script:
|
| 232 |
-
```bash
|
| 233 |
-
# Edit slurm_benchmark.sh
|
| 234 |
-
#SBATCH --time=1-00:00:00 # Change to 24 hours
|
| 235 |
-
```
|
| 236 |
-
|
| 237 |
-
## Next Steps After Baseline
|
| 238 |
-
|
| 239 |
-
1. **Analyze results** - Download JSON/CSV, review F1 scores
|
| 240 |
-
2. **Implement Phase 2** - ByteDance integration (locally)
|
| 241 |
-
3. **Re-benchmark** - Run ByteDance benchmark on cluster
|
| 242 |
-
4. **Compare** - YourMT3+ vs ByteDance F1 scores
|
| 243 |
-
5. **Iterate** - Continue through Phases 3-5
|
| 244 |
-
|
| 245 |
-
## Quick Reference
|
| 246 |
-
|
| 247 |
-
```bash
|
| 248 |
-
# Submit job
|
| 249 |
-
sbatch evaluation/slurm_benchmark.sh
|
| 250 |
-
|
| 251 |
-
# Check status
|
| 252 |
-
squeue -u $USER
|
| 253 |
-
|
| 254 |
-
# Cancel job
|
| 255 |
-
scancel <JOB_ID>
|
| 256 |
-
|
| 257 |
-
# View live logs
|
| 258 |
-
tail -f logs/slurm/benchmark_<JOB_ID>.log
|
| 259 |
-
|
| 260 |
-
# Download results
|
| 261 |
-
scp <cluster>:/cluster/path/rescored/backend/evaluation/results/* .
|
| 262 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/evaluation/README.md
DELETED
|
@@ -1,251 +0,0 @@
|
|
| 1 |
-
# Transcription Evaluation Module
|
| 2 |
-
|
| 3 |
-
Benchmarking infrastructure for measuring piano transcription accuracy.
|
| 4 |
-
|
| 5 |
-
## Overview
|
| 6 |
-
|
| 7 |
-
This module provides tools to:
|
| 8 |
-
- Calculate F1 score, precision, recall for MIDI transcription
|
| 9 |
-
- Benchmark models on MAESTRO dataset
|
| 10 |
-
- Track accuracy improvements across development phases
|
| 11 |
-
- Generate detailed reports for analysis
|
| 12 |
-
|
| 13 |
-
## Quick Start
|
| 14 |
-
|
| 15 |
-
### 1. Install Dependencies
|
| 16 |
-
|
| 17 |
-
```bash
|
| 18 |
-
cd backend
|
| 19 |
-
pip install -r requirements.txt
|
| 20 |
-
```
|
| 21 |
-
|
| 22 |
-
### 2. Prepare Test Cases (Option A: MAESTRO Dataset)
|
| 23 |
-
|
| 24 |
-
Download MAESTRO v3.0.0 from https://magenta.tensorflow.org/datasets/maestro
|
| 25 |
-
|
| 26 |
-
```bash
|
| 27 |
-
# Extract dataset to /tmp/maestro-v3.0.0/
|
| 28 |
-
# Then prepare test cases:
|
| 29 |
-
python -m evaluation.prepare_maestro \
|
| 30 |
-
--maestro-dir /tmp/maestro-v3.0.0 \
|
| 31 |
-
--output-json evaluation/test_videos.json
|
| 32 |
-
```
|
| 33 |
-
|
| 34 |
-
### 2. Prepare Test Cases (Option B: Custom Videos)
|
| 35 |
-
|
| 36 |
-
Create `evaluation/test_videos.json` manually:
|
| 37 |
-
|
| 38 |
-
```json
|
| 39 |
-
[
|
| 40 |
-
{
|
| 41 |
-
"name": "Simple Piano Melody",
|
| 42 |
-
"audio_path": "/path/to/audio.wav",
|
| 43 |
-
"ground_truth_midi": "/path/to/ground_truth.mid",
|
| 44 |
-
"genre": "classical",
|
| 45 |
-
"difficulty": "easy"
|
| 46 |
-
}
|
| 47 |
-
]
|
| 48 |
-
```
|
| 49 |
-
|
| 50 |
-
### 3. Run Baseline Benchmark
|
| 51 |
-
|
| 52 |
-
```bash
|
| 53 |
-
# Benchmark current YourMT3+ model
|
| 54 |
-
python -m evaluation.run_benchmark \
|
| 55 |
-
--model yourmt3 \
|
| 56 |
-
--test-cases evaluation/test_videos.json \
|
| 57 |
-
--output-dir evaluation/results
|
| 58 |
-
```
|
| 59 |
-
|
| 60 |
-
### 4. View Results
|
| 61 |
-
|
| 62 |
-
Results are saved in two formats:
|
| 63 |
-
- **JSON**: `evaluation/results/yourmt3_results.json` (detailed)
|
| 64 |
-
- **CSV**: `evaluation/results/yourmt3_results.csv` (for spreadsheets)
|
| 65 |
-
|
| 66 |
-
Example output:
|
| 67 |
-
```
|
| 68 |
-
📊 BENCHMARK SUMMARY: yourmt3
|
| 69 |
-
========================================
|
| 70 |
-
Total tests: 8
|
| 71 |
-
Successful: 8
|
| 72 |
-
Failed: 0
|
| 73 |
-
|
| 74 |
-
📈 Overall Accuracy:
|
| 75 |
-
F1 Score: 0.847
|
| 76 |
-
Precision: 0.823
|
| 77 |
-
Recall: 0.872
|
| 78 |
-
Onset MAE: 38.2ms
|
| 79 |
-
Avg Processing Time: 127.3s
|
| 80 |
-
|
| 81 |
-
📊 By Genre:
|
| 82 |
-
Classical: F1=0.847 (8 tests)
|
| 83 |
-
|
| 84 |
-
📊 By Difficulty:
|
| 85 |
-
Easy: F1=0.921 (2 tests)
|
| 86 |
-
Medium: F1=0.854 (3 tests)
|
| 87 |
-
Hard: F1=0.782 (3 tests)
|
| 88 |
-
```
|
| 89 |
-
|
| 90 |
-
## Metrics Explained
|
| 91 |
-
|
| 92 |
-
### F1 Score (Primary Metric)
|
| 93 |
-
Harmonic mean of precision and recall. Balances false positives and false negatives.
|
| 94 |
-
- **Target**: ≥0.95 (95%+ accuracy)
|
| 95 |
-
- **Good**: ≥0.85 (85%+ accuracy)
|
| 96 |
-
- **Needs improvement**: <0.80
|
| 97 |
-
|
| 98 |
-
### Precision
|
| 99 |
-
Percentage of predicted notes that are correct.
|
| 100 |
-
- High precision = few false positives (wrong notes)
|
| 101 |
-
|
| 102 |
-
### Recall
|
| 103 |
-
Percentage of ground truth notes that were detected.
|
| 104 |
-
- High recall = few false negatives (missed notes)
|
| 105 |
-
|
| 106 |
-
### Onset MAE (Mean Absolute Error)
|
| 107 |
-
Average timing error for note onsets, in milliseconds.
|
| 108 |
-
- **Excellent**: <30ms
|
| 109 |
-
- **Good**: <50ms
|
| 110 |
-
- **Acceptable**: <100ms
|
| 111 |
-
|
| 112 |
-
### Pitch Accuracy
|
| 113 |
-
Percentage of matched notes with correct pitch.
|
| 114 |
-
- Should be close to 100% if onset matching is working
|
| 115 |
-
|
| 116 |
-
## Benchmarking Workflow
|
| 117 |
-
|
| 118 |
-
### Phase 1: Baseline (Current)
|
| 119 |
-
```bash
|
| 120 |
-
python -m evaluation.run_benchmark --model yourmt3
|
| 121 |
-
```
|
| 122 |
-
|
| 123 |
-
Expected: F1 ~0.80-0.85 (current YourMT3+ performance)
|
| 124 |
-
|
| 125 |
-
### Phase 2: ByteDance Integration
|
| 126 |
-
```bash
|
| 127 |
-
# After implementing ByteDance wrapper
|
| 128 |
-
python -m evaluation.run_benchmark --model bytedance
|
| 129 |
-
|
| 130 |
-
# Compare with baseline
|
| 131 |
-
python evaluation/compare_results.py yourmt3 bytedance
|
| 132 |
-
```
|
| 133 |
-
|
| 134 |
-
Expected: F1 ~0.83-0.90 (if ByteDance generalizes to YouTube audio)
|
| 135 |
-
|
| 136 |
-
### Phase 3: Ensemble
|
| 137 |
-
```bash
|
| 138 |
-
python -m evaluation.run_benchmark --model ensemble
|
| 139 |
-
```
|
| 140 |
-
|
| 141 |
-
Expected: F1 ~0.88-0.95 (ensemble voting)
|
| 142 |
-
|
| 143 |
-
### Phase 4: With Preprocessing
|
| 144 |
-
```bash
|
| 145 |
-
# Enable audio preprocessing in app_config.py
|
| 146 |
-
# Then re-run ensemble benchmark
|
| 147 |
-
```
|
| 148 |
-
|
| 149 |
-
Expected: F1 ~0.90-0.96 (preprocessing + ensemble)
|
| 150 |
-
|
| 151 |
-
## Tolerance Settings
|
| 152 |
-
|
| 153 |
-
Default onset tolerance: **50ms**
|
| 154 |
-
|
| 155 |
-
For different difficulty levels:
|
| 156 |
-
- **Strict** (20ms): Simple melodies, slow tempo
|
| 157 |
-
- **Default** (50ms): Standard evaluation
|
| 158 |
-
- **Lenient** (100ms): Fast passages, complex music
|
| 159 |
-
|
| 160 |
-
Change tolerance:
|
| 161 |
-
```bash
|
| 162 |
-
python -m evaluation.run_benchmark --model yourmt3 --onset-tolerance 0.02 # 20ms
|
| 163 |
-
```
|
| 164 |
-
|
| 165 |
-
## Test Case Structure
|
| 166 |
-
|
| 167 |
-
Each test case requires:
|
| 168 |
-
- **Audio file** (WAV, MP3, FLAC)
|
| 169 |
-
- **Ground truth MIDI** (verified transcription)
|
| 170 |
-
- **Metadata**: genre, difficulty, name
|
| 171 |
-
|
| 172 |
-
### Recommended Test Suite
|
| 173 |
-
|
| 174 |
-
Minimum 10-15 test cases:
|
| 175 |
-
- 2-3 simple melodies (easy)
|
| 176 |
-
- 5-7 classical piano pieces (medium)
|
| 177 |
-
- 3-5 complex/fast passages (hard)
|
| 178 |
-
- Mix of genres: classical, pop, jazz
|
| 179 |
-
|
| 180 |
-
## MAESTRO Dataset
|
| 181 |
-
|
| 182 |
-
### Subset Selection
|
| 183 |
-
|
| 184 |
-
We use 8 curated pieces from MAESTRO:
|
| 185 |
-
- 2 easy (simple classical)
|
| 186 |
-
- 3 medium (Chopin, moderate tempo)
|
| 187 |
-
- 3 hard (fast passages, complex harmony)
|
| 188 |
-
|
| 189 |
-
### Why MAESTRO?
|
| 190 |
-
|
| 191 |
-
Pros:
|
| 192 |
-
- High-quality ground truth MIDI (aligned by humans)
|
| 193 |
-
- Professional piano performances
|
| 194 |
-
- Varied difficulty and styles
|
| 195 |
-
|
| 196 |
-
Cons:
|
| 197 |
-
- Clean studio recordings (not YouTube quality)
|
| 198 |
-
- All classical piano (no pop/jazz)
|
| 199 |
-
- May overestimate accuracy on real YouTube videos
|
| 200 |
-
|
| 201 |
-
### Validation on YouTube
|
| 202 |
-
|
| 203 |
-
After achieving target accuracy on MAESTRO, validate on real YouTube videos:
|
| 204 |
-
1. Transcribe 5-10 YouTube piano videos
|
| 205 |
-
2. Manually verify transcriptions in MuseScore
|
| 206 |
-
3. Measure accuracy using same metrics
|
| 207 |
-
|
| 208 |
-
## Development Workflow
|
| 209 |
-
|
| 210 |
-
1. **Baseline**: Measure current YourMT3+ (Week 1)
|
| 211 |
-
2. **Implement enhancement**: ByteDance, ensemble, etc. (Week 2-4)
|
| 212 |
-
3. **Benchmark**: Re-run on same test set
|
| 213 |
-
4. **Compare**: Did F1 improve by ≥2%?
|
| 214 |
-
5. **Iterate**: Tune parameters if needed
|
| 215 |
-
|
| 216 |
-
## Troubleshooting
|
| 217 |
-
|
| 218 |
-
### "Test cases file not found"
|
| 219 |
-
```bash
|
| 220 |
-
# Create test cases first:
|
| 221 |
-
python -m evaluation.prepare_maestro --maestro-dir /path/to/maestro-v3.0.0
|
| 222 |
-
```
|
| 223 |
-
|
| 224 |
-
### "Transcription failed"
|
| 225 |
-
Check pipeline logs for errors. Common issues:
|
| 226 |
-
- Demucs CUDA out of memory → use CPU
|
| 227 |
-
- YourMT3+ checkpoint not loaded
|
| 228 |
-
- Audio file format not supported
|
| 229 |
-
|
| 230 |
-
### "F1 score is 0.0"
|
| 231 |
-
- Check that MIDI files are valid
|
| 232 |
-
- Verify onset tolerance isn't too strict
|
| 233 |
-
- Ensure ground truth MIDI has notes
|
| 234 |
-
|
| 235 |
-
## Files
|
| 236 |
-
|
| 237 |
-
- `metrics.py` - F1, precision, recall calculation
|
| 238 |
-
- `benchmark.py` - Benchmark runner framework
|
| 239 |
-
- `prepare_maestro.py` - MAESTRO dataset preparation
|
| 240 |
-
- `run_benchmark.py` - Main CLI script
|
| 241 |
-
- `test_videos.json` - Test case metadata (created by prepare_maestro)
|
| 242 |
-
- `results/` - Benchmark results (JSON + CSV)
|
| 243 |
-
|
| 244 |
-
## Next Steps
|
| 245 |
-
|
| 246 |
-
After Phase 1 baseline:
|
| 247 |
-
- [ ] Integrate ByteDance model (Phase 2)
|
| 248 |
-
- [ ] Implement ensemble voting (Phase 3)
|
| 249 |
-
- [ ] Add audio preprocessing (Phase 4)
|
| 250 |
-
- [ ] Run comprehensive benchmarks
|
| 251 |
-
- [ ] Target: F1 ≥ 0.95 (95%+ accuracy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/evaluation/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
"""Evaluation module for transcription accuracy benchmarking."""
|
|
|
|
|
|
backend/evaluation/benchmark.py
DELETED
|
@@ -1,308 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Benchmark runner for evaluating transcription accuracy on test datasets.
|
| 3 |
-
|
| 4 |
-
Supports MAESTRO dataset and custom test videos with ground truth MIDI.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import json
|
| 8 |
-
import time
|
| 9 |
-
from dataclasses import dataclass, asdict
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
from typing import List, Dict, Optional
|
| 12 |
-
import pandas as pd
|
| 13 |
-
import sys
|
| 14 |
-
|
| 15 |
-
# Add backend directory to path for imports
|
| 16 |
-
backend_dir = Path(__file__).parent.parent
|
| 17 |
-
if str(backend_dir) not in sys.path:
|
| 18 |
-
sys.path.insert(0, str(backend_dir))
|
| 19 |
-
|
| 20 |
-
from evaluation.metrics import calculate_metrics, TranscriptionMetrics
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
@dataclass
|
| 24 |
-
class TestCase:
|
| 25 |
-
"""Represents a single test case for benchmarking."""
|
| 26 |
-
name: str # Descriptive name (e.g., "Chopin_Nocturne_Op9_No2")
|
| 27 |
-
audio_path: Path # Path to audio file (WAV/MP3)
|
| 28 |
-
ground_truth_midi: Optional[Path] = None # Path to ground truth MIDI file (None for manual review)
|
| 29 |
-
genre: str = "classical" # Genre: classical, pop, jazz, simple
|
| 30 |
-
difficulty: str = "medium" # Difficulty: easy, medium, hard
|
| 31 |
-
duration: Optional[float] = None # Duration in seconds
|
| 32 |
-
|
| 33 |
-
def to_dict(self) -> dict:
|
| 34 |
-
"""Convert to dictionary for JSON serialization."""
|
| 35 |
-
return {
|
| 36 |
-
'name': self.name,
|
| 37 |
-
'audio_path': str(self.audio_path),
|
| 38 |
-
'ground_truth_midi': str(self.ground_truth_midi) if self.ground_truth_midi else None,
|
| 39 |
-
'genre': self.genre,
|
| 40 |
-
'difficulty': self.difficulty,
|
| 41 |
-
'duration': self.duration
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
@classmethod
|
| 45 |
-
def from_dict(cls, data: dict) -> 'TestCase':
|
| 46 |
-
"""Create TestCase from dictionary."""
|
| 47 |
-
ground_truth = data.get('ground_truth_midi')
|
| 48 |
-
return cls(
|
| 49 |
-
name=data['name'],
|
| 50 |
-
audio_path=Path(data['audio_path']),
|
| 51 |
-
ground_truth_midi=Path(ground_truth) if ground_truth else None,
|
| 52 |
-
genre=data.get('genre', 'classical'),
|
| 53 |
-
difficulty=data.get('difficulty', 'medium'),
|
| 54 |
-
duration=data.get('duration')
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
@dataclass
|
| 59 |
-
class BenchmarkResult:
|
| 60 |
-
"""Results for a single test case."""
|
| 61 |
-
test_case_name: str
|
| 62 |
-
genre: str
|
| 63 |
-
difficulty: str
|
| 64 |
-
metrics: TranscriptionMetrics
|
| 65 |
-
processing_time: float # Time taken to transcribe (seconds)
|
| 66 |
-
success: bool = True
|
| 67 |
-
error_message: Optional[str] = None
|
| 68 |
-
|
| 69 |
-
def to_dict(self) -> dict:
|
| 70 |
-
"""Convert to dictionary for JSON serialization."""
|
| 71 |
-
return {
|
| 72 |
-
'test_case': self.test_case_name,
|
| 73 |
-
'genre': self.genre,
|
| 74 |
-
'difficulty': self.difficulty,
|
| 75 |
-
'f1_score': self.metrics.f1_score,
|
| 76 |
-
'precision': self.metrics.precision,
|
| 77 |
-
'recall': self.metrics.recall,
|
| 78 |
-
'onset_mae': self.metrics.onset_mae,
|
| 79 |
-
'pitch_accuracy': self.metrics.pitch_accuracy,
|
| 80 |
-
'true_positives': self.metrics.true_positives,
|
| 81 |
-
'false_positives': self.metrics.false_positives,
|
| 82 |
-
'false_negatives': self.metrics.false_negatives,
|
| 83 |
-
'processing_time': self.processing_time,
|
| 84 |
-
'success': self.success,
|
| 85 |
-
'error': self.error_message
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class TranscriptionBenchmark:
|
| 90 |
-
"""
|
| 91 |
-
Benchmark runner for transcription models.
|
| 92 |
-
|
| 93 |
-
Evaluates transcription accuracy on a test set and generates reports.
|
| 94 |
-
"""
|
| 95 |
-
|
| 96 |
-
def __init__(
|
| 97 |
-
self,
|
| 98 |
-
test_cases: List[TestCase],
|
| 99 |
-
output_dir: Path,
|
| 100 |
-
onset_tolerance: float = 0.05
|
| 101 |
-
):
|
| 102 |
-
"""
|
| 103 |
-
Initialize benchmark runner.
|
| 104 |
-
|
| 105 |
-
Args:
|
| 106 |
-
test_cases: List of test cases to evaluate
|
| 107 |
-
output_dir: Directory to save results
|
| 108 |
-
onset_tolerance: Onset matching tolerance (seconds)
|
| 109 |
-
"""
|
| 110 |
-
self.test_cases = test_cases
|
| 111 |
-
self.output_dir = Path(output_dir)
|
| 112 |
-
self.onset_tolerance = onset_tolerance
|
| 113 |
-
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 114 |
-
|
| 115 |
-
def run_single_test(
|
| 116 |
-
self,
|
| 117 |
-
test_case: TestCase,
|
| 118 |
-
transcribe_fn,
|
| 119 |
-
output_midi_dir: Path
|
| 120 |
-
) -> BenchmarkResult:
|
| 121 |
-
"""
|
| 122 |
-
Run a single test case.
|
| 123 |
-
|
| 124 |
-
Args:
|
| 125 |
-
test_case: Test case to evaluate
|
| 126 |
-
transcribe_fn: Function that takes audio_path and returns MIDI path
|
| 127 |
-
output_midi_dir: Directory to save transcribed MIDI files
|
| 128 |
-
|
| 129 |
-
Returns:
|
| 130 |
-
BenchmarkResult with metrics and timing
|
| 131 |
-
"""
|
| 132 |
-
print(f"\n{'='*60}")
|
| 133 |
-
print(f"Test: {test_case.name}")
|
| 134 |
-
print(f"Genre: {test_case.genre} | Difficulty: {test_case.difficulty}")
|
| 135 |
-
print(f"{'='*60}")
|
| 136 |
-
|
| 137 |
-
try:
|
| 138 |
-
# Transcribe audio
|
| 139 |
-
start_time = time.time()
|
| 140 |
-
predicted_midi = transcribe_fn(test_case.audio_path, output_midi_dir)
|
| 141 |
-
processing_time = time.time() - start_time
|
| 142 |
-
|
| 143 |
-
if not predicted_midi.exists():
|
| 144 |
-
raise FileNotFoundError(f"Transcription failed: {predicted_midi} not found")
|
| 145 |
-
|
| 146 |
-
print(f"✅ Transcription completed in {processing_time:.1f}s")
|
| 147 |
-
|
| 148 |
-
# Calculate metrics only if ground truth is available
|
| 149 |
-
if test_case.ground_truth_midi:
|
| 150 |
-
metrics = calculate_metrics(
|
| 151 |
-
predicted_midi,
|
| 152 |
-
test_case.ground_truth_midi,
|
| 153 |
-
onset_tolerance=self.onset_tolerance
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
print(f"\n📊 Results:")
|
| 157 |
-
print(f" F1 Score: {metrics.f1_score:.3f}")
|
| 158 |
-
print(f" Precision: {metrics.precision:.3f}")
|
| 159 |
-
print(f" Recall: {metrics.recall:.3f}")
|
| 160 |
-
print(f" Onset MAE: {metrics.onset_mae*1000:.1f}ms")
|
| 161 |
-
else:
|
| 162 |
-
# No ground truth - create placeholder metrics for manual review
|
| 163 |
-
print(f"\n📝 No ground truth available - MIDI saved for manual review")
|
| 164 |
-
print(f" Output: {predicted_midi}")
|
| 165 |
-
metrics = TranscriptionMetrics(
|
| 166 |
-
precision=0.0, recall=0.0, f1_score=0.0,
|
| 167 |
-
onset_mae=0.0, pitch_accuracy=0.0,
|
| 168 |
-
true_positives=0, false_positives=0, false_negatives=0
|
| 169 |
-
)
|
| 170 |
-
|
| 171 |
-
return BenchmarkResult(
|
| 172 |
-
test_case_name=test_case.name,
|
| 173 |
-
genre=test_case.genre,
|
| 174 |
-
difficulty=test_case.difficulty,
|
| 175 |
-
metrics=metrics,
|
| 176 |
-
processing_time=processing_time,
|
| 177 |
-
success=True
|
| 178 |
-
)
|
| 179 |
-
|
| 180 |
-
except Exception as e:
|
| 181 |
-
import traceback
|
| 182 |
-
error_traceback = traceback.format_exc()
|
| 183 |
-
print(f"❌ Test failed: {e}")
|
| 184 |
-
print(f"\nFull traceback:")
|
| 185 |
-
print(error_traceback)
|
| 186 |
-
|
| 187 |
-
# Return placeholder metrics for failed test
|
| 188 |
-
return BenchmarkResult(
|
| 189 |
-
test_case_name=test_case.name,
|
| 190 |
-
genre=test_case.genre,
|
| 191 |
-
difficulty=test_case.difficulty,
|
| 192 |
-
metrics=TranscriptionMetrics(
|
| 193 |
-
precision=0.0, recall=0.0, f1_score=0.0,
|
| 194 |
-
onset_mae=float('inf'), pitch_accuracy=0.0,
|
| 195 |
-
true_positives=0, false_positives=0, false_negatives=0
|
| 196 |
-
),
|
| 197 |
-
processing_time=0.0,
|
| 198 |
-
success=False,
|
| 199 |
-
error_message=str(e)
|
| 200 |
-
)
|
| 201 |
-
|
| 202 |
-
def run_benchmark(self, transcribe_fn, model_name: str = "model") -> List[BenchmarkResult]:
|
| 203 |
-
"""
|
| 204 |
-
Run full benchmark on all test cases.
|
| 205 |
-
|
| 206 |
-
Args:
|
| 207 |
-
transcribe_fn: Function that transcribes audio to MIDI
|
| 208 |
-
model_name: Name of model being tested (for output files)
|
| 209 |
-
|
| 210 |
-
Returns:
|
| 211 |
-
List of BenchmarkResult objects
|
| 212 |
-
"""
|
| 213 |
-
print(f"\n🎹 Starting Benchmark: {model_name}")
|
| 214 |
-
print(f"📝 Test cases: {len(self.test_cases)}")
|
| 215 |
-
print(f"⏱️ Onset tolerance: {self.onset_tolerance*1000:.0f}ms")
|
| 216 |
-
|
| 217 |
-
# Create output directory for transcribed MIDI
|
| 218 |
-
output_midi_dir = self.output_dir / f"{model_name}_midi"
|
| 219 |
-
output_midi_dir.mkdir(parents=True, exist_ok=True)
|
| 220 |
-
|
| 221 |
-
results = []
|
| 222 |
-
for i, test_case in enumerate(self.test_cases, 1):
|
| 223 |
-
print(f"\n[{i}/{len(self.test_cases)}]", end=" ")
|
| 224 |
-
result = self.run_single_test(test_case, transcribe_fn, output_midi_dir)
|
| 225 |
-
results.append(result)
|
| 226 |
-
|
| 227 |
-
# Save results
|
| 228 |
-
self._save_results(results, model_name)
|
| 229 |
-
self._print_summary(results, model_name)
|
| 230 |
-
|
| 231 |
-
return results
|
| 232 |
-
|
| 233 |
-
def _save_results(self, results: List[BenchmarkResult], model_name: str):
|
| 234 |
-
"""Save benchmark results to JSON and CSV."""
|
| 235 |
-
# JSON format (detailed)
|
| 236 |
-
json_path = self.output_dir / f"{model_name}_results.json"
|
| 237 |
-
with open(json_path, 'w') as f:
|
| 238 |
-
json.dump([r.to_dict() for r in results], f, indent=2)
|
| 239 |
-
print(f"\n💾 Saved detailed results to: {json_path}")
|
| 240 |
-
|
| 241 |
-
# CSV format (for spreadsheet analysis)
|
| 242 |
-
csv_path = self.output_dir / f"{model_name}_results.csv"
|
| 243 |
-
df = pd.DataFrame([r.to_dict() for r in results])
|
| 244 |
-
df.to_csv(csv_path, index=False)
|
| 245 |
-
print(f"💾 Saved CSV results to: {csv_path}")
|
| 246 |
-
|
| 247 |
-
def _print_summary(self, results: List[BenchmarkResult], model_name: str):
|
| 248 |
-
"""Print summary statistics."""
|
| 249 |
-
successful = [r for r in results if r.success]
|
| 250 |
-
failed = [r for r in results if not r.success]
|
| 251 |
-
|
| 252 |
-
print(f"\n{'='*60}")
|
| 253 |
-
print(f"📊 BENCHMARK SUMMARY: {model_name}")
|
| 254 |
-
print(f"{'='*60}")
|
| 255 |
-
print(f"Total tests: {len(results)}")
|
| 256 |
-
print(f"Successful: {len(successful)}")
|
| 257 |
-
print(f"Failed: {len(failed)}")
|
| 258 |
-
|
| 259 |
-
if len(successful) == 0:
|
| 260 |
-
print("\n❌ All tests failed!")
|
| 261 |
-
return
|
| 262 |
-
|
| 263 |
-
# Overall metrics
|
| 264 |
-
avg_f1 = sum(r.metrics.f1_score for r in successful) / len(successful)
|
| 265 |
-
avg_precision = sum(r.metrics.precision for r in successful) / len(successful)
|
| 266 |
-
avg_recall = sum(r.metrics.recall for r in successful) / len(successful)
|
| 267 |
-
avg_onset_mae = sum(r.metrics.onset_mae for r in successful) / len(successful)
|
| 268 |
-
avg_time = sum(r.processing_time for r in successful) / len(successful)
|
| 269 |
-
|
| 270 |
-
print(f"\n📈 Overall Accuracy:")
|
| 271 |
-
print(f" F1 Score: {avg_f1:.3f}")
|
| 272 |
-
print(f" Precision: {avg_precision:.3f}")
|
| 273 |
-
print(f" Recall: {avg_recall:.3f}")
|
| 274 |
-
print(f" Onset MAE: {avg_onset_mae*1000:.1f}ms")
|
| 275 |
-
print(f" Avg Processing Time: {avg_time:.1f}s")
|
| 276 |
-
|
| 277 |
-
# By genre
|
| 278 |
-
genres = set(r.genre for r in successful)
|
| 279 |
-
print(f"\n📊 By Genre:")
|
| 280 |
-
for genre in sorted(genres):
|
| 281 |
-
genre_results = [r for r in successful if r.genre == genre]
|
| 282 |
-
genre_f1 = sum(r.metrics.f1_score for r in genre_results) / len(genre_results)
|
| 283 |
-
print(f" {genre.capitalize()}: F1={genre_f1:.3f} ({len(genre_results)} tests)")
|
| 284 |
-
|
| 285 |
-
# By difficulty
|
| 286 |
-
difficulties = set(r.difficulty for r in successful)
|
| 287 |
-
print(f"\n📊 By Difficulty:")
|
| 288 |
-
for diff in ['easy', 'medium', 'hard']:
|
| 289 |
-
if diff in difficulties:
|
| 290 |
-
diff_results = [r for r in successful if r.difficulty == diff]
|
| 291 |
-
diff_f1 = sum(r.metrics.f1_score for r in diff_results) / len(diff_results)
|
| 292 |
-
print(f" {diff.capitalize()}: F1={diff_f1:.3f} ({len(diff_results)} tests)")
|
| 293 |
-
|
| 294 |
-
print(f"\n{'='*60}\n")
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
def load_test_cases_from_json(json_path: Path) -> List[TestCase]:
|
| 298 |
-
"""Load test cases from JSON file."""
|
| 299 |
-
with open(json_path, 'r') as f:
|
| 300 |
-
data = json.load(f)
|
| 301 |
-
return [TestCase.from_dict(case) for case in data]
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
def save_test_cases_to_json(test_cases: List[TestCase], json_path: Path):
|
| 305 |
-
"""Save test cases to JSON file."""
|
| 306 |
-
with open(json_path, 'w') as f:
|
| 307 |
-
json.dump([tc.to_dict() for tc in test_cases], f, indent=2)
|
| 308 |
-
print(f"💾 Saved {len(test_cases)} test cases to: {json_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/evaluation/generate_full_test_set.py
DELETED
|
@@ -1,66 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Generate full 8-piece MAESTRO test set.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
import json
|
| 7 |
-
|
| 8 |
-
# Test cases from prepare_maestro.py
|
| 9 |
-
MAESTRO_SUBSET = [
|
| 10 |
-
# Easy - Simple classical pieces
|
| 11 |
-
("2004", "test", "MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_03_Track03_wav", "classical", "easy"),
|
| 12 |
-
("2004", "test", "MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav", "classical", "easy"),
|
| 13 |
-
|
| 14 |
-
# Medium - Chopin, moderate tempo
|
| 15 |
-
("2004", "test", "MIDI-Unprocessed_XP_14_R1_2004_01-04_ORIG_MID--AUDIO_14_R1_2004_04_Track04_wav", "classical", "medium"),
|
| 16 |
-
("2006", "test", "MIDI-Unprocessed_07_R1_2006_01-09_ORIG_MID--AUDIO_07_R1_2006_04_Track04_wav", "classical", "medium"),
|
| 17 |
-
("2008", "test", "MIDI-Unprocessed_11_R1_2008_01-04_ORIG_MID--AUDIO_11_R1_2008_02_Track02_wav", "classical", "medium"),
|
| 18 |
-
|
| 19 |
-
# Hard - Fast passages, complex harmony
|
| 20 |
-
("2009", "test", "MIDI-Unprocessed_16_R1_2009_01-04_ORIG_MID--AUDIO_16_R1_2009_16_R1_2009_02_WAV", "classical", "hard"),
|
| 21 |
-
("2011", "test", "MIDI-Unprocessed_03_R1_2011_MID--AUDIO_03_R1_2011_03_R1_2011_02_WAV", "classical", "hard"),
|
| 22 |
-
("2013", "test", "MIDI-Unprocessed_20_R1_2013_MID--AUDIO_20_R1_2013_20_R1_2013_02_WAV", "classical", "hard"),
|
| 23 |
-
]
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def generate_test_cases(maestro_dir: str = "../../data/maestro-v3.0.0"):
|
| 27 |
-
"""Generate test_videos.json with all 8 MAESTRO pieces."""
|
| 28 |
-
|
| 29 |
-
test_cases = []
|
| 30 |
-
|
| 31 |
-
for year, split, filename_prefix, genre, difficulty in MAESTRO_SUBSET:
|
| 32 |
-
# Construct paths
|
| 33 |
-
audio_path = f"{maestro_dir}/{year}/{filename_prefix}.wav"
|
| 34 |
-
midi_path = f"{maestro_dir}/{year}/{filename_prefix}.midi"
|
| 35 |
-
|
| 36 |
-
# Create test case
|
| 37 |
-
test_case = {
|
| 38 |
-
"name": filename_prefix,
|
| 39 |
-
"audio_path": audio_path,
|
| 40 |
-
"ground_truth_midi": midi_path,
|
| 41 |
-
"genre": genre,
|
| 42 |
-
"difficulty": difficulty,
|
| 43 |
-
"duration": None
|
| 44 |
-
}
|
| 45 |
-
|
| 46 |
-
test_cases.append(test_case)
|
| 47 |
-
|
| 48 |
-
return test_cases
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
if __name__ == "__main__":
|
| 52 |
-
# Generate test cases
|
| 53 |
-
test_cases = generate_test_cases()
|
| 54 |
-
|
| 55 |
-
# Save to JSON
|
| 56 |
-
output_file = Path(__file__).parent / "test_videos.json"
|
| 57 |
-
|
| 58 |
-
with open(output_file, 'w') as f:
|
| 59 |
-
json.dump(test_cases, f, indent=2)
|
| 60 |
-
|
| 61 |
-
print(f"✅ Generated {len(test_cases)} test cases")
|
| 62 |
-
print(f"📝 Saved to: {output_file}")
|
| 63 |
-
print("\nBreakdown:")
|
| 64 |
-
print(f" - Easy: {sum(1 for tc in test_cases if tc['difficulty'] == 'easy')}")
|
| 65 |
-
print(f" - Medium: {sum(1 for tc in test_cases if tc['difficulty'] == 'medium')}")
|
| 66 |
-
print(f" - Hard: {sum(1 for tc in test_cases if tc['difficulty'] == 'hard')}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/evaluation/metrics.py
DELETED
|
@@ -1,253 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Transcription accuracy metrics for piano transcription evaluation.
|
| 3 |
-
|
| 4 |
-
Implements F1 score, precision, recall, and timing accuracy for comparing
|
| 5 |
-
predicted MIDI against ground truth MIDI files.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from typing import List, Tuple, Optional
|
| 11 |
-
import numpy as np
|
| 12 |
-
from mido import MidiFile
|
| 13 |
-
import pretty_midi
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
@dataclass
|
| 17 |
-
class Note:
|
| 18 |
-
"""Represents a musical note with timing and pitch information."""
|
| 19 |
-
pitch: int # MIDI pitch (0-127)
|
| 20 |
-
onset: float # Start time in seconds
|
| 21 |
-
offset: float # End time in seconds
|
| 22 |
-
velocity: int = 64 # Note velocity (0-127)
|
| 23 |
-
|
| 24 |
-
@property
|
| 25 |
-
def duration(self) -> float:
|
| 26 |
-
"""Note duration in seconds."""
|
| 27 |
-
return self.offset - self.onset
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
@dataclass
|
| 31 |
-
class TranscriptionMetrics:
|
| 32 |
-
"""Container for transcription evaluation metrics."""
|
| 33 |
-
precision: float # True positives / (True positives + False positives)
|
| 34 |
-
recall: float # True positives / (True positives + False negatives)
|
| 35 |
-
f1_score: float # Harmonic mean of precision and recall
|
| 36 |
-
onset_mae: float # Mean absolute error for note onsets (seconds)
|
| 37 |
-
pitch_accuracy: float # Percentage of correct pitches (given correct onset)
|
| 38 |
-
true_positives: int
|
| 39 |
-
false_positives: int
|
| 40 |
-
false_negatives: int
|
| 41 |
-
|
| 42 |
-
def __str__(self) -> str:
|
| 43 |
-
"""Human-readable metrics summary."""
|
| 44 |
-
return (
|
| 45 |
-
f"F1 Score: {self.f1_score:.3f}\n"
|
| 46 |
-
f"Precision: {self.precision:.3f}\n"
|
| 47 |
-
f"Recall: {self.recall:.3f}\n"
|
| 48 |
-
f"Onset MAE: {self.onset_mae*1000:.1f}ms\n"
|
| 49 |
-
f"Pitch Accuracy: {self.pitch_accuracy:.3f}\n"
|
| 50 |
-
f"TP: {self.true_positives}, FP: {self.false_positives}, FN: {self.false_negatives}"
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def extract_notes_from_midi(midi_path: Path) -> List[Note]:
|
| 55 |
-
"""
|
| 56 |
-
Extract notes from a MIDI file using pretty_midi.
|
| 57 |
-
|
| 58 |
-
Args:
|
| 59 |
-
midi_path: Path to MIDI file
|
| 60 |
-
|
| 61 |
-
Returns:
|
| 62 |
-
List of Note objects sorted by onset time
|
| 63 |
-
"""
|
| 64 |
-
try:
|
| 65 |
-
pm = pretty_midi.PrettyMIDI(str(midi_path))
|
| 66 |
-
except Exception as e:
|
| 67 |
-
raise ValueError(f"Failed to load MIDI file {midi_path}: {e}")
|
| 68 |
-
|
| 69 |
-
notes = []
|
| 70 |
-
for instrument in pm.instruments:
|
| 71 |
-
# Skip drum tracks
|
| 72 |
-
if instrument.is_drum:
|
| 73 |
-
continue
|
| 74 |
-
|
| 75 |
-
for note in instrument.notes:
|
| 76 |
-
notes.append(Note(
|
| 77 |
-
pitch=note.pitch,
|
| 78 |
-
onset=note.start,
|
| 79 |
-
offset=note.end,
|
| 80 |
-
velocity=note.velocity
|
| 81 |
-
))
|
| 82 |
-
|
| 83 |
-
# Sort by onset time
|
| 84 |
-
notes.sort(key=lambda n: n.onset)
|
| 85 |
-
return notes
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def match_notes(
|
| 89 |
-
predicted_notes: List[Note],
|
| 90 |
-
ground_truth_notes: List[Note],
|
| 91 |
-
onset_tolerance: float = 0.05, # 50ms
|
| 92 |
-
pitch_tolerance: int = 0 # Exact pitch match
|
| 93 |
-
) -> Tuple[List[Tuple[Note, Note]], List[Note], List[Note]]:
|
| 94 |
-
"""
|
| 95 |
-
Match predicted notes to ground truth notes using onset and pitch tolerance.
|
| 96 |
-
|
| 97 |
-
Uses greedy matching: for each ground truth note, find the closest predicted
|
| 98 |
-
note within tolerance. A predicted note can only match one ground truth note.
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
predicted_notes: List of predicted notes
|
| 102 |
-
ground_truth_notes: List of ground truth notes
|
| 103 |
-
onset_tolerance: Maximum time difference (seconds) to consider a match
|
| 104 |
-
pitch_tolerance: Maximum pitch difference (semitones) to consider a match
|
| 105 |
-
|
| 106 |
-
Returns:
|
| 107 |
-
Tuple of (matches, false_positives, false_negatives) where:
|
| 108 |
-
- matches: List of (predicted_note, ground_truth_note) pairs
|
| 109 |
-
- false_positives: Predicted notes with no match
|
| 110 |
-
- false_negatives: Ground truth notes with no match
|
| 111 |
-
"""
|
| 112 |
-
matches = []
|
| 113 |
-
matched_pred_indices = set()
|
| 114 |
-
unmatched_gt = []
|
| 115 |
-
|
| 116 |
-
# For each ground truth note, find best matching predicted note
|
| 117 |
-
for gt_note in ground_truth_notes:
|
| 118 |
-
best_match_idx = None
|
| 119 |
-
best_onset_diff = float('inf')
|
| 120 |
-
|
| 121 |
-
for i, pred_note in enumerate(predicted_notes):
|
| 122 |
-
if i in matched_pred_indices:
|
| 123 |
-
continue # Already matched
|
| 124 |
-
|
| 125 |
-
# Check pitch tolerance
|
| 126 |
-
pitch_diff = abs(pred_note.pitch - gt_note.pitch)
|
| 127 |
-
if pitch_diff > pitch_tolerance:
|
| 128 |
-
continue
|
| 129 |
-
|
| 130 |
-
# Check onset tolerance
|
| 131 |
-
onset_diff = abs(pred_note.onset - gt_note.onset)
|
| 132 |
-
if onset_diff <= onset_tolerance and onset_diff < best_onset_diff:
|
| 133 |
-
best_match_idx = i
|
| 134 |
-
best_onset_diff = onset_diff
|
| 135 |
-
|
| 136 |
-
if best_match_idx is not None:
|
| 137 |
-
matches.append((predicted_notes[best_match_idx], gt_note))
|
| 138 |
-
matched_pred_indices.add(best_match_idx)
|
| 139 |
-
else:
|
| 140 |
-
unmatched_gt.append(gt_note)
|
| 141 |
-
|
| 142 |
-
# Unmatched predicted notes are false positives
|
| 143 |
-
false_positives = [
|
| 144 |
-
note for i, note in enumerate(predicted_notes)
|
| 145 |
-
if i not in matched_pred_indices
|
| 146 |
-
]
|
| 147 |
-
|
| 148 |
-
return matches, false_positives, unmatched_gt
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def calculate_metrics(
|
| 152 |
-
predicted_midi: Path,
|
| 153 |
-
ground_truth_midi: Path,
|
| 154 |
-
onset_tolerance: float = 0.05, # 50ms
|
| 155 |
-
pitch_tolerance: int = 0 # Exact pitch
|
| 156 |
-
) -> TranscriptionMetrics:
|
| 157 |
-
"""
|
| 158 |
-
Calculate transcription accuracy metrics by comparing predicted vs ground truth MIDI.
|
| 159 |
-
|
| 160 |
-
Args:
|
| 161 |
-
predicted_midi: Path to predicted MIDI file
|
| 162 |
-
ground_truth_midi: Path to ground truth MIDI file
|
| 163 |
-
onset_tolerance: Maximum onset time difference for matching (seconds)
|
| 164 |
-
pitch_tolerance: Maximum pitch difference for matching (semitones)
|
| 165 |
-
|
| 166 |
-
Returns:
|
| 167 |
-
TranscriptionMetrics object with all evaluation metrics
|
| 168 |
-
"""
|
| 169 |
-
# Extract notes from both files
|
| 170 |
-
pred_notes = extract_notes_from_midi(predicted_midi)
|
| 171 |
-
gt_notes = extract_notes_from_midi(ground_truth_midi)
|
| 172 |
-
|
| 173 |
-
if len(gt_notes) == 0:
|
| 174 |
-
raise ValueError(f"Ground truth MIDI has no notes: {ground_truth_midi}")
|
| 175 |
-
|
| 176 |
-
# Match notes
|
| 177 |
-
matches, false_positives, false_negatives = match_notes(
|
| 178 |
-
pred_notes, gt_notes, onset_tolerance, pitch_tolerance
|
| 179 |
-
)
|
| 180 |
-
|
| 181 |
-
# Calculate counts
|
| 182 |
-
true_positives = len(matches)
|
| 183 |
-
num_false_positives = len(false_positives)
|
| 184 |
-
num_false_negatives = len(false_negatives)
|
| 185 |
-
|
| 186 |
-
# Calculate precision and recall
|
| 187 |
-
precision = true_positives / (true_positives + num_false_positives) if (true_positives + num_false_positives) > 0 else 0.0
|
| 188 |
-
recall = true_positives / (true_positives + num_false_negatives) if (true_positives + num_false_negatives) > 0 else 0.0
|
| 189 |
-
|
| 190 |
-
# Calculate F1 score
|
| 191 |
-
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 192 |
-
|
| 193 |
-
# Calculate onset MAE (only for matched notes)
|
| 194 |
-
if len(matches) > 0:
|
| 195 |
-
onset_errors = [abs(pred.onset - gt.onset) for pred, gt in matches]
|
| 196 |
-
onset_mae = np.mean(onset_errors)
|
| 197 |
-
else:
|
| 198 |
-
onset_mae = float('inf')
|
| 199 |
-
|
| 200 |
-
# Calculate pitch accuracy (for matched notes)
|
| 201 |
-
if len(matches) > 0:
|
| 202 |
-
pitch_correct = sum(1 for pred, gt in matches if pred.pitch == gt.pitch)
|
| 203 |
-
pitch_accuracy = pitch_correct / len(matches)
|
| 204 |
-
else:
|
| 205 |
-
pitch_accuracy = 0.0
|
| 206 |
-
|
| 207 |
-
return TranscriptionMetrics(
|
| 208 |
-
precision=precision,
|
| 209 |
-
recall=recall,
|
| 210 |
-
f1_score=f1_score,
|
| 211 |
-
onset_mae=onset_mae,
|
| 212 |
-
pitch_accuracy=pitch_accuracy,
|
| 213 |
-
true_positives=true_positives,
|
| 214 |
-
false_positives=num_false_positives,
|
| 215 |
-
false_negatives=num_false_negatives
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
def calculate_metrics_by_difficulty(
|
| 220 |
-
predicted_midi: Path,
|
| 221 |
-
ground_truth_midi: Path,
|
| 222 |
-
onset_tolerance: float = 0.05
|
| 223 |
-
) -> dict:
|
| 224 |
-
"""
|
| 225 |
-
Calculate metrics at multiple onset tolerances to assess difficulty.
|
| 226 |
-
|
| 227 |
-
Stricter tolerances (20ms) test timing accuracy for simple music.
|
| 228 |
-
Looser tolerances (100ms) are more forgiving for complex/fast passages.
|
| 229 |
-
|
| 230 |
-
Args:
|
| 231 |
-
predicted_midi: Path to predicted MIDI file
|
| 232 |
-
ground_truth_midi: Path to ground truth MIDI file
|
| 233 |
-
onset_tolerance: Default onset tolerance (seconds)
|
| 234 |
-
|
| 235 |
-
Returns:
|
| 236 |
-
Dictionary with metrics at different tolerance levels
|
| 237 |
-
"""
|
| 238 |
-
tolerances = {
|
| 239 |
-
'strict': 0.02, # 20ms - for simple piano melodies
|
| 240 |
-
'default': 0.05, # 50ms - standard evaluation
|
| 241 |
-
'lenient': 0.10 # 100ms - for fast/complex passages
|
| 242 |
-
}
|
| 243 |
-
|
| 244 |
-
results = {}
|
| 245 |
-
for name, tol in tolerances.items():
|
| 246 |
-
try:
|
| 247 |
-
metrics = calculate_metrics(predicted_midi, ground_truth_midi, onset_tolerance=tol)
|
| 248 |
-
results[name] = metrics
|
| 249 |
-
except Exception as e:
|
| 250 |
-
print(f"Warning: Failed to calculate {name} metrics: {e}")
|
| 251 |
-
results[name] = None
|
| 252 |
-
|
| 253 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/evaluation/prepare_maestro.py
DELETED
|
@@ -1,259 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Download and prepare MAESTRO dataset subset for benchmarking.
|
| 3 |
-
|
| 4 |
-
MAESTRO (MIDI and Audio Edited for Synchronous TRacks and Organization) is a
|
| 5 |
-
dataset of ~200 hours of piano performances with aligned MIDI.
|
| 6 |
-
|
| 7 |
-
We'll download a curated subset for testing:
|
| 8 |
-
- Simple pieces (easy difficulty)
|
| 9 |
-
- Classical pieces (Chopin, Bach - medium/hard)
|
| 10 |
-
- Varied tempo and complexity
|
| 11 |
-
|
| 12 |
-
Dataset info: https://magenta.tensorflow.org/datasets/maestro
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
import json
|
| 16 |
-
import subprocess
|
| 17 |
-
from pathlib import Path
|
| 18 |
-
from typing import List
|
| 19 |
-
import urllib.request
|
| 20 |
-
import zipfile
|
| 21 |
-
import shutil
|
| 22 |
-
|
| 23 |
-
from evaluation.benchmark import TestCase, save_test_cases_to_json
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# Curated subset of MAESTRO for testing
|
| 27 |
-
# Format: (year, split, filename_prefix, genre, difficulty)
|
| 28 |
-
MAESTRO_SUBSET = [
|
| 29 |
-
# Easy - Simple classical pieces
|
| 30 |
-
("2004", "test", "MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_03_Track03_wav", "classical", "easy"),
|
| 31 |
-
("2004", "test", "MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav", "classical", "easy"),
|
| 32 |
-
|
| 33 |
-
# Medium - Chopin, moderate tempo
|
| 34 |
-
("2004", "test", "MIDI-Unprocessed_XP_14_R1_2004_01-04_ORIG_MID--AUDIO_14_R1_2004_04_Track04_wav", "classical", "medium"),
|
| 35 |
-
("2006", "test", "MIDI-Unprocessed_07_R1_2006_01-09_ORIG_MID--AUDIO_07_R1_2006_04_Track04_wav", "classical", "medium"),
|
| 36 |
-
("2008", "test", "MIDI-Unprocessed_11_R1_2008_01-04_ORIG_MID--AUDIO_11_R1_2008_02_Track02_wav", "classical", "medium"),
|
| 37 |
-
|
| 38 |
-
# Hard - Fast passages, complex harmony
|
| 39 |
-
("2009", "test", "MIDI-Unprocessed_16_R1_2009_01-04_ORIG_MID--AUDIO_16_R1_2009_16_R1_2009_02_WAV", "classical", "hard"),
|
| 40 |
-
("2011", "test", "MIDI-Unprocessed_03_R1_2011_MID--AUDIO_03_R1_2011_03_R1_2011_02_WAV", "classical", "hard"),
|
| 41 |
-
("2013", "test", "MIDI-Unprocessed_20_R1_2013_MID--AUDIO_20_R1_2013_20_R1_2013_02_WAV", "classical", "hard"),
|
| 42 |
-
]
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def download_maestro_subset(
|
| 46 |
-
output_dir: Path,
|
| 47 |
-
version: str = "v3.0.0"
|
| 48 |
-
) -> Path:
|
| 49 |
-
"""
|
| 50 |
-
Download MAESTRO dataset (full version - 100+ GB).
|
| 51 |
-
|
| 52 |
-
Note: This is a large download. For testing, we'll only use a small subset.
|
| 53 |
-
|
| 54 |
-
Args:
|
| 55 |
-
output_dir: Directory to save dataset
|
| 56 |
-
version: MAESTRO version to download
|
| 57 |
-
|
| 58 |
-
Returns:
|
| 59 |
-
Path to extracted dataset directory
|
| 60 |
-
"""
|
| 61 |
-
output_dir = Path(output_dir)
|
| 62 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
| 63 |
-
|
| 64 |
-
maestro_dir = output_dir / f"maestro-{version}"
|
| 65 |
-
|
| 66 |
-
if maestro_dir.exists():
|
| 67 |
-
print(f"✅ MAESTRO dataset already exists at: {maestro_dir}")
|
| 68 |
-
return maestro_dir
|
| 69 |
-
|
| 70 |
-
# Download URL
|
| 71 |
-
url = f"https://storage.googleapis.com/magentadata/datasets/maestro/{version}/maestro-{version}.zip"
|
| 72 |
-
zip_path = output_dir / f"maestro-{version}.zip"
|
| 73 |
-
|
| 74 |
-
print(f"⬇️ Downloading MAESTRO {version} (this may take a while - ~100GB)...")
|
| 75 |
-
print(f" From: {url}")
|
| 76 |
-
print(f" To: {zip_path}")
|
| 77 |
-
print("\n ⚠️ WARNING: This is a LARGE download (100+ GB)!")
|
| 78 |
-
print(" Consider downloading manually and extracting to:", output_dir)
|
| 79 |
-
|
| 80 |
-
# For now, we'll skip auto-download and assume user has dataset
|
| 81 |
-
# or provide instructions for manual download
|
| 82 |
-
raise NotImplementedError(
|
| 83 |
-
f"Please download MAESTRO manually from:\n"
|
| 84 |
-
f" {url}\n"
|
| 85 |
-
f"Extract to: {output_dir}\n"
|
| 86 |
-
f"Or use the maestro-downloader package: pip install maestro-downloader"
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def find_maestro_files(
|
| 91 |
-
maestro_dir: Path,
|
| 92 |
-
test_case_prefix: str,
|
| 93 |
-
year: str
|
| 94 |
-
) -> tuple:
|
| 95 |
-
"""
|
| 96 |
-
Find audio and MIDI files for a MAESTRO test case.
|
| 97 |
-
|
| 98 |
-
Args:
|
| 99 |
-
maestro_dir: Path to MAESTRO dataset root
|
| 100 |
-
test_case_prefix: Filename prefix (without extension)
|
| 101 |
-
year: Year subdirectory
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
Tuple of (audio_path, midi_path) or (None, None) if not found
|
| 105 |
-
"""
|
| 106 |
-
year_dir = maestro_dir / year
|
| 107 |
-
|
| 108 |
-
# Look for audio file (.wav)
|
| 109 |
-
audio_path = year_dir / f"{test_case_prefix}.wav"
|
| 110 |
-
if not audio_path.exists():
|
| 111 |
-
# Try alternative naming
|
| 112 |
-
audio_path = year_dir / f"{test_case_prefix}.flac"
|
| 113 |
-
|
| 114 |
-
# Look for MIDI file (.midi or .mid)
|
| 115 |
-
midi_path = year_dir / f"{test_case_prefix}.midi"
|
| 116 |
-
if not midi_path.exists():
|
| 117 |
-
midi_path = year_dir / f"{test_case_prefix}.mid"
|
| 118 |
-
|
| 119 |
-
if audio_path.exists() and midi_path.exists():
|
| 120 |
-
return audio_path, midi_path
|
| 121 |
-
|
| 122 |
-
return None, None
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def create_maestro_test_cases(
|
| 126 |
-
maestro_dir: Path,
|
| 127 |
-
subset: List[tuple] = MAESTRO_SUBSET
|
| 128 |
-
) -> List[TestCase]:
|
| 129 |
-
"""
|
| 130 |
-
Create test cases from MAESTRO dataset subset.
|
| 131 |
-
|
| 132 |
-
Args:
|
| 133 |
-
maestro_dir: Path to MAESTRO dataset root
|
| 134 |
-
subset: List of (year, split, prefix, genre, difficulty) tuples
|
| 135 |
-
|
| 136 |
-
Returns:
|
| 137 |
-
List of TestCase objects
|
| 138 |
-
"""
|
| 139 |
-
test_cases = []
|
| 140 |
-
|
| 141 |
-
for year, split, prefix, genre, difficulty in subset:
|
| 142 |
-
audio_path, midi_path = find_maestro_files(maestro_dir, prefix, year)
|
| 143 |
-
|
| 144 |
-
if audio_path and midi_path:
|
| 145 |
-
# Extract a readable name from the filename
|
| 146 |
-
name = prefix.split("--")[-1].replace("_", " ").replace(".wav", "")
|
| 147 |
-
|
| 148 |
-
test_case = TestCase(
|
| 149 |
-
name=f"MAESTRO_{year}_{name[:50]}", # Truncate long names
|
| 150 |
-
audio_path=audio_path,
|
| 151 |
-
ground_truth_midi=midi_path,
|
| 152 |
-
genre=genre,
|
| 153 |
-
difficulty=difficulty
|
| 154 |
-
)
|
| 155 |
-
test_cases.append(test_case)
|
| 156 |
-
print(f"✅ Added test case: {test_case.name}")
|
| 157 |
-
else:
|
| 158 |
-
print(f"⚠️ Skipping (files not found): {year}/{prefix}")
|
| 159 |
-
|
| 160 |
-
return test_cases
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
def create_simple_test_cases(output_dir: Path) -> List[TestCase]:
|
| 164 |
-
"""
|
| 165 |
-
Create simple test cases for initial testing (without MAESTRO).
|
| 166 |
-
|
| 167 |
-
Uses synthesized MIDI or publicly available simple piano pieces.
|
| 168 |
-
Useful for testing the pipeline without downloading MAESTRO.
|
| 169 |
-
|
| 170 |
-
Args:
|
| 171 |
-
output_dir: Directory to save test files
|
| 172 |
-
|
| 173 |
-
Returns:
|
| 174 |
-
List of TestCase objects
|
| 175 |
-
"""
|
| 176 |
-
test_cases = []
|
| 177 |
-
output_dir = Path(output_dir)
|
| 178 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
| 179 |
-
|
| 180 |
-
# For now, return empty list with instructions
|
| 181 |
-
print("📝 To use MAESTRO dataset:")
|
| 182 |
-
print(" 1. Download MAESTRO v3.0.0 from: https://magenta.tensorflow.org/datasets/maestro")
|
| 183 |
-
print(" 2. Extract to a directory (e.g., /tmp/maestro-v3.0.0/)")
|
| 184 |
-
print(" 3. Run: prepare_maestro_test_cases('/tmp/maestro-v3.0.0/')")
|
| 185 |
-
print("")
|
| 186 |
-
print("📝 Or create custom test cases with your own audio + MIDI files")
|
| 187 |
-
|
| 188 |
-
return test_cases
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
def prepare_maestro_test_cases(
|
| 192 |
-
maestro_dir: Path,
|
| 193 |
-
output_json: Path
|
| 194 |
-
) -> List[TestCase]:
|
| 195 |
-
"""
|
| 196 |
-
Main function to prepare MAESTRO test cases and save to JSON.
|
| 197 |
-
|
| 198 |
-
Args:
|
| 199 |
-
maestro_dir: Path to MAESTRO dataset root directory
|
| 200 |
-
output_json: Path to save test_videos.json
|
| 201 |
-
|
| 202 |
-
Returns:
|
| 203 |
-
List of TestCase objects
|
| 204 |
-
"""
|
| 205 |
-
maestro_dir = Path(maestro_dir)
|
| 206 |
-
|
| 207 |
-
if not maestro_dir.exists():
|
| 208 |
-
raise FileNotFoundError(
|
| 209 |
-
f"MAESTRO directory not found: {maestro_dir}\n"
|
| 210 |
-
f"Please download and extract MAESTRO dataset first."
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
print(f"🎹 Preparing MAESTRO test cases from: {maestro_dir}")
|
| 214 |
-
|
| 215 |
-
# Create test cases from subset
|
| 216 |
-
test_cases = create_maestro_test_cases(maestro_dir, MAESTRO_SUBSET)
|
| 217 |
-
|
| 218 |
-
if len(test_cases) == 0:
|
| 219 |
-
raise ValueError(
|
| 220 |
-
"No test cases created! Check if MAESTRO directory structure is correct.\n"
|
| 221 |
-
f"Expected structure: {maestro_dir}/YYYY/*.wav and *.midi"
|
| 222 |
-
)
|
| 223 |
-
|
| 224 |
-
# Save to JSON
|
| 225 |
-
save_test_cases_to_json(test_cases, output_json)
|
| 226 |
-
|
| 227 |
-
print(f"\n✅ Created {len(test_cases)} test cases")
|
| 228 |
-
print(f" Easy: {sum(1 for tc in test_cases if tc.difficulty == 'easy')}")
|
| 229 |
-
print(f" Medium: {sum(1 for tc in test_cases if tc.difficulty == 'medium')}")
|
| 230 |
-
print(f" Hard: {sum(1 for tc in test_cases if tc.difficulty == 'hard')}")
|
| 231 |
-
|
| 232 |
-
return test_cases
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
if __name__ == "__main__":
|
| 236 |
-
import argparse
|
| 237 |
-
|
| 238 |
-
parser = argparse.ArgumentParser(description="Prepare MAESTRO dataset for benchmarking")
|
| 239 |
-
parser.add_argument(
|
| 240 |
-
"--maestro-dir",
|
| 241 |
-
type=Path,
|
| 242 |
-
required=True,
|
| 243 |
-
help="Path to MAESTRO dataset root directory (e.g., /tmp/maestro-v3.0.0)"
|
| 244 |
-
)
|
| 245 |
-
parser.add_argument(
|
| 246 |
-
"--output-json",
|
| 247 |
-
type=Path,
|
| 248 |
-
default=Path("backend/evaluation/test_videos.json"),
|
| 249 |
-
help="Path to save test cases JSON file"
|
| 250 |
-
)
|
| 251 |
-
|
| 252 |
-
args = parser.parse_args()
|
| 253 |
-
|
| 254 |
-
test_cases = prepare_maestro_test_cases(args.maestro_dir, args.output_json)
|
| 255 |
-
|
| 256 |
-
print(f"\n🎯 Next steps:")
|
| 257 |
-
print(f" 1. Run baseline benchmark:")
|
| 258 |
-
print(f" python -m evaluation.run_benchmark --model yourmt3")
|
| 259 |
-
print(f" 2. Compare with other models after implementing them")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/evaluation/run_benchmark.py
DELETED
|
@@ -1,192 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Main script to run transcription benchmarks.
|
| 3 |
-
|
| 4 |
-
Usage:
|
| 5 |
-
# Prepare MAESTRO test cases (one-time setup)
|
| 6 |
-
python -m evaluation.prepare_maestro --maestro-dir /path/to/maestro-v3.0.0
|
| 7 |
-
|
| 8 |
-
# Run baseline benchmark on YourMT3+
|
| 9 |
-
python -m evaluation.run_benchmark --model yourmt3 --test-cases evaluation/test_videos.json
|
| 10 |
-
|
| 11 |
-
# Compare with ensemble after Phase 3
|
| 12 |
-
python -m evaluation.run_benchmark --model ensemble --test-cases evaluation/test_videos.json
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
import argparse
|
| 16 |
-
import sys
|
| 17 |
-
from pathlib import Path
|
| 18 |
-
|
| 19 |
-
# Add parent directory to path to import backend modules
|
| 20 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 21 |
-
|
| 22 |
-
from evaluation.benchmark import TranscriptionBenchmark, load_test_cases_from_json
|
| 23 |
-
from evaluation.metrics import calculate_metrics
|
| 24 |
-
from pipeline import TranscriptionPipeline
|
| 25 |
-
from app_config import Settings
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def transcribe_with_yourmt3(audio_path: Path, output_dir: Path) -> Path:
|
| 29 |
-
"""
|
| 30 |
-
Transcribe audio using current YourMT3+ pipeline.
|
| 31 |
-
|
| 32 |
-
Args:
|
| 33 |
-
audio_path: Path to input audio file
|
| 34 |
-
output_dir: Directory to save output MIDI
|
| 35 |
-
|
| 36 |
-
Returns:
|
| 37 |
-
Path to output MIDI file
|
| 38 |
-
"""
|
| 39 |
-
import shutil
|
| 40 |
-
import uuid
|
| 41 |
-
|
| 42 |
-
config = Settings()
|
| 43 |
-
|
| 44 |
-
# Create a temporary job ID and storage for benchmarking
|
| 45 |
-
job_id = f"benchmark_{uuid.uuid4().hex[:8]}"
|
| 46 |
-
temp_storage = Path("/tmp/rescored_benchmark")
|
| 47 |
-
temp_storage.mkdir(parents=True, exist_ok=True)
|
| 48 |
-
|
| 49 |
-
# Initialize pipeline with dummy URL (not used for benchmark)
|
| 50 |
-
pipeline = TranscriptionPipeline(
|
| 51 |
-
job_id=job_id,
|
| 52 |
-
youtube_url="benchmark://test", # Dummy URL
|
| 53 |
-
storage_path=temp_storage,
|
| 54 |
-
config=config
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
try:
|
| 58 |
-
# Copy audio to pipeline's temp directory
|
| 59 |
-
temp_audio = pipeline.temp_dir / "audio.wav"
|
| 60 |
-
shutil.copy(audio_path, temp_audio)
|
| 61 |
-
|
| 62 |
-
# Run source separation
|
| 63 |
-
print(f" Running Demucs separation...")
|
| 64 |
-
separated_stems = pipeline.separate_sources(temp_audio)
|
| 65 |
-
piano_stem_path = separated_stems.get("other")
|
| 66 |
-
|
| 67 |
-
if not piano_stem_path or not Path(piano_stem_path).exists():
|
| 68 |
-
raise FileNotFoundError(f"Source separation failed: piano stem not found")
|
| 69 |
-
|
| 70 |
-
# Run transcription (use the transcribe_with_yourmt3 method)
|
| 71 |
-
print(f" Running YourMT3+ transcription...")
|
| 72 |
-
midi_path = pipeline.transcribe_with_yourmt3(Path(piano_stem_path))
|
| 73 |
-
|
| 74 |
-
if not midi_path or not midi_path.exists():
|
| 75 |
-
raise FileNotFoundError(f"Transcription failed: no MIDI output")
|
| 76 |
-
|
| 77 |
-
# Copy result to output directory
|
| 78 |
-
output_midi = output_dir / f"{audio_path.stem}.mid"
|
| 79 |
-
shutil.copy(midi_path, output_midi)
|
| 80 |
-
|
| 81 |
-
return output_midi
|
| 82 |
-
|
| 83 |
-
finally:
|
| 84 |
-
# Cleanup temp directory
|
| 85 |
-
shutil.rmtree(temp_storage, ignore_errors=True)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def transcribe_with_ensemble(audio_path: Path, output_dir: Path) -> Path:
|
| 89 |
-
"""
|
| 90 |
-
Transcribe audio using ensemble method (Phase 3).
|
| 91 |
-
|
| 92 |
-
Note: This will be implemented in Phase 3.
|
| 93 |
-
|
| 94 |
-
Args:
|
| 95 |
-
audio_path: Path to input audio file
|
| 96 |
-
output_dir: Directory to save output MIDI
|
| 97 |
-
|
| 98 |
-
Returns:
|
| 99 |
-
Path to output MIDI file
|
| 100 |
-
"""
|
| 101 |
-
raise NotImplementedError(
|
| 102 |
-
"Ensemble transcription not yet implemented. "
|
| 103 |
-
"This will be available after Phase 3."
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def transcribe_with_bytedance(audio_path: Path, output_dir: Path) -> Path:
|
| 108 |
-
"""
|
| 109 |
-
Transcribe audio using ByteDance piano model (Phase 2).
|
| 110 |
-
|
| 111 |
-
Note: This will be implemented in Phase 2.
|
| 112 |
-
|
| 113 |
-
Args:
|
| 114 |
-
audio_path: Path to input audio file
|
| 115 |
-
output_dir: Directory to save output MIDI
|
| 116 |
-
|
| 117 |
-
Returns:
|
| 118 |
-
Path to output MIDI file
|
| 119 |
-
"""
|
| 120 |
-
raise NotImplementedError(
|
| 121 |
-
"ByteDance transcription not yet implemented. "
|
| 122 |
-
"This will be available after Phase 2."
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def main():
|
| 127 |
-
parser = argparse.ArgumentParser(description="Run transcription benchmarks")
|
| 128 |
-
parser.add_argument(
|
| 129 |
-
"--model",
|
| 130 |
-
type=str,
|
| 131 |
-
required=True,
|
| 132 |
-
choices=["yourmt3", "bytedance", "ensemble"],
|
| 133 |
-
help="Model to benchmark"
|
| 134 |
-
)
|
| 135 |
-
parser.add_argument(
|
| 136 |
-
"--test-cases",
|
| 137 |
-
type=Path,
|
| 138 |
-
default=Path("backend/evaluation/test_videos.json"),
|
| 139 |
-
help="Path to test cases JSON file"
|
| 140 |
-
)
|
| 141 |
-
parser.add_argument(
|
| 142 |
-
"--output-dir",
|
| 143 |
-
type=Path,
|
| 144 |
-
default=Path("backend/evaluation/results"),
|
| 145 |
-
help="Directory to save benchmark results"
|
| 146 |
-
)
|
| 147 |
-
parser.add_argument(
|
| 148 |
-
"--onset-tolerance",
|
| 149 |
-
type=float,
|
| 150 |
-
default=0.05,
|
| 151 |
-
help="Onset matching tolerance in seconds (default: 0.05 = 50ms)"
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
args = parser.parse_args()
|
| 155 |
-
|
| 156 |
-
# Load test cases
|
| 157 |
-
if not args.test_cases.exists():
|
| 158 |
-
print(f"❌ Test cases file not found: {args.test_cases}")
|
| 159 |
-
print(f"\n📝 First, prepare test cases:")
|
| 160 |
-
print(f" python -m evaluation.prepare_maestro --maestro-dir /path/to/maestro-v3.0.0")
|
| 161 |
-
sys.exit(1)
|
| 162 |
-
|
| 163 |
-
test_cases = load_test_cases_from_json(args.test_cases)
|
| 164 |
-
print(f"✅ Loaded {len(test_cases)} test cases from {args.test_cases}")
|
| 165 |
-
|
| 166 |
-
# Select transcription function
|
| 167 |
-
transcribe_fn_map = {
|
| 168 |
-
"yourmt3": transcribe_with_yourmt3,
|
| 169 |
-
"bytedance": transcribe_with_bytedance,
|
| 170 |
-
"ensemble": transcribe_with_ensemble
|
| 171 |
-
}
|
| 172 |
-
transcribe_fn = transcribe_fn_map[args.model]
|
| 173 |
-
|
| 174 |
-
# Create benchmark runner
|
| 175 |
-
benchmark = TranscriptionBenchmark(
|
| 176 |
-
test_cases=test_cases,
|
| 177 |
-
output_dir=args.output_dir,
|
| 178 |
-
onset_tolerance=args.onset_tolerance
|
| 179 |
-
)
|
| 180 |
-
|
| 181 |
-
# Run benchmark
|
| 182 |
-
results = benchmark.run_benchmark(
|
| 183 |
-
transcribe_fn=transcribe_fn,
|
| 184 |
-
model_name=args.model
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
print(f"\n✅ Benchmark complete!")
|
| 188 |
-
print(f" Results saved to: {args.output_dir}")
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
if __name__ == "__main__":
|
| 192 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/evaluation/slurm_benchmark.sh
DELETED
|
@@ -1,175 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --job-name=rescored_benchmark
|
| 3 |
-
#SBATCH --output=../../logs/slurm/benchmark_%j.log
|
| 4 |
-
#SBATCH --error=../../logs/slurm/benchmark_%j.err
|
| 5 |
-
#SBATCH --time=0-12:00:00 # 12 hours for 8-10 test cases
|
| 6 |
-
#SBATCH --partition=l40-gpu # Use l40-gpu partition
|
| 7 |
-
#SBATCH --qos=gpu_access # Required for l40-gpu partition
|
| 8 |
-
#SBATCH --gres=gpu:1 # 1 GPU (L40) - speeds up YourMT3+ and Demucs
|
| 9 |
-
#SBATCH --cpus-per-task=8
|
| 10 |
-
#SBATCH --mem=32G # 32GB memory
|
| 11 |
-
|
| 12 |
-
# Piano Transcription Accuracy Benchmark
|
| 13 |
-
# Evaluates YourMT3+ baseline on MAESTRO dataset
|
| 14 |
-
# Expected runtime: ~8-12 hours for 8 test cases (with GPU)
|
| 15 |
-
|
| 16 |
-
echo "========================================"
|
| 17 |
-
echo "Rescored Transcription Benchmark"
|
| 18 |
-
echo "Job ID: $SLURM_JOB_ID"
|
| 19 |
-
echo "Node: $SLURM_NODELIST"
|
| 20 |
-
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
| 21 |
-
echo "========================================"
|
| 22 |
-
|
| 23 |
-
# Configuration
|
| 24 |
-
MAESTRO_DIR=${1:-"../../data/maestro-v3.0.0"} # Path to MAESTRO dataset (relative to backend/)
|
| 25 |
-
MODEL=${2:-"yourmt3"} # Model to benchmark (yourmt3, bytedance, ensemble)
|
| 26 |
-
REPO_DIR=${3:-".."} # Path to git repo (relative to backend/)
|
| 27 |
-
|
| 28 |
-
# Verify MAESTRO dataset exists
|
| 29 |
-
if [ ! -d "$MAESTRO_DIR" ]; then
|
| 30 |
-
echo "ERROR: MAESTRO dataset not found: $MAESTRO_DIR"
|
| 31 |
-
echo ""
|
| 32 |
-
echo "Please download MAESTRO v3.0.0 from:"
|
| 33 |
-
echo " https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.zip"
|
| 34 |
-
echo ""
|
| 35 |
-
echo "Extract to: ../../data/maestro-v3.0.0/"
|
| 36 |
-
echo ""
|
| 37 |
-
echo "Directory structure should be:"
|
| 38 |
-
echo " rescored/"
|
| 39 |
-
echo " ├── data/"
|
| 40 |
-
echo " │ └── maestro-v3.0.0/ <- MAESTRO dataset here"
|
| 41 |
-
echo " └── rescored/ <- Git repo here"
|
| 42 |
-
echo " └── backend/ <- Script runs from here"
|
| 43 |
-
exit 1
|
| 44 |
-
fi
|
| 45 |
-
|
| 46 |
-
# Verify git repo exists
|
| 47 |
-
if [ ! -d "$REPO_DIR" ]; then
|
| 48 |
-
echo "ERROR: Rescored repo not found: $REPO_DIR"
|
| 49 |
-
echo "Please clone the repo first:"
|
| 50 |
-
echo " git clone <repo-url> $REPO_DIR"
|
| 51 |
-
exit 1
|
| 52 |
-
fi
|
| 53 |
-
|
| 54 |
-
# Navigate to backend directory
|
| 55 |
-
cd "$REPO_DIR/backend" || exit 1
|
| 56 |
-
|
| 57 |
-
echo "MAESTRO dataset: $MAESTRO_DIR"
|
| 58 |
-
echo "Model: $MODEL"
|
| 59 |
-
echo "Repository: $REPO_DIR"
|
| 60 |
-
echo ""
|
| 61 |
-
|
| 62 |
-
# Create necessary directories
|
| 63 |
-
mkdir -p logs/slurm
|
| 64 |
-
mkdir -p evaluation/results
|
| 65 |
-
|
| 66 |
-
# Load required modules
|
| 67 |
-
module load anaconda/2024.02
|
| 68 |
-
module load ffmpeg # Required by Demucs for audio loading
|
| 69 |
-
|
| 70 |
-
# Activate virtual environment
|
| 71 |
-
source activate .venv
|
| 72 |
-
|
| 73 |
-
# Display GPU info
|
| 74 |
-
echo ""
|
| 75 |
-
echo "GPU Information:"
|
| 76 |
-
nvidia-smi
|
| 77 |
-
echo ""
|
| 78 |
-
|
| 79 |
-
# Prepare test cases from MAESTRO
|
| 80 |
-
echo "========================================"
|
| 81 |
-
echo "Step 1: Preparing Test Cases"
|
| 82 |
-
echo "========================================"
|
| 83 |
-
|
| 84 |
-
if [ ! -f "evaluation/test_videos.json" ]; then
|
| 85 |
-
echo "Extracting test subset from MAESTRO..."
|
| 86 |
-
python -m evaluation.prepare_maestro \
|
| 87 |
-
--maestro-dir "$MAESTRO_DIR" \
|
| 88 |
-
--output-json evaluation/test_videos.json
|
| 89 |
-
|
| 90 |
-
if [ $? -ne 0 ]; then
|
| 91 |
-
echo "ERROR: Failed to prepare test cases"
|
| 92 |
-
exit 1
|
| 93 |
-
fi
|
| 94 |
-
else
|
| 95 |
-
echo "✅ Test cases already exist: evaluation/test_videos.json"
|
| 96 |
-
fi
|
| 97 |
-
|
| 98 |
-
NUM_TESTS=$(python -c "import json; print(len(json.load(open('evaluation/test_videos.json'))))")
|
| 99 |
-
echo ""
|
| 100 |
-
echo "Number of test cases: $NUM_TESTS"
|
| 101 |
-
echo ""
|
| 102 |
-
|
| 103 |
-
# Run benchmark
|
| 104 |
-
echo "========================================"
|
| 105 |
-
echo "Step 2: Running Benchmark"
|
| 106 |
-
echo "========================================"
|
| 107 |
-
echo "Model: $MODEL"
|
| 108 |
-
echo "Onset tolerance: 50ms (default)"
|
| 109 |
-
echo "Expected time: ~60-90 min per test case with GPU"
|
| 110 |
-
echo ""
|
| 111 |
-
|
| 112 |
-
python -m evaluation.run_benchmark \
|
| 113 |
-
--model "$MODEL" \
|
| 114 |
-
--test-cases evaluation/test_videos.json \
|
| 115 |
-
--output-dir evaluation/results \
|
| 116 |
-
2>&1 | tee "logs/slurm/benchmark_${SLURM_JOB_ID}.log"
|
| 117 |
-
|
| 118 |
-
EXIT_CODE=$?
|
| 119 |
-
|
| 120 |
-
echo ""
|
| 121 |
-
echo "========================================"
|
| 122 |
-
if [ $EXIT_CODE -eq 0 ]; then
|
| 123 |
-
echo "✅ Benchmark completed successfully!"
|
| 124 |
-
echo ""
|
| 125 |
-
echo "Results:"
|
| 126 |
-
ls -lh evaluation/results/${MODEL}_results.*
|
| 127 |
-
|
| 128 |
-
echo ""
|
| 129 |
-
echo "Summary (from JSON):"
|
| 130 |
-
python -c "
|
| 131 |
-
import json
|
| 132 |
-
import sys
|
| 133 |
-
try:
|
| 134 |
-
with open('evaluation/results/${MODEL}_results.json', 'r') as f:
|
| 135 |
-
results = json.load(f)
|
| 136 |
-
|
| 137 |
-
successful = [r for r in results if r.get('success', False)]
|
| 138 |
-
failed = [r for r in results if not r.get('success', False)]
|
| 139 |
-
|
| 140 |
-
print(f' Total tests: {len(results)}')
|
| 141 |
-
print(f' Successful: {len(successful)}')
|
| 142 |
-
print(f' Failed: {len(failed)}')
|
| 143 |
-
|
| 144 |
-
if successful:
|
| 145 |
-
avg_f1 = sum(r['f1_score'] for r in successful) / len(successful)
|
| 146 |
-
avg_precision = sum(r['precision'] for r in successful) / len(successful)
|
| 147 |
-
avg_recall = sum(r['recall'] for r in successful) / len(successful)
|
| 148 |
-
avg_time = sum(r['processing_time'] for r in successful) / len(successful)
|
| 149 |
-
|
| 150 |
-
print(f'')
|
| 151 |
-
print(f' Average F1 Score: {avg_f1:.3f}')
|
| 152 |
-
print(f' Average Precision: {avg_precision:.3f}')
|
| 153 |
-
print(f' Average Recall: {avg_recall:.3f}')
|
| 154 |
-
print(f' Avg Processing Time: {avg_time:.1f}s')
|
| 155 |
-
except Exception as e:
|
| 156 |
-
print(f' Could not parse results: {e}')
|
| 157 |
-
sys.exit(1)
|
| 158 |
-
"
|
| 159 |
-
|
| 160 |
-
echo ""
|
| 161 |
-
echo "📥 Download results to local machine:"
|
| 162 |
-
echo " scp <cluster>:$(pwd)/evaluation/results/${MODEL}_results.* ."
|
| 163 |
-
|
| 164 |
-
else
|
| 165 |
-
echo "❌ Benchmark failed with exit code: $EXIT_CODE"
|
| 166 |
-
echo ""
|
| 167 |
-
echo "Check logs:"
|
| 168 |
-
echo " tail -100 logs/slurm/benchmark_${SLURM_JOB_ID}.log"
|
| 169 |
-
fi
|
| 170 |
-
|
| 171 |
-
echo ""
|
| 172 |
-
echo "Job ID: $SLURM_JOB_ID"
|
| 173 |
-
echo "========================================"
|
| 174 |
-
|
| 175 |
-
exit $EXIT_CODE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/evaluation/test_videos.json
DELETED
|
@@ -1,66 +0,0 @@
|
|
| 1 |
-
[
|
| 2 |
-
{
|
| 3 |
-
"name": "MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_03_Track03_wav",
|
| 4 |
-
"audio_path": "../../data/maestro-v3.0.0/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_03_Track03_wav.wav",
|
| 5 |
-
"ground_truth_midi": "../../data/maestro-v3.0.0/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_03_Track03_wav.midi",
|
| 6 |
-
"genre": "classical",
|
| 7 |
-
"difficulty": "easy",
|
| 8 |
-
"duration": null
|
| 9 |
-
},
|
| 10 |
-
{
|
| 11 |
-
"name": "MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav",
|
| 12 |
-
"audio_path": "../../data/maestro-v3.0.0/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.wav",
|
| 13 |
-
"ground_truth_midi": "../../data/maestro-v3.0.0/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi",
|
| 14 |
-
"genre": "classical",
|
| 15 |
-
"difficulty": "easy",
|
| 16 |
-
"duration": null
|
| 17 |
-
},
|
| 18 |
-
{
|
| 19 |
-
"name": "MIDI-Unprocessed_XP_14_R1_2004_01-04_ORIG_MID--AUDIO_14_R1_2004_04_Track04_wav",
|
| 20 |
-
"audio_path": "../../data/maestro-v3.0.0/2004/MIDI-Unprocessed_XP_14_R1_2004_01-04_ORIG_MID--AUDIO_14_R1_2004_04_Track04_wav.wav",
|
| 21 |
-
"ground_truth_midi": "../../data/maestro-v3.0.0/2004/MIDI-Unprocessed_XP_14_R1_2004_01-04_ORIG_MID--AUDIO_14_R1_2004_04_Track04_wav.midi",
|
| 22 |
-
"genre": "classical",
|
| 23 |
-
"difficulty": "medium",
|
| 24 |
-
"duration": null
|
| 25 |
-
},
|
| 26 |
-
{
|
| 27 |
-
"name": "MIDI-Unprocessed_07_R1_2006_01-09_ORIG_MID--AUDIO_07_R1_2006_04_Track04_wav",
|
| 28 |
-
"audio_path": "../../data/maestro-v3.0.0/2006/MIDI-Unprocessed_07_R1_2006_01-09_ORIG_MID--AUDIO_07_R1_2006_04_Track04_wav.wav",
|
| 29 |
-
"ground_truth_midi": "../../data/maestro-v3.0.0/2006/MIDI-Unprocessed_07_R1_2006_01-09_ORIG_MID--AUDIO_07_R1_2006_04_Track04_wav.midi",
|
| 30 |
-
"genre": "classical",
|
| 31 |
-
"difficulty": "medium",
|
| 32 |
-
"duration": null
|
| 33 |
-
},
|
| 34 |
-
{
|
| 35 |
-
"name": "MIDI-Unprocessed_11_R1_2008_01-04_ORIG_MID--AUDIO_11_R1_2008_02_Track02_wav",
|
| 36 |
-
"audio_path": "../../data/maestro-v3.0.0/2008/MIDI-Unprocessed_11_R1_2008_01-04_ORIG_MID--AUDIO_11_R1_2008_02_Track02_wav.wav",
|
| 37 |
-
"ground_truth_midi": "../../data/maestro-v3.0.0/2008/MIDI-Unprocessed_11_R1_2008_01-04_ORIG_MID--AUDIO_11_R1_2008_02_Track02_wav.midi",
|
| 38 |
-
"genre": "classical",
|
| 39 |
-
"difficulty": "medium",
|
| 40 |
-
"duration": null
|
| 41 |
-
},
|
| 42 |
-
{
|
| 43 |
-
"name": "MIDI-Unprocessed_16_R1_2009_01-04_ORIG_MID--AUDIO_16_R1_2009_16_R1_2009_02_WAV",
|
| 44 |
-
"audio_path": "../../data/maestro-v3.0.0/2009/MIDI-Unprocessed_16_R1_2009_01-04_ORIG_MID--AUDIO_16_R1_2009_16_R1_2009_02_WAV.wav",
|
| 45 |
-
"ground_truth_midi": "../../data/maestro-v3.0.0/2009/MIDI-Unprocessed_16_R1_2009_01-04_ORIG_MID--AUDIO_16_R1_2009_16_R1_2009_02_WAV.midi",
|
| 46 |
-
"genre": "classical",
|
| 47 |
-
"difficulty": "hard",
|
| 48 |
-
"duration": null
|
| 49 |
-
},
|
| 50 |
-
{
|
| 51 |
-
"name": "MIDI-Unprocessed_03_R1_2011_MID--AUDIO_03_R1_2011_03_R1_2011_02_WAV",
|
| 52 |
-
"audio_path": "../../data/maestro-v3.0.0/2011/MIDI-Unprocessed_03_R1_2011_MID--AUDIO_03_R1_2011_03_R1_2011_02_WAV.wav",
|
| 53 |
-
"ground_truth_midi": "../../data/maestro-v3.0.0/2011/MIDI-Unprocessed_03_R1_2011_MID--AUDIO_03_R1_2011_03_R1_2011_02_WAV.midi",
|
| 54 |
-
"genre": "classical",
|
| 55 |
-
"difficulty": "hard",
|
| 56 |
-
"duration": null
|
| 57 |
-
},
|
| 58 |
-
{
|
| 59 |
-
"name": "MIDI-Unprocessed_20_R1_2013_MID--AUDIO_20_R1_2013_20_R1_2013_02_WAV",
|
| 60 |
-
"audio_path": "../../data/maestro-v3.0.0/2013/MIDI-Unprocessed_20_R1_2013_MID--AUDIO_20_R1_2013_20_R1_2013_02_WAV.wav",
|
| 61 |
-
"ground_truth_midi": "../../data/maestro-v3.0.0/2013/MIDI-Unprocessed_20_R1_2013_MID--AUDIO_20_R1_2013_20_R1_2013_02_WAV.midi",
|
| 62 |
-
"genre": "classical",
|
| 63 |
-
"difficulty": "hard",
|
| 64 |
-
"duration": null
|
| 65 |
-
}
|
| 66 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/key_filter.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Key-Aware MIDI Filtering
|
| 3 |
+
|
| 4 |
+
Filters out notes that are inconsistent with the detected key signature.
|
| 5 |
+
Removes isolated out-of-key notes that are likely false positives.
|
| 6 |
+
|
| 7 |
+
Expected Impact: +1-2% precision improvement (especially for tonal music)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Set, Optional
|
| 12 |
+
import pretty_midi
|
| 13 |
+
from music21 import scale, pitch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class KeyAwareFilter:
|
| 17 |
+
"""
|
| 18 |
+
Filter MIDI notes based on key signature analysis.
|
| 19 |
+
|
| 20 |
+
Removes isolated out-of-key notes that are likely false positives while
|
| 21 |
+
preserving intentional chromatic notes (passing tones, accidentals).
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
allow_chromatic: bool = True,
|
| 27 |
+
isolation_threshold: float = 0.5
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Initialize key-aware filter.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
allow_chromatic: Allow chromatic passing tones (brief out-of-key notes)
|
| 34 |
+
isolation_threshold: Time threshold (seconds) to consider note "isolated"
|
| 35 |
+
"""
|
| 36 |
+
self.allow_chromatic = allow_chromatic
|
| 37 |
+
self.isolation_threshold = isolation_threshold
|
| 38 |
+
|
| 39 |
+
def filter_midi_by_key(
|
| 40 |
+
self,
|
| 41 |
+
midi_path: Path,
|
| 42 |
+
detected_key: str,
|
| 43 |
+
output_path: Optional[Path] = None
|
| 44 |
+
) -> Path:
|
| 45 |
+
"""
|
| 46 |
+
Filter MIDI notes based on key signature.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
midi_path: Input MIDI file
|
| 50 |
+
detected_key: Detected key (e.g., "C major", "A minor")
|
| 51 |
+
output_path: Output path (default: input_path with _key_filtered suffix)
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
Path to filtered MIDI file
|
| 55 |
+
"""
|
| 56 |
+
# Parse key signature
|
| 57 |
+
key_pitches = self._get_key_pitches(detected_key)
|
| 58 |
+
|
| 59 |
+
# Load MIDI
|
| 60 |
+
pm = pretty_midi.PrettyMIDI(str(midi_path))
|
| 61 |
+
|
| 62 |
+
# Create new MIDI with filtered notes
|
| 63 |
+
filtered_pm = pretty_midi.PrettyMIDI(initial_tempo=pm.estimate_tempo())
|
| 64 |
+
|
| 65 |
+
total_notes = 0
|
| 66 |
+
kept_notes = 0
|
| 67 |
+
|
| 68 |
+
for inst in pm.instruments:
|
| 69 |
+
if inst.is_drum:
|
| 70 |
+
# Keep drum tracks as-is
|
| 71 |
+
filtered_pm.instruments.append(inst)
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
# Create new instrument with filtered notes
|
| 75 |
+
filtered_inst = pretty_midi.Instrument(
|
| 76 |
+
program=inst.program,
|
| 77 |
+
is_drum=inst.is_drum,
|
| 78 |
+
name=inst.name
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Sort notes by onset for isolation detection
|
| 82 |
+
sorted_notes = sorted(inst.notes, key=lambda n: n.start)
|
| 83 |
+
|
| 84 |
+
for i, note in enumerate(sorted_notes):
|
| 85 |
+
total_notes += 1
|
| 86 |
+
|
| 87 |
+
# Check if note should be kept
|
| 88 |
+
if self._should_keep_note(note, key_pitches, sorted_notes, i):
|
| 89 |
+
filtered_inst.notes.append(note)
|
| 90 |
+
kept_notes += 1
|
| 91 |
+
|
| 92 |
+
filtered_pm.instruments.append(filtered_inst)
|
| 93 |
+
|
| 94 |
+
# Set output path
|
| 95 |
+
if output_path is None:
|
| 96 |
+
output_path = midi_path.with_stem(f"{midi_path.stem}_key_filtered")
|
| 97 |
+
|
| 98 |
+
# Save filtered MIDI
|
| 99 |
+
filtered_pm.write(str(output_path))
|
| 100 |
+
|
| 101 |
+
removed = total_notes - kept_notes
|
| 102 |
+
print(f" Key-aware filtering: kept {kept_notes}/{total_notes} notes (removed {removed})")
|
| 103 |
+
|
| 104 |
+
return output_path
|
| 105 |
+
|
| 106 |
+
def _get_key_pitches(self, detected_key: str) -> Set[int]:
|
| 107 |
+
"""
|
| 108 |
+
Get pitch classes that are in the detected key.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
detected_key: Key signature (e.g., "C major", "A minor")
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Set of pitch classes (0-11) in the key
|
| 115 |
+
"""
|
| 116 |
+
# Parse key signature
|
| 117 |
+
key_parts = detected_key.split()
|
| 118 |
+
if len(key_parts) != 2:
|
| 119 |
+
# Default to C major if parsing fails
|
| 120 |
+
print(f" ⚠ Could not parse key '{detected_key}', defaulting to C major")
|
| 121 |
+
key_parts = ["C", "major"]
|
| 122 |
+
|
| 123 |
+
key_root = key_parts[0]
|
| 124 |
+
key_mode = key_parts[1].lower()
|
| 125 |
+
|
| 126 |
+
# Create scale
|
| 127 |
+
try:
|
| 128 |
+
if key_mode == "major":
|
| 129 |
+
key_scale = scale.MajorScale(key_root)
|
| 130 |
+
elif key_mode == "minor":
|
| 131 |
+
key_scale = scale.MinorScale(key_root)
|
| 132 |
+
else:
|
| 133 |
+
# Default to major
|
| 134 |
+
key_scale = scale.MajorScale(key_root)
|
| 135 |
+
|
| 136 |
+
# Get pitch classes in key
|
| 137 |
+
pitch_classes = set()
|
| 138 |
+
for p in key_scale.pitches:
|
| 139 |
+
pitch_classes.add(p.pitchClass)
|
| 140 |
+
|
| 141 |
+
return pitch_classes
|
| 142 |
+
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f" ⚠ Error creating scale for '{detected_key}': {e}")
|
| 145 |
+
# Default to chromatic (all notes) if error
|
| 146 |
+
return set(range(12))
|
| 147 |
+
|
| 148 |
+
def _is_in_key(self, note_pitch: int, key_pitches: Set[int]) -> bool:
|
| 149 |
+
"""
|
| 150 |
+
Check if a note is in the key signature.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
note_pitch: MIDI note number (0-127)
|
| 154 |
+
key_pitches: Set of pitch classes (0-11) in the key
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
True if note is in key, False otherwise
|
| 158 |
+
"""
|
| 159 |
+
pitch_class = note_pitch % 12
|
| 160 |
+
return pitch_class in key_pitches
|
| 161 |
+
|
| 162 |
+
def _is_isolated(
|
| 163 |
+
self,
|
| 164 |
+
note: pretty_midi.Note,
|
| 165 |
+
sorted_notes: List[pretty_midi.Note],
|
| 166 |
+
note_index: int
|
| 167 |
+
) -> bool:
|
| 168 |
+
"""
|
| 169 |
+
Check if a note is isolated (no nearby notes).
|
| 170 |
+
|
| 171 |
+
An isolated out-of-key note is likely a false positive.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
note: Note to check
|
| 175 |
+
sorted_notes: All notes sorted by onset time
|
| 176 |
+
note_index: Index of note in sorted_notes
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
True if note is isolated, False otherwise
|
| 180 |
+
"""
|
| 181 |
+
# Check for nearby notes (within isolation_threshold)
|
| 182 |
+
has_nearby = False
|
| 183 |
+
|
| 184 |
+
# Check previous notes
|
| 185 |
+
for i in range(note_index - 1, -1, -1):
|
| 186 |
+
prev_note = sorted_notes[i]
|
| 187 |
+
time_gap = note.start - prev_note.start
|
| 188 |
+
|
| 189 |
+
if time_gap > self.isolation_threshold:
|
| 190 |
+
break # Too far back
|
| 191 |
+
|
| 192 |
+
# Nearby note found
|
| 193 |
+
has_nearby = True
|
| 194 |
+
break
|
| 195 |
+
|
| 196 |
+
# Check next notes
|
| 197 |
+
for i in range(note_index + 1, len(sorted_notes)):
|
| 198 |
+
next_note = sorted_notes[i]
|
| 199 |
+
time_gap = next_note.start - note.start
|
| 200 |
+
|
| 201 |
+
if time_gap > self.isolation_threshold:
|
| 202 |
+
break # Too far forward
|
| 203 |
+
|
| 204 |
+
# Nearby note found
|
| 205 |
+
has_nearby = True
|
| 206 |
+
break
|
| 207 |
+
|
| 208 |
+
return not has_nearby
|
| 209 |
+
|
| 210 |
+
def _is_chromatic_passing_tone(
|
| 211 |
+
self,
|
| 212 |
+
note: pretty_midi.Note,
|
| 213 |
+
sorted_notes: List[pretty_midi.Note],
|
| 214 |
+
note_index: int,
|
| 215 |
+
key_pitches: Set[int]
|
| 216 |
+
) -> bool:
|
| 217 |
+
"""
|
| 218 |
+
Check if an out-of-key note is a chromatic passing tone.
|
| 219 |
+
|
| 220 |
+
A chromatic passing tone:
|
| 221 |
+
- Is short duration
|
| 222 |
+
- Is surrounded by in-key notes
|
| 223 |
+
- Steps between the surrounding notes (semitone or whole tone)
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
note: Note to check
|
| 227 |
+
sorted_notes: All notes sorted by onset time
|
| 228 |
+
note_index: Index of note in sorted_notes
|
| 229 |
+
key_pitches: Set of pitch classes in the key
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
True if note is likely a chromatic passing tone, False otherwise
|
| 233 |
+
"""
|
| 234 |
+
# Must be short duration (< 0.25 seconds)
|
| 235 |
+
if note.end - note.start > 0.25:
|
| 236 |
+
return False
|
| 237 |
+
|
| 238 |
+
# Check surrounding notes
|
| 239 |
+
prev_note = None
|
| 240 |
+
next_note = None
|
| 241 |
+
|
| 242 |
+
# Find previous in-key note
|
| 243 |
+
for i in range(note_index - 1, -1, -1):
|
| 244 |
+
candidate = sorted_notes[i]
|
| 245 |
+
if self._is_in_key(candidate.pitch, key_pitches):
|
| 246 |
+
prev_note = candidate
|
| 247 |
+
break
|
| 248 |
+
|
| 249 |
+
# Find next in-key note
|
| 250 |
+
for i in range(note_index + 1, len(sorted_notes)):
|
| 251 |
+
candidate = sorted_notes[i]
|
| 252 |
+
if self._is_in_key(candidate.pitch, key_pitches):
|
| 253 |
+
next_note = candidate
|
| 254 |
+
break
|
| 255 |
+
|
| 256 |
+
# Must be surrounded by in-key notes
|
| 257 |
+
if prev_note is None or next_note is None:
|
| 258 |
+
return False
|
| 259 |
+
|
| 260 |
+
# Check if it's a passing tone (connects prev and next)
|
| 261 |
+
pitch_interval = abs(next_note.pitch - prev_note.pitch)
|
| 262 |
+
is_step = pitch_interval in [1, 2, 3] # Semitone, whole tone, or minor third
|
| 263 |
+
|
| 264 |
+
# Check if note is between prev and next
|
| 265 |
+
is_between = (
|
| 266 |
+
(prev_note.pitch < note.pitch < next_note.pitch) or
|
| 267 |
+
(prev_note.pitch > note.pitch > next_note.pitch)
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
return is_step and is_between
|
| 271 |
+
|
| 272 |
+
def _should_keep_note(
|
| 273 |
+
self,
|
| 274 |
+
note: pretty_midi.Note,
|
| 275 |
+
key_pitches: Set[int],
|
| 276 |
+
sorted_notes: List[pretty_midi.Note],
|
| 277 |
+
note_index: int
|
| 278 |
+
) -> bool:
|
| 279 |
+
"""
|
| 280 |
+
Determine whether to keep a note based on key signature analysis.
|
| 281 |
+
|
| 282 |
+
Keep rules:
|
| 283 |
+
1. All in-key notes → keep
|
| 284 |
+
2. Out-of-key but chromatic passing tone → keep (if allow_chromatic)
|
| 285 |
+
3. Out-of-key and isolated → remove (likely false positive)
|
| 286 |
+
4. Out-of-key with nearby notes → keep (intentional accidental)
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
note: Note to evaluate
|
| 290 |
+
key_pitches: Set of pitch classes in the key
|
| 291 |
+
sorted_notes: All notes sorted by onset time
|
| 292 |
+
note_index: Index of note in sorted_notes
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
True if note should be kept, False otherwise
|
| 296 |
+
"""
|
| 297 |
+
# In-key notes always kept
|
| 298 |
+
if self._is_in_key(note.pitch, key_pitches):
|
| 299 |
+
return True
|
| 300 |
+
|
| 301 |
+
# Out-of-key note - apply filtering logic
|
| 302 |
+
|
| 303 |
+
# Check if it's a chromatic passing tone
|
| 304 |
+
if self.allow_chromatic:
|
| 305 |
+
if self._is_chromatic_passing_tone(note, sorted_notes, note_index, key_pitches):
|
| 306 |
+
return True # Keep passing tones
|
| 307 |
+
|
| 308 |
+
# Check if isolated
|
| 309 |
+
if self._is_isolated(note, sorted_notes, note_index):
|
| 310 |
+
# Isolated out-of-key note → likely false positive → remove
|
| 311 |
+
return False
|
| 312 |
+
|
| 313 |
+
# Out-of-key but not isolated → likely intentional accidental → keep
|
| 314 |
+
return True
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
# Test the key filter
|
| 319 |
+
import argparse
|
| 320 |
+
|
| 321 |
+
parser = argparse.ArgumentParser(description="Test Key-Aware Filter")
|
| 322 |
+
parser.add_argument("midi_file", type=str, help="Path to MIDI file")
|
| 323 |
+
parser.add_argument("--key", type=str, required=True,
|
| 324 |
+
help="Detected key (e.g., 'C major', 'A minor')")
|
| 325 |
+
parser.add_argument("--output", type=str, default=None,
|
| 326 |
+
help="Output MIDI file path")
|
| 327 |
+
parser.add_argument("--no-chromatic", action="store_true",
|
| 328 |
+
help="Disallow chromatic passing tones")
|
| 329 |
+
args = parser.parse_args()
|
| 330 |
+
|
| 331 |
+
filter = KeyAwareFilter(
|
| 332 |
+
allow_chromatic=not args.no_chromatic,
|
| 333 |
+
isolation_threshold=0.5
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
midi_path = Path(args.midi_file)
|
| 337 |
+
output_path = Path(args.output) if args.output else None
|
| 338 |
+
|
| 339 |
+
# Filter MIDI
|
| 340 |
+
filtered_path = filter.filter_midi_by_key(
|
| 341 |
+
midi_path,
|
| 342 |
+
detected_key=args.key,
|
| 343 |
+
output_path=output_path
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
print(f"\n✓ Key-filtered MIDI saved: {filtered_path}")
|
backend/main.py
CHANGED
|
@@ -36,6 +36,9 @@ app = FastAPI(
|
|
| 36 |
# Redis client (initialized before middleware)
|
| 37 |
redis_client = redis.Redis.from_url(settings.redis_url, decode_responses=True)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
| 39 |
# YourMT3+ transcriber (loaded on startup)
|
| 40 |
yourmt3_transcriber: Optional[YourMT3Transcriber] = None
|
| 41 |
YOURMT3_TEMP_DIR = Path(tempfile.gettempdir()) / "yourmt3_service"
|
|
@@ -123,10 +126,13 @@ async def startup_event():
|
|
| 123 |
|
| 124 |
@app.on_event("shutdown")
|
| 125 |
async def shutdown_event():
|
| 126 |
-
"""Clean up temporary files on shutdown."""
|
| 127 |
if YOURMT3_TEMP_DIR.exists():
|
| 128 |
shutil.rmtree(YOURMT3_TEMP_DIR, ignore_errors=True)
|
| 129 |
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
# === Request/Response Models ===
|
| 132 |
|
|
@@ -422,31 +428,12 @@ async def websocket_endpoint(websocket: WebSocket, job_id: str):
|
|
| 422 |
job_id: Job identifier
|
| 423 |
"""
|
| 424 |
await manager.connect(websocket, job_id)
|
|
|
|
| 425 |
|
| 426 |
try:
|
| 427 |
-
# Subscribe to Redis pub/sub for this job
|
| 428 |
-
pubsub =
|
| 429 |
-
pubsub.subscribe(f"job:{job_id}:updates")
|
| 430 |
-
|
| 431 |
-
# Listen for updates in a separate task
|
| 432 |
-
async def listen_for_updates():
|
| 433 |
-
for message in pubsub.listen():
|
| 434 |
-
if message['type'] == 'message':
|
| 435 |
-
update = json.loads(message['data'])
|
| 436 |
-
await websocket.send_json(update)
|
| 437 |
-
|
| 438 |
-
# Close connection if job completed
|
| 439 |
-
if update.get('type') == 'completed':
|
| 440 |
-
break
|
| 441 |
-
|
| 442 |
-
# Close connection if job failed with non-retryable error
|
| 443 |
-
if update.get('type') == 'error':
|
| 444 |
-
error_info = update.get('error', {})
|
| 445 |
-
is_retryable = error_info.get('retryable', False)
|
| 446 |
-
if not is_retryable:
|
| 447 |
-
# Only close if error is permanent
|
| 448 |
-
break
|
| 449 |
-
# If retryable, keep connection open for retry progress updates
|
| 450 |
|
| 451 |
# Send initial status
|
| 452 |
job_data = redis_client.hgetall(f"job:{job_id}")
|
|
@@ -461,14 +448,31 @@ async def websocket_endpoint(websocket: WebSocket, job_id: str):
|
|
| 461 |
}
|
| 462 |
await websocket.send_json(initial_update)
|
| 463 |
|
| 464 |
-
# Listen for updates
|
| 465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
|
| 467 |
except WebSocketDisconnect:
|
| 468 |
manager.disconnect(websocket, job_id)
|
| 469 |
finally:
|
| 470 |
-
pubsub
|
| 471 |
-
|
|
|
|
| 472 |
|
| 473 |
|
| 474 |
# === Health Check ===
|
|
|
|
| 36 |
# Redis client (initialized before middleware)
|
| 37 |
redis_client = redis.Redis.from_url(settings.redis_url, decode_responses=True)
|
| 38 |
|
| 39 |
+
# Async Redis client for WebSocket pub/sub
|
| 40 |
+
async_redis_client = redis.asyncio.Redis.from_url(settings.redis_url, decode_responses=True)
|
| 41 |
+
|
| 42 |
# YourMT3+ transcriber (loaded on startup)
|
| 43 |
yourmt3_transcriber: Optional[YourMT3Transcriber] = None
|
| 44 |
YOURMT3_TEMP_DIR = Path(tempfile.gettempdir()) / "yourmt3_service"
|
|
|
|
| 126 |
|
| 127 |
@app.on_event("shutdown")
|
| 128 |
async def shutdown_event():
|
| 129 |
+
"""Clean up temporary files and close Redis connections on shutdown."""
|
| 130 |
if YOURMT3_TEMP_DIR.exists():
|
| 131 |
shutil.rmtree(YOURMT3_TEMP_DIR, ignore_errors=True)
|
| 132 |
|
| 133 |
+
# Close async Redis client
|
| 134 |
+
await async_redis_client.close()
|
| 135 |
+
|
| 136 |
|
| 137 |
# === Request/Response Models ===
|
| 138 |
|
|
|
|
| 428 |
job_id: Job identifier
|
| 429 |
"""
|
| 430 |
await manager.connect(websocket, job_id)
|
| 431 |
+
pubsub = None
|
| 432 |
|
| 433 |
try:
|
| 434 |
+
# Subscribe to Redis pub/sub for this job using async client
|
| 435 |
+
pubsub = async_redis_client.pubsub()
|
| 436 |
+
await pubsub.subscribe(f"job:{job_id}:updates")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
# Send initial status
|
| 439 |
job_data = redis_client.hgetall(f"job:{job_id}")
|
|
|
|
| 448 |
}
|
| 449 |
await websocket.send_json(initial_update)
|
| 450 |
|
| 451 |
+
# Listen for updates asynchronously
|
| 452 |
+
async for message in pubsub.listen():
|
| 453 |
+
if message['type'] == 'message':
|
| 454 |
+
update = json.loads(message['data'])
|
| 455 |
+
await websocket.send_json(update)
|
| 456 |
+
|
| 457 |
+
# Close connection if job completed
|
| 458 |
+
if update.get('type') == 'completed':
|
| 459 |
+
break
|
| 460 |
+
|
| 461 |
+
# Close connection if job failed with non-retryable error
|
| 462 |
+
if update.get('type') == 'error':
|
| 463 |
+
error_info = update.get('error', {})
|
| 464 |
+
is_retryable = error_info.get('retryable', False)
|
| 465 |
+
if not is_retryable:
|
| 466 |
+
# Only close if error is permanent
|
| 467 |
+
break
|
| 468 |
+
# If retryable, keep connection open for retry progress updates
|
| 469 |
|
| 470 |
except WebSocketDisconnect:
|
| 471 |
manager.disconnect(websocket, job_id)
|
| 472 |
finally:
|
| 473 |
+
if pubsub:
|
| 474 |
+
await pubsub.unsubscribe(f"job:{job_id}:updates")
|
| 475 |
+
await pubsub.close()
|
| 476 |
|
| 477 |
|
| 478 |
# === Health Check ===
|
backend/pipeline.py
CHANGED
|
@@ -80,11 +80,45 @@ class TranscriptionPipeline:
|
|
| 80 |
self.progress(0, "download", "Starting audio download")
|
| 81 |
audio_path = self.download_audio()
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
self.progress(20, "separate", "Starting source separation")
|
| 84 |
stems = self.separate_sources(audio_path)
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
# Store final MIDI path for tasks.py to access
|
| 90 |
self.final_midi_path = midi_path
|
|
@@ -92,7 +126,7 @@ class TranscriptionPipeline:
|
|
| 92 |
self.progress(90, "musicxml", "Generating MusicXML")
|
| 93 |
# Use minimal MusicXML generation (YourMT3+ optimized)
|
| 94 |
print(f" Using minimal MusicXML generation (YourMT3+)")
|
| 95 |
-
musicxml_path = self.generate_musicxml_minimal(midi_path,
|
| 96 |
|
| 97 |
self.progress(100, "complete", "Transcription complete")
|
| 98 |
return musicxml_path
|
|
@@ -127,6 +161,48 @@ class TranscriptionPipeline:
|
|
| 127 |
|
| 128 |
return output_path
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
def separate_sources(self, audio_path: Path) -> dict:
|
| 131 |
"""
|
| 132 |
Separate audio into 4 stems using Demucs.
|
|
@@ -138,33 +214,79 @@ class TranscriptionPipeline:
|
|
| 138 |
if not audio_path.exists():
|
| 139 |
raise FileNotFoundError(f"Input audio not found: {audio_path}")
|
| 140 |
|
| 141 |
-
#
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
]
|
| 148 |
|
| 149 |
-
|
|
|
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
raise RuntimeError(f"Demucs failed (exit code {result.returncode}): {error_msg}")
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
'
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
|
| 167 |
-
|
| 168 |
|
| 169 |
def transcribe_to_midi(
|
| 170 |
self,
|
|
@@ -187,10 +309,15 @@ class TranscriptionPipeline:
|
|
| 187 |
"""
|
| 188 |
output_dir = self.temp_dir
|
| 189 |
|
| 190 |
-
# Transcribe with
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
# Rename final MIDI to standard name for post-processing
|
| 196 |
final_midi_path = output_dir / "piano.mid"
|
|
@@ -304,6 +431,301 @@ class TranscriptionPipeline:
|
|
| 304 |
except Exception as e:
|
| 305 |
raise RuntimeError(f"YourMT3+ transcription failed: {e}")
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
def _get_midi_range(self, midi_path: Path) -> int:
|
| 308 |
"""
|
| 309 |
Calculate the MIDI note range (max - min) in semitones.
|
|
@@ -723,9 +1145,11 @@ class TranscriptionPipeline:
|
|
| 723 |
if msg.type in ('note_on', 'note_off'):
|
| 724 |
last_note_time = abs_time
|
| 725 |
|
| 726 |
-
# Add end_of_track after last note with
|
| 727 |
from mido import MetaMessage
|
| 728 |
-
|
|
|
|
|
|
|
| 729 |
track.append(end_msg)
|
| 730 |
|
| 731 |
# 5. Save beat-quantized MIDI
|
|
|
|
| 80 |
self.progress(0, "download", "Starting audio download")
|
| 81 |
audio_path = self.download_audio()
|
| 82 |
|
| 83 |
+
# Preprocess audio if enabled (improves separation and transcription quality)
|
| 84 |
+
if self.config.enable_audio_preprocessing:
|
| 85 |
+
self.progress(10, "preprocess", "Preprocessing audio")
|
| 86 |
+
audio_path = self.preprocess_audio(audio_path)
|
| 87 |
+
|
| 88 |
self.progress(20, "separate", "Starting source separation")
|
| 89 |
stems = self.separate_sources(audio_path)
|
| 90 |
|
| 91 |
+
# Select best stem for piano transcription
|
| 92 |
+
# Priority: piano (dedicated stem) > other (mixed instruments)
|
| 93 |
+
if 'piano' in stems:
|
| 94 |
+
piano_stem = stems['piano']
|
| 95 |
+
print(f" Using dedicated piano stem for transcription")
|
| 96 |
+
else:
|
| 97 |
+
piano_stem = stems['other']
|
| 98 |
+
print(f" Using 'other' stem for transcription (legacy mode)")
|
| 99 |
+
|
| 100 |
+
# Transcribe piano
|
| 101 |
+
self.progress(50, "transcribe", "Starting piano transcription")
|
| 102 |
+
piano_midi = self.transcribe_to_midi(piano_stem)
|
| 103 |
+
|
| 104 |
+
# Transcribe vocals if enabled
|
| 105 |
+
if self.config.transcribe_vocals and 'vocals' in stems:
|
| 106 |
+
self.progress(70, "transcribe_vocals", "Transcribing vocal melody")
|
| 107 |
+
vocals_midi = self.transcribe_vocals_to_midi(stems['vocals'])
|
| 108 |
+
|
| 109 |
+
# Merge piano and vocals into single MIDI
|
| 110 |
+
print(f" Merging piano and vocals...")
|
| 111 |
+
midi_path = self.merge_piano_and_vocals(
|
| 112 |
+
piano_midi,
|
| 113 |
+
vocals_midi,
|
| 114 |
+
piano_program=0, # Acoustic Grand Piano
|
| 115 |
+
vocal_program=self.config.vocal_instrument
|
| 116 |
+
)
|
| 117 |
+
else:
|
| 118 |
+
midi_path = piano_midi
|
| 119 |
+
|
| 120 |
+
# Apply post-processing filters (Phase 4)
|
| 121 |
+
midi_path = self.apply_post_processing_filters(midi_path)
|
| 122 |
|
| 123 |
# Store final MIDI path for tasks.py to access
|
| 124 |
self.final_midi_path = midi_path
|
|
|
|
| 126 |
self.progress(90, "musicxml", "Generating MusicXML")
|
| 127 |
# Use minimal MusicXML generation (YourMT3+ optimized)
|
| 128 |
print(f" Using minimal MusicXML generation (YourMT3+)")
|
| 129 |
+
musicxml_path = self.generate_musicxml_minimal(midi_path, piano_stem)
|
| 130 |
|
| 131 |
self.progress(100, "complete", "Transcription complete")
|
| 132 |
return musicxml_path
|
|
|
|
| 161 |
|
| 162 |
return output_path
|
| 163 |
|
| 164 |
+
def preprocess_audio(self, audio_path: Path) -> Path:
|
| 165 |
+
"""
|
| 166 |
+
Preprocess audio for improved separation and transcription quality.
|
| 167 |
+
|
| 168 |
+
Applies:
|
| 169 |
+
- Spectral denoising (remove background noise)
|
| 170 |
+
- Peak normalization (consistent volume)
|
| 171 |
+
- High-pass filtering (remove rumble <30Hz)
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
audio_path: Path to raw audio file
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Path to preprocessed audio file
|
| 178 |
+
"""
|
| 179 |
+
try:
|
| 180 |
+
from audio_preprocessor import AudioPreprocessor
|
| 181 |
+
except ImportError:
|
| 182 |
+
# Try adding backend directory to path
|
| 183 |
+
import sys
|
| 184 |
+
from pathlib import Path as PathLib
|
| 185 |
+
backend_dir = PathLib(__file__).parent
|
| 186 |
+
if str(backend_dir) not in sys.path:
|
| 187 |
+
sys.path.insert(0, str(backend_dir))
|
| 188 |
+
from audio_preprocessor import AudioPreprocessor
|
| 189 |
+
|
| 190 |
+
print(f" Preprocessing audio to improve quality...")
|
| 191 |
+
|
| 192 |
+
preprocessor = AudioPreprocessor(
|
| 193 |
+
enable_denoising=self.config.enable_audio_denoising,
|
| 194 |
+
enable_normalization=self.config.enable_audio_normalization,
|
| 195 |
+
enable_highpass=self.config.enable_highpass_filter,
|
| 196 |
+
target_sample_rate=44100
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Preprocess (output will be saved in temp directory)
|
| 200 |
+
preprocessed_path = preprocessor.preprocess(audio_path, self.temp_dir)
|
| 201 |
+
|
| 202 |
+
print(f" ✓ Audio preprocessing complete")
|
| 203 |
+
|
| 204 |
+
return preprocessed_path
|
| 205 |
+
|
| 206 |
def separate_sources(self, audio_path: Path) -> dict:
|
| 207 |
"""
|
| 208 |
Separate audio into 4 stems using Demucs.
|
|
|
|
| 214 |
if not audio_path.exists():
|
| 215 |
raise FileNotFoundError(f"Input audio not found: {audio_path}")
|
| 216 |
|
| 217 |
+
# Source separation - config-driven approach
|
| 218 |
+
if self.config.use_two_stage_separation:
|
| 219 |
+
# Two-stage separation for maximum quality:
|
| 220 |
+
# 1. BS-RoFormer removes vocals (SOTA vocal separation)
|
| 221 |
+
# 2. Demucs separates clean instrumental into piano/guitar/drums/bass/other
|
| 222 |
+
print(" Using two-stage separation (BS-RoFormer + Demucs)")
|
|
|
|
| 223 |
|
| 224 |
+
from audio_separator_wrapper import AudioSeparator
|
| 225 |
+
separator = AudioSeparator()
|
| 226 |
|
| 227 |
+
separation_dir = self.temp_dir / "separation"
|
| 228 |
+
instrument_stems = 6 if self.config.use_6stem_demucs else 4
|
|
|
|
| 229 |
|
| 230 |
+
stems = separator.two_stage_separation(
|
| 231 |
+
audio_path,
|
| 232 |
+
separation_dir,
|
| 233 |
+
instrument_stems=instrument_stems
|
| 234 |
+
)
|
| 235 |
|
| 236 |
+
# Two-stage separation returns: vocals, piano, guitar, drums, bass, other
|
| 237 |
+
# For piano transcription, use the dedicated piano stem
|
| 238 |
+
if 'piano' in stems:
|
| 239 |
+
print(f" ✓ Using dedicated piano stem for transcription")
|
| 240 |
+
|
| 241 |
+
return stems
|
| 242 |
+
|
| 243 |
+
elif self.config.use_6stem_demucs:
|
| 244 |
+
# Direct Demucs 6-stem separation (no vocal pre-removal)
|
| 245 |
+
print(" Using Demucs 6-stem separation")
|
| 246 |
+
|
| 247 |
+
from audio_separator_wrapper import AudioSeparator
|
| 248 |
+
separator = AudioSeparator()
|
| 249 |
+
|
| 250 |
+
instrument_dir = self.temp_dir / "instruments"
|
| 251 |
+
stems = separator.separate_instruments_demucs(
|
| 252 |
+
audio_path,
|
| 253 |
+
instrument_dir,
|
| 254 |
+
stems=6
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# 6-stem returns: vocals, piano, guitar, drums, bass, other
|
| 258 |
+
return stems
|
| 259 |
+
|
| 260 |
+
else:
|
| 261 |
+
# Legacy mode: Demucs 2-stem (backwards compatibility)
|
| 262 |
+
print(" Using legacy Demucs 2-stem separation")
|
| 263 |
+
|
| 264 |
+
cmd = [
|
| 265 |
+
"demucs",
|
| 266 |
+
"--two-stems=other", # For piano, we only need "other" stem
|
| 267 |
+
"-o", str(self.temp_dir),
|
| 268 |
+
str(audio_path)
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
| 272 |
+
|
| 273 |
+
if result.returncode != 0:
|
| 274 |
+
error_msg = result.stderr.strip() or result.stdout.strip() or "Unknown error"
|
| 275 |
+
raise RuntimeError(f"Demucs failed (exit code {result.returncode}): {error_msg}")
|
| 276 |
+
|
| 277 |
+
# Demucs creates: temp/htdemucs/audio/*.wav
|
| 278 |
+
demucs_output = self.temp_dir / "htdemucs" / audio_path.stem
|
| 279 |
+
|
| 280 |
+
stems = {
|
| 281 |
+
'other': demucs_output / "other.wav",
|
| 282 |
+
'no_other': demucs_output / "no_other.wav",
|
| 283 |
+
}
|
| 284 |
|
| 285 |
+
# Verify output
|
| 286 |
+
if not stems['other'].exists():
|
| 287 |
+
raise RuntimeError("Demucs did not create expected output files")
|
| 288 |
|
| 289 |
+
return stems
|
| 290 |
|
| 291 |
def transcribe_to_midi(
|
| 292 |
self,
|
|
|
|
| 309 |
"""
|
| 310 |
output_dir = self.temp_dir
|
| 311 |
|
| 312 |
+
# Transcribe with ensemble or single model
|
| 313 |
+
if self.config.use_ensemble_transcription:
|
| 314 |
+
print(f" Transcribing with ensemble (YourMT3+ + ByteDance)...")
|
| 315 |
+
midi_path = self.transcribe_with_ensemble(audio_path)
|
| 316 |
+
print(f" ✓ Ensemble transcription complete")
|
| 317 |
+
else:
|
| 318 |
+
print(f" Transcribing with YourMT3+...")
|
| 319 |
+
midi_path = self.transcribe_with_yourmt3(audio_path)
|
| 320 |
+
print(f" ✓ YourMT3+ transcription complete")
|
| 321 |
|
| 322 |
# Rename final MIDI to standard name for post-processing
|
| 323 |
final_midi_path = output_dir / "piano.mid"
|
|
|
|
| 431 |
except Exception as e:
|
| 432 |
raise RuntimeError(f"YourMT3+ transcription failed: {e}")
|
| 433 |
|
| 434 |
+
def transcribe_with_ensemble(self, audio_path: Path) -> Path:
|
| 435 |
+
"""
|
| 436 |
+
Transcribe audio using ensemble of YourMT3+ and ByteDance.
|
| 437 |
+
|
| 438 |
+
Ensemble combines:
|
| 439 |
+
- YourMT3+: Multi-instrument generalist (80-85% accuracy)
|
| 440 |
+
- ByteDance: Piano specialist (90-95% accuracy)
|
| 441 |
+
- Result: 90-95% accuracy through voting
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
audio_path: Path to audio file (should be piano stem)
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
Path to ensemble MIDI file
|
| 448 |
+
|
| 449 |
+
Raises:
|
| 450 |
+
RuntimeError: If transcription fails
|
| 451 |
+
"""
|
| 452 |
+
try:
|
| 453 |
+
from yourmt3_wrapper import YourMT3Transcriber
|
| 454 |
+
from bytedance_wrapper import ByteDanceTranscriber
|
| 455 |
+
from ensemble_transcriber import EnsembleTranscriber
|
| 456 |
+
except ImportError:
|
| 457 |
+
# Try adding backend directory to path
|
| 458 |
+
import sys
|
| 459 |
+
from pathlib import Path as PathLib
|
| 460 |
+
backend_dir = PathLib(__file__).parent
|
| 461 |
+
if str(backend_dir) not in sys.path:
|
| 462 |
+
sys.path.insert(0, str(backend_dir))
|
| 463 |
+
from yourmt3_wrapper import YourMT3Transcriber
|
| 464 |
+
from bytedance_wrapper import ByteDanceTranscriber
|
| 465 |
+
from ensemble_transcriber import EnsembleTranscriber
|
| 466 |
+
|
| 467 |
+
try:
|
| 468 |
+
# Initialize transcribers
|
| 469 |
+
yourmt3 = YourMT3Transcriber(
|
| 470 |
+
model_name="YPTF.MoE+Multi (noPS)",
|
| 471 |
+
device=self.config.yourmt3_device
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
bytedance = ByteDanceTranscriber(
|
| 475 |
+
device=self.config.yourmt3_device, # Use same device
|
| 476 |
+
checkpoint=None # Auto-download default model
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# Initialize ensemble
|
| 480 |
+
ensemble = EnsembleTranscriber(
|
| 481 |
+
yourmt3_transcriber=yourmt3,
|
| 482 |
+
bytedance_transcriber=bytedance,
|
| 483 |
+
voting_strategy=self.config.ensemble_voting_strategy,
|
| 484 |
+
onset_tolerance_ms=self.config.ensemble_onset_tolerance_ms,
|
| 485 |
+
confidence_threshold=self.config.ensemble_confidence_threshold
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Transcribe with ensemble
|
| 489 |
+
output_dir = self.temp_dir / "ensemble_output"
|
| 490 |
+
output_dir.mkdir(exist_ok=True)
|
| 491 |
+
|
| 492 |
+
midi_path = ensemble.transcribe(audio_path, output_dir)
|
| 493 |
+
|
| 494 |
+
print(f" ✓ Ensemble transcription complete")
|
| 495 |
+
return midi_path
|
| 496 |
+
|
| 497 |
+
except Exception as e:
|
| 498 |
+
# Fallback to YourMT3+ only if ensemble fails
|
| 499 |
+
print(f" ⚠ Ensemble transcription failed: {e}")
|
| 500 |
+
print(f" Falling back to YourMT3+ only...")
|
| 501 |
+
return self.transcribe_with_yourmt3(audio_path)
|
| 502 |
+
|
| 503 |
+
def transcribe_vocals_to_midi(self, vocals_audio_path: Path) -> Path:
|
| 504 |
+
"""
|
| 505 |
+
Transcribe vocal melody to MIDI.
|
| 506 |
+
|
| 507 |
+
Uses YourMT3+ to transcribe vocals stem. YourMT3+ can transcribe melodies,
|
| 508 |
+
though it's primarily trained on multi-instrument music.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
vocals_audio_path: Path to vocals stem audio
|
| 512 |
+
|
| 513 |
+
Returns:
|
| 514 |
+
Path to vocals MIDI file
|
| 515 |
+
"""
|
| 516 |
+
print(f" Transcribing vocals with YourMT3+...")
|
| 517 |
+
|
| 518 |
+
# Use YourMT3+ for vocal transcription
|
| 519 |
+
# (Could use dedicated melody transcription model in future)
|
| 520 |
+
try:
|
| 521 |
+
from yourmt3_wrapper import YourMT3Transcriber
|
| 522 |
+
except ImportError:
|
| 523 |
+
import sys
|
| 524 |
+
from pathlib import Path as PathLib
|
| 525 |
+
backend_dir = PathLib(__file__).parent
|
| 526 |
+
if str(backend_dir) not in sys.path:
|
| 527 |
+
sys.path.insert(0, str(backend_dir))
|
| 528 |
+
from yourmt3_wrapper import YourMT3Transcriber
|
| 529 |
+
|
| 530 |
+
transcriber = YourMT3Transcriber(
|
| 531 |
+
model_name="YPTF.MoE+Multi (noPS)",
|
| 532 |
+
device=self.config.yourmt3_device
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
output_dir = self.temp_dir / "vocals_output"
|
| 536 |
+
output_dir.mkdir(exist_ok=True)
|
| 537 |
+
|
| 538 |
+
vocals_midi = transcriber.transcribe_audio(vocals_audio_path, output_dir)
|
| 539 |
+
|
| 540 |
+
print(f" ✓ Vocals transcription complete")
|
| 541 |
+
|
| 542 |
+
return vocals_midi
|
| 543 |
+
|
| 544 |
+
def merge_piano_and_vocals(
|
| 545 |
+
self,
|
| 546 |
+
piano_midi_path: Path,
|
| 547 |
+
vocals_midi_path: Path,
|
| 548 |
+
piano_program: int = 0,
|
| 549 |
+
vocal_program: int = 40
|
| 550 |
+
) -> Path:
|
| 551 |
+
"""
|
| 552 |
+
Merge piano and vocals MIDI into single file with proper instrument assignments.
|
| 553 |
+
|
| 554 |
+
Filters out spurious instruments from YourMT3+ output (keeps only piano notes),
|
| 555 |
+
then adds vocals on separate track with specified instrument.
|
| 556 |
+
|
| 557 |
+
Args:
|
| 558 |
+
piano_midi_path: Path to piano MIDI
|
| 559 |
+
vocals_midi_path: Path to vocals MIDI
|
| 560 |
+
piano_program: MIDI program for piano (0 = Acoustic Grand Piano)
|
| 561 |
+
vocal_program: MIDI program for vocals (40 = Violin, 73 = Flute, etc.)
|
| 562 |
+
|
| 563 |
+
Returns:
|
| 564 |
+
Path to merged MIDI file
|
| 565 |
+
"""
|
| 566 |
+
import pretty_midi
|
| 567 |
+
|
| 568 |
+
# Load piano MIDI
|
| 569 |
+
piano_pm = pretty_midi.PrettyMIDI(str(piano_midi_path))
|
| 570 |
+
|
| 571 |
+
# Load vocals MIDI
|
| 572 |
+
vocals_pm = pretty_midi.PrettyMIDI(str(vocals_midi_path))
|
| 573 |
+
|
| 574 |
+
# Create new MIDI file
|
| 575 |
+
merged_pm = pretty_midi.PrettyMIDI(initial_tempo=piano_pm.estimate_tempo())
|
| 576 |
+
|
| 577 |
+
# Add piano track (keep ONLY piano instrument 0, filter out false positives)
|
| 578 |
+
piano_instrument = pretty_midi.Instrument(program=piano_program, name="Piano")
|
| 579 |
+
|
| 580 |
+
# Collect all notes from piano MIDI, filtering to only program 0 (piano)
|
| 581 |
+
for inst in piano_pm.instruments:
|
| 582 |
+
if inst.is_drum:
|
| 583 |
+
continue
|
| 584 |
+
# Only keep notes from Acoustic Grand Piano (program 0)
|
| 585 |
+
# Discard organs, guitars, strings, etc. (false positives from YourMT3+)
|
| 586 |
+
if inst.program == 0:
|
| 587 |
+
piano_instrument.notes.extend(inst.notes)
|
| 588 |
+
|
| 589 |
+
print(f" Piano: {len(piano_instrument.notes)} notes (filtered from YourMT3+ output)")
|
| 590 |
+
|
| 591 |
+
# Add vocals track (keep highest/loudest notes - melody line)
|
| 592 |
+
vocal_instrument = pretty_midi.Instrument(program=vocal_program, name="Vocals")
|
| 593 |
+
|
| 594 |
+
# Collect vocals notes
|
| 595 |
+
# YourMT3+ may output multiple instruments for vocals - take the melody (highest notes)
|
| 596 |
+
all_vocal_notes = []
|
| 597 |
+
for inst in vocals_pm.instruments:
|
| 598 |
+
if inst.is_drum:
|
| 599 |
+
continue
|
| 600 |
+
all_vocal_notes.extend(inst.notes)
|
| 601 |
+
|
| 602 |
+
# Sort by time, then filter to monophonic melody (one note at a time)
|
| 603 |
+
all_vocal_notes.sort(key=lambda n: n.start)
|
| 604 |
+
|
| 605 |
+
# Simple melody extraction: at each time point, keep only highest note
|
| 606 |
+
melody_notes = []
|
| 607 |
+
if len(all_vocal_notes) > 0:
|
| 608 |
+
time_tolerance = 0.05 # 50ms tolerance for simultaneous notes
|
| 609 |
+
|
| 610 |
+
i = 0
|
| 611 |
+
while i < len(all_vocal_notes):
|
| 612 |
+
# Find all notes starting around the same time
|
| 613 |
+
current_time = all_vocal_notes[i].start
|
| 614 |
+
simultaneous = []
|
| 615 |
+
|
| 616 |
+
while i < len(all_vocal_notes) and all_vocal_notes[i].start - current_time < time_tolerance:
|
| 617 |
+
simultaneous.append(all_vocal_notes[i])
|
| 618 |
+
i += 1
|
| 619 |
+
|
| 620 |
+
# Keep only highest note (melody)
|
| 621 |
+
highest = max(simultaneous, key=lambda n: n.pitch)
|
| 622 |
+
melody_notes.append(highest)
|
| 623 |
+
|
| 624 |
+
vocal_instrument.notes.extend(melody_notes)
|
| 625 |
+
|
| 626 |
+
print(f" Vocals: {len(vocal_instrument.notes)} notes (melody extracted)")
|
| 627 |
+
|
| 628 |
+
# Add both instruments to merged MIDI
|
| 629 |
+
merged_pm.instruments.append(piano_instrument)
|
| 630 |
+
merged_pm.instruments.append(vocal_instrument)
|
| 631 |
+
|
| 632 |
+
# Save merged MIDI
|
| 633 |
+
merged_path = self.temp_dir / "merged_piano_vocals.mid"
|
| 634 |
+
merged_pm.write(str(merged_path))
|
| 635 |
+
|
| 636 |
+
print(f" ✓ Merged MIDI saved: {merged_path.name}")
|
| 637 |
+
print(f" Instruments: Piano (program {piano_program}), Vocals (program {vocal_program})")
|
| 638 |
+
|
| 639 |
+
return merged_path
|
| 640 |
+
|
| 641 |
+
def apply_post_processing_filters(self, midi_path: Path) -> Path:
|
| 642 |
+
"""
|
| 643 |
+
Apply post-processing filters to improve transcription quality.
|
| 644 |
+
|
| 645 |
+
Applies confidence filtering and key-aware filtering based on config.
|
| 646 |
+
|
| 647 |
+
Args:
|
| 648 |
+
midi_path: Input MIDI file
|
| 649 |
+
|
| 650 |
+
Returns:
|
| 651 |
+
Path to filtered MIDI file (or original if no filtering enabled)
|
| 652 |
+
"""
|
| 653 |
+
filtered_path = midi_path
|
| 654 |
+
|
| 655 |
+
# Apply confidence filtering
|
| 656 |
+
if self.config.enable_confidence_filtering:
|
| 657 |
+
print(f" Applying confidence filtering...")
|
| 658 |
+
|
| 659 |
+
try:
|
| 660 |
+
from confidence_filter import ConfidenceFilter
|
| 661 |
+
except ImportError:
|
| 662 |
+
import sys
|
| 663 |
+
from pathlib import Path as PathLib
|
| 664 |
+
backend_dir = PathLib(__file__).parent
|
| 665 |
+
if str(backend_dir) not in sys.path:
|
| 666 |
+
sys.path.insert(0, str(backend_dir))
|
| 667 |
+
from confidence_filter import ConfidenceFilter
|
| 668 |
+
|
| 669 |
+
filter = ConfidenceFilter(
|
| 670 |
+
confidence_threshold=self.config.confidence_threshold,
|
| 671 |
+
velocity_threshold=self.config.velocity_threshold,
|
| 672 |
+
duration_threshold=self.config.min_note_duration
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
filtered_path = filter.filter_midi_by_confidence(
|
| 676 |
+
filtered_path,
|
| 677 |
+
confidence_scores=None # Use heuristics for now
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
# Apply key-aware filtering
|
| 681 |
+
if self.config.enable_key_aware_filtering:
|
| 682 |
+
print(f" Applying key-aware filtering...")
|
| 683 |
+
|
| 684 |
+
# Need to detect key first (or get from MusicXML generation)
|
| 685 |
+
# For now, we'll apply after key detection in generate_musicxml_minimal
|
| 686 |
+
# Skip here to avoid redundant key detection
|
| 687 |
+
pass
|
| 688 |
+
|
| 689 |
+
return filtered_path
|
| 690 |
+
|
| 691 |
+
def apply_key_aware_filter(self, midi_path: Path, detected_key: str) -> Path:
|
| 692 |
+
"""
|
| 693 |
+
Apply key-aware filtering using detected key signature.
|
| 694 |
+
|
| 695 |
+
This is called from generate_musicxml_minimal after key detection.
|
| 696 |
+
|
| 697 |
+
Args:
|
| 698 |
+
midi_path: Input MIDI file
|
| 699 |
+
detected_key: Detected key signature (e.g., "C major")
|
| 700 |
+
|
| 701 |
+
Returns:
|
| 702 |
+
Path to filtered MIDI file
|
| 703 |
+
"""
|
| 704 |
+
if not self.config.enable_key_aware_filtering:
|
| 705 |
+
return midi_path
|
| 706 |
+
|
| 707 |
+
try:
|
| 708 |
+
from key_filter import KeyAwareFilter
|
| 709 |
+
except ImportError:
|
| 710 |
+
import sys
|
| 711 |
+
from pathlib import Path as PathLib
|
| 712 |
+
backend_dir = PathLib(__file__).parent
|
| 713 |
+
if str(backend_dir) not in sys.path:
|
| 714 |
+
sys.path.insert(0, str(backend_dir))
|
| 715 |
+
from key_filter import KeyAwareFilter
|
| 716 |
+
|
| 717 |
+
filter = KeyAwareFilter(
|
| 718 |
+
allow_chromatic=self.config.allow_chromatic_passing_tones,
|
| 719 |
+
isolation_threshold=self.config.isolation_threshold
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
filtered_path = filter.filter_midi_by_key(
|
| 723 |
+
midi_path,
|
| 724 |
+
detected_key=detected_key
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
return filtered_path
|
| 728 |
+
|
| 729 |
def _get_midi_range(self, midi_path: Path) -> int:
|
| 730 |
"""
|
| 731 |
Calculate the MIDI note range (max - min) in semitones.
|
|
|
|
| 1145 |
if msg.type in ('note_on', 'note_off'):
|
| 1146 |
last_note_time = abs_time
|
| 1147 |
|
| 1148 |
+
# Add end_of_track after last note with proper delta
|
| 1149 |
from mido import MetaMessage
|
| 1150 |
+
# Use 1 beat gap after last note for clean ending
|
| 1151 |
+
gap_after_last_note = mid.ticks_per_beat
|
| 1152 |
+
end_msg = MetaMessage('end_of_track', time=gap_after_last_note)
|
| 1153 |
track.append(end_msg)
|
| 1154 |
|
| 1155 |
# 5. Save beat-quantized MIDI
|
backend/requirements-test.txt
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
# Test dependencies for Rescored backend
|
| 2 |
-
-r requirements.txt
|
| 3 |
-
|
| 4 |
-
# Testing framework
|
| 5 |
-
pytest==8.2.0
|
| 6 |
-
pytest-asyncio==0.24.0
|
| 7 |
-
pytest-cov==4.1.0
|
| 8 |
-
pytest-mock==3.12.0
|
| 9 |
-
|
| 10 |
-
# HTTP testing
|
| 11 |
-
httpx==0.26.0
|
| 12 |
-
|
| 13 |
-
# Test utilities
|
| 14 |
-
faker==22.5.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/requirements.txt
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
# Web Framework
|
| 2 |
-
# Note: This file now includes torch/torchaudio as they are required by demucs on macOS
|
| 3 |
fastapi==0.115.5
|
| 4 |
uvicorn[standard]==0.32.1
|
| 5 |
python-multipart==0.0.20
|
|
@@ -12,11 +11,11 @@ redis==5.2.1
|
|
| 12 |
yt-dlp>=2025.12.8
|
| 13 |
soundfile==0.12.1
|
| 14 |
librosa>=0.11.0
|
|
|
|
| 15 |
Cython # Required by madmom
|
| 16 |
madmom>=0.16.1 # Zero-tradeoff: Beat tracking and multi-scale tempo detection
|
| 17 |
-
scipy
|
| 18 |
torch>=2.0.0
|
| 19 |
-
torchaudio
|
| 20 |
demucs>=3.0.6
|
| 21 |
audio-separator>=0.40.0 # BS-RoFormer and UVR models for better vocal separation
|
| 22 |
|
|
@@ -41,7 +40,6 @@ python-dotenv==1.0.1
|
|
| 41 |
tenacity==9.0.0
|
| 42 |
pydantic==2.10.4
|
| 43 |
pydantic-settings==2.7.0
|
| 44 |
-
numpy<2.0.0
|
| 45 |
filelock # Required by huggingface-hub and transformers
|
| 46 |
pyyaml>=5.1 # Required by huggingface-hub and transformers
|
| 47 |
requests>=2.21.0 # Required by tensorflow and transformers
|
|
@@ -51,3 +49,14 @@ wrapt>=1.11.0 # Required by tensorflow
|
|
| 51 |
|
| 52 |
# WebSocket
|
| 53 |
websockets==14.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Web Framework
|
|
|
|
| 2 |
fastapi==0.115.5
|
| 3 |
uvicorn[standard]==0.32.1
|
| 4 |
python-multipart==0.0.20
|
|
|
|
| 11 |
yt-dlp>=2025.12.8
|
| 12 |
soundfile==0.12.1
|
| 13 |
librosa>=0.11.0
|
| 14 |
+
scipy>=1.10.0
|
| 15 |
Cython # Required by madmom
|
| 16 |
madmom>=0.16.1 # Zero-tradeoff: Beat tracking and multi-scale tempo detection
|
|
|
|
| 17 |
torch>=2.0.0
|
| 18 |
+
torchaudio>=2.0.0
|
| 19 |
demucs>=3.0.6
|
| 20 |
audio-separator>=0.40.0 # BS-RoFormer and UVR models for better vocal separation
|
| 21 |
|
|
|
|
| 40 |
tenacity==9.0.0
|
| 41 |
pydantic==2.10.4
|
| 42 |
pydantic-settings==2.7.0
|
|
|
|
| 43 |
filelock # Required by huggingface-hub and transformers
|
| 44 |
pyyaml>=5.1 # Required by huggingface-hub and transformers
|
| 45 |
requests>=2.21.0 # Required by tensorflow and transformers
|
|
|
|
| 49 |
|
| 50 |
# WebSocket
|
| 51 |
websockets==14.1
|
| 52 |
+
|
| 53 |
+
# Testing dependencies
|
| 54 |
+
pytest==8.2.0
|
| 55 |
+
pytest-asyncio==0.24.0
|
| 56 |
+
pytest-cov==4.1.0
|
| 57 |
+
pytest-mock==3.12.0
|
| 58 |
+
httpx==0.26.0
|
| 59 |
+
faker==22.5.1
|
| 60 |
+
|
| 61 |
+
# Audio preprocessing (optional, for noise reduction)
|
| 62 |
+
noisereduce>=3.0.0 # For spectral denoising in audio preprocessor
|
backend/tests/conftest.py
CHANGED
|
@@ -32,7 +32,7 @@ def mock_redis():
|
|
| 32 |
def test_client(mock_redis, temp_storage_dir):
|
| 33 |
"""Create FastAPI test client with mocked dependencies."""
|
| 34 |
with patch('main.redis_client', mock_redis):
|
| 35 |
-
with patch('
|
| 36 |
from main import app
|
| 37 |
client = TestClient(app)
|
| 38 |
yield client
|
|
|
|
| 32 |
def test_client(mock_redis, temp_storage_dir):
|
| 33 |
"""Create FastAPI test client with mocked dependencies."""
|
| 34 |
with patch('main.redis_client', mock_redis):
|
| 35 |
+
with patch('app_config.settings.storage_path', temp_storage_dir):
|
| 36 |
from main import app
|
| 37 |
client = TestClient(app)
|
| 38 |
yield client
|
backend/tests/test_api.py
CHANGED
|
@@ -45,8 +45,8 @@ class TestTranscribeEndpoint:
|
|
| 45 |
"""Test transcription submission endpoint."""
|
| 46 |
|
| 47 |
@patch('main.process_transcription_task')
|
| 48 |
-
@patch('
|
| 49 |
-
@patch('
|
| 50 |
def test_submit_valid_transcription(
|
| 51 |
self,
|
| 52 |
mock_validate,
|
|
@@ -81,7 +81,7 @@ class TestTranscribeEndpoint:
|
|
| 81 |
# Verify Celery task was queued
|
| 82 |
assert mock_task.delay.called
|
| 83 |
|
| 84 |
-
@patch('
|
| 85 |
def test_submit_invalid_url(self, mock_validate, test_client):
|
| 86 |
"""Test submitting invalid YouTube URL."""
|
| 87 |
mock_validate.return_value = (False, "Invalid YouTube URL format")
|
|
@@ -117,8 +117,8 @@ class TestTranscribeEndpoint:
|
|
| 117 |
assert response.status_code == 422
|
| 118 |
assert "too long" in response.json()["detail"]
|
| 119 |
|
| 120 |
-
@patch('
|
| 121 |
-
@patch('
|
| 122 |
def test_submit_with_options(
|
| 123 |
self,
|
| 124 |
mock_check_availability,
|
|
@@ -145,8 +145,8 @@ class TestTranscribeEndpoint:
|
|
| 145 |
class TestRateLimiting:
|
| 146 |
"""Test rate limiting middleware."""
|
| 147 |
|
| 148 |
-
@patch('
|
| 149 |
-
@patch('
|
| 150 |
@patch('main.process_transcription_task')
|
| 151 |
def test_rate_limit_enforced(
|
| 152 |
self,
|
|
@@ -172,8 +172,8 @@ class TestRateLimiting:
|
|
| 172 |
assert response.status_code == 429
|
| 173 |
assert "Rate limit exceeded" in response.json()["detail"]
|
| 174 |
|
| 175 |
-
@patch('
|
| 176 |
-
@patch('
|
| 177 |
@patch('main.process_transcription_task')
|
| 178 |
def test_rate_limit_under_limit(
|
| 179 |
self,
|
|
|
|
| 45 |
"""Test transcription submission endpoint."""
|
| 46 |
|
| 47 |
@patch('main.process_transcription_task')
|
| 48 |
+
@patch('app_utils.check_video_availability')
|
| 49 |
+
@patch('main.validate_youtube_url')
|
| 50 |
def test_submit_valid_transcription(
|
| 51 |
self,
|
| 52 |
mock_validate,
|
|
|
|
| 81 |
# Verify Celery task was queued
|
| 82 |
assert mock_task.delay.called
|
| 83 |
|
| 84 |
+
@patch('main.validate_youtube_url')
|
| 85 |
def test_submit_invalid_url(self, mock_validate, test_client):
|
| 86 |
"""Test submitting invalid YouTube URL."""
|
| 87 |
mock_validate.return_value = (False, "Invalid YouTube URL format")
|
|
|
|
| 117 |
assert response.status_code == 422
|
| 118 |
assert "too long" in response.json()["detail"]
|
| 119 |
|
| 120 |
+
@patch('main.validate_youtube_url')
|
| 121 |
+
@patch('main.check_video_availability')
|
| 122 |
def test_submit_with_options(
|
| 123 |
self,
|
| 124 |
mock_check_availability,
|
|
|
|
| 145 |
class TestRateLimiting:
|
| 146 |
"""Test rate limiting middleware."""
|
| 147 |
|
| 148 |
+
@patch('main.validate_youtube_url')
|
| 149 |
+
@patch('main.check_video_availability')
|
| 150 |
@patch('main.process_transcription_task')
|
| 151 |
def test_rate_limit_enforced(
|
| 152 |
self,
|
|
|
|
| 172 |
assert response.status_code == 429
|
| 173 |
assert "Rate limit exceeded" in response.json()["detail"]
|
| 174 |
|
| 175 |
+
@patch('main.validate_youtube_url')
|
| 176 |
+
@patch('main.check_video_availability')
|
| 177 |
@patch('main.process_transcription_task')
|
| 178 |
def test_rate_limit_under_limit(
|
| 179 |
self,
|
backend/tests/test_pipeline.py
CHANGED
|
@@ -85,18 +85,18 @@ class TestTranscriptionPipelineClass:
|
|
| 85 |
def test_pipeline_has_required_methods(self):
|
| 86 |
"""Test TranscriptionPipeline has all required methods."""
|
| 87 |
from pipeline import TranscriptionPipeline
|
| 88 |
-
|
| 89 |
pipeline = TranscriptionPipeline("test_job", "http://example.com", Path("/tmp"))
|
| 90 |
-
|
| 91 |
required_methods = [
|
| 92 |
'download_audio',
|
| 93 |
'separate_sources',
|
| 94 |
'transcribe_to_midi',
|
| 95 |
'clean_midi',
|
| 96 |
-
'
|
| 97 |
'cleanup'
|
| 98 |
]
|
| 99 |
-
|
| 100 |
for method in required_methods:
|
| 101 |
assert hasattr(pipeline, method)
|
| 102 |
assert callable(getattr(pipeline, method))
|
|
|
|
| 85 |
def test_pipeline_has_required_methods(self):
|
| 86 |
"""Test TranscriptionPipeline has all required methods."""
|
| 87 |
from pipeline import TranscriptionPipeline
|
| 88 |
+
|
| 89 |
pipeline = TranscriptionPipeline("test_job", "http://example.com", Path("/tmp"))
|
| 90 |
+
|
| 91 |
required_methods = [
|
| 92 |
'download_audio',
|
| 93 |
'separate_sources',
|
| 94 |
'transcribe_to_midi',
|
| 95 |
'clean_midi',
|
| 96 |
+
'generate_musicxml_minimal',
|
| 97 |
'cleanup'
|
| 98 |
]
|
| 99 |
+
|
| 100 |
for method in required_methods:
|
| 101 |
assert hasattr(pipeline, method)
|
| 102 |
assert callable(getattr(pipeline, method))
|
backend/tests/test_pipeline_fixes.py
DELETED
|
@@ -1,605 +0,0 @@
|
|
| 1 |
-
"""Unit tests for pipeline fixes (Issues #6, #7, #8)."""
|
| 2 |
-
import pytest
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
import mido
|
| 5 |
-
from music21 import note, chord, stream, converter
|
| 6 |
-
from pipeline import TranscriptionPipeline
|
| 7 |
-
from app_config import Settings
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
@pytest.fixture
|
| 11 |
-
def pipeline(temp_storage_dir):
|
| 12 |
-
"""Create a TranscriptionPipeline instance for testing."""
|
| 13 |
-
config = Settings(storage_path=temp_storage_dir)
|
| 14 |
-
return TranscriptionPipeline(
|
| 15 |
-
job_id="test-job",
|
| 16 |
-
youtube_url="https://www.youtube.com/watch?v=test",
|
| 17 |
-
storage_path=temp_storage_dir,
|
| 18 |
-
config=config
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
@pytest.fixture
|
| 23 |
-
def midi_with_sequential_notes(temp_storage_dir):
|
| 24 |
-
"""Create MIDI file with sequential notes of same pitch with small gaps."""
|
| 25 |
-
mid = mido.MidiFile()
|
| 26 |
-
track = mido.MidiTrack()
|
| 27 |
-
mid.tracks.append(track)
|
| 28 |
-
|
| 29 |
-
# Note 1: C4 (60) from 0-480 ticks (1 beat)
|
| 30 |
-
track.append(mido.Message('note_on', note=60, velocity=64, time=0))
|
| 31 |
-
track.append(mido.Message('note_off', note=60, velocity=64, time=480))
|
| 32 |
-
|
| 33 |
-
# Tiny gap of 10 ticks (~20ms at 120 BPM)
|
| 34 |
-
# Note 2: C4 (60) from 490-970 ticks (1 beat)
|
| 35 |
-
track.append(mido.Message('note_on', note=60, velocity=64, time=10))
|
| 36 |
-
track.append(mido.Message('note_off', note=60, velocity=64, time=480))
|
| 37 |
-
|
| 38 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 39 |
-
|
| 40 |
-
midi_path = temp_storage_dir / "sequential_notes.mid"
|
| 41 |
-
mid.save(str(midi_path))
|
| 42 |
-
return midi_path
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
@pytest.fixture
|
| 46 |
-
def midi_with_low_velocity_notes(temp_storage_dir):
|
| 47 |
-
"""Create MIDI file with low velocity notes (noise)."""
|
| 48 |
-
mid = mido.MidiFile()
|
| 49 |
-
track = mido.MidiTrack()
|
| 50 |
-
mid.tracks.append(track)
|
| 51 |
-
|
| 52 |
-
# Loud note (should keep)
|
| 53 |
-
track.append(mido.Message('note_on', note=60, velocity=64, time=0))
|
| 54 |
-
track.append(mido.Message('note_off', note=60, velocity=64, time=480))
|
| 55 |
-
|
| 56 |
-
# Very quiet note (should filter - velocity < 45)
|
| 57 |
-
track.append(mido.Message('note_on', note=62, velocity=30, time=0))
|
| 58 |
-
track.append(mido.Message('note_off', note=62, velocity=30, time=480))
|
| 59 |
-
|
| 60 |
-
# Another loud note
|
| 61 |
-
track.append(mido.Message('note_on', note=64, velocity=80, time=0))
|
| 62 |
-
track.append(mido.Message('note_off', note=64, velocity=80, time=480))
|
| 63 |
-
|
| 64 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 65 |
-
|
| 66 |
-
midi_path = temp_storage_dir / "low_velocity.mid"
|
| 67 |
-
mid.save(str(midi_path))
|
| 68 |
-
return midi_path
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
@pytest.fixture
|
| 72 |
-
def score_with_tiny_gaps():
|
| 73 |
-
"""Create music21 score with sequential notes that have tiny gaps."""
|
| 74 |
-
s = stream.Score()
|
| 75 |
-
p = stream.Part()
|
| 76 |
-
|
| 77 |
-
# Note 1: C4 at offset 0.0, duration 1.0 QN
|
| 78 |
-
n1 = note.Note('C4', quarterLength=1.0)
|
| 79 |
-
p.insert(0.0, n1)
|
| 80 |
-
|
| 81 |
-
# Note 2: C4 at offset 1.01 (tiny gap of 0.01 QN) - should merge
|
| 82 |
-
n2 = note.Note('C4', quarterLength=1.0)
|
| 83 |
-
p.insert(1.01, n2)
|
| 84 |
-
|
| 85 |
-
# Note 3: C4 at offset 2.05 (larger gap of 0.04 QN) - should still merge (< 0.02 threshold)
|
| 86 |
-
# Actually, this should NOT merge as gap > 0.02
|
| 87 |
-
n3 = note.Note('C4', quarterLength=1.0)
|
| 88 |
-
p.insert(2.05, n3)
|
| 89 |
-
|
| 90 |
-
s.append(p)
|
| 91 |
-
return s
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
@pytest.fixture
|
| 95 |
-
def score_with_chord():
|
| 96 |
-
"""Create music21 score with a chord (should not merge chord notes)."""
|
| 97 |
-
s = stream.Score()
|
| 98 |
-
p = stream.Part()
|
| 99 |
-
|
| 100 |
-
# Chord: C4, E4, G4 at offset 0.0
|
| 101 |
-
c = chord.Chord(['C4', 'E4', 'G4'], quarterLength=2.0)
|
| 102 |
-
p.insert(0.0, c)
|
| 103 |
-
|
| 104 |
-
s.append(p)
|
| 105 |
-
return s
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
@pytest.fixture
|
| 109 |
-
def score_with_staccato():
|
| 110 |
-
"""Create music21 score with staccato notes (large gaps - should NOT merge)."""
|
| 111 |
-
s = stream.Score()
|
| 112 |
-
p = stream.Part()
|
| 113 |
-
|
| 114 |
-
# Note 1: C4 at offset 0.0, duration 0.5 QN
|
| 115 |
-
n1 = note.Note('C4', quarterLength=0.5)
|
| 116 |
-
p.insert(0.0, n1)
|
| 117 |
-
|
| 118 |
-
# Note 2: C4 at offset 1.0 (gap of 0.5 QN - staccato, should NOT merge)
|
| 119 |
-
n2 = note.Note('C4', quarterLength=0.5)
|
| 120 |
-
p.insert(1.0, n2)
|
| 121 |
-
|
| 122 |
-
s.append(p)
|
| 123 |
-
return s
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
class TestIssue8MergeMusic21Notes:
|
| 127 |
-
"""Tests for Issue #8: Tiny rests between notes."""
|
| 128 |
-
|
| 129 |
-
def test_merge_sequential_notes_same_pitch(self, pipeline, score_with_tiny_gaps):
|
| 130 |
-
"""Sequential notes of same pitch with small gap should merge."""
|
| 131 |
-
# Get notes before merging
|
| 132 |
-
notes_before = list(score_with_tiny_gaps.flatten().notes)
|
| 133 |
-
assert len(notes_before) == 3 # 3 separate C4 notes
|
| 134 |
-
|
| 135 |
-
# Merge with 0.02 QN threshold
|
| 136 |
-
merged_score = pipeline._merge_music21_notes(score_with_tiny_gaps, gap_threshold_qn=0.02)
|
| 137 |
-
|
| 138 |
-
# Get notes after merging
|
| 139 |
-
notes_after = list(merged_score.flatten().notes)
|
| 140 |
-
|
| 141 |
-
# First two notes should merge (gap 0.01 < 0.02)
|
| 142 |
-
# Third note should NOT merge (gap 0.04 > 0.02 from second note's end)
|
| 143 |
-
# Actually need to recalculate: n1 ends at 1.0, n2 at 2.01, gap to n3 at 2.05 = 0.04
|
| 144 |
-
assert len(notes_after) == 2, f"Expected 2 notes after merging, got {len(notes_after)}"
|
| 145 |
-
|
| 146 |
-
# First note should have extended duration
|
| 147 |
-
first_note = notes_after[0]
|
| 148 |
-
assert first_note.pitch.midi == 60 # C4
|
| 149 |
-
# Duration should cover both first and second note: 0.0 to 2.01 = 2.01 QN
|
| 150 |
-
assert abs(first_note.quarterLength - 2.01) < 0.001, \
|
| 151 |
-
f"Expected duration ~2.01, got {first_note.quarterLength}"
|
| 152 |
-
|
| 153 |
-
def test_dont_merge_different_pitches(self, pipeline):
|
| 154 |
-
"""Notes with different pitches should NOT merge."""
|
| 155 |
-
s = stream.Score()
|
| 156 |
-
p = stream.Part()
|
| 157 |
-
|
| 158 |
-
# C4, then D4 with tiny gap - should NOT merge
|
| 159 |
-
p.insert(0.0, note.Note('C4', quarterLength=1.0))
|
| 160 |
-
p.insert(1.01, note.Note('D4', quarterLength=1.0))
|
| 161 |
-
|
| 162 |
-
s.append(p)
|
| 163 |
-
|
| 164 |
-
merged_score = pipeline._merge_music21_notes(s, gap_threshold_qn=0.02)
|
| 165 |
-
notes_after = list(merged_score.flatten().notes)
|
| 166 |
-
|
| 167 |
-
assert len(notes_after) == 2, "Different pitches should not merge"
|
| 168 |
-
assert notes_after[0].pitch.midi == 60 # C4
|
| 169 |
-
assert notes_after[1].pitch.midi == 62 # D4
|
| 170 |
-
|
| 171 |
-
def test_dont_merge_large_gaps(self, pipeline, score_with_staccato):
|
| 172 |
-
"""Notes with large gaps (staccato) should NOT merge."""
|
| 173 |
-
notes_before = list(score_with_staccato.flatten().notes)
|
| 174 |
-
assert len(notes_before) == 2
|
| 175 |
-
|
| 176 |
-
merged_score = pipeline._merge_music21_notes(score_with_staccato, gap_threshold_qn=0.02)
|
| 177 |
-
notes_after = list(merged_score.flatten().notes)
|
| 178 |
-
|
| 179 |
-
# Gap is 0.5 QN (n1 ends at 0.5, n2 starts at 1.0)
|
| 180 |
-
# Should NOT merge (0.5 > 0.02)
|
| 181 |
-
assert len(notes_after) == 2, "Staccato notes with large gaps should not merge"
|
| 182 |
-
assert notes_after[0].quarterLength == 0.5
|
| 183 |
-
assert notes_after[1].quarterLength == 0.5
|
| 184 |
-
|
| 185 |
-
def test_dont_merge_same_chord_notes(self, pipeline, score_with_chord):
|
| 186 |
-
"""Notes from the SAME chord should NOT merge into one note."""
|
| 187 |
-
chords_before = list(score_with_chord.flatten().getElementsByClass(chord.Chord))
|
| 188 |
-
assert len(chords_before) == 1
|
| 189 |
-
assert len(chords_before[0].pitches) == 3 # C, E, G
|
| 190 |
-
|
| 191 |
-
merged_score = pipeline._merge_music21_notes(score_with_chord, gap_threshold_qn=0.02)
|
| 192 |
-
|
| 193 |
-
# Chord should remain intact
|
| 194 |
-
chords_after = list(merged_score.flatten().getElementsByClass(chord.Chord))
|
| 195 |
-
assert len(chords_after) == 1, "Chord should remain"
|
| 196 |
-
assert len(chords_after[0].pitches) == 3, "Chord should still have 3 notes"
|
| 197 |
-
|
| 198 |
-
def test_merge_across_different_velocities(self, pipeline):
|
| 199 |
-
"""Sequential notes with different velocities but same pitch should merge."""
|
| 200 |
-
s = stream.Score()
|
| 201 |
-
p = stream.Part()
|
| 202 |
-
|
| 203 |
-
# Create notes with different velocities
|
| 204 |
-
n1 = note.Note('C4', quarterLength=1.0)
|
| 205 |
-
n1.volume.velocity = 64
|
| 206 |
-
p.insert(0.0, n1)
|
| 207 |
-
|
| 208 |
-
n2 = note.Note('C4', quarterLength=1.0)
|
| 209 |
-
n2.volume.velocity = 80
|
| 210 |
-
p.insert(1.01, n2) # Small gap
|
| 211 |
-
|
| 212 |
-
s.append(p)
|
| 213 |
-
|
| 214 |
-
merged_score = pipeline._merge_music21_notes(s, gap_threshold_qn=0.02)
|
| 215 |
-
notes_after = list(merged_score.flatten().notes)
|
| 216 |
-
|
| 217 |
-
# Should merge despite different velocities
|
| 218 |
-
assert len(notes_after) == 1, "Notes with different velocities should still merge"
|
| 219 |
-
assert abs(notes_after[0].quarterLength - 2.01) < 0.001
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
class TestIssue6NoiseFiltering:
|
| 223 |
-
"""Tests for Issue #6: Random noise notes."""
|
| 224 |
-
|
| 225 |
-
def test_filters_low_velocity_notes(self, pipeline, midi_with_low_velocity_notes):
|
| 226 |
-
"""Notes with velocity < 45 should be filtered."""
|
| 227 |
-
# Process MIDI through cleanup
|
| 228 |
-
cleaned_midi = pipeline.clean_midi(midi_with_low_velocity_notes)
|
| 229 |
-
|
| 230 |
-
# Load cleaned MIDI
|
| 231 |
-
mid = mido.MidiFile(cleaned_midi)
|
| 232 |
-
|
| 233 |
-
# Count note_on messages
|
| 234 |
-
note_ons = []
|
| 235 |
-
for track in mid.tracks:
|
| 236 |
-
for msg in track:
|
| 237 |
-
if msg.type == 'note_on' and msg.velocity > 0:
|
| 238 |
-
note_ons.append((msg.note, msg.velocity))
|
| 239 |
-
|
| 240 |
-
# Should have 2 notes (60 vel=64, 64 vel=80), not 3
|
| 241 |
-
assert len(note_ons) == 2, f"Expected 2 notes after filtering, got {len(note_ons)}"
|
| 242 |
-
|
| 243 |
-
# Check that low velocity note (62) is not present
|
| 244 |
-
notes = [n[0] for n in note_ons]
|
| 245 |
-
assert 62 not in notes, "Low velocity note should be filtered"
|
| 246 |
-
assert 60 in notes, "High velocity note should remain"
|
| 247 |
-
assert 64 in notes, "High velocity note should remain"
|
| 248 |
-
|
| 249 |
-
def test_keeps_moderate_velocity_notes(self, pipeline, temp_storage_dir):
|
| 250 |
-
"""Notes with velocity >= 45 should be kept."""
|
| 251 |
-
mid = mido.MidiFile()
|
| 252 |
-
track = mido.MidiTrack()
|
| 253 |
-
mid.tracks.append(track)
|
| 254 |
-
|
| 255 |
-
# Note with velocity exactly at threshold (45)
|
| 256 |
-
track.append(mido.Message('note_on', note=60, velocity=45, time=0))
|
| 257 |
-
track.append(mido.Message('note_off', note=60, velocity=45, time=480))
|
| 258 |
-
|
| 259 |
-
# Note with velocity above threshold (60)
|
| 260 |
-
track.append(mido.Message('note_on', note=62, velocity=60, time=0))
|
| 261 |
-
track.append(mido.Message('note_off', note=62, velocity=60, time=480))
|
| 262 |
-
|
| 263 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 264 |
-
|
| 265 |
-
midi_path = temp_storage_dir / "moderate_velocity.mid"
|
| 266 |
-
mid.save(str(midi_path))
|
| 267 |
-
|
| 268 |
-
cleaned_midi = pipeline.clean_midi(midi_path)
|
| 269 |
-
mid_cleaned = mido.MidiFile(cleaned_midi)
|
| 270 |
-
|
| 271 |
-
note_ons = []
|
| 272 |
-
for track in mid_cleaned.tracks:
|
| 273 |
-
for msg in track:
|
| 274 |
-
if msg.type == 'note_on' and msg.velocity > 0:
|
| 275 |
-
note_ons.append(msg.note)
|
| 276 |
-
|
| 277 |
-
# Both notes should be kept
|
| 278 |
-
assert len(note_ons) == 2, "Notes with velocity >= 45 should be kept"
|
| 279 |
-
assert 60 in note_ons
|
| 280 |
-
assert 62 in note_ons
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
class TestIssue7MeasureNormalization:
|
| 284 |
-
"""Tests for Issue #7: Corrupted measures."""
|
| 285 |
-
|
| 286 |
-
def test_deduplication_bucketing_precision(self, pipeline):
|
| 287 |
-
"""Deduplication should use 0.005 QN bucketing (5ms at 120 BPM)."""
|
| 288 |
-
s = stream.Score()
|
| 289 |
-
p = stream.Part()
|
| 290 |
-
|
| 291 |
-
# Create two notes very close together (should be in same bucket)
|
| 292 |
-
# At 0.0 and 0.003 QN (~3ms) - should be bucketed together
|
| 293 |
-
# After bucketing: 0.0 -> bucket 0.0, 0.003 -> bucket 0.005 (rounded)
|
| 294 |
-
# Actually need to be within the same bucket after rounding
|
| 295 |
-
n1 = note.Note('C4', quarterLength=1.0)
|
| 296 |
-
p.insert(0.0, n1)
|
| 297 |
-
|
| 298 |
-
n2 = note.Note('C4', quarterLength=1.0)
|
| 299 |
-
p.insert(0.002, n2) # 0.002 rounds to 0.0 bucket (0.002/0.005 = 0.4 -> rounds to 0)
|
| 300 |
-
|
| 301 |
-
s.append(p)
|
| 302 |
-
|
| 303 |
-
# Deduplicate
|
| 304 |
-
deduped_score = pipeline._deduplicate_overlapping_notes(s)
|
| 305 |
-
notes_after = list(deduped_score.flatten().notes)
|
| 306 |
-
|
| 307 |
-
# Should merge into one note (duplicate in same bucket)
|
| 308 |
-
assert len(notes_after) == 1, "Notes within same 0.005 QN bucket should be deduplicated"
|
| 309 |
-
|
| 310 |
-
def test_skip_threshold_relaxed(self, pipeline):
|
| 311 |
-
"""Normalization should skip measures within 0.05 QN of correct duration."""
|
| 312 |
-
s = stream.Score()
|
| 313 |
-
p = stream.Part()
|
| 314 |
-
|
| 315 |
-
# Create a measure that's slightly off (3.98 QN instead of 4.0)
|
| 316 |
-
# This is within 0.05 tolerance, so should be skipped
|
| 317 |
-
n1 = note.Note('C4', quarterLength=1.98)
|
| 318 |
-
n2 = note.Note('D4', quarterLength=2.0)
|
| 319 |
-
|
| 320 |
-
p.insert(0.0, n1)
|
| 321 |
-
p.insert(1.98, n2)
|
| 322 |
-
|
| 323 |
-
s.append(p)
|
| 324 |
-
s = s.makeMeasures()
|
| 325 |
-
|
| 326 |
-
# Normalize with 4/4 time signature
|
| 327 |
-
normalized = pipeline._normalize_measure_durations(s, 4, 4)
|
| 328 |
-
|
| 329 |
-
# Get first measure
|
| 330 |
-
measures = normalized.parts[0].getElementsByClass('Measure')
|
| 331 |
-
if measures:
|
| 332 |
-
first_measure = measures[0]
|
| 333 |
-
elements = list(first_measure.notesAndRests)
|
| 334 |
-
|
| 335 |
-
# Measure should not be modified (within tolerance)
|
| 336 |
-
# Total duration should still be ~3.98
|
| 337 |
-
total_duration = sum(e.quarterLength for e in elements)
|
| 338 |
-
assert abs(total_duration - 3.98) < 0.1, \
|
| 339 |
-
"Measure within tolerance should not be heavily modified"
|
| 340 |
-
|
| 341 |
-
def test_rest_fill_minimum_lowered(self, pipeline):
|
| 342 |
-
"""Gaps > 0.15 QN (new tolerance) should be filled with rests, smaller gaps skipped."""
|
| 343 |
-
s = stream.Score()
|
| 344 |
-
p = stream.Part()
|
| 345 |
-
|
| 346 |
-
# Create measure with gap LARGER than tolerance (0.15 QN)
|
| 347 |
-
# Total: 3.80 QN (gap of 0.20 QN to fill to 4.0) - exceeds 0.15 tolerance
|
| 348 |
-
# This should trigger normalization and rest filling
|
| 349 |
-
n1 = note.Note('C4', quarterLength=2.0)
|
| 350 |
-
n2 = note.Note('D4', quarterLength=1.80)
|
| 351 |
-
|
| 352 |
-
p.insert(0.0, n1)
|
| 353 |
-
p.insert(2.0, n2)
|
| 354 |
-
|
| 355 |
-
s.append(p)
|
| 356 |
-
s = s.makeMeasures()
|
| 357 |
-
|
| 358 |
-
# Normalize - should add rest to fill 0.20 QN gap (exceeds 0.15 tolerance)
|
| 359 |
-
normalized = pipeline._normalize_measure_durations(s, 4, 4)
|
| 360 |
-
|
| 361 |
-
measures = normalized.parts[0].getElementsByClass('Measure')
|
| 362 |
-
if measures:
|
| 363 |
-
first_measure = measures[0]
|
| 364 |
-
elements = list(first_measure.notesAndRests)
|
| 365 |
-
|
| 366 |
-
# Total duration should now be close to 4.0 after normalization
|
| 367 |
-
total_duration = sum(e.quarterLength for e in elements)
|
| 368 |
-
|
| 369 |
-
# With 0.20 gap (> 0.15 tolerance), should have been normalized
|
| 370 |
-
# Either proportionally scaled or filled with rest
|
| 371 |
-
assert abs(total_duration - 4.0) < 0.15, \
|
| 372 |
-
f"Measure should be normalized to ~4.0 QN, got {total_duration}"
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
class TestIntegration:
|
| 376 |
-
"""Integration tests for full pipeline."""
|
| 377 |
-
|
| 378 |
-
def test_onset_threshold_config(self):
|
| 379 |
-
"""Config should have onset_threshold = 0.5 (increased to reduce false positives)."""
|
| 380 |
-
config = Settings()
|
| 381 |
-
assert config.onset_threshold == 0.5, \
|
| 382 |
-
f"onset_threshold should be 0.5, got {config.onset_threshold}"
|
| 383 |
-
|
| 384 |
-
def test_sequential_note_merging_in_pipeline(self, pipeline, midi_with_sequential_notes):
|
| 385 |
-
"""Full pipeline should merge sequential notes."""
|
| 386 |
-
# Convert MIDI to music21
|
| 387 |
-
score = converter.parse(str(midi_with_sequential_notes))
|
| 388 |
-
|
| 389 |
-
# Check notes before merging
|
| 390 |
-
notes_before = list(score.flatten().notes)
|
| 391 |
-
# MIDI has 2 notes with tiny gap
|
| 392 |
-
assert len(notes_before) >= 2, "MIDI should have at least 2 notes"
|
| 393 |
-
|
| 394 |
-
# Run merge
|
| 395 |
-
merged_score = pipeline._merge_music21_notes(score, gap_threshold_qn=0.02)
|
| 396 |
-
|
| 397 |
-
# After merging, should have 1 note
|
| 398 |
-
notes_after = list(merged_score.flatten().notes)
|
| 399 |
-
assert len(notes_after) == 1, \
|
| 400 |
-
f"Sequential notes should merge into 1, got {len(notes_after)}"
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
class TestEnvelopeAnalysis:
|
| 404 |
-
"""Test velocity envelope analysis and sustain artifact detection."""
|
| 405 |
-
|
| 406 |
-
@pytest.fixture
|
| 407 |
-
def midi_with_decay_pattern(self, temp_storage_dir):
|
| 408 |
-
"""Create MIDI with decreasing velocity pattern (sustain decay artifact)."""
|
| 409 |
-
mid = mido.MidiFile()
|
| 410 |
-
track = mido.MidiTrack()
|
| 411 |
-
mid.tracks.append(track)
|
| 412 |
-
|
| 413 |
-
# Note 1: C4 (60) velocity 80, 0-480 ticks (1 beat)
|
| 414 |
-
track.append(mido.Message('note_on', note=60, velocity=80, time=0))
|
| 415 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=480))
|
| 416 |
-
|
| 417 |
-
# Gap of 120 ticks (~250ms at 120 BPM)
|
| 418 |
-
# Note 2: C4 (60) velocity 50 (decaying), 600-1080 ticks
|
| 419 |
-
track.append(mido.Message('note_on', note=60, velocity=50, time=120))
|
| 420 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=480))
|
| 421 |
-
|
| 422 |
-
# Gap of 100 ticks
|
| 423 |
-
# Note 3: C4 (60) velocity 30 (further decay), 1180-1660 ticks
|
| 424 |
-
track.append(mido.Message('note_on', note=60, velocity=30, time=100))
|
| 425 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=480))
|
| 426 |
-
|
| 427 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 428 |
-
|
| 429 |
-
midi_path = temp_storage_dir / "decay_pattern.mid"
|
| 430 |
-
mid.save(str(midi_path))
|
| 431 |
-
return midi_path
|
| 432 |
-
|
| 433 |
-
@pytest.fixture
|
| 434 |
-
def midi_with_staccato_pattern(self, temp_storage_dir):
|
| 435 |
-
"""Create MIDI with similar velocities (intentional staccato)."""
|
| 436 |
-
mid = mido.MidiFile()
|
| 437 |
-
track = mido.MidiTrack()
|
| 438 |
-
mid.tracks.append(track)
|
| 439 |
-
|
| 440 |
-
# Note 1: C4 (60) velocity 70
|
| 441 |
-
track.append(mido.Message('note_on', note=60, velocity=70, time=0))
|
| 442 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=240))
|
| 443 |
-
|
| 444 |
-
# Gap
|
| 445 |
-
# Note 2: C4 (60) velocity 68 (similar, intentional)
|
| 446 |
-
track.append(mido.Message('note_on', note=60, velocity=68, time=100))
|
| 447 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=240))
|
| 448 |
-
|
| 449 |
-
# Gap
|
| 450 |
-
# Note 3: C4 (60) velocity 72 (similar, intentional)
|
| 451 |
-
track.append(mido.Message('note_on', note=60, velocity=72, time=100))
|
| 452 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=240))
|
| 453 |
-
|
| 454 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 455 |
-
|
| 456 |
-
midi_path = temp_storage_dir / "staccato_pattern.mid"
|
| 457 |
-
mid.save(str(midi_path))
|
| 458 |
-
return midi_path
|
| 459 |
-
|
| 460 |
-
def test_detects_and_merges_decay_pattern(self, pipeline, midi_with_decay_pattern):
|
| 461 |
-
"""Test that decreasing velocity patterns are detected and merged."""
|
| 462 |
-
result = pipeline.analyze_note_envelope_and_merge_sustains(
|
| 463 |
-
midi_with_decay_pattern,
|
| 464 |
-
tempo_bpm=120.0
|
| 465 |
-
)
|
| 466 |
-
|
| 467 |
-
# Load result and count note_on events
|
| 468 |
-
mid = mido.MidiFile(result)
|
| 469 |
-
note_ons = [msg for msg in mid.tracks[0] if msg.type == 'note_on' and msg.velocity > 0]
|
| 470 |
-
|
| 471 |
-
# Should merge 3 decaying notes into 1
|
| 472 |
-
assert len(note_ons) == 1, \
|
| 473 |
-
f"Decay pattern should merge to 1 note, got {len(note_ons)}"
|
| 474 |
-
|
| 475 |
-
def test_preserves_staccato_pattern(self, pipeline, midi_with_staccato_pattern):
|
| 476 |
-
"""Test that similar velocities are NOT merged (intentional staccato)."""
|
| 477 |
-
result = pipeline.analyze_note_envelope_and_merge_sustains(
|
| 478 |
-
midi_with_staccato_pattern,
|
| 479 |
-
tempo_bpm=120.0
|
| 480 |
-
)
|
| 481 |
-
|
| 482 |
-
# Load result and count note_on events
|
| 483 |
-
mid = mido.MidiFile(result)
|
| 484 |
-
note_ons = [msg for msg in mid.tracks[0] if msg.type == 'note_on' and msg.velocity > 0]
|
| 485 |
-
|
| 486 |
-
# Should keep all 3 notes (similar velocities = intentional)
|
| 487 |
-
assert len(note_ons) == 3, \
|
| 488 |
-
f"Staccato pattern should keep 3 notes, got {len(note_ons)}"
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
class TestTempoAdaptiveThresholds:
|
| 492 |
-
"""Test tempo-adaptive threshold selection."""
|
| 493 |
-
|
| 494 |
-
def test_fast_tempo_uses_strict_thresholds(self, pipeline):
|
| 495 |
-
"""Test that fast tempos (>140 BPM) use stricter thresholds."""
|
| 496 |
-
thresholds = pipeline._get_tempo_adaptive_thresholds(160.0)
|
| 497 |
-
|
| 498 |
-
assert thresholds['onset_threshold'] == 0.50, "Fast tempo should use 0.50 onset threshold"
|
| 499 |
-
assert thresholds['min_velocity'] == 50, "Fast tempo should use 50 min velocity"
|
| 500 |
-
assert thresholds['min_duration_divisor'] == 6, "Fast tempo should use 48th notes"
|
| 501 |
-
|
| 502 |
-
def test_slow_tempo_uses_permissive_thresholds(self, pipeline):
|
| 503 |
-
"""Test that slow tempos (<80 BPM) use more permissive thresholds."""
|
| 504 |
-
thresholds = pipeline._get_tempo_adaptive_thresholds(60.0)
|
| 505 |
-
|
| 506 |
-
assert thresholds['onset_threshold'] == 0.40, "Slow tempo should use 0.40 onset threshold"
|
| 507 |
-
assert thresholds['min_velocity'] == 40, "Slow tempo should use 40 min velocity"
|
| 508 |
-
assert thresholds['min_duration_divisor'] == 10, "Slow tempo should use permissive divisor"
|
| 509 |
-
|
| 510 |
-
def test_medium_tempo_uses_default_thresholds(self, pipeline):
|
| 511 |
-
"""Test that medium tempos (80-140 BPM) use default thresholds."""
|
| 512 |
-
thresholds = pipeline._get_tempo_adaptive_thresholds(120.0)
|
| 513 |
-
|
| 514 |
-
assert thresholds['onset_threshold'] == 0.45, "Medium tempo should use 0.45 onset threshold"
|
| 515 |
-
assert thresholds['min_velocity'] == 45, "Medium tempo should use 45 min velocity"
|
| 516 |
-
assert thresholds['min_duration_divisor'] == 8, "Medium tempo should use 32nd notes"
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
class TestMusicXMLTies:
|
| 520 |
-
"""Test MusicXML tie notation generation."""
|
| 521 |
-
|
| 522 |
-
@pytest.fixture
|
| 523 |
-
def score_with_long_note(self):
|
| 524 |
-
"""Create a score with a note that spans multiple measures."""
|
| 525 |
-
from music21 import stream, note, meter
|
| 526 |
-
|
| 527 |
-
s = stream.Score()
|
| 528 |
-
part = stream.Part()
|
| 529 |
-
|
| 530 |
-
# Add 4/4 time signature
|
| 531 |
-
part.append(meter.TimeSignature('4/4'))
|
| 532 |
-
|
| 533 |
-
# Measure 1: C4 whole note (4 QN)
|
| 534 |
-
m1 = stream.Measure()
|
| 535 |
-
m1.append(note.Note('C4', quarterLength=4.0))
|
| 536 |
-
part.append(m1)
|
| 537 |
-
|
| 538 |
-
# Measure 2: D4 whole note (4 QN)
|
| 539 |
-
m2 = stream.Measure()
|
| 540 |
-
m2.append(note.Note('D4', quarterLength=4.0))
|
| 541 |
-
part.append(m2)
|
| 542 |
-
|
| 543 |
-
s.insert(0, part)
|
| 544 |
-
return s
|
| 545 |
-
|
| 546 |
-
@pytest.fixture
|
| 547 |
-
def score_with_cross_measure_note(self):
|
| 548 |
-
"""Create a score with a note crossing measure boundary."""
|
| 549 |
-
from music21 import stream, note, meter
|
| 550 |
-
|
| 551 |
-
s = stream.Score()
|
| 552 |
-
part = stream.Part()
|
| 553 |
-
|
| 554 |
-
# Add 4/4 time signature
|
| 555 |
-
part.append(meter.TimeSignature('4/4'))
|
| 556 |
-
|
| 557 |
-
# Measure 1: Two quarter notes + note that extends into measure 2
|
| 558 |
-
m1 = stream.Measure()
|
| 559 |
-
m1.append(note.Note('C4', quarterLength=1.0))
|
| 560 |
-
m1.append(note.Note('D4', quarterLength=1.0))
|
| 561 |
-
# This note is 2.5 QN, extends 0.5 QN beyond measure boundary
|
| 562 |
-
m1.append(note.Note('E4', quarterLength=2.5))
|
| 563 |
-
part.append(m1)
|
| 564 |
-
|
| 565 |
-
# Measure 2: Continuation
|
| 566 |
-
m2 = stream.Measure()
|
| 567 |
-
# The E4 should continue here with a tie
|
| 568 |
-
m2.append(note.Note('E4', quarterLength=1.0))
|
| 569 |
-
m2.append(note.Note('F4', quarterLength=2.0))
|
| 570 |
-
part.append(m2)
|
| 571 |
-
|
| 572 |
-
s.insert(0, part)
|
| 573 |
-
return s
|
| 574 |
-
|
| 575 |
-
def test_adds_ties_to_cross_measure_notes(self, pipeline, score_with_cross_measure_note):
|
| 576 |
-
"""Test that ties are added to notes crossing measure boundaries."""
|
| 577 |
-
result = pipeline._add_ties_to_score(score_with_cross_measure_note)
|
| 578 |
-
|
| 579 |
-
# Get all notes
|
| 580 |
-
all_notes = list(result.flatten().notes)
|
| 581 |
-
|
| 582 |
-
# Find E4 notes (should have ties)
|
| 583 |
-
e4_notes = [n for n in all_notes if n.pitch.name == 'E']
|
| 584 |
-
|
| 585 |
-
# Should have at least 1 E4 with 'start' tie
|
| 586 |
-
has_start_tie = any(n.tie is not None and n.tie.type == 'start' for n in e4_notes)
|
| 587 |
-
assert has_start_tie, "E4 note crossing measure should have 'start' tie"
|
| 588 |
-
|
| 589 |
-
def test_does_not_add_ties_to_within_measure_notes(self, pipeline, score_with_long_note):
|
| 590 |
-
"""Test that ties are NOT added to notes within a single measure."""
|
| 591 |
-
result = pipeline._add_ties_to_score(score_with_long_note)
|
| 592 |
-
|
| 593 |
-
# Get all notes
|
| 594 |
-
all_notes = list(result.flatten().notes)
|
| 595 |
-
|
| 596 |
-
# C4 and D4 are within measures, should not have ties
|
| 597 |
-
c4_notes = [n for n in all_notes if n.pitch.name == 'C']
|
| 598 |
-
d4_notes = [n for n in all_notes if n.pitch.name == 'D']
|
| 599 |
-
|
| 600 |
-
# None should have ties (they don't cross boundaries)
|
| 601 |
-
c4_has_ties = any(n.tie is not None for n in c4_notes)
|
| 602 |
-
d4_has_ties = any(n.tie is not None for n in d4_notes)
|
| 603 |
-
|
| 604 |
-
assert not c4_has_ties, "C4 within measure should not have tie"
|
| 605 |
-
assert not d4_has_ties, "D4 within measure should not have tie"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/tests/test_pipeline_monophonic.py
DELETED
|
@@ -1,306 +0,0 @@
|
|
| 1 |
-
"""Tests for monophonic melody extraction and frequency filtering."""
|
| 2 |
-
import pytest
|
| 3 |
-
import mido
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from pipeline import TranscriptionPipeline
|
| 6 |
-
from app_config import Settings
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
@pytest.fixture
|
| 10 |
-
def pipeline(tmp_path):
|
| 11 |
-
"""Create pipeline instance for testing."""
|
| 12 |
-
config = Settings(storage_path=tmp_path)
|
| 13 |
-
return TranscriptionPipeline(
|
| 14 |
-
job_id="test-mono",
|
| 15 |
-
youtube_url="https://test.com",
|
| 16 |
-
storage_path=tmp_path,
|
| 17 |
-
config=config
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
@pytest.fixture
|
| 22 |
-
def midi_with_octaves(tmp_path):
|
| 23 |
-
"""Create MIDI file with simultaneous octave notes (C4 + C6)."""
|
| 24 |
-
mid = mido.MidiFile(ticks_per_beat=220)
|
| 25 |
-
track = mido.MidiTrack()
|
| 26 |
-
mid.tracks.append(track)
|
| 27 |
-
|
| 28 |
-
# Simultaneous C4 (60) + C6 (84) - octave duplicate
|
| 29 |
-
track.append(mido.Message('note_on', note=60, velocity=64, time=0))
|
| 30 |
-
track.append(mido.Message('note_on', note=84, velocity=64, time=0))
|
| 31 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=480))
|
| 32 |
-
track.append(mido.Message('note_off', note=84, velocity=0, time=0))
|
| 33 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 34 |
-
|
| 35 |
-
path = tmp_path / "octaves.mid"
|
| 36 |
-
mid.save(str(path))
|
| 37 |
-
return path
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
@pytest.fixture
|
| 41 |
-
def midi_with_single_notes(tmp_path):
|
| 42 |
-
"""Create MIDI file with sequential single notes."""
|
| 43 |
-
mid = mido.MidiFile(ticks_per_beat=220)
|
| 44 |
-
track = mido.MidiTrack()
|
| 45 |
-
mid.tracks.append(track)
|
| 46 |
-
|
| 47 |
-
# C4, D4, E4 in sequence
|
| 48 |
-
track.append(mido.Message('note_on', note=60, velocity=64, time=0))
|
| 49 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=480))
|
| 50 |
-
track.append(mido.Message('note_on', note=62, velocity=64, time=0))
|
| 51 |
-
track.append(mido.Message('note_off', note=62, velocity=0, time=480))
|
| 52 |
-
track.append(mido.Message('note_on', note=64, velocity=64, time=0))
|
| 53 |
-
track.append(mido.Message('note_off', note=64, velocity=0, time=480))
|
| 54 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 55 |
-
|
| 56 |
-
path = tmp_path / "single_notes.mid"
|
| 57 |
-
mid.save(str(path))
|
| 58 |
-
return path
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
@pytest.fixture
|
| 62 |
-
def midi_with_close_onset(tmp_path):
|
| 63 |
-
"""Create MIDI file with notes starting within ONSET_TOLERANCE (10 ticks)."""
|
| 64 |
-
mid = mido.MidiFile(ticks_per_beat=220)
|
| 65 |
-
track = mido.MidiTrack()
|
| 66 |
-
mid.tracks.append(track)
|
| 67 |
-
|
| 68 |
-
# C4 at time 0, G4 at time 5 (within tolerance, should be treated as simultaneous)
|
| 69 |
-
track.append(mido.Message('note_on', note=60, velocity=64, time=0))
|
| 70 |
-
track.append(mido.Message('note_on', note=67, velocity=64, time=5))
|
| 71 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=475))
|
| 72 |
-
track.append(mido.Message('note_off', note=67, velocity=0, time=0))
|
| 73 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 74 |
-
|
| 75 |
-
path = tmp_path / "close_onset.mid"
|
| 76 |
-
mid.save(str(path))
|
| 77 |
-
return path
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
@pytest.fixture
|
| 81 |
-
def midi_with_consecutive_same_pitch(tmp_path):
|
| 82 |
-
"""Create MIDI file with consecutive same-pitch notes (not simultaneous)."""
|
| 83 |
-
mid = mido.MidiFile(ticks_per_beat=220)
|
| 84 |
-
track = mido.MidiTrack()
|
| 85 |
-
mid.tracks.append(track)
|
| 86 |
-
|
| 87 |
-
# C4 at time 0, C4 at time 500 (sequential, not simultaneous)
|
| 88 |
-
track.append(mido.Message('note_on', note=60, velocity=64, time=0))
|
| 89 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=480))
|
| 90 |
-
track.append(mido.Message('note_on', note=60, velocity=64, time=20))
|
| 91 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=480))
|
| 92 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 93 |
-
|
| 94 |
-
path = tmp_path / "consecutive_same.mid"
|
| 95 |
-
mid.save(str(path))
|
| 96 |
-
return path
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
@pytest.fixture
|
| 100 |
-
def midi_with_triple_octaves(tmp_path):
|
| 101 |
-
"""Create MIDI file with three simultaneous octaves (C4 + C5 + C6)."""
|
| 102 |
-
mid = mido.MidiFile(ticks_per_beat=220)
|
| 103 |
-
track = mido.MidiTrack()
|
| 104 |
-
mid.tracks.append(track)
|
| 105 |
-
|
| 106 |
-
# Simultaneous C4 (60) + C5 (72) + C6 (84)
|
| 107 |
-
track.append(mido.Message('note_on', note=60, velocity=64, time=0))
|
| 108 |
-
track.append(mido.Message('note_on', note=72, velocity=64, time=0))
|
| 109 |
-
track.append(mido.Message('note_on', note=84, velocity=64, time=0))
|
| 110 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=480))
|
| 111 |
-
track.append(mido.Message('note_off', note=72, velocity=0, time=0))
|
| 112 |
-
track.append(mido.Message('note_off', note=84, velocity=0, time=0))
|
| 113 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 114 |
-
|
| 115 |
-
path = tmp_path / "triple_octaves.mid"
|
| 116 |
-
mid.save(str(path))
|
| 117 |
-
return path
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
@pytest.fixture
|
| 121 |
-
def midi_with_different_pitch_classes(tmp_path):
|
| 122 |
-
"""Create MIDI file with bass + treble (different pitch classes) simultaneous."""
|
| 123 |
-
mid = mido.MidiFile(ticks_per_beat=220)
|
| 124 |
-
track = mido.MidiTrack()
|
| 125 |
-
mid.tracks.append(track)
|
| 126 |
-
|
| 127 |
-
# Simultaneous D2 (38, pitch class 2) + A5 (81, pitch class 9)
|
| 128 |
-
# This simulates piano with left hand bass + right hand treble
|
| 129 |
-
track.append(mido.Message('note_on', note=38, velocity=64, time=0))
|
| 130 |
-
track.append(mido.Message('note_on', note=81, velocity=64, time=0))
|
| 131 |
-
track.append(mido.Message('note_off', note=38, velocity=0, time=480))
|
| 132 |
-
track.append(mido.Message('note_off', note=81, velocity=0, time=0))
|
| 133 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 134 |
-
|
| 135 |
-
path = tmp_path / "different_pitch_classes.mid"
|
| 136 |
-
mid.save(str(path))
|
| 137 |
-
return path
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
@pytest.fixture
|
| 141 |
-
def midi_wide_range(tmp_path):
|
| 142 |
-
"""Create MIDI file with wide range (>24 semitones) - simulates piano."""
|
| 143 |
-
mid = mido.MidiFile(ticks_per_beat=220)
|
| 144 |
-
track = mido.MidiTrack()
|
| 145 |
-
mid.tracks.append(track)
|
| 146 |
-
|
| 147 |
-
# D2 (38) to E6 (88) = 50 semitones (wide range)
|
| 148 |
-
track.append(mido.Message('note_on', note=38, velocity=64, time=0))
|
| 149 |
-
track.append(mido.Message('note_off', note=38, velocity=0, time=480))
|
| 150 |
-
track.append(mido.Message('note_on', note=88, velocity=64, time=0))
|
| 151 |
-
track.append(mido.Message('note_off', note=88, velocity=0, time=480))
|
| 152 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 153 |
-
|
| 154 |
-
path = tmp_path / "wide_range.mid"
|
| 155 |
-
mid.save(str(path))
|
| 156 |
-
return path
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
@pytest.fixture
|
| 160 |
-
def midi_narrow_range(tmp_path):
|
| 161 |
-
"""Create MIDI file with narrow range (≤24 semitones) - simulates monophonic melody."""
|
| 162 |
-
mid = mido.MidiFile(ticks_per_beat=220)
|
| 163 |
-
track = mido.MidiTrack()
|
| 164 |
-
mid.tracks.append(track)
|
| 165 |
-
|
| 166 |
-
# C4 (60) to A4 (69) = 9 semitones (narrow range)
|
| 167 |
-
track.append(mido.Message('note_on', note=60, velocity=64, time=0))
|
| 168 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=480))
|
| 169 |
-
track.append(mido.Message('note_on', note=69, velocity=64, time=0))
|
| 170 |
-
track.append(mido.Message('note_off', note=69, velocity=0, time=480))
|
| 171 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 172 |
-
|
| 173 |
-
path = tmp_path / "narrow_range.mid"
|
| 174 |
-
mid.save(str(path))
|
| 175 |
-
return path
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
class TestExtractMonophonicMelody:
|
| 179 |
-
"""Tests for the extract_monophonic_melody() function."""
|
| 180 |
-
|
| 181 |
-
def test_skyline_algorithm_keeps_highest_pitch(self, pipeline, midi_with_octaves):
|
| 182 |
-
"""Skyline algorithm should keep C6 (highest) and remove C4."""
|
| 183 |
-
result_path = pipeline.extract_monophonic_melody(midi_with_octaves)
|
| 184 |
-
|
| 185 |
-
# Load result and extract notes
|
| 186 |
-
mid = mido.MidiFile(result_path)
|
| 187 |
-
notes = []
|
| 188 |
-
for track in mid.tracks:
|
| 189 |
-
for msg in track:
|
| 190 |
-
if msg.type == 'note_on' and msg.velocity > 0:
|
| 191 |
-
notes.append(msg.note)
|
| 192 |
-
|
| 193 |
-
assert len(notes) == 1, f"Should have exactly one note, got {len(notes)}"
|
| 194 |
-
assert notes[0] == 84, f"Should keep C6 (84), not C4 (60), got {notes[0]}"
|
| 195 |
-
|
| 196 |
-
def test_preserves_single_notes_unchanged(self, pipeline, midi_with_single_notes):
|
| 197 |
-
"""Single sequential notes should pass through unchanged."""
|
| 198 |
-
result_path = pipeline.extract_monophonic_melody(midi_with_single_notes)
|
| 199 |
-
|
| 200 |
-
# Load result and extract notes
|
| 201 |
-
mid = mido.MidiFile(result_path)
|
| 202 |
-
notes = []
|
| 203 |
-
for track in mid.tracks:
|
| 204 |
-
for msg in track:
|
| 205 |
-
if msg.type == 'note_on' and msg.velocity > 0:
|
| 206 |
-
notes.append(msg.note)
|
| 207 |
-
|
| 208 |
-
assert len(notes) == 3, f"Should preserve all 3 notes, got {len(notes)}"
|
| 209 |
-
assert notes == [60, 62, 64], f"Should preserve C4, D4, E4, got {notes}"
|
| 210 |
-
|
| 211 |
-
def test_handles_onset_tolerance(self, pipeline, midi_with_close_onset):
|
| 212 |
-
"""Notes within ONSET_TOLERANCE (10 ticks) should be treated as simultaneous."""
|
| 213 |
-
result_path = pipeline.extract_monophonic_melody(midi_with_close_onset)
|
| 214 |
-
|
| 215 |
-
# Load result and extract notes
|
| 216 |
-
mid = mido.MidiFile(result_path)
|
| 217 |
-
notes = []
|
| 218 |
-
for track in mid.tracks:
|
| 219 |
-
for msg in track:
|
| 220 |
-
if msg.type == 'note_on' and msg.velocity > 0:
|
| 221 |
-
notes.append(msg.note)
|
| 222 |
-
|
| 223 |
-
# C4 (60, pitch class 0) and G4 (67, pitch class 7) start within 5 ticks
|
| 224 |
-
# Different pitch classes → both should be kept
|
| 225 |
-
assert len(notes) == 2, f"Should keep both notes (different pitch classes), got {len(notes)}"
|
| 226 |
-
assert set(notes) == {60, 67}, f"Should keep both C4 (60) and G4 (67), got {notes}"
|
| 227 |
-
|
| 228 |
-
def test_consecutive_same_pitch_preserved(self, pipeline, midi_with_consecutive_same_pitch):
|
| 229 |
-
"""Consecutive same-pitch notes should be preserved (not merged)."""
|
| 230 |
-
result_path = pipeline.extract_monophonic_melody(midi_with_consecutive_same_pitch)
|
| 231 |
-
|
| 232 |
-
# Load result and extract notes
|
| 233 |
-
mid = mido.MidiFile(result_path)
|
| 234 |
-
notes = []
|
| 235 |
-
for track in mid.tracks:
|
| 236 |
-
for msg in track:
|
| 237 |
-
if msg.type == 'note_on' and msg.velocity > 0:
|
| 238 |
-
notes.append(msg.note)
|
| 239 |
-
|
| 240 |
-
assert len(notes) == 2, f"Should preserve both C4 notes, got {len(notes)}"
|
| 241 |
-
assert notes == [60, 60], f"Should have two C4 notes, got {notes}"
|
| 242 |
-
|
| 243 |
-
def test_removes_multiple_octave_duplicates(self, pipeline, midi_with_triple_octaves):
|
| 244 |
-
"""Should keep only the highest pitch from multiple simultaneous octaves."""
|
| 245 |
-
result_path = pipeline.extract_monophonic_melody(midi_with_triple_octaves)
|
| 246 |
-
|
| 247 |
-
# Load result and extract notes
|
| 248 |
-
mid = mido.MidiFile(result_path)
|
| 249 |
-
notes = []
|
| 250 |
-
for track in mid.tracks:
|
| 251 |
-
for msg in track:
|
| 252 |
-
if msg.type == 'note_on' and msg.velocity > 0:
|
| 253 |
-
notes.append(msg.note)
|
| 254 |
-
|
| 255 |
-
assert len(notes) == 1, f"Should have exactly one note from three octaves, got {len(notes)}"
|
| 256 |
-
assert notes[0] == 84, f"Should keep C6 (84) as highest, got {notes[0]}"
|
| 257 |
-
|
| 258 |
-
def test_preserves_different_pitch_classes(self, pipeline, midi_with_different_pitch_classes):
|
| 259 |
-
"""Should preserve notes of different pitch classes (bass + treble)."""
|
| 260 |
-
result_path = pipeline.extract_monophonic_melody(midi_with_different_pitch_classes)
|
| 261 |
-
|
| 262 |
-
# Load result and extract notes
|
| 263 |
-
mid = mido.MidiFile(result_path)
|
| 264 |
-
notes = []
|
| 265 |
-
for track in mid.tracks:
|
| 266 |
-
for msg in track:
|
| 267 |
-
if msg.type == 'note_on' and msg.velocity > 0:
|
| 268 |
-
notes.append(msg.note)
|
| 269 |
-
|
| 270 |
-
# D2 (38, pitch class 2) and A5 (81, pitch class 9) are different
|
| 271 |
-
# Both should be preserved (simulates piano left + right hand)
|
| 272 |
-
assert len(notes) == 2, f"Should preserve both bass and treble notes, got {len(notes)}"
|
| 273 |
-
assert set(notes) == {38, 81}, f"Should keep both D2 (38) and A5 (81), got {notes}"
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
class TestRangeDetection:
|
| 277 |
-
"""Tests for MIDI range detection and adaptive processing."""
|
| 278 |
-
|
| 279 |
-
def test_detects_wide_range_piano(self, pipeline, midi_wide_range):
|
| 280 |
-
"""Should detect wide range (>24 semitones) as polyphonic."""
|
| 281 |
-
range_semitones = pipeline._get_midi_range(midi_wide_range)
|
| 282 |
-
|
| 283 |
-
# D2 (38) to E6 (88) = 50 semitones
|
| 284 |
-
assert range_semitones == 50, f"Expected 50 semitones, got {range_semitones}"
|
| 285 |
-
assert range_semitones > 24, "Should be detected as wide range (polyphonic)"
|
| 286 |
-
|
| 287 |
-
def test_detects_narrow_range_melody(self, pipeline, midi_narrow_range):
|
| 288 |
-
"""Should detect narrow range (≤24 semitones) as monophonic."""
|
| 289 |
-
range_semitones = pipeline._get_midi_range(midi_narrow_range)
|
| 290 |
-
|
| 291 |
-
# C4 (60) to A4 (69) = 9 semitones
|
| 292 |
-
assert range_semitones == 9, f"Expected 9 semitones, got {range_semitones}"
|
| 293 |
-
assert range_semitones <= 24, "Should be detected as narrow range (monophonic)"
|
| 294 |
-
|
| 295 |
-
def test_empty_midi_returns_zero_range(self, pipeline, tmp_path):
|
| 296 |
-
"""Should return 0 for MIDI with no notes."""
|
| 297 |
-
mid = mido.MidiFile(ticks_per_beat=220)
|
| 298 |
-
track = mido.MidiTrack()
|
| 299 |
-
mid.tracks.append(track)
|
| 300 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 301 |
-
|
| 302 |
-
path = tmp_path / "empty.mid"
|
| 303 |
-
mid.save(str(path))
|
| 304 |
-
|
| 305 |
-
range_semitones = pipeline._get_midi_range(path)
|
| 306 |
-
assert range_semitones == 0, f"Expected 0 for empty MIDI, got {range_semitones}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/tests/test_tasks.py
CHANGED
|
@@ -14,18 +14,31 @@ class TestProcessTranscriptionTask:
|
|
| 14 |
"""Test successful task execution."""
|
| 15 |
from tasks import process_transcription_task
|
| 16 |
|
| 17 |
-
# Mock job data in Redis
|
| 18 |
job_data = {
|
| 19 |
-
'job_id': sample_job_id,
|
| 20 |
'youtube_url': 'https://www.youtube.com/watch?v=dQw4w9WgXcQ',
|
| 21 |
'video_id': 'dQw4w9WgXcQ',
|
| 22 |
'options': '{"instruments": ["piano"]}'
|
| 23 |
}
|
| 24 |
mock_redis.hgetall.return_value = job_data
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# Mock successful pipeline instance
|
| 27 |
mock_pipeline = MagicMock()
|
| 28 |
mock_pipeline.run.return_value = str(temp_storage_dir / "output.musicxml")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
mock_pipeline_class.return_value = mock_pipeline
|
| 30 |
|
| 31 |
# Execute task
|
|
@@ -84,15 +97,28 @@ class TestProcessTranscriptionTask:
|
|
| 84 |
from tasks import process_transcription_task
|
| 85 |
|
| 86 |
job_data = {
|
| 87 |
-
'job_id': sample_job_id,
|
| 88 |
'youtube_url': 'https://www.youtube.com/watch?v=dQw4w9WgXcQ',
|
| 89 |
'video_id': 'dQw4w9WgXcQ',
|
| 90 |
'options': '{}'
|
| 91 |
}
|
| 92 |
mock_redis.hgetall.return_value = job_data
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
mock_pipeline = MagicMock()
|
| 95 |
mock_pipeline.run.return_value = str(temp_storage_dir / "output.musicxml")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
mock_pipeline_class.return_value = mock_pipeline
|
| 97 |
|
| 98 |
process_transcription_task(sample_job_id)
|
|
|
|
| 14 |
"""Test successful task execution."""
|
| 15 |
from tasks import process_transcription_task
|
| 16 |
|
| 17 |
+
# Mock job data in Redis - all string values
|
| 18 |
job_data = {
|
| 19 |
+
'job_id': str(sample_job_id),
|
| 20 |
'youtube_url': 'https://www.youtube.com/watch?v=dQw4w9WgXcQ',
|
| 21 |
'video_id': 'dQw4w9WgXcQ',
|
| 22 |
'options': '{"instruments": ["piano"]}'
|
| 23 |
}
|
| 24 |
mock_redis.hgetall.return_value = job_data
|
| 25 |
|
| 26 |
+
# Ensure pipeline method returns None
|
| 27 |
+
mock_redis.pipeline.return_value.__enter__.return_value = mock_redis
|
| 28 |
+
|
| 29 |
+
# Create actual files so they exist
|
| 30 |
+
(temp_storage_dir / "output.musicxml").write_text("<?xml version='1.0'?><score-partwise></score-partwise>")
|
| 31 |
+
(temp_storage_dir / "output.mid").write_bytes(b"MThd")
|
| 32 |
+
|
| 33 |
# Mock successful pipeline instance
|
| 34 |
mock_pipeline = MagicMock()
|
| 35 |
mock_pipeline.run.return_value = str(temp_storage_dir / "output.musicxml")
|
| 36 |
+
mock_pipeline.final_midi_path = temp_storage_dir / "output.mid"
|
| 37 |
+
mock_pipeline.metadata = {
|
| 38 |
+
"tempo": 120.0,
|
| 39 |
+
"time_signature": {"numerator": 4, "denominator": 4},
|
| 40 |
+
"key_signature": "C"
|
| 41 |
+
}
|
| 42 |
mock_pipeline_class.return_value = mock_pipeline
|
| 43 |
|
| 44 |
# Execute task
|
|
|
|
| 97 |
from tasks import process_transcription_task
|
| 98 |
|
| 99 |
job_data = {
|
| 100 |
+
'job_id': str(sample_job_id),
|
| 101 |
'youtube_url': 'https://www.youtube.com/watch?v=dQw4w9WgXcQ',
|
| 102 |
'video_id': 'dQw4w9WgXcQ',
|
| 103 |
'options': '{}'
|
| 104 |
}
|
| 105 |
mock_redis.hgetall.return_value = job_data
|
| 106 |
+
|
| 107 |
+
# Create actual files so they exist
|
| 108 |
+
(temp_storage_dir / "output.musicxml").write_text("<?xml version='1.0'?><score-partwise></score-partwise>")
|
| 109 |
+
(temp_storage_dir / "output.mid").write_bytes(b"MThd")
|
| 110 |
+
|
| 111 |
+
# Ensure pipeline method returns None
|
| 112 |
+
mock_redis.pipeline.return_value.__enter__.return_value = mock_redis
|
| 113 |
+
|
| 114 |
mock_pipeline = MagicMock()
|
| 115 |
mock_pipeline.run.return_value = str(temp_storage_dir / "output.musicxml")
|
| 116 |
+
mock_pipeline.final_midi_path = temp_storage_dir / "output.mid"
|
| 117 |
+
mock_pipeline.metadata = {
|
| 118 |
+
"tempo": 120.0,
|
| 119 |
+
"time_signature": {"numerator": 4, "denominator": 4},
|
| 120 |
+
"key_signature": "C"
|
| 121 |
+
}
|
| 122 |
mock_pipeline_class.return_value = mock_pipeline
|
| 123 |
|
| 124 |
process_transcription_task(sample_job_id)
|
backend/tests/test_yourmt3_integration.py
DELETED
|
@@ -1,296 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Tests for YourMT3+ transcription service integration.
|
| 3 |
-
|
| 4 |
-
Tests cover:
|
| 5 |
-
- YourMT3+ service health check
|
| 6 |
-
- Successful transcription
|
| 7 |
-
- Fallback to basic-pitch on service failure
|
| 8 |
-
- Fallback to basic-pitch when service disabled
|
| 9 |
-
"""
|
| 10 |
-
import pytest
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
from unittest.mock import Mock, patch, MagicMock
|
| 13 |
-
import mido
|
| 14 |
-
import tempfile
|
| 15 |
-
import shutil
|
| 16 |
-
|
| 17 |
-
from pipeline import TranscriptionPipeline
|
| 18 |
-
from app_config import Settings
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
@pytest.fixture
|
| 22 |
-
def temp_storage():
|
| 23 |
-
"""Create temporary storage directory for tests."""
|
| 24 |
-
temp_dir = Path(tempfile.mkdtemp())
|
| 25 |
-
yield temp_dir
|
| 26 |
-
shutil.rmtree(temp_dir)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
@pytest.fixture
|
| 30 |
-
def test_audio_file(temp_storage):
|
| 31 |
-
"""Create a minimal test audio file."""
|
| 32 |
-
import soundfile as sf
|
| 33 |
-
import numpy as np
|
| 34 |
-
|
| 35 |
-
audio_path = temp_storage / "test_audio.wav"
|
| 36 |
-
# Create 1 second of silence
|
| 37 |
-
sample_rate = 44100
|
| 38 |
-
audio_data = np.zeros(sample_rate)
|
| 39 |
-
sf.write(str(audio_path), audio_data, sample_rate)
|
| 40 |
-
|
| 41 |
-
return audio_path
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
@pytest.fixture
|
| 45 |
-
def mock_yourmt3_midi(temp_storage):
|
| 46 |
-
"""Create a mock MIDI file that YourMT3+ would return."""
|
| 47 |
-
midi_path = temp_storage / "yourmt3_output.mid"
|
| 48 |
-
|
| 49 |
-
# Create a simple MIDI file with one note
|
| 50 |
-
mid = mido.MidiFile()
|
| 51 |
-
track = mido.MidiTrack()
|
| 52 |
-
mid.tracks.append(track)
|
| 53 |
-
|
| 54 |
-
track.append(mido.Message('note_on', note=60, velocity=80, time=0))
|
| 55 |
-
track.append(mido.Message('note_off', note=60, velocity=0, time=480))
|
| 56 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 57 |
-
|
| 58 |
-
mid.save(str(midi_path))
|
| 59 |
-
return midi_path
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
@pytest.fixture
|
| 63 |
-
def mock_basic_pitch_midi(temp_storage):
|
| 64 |
-
"""Create a mock MIDI file that basic-pitch would return."""
|
| 65 |
-
midi_path = temp_storage / "basic_pitch_output.mid"
|
| 66 |
-
|
| 67 |
-
# Create a simple MIDI file with one note
|
| 68 |
-
mid = mido.MidiFile()
|
| 69 |
-
track = mido.MidiTrack()
|
| 70 |
-
mid.tracks.append(track)
|
| 71 |
-
|
| 72 |
-
track.append(mido.Message('note_on', note=62, velocity=70, time=0))
|
| 73 |
-
track.append(mido.Message('note_off', note=62, velocity=0, time=480))
|
| 74 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 75 |
-
|
| 76 |
-
mid.save(str(midi_path))
|
| 77 |
-
return midi_path
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class TestYourMT3Integration:
|
| 81 |
-
"""Test suite for YourMT3+ transcription service integration."""
|
| 82 |
-
|
| 83 |
-
def test_yourmt3_enabled_by_default(self):
|
| 84 |
-
"""Test that YourMT3+ is enabled by default in config."""
|
| 85 |
-
config = Settings()
|
| 86 |
-
assert config.use_yourmt3_transcription is True
|
| 87 |
-
|
| 88 |
-
def test_yourmt3_service_health_check(self, temp_storage):
|
| 89 |
-
"""Test YourMT3+ service health check endpoint."""
|
| 90 |
-
config = Settings(use_yourmt3_transcription=True)
|
| 91 |
-
pipeline = TranscriptionPipeline(
|
| 92 |
-
job_id="test_health",
|
| 93 |
-
youtube_url="https://youtube.com/test",
|
| 94 |
-
storage_path=temp_storage,
|
| 95 |
-
config=config
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
with patch('requests.get') as mock_get:
|
| 99 |
-
# Mock successful health check
|
| 100 |
-
mock_response = Mock()
|
| 101 |
-
mock_response.json.return_value = {
|
| 102 |
-
"status": "healthy",
|
| 103 |
-
"model_loaded": True,
|
| 104 |
-
"device": "mps"
|
| 105 |
-
}
|
| 106 |
-
mock_response.raise_for_status = Mock()
|
| 107 |
-
mock_get.return_value = mock_response
|
| 108 |
-
|
| 109 |
-
# Call transcribe_with_yourmt3 (which includes health check)
|
| 110 |
-
with patch('requests.post') as mock_post:
|
| 111 |
-
mock_post_response = Mock()
|
| 112 |
-
mock_post_response.content = b"mock midi data"
|
| 113 |
-
mock_post.return_value = mock_post_response
|
| 114 |
-
|
| 115 |
-
with patch('builtins.open', create=True):
|
| 116 |
-
with patch('pathlib.Path.exists', return_value=True):
|
| 117 |
-
# This would fail in real scenario, but we're testing health check
|
| 118 |
-
try:
|
| 119 |
-
pipeline.transcribe_with_yourmt3(temp_storage / "test.wav")
|
| 120 |
-
except:
|
| 121 |
-
pass # Expected to fail, we just want to verify health check was called
|
| 122 |
-
|
| 123 |
-
# Verify health check was called
|
| 124 |
-
assert mock_get.called
|
| 125 |
-
assert "/health" in str(mock_get.call_args)
|
| 126 |
-
|
| 127 |
-
def test_yourmt3_transcription_success(self, temp_storage, test_audio_file, mock_yourmt3_midi):
|
| 128 |
-
"""Test successful YourMT3+ transcription."""
|
| 129 |
-
config = Settings(use_yourmt3_transcription=True)
|
| 130 |
-
pipeline = TranscriptionPipeline(
|
| 131 |
-
job_id="test_success",
|
| 132 |
-
youtube_url="https://youtube.com/test",
|
| 133 |
-
storage_path=temp_storage,
|
| 134 |
-
config=config
|
| 135 |
-
)
|
| 136 |
-
|
| 137 |
-
with patch('requests.get') as mock_get:
|
| 138 |
-
# Mock successful health check
|
| 139 |
-
mock_health = Mock()
|
| 140 |
-
mock_health.json.return_value = {"status": "healthy", "model_loaded": True}
|
| 141 |
-
mock_health.raise_for_status = Mock()
|
| 142 |
-
mock_get.return_value = mock_health
|
| 143 |
-
|
| 144 |
-
with patch('requests.post') as mock_post:
|
| 145 |
-
# Mock successful transcription
|
| 146 |
-
with open(mock_yourmt3_midi, 'rb') as f:
|
| 147 |
-
mock_midi_data = f.read()
|
| 148 |
-
|
| 149 |
-
mock_response = Mock()
|
| 150 |
-
mock_response.content = mock_midi_data
|
| 151 |
-
mock_post.return_value = mock_response
|
| 152 |
-
|
| 153 |
-
result = pipeline.transcribe_with_yourmt3(test_audio_file)
|
| 154 |
-
|
| 155 |
-
assert result.exists()
|
| 156 |
-
assert result.suffix == '.mid'
|
| 157 |
-
|
| 158 |
-
# Verify MIDI file is valid
|
| 159 |
-
mid = mido.MidiFile(result)
|
| 160 |
-
assert len(mid.tracks) > 0
|
| 161 |
-
|
| 162 |
-
def test_yourmt3_fallback_on_service_error(self, temp_storage, test_audio_file):
|
| 163 |
-
"""Test fallback to basic-pitch when YourMT3+ service fails."""
|
| 164 |
-
config = Settings(use_yourmt3_transcription=True)
|
| 165 |
-
pipeline = TranscriptionPipeline(
|
| 166 |
-
job_id="test_fallback",
|
| 167 |
-
youtube_url="https://youtube.com/test",
|
| 168 |
-
storage_path=temp_storage,
|
| 169 |
-
config=config
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
with patch('requests.get') as mock_get:
|
| 173 |
-
# Mock health check failure
|
| 174 |
-
mock_get.side_effect = Exception("Service unavailable")
|
| 175 |
-
|
| 176 |
-
with patch('basic_pitch.inference.predict_and_save') as mock_bp:
|
| 177 |
-
# Mock basic-pitch creating a MIDI file
|
| 178 |
-
def create_basic_pitch_midi(*args, **kwargs):
|
| 179 |
-
output_dir = Path(kwargs['output_directory'])
|
| 180 |
-
audio_path = Path(kwargs['audio_path_list'][0])
|
| 181 |
-
midi_path = output_dir / f"{audio_path.stem}_basic_pitch.mid"
|
| 182 |
-
|
| 183 |
-
# Create simple MIDI
|
| 184 |
-
mid = mido.MidiFile()
|
| 185 |
-
track = mido.MidiTrack()
|
| 186 |
-
mid.tracks.append(track)
|
| 187 |
-
track.append(mido.Message('note_on', note=64, velocity=75, time=0))
|
| 188 |
-
track.append(mido.Message('note_off', note=64, velocity=0, time=480))
|
| 189 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 190 |
-
mid.save(str(midi_path))
|
| 191 |
-
|
| 192 |
-
mock_bp.side_effect = create_basic_pitch_midi
|
| 193 |
-
|
| 194 |
-
# This should use basic-pitch as fallback
|
| 195 |
-
result = pipeline.transcribe_to_midi(
|
| 196 |
-
audio_path=test_audio_file
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
assert result.exists()
|
| 200 |
-
assert result.suffix == '.mid'
|
| 201 |
-
|
| 202 |
-
# Verify basic-pitch was called
|
| 203 |
-
assert mock_bp.called
|
| 204 |
-
|
| 205 |
-
def test_yourmt3_disabled_uses_basic_pitch(self, temp_storage, test_audio_file):
|
| 206 |
-
"""Test that basic-pitch is used when YourMT3+ is disabled."""
|
| 207 |
-
config = Settings(use_yourmt3_transcription=False)
|
| 208 |
-
pipeline = TranscriptionPipeline(
|
| 209 |
-
job_id="test_disabled",
|
| 210 |
-
youtube_url="https://youtube.com/test",
|
| 211 |
-
storage_path=temp_storage,
|
| 212 |
-
config=config
|
| 213 |
-
)
|
| 214 |
-
|
| 215 |
-
with patch('basic_pitch.inference.predict_and_save') as mock_bp:
|
| 216 |
-
# Mock basic-pitch creating a MIDI file
|
| 217 |
-
def create_basic_pitch_midi(*args, **kwargs):
|
| 218 |
-
output_dir = Path(kwargs['output_directory'])
|
| 219 |
-
audio_path = Path(kwargs['audio_path_list'][0])
|
| 220 |
-
midi_path = output_dir / f"{audio_path.stem}_basic_pitch.mid"
|
| 221 |
-
|
| 222 |
-
# Create simple MIDI
|
| 223 |
-
mid = mido.MidiFile()
|
| 224 |
-
track = mido.MidiTrack()
|
| 225 |
-
mid.tracks.append(track)
|
| 226 |
-
track.append(mido.Message('note_on', note=65, velocity=78, time=0))
|
| 227 |
-
track.append(mido.Message('note_off', note=65, velocity=0, time=480))
|
| 228 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 229 |
-
mid.save(str(midi_path))
|
| 230 |
-
|
| 231 |
-
mock_bp.side_effect = create_basic_pitch_midi
|
| 232 |
-
|
| 233 |
-
result = pipeline.transcribe_to_midi(
|
| 234 |
-
audio_path=test_audio_file
|
| 235 |
-
)
|
| 236 |
-
|
| 237 |
-
assert result.exists()
|
| 238 |
-
assert result.suffix == '.mid'
|
| 239 |
-
|
| 240 |
-
# Verify basic-pitch was called and YourMT3+ was not
|
| 241 |
-
assert mock_bp.called
|
| 242 |
-
|
| 243 |
-
def test_yourmt3_service_timeout(self, temp_storage, test_audio_file):
|
| 244 |
-
"""Test that timeouts are handled gracefully with fallback."""
|
| 245 |
-
config = Settings(
|
| 246 |
-
use_yourmt3_transcription=True,
|
| 247 |
-
transcription_service_timeout=5
|
| 248 |
-
)
|
| 249 |
-
pipeline = TranscriptionPipeline(
|
| 250 |
-
job_id="test_timeout",
|
| 251 |
-
youtube_url="https://youtube.com/test",
|
| 252 |
-
storage_path=temp_storage,
|
| 253 |
-
config=config
|
| 254 |
-
)
|
| 255 |
-
|
| 256 |
-
import requests
|
| 257 |
-
|
| 258 |
-
with patch('requests.get') as mock_get:
|
| 259 |
-
# Mock health check success
|
| 260 |
-
mock_health = Mock()
|
| 261 |
-
mock_health.json.return_value = {"status": "healthy", "model_loaded": True}
|
| 262 |
-
mock_get.return_value = mock_health
|
| 263 |
-
|
| 264 |
-
with patch('requests.post') as mock_post:
|
| 265 |
-
# Mock timeout
|
| 266 |
-
mock_post.side_effect = requests.exceptions.Timeout()
|
| 267 |
-
|
| 268 |
-
with patch('basic_pitch.inference.predict_and_save') as mock_bp:
|
| 269 |
-
# Mock basic-pitch creating a MIDI file
|
| 270 |
-
def create_basic_pitch_midi(*args, **kwargs):
|
| 271 |
-
output_dir = Path(kwargs['output_directory'])
|
| 272 |
-
audio_path = Path(kwargs['audio_path_list'][0])
|
| 273 |
-
midi_path = output_dir / f"{audio_path.stem}_basic_pitch.mid"
|
| 274 |
-
|
| 275 |
-
# Create simple MIDI
|
| 276 |
-
mid = mido.MidiFile()
|
| 277 |
-
track = mido.MidiTrack()
|
| 278 |
-
mid.tracks.append(track)
|
| 279 |
-
track.append(mido.Message('note_on', note=66, velocity=80, time=0))
|
| 280 |
-
track.append(mido.Message('note_off', note=66, velocity=0, time=480))
|
| 281 |
-
track.append(mido.MetaMessage('end_of_track', time=0))
|
| 282 |
-
mid.save(str(midi_path))
|
| 283 |
-
|
| 284 |
-
mock_bp.side_effect = create_basic_pitch_midi
|
| 285 |
-
|
| 286 |
-
result = pipeline.transcribe_to_midi(
|
| 287 |
-
audio_path=test_audio_file
|
| 288 |
-
)
|
| 289 |
-
|
| 290 |
-
assert result.exists()
|
| 291 |
-
# Verify fallback to basic-pitch
|
| 292 |
-
assert mock_bp.called
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
if __name__ == "__main__":
|
| 296 |
-
pytest.main([__file__, "-v"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/README.md
CHANGED
|
@@ -28,6 +28,7 @@ This documentation serves as the technical blueprint for implementing Rescored.
|
|
| 28 |
3. [Audio Processing Pipeline](backend/pipeline.md) - Core workflow
|
| 29 |
4. [API Design](backend/api.md)
|
| 30 |
5. [Background Workers](backend/workers.md)
|
|
|
|
| 31 |
|
| 32 |
### For Frontend Engineers
|
| 33 |
1. [Architecture Overview](architecture/overview.md)
|
|
@@ -53,6 +54,7 @@ This documentation serves as the technical blueprint for implementing Rescored.
|
|
| 53 |
- [Audio Processing Pipeline](backend/pipeline.md) - End-to-end audio → notation workflow
|
| 54 |
- [API Design](backend/api.md) - REST endpoints and WebSocket protocol
|
| 55 |
- [Background Workers](backend/workers.md) - Async job processing with Celery
|
|
|
|
| 56 |
|
| 57 |
### Frontend
|
| 58 |
- [Notation Rendering](frontend/notation-rendering.md) - Sheet music display with VexFlow
|
|
|
|
| 28 |
3. [Audio Processing Pipeline](backend/pipeline.md) - Core workflow
|
| 29 |
4. [API Design](backend/api.md)
|
| 30 |
5. [Background Workers](backend/workers.md)
|
| 31 |
+
6. [Testing Guide](backend/testing.md) - Writing and running tests
|
| 32 |
|
| 33 |
### For Frontend Engineers
|
| 34 |
1. [Architecture Overview](architecture/overview.md)
|
|
|
|
| 54 |
- [Audio Processing Pipeline](backend/pipeline.md) - End-to-end audio → notation workflow
|
| 55 |
- [API Design](backend/api.md) - REST endpoints and WebSocket protocol
|
| 56 |
- [Background Workers](backend/workers.md) - Async job processing with Celery
|
| 57 |
+
- [Testing Guide](backend/testing.md) - Backend test suite and best practices
|
| 58 |
|
| 59 |
### Frontend
|
| 60 |
- [Notation Rendering](frontend/notation-rendering.md) - Sheet music display with VexFlow
|
docs/architecture/deployment.md
CHANGED
|
@@ -25,17 +25,17 @@ graph TB
|
|
| 25 |
### Setup Requirements
|
| 26 |
|
| 27 |
**Hardware**:
|
| 28 |
-
- **GPU**: NVIDIA GPU with
|
| 29 |
-
- Alternative: Run on CPU (10-
|
| 30 |
- **RAM**: 16GB+ recommended
|
| 31 |
- **Disk**: 10GB for models and temp files
|
| 32 |
|
| 33 |
**Software**:
|
| 34 |
-
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
|
| 40 |
### Docker Compose Setup (Recommended)
|
| 41 |
|
|
@@ -107,44 +107,75 @@ volumes:
|
|
| 107 |
- Slower hot reload than native
|
| 108 |
- GPU support requires Docker Desktop on Mac (experimental)
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
### Manual Setup (Alternative)
|
| 111 |
|
| 112 |
-
|
|
|
|
|
|
|
| 113 |
```bash
|
| 114 |
-
redis
|
|
|
|
| 115 |
```
|
| 116 |
|
| 117 |
**Terminal 2 - Backend API**:
|
| 118 |
```bash
|
| 119 |
cd backend
|
| 120 |
-
uv venv
|
| 121 |
source .venv/bin/activate
|
| 122 |
-
|
| 123 |
-
uvicorn main:app --reload --port 8000
|
| 124 |
```
|
| 125 |
|
| 126 |
**Terminal 3 - Celery Worker**:
|
| 127 |
```bash
|
| 128 |
cd backend
|
| 129 |
source .venv/bin/activate
|
| 130 |
-
|
|
|
|
| 131 |
```
|
| 132 |
|
| 133 |
**Terminal 4 - Frontend**:
|
| 134 |
```bash
|
| 135 |
cd frontend
|
| 136 |
-
npm install
|
| 137 |
npm run dev
|
| 138 |
```
|
| 139 |
|
| 140 |
**Benefits**:
|
| 141 |
-
-
|
| 142 |
-
-
|
| 143 |
-
-
|
| 144 |
|
| 145 |
**Limitations**:
|
| 146 |
-
- Managing
|
| 147 |
-
-
|
| 148 |
|
| 149 |
---
|
| 150 |
|
|
|
|
| 25 |
### Setup Requirements
|
| 26 |
|
| 27 |
**Hardware**:
|
| 28 |
+
- **GPU**: Apple Silicon (M1/M2/M3/M4 with MPS) OR NVIDIA GPU with 4GB+ VRAM
|
| 29 |
+
- Alternative: Run on CPU (10-15x slower, acceptable for development)
|
| 30 |
- **RAM**: 16GB+ recommended
|
| 31 |
- **Disk**: 10GB for models and temp files
|
| 32 |
|
| 33 |
**Software**:
|
| 34 |
+
- **Python 3.10** (required for madmom compatibility)
|
| 35 |
+
- **Node.js 18+**
|
| 36 |
+
- **Redis 7+**
|
| 37 |
+
- **FFmpeg**
|
| 38 |
+
- **YouTube cookies** (required as of December 2024)
|
| 39 |
|
| 40 |
### Docker Compose Setup (Recommended)
|
| 41 |
|
|
|
|
| 107 |
- Slower hot reload than native
|
| 108 |
- GPU support requires Docker Desktop on Mac (experimental)
|
| 109 |
|
| 110 |
+
### Quick Start (Recommended)
|
| 111 |
+
|
| 112 |
+
Use the provided shell scripts to start/stop all services:
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
# From project root
|
| 116 |
+
./start.sh
|
| 117 |
+
|
| 118 |
+
# View logs
|
| 119 |
+
tail -f logs/api.log # Backend API
|
| 120 |
+
tail -f logs/worker.log # Celery worker
|
| 121 |
+
tail -f logs/frontend.log # Frontend
|
| 122 |
+
|
| 123 |
+
# Stop all services
|
| 124 |
+
./stop.sh
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
**What `start.sh` does:**
|
| 128 |
+
1. Starts Redis (if not already running via Homebrew)
|
| 129 |
+
2. Activates Python 3.10 venv
|
| 130 |
+
3. Starts Backend API (uvicorn) in background
|
| 131 |
+
4. Starts Celery Worker (--pool=solo for macOS) in background
|
| 132 |
+
5. Starts Frontend (npm run dev) in background
|
| 133 |
+
6. Writes all logs to `logs/` directory
|
| 134 |
+
|
| 135 |
+
**Services available at:**
|
| 136 |
+
- Frontend: http://localhost:5173
|
| 137 |
+
- Backend API: http://localhost:8000
|
| 138 |
+
- API Docs: http://localhost:8000/docs
|
| 139 |
+
|
| 140 |
### Manual Setup (Alternative)
|
| 141 |
|
| 142 |
+
If you prefer to run services manually in separate terminals:
|
| 143 |
+
|
| 144 |
+
**Terminal 1 - Redis (macOS with Homebrew)**:
|
| 145 |
```bash
|
| 146 |
+
brew services start redis
|
| 147 |
+
redis-cli ping # Should return PONG
|
| 148 |
```
|
| 149 |
|
| 150 |
**Terminal 2 - Backend API**:
|
| 151 |
```bash
|
| 152 |
cd backend
|
|
|
|
| 153 |
source .venv/bin/activate
|
| 154 |
+
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
|
|
|
| 155 |
```
|
| 156 |
|
| 157 |
**Terminal 3 - Celery Worker**:
|
| 158 |
```bash
|
| 159 |
cd backend
|
| 160 |
source .venv/bin/activate
|
| 161 |
+
# Use --pool=solo on macOS to avoid fork() crashes with ML libraries
|
| 162 |
+
celery -A tasks worker --loglevel=info --pool=solo
|
| 163 |
```
|
| 164 |
|
| 165 |
**Terminal 4 - Frontend**:
|
| 166 |
```bash
|
| 167 |
cd frontend
|
|
|
|
| 168 |
npm run dev
|
| 169 |
```
|
| 170 |
|
| 171 |
**Benefits**:
|
| 172 |
+
- Easier debugging (separate terminal per service)
|
| 173 |
+
- More control over each service
|
| 174 |
+
- See output in real-time
|
| 175 |
|
| 176 |
**Limitations**:
|
| 177 |
+
- Managing 4 terminals
|
| 178 |
+
- Need to manually stop each service
|
| 179 |
|
| 180 |
---
|
| 181 |
|
docs/backend/testing.md
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Backend Testing Guide
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
The backend test suite ensures the reliability of the audio processing pipeline, API endpoints, Celery tasks, and utility functions. All tests are written using pytest and can be run locally or in CI/CD pipelines.
|
| 6 |
+
|
| 7 |
+
## Test Structure
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
backend/tests/
|
| 11 |
+
├── conftest.py # Shared fixtures and test configuration
|
| 12 |
+
├── test_api.py # API endpoint tests (21 tests)
|
| 13 |
+
├── test_pipeline.py # Pipeline component tests (14 tests)
|
| 14 |
+
├── test_tasks.py # Celery task tests (9 tests)
|
| 15 |
+
└── test_utils.py # Utility function tests (15 tests)
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
**Total: 59 tests, 27% code coverage**
|
| 19 |
+
|
| 20 |
+
## Running Tests
|
| 21 |
+
|
| 22 |
+
### Quick Start
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
cd backend
|
| 26 |
+
source .venv/bin/activate
|
| 27 |
+
|
| 28 |
+
# Run all tests
|
| 29 |
+
pytest
|
| 30 |
+
|
| 31 |
+
# Run with coverage report
|
| 32 |
+
pytest --cov=. --cov-report=html
|
| 33 |
+
|
| 34 |
+
# Run specific test file
|
| 35 |
+
pytest tests/test_api.py
|
| 36 |
+
|
| 37 |
+
# Run specific test
|
| 38 |
+
pytest tests/test_api.py::TestRootEndpoint::test_root
|
| 39 |
+
|
| 40 |
+
# Run with verbose output
|
| 41 |
+
pytest -v
|
| 42 |
+
|
| 43 |
+
# Run with short traceback on failures
|
| 44 |
+
pytest --tb=short
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### Test Categories
|
| 48 |
+
|
| 49 |
+
**API Tests** (`test_api.py`):
|
| 50 |
+
- Root endpoint
|
| 51 |
+
- Health check (Redis connectivity)
|
| 52 |
+
- Transcription submission (validation, rate limiting)
|
| 53 |
+
- Job status queries
|
| 54 |
+
- Score/MIDI downloads
|
| 55 |
+
|
| 56 |
+
**Pipeline Tests** (`test_pipeline.py`):
|
| 57 |
+
- Function imports and callability
|
| 58 |
+
- Pipeline class instantiation
|
| 59 |
+
- Required method availability
|
| 60 |
+
- Progress callback functionality
|
| 61 |
+
|
| 62 |
+
**Task Tests** (`test_tasks.py`):
|
| 63 |
+
- Celery task execution
|
| 64 |
+
- Progress updates
|
| 65 |
+
- Error handling and retries
|
| 66 |
+
- Job not found scenarios
|
| 67 |
+
- Temp file cleanup
|
| 68 |
+
|
| 69 |
+
**Utility Tests** (`test_utils.py`):
|
| 70 |
+
- YouTube URL validation
|
| 71 |
+
- Video availability checks
|
| 72 |
+
- Error handling for invalid inputs
|
| 73 |
+
|
| 74 |
+
## Writing Tests
|
| 75 |
+
|
| 76 |
+
### Test Fixtures
|
| 77 |
+
|
| 78 |
+
Common fixtures are defined in `conftest.py`:
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
# Temporary storage directory
|
| 82 |
+
def temp_storage_dir():
|
| 83 |
+
"""Create temporary storage directory for tests."""
|
| 84 |
+
|
| 85 |
+
# Mock Redis client
|
| 86 |
+
def mock_redis():
|
| 87 |
+
"""Mock Redis client for testing."""
|
| 88 |
+
|
| 89 |
+
# FastAPI test client
|
| 90 |
+
def test_client(mock_redis, temp_storage_dir):
|
| 91 |
+
"""Create FastAPI test client with mocked dependencies."""
|
| 92 |
+
|
| 93 |
+
# Sample job data
|
| 94 |
+
def sample_job_id():
|
| 95 |
+
"""Generate a sample job ID for testing."""
|
| 96 |
+
|
| 97 |
+
def sample_job_data(sample_job_id):
|
| 98 |
+
"""Sample job data for testing."""
|
| 99 |
+
|
| 100 |
+
# Sample media files
|
| 101 |
+
def sample_audio_file(temp_storage_dir):
|
| 102 |
+
"""Create a sample WAV file for testing."""
|
| 103 |
+
|
| 104 |
+
def sample_midi_file(temp_storage_dir):
|
| 105 |
+
"""Create a sample MIDI file for testing."""
|
| 106 |
+
|
| 107 |
+
def sample_musicxml_content():
|
| 108 |
+
"""Sample MusicXML content for testing."""
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
### Example Test
|
| 112 |
+
|
| 113 |
+
```python
|
| 114 |
+
import pytest
|
| 115 |
+
from unittest.mock import patch, MagicMock
|
| 116 |
+
|
| 117 |
+
class TestTranscriptionPipeline:
|
| 118 |
+
"""Test the transcription pipeline."""
|
| 119 |
+
|
| 120 |
+
@patch('pipeline.TranscriptionPipeline')
|
| 121 |
+
def test_pipeline_runs_successfully(
|
| 122 |
+
self,
|
| 123 |
+
mock_pipeline,
|
| 124 |
+
temp_storage_dir
|
| 125 |
+
):
|
| 126 |
+
"""Test successful pipeline execution."""
|
| 127 |
+
# Setup mock
|
| 128 |
+
mock_instance = MagicMock()
|
| 129 |
+
mock_instance.run.return_value = str(temp_storage_dir / "output.musicxml")
|
| 130 |
+
mock_pipeline.return_value = mock_instance
|
| 131 |
+
|
| 132 |
+
# Execute
|
| 133 |
+
result = mock_instance.run()
|
| 134 |
+
|
| 135 |
+
# Assert
|
| 136 |
+
assert result.endswith("output.musicxml")
|
| 137 |
+
mock_instance.run.assert_called_once()
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
### Mocking Best Practices
|
| 141 |
+
|
| 142 |
+
**1. Mock External Dependencies**
|
| 143 |
+
|
| 144 |
+
Always mock:
|
| 145 |
+
- Redis connections
|
| 146 |
+
- File system operations (when testing logic, not I/O)
|
| 147 |
+
- External API calls (yt-dlp, YourMT3+ service)
|
| 148 |
+
- Time-dependent operations
|
| 149 |
+
|
| 150 |
+
**2. Use Proper Patch Targets**
|
| 151 |
+
|
| 152 |
+
Patch at the point of import, not the definition:
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
# CORRECT - patch where it's imported
|
| 156 |
+
@patch('main.validate_youtube_url')
|
| 157 |
+
|
| 158 |
+
# WRONG - patch at definition
|
| 159 |
+
@patch('app_utils.validate_youtube_url')
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
**3. Create Real Files for Integration Tests**
|
| 163 |
+
|
| 164 |
+
When testing file operations, create real temp files:
|
| 165 |
+
|
| 166 |
+
```python
|
| 167 |
+
def test_midi_processing(temp_storage_dir):
|
| 168 |
+
midi_file = temp_storage_dir / "test.mid"
|
| 169 |
+
midi_file.write_bytes(b"MThd...") # Create real file
|
| 170 |
+
result = process_midi(midi_file)
|
| 171 |
+
assert result.exists()
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
## Test Coverage
|
| 175 |
+
|
| 176 |
+
Current coverage by module:
|
| 177 |
+
|
| 178 |
+
| Module | Coverage | Notes |
|
| 179 |
+
|--------|----------|-------|
|
| 180 |
+
| app_config.py | 92% | Configuration loading |
|
| 181 |
+
| app_utils.py | 100% | URL validation, video checks |
|
| 182 |
+
| main.py | 55% | API endpoints (some error paths untested) |
|
| 183 |
+
| tasks.py | 91% | Celery task execution |
|
| 184 |
+
| pipeline.py | 5% | Needs integration tests with real ML models |
|
| 185 |
+
| tests/*.py | 100% | Test code itself |
|
| 186 |
+
|
| 187 |
+
**Note**: Low pipeline.py coverage is expected since it requires ML models and GPU. Integration tests should be run separately with real hardware.
|
| 188 |
+
|
| 189 |
+
## Continuous Integration
|
| 190 |
+
|
| 191 |
+
### GitHub Actions Example
|
| 192 |
+
|
| 193 |
+
```yaml
|
| 194 |
+
name: Backend Tests
|
| 195 |
+
|
| 196 |
+
on: [push, pull_request]
|
| 197 |
+
|
| 198 |
+
jobs:
|
| 199 |
+
test:
|
| 200 |
+
runs-on: ubuntu-latest
|
| 201 |
+
|
| 202 |
+
services:
|
| 203 |
+
redis:
|
| 204 |
+
image: redis:7-alpine
|
| 205 |
+
ports:
|
| 206 |
+
- 6379:6379
|
| 207 |
+
|
| 208 |
+
steps:
|
| 209 |
+
- uses: actions/checkout@v3
|
| 210 |
+
|
| 211 |
+
- name: Set up Python 3.10
|
| 212 |
+
uses: actions/setup-python@v4
|
| 213 |
+
with:
|
| 214 |
+
python-version: '3.10'
|
| 215 |
+
|
| 216 |
+
- name: Install dependencies
|
| 217 |
+
run: |
|
| 218 |
+
cd backend
|
| 219 |
+
pip install -r requirements.txt
|
| 220 |
+
|
| 221 |
+
- name: Run tests
|
| 222 |
+
run: |
|
| 223 |
+
cd backend
|
| 224 |
+
pytest --cov=. --cov-report=xml
|
| 225 |
+
|
| 226 |
+
- name: Upload coverage
|
| 227 |
+
uses: codecov/codecov-action@v3
|
| 228 |
+
with:
|
| 229 |
+
file: ./backend/coverage.xml
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
## Common Testing Patterns
|
| 233 |
+
|
| 234 |
+
### Testing API Endpoints
|
| 235 |
+
|
| 236 |
+
```python
|
| 237 |
+
def test_submit_transcription(test_client, mock_redis):
|
| 238 |
+
"""Test transcription submission."""
|
| 239 |
+
response = test_client.post(
|
| 240 |
+
"/api/v1/transcribe",
|
| 241 |
+
json={"youtube_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"}
|
| 242 |
+
)
|
| 243 |
+
assert response.status_code == 201
|
| 244 |
+
assert "job_id" in response.json()
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
### Testing Celery Tasks
|
| 248 |
+
|
| 249 |
+
```python
|
| 250 |
+
@patch('tasks.TranscriptionPipeline')
|
| 251 |
+
@patch('tasks.redis_client')
|
| 252 |
+
def test_task_execution(mock_redis, mock_pipeline):
|
| 253 |
+
"""Test Celery task executes successfully."""
|
| 254 |
+
from tasks import process_transcription_task
|
| 255 |
+
|
| 256 |
+
# Setup
|
| 257 |
+
mock_redis.hgetall.return_value = {
|
| 258 |
+
'job_id': 'test-123',
|
| 259 |
+
'youtube_url': 'https://youtube.com/watch?v=test'
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
# Execute
|
| 263 |
+
process_transcription_task('test-123')
|
| 264 |
+
|
| 265 |
+
# Verify
|
| 266 |
+
assert mock_pipeline.called
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
### Testing File Operations
|
| 270 |
+
|
| 271 |
+
```python
|
| 272 |
+
def test_file_cleanup(temp_storage_dir):
|
| 273 |
+
"""Test temporary files are cleaned up."""
|
| 274 |
+
temp_file = temp_storage_dir / "temp.wav"
|
| 275 |
+
temp_file.write_bytes(b"test")
|
| 276 |
+
|
| 277 |
+
cleanup_temp_files(temp_storage_dir)
|
| 278 |
+
|
| 279 |
+
assert not temp_file.exists()
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
## Troubleshooting Tests
|
| 283 |
+
|
| 284 |
+
### Common Issues
|
| 285 |
+
|
| 286 |
+
**1. Import Errors**
|
| 287 |
+
|
| 288 |
+
```bash
|
| 289 |
+
# Make sure you're in the venv
|
| 290 |
+
source .venv/bin/activate
|
| 291 |
+
|
| 292 |
+
# Verify pytest is installed
|
| 293 |
+
pytest --version
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
**2. Redis Connection Errors**
|
| 297 |
+
|
| 298 |
+
Tests mock Redis by default. If you see connection errors:
|
| 299 |
+
|
| 300 |
+
```python
|
| 301 |
+
# Check conftest.py has mock_redis fixture
|
| 302 |
+
# Ensure test uses the fixture:
|
| 303 |
+
def test_something(mock_redis): # Add this parameter
|
| 304 |
+
...
|
| 305 |
+
```
|
| 306 |
+
|
| 307 |
+
**3. File Permission Errors**
|
| 308 |
+
|
| 309 |
+
Temp directories should be writable:
|
| 310 |
+
|
| 311 |
+
```python
|
| 312 |
+
# Use the temp_storage_dir fixture
|
| 313 |
+
def test_something(temp_storage_dir):
|
| 314 |
+
file_path = temp_storage_dir / "test.txt"
|
| 315 |
+
file_path.write_text("content")
|
| 316 |
+
```
|
| 317 |
+
|
| 318 |
+
**4. Async Test Errors**
|
| 319 |
+
|
| 320 |
+
For async tests, use pytest-asyncio:
|
| 321 |
+
|
| 322 |
+
```python
|
| 323 |
+
import pytest
|
| 324 |
+
|
| 325 |
+
@pytest.mark.asyncio
|
| 326 |
+
async def test_async_function():
|
| 327 |
+
result = await some_async_function()
|
| 328 |
+
assert result is not None
|
| 329 |
+
```
|
| 330 |
+
|
| 331 |
+
## Test Performance
|
| 332 |
+
|
| 333 |
+
**Running all tests**: ~5-10 seconds
|
| 334 |
+
- API tests: ~2 seconds
|
| 335 |
+
- Pipeline tests: <1 second
|
| 336 |
+
- Task tests: ~2 seconds
|
| 337 |
+
- Utils tests: <1 second
|
| 338 |
+
|
| 339 |
+
**Tips for faster tests**:
|
| 340 |
+
- Mock expensive operations (ML inference, file I/O)
|
| 341 |
+
- Use `pytest -n auto` for parallel execution (requires pytest-xdist)
|
| 342 |
+
- Run specific test files during development
|
| 343 |
+
|
| 344 |
+
## Future Improvements
|
| 345 |
+
|
| 346 |
+
**Needed Tests**:
|
| 347 |
+
1. Integration tests with real YourMT3+ model
|
| 348 |
+
2. End-to-end tests with actual YouTube videos
|
| 349 |
+
3. Performance benchmarks
|
| 350 |
+
4. Load testing for concurrent jobs
|
| 351 |
+
5. WebSocket connection tests
|
| 352 |
+
6. MIDI quantization edge cases
|
| 353 |
+
7. MusicXML generation validation
|
| 354 |
+
|
| 355 |
+
**Coverage Goals**:
|
| 356 |
+
- Increase pipeline.py to 40% (integration tests)
|
| 357 |
+
- Increase main.py to 80% (all error paths)
|
| 358 |
+
- Add performance regression tests
|
| 359 |
+
|
| 360 |
+
## References
|
| 361 |
+
|
| 362 |
+
- [pytest Documentation](https://docs.pytest.org/)
|
| 363 |
+
- [pytest-asyncio](https://pytest-asyncio.readthedocs.io/)
|
| 364 |
+
- [unittest.mock](https://docs.python.org/3/library/unittest.mock.html)
|
| 365 |
+
- [FastAPI Testing](https://fastapi.tiangolo.com/tutorial/testing/)
|
docs/getting-started.md
CHANGED
|
@@ -164,23 +164,44 @@ graph TB
|
|
| 164 |
|
| 165 |
## Setting Up Local Development
|
| 166 |
|
| 167 |
-
See [
|
| 168 |
|
| 169 |
```bash
|
| 170 |
# Clone repo
|
| 171 |
git clone https://github.com/yourusername/rescored.git
|
| 172 |
cd rescored
|
| 173 |
|
| 174 |
-
#
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
# Services:
|
| 178 |
# - Frontend: http://localhost:5173
|
| 179 |
-
# - API: http://localhost:8000
|
| 180 |
-
# -
|
| 181 |
-
# -
|
|
|
|
|
|
|
|
|
|
| 182 |
```
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
---
|
| 185 |
|
| 186 |
## Common Questions
|
|
|
|
| 164 |
|
| 165 |
## Setting Up Local Development
|
| 166 |
|
| 167 |
+
See the [main README](../README.md) for detailed setup instructions. Quick start:
|
| 168 |
|
| 169 |
```bash
|
| 170 |
# Clone repo
|
| 171 |
git clone https://github.com/yourusername/rescored.git
|
| 172 |
cd rescored
|
| 173 |
|
| 174 |
+
# Setup backend (Python 3.10)
|
| 175 |
+
cd backend
|
| 176 |
+
python3.10 -m venv .venv
|
| 177 |
+
source .venv/bin/activate
|
| 178 |
+
pip install -r requirements.txt
|
| 179 |
+
|
| 180 |
+
# Setup frontend
|
| 181 |
+
cd ../frontend
|
| 182 |
+
npm install
|
| 183 |
+
|
| 184 |
+
# Start all services (from project root)
|
| 185 |
+
cd ..
|
| 186 |
+
./start.sh
|
| 187 |
|
| 188 |
# Services:
|
| 189 |
# - Frontend: http://localhost:5173
|
| 190 |
+
# - Backend API: http://localhost:8000
|
| 191 |
+
# - API Docs: http://localhost:8000/docs
|
| 192 |
+
# - Redis: localhost:6379 (must be running: brew services start redis)
|
| 193 |
+
|
| 194 |
+
# Stop all services
|
| 195 |
+
./stop.sh
|
| 196 |
```
|
| 197 |
|
| 198 |
+
**Requirements:**
|
| 199 |
+
- Python 3.10 (for madmom compatibility)
|
| 200 |
+
- Node.js 18+
|
| 201 |
+
- Redis 7+
|
| 202 |
+
- FFmpeg
|
| 203 |
+
- YouTube cookies (see README for setup)
|
| 204 |
+
|
| 205 |
---
|
| 206 |
|
| 207 |
## Common Questions
|