calebhan commited on
Commit
44a2550
·
1 Parent(s): c27ae8d

mvp scope

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +31 -1
  2. README.md +264 -1
  3. backend/.env.example +16 -0
  4. backend/Dockerfile +25 -0
  5. backend/Dockerfile.worker +33 -0
  6. backend/celery_app.py +41 -0
  7. backend/config.py +50 -0
  8. backend/main.py +407 -0
  9. backend/pipeline.py +881 -0
  10. backend/pytest.ini +45 -0
  11. backend/requirements-test.txt +14 -0
  12. backend/requirements.txt +35 -0
  13. backend/scripts/README.md +184 -0
  14. backend/scripts/analyze_transcription.py +175 -0
  15. backend/scripts/diagnose_pipeline.py +307 -0
  16. backend/scripts/test_accuracy.py +277 -0
  17. backend/scripts/test_demucs_models.py +199 -0
  18. backend/scripts/test_e2e.py +106 -0
  19. backend/scripts/test_quick_verify.py +142 -0
  20. backend/tasks.py +205 -0
  21. backend/tests/__init__.py +1 -0
  22. backend/tests/conftest.py +169 -0
  23. backend/tests/test_api.py +369 -0
  24. backend/tests/test_pipeline.py +102 -0
  25. backend/tests/test_tasks.py +243 -0
  26. backend/tests/test_utils.py +147 -0
  27. backend/utils.py +79 -0
  28. docker-compose.yml +79 -0
  29. docs/testing/backend-testing.md +520 -0
  30. docs/testing/baseline-accuracy.md +178 -0
  31. docs/testing/failure-modes.md +216 -0
  32. docs/testing/frontend-testing.md +653 -0
  33. docs/testing/overview.md +315 -0
  34. docs/testing/test-videos.md +371 -0
  35. frontend/.env.example +1 -0
  36. frontend/.gitignore +24 -0
  37. frontend/Dockerfile +19 -0
  38. frontend/README.md +73 -0
  39. frontend/eslint.config.js +23 -0
  40. frontend/index.html +13 -0
  41. frontend/package-lock.json +0 -0
  42. frontend/package.json +44 -0
  43. frontend/public/vite.svg +1 -0
  44. frontend/scripts/debug-parser.cjs +58 -0
  45. frontend/scripts/test-chord-handling.cjs +42 -0
  46. frontend/src/App.css +30 -0
  47. frontend/src/App.tsx +36 -0
  48. frontend/src/api/client.ts +143 -0
  49. frontend/src/assets/react.svg +1 -0
  50. frontend/src/components/JobSubmission.css +83 -0
.gitignore CHANGED
@@ -156,6 +156,7 @@ ENV/
156
  env.bak/
157
  venv.bak/
158
  CLAUDE.md
 
159
 
160
  # Spyder project settings
161
  .spyderproject
@@ -213,4 +214,33 @@ marimo/_lsp/
213
  __marimo__/
214
 
215
  # Streamlit
216
- .streamlit/secrets.toml
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  env.bak/
157
  venv.bak/
158
  CLAUDE.md
159
+ .claude
160
 
161
  # Spyder project settings
162
  .spyderproject
 
214
  __marimo__/
215
 
216
  # Streamlit
217
+ .streamlit/secrets.toml
218
+
219
+ # Rescored specific
220
+ # Backend
221
+ backend/.env
222
+ backend/storage/
223
+ backend/*.musicxml
224
+ backend/*.mid
225
+ backend/*.wav
226
+
227
+ # Frontend
228
+ frontend/node_modules/
229
+ frontend/dist/
230
+ frontend/.env.local
231
+ frontend/.env.production
232
+
233
+ # Storage (contains sensitive cookies)
234
+ storage/*.txt
235
+ storage/*.json
236
+ storage/youtube_cookies*
237
+ !storage/README.txt
238
+ storage/outputs/*
239
+ storage/temp/*
240
+
241
+ # Temp files
242
+ /tmp/
243
+ *.tmp
244
+
245
+ # Docker volumes
246
+ docker-compose.override.yml
README.md CHANGED
@@ -1 +1,264 @@
1
- # rescored
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Rescored - AI Music Transcription
2
+
3
+ Convert YouTube videos into editable sheet music using AI.
4
+
5
+ ## Overview
6
+
7
+ Rescored transcribes YouTube videos to professional-quality music notation:
8
+ 1. **Submit** a YouTube URL
9
+ 2. **AI Processing** extracts audio, separates instruments, and transcribes to MIDI
10
+ 3. **Edit** the notation in an interactive editor
11
+ 4. **Export** as MusicXML or MIDI
12
+
13
+ **Tech Stack**:
14
+ - **Backend**: Python/FastAPI + Celery + Redis
15
+ - **Frontend**: React + VexFlow (notation) + Tone.js (playback)
16
+ - **ML**: Demucs (source separation) + basic-pitch (transcription)
17
+
18
+ ## Quick Start
19
+
20
+ ### Prerequisites
21
+
22
+ - **Docker Desktop** (recommended) OR:
23
+ - Python 3.11+
24
+ - Node.js 18+
25
+ - Redis 7+
26
+ - FFmpeg
27
+ - (Optional) NVIDIA GPU with CUDA for faster processing
28
+
29
+ ### Option 1: Docker Compose (Recommended)
30
+
31
+ ```bash
32
+ # Clone repository
33
+ git clone https://github.com/yourusername/rescored.git
34
+ cd rescored
35
+ ```
36
+
37
+ #### ⚠️ REQUIRED: YouTube Cookies Setup
38
+
39
+ YouTube requires authentication for video downloads (as of December 2024). You **MUST** export your YouTube cookies before the application will work.
40
+
41
+ **Quick Setup (5 minutes):**
42
+
43
+ 1. **Install Browser Extension**
44
+ - Install [Get cookies.txt LOCALLY](https://chrome.google.com/webstore/detail/cclelndahbckbenkjhflpdbgdldlbecc) for Chrome/Edge/Brave
45
+
46
+ 2. **Export Cookies**
47
+ - Open a **NEW private/incognito window** (this is important!)
48
+ - **Sign in to YouTube** with your Google account
49
+ - **Visit any YouTube video page**
50
+ - **Click the extension icon** in your browser toolbar
51
+ - **Click "Export"** or "Download"
52
+ - **Save the file** to your computer
53
+
54
+ 3. **Place Cookie File**
55
+ ```bash
56
+ # Create storage directory
57
+ mkdir -p storage
58
+
59
+ # Move the exported file (adjust path if needed)
60
+ mv ~/Downloads/youtube.com_cookies.txt ./storage/youtube_cookies.txt
61
+
62
+ # OR on Windows:
63
+ # move %USERPROFILE%\Downloads\youtube.com_cookies.txt storage\youtube_cookies.txt
64
+ ```
65
+
66
+ 4. **Start Services**
67
+ ```bash
68
+ docker-compose up
69
+
70
+ # Services will be available at:
71
+ # - Frontend: http://localhost:5173
72
+ # - Backend API: http://localhost:8000
73
+ # - API Docs: http://localhost:8000/docs
74
+ ```
75
+
76
+ **Verification:**
77
+ ```bash
78
+ docker-compose exec worker ls -lh /app/storage/youtube_cookies.txt
79
+ ```
80
+ You should see the file listed.
81
+
82
+ **Troubleshooting:**
83
+
84
+ - **"Please sign in" error**: Make sure you exported from a private/incognito window. Export fresh cookies (don't reuse old ones). Ensure the file is named exactly `youtube_cookies.txt` and isn't empty.
85
+
86
+ - **File format errors**: The first line should be `# Netscape HTTP Cookie File`. If not, use the browser extension method.
87
+
88
+ - **Cookies expire quickly**: Export from a NEW incognito window each time. You may need to re-export periodically.
89
+
90
+ **Security Note:** ⚠️ Never commit `youtube_cookies.txt` to git (it's already in `.gitignore`). Your cookies contain authentication tokens for your Google account—keep them private!
91
+
92
+ **Why Is This Required?** YouTube implemented bot detection in late 2024 that blocks unauthenticated downloads. Even though our tool is for legitimate transcription purposes, YouTube's systems can't distinguish it from scrapers. By providing your cookies, you're proving you're a real user who has agreed to YouTube's terms of service.
93
+
94
+ ### Option 2: Manual Setup
95
+
96
+ **Backend**:
97
+ ```bash
98
+ cd backend
99
+
100
+ # Create virtual environment
101
+ python3 -m venv venv
102
+ source venv/bin/activate # On Windows: venv\Scripts\activate
103
+
104
+ # Install dependencies
105
+ pip install -r requirements.txt
106
+
107
+ # Copy environment file
108
+ cp .env.example .env
109
+
110
+ # Start Redis (in separate terminal)
111
+ redis-server
112
+
113
+ # Start Celery worker (in separate terminal)
114
+ celery -A tasks worker --loglevel=info
115
+
116
+ # Start API server
117
+ python main.py
118
+ ```
119
+
120
+ **Frontend**:
121
+ ```bash
122
+ cd frontend
123
+
124
+ # Install dependencies
125
+ npm install
126
+
127
+ # Start dev server
128
+ npm run dev
129
+ ```
130
+
131
+ ## Usage
132
+
133
+ 1. Open [http://localhost:5173](http://localhost:5173)
134
+ 2. Paste a YouTube URL (piano music recommended for best results)
135
+ 3. Wait 1-2 minutes for transcription (with GPU) or 10-15 minutes (CPU)
136
+ 4. Edit the notation in the interactive editor
137
+ 5. Export as MusicXML or MIDI
138
+
139
+ ## MVP Features
140
+
141
+ ✅ YouTube URL input and validation
142
+ ✅ Piano-only transcription (MVP limitation)
143
+ ✅ Single staff notation (treble clef)
144
+ ✅ Basic editing: select, delete, add notes
145
+ ✅ Play/pause with tempo control
146
+ ✅ Export MusicXML
147
+
148
+ ### Coming in Phase 2
149
+
150
+ - Multi-instrument transcription
151
+ - Grand staff (treble + bass)
152
+ - Advanced editing (copy/paste, undo/redo)
153
+ - MIDI export
154
+ - PDF export
155
+
156
+ ## Project Structure
157
+
158
+ ```
159
+ rescored/
160
+ ├── backend/ # Python/FastAPI backend
161
+ │ ├── main.py # REST API + WebSocket server
162
+ │ ├── tasks.py # Celery background workers
163
+ │ ├── pipeline.py # Audio processing pipeline
164
+ │ ├── config.py # Configuration
165
+ │ └── requirements.txt # Python dependencies
166
+ ├── frontend/ # React frontend
167
+ │ ├── src/
168
+ │ │ ├── components/ # UI components
169
+ │ │ ├── store/ # Zustand state management
170
+ │ │ └── api/ # API client
171
+ │ └── package.json # Node dependencies
172
+ ├── docs/ # Comprehensive documentation
173
+ └── docker-compose.yml # Docker setup
174
+ ```
175
+
176
+ ## Documentation
177
+
178
+ Comprehensive documentation is available in the [`docs/`](docs/) directory:
179
+
180
+ - [Getting Started](docs/getting-started.md)
181
+ - [Architecture Overview](docs/architecture/overview.md)
182
+ - [Backend Pipeline](docs/backend/pipeline.md)
183
+ - [Frontend Rendering](docs/frontend/notation-rendering.md)
184
+ - [MVP Scope](docs/features/mvp.md)
185
+ - [Known Challenges](docs/research/challenges.md)
186
+
187
+ ## Performance
188
+
189
+ **With GPU (RTX 3080)**:
190
+ - Download: ~10 seconds
191
+ - Source separation: ~45 seconds
192
+ - Transcription: ~5 seconds
193
+ - **Total: ~1-2 minutes**
194
+
195
+ **With CPU**:
196
+ - Download: ~10 seconds
197
+ - Source separation: ~8-10 minutes
198
+ - Transcription: ~30 seconds
199
+ - **Total: ~10-15 minutes**
200
+
201
+ ## Accuracy Expectations
202
+
203
+ Transcription is **70-80% accurate** for simple piano music, **60-70%** for complex pieces. The interactive editor is designed to make fixing errors easy.
204
+
205
+ ## Development
206
+
207
+ ### Running Tests
208
+
209
+ ```bash
210
+ # Backend tests
211
+ cd backend
212
+ pytest
213
+
214
+ # Frontend tests
215
+ cd frontend
216
+ npm test
217
+ ```
218
+
219
+ ### API Documentation
220
+
221
+ Once the backend is running, visit:
222
+ - Swagger UI: [http://localhost:8000/docs](http://localhost:8000/docs)
223
+ - ReDoc: [http://localhost:8000/redoc](http://localhost:8000/redoc)
224
+
225
+ ## Troubleshooting
226
+
227
+ **Worker not processing jobs?**
228
+ - Check Redis is running: `redis-cli ping` (should return PONG)
229
+ - Check worker logs: `docker-compose logs worker`
230
+
231
+ **GPU not detected?**
232
+ - Install NVIDIA Docker runtime
233
+ - Uncomment GPU section in `docker-compose.yml`
234
+ - Set `GPU_ENABLED=true` in `.env`
235
+
236
+ **YouTube download fails?**
237
+ - Video may be age-restricted or private
238
+ - Check yt-dlp is up to date: `pip install -U yt-dlp`
239
+
240
+ ## Contributing
241
+
242
+ See [CLAUDE.md](CLAUDE.md) for development guidelines.
243
+
244
+ ## License
245
+
246
+ MIT License - see [LICENSE](LICENSE) for details.
247
+
248
+ ## Acknowledgments
249
+
250
+ - **Demucs** (Meta AI Research) - Source separation
251
+ - **basic-pitch** (Spotify) - Audio transcription
252
+ - **VexFlow** - Music notation rendering
253
+ - **Tone.js** - Web audio synthesis
254
+
255
+ ## Roadmap
256
+
257
+ - **Phase 1 (MVP)**: ✅ Piano transcription with basic editing
258
+ - **Phase 2**: Multi-instrument, advanced editing, PDF export
259
+ - **Phase 3**: User accounts, cloud storage, collaboration
260
+ - **Phase 4**: Mobile app, real-time collaboration
261
+
262
+ ---
263
+
264
+ **Note**: This is an educational project. Users are responsible for copyright compliance when transcribing YouTube content.
backend/.env.example ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Redis Configuration
2
+ REDIS_URL=redis://localhost:6379/0
3
+
4
+ # Storage Configuration
5
+ STORAGE_PATH=/tmp/rescored
6
+
7
+ # API Configuration
8
+ API_HOST=0.0.0.0
9
+ API_PORT=8000
10
+
11
+ # Worker Configuration
12
+ GPU_ENABLED=true
13
+ MAX_VIDEO_DURATION=900 # 15 minutes in seconds
14
+
15
+ # CORS Origins (comma-separated)
16
+ CORS_ORIGINS=http://localhost:5173,http://localhost:3000
backend/Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Install system dependencies
4
+ RUN apt-get update && apt-get install -y \
5
+ ffmpeg \
6
+ git \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ # Set working directory
10
+ WORKDIR /app
11
+
12
+ # Copy requirements
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy application code
19
+ COPY . .
20
+
21
+ # Expose API port
22
+ EXPOSE 8000
23
+
24
+ # Default command (can be overridden in docker-compose)
25
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
backend/Dockerfile.worker ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use NVIDIA CUDA base image for GPU support
2
+ # For CPU-only, use: FROM python:3.11-slim
3
+ FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
4
+
5
+ # Install Python and system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ python3.11 \
8
+ python3-pip \
9
+ ffmpeg \
10
+ git \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Set working directory
14
+ WORKDIR /app
15
+
16
+ # Copy requirements
17
+ COPY requirements.txt .
18
+
19
+ # Install Python dependencies
20
+ RUN pip3 install --no-cache-dir -r requirements.txt
21
+
22
+ # Copy application code
23
+ COPY . .
24
+
25
+ # Create a wrapper script to patch torchaudio to use soundfile backend
26
+ RUN echo '#!/bin/bash\n\
27
+ # Force torchaudio to use soundfile backend\n\
28
+ export TORCHAUDIO_USE_BACKEND_DISPATCHER=0\n\
29
+ exec celery -A tasks worker --loglevel=info --concurrency=1\n\
30
+ ' > /app/start-worker.sh && chmod +x /app/start-worker.sh
31
+
32
+ # Default command
33
+ CMD ["/app/start-worker.sh"]
backend/celery_app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Celery application configuration."""
2
+ from celery import Celery
3
+ from kombu import Exchange, Queue
4
+ from config import settings
5
+
6
+ # Initialize Celery
7
+ celery_app = Celery(
8
+ "rescored",
9
+ broker=settings.redis_url,
10
+ backend=settings.redis_url,
11
+ )
12
+
13
+ # Configuration
14
+ celery_app.conf.update(
15
+ task_serializer="json",
16
+ accept_content=["json"],
17
+ result_serializer="json",
18
+ timezone="UTC",
19
+ enable_utc=True,
20
+
21
+ # Task settings
22
+ task_track_started=True,
23
+ task_time_limit=600, # 10 minutes max per task
24
+ task_soft_time_limit=540, # Soft limit at 9 minutes
25
+ task_acks_late=True, # Acknowledge task after completion (safer)
26
+ worker_prefetch_multiplier=1, # Take 1 task at a time
27
+
28
+ # Retry settings
29
+ task_autoretry_for=(Exception,),
30
+ task_retry_kwargs={'max_retries': 3},
31
+ task_retry_backoff=True, # Exponential backoff
32
+ task_retry_backoff_max=600,
33
+
34
+ # Priority queues
35
+ task_queues=(
36
+ Queue('default', Exchange('default'), routing_key='default', priority=5),
37
+ Queue('high_priority', Exchange('high_priority'), routing_key='high_priority', priority=10),
38
+ ),
39
+ task_default_queue='default',
40
+ task_default_routing_key='default',
41
+ )
backend/config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration module for Rescored backend."""
2
+ from pydantic_settings import BaseSettings
3
+ from pathlib import Path
4
+
5
+
6
+ class Settings(BaseSettings):
7
+ """Application settings."""
8
+
9
+ # Redis Configuration
10
+ redis_url: str = "redis://localhost:6379/0"
11
+
12
+ # Storage Configuration
13
+ storage_path: Path = Path("/tmp/rescored")
14
+
15
+ # API Configuration
16
+ api_host: str = "0.0.0.0"
17
+ api_port: int = 8000
18
+
19
+ # Worker Configuration
20
+ gpu_enabled: bool = True
21
+ max_video_duration: int = 900 # 15 minutes
22
+
23
+ # CORS Configuration
24
+ cors_origins: str = "http://localhost:5173,http://localhost:3000"
25
+
26
+ class Config:
27
+ env_file = ".env"
28
+ env_file_encoding = "utf-8"
29
+
30
+ @property
31
+ def cors_origins_list(self) -> list[str]:
32
+ """Parse CORS origins as list."""
33
+ return [origin.strip() for origin in self.cors_origins.split(",")]
34
+
35
+ @property
36
+ def temp_audio_path(self) -> Path:
37
+ """Temporary audio storage path."""
38
+ path = self.storage_path / "temp_audio"
39
+ path.mkdir(parents=True, exist_ok=True)
40
+ return path
41
+
42
+ @property
43
+ def outputs_path(self) -> Path:
44
+ """Output files storage path."""
45
+ path = self.storage_path / "outputs"
46
+ path.mkdir(parents=True, exist_ok=True)
47
+ return path
48
+
49
+
50
+ settings = Settings()
backend/main.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application for Rescored backend."""
2
+ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Request
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import FileResponse
5
+ from pydantic import BaseModel, HttpUrl
6
+ from uuid import uuid4
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from starlette.middleware.base import BaseHTTPMiddleware
10
+ from starlette.responses import JSONResponse
11
+ import redis
12
+ import json
13
+ import asyncio
14
+ from config import settings
15
+ from utils import validate_youtube_url, check_video_availability
16
+ from tasks import process_transcription_task
17
+
18
+ # Initialize FastAPI
19
+ app = FastAPI(
20
+ title="Rescored API",
21
+ description="AI-powered music transcription from YouTube videos",
22
+ version="1.0.0"
23
+ )
24
+
25
+ # Redis client (initialized before middleware)
26
+ redis_client = redis.Redis.from_url(settings.redis_url, decode_responses=True)
27
+
28
+
29
+ # === Rate Limiting Middleware ===
30
+
31
+ class RateLimitMiddleware(BaseHTTPMiddleware):
32
+ """
33
+ Rate limiting middleware to prevent abuse.
34
+
35
+ Limits: 10 transcription jobs per IP per hour (security requirement).
36
+ Uses Redis with sliding window counter.
37
+ """
38
+
39
+ async def dispatch(self, request: Request, call_next):
40
+ # Only rate limit the transcribe endpoint
41
+ if request.url.path == "/api/v1/transcribe" and request.method == "POST":
42
+ # Get client IP (handle proxies)
43
+ client_ip = request.client.host
44
+ if "x-forwarded-for" in request.headers:
45
+ client_ip = request.headers["x-forwarded-for"].split(",")[0].strip()
46
+
47
+ # Redis key for this IP
48
+ rate_limit_key = f"ratelimit:{client_ip}"
49
+
50
+ # Get current count
51
+ current_count = redis_client.get(rate_limit_key)
52
+
53
+ if current_count and int(current_count) >= 10:
54
+ return JSONResponse(
55
+ status_code=429,
56
+ content={
57
+ "detail": "Rate limit exceeded. Maximum 10 transcription jobs per hour per IP."
58
+ }
59
+ )
60
+
61
+ # Increment counter
62
+ pipe = redis_client.pipeline()
63
+ pipe.incr(rate_limit_key)
64
+ pipe.expire(rate_limit_key, 3600) # 1 hour TTL
65
+ pipe.execute()
66
+
67
+ response = await call_next(request)
68
+ return response
69
+
70
+
71
+ # CORS middleware
72
+ app.add_middleware(
73
+ CORSMiddleware,
74
+ allow_origins=settings.cors_origins_list,
75
+ allow_credentials=True,
76
+ allow_methods=["*"],
77
+ allow_headers=["*"],
78
+ )
79
+
80
+ # Rate limiting middleware
81
+ app.add_middleware(RateLimitMiddleware)
82
+
83
+
84
+ # === Request/Response Models ===
85
+
86
+ class TranscribeRequest(BaseModel):
87
+ """Request model for transcription."""
88
+ youtube_url: HttpUrl
89
+ options: dict = {"instruments": ["piano"]}
90
+
91
+
92
+ class TranscribeResponse(BaseModel):
93
+ """Response model for transcription submission."""
94
+ job_id: str
95
+ status: str
96
+ created_at: datetime
97
+ estimated_duration_seconds: int
98
+ websocket_url: str
99
+
100
+
101
+ class JobStatusResponse(BaseModel):
102
+ """Response model for job status."""
103
+ job_id: str
104
+ status: str
105
+ progress: int
106
+ current_stage: str | None
107
+ status_message: str | None
108
+ created_at: str
109
+ started_at: str | None
110
+ completed_at: str | None
111
+ failed_at: str | None
112
+ error: dict | None
113
+ result_url: str | None
114
+
115
+
116
+ # === WebSocket Connection Manager ===
117
+
118
+ class ConnectionManager:
119
+ """Manages WebSocket connections."""
120
+
121
+ def __init__(self):
122
+ self.active_connections: dict[str, list[WebSocket]] = {}
123
+
124
+ async def connect(self, websocket: WebSocket, job_id: str):
125
+ """Accept and register a WebSocket connection."""
126
+ await websocket.accept()
127
+ if job_id not in self.active_connections:
128
+ self.active_connections[job_id] = []
129
+ self.active_connections[job_id].append(websocket)
130
+
131
+ def disconnect(self, websocket: WebSocket, job_id: str):
132
+ """Remove a WebSocket connection."""
133
+ if job_id in self.active_connections:
134
+ self.active_connections[job_id].remove(websocket)
135
+ if not self.active_connections[job_id]:
136
+ del self.active_connections[job_id]
137
+
138
+ async def broadcast(self, job_id: str, message: dict):
139
+ """Broadcast message to all clients connected to a job."""
140
+ if job_id in self.active_connections:
141
+ dead_connections = []
142
+
143
+ for connection in self.active_connections[job_id]:
144
+ try:
145
+ await connection.send_json(message)
146
+ except:
147
+ dead_connections.append(connection)
148
+
149
+ # Clean up dead connections
150
+ for conn in dead_connections:
151
+ self.disconnect(conn, job_id)
152
+
153
+
154
+ manager = ConnectionManager()
155
+
156
+
157
+ # === REST Endpoints ===
158
+
159
+ @app.get("/")
160
+ async def root():
161
+ """Root endpoint."""
162
+ return {
163
+ "name": "Rescored API",
164
+ "version": "1.0.0",
165
+ "docs": "/docs"
166
+ }
167
+
168
+
169
+ @app.post("/api/v1/transcribe", response_model=TranscribeResponse, status_code=201)
170
+ async def submit_transcription(request: TranscribeRequest):
171
+ """
172
+ Submit a YouTube URL for transcription.
173
+
174
+ Args:
175
+ request: Transcription request with YouTube URL
176
+
177
+ Returns:
178
+ Job information including job ID and WebSocket URL
179
+ """
180
+ # Validate YouTube URL
181
+ is_valid, video_id_or_error = validate_youtube_url(str(request.youtube_url))
182
+ if not is_valid:
183
+ raise HTTPException(status_code=400, detail=video_id_or_error)
184
+
185
+ video_id = video_id_or_error
186
+
187
+ # Check video availability
188
+ availability = check_video_availability(video_id, settings.max_video_duration)
189
+ if not availability['available']:
190
+ raise HTTPException(status_code=422, detail=availability['reason'])
191
+
192
+ # Create job
193
+ job_id = str(uuid4())
194
+ job_data = {
195
+ "job_id": job_id,
196
+ "status": "queued",
197
+ "youtube_url": str(request.youtube_url),
198
+ "video_id": video_id,
199
+ "options": json.dumps(request.options),
200
+ "created_at": datetime.utcnow().isoformat(),
201
+ "progress": 0,
202
+ "current_stage": "queued",
203
+ "status_message": "Job queued for processing",
204
+ }
205
+
206
+ # Store in Redis
207
+ redis_client.hset(f"job:{job_id}", mapping=job_data)
208
+
209
+ # Queue Celery task
210
+ process_transcription_task.delay(job_id)
211
+
212
+ return TranscribeResponse(
213
+ job_id=job_id,
214
+ status="queued",
215
+ created_at=datetime.utcnow(),
216
+ estimated_duration_seconds=120,
217
+ websocket_url=f"ws://localhost:{settings.api_port}/api/v1/jobs/{job_id}/stream"
218
+ )
219
+
220
+
221
+ @app.get("/api/v1/jobs/{job_id}", response_model=JobStatusResponse)
222
+ async def get_job_status(job_id: str):
223
+ """
224
+ Get job status.
225
+
226
+ Args:
227
+ job_id: Job identifier
228
+
229
+ Returns:
230
+ Job status information
231
+ """
232
+ job_data = redis_client.hgetall(f"job:{job_id}")
233
+
234
+ if not job_data:
235
+ raise HTTPException(status_code=404, detail="Job not found")
236
+
237
+ # Parse error if present
238
+ error = None
239
+ if 'error' in job_data:
240
+ try:
241
+ error = json.loads(job_data['error'])
242
+ except:
243
+ error = {"message": job_data['error']}
244
+
245
+ # Construct result URL if completed
246
+ result_url = None
247
+ if job_data.get('status') == 'completed':
248
+ result_url = f"/api/v1/scores/{job_id}"
249
+
250
+ return JobStatusResponse(
251
+ job_id=job_id,
252
+ status=job_data.get('status', 'unknown'),
253
+ progress=int(job_data.get('progress', 0)),
254
+ current_stage=job_data.get('current_stage'),
255
+ status_message=job_data.get('status_message'),
256
+ created_at=job_data.get('created_at', ''),
257
+ started_at=job_data.get('started_at'),
258
+ completed_at=job_data.get('completed_at'),
259
+ failed_at=job_data.get('failed_at'),
260
+ error=error,
261
+ result_url=result_url
262
+ )
263
+
264
+
265
+ @app.get("/api/v1/scores/{job_id}")
266
+ async def download_score(job_id: str):
267
+ """
268
+ Download MusicXML score.
269
+
270
+ Args:
271
+ job_id: Job identifier
272
+
273
+ Returns:
274
+ MusicXML file
275
+ """
276
+ job_data = redis_client.hgetall(f"job:{job_id}")
277
+
278
+ if not job_data or job_data.get('status') != 'completed':
279
+ raise HTTPException(status_code=404, detail="Score not available")
280
+
281
+ output_path = job_data.get('output_path')
282
+ if not output_path:
283
+ raise HTTPException(status_code=404, detail="Score file path not found")
284
+
285
+ file_path = Path(output_path)
286
+ if not file_path.exists():
287
+ raise HTTPException(status_code=404, detail="Score file not found")
288
+
289
+ return FileResponse(
290
+ path=file_path,
291
+ media_type="application/vnd.recordare.musicxml+xml",
292
+ filename=f"score_{job_id}.musicxml"
293
+ )
294
+
295
+
296
+ @app.get("/api/v1/scores/{job_id}/midi")
297
+ async def download_midi(job_id: str):
298
+ """
299
+ Download MIDI version of score.
300
+
301
+ For MVP, this returns the cleaned MIDI from transcription (piano_clean.mid).
302
+
303
+ Args:
304
+ job_id: Job identifier
305
+
306
+ Returns:
307
+ MIDI file
308
+ """
309
+ job_data = redis_client.hgetall(f"job:{job_id}")
310
+
311
+ if not job_data or job_data.get('status') != 'completed':
312
+ raise HTTPException(status_code=404, detail="MIDI not available")
313
+
314
+ midi_path_str = job_data.get('midi_path')
315
+ if not midi_path_str:
316
+ raise HTTPException(status_code=404, detail="MIDI file path not found")
317
+
318
+ file_path = Path(midi_path_str)
319
+ if not file_path.exists():
320
+ raise HTTPException(status_code=404, detail="MIDI file not found")
321
+
322
+ return FileResponse(
323
+ path=file_path,
324
+ media_type="audio/midi",
325
+ filename=f"score_{job_id}.mid"
326
+ )
327
+
328
+
329
+ # === WebSocket Endpoint ===
330
+
331
+ @app.websocket("/api/v1/jobs/{job_id}/stream")
332
+ async def websocket_endpoint(websocket: WebSocket, job_id: str):
333
+ """
334
+ WebSocket endpoint for real-time progress updates.
335
+
336
+ Args:
337
+ websocket: WebSocket connection
338
+ job_id: Job identifier
339
+ """
340
+ await manager.connect(websocket, job_id)
341
+
342
+ try:
343
+ # Subscribe to Redis pub/sub for this job
344
+ pubsub = redis_client.pubsub()
345
+ pubsub.subscribe(f"job:{job_id}:updates")
346
+
347
+ # Listen for updates in a separate task
348
+ async def listen_for_updates():
349
+ for message in pubsub.listen():
350
+ if message['type'] == 'message':
351
+ update = json.loads(message['data'])
352
+ await websocket.send_json(update)
353
+
354
+ # Close connection if job completed or failed
355
+ if update.get('type') in ['completed', 'error']:
356
+ break
357
+
358
+ # Send initial status
359
+ job_data = redis_client.hgetall(f"job:{job_id}")
360
+ if job_data:
361
+ initial_update = {
362
+ "type": "progress",
363
+ "job_id": job_id,
364
+ "progress": int(job_data.get('progress', 0)),
365
+ "stage": job_data.get('current_stage', 'queued'),
366
+ "message": job_data.get('status_message', 'Starting...'),
367
+ "timestamp": datetime.utcnow().isoformat(),
368
+ }
369
+ await websocket.send_json(initial_update)
370
+
371
+ # Listen for updates (blocking)
372
+ await listen_for_updates()
373
+
374
+ except WebSocketDisconnect:
375
+ manager.disconnect(websocket, job_id)
376
+ finally:
377
+ pubsub.unsubscribe(f"job:{job_id}:updates")
378
+ pubsub.close()
379
+
380
+
381
+ # === Health Check ===
382
+
383
+ @app.get("/health")
384
+ async def health_check():
385
+ """Health check endpoint."""
386
+ # Check Redis connection
387
+ try:
388
+ redis_client.ping()
389
+ redis_status = "healthy"
390
+ except:
391
+ redis_status = "unhealthy"
392
+
393
+ return {
394
+ "status": "healthy" if redis_status == "healthy" else "degraded",
395
+ "redis": redis_status,
396
+ "storage": str(settings.storage_path)
397
+ }
398
+
399
+
400
+ if __name__ == "__main__":
401
+ import uvicorn
402
+ uvicorn.run(
403
+ "main:app",
404
+ host=settings.api_host,
405
+ port=settings.api_port,
406
+ reload=True
407
+ )
backend/pipeline.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AI-powered music transcription pipeline.
3
+
4
+ Processes YouTube videos to extract audio, separate sources, transcribe to MIDI,
5
+ and generate MusicXML notation.
6
+ """
7
+ import subprocess
8
+ from pathlib import Path
9
+ import tempfile
10
+ from typing import Optional
11
+ import mido
12
+ import librosa
13
+ from piano_transcription_inference import PianoTranscription, sample_rate
14
+ from music21 import converter, key, meter, tempo, note, clef, stream, chord as m21_chord
15
+
16
+
17
+ class TranscriptionPipeline:
18
+ """Handles the complete transcription workflow."""
19
+
20
+ def __init__(self, job_id: str, youtube_url: str, storage_path: Path):
21
+ self.job_id = job_id
22
+ self.youtube_url = youtube_url
23
+ self.storage_path = storage_path
24
+ self.temp_dir = storage_path / "temp" / job_id
25
+ self.temp_dir.mkdir(parents=True, exist_ok=True)
26
+ self.progress_callback = None
27
+
28
+ # Initialize ByteDance piano transcription model (lazy loading)
29
+ self._transcriptor = None
30
+
31
+ def set_progress_callback(self, callback):
32
+ """Set callback for progress updates: callback(percent, stage, message)"""
33
+ self.progress_callback = callback
34
+
35
+ def progress(self, percent: int, stage: str, message: str):
36
+ """Report progress if callback is set."""
37
+ if self.progress_callback:
38
+ self.progress_callback(percent, stage, message)
39
+
40
+ def run(self) -> Path:
41
+ """
42
+ Execute full pipeline and return path to MusicXML file.
43
+
44
+ Raises:
45
+ Exception: If any stage fails
46
+ """
47
+ try:
48
+ self.progress(0, "download", "Starting audio download")
49
+ audio_path = self.download_audio()
50
+
51
+ self.progress(20, "separate", "Starting source separation")
52
+ stems = self.separate_sources(audio_path)
53
+
54
+ self.progress(50, "transcribe", "Starting MIDI transcription")
55
+ midi_path = self.transcribe_to_midi(stems['other'])
56
+
57
+ self.progress(90, "musicxml", "Generating MusicXML")
58
+ musicxml_path = self.generate_musicxml(midi_path)
59
+
60
+ self.progress(100, "complete", "Transcription complete")
61
+ return musicxml_path
62
+
63
+ except Exception as e:
64
+ self.progress(0, "error", str(e))
65
+ raise
66
+
67
+ def download_audio(self) -> Path:
68
+ """Download audio from YouTube URL using yt-dlp."""
69
+ output_path = self.temp_dir / "audio.wav"
70
+
71
+ cmd = [
72
+ "yt-dlp",
73
+ "-x", # Extract audio
74
+ "--audio-format", "wav",
75
+ "--audio-quality", "0", # Best quality
76
+ "--output", str(output_path.with_suffix('')), # yt-dlp adds .wav
77
+ # Workarounds for YouTube restrictions
78
+ "--extractor-args", "youtube:player_client=android,web",
79
+ "--no-check-certificates",
80
+ self.youtube_url
81
+ ]
82
+
83
+ result = subprocess.run(cmd, capture_output=True, text=True)
84
+
85
+ if result.returncode != 0:
86
+ raise RuntimeError(f"yt-dlp failed: {result.stderr}")
87
+
88
+ if not output_path.exists():
89
+ raise RuntimeError("Audio file not created")
90
+
91
+ return output_path
92
+
93
+ def separate_sources(self, audio_path: Path) -> dict:
94
+ """
95
+ Separate audio into 4 stems using Demucs.
96
+
97
+ Returns:
98
+ dict with keys: drums, bass, vocals, other
99
+ """
100
+ # Run Demucs
101
+ cmd = [
102
+ "demucs",
103
+ "--two-stems=other", # For piano, we only need "other" stem
104
+ "-o", str(self.temp_dir),
105
+ str(audio_path)
106
+ ]
107
+
108
+ result = subprocess.run(cmd, capture_output=True, text=True)
109
+
110
+ if result.returncode != 0:
111
+ raise RuntimeError(f"Demucs failed: {result.stderr}")
112
+
113
+ # Demucs creates: temp/htdemucs/audio/*.wav
114
+ demucs_output = self.temp_dir / "htdemucs" / audio_path.stem
115
+
116
+ stems = {
117
+ 'other': demucs_output / "other.wav",
118
+ 'no_other': demucs_output / "no_other.wav",
119
+ }
120
+
121
+ # Verify output
122
+ if not stems['other'].exists():
123
+ raise RuntimeError("Demucs did not create expected output files")
124
+
125
+ return stems
126
+
127
+ def _get_transcriptor(self):
128
+ """Lazy load ByteDance piano transcription model."""
129
+ if self._transcriptor is None:
130
+ import torch
131
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
132
+ print(f" Loading ByteDance piano transcription model on {device}...")
133
+ self._transcriptor = PianoTranscription(device=device, checkpoint_path=None)
134
+ return self._transcriptor
135
+
136
+ def transcribe_to_midi(self, audio_path: Path) -> Path:
137
+ """
138
+ Transcribe audio to MIDI using ByteDance piano_transcription.
139
+
140
+ Args:
141
+ audio_path: Path to audio file (should be 'other' stem for piano)
142
+
143
+ Returns:
144
+ Path to generated MIDI file
145
+ """
146
+ output_dir = self.temp_dir
147
+ midi_path = output_dir / "piano.mid"
148
+
149
+ # Load audio with librosa (ByteDance expects specific sample rate and mono)
150
+ print(f" Loading audio from {audio_path}...")
151
+ audio, _ = librosa.load(str(audio_path), sr=sample_rate, mono=True)
152
+
153
+ # Get transcriptor (lazy loaded)
154
+ transcriptor = self._get_transcriptor()
155
+
156
+ # Transcribe to MIDI
157
+ print(f" Transcribing with ByteDance model...")
158
+ transcriptor.transcribe(audio, str(midi_path))
159
+
160
+ if not midi_path.exists():
161
+ raise RuntimeError("ByteDance transcription did not create MIDI file")
162
+
163
+ # Post-process MIDI (quantize, clean up)
164
+ cleaned_midi = self.clean_midi(midi_path)
165
+
166
+ return cleaned_midi
167
+
168
+ def clean_midi(self, midi_path: Path) -> Path:
169
+ """
170
+ Clean up MIDI file: filter invalid notes, remove very short notes, light quantization.
171
+
172
+ Args:
173
+ midi_path: Path to raw MIDI file
174
+
175
+ Returns:
176
+ Path to cleaned MIDI file
177
+ """
178
+ mid = mido.MidiFile(midi_path)
179
+
180
+ # First pass: collect all notes with timing info to filter by duration
181
+ for track in mid.tracks:
182
+ absolute_time = 0
183
+ active_notes = {} # note_number -> (start_time, start_msg_index, velocity)
184
+ note_durations = {} # msg_index -> duration_ticks
185
+ messages_with_abs_time = []
186
+
187
+ # Build list of messages with absolute timing
188
+ for msg_idx, msg in enumerate(track):
189
+ absolute_time += msg.time
190
+ messages_with_abs_time.append((msg_idx, msg, absolute_time))
191
+
192
+ if msg.type == 'note_on' and msg.velocity > 0:
193
+ active_notes[msg.note] = (absolute_time, msg_idx, msg.velocity)
194
+ elif msg.type in ['note_off', 'note_on']: # note_on with vel=0 is note_off
195
+ if msg.note in active_notes:
196
+ start_time, start_idx, velocity = active_notes.pop(msg.note)
197
+ duration = absolute_time - start_time
198
+ note_durations[start_idx] = duration
199
+
200
+ # Second pass: filter messages based on criteria
201
+ messages_to_keep = []
202
+ min_duration_ticks = mid.ticks_per_beat // 8 # Minimum 32nd note duration
203
+ min_velocity = 20 # Filter very quiet notes (likely noise)
204
+ notes_to_skip = set() # Track note_on indices to skip
205
+
206
+ # Identify notes to skip based on duration
207
+ for msg_idx in note_durations:
208
+ if note_durations[msg_idx] < min_duration_ticks:
209
+ notes_to_skip.add(msg_idx)
210
+
211
+ for msg_idx, msg, abs_time in messages_with_abs_time:
212
+ # Filter out notes outside piano range (A0 = 21, C8 = 108)
213
+ if hasattr(msg, 'note') and (msg.note < 21 or msg.note > 108):
214
+ continue
215
+
216
+ # Filter very quiet notes (likely false positives)
217
+ if msg.type == 'note_on' and msg.velocity > 0 and msg.velocity < min_velocity:
218
+ notes_to_skip.add(msg_idx)
219
+ continue
220
+
221
+ # Skip notes marked for removal (very short)
222
+ if msg.type == 'note_on' and msg_idx in notes_to_skip:
223
+ continue
224
+
225
+ # Skip note_off for notes we filtered out
226
+ if msg.type in ['note_off', 'note_on'] and hasattr(msg, 'note'):
227
+ # Check if this note_off corresponds to a filtered note_on
228
+ should_skip = False
229
+ for skip_idx in notes_to_skip:
230
+ if skip_idx < msg_idx:
231
+ skip_msg = messages_with_abs_time[skip_idx][1]
232
+ if skip_msg.type == 'note_on' and skip_msg.note == msg.note:
233
+ should_skip = True
234
+ break
235
+ if should_skip and msg.type == 'note_off':
236
+ continue
237
+
238
+ messages_to_keep.append((msg, abs_time))
239
+
240
+ # Third pass: rebuild track with delta times and light quantization
241
+ track.clear()
242
+ previous_time = 0
243
+
244
+ # Use 16th note quantization grid (less aggressive than 8th)
245
+ ticks_per_16th = mid.ticks_per_beat // 4
246
+
247
+ for msg, abs_time in messages_to_keep:
248
+ if msg.type in ['note_on', 'note_off']:
249
+ # Light quantization - only snap if close to grid (within 10%)
250
+ nearest_grid = round(abs_time / ticks_per_16th) * ticks_per_16th
251
+ snap_threshold = ticks_per_16th * 0.1
252
+
253
+ if abs(abs_time - nearest_grid) < snap_threshold:
254
+ abs_time = nearest_grid
255
+
256
+ # Set delta time from previous message
257
+ msg.time = max(0, abs_time - previous_time)
258
+ previous_time = abs_time
259
+ track.append(msg)
260
+
261
+ # Save cleaned MIDI
262
+ cleaned_path = midi_path.with_stem(f"{midi_path.stem}_clean")
263
+ mid.save(cleaned_path)
264
+
265
+ return cleaned_path
266
+
267
+ def generate_musicxml(self, midi_path: Path) -> Path:
268
+ """
269
+ Convert MIDI to MusicXML using music21, with grand staff for piano.
270
+
271
+ Args:
272
+ midi_path: Path to input MIDI file
273
+
274
+ Returns:
275
+ Path to output MusicXML file
276
+ """
277
+ self.progress(92, "musicxml", "Parsing MIDI")
278
+
279
+ # Parse MIDI
280
+ score = converter.parse(midi_path)
281
+
282
+ self.progress(94, "musicxml", "Analyzing key signature")
283
+
284
+ # Detect key signature
285
+ try:
286
+ analyzed_key = score.analyze('key')
287
+ score.insert(0, analyzed_key)
288
+ except:
289
+ # Default to C major if analysis fails
290
+ score.insert(0, key.Key('C'))
291
+
292
+ # Set time signature (default 4/4)
293
+ score.insert(0, meter.TimeSignature('4/4'))
294
+
295
+ # Extract or default tempo
296
+ midi_tempo = self._extract_tempo(score)
297
+ score.insert(0, tempo.MetronomeMark(number=midi_tempo))
298
+
299
+ self.progress(95, "musicxml", "Deduplicating overlapping notes")
300
+
301
+ # Fix overlapping polyphonic notes from basic-pitch before creating measures
302
+ # This prevents MusicXML corruption where measures have >4.0 beats
303
+ score = self._deduplicate_overlapping_notes(score)
304
+
305
+ self.progress(96, "musicxml", "Creating measures")
306
+
307
+ # For MVP: Use single staff with treble clef
308
+ # Grand staff splitting causes issues with overlapping polyphonic notes from basic-pitch
309
+ # TODO: Implement proper grand staff in Phase 2 with better note splitting algorithm
310
+
311
+ # Add treble clef (most piano music reads treble, bass notes will show ledger lines)
312
+ for part in score.parts:
313
+ part.insert(0, clef.TrebleClef())
314
+ part.partName = "Piano"
315
+
316
+ # Create measures
317
+ score = score.makeMeasures()
318
+
319
+ # Remove impossible note durations that makeMeasures() might have created
320
+ score = self._remove_impossible_durations(score)
321
+
322
+ # Fix tuplets containing impossible durations (must be done AFTER makeMeasures)
323
+ # This prevents "Cannot convert 2048th duration to MusicXML" errors during export
324
+ score = self._fix_tuplet_durations(score)
325
+
326
+ # Validate measure durations to catch any remaining issues
327
+ self._validate_measures(score)
328
+
329
+ self.progress(97, "musicxml", "Finalizing score")
330
+
331
+ self.progress(98, "musicxml", "Writing MusicXML file")
332
+
333
+ # Write MusicXML with retry logic for 2048th note errors
334
+ output_path = self.temp_dir / f"{self.job_id}.musicxml"
335
+ max_retries = 10 # Prevent infinite loop
336
+ retry_count = 0
337
+
338
+ while retry_count < max_retries:
339
+ try:
340
+ score.write('musicxml', fp=str(output_path))
341
+ break # Success!
342
+ except Exception as e:
343
+ error_msg = str(e)
344
+ # Check if this is a 2048th note error
345
+ if 'Cannot convert "2048th" duration to MusicXML' in error_msg or \
346
+ 'Cannot convert "4096th" duration to MusicXML' in error_msg:
347
+ # Extract measure number from error message
348
+ import re
349
+ match = re.search(r'measure \((\d+)\)', error_msg)
350
+ if match:
351
+ measure_num = int(match.group(1))
352
+ print(f" Fixing 2048th note error in measure {measure_num}...")
353
+
354
+ # Remove ALL tuplets from this measure as a last resort
355
+ for part in score.parts:
356
+ measures = list(part.getElementsByClass('Measure'))
357
+ if measure_num <= len(measures):
358
+ problem_measure = measures[measure_num - 1]
359
+
360
+ # Remove ALL notes/rests from the problematic measure
361
+ # The 2048th note error is created BY music21 during export
362
+ # We can't prevent it, so we just empty the measure
363
+ to_remove = list(problem_measure.recurse().notesAndRests)
364
+
365
+ for element in to_remove:
366
+ # Remove from its container
367
+ element.activeSite.remove(element)
368
+
369
+ # Clear caches
370
+ problem_measure.coreElementsChanged()
371
+ part.coreElementsChanged()
372
+
373
+ print(f" Removed all {len(to_remove)} elements from measure {measure_num}")
374
+
375
+ retry_count += 1
376
+ else:
377
+ # Can't parse measure number, give up
378
+ raise
379
+ else:
380
+ # Different error, give up
381
+ raise
382
+
383
+ if retry_count >= max_retries:
384
+ raise RuntimeError(f"Failed to fix 2048th note errors after {max_retries} attempts")
385
+
386
+ return output_path
387
+
388
+ def _deduplicate_overlapping_notes(self, score):
389
+ """
390
+ Deduplicate overlapping notes from basic-pitch to prevent MusicXML corruption.
391
+
392
+ Problem: basic-pitch outputs multiple notes at the same timestamp for polyphonic detection.
393
+ When music21's makeMeasures() processes these, it creates measures with >4.0 beats.
394
+
395
+ Solution: Group simultaneous notes (within 10ms) into chords, merge duplicate pitches.
396
+
397
+ Args:
398
+ score: music21 Score object before makeMeasures()
399
+
400
+ Returns:
401
+ Cleaned score with deduplicated notes
402
+ """
403
+ from music21 import stream, note, chord as m21_chord
404
+ from collections import defaultdict
405
+
406
+ # Process each part
407
+ for part in score.parts:
408
+ # Collect all notes with their absolute offsets
409
+ notes_by_time = defaultdict(list) # offset_ms -> [notes]
410
+
411
+ for element in part.flatten().notesAndRests:
412
+ if isinstance(element, note.Rest):
413
+ continue # Skip rests for deduplication
414
+
415
+ # Get absolute offset in quarter notes, convert to milliseconds for bucketing
416
+ offset_qn = element.offset
417
+ offset_ms = round(offset_qn * 1000) # Convert to ms for 10ms bucketing
418
+
419
+ # Bucket into 10ms slots (merge notes within 10ms of each other)
420
+ bucket = (offset_ms // 10) * 10
421
+
422
+ if isinstance(element, note.Note):
423
+ notes_by_time[bucket].append(element)
424
+ elif isinstance(element, m21_chord.Chord):
425
+ # Explode chords into individual notes for deduplication
426
+ for pitch in element.pitches:
427
+ n = note.Note(pitch)
428
+ n.quarterLength = element.quarterLength
429
+ n.offset = element.offset
430
+ notes_by_time[bucket].append(n)
431
+
432
+ # Rebuild part with deduplicated notes
433
+ new_part = stream.Part()
434
+
435
+ # Copy metadata (key, tempo, time signature will be added later)
436
+ new_part.id = part.id
437
+ new_part.partName = part.partName
438
+
439
+ for bucket_ms in sorted(notes_by_time.keys()):
440
+ bucket_notes = notes_by_time[bucket_ms]
441
+
442
+ if not bucket_notes:
443
+ continue
444
+
445
+ # Group by pitch to remove duplicates
446
+ pitch_groups = defaultdict(list)
447
+ for n in bucket_notes:
448
+ pitch_groups[n.pitch.midi].append(n)
449
+
450
+ # For each unique pitch, keep the note with longest duration
451
+ unique_notes = []
452
+ for midi_pitch, pitch_notes in pitch_groups.items():
453
+ # Sort by duration (longest first)
454
+ # Get velocity as integer for comparison (handle None values)
455
+ def get_velocity(note):
456
+ if hasattr(note, 'volume') and hasattr(note.volume, 'velocity'):
457
+ vel = note.volume.velocity
458
+ return vel if vel is not None else 64
459
+ return 64
460
+
461
+ pitch_notes.sort(key=lambda x: (x.quarterLength, get_velocity(x)), reverse=True)
462
+ best_note = pitch_notes[0]
463
+
464
+ # Filter out extremely short notes (< 64th note = 0.0625 quarter notes)
465
+ # MusicXML can't handle notes shorter than 1024th
466
+ if best_note.quarterLength >= 0.0625:
467
+ unique_notes.append(best_note)
468
+
469
+ if not unique_notes:
470
+ continue # Skip if all notes were too short
471
+
472
+ # Convert back to quarter notes for offset
473
+ offset_qn = bucket_ms / 1000.0
474
+
475
+ if len(unique_notes) == 1:
476
+ # Single note - snap duration to avoid impossible tuplets
477
+ n = note.Note(unique_notes[0].pitch)
478
+ n.quarterLength = self._snap_duration(unique_notes[0].quarterLength)
479
+ new_part.insert(offset_qn, n)
480
+ elif len(unique_notes) > 1:
481
+ # Multiple notes at same time -> create chord
482
+ # Use the shortest duration to avoid overlaps, then snap
483
+ min_duration = min(n.quarterLength for n in unique_notes)
484
+
485
+ c = m21_chord.Chord([n.pitch for n in unique_notes])
486
+ c.quarterLength = self._snap_duration(min_duration)
487
+ new_part.insert(offset_qn, c)
488
+
489
+ # Replace old part with new part
490
+ score.replace(part, new_part)
491
+
492
+ return score
493
+
494
+ def _snap_duration(self, duration):
495
+ """
496
+ Snap duration to nearest MusicXML-valid note value to avoid impossible tuplets.
497
+
498
+ Valid durations: whole (4.0), half (2.0), quarter (1.0), eighth (0.5),
499
+ sixteenth (0.25), thirty-second (0.125), sixty-fourth (0.0625)
500
+
501
+ Args:
502
+ duration: Quarter length as float or Fraction
503
+
504
+ Returns:
505
+ Snapped quarter length
506
+ """
507
+ valid_durations = [4.0, 2.0, 1.0, 0.5, 0.25, 0.125, 0.0625]
508
+
509
+ # Convert to float for comparison
510
+ dur_float = float(duration)
511
+
512
+ # Find nearest valid duration
513
+ nearest = min(valid_durations, key=lambda x: abs(x - dur_float))
514
+
515
+ return nearest
516
+
517
+ def _remove_impossible_durations(self, score):
518
+ """
519
+ Remove notes/rests with durations too short for MusicXML export (<128th note).
520
+
521
+ music21's makeMeasures() can create rests with impossible durations (2048th notes)
522
+ when filling gaps. This removes them to prevent MusicXML export errors.
523
+
524
+ Args:
525
+ score: music21 Score with measures
526
+
527
+ Returns:
528
+ Cleaned score
529
+ """
530
+ from music21 import note, stream
531
+
532
+ # Be VERY aggressive - remove anything shorter than 16th note
533
+ # ByteDance transcription creates many very short notes that cause music21
534
+ # to generate complex tuplets with impossible durations (2048th notes)
535
+ # By filtering aggressively, we prevent this MusicXML export error
536
+ MIN_DURATION = 0.25 # 16th note (1.0 / 4)
537
+
538
+ removed_count = 0
539
+ for part in score.parts:
540
+ for measure in part.getElementsByClass('Measure'):
541
+ # Collect elements to remove
542
+ to_remove = []
543
+
544
+ for element in measure.notesAndRests:
545
+ if element.quarterLength < MIN_DURATION:
546
+ to_remove.append(element)
547
+ removed_count += 1
548
+
549
+ # Remove impossible durations
550
+ for element in to_remove:
551
+ measure.remove(element)
552
+
553
+ if removed_count > 0:
554
+ print(f" Removed {removed_count} notes/rests shorter than 16th note to prevent tuplet errors")
555
+
556
+ return score
557
+
558
+ def _fix_tuplet_durations(self, score):
559
+ """
560
+ Fix tuplets containing notes/rests with impossible durations for MusicXML export.
561
+
562
+ The error occurs during MusicXML export when music21 tries to convert tuplet
563
+ durationNormal.type to MusicXML format. If a tuplet contains a 2048th note or
564
+ shorter, it will fail with MusicXMLExportException.
565
+
566
+ This method removes or fixes problematic elements within tuplets BEFORE export.
567
+
568
+ Args:
569
+ score: music21 Score with measures and tuplets
570
+
571
+ Returns:
572
+ Cleaned score
573
+ """
574
+ from music21 import note, stream, duration
575
+
576
+ # List of impossible duration types that MusicXML cannot represent
577
+ IMPOSSIBLE_TYPES = {'2048th', '4096th', '8192th', '16384th', '32768th'}
578
+
579
+ removed_count = 0
580
+ fixed_tuplets = 0
581
+
582
+ for part in score.parts:
583
+ for measure_idx, measure in enumerate(part.getElementsByClass('Measure')):
584
+ # Collect elements to remove (can't modify while iterating)
585
+ to_remove = []
586
+
587
+ # Check all notes and rests in the measure (not flattened - direct children)
588
+ for element in measure.notesAndRests:
589
+ should_remove = False
590
+
591
+ # Check if this element is part of a tuplet
592
+ if element.duration.tuplets:
593
+ # Check each tuplet attached to this element
594
+ for tuplet in element.duration.tuplets:
595
+ # Check if the tuplet's durationNormal has an impossible type
596
+ if hasattr(tuplet, 'durationNormal') and tuplet.durationNormal:
597
+ dur_type = tuplet.durationNormal.type
598
+ if dur_type in IMPOSSIBLE_TYPES:
599
+ should_remove = True
600
+ fixed_tuplets += 1
601
+ break
602
+
603
+ # Also check the element's own duration type
604
+ if element.duration.type in IMPOSSIBLE_TYPES:
605
+ should_remove = True
606
+ fixed_tuplets += 1
607
+
608
+ if should_remove:
609
+ to_remove.append(element)
610
+
611
+ # Remove problematic elements
612
+ for element in to_remove:
613
+ try:
614
+ measure.remove(element)
615
+ removed_count += 1
616
+ except Exception as e:
617
+ print(f" Warning: Could not remove element from measure {measure_idx + 1}: {e}")
618
+ continue
619
+
620
+ if removed_count > 0:
621
+ print(f" Fixed {fixed_tuplets} tuplets by removing {removed_count} elements with impossible durations")
622
+
623
+ return score
624
+
625
+ def _validate_measures(self, score):
626
+ """
627
+ Validate that all measures have correct durations matching their time signature.
628
+
629
+ Logs warnings for any measures that are overfull or underfull.
630
+
631
+ Args:
632
+ score: music21 Score with measures already created
633
+ """
634
+ for part_idx, part in enumerate(score.parts):
635
+ for measure_idx, measure in enumerate(part.getElementsByClass('Measure')):
636
+ # Get time signature for this measure
637
+ ts = measure.timeSignature or measure.getContextByClass('TimeSignature')
638
+ if not ts:
639
+ continue # Skip if no time signature
640
+
641
+ expected_duration = ts.barDuration.quarterLength
642
+ actual_duration = measure.duration.quarterLength
643
+
644
+ # Allow small floating-point tolerance (0.01 quarter notes = ~10ms at 120 BPM)
645
+ tolerance = 0.01
646
+
647
+ if abs(actual_duration - expected_duration) > tolerance:
648
+ print(f"WARNING: Measure {measure_idx + 1} in part {part_idx} has duration {float(actual_duration):.2f} "
649
+ f"(expected {float(expected_duration):.2f} for {ts.ratioString} time)")
650
+
651
+ def _split_into_grand_staff(self, score):
652
+ """
653
+ Split a measured score into treble and bass parts for piano grand staff.
654
+
655
+ Notes >= Middle C (C4/MIDI 60) go to treble clef (right hand)
656
+ Notes < Middle C go to bass clef (left hand)
657
+
658
+ This method processes a score that ALREADY has measures created by makeMeasures().
659
+ """
660
+ from music21 import stream, note, chord as m21_chord
661
+
662
+ # If score already has multiple parts, just add clefs and return
663
+ if len(score.parts) > 1:
664
+ for part_idx, part in enumerate(score.parts):
665
+ if part_idx == 0:
666
+ part.insert(0, clef.TrebleClef())
667
+ else:
668
+ part.insert(0, clef.BassClef())
669
+ return score
670
+
671
+ # Get the single part from the score
672
+ original_part = score.parts[0] if len(score.parts) > 0 else None
673
+ if not original_part:
674
+ return score
675
+
676
+ # Create new score with two parts
677
+ new_score = stream.Score()
678
+
679
+ # Copy metadata from original score
680
+ for element in score.flatten():
681
+ if isinstance(element, (key.Key, meter.TimeSignature, tempo.MetronomeMark)):
682
+ new_score.insert(0, element)
683
+
684
+ # Create right hand (treble) and left hand (bass) parts
685
+ treble_part = stream.Part()
686
+ treble_part.insert(0, clef.TrebleClef())
687
+ treble_part.partName = "Piano Right Hand"
688
+
689
+ bass_part = stream.Part()
690
+ bass_part.insert(0, clef.BassClef())
691
+ bass_part.partName = "Piano Left Hand"
692
+
693
+ # Middle C (C4) is MIDI note 60
694
+ SPLIT_POINT = 60
695
+
696
+ # Process each measure from the original part
697
+ for measure in original_part.getElementsByClass('Measure'):
698
+ # Create corresponding measures for treble and bass
699
+ treble_measure = stream.Measure(number=measure.number)
700
+ bass_measure = stream.Measure(number=measure.number)
701
+
702
+ # Copy time signature if present
703
+ for ts in measure.getElementsByClass(meter.TimeSignature):
704
+ treble_measure.insert(0, ts)
705
+ bass_measure.insert(0, ts)
706
+
707
+ # Process all notes and rests in this measure
708
+ for element in measure.notesAndRests:
709
+ offset = element.getOffsetInHierarchy(measure)
710
+
711
+ if isinstance(element, note.Rest):
712
+ # Skip rests - music21 will add them automatically where needed
713
+ continue
714
+
715
+ elif isinstance(element, note.Note):
716
+ # Single note - assign to treble or bass based on pitch
717
+ new_note = note.Note(element.pitch, quarterLength=element.quarterLength)
718
+
719
+ if element.pitch.midi >= SPLIT_POINT:
720
+ # Treble: add note only
721
+ treble_measure.insert(offset, new_note)
722
+ else:
723
+ # Bass: add note only
724
+ bass_measure.insert(offset, new_note)
725
+
726
+ elif isinstance(element, m21_chord.Chord):
727
+ # Chord - split notes between treble and bass
728
+ treble_pitches = []
729
+ bass_pitches = []
730
+
731
+ for pitch in element.pitches:
732
+ if pitch.midi >= SPLIT_POINT:
733
+ treble_pitches.append(pitch)
734
+ else:
735
+ bass_pitches.append(pitch)
736
+
737
+ # Create elements for treble (only if has notes)
738
+ if treble_pitches:
739
+ treble_chord = m21_chord.Chord(treble_pitches, quarterLength=element.quarterLength)
740
+ treble_measure.insert(offset, treble_chord)
741
+
742
+ # Create elements for bass (only if has notes)
743
+ if bass_pitches:
744
+ bass_chord = m21_chord.Chord(bass_pitches, quarterLength=element.quarterLength)
745
+ bass_measure.insert(offset, bass_chord)
746
+
747
+ # Add measures to parts
748
+ treble_part.append(treble_measure)
749
+ bass_part.append(bass_measure)
750
+
751
+ # Add parts to score (treble first for proper ordering)
752
+ new_score.insert(0, treble_part)
753
+ new_score.insert(0, bass_part)
754
+
755
+ # Let music21 add rests where needed and fix measure boundaries
756
+ try:
757
+ new_score.makeRests(inPlace=True, fillGaps=True)
758
+ except:
759
+ # If makeRests fails, continue anyway
760
+ pass
761
+
762
+ return new_score
763
+
764
+ def _extract_tempo(self, score) -> int:
765
+ """Extract tempo from MIDI or default to 120 BPM."""
766
+ for element in score.flatten():
767
+ if isinstance(element, tempo.MetronomeMark):
768
+ return int(element.number)
769
+ return 120
770
+
771
+ def cleanup(self):
772
+ """Delete temporary files (except output)."""
773
+ # Don't delete entire temp_dir yet - output file is still there
774
+ # Delete individual temp files instead
775
+ for file in self.temp_dir.glob("*.wav"):
776
+ file.unlink(missing_ok=True)
777
+ for file in self.temp_dir.glob("*_clean.mid"):
778
+ if file.name != "piano_clean.mid":
779
+ file.unlink(missing_ok=True)
780
+
781
+
782
+ # === Module-level convenience functions for backward compatibility ===
783
+
784
+ def download_audio(youtube_url: str, storage_path: Path) -> Path:
785
+ """Download audio from YouTube URL (module-level wrapper)."""
786
+ pipeline = TranscriptionPipeline("compat_job", youtube_url, storage_path)
787
+ return pipeline.download_audio()
788
+
789
+
790
+ def separate_sources(audio_path: Path, storage_path: Path) -> dict:
791
+ """Separate audio sources (module-level wrapper)."""
792
+ pipeline = TranscriptionPipeline("compat_job", "http://example.com", storage_path)
793
+ return pipeline.separate_sources(audio_path)
794
+
795
+
796
+ def transcribe_audio(
797
+ audio_path: Path,
798
+ storage_path: Path,
799
+ onset_threshold: float = 0.4,
800
+ frame_threshold: float = 0.35
801
+ ) -> Path:
802
+ """Transcribe audio to MIDI (module-level wrapper)."""
803
+ pipeline = TranscriptionPipeline("compat_job", "http://example.com", storage_path)
804
+ # Note: The class method doesn't support these parameters in the current signature
805
+ # But we create a job and transcribe
806
+ midi_path = pipeline.transcribe_to_midi(audio_path)
807
+ return midi_path
808
+
809
+
810
+ def quantize_midi(midi_path: Path, resolution: int = 480) -> Path:
811
+ """Quantize MIDI file (module-level wrapper)."""
812
+ pipeline = TranscriptionPipeline("compat_job", "http://example.com", midi_path.parent)
813
+ return pipeline.clean_midi(midi_path)
814
+
815
+
816
+ def remove_duplicate_notes(midi_path: Path) -> Path:
817
+ """Remove duplicate notes from MIDI (included in clean_midi)."""
818
+ # The implementation includes this in clean_midi
819
+ pipeline = TranscriptionPipeline("compat_job", "http://example.com", midi_path.parent)
820
+ return pipeline.clean_midi(midi_path)
821
+
822
+
823
+ def remove_short_notes(midi_path: Path, min_duration: int = 60) -> Path:
824
+ """Remove short notes from MIDI (included in clean_midi)."""
825
+ # The implementation includes this in clean_midi
826
+ pipeline = TranscriptionPipeline("compat_job", "http://example.com", midi_path.parent)
827
+ return pipeline.clean_midi(midi_path)
828
+
829
+
830
+ def generate_musicxml(midi_path: Path, storage_path: Path) -> Path:
831
+ """Generate MusicXML from MIDI (module-level wrapper)."""
832
+ pipeline = TranscriptionPipeline("compat_job", "http://example.com", storage_path)
833
+ return pipeline.generate_musicxml(midi_path)
834
+
835
+
836
+ def detect_key_signature(midi_path: Path) -> dict:
837
+ """Detect key signature from MIDI."""
838
+ score = converter.parse(midi_path)
839
+ try:
840
+ analyzed_key = score.analyze('key')
841
+ return {
842
+ 'tonic': analyzed_key.tonic.name,
843
+ 'mode': analyzed_key.mode
844
+ }
845
+ except:
846
+ return {'tonic': 'C', 'mode': 'major'}
847
+
848
+
849
+ def detect_time_signature(midi_path: Path) -> dict:
850
+ """Detect time signature from MIDI."""
851
+ score = converter.parse(midi_path)
852
+ for ts in score.flatten().getElementsByClass(meter.TimeSignature):
853
+ return {
854
+ 'numerator': ts.numerator,
855
+ 'denominator': ts.denominator
856
+ }
857
+ return {'numerator': 4, 'denominator': 4}
858
+
859
+
860
+ def detect_tempo(midi_path: Path) -> int:
861
+ """Detect tempo from MIDI."""
862
+ score = converter.parse(midi_path)
863
+ for t in score.flatten().getElementsByClass(tempo.MetronomeMark):
864
+ return int(t.number)
865
+ return 120
866
+
867
+
868
+ def run_transcription_pipeline(youtube_url: str, storage_path: Path) -> dict:
869
+ """Run the full transcription pipeline (module-level wrapper)."""
870
+ pipeline = TranscriptionPipeline("compat_job", youtube_url, storage_path)
871
+ try:
872
+ result = pipeline.run()
873
+ return {
874
+ 'status': 'success',
875
+ 'musicxml_path': str(result)
876
+ }
877
+ except Exception as e:
878
+ return {
879
+ 'status': 'failed',
880
+ 'error': str(e)
881
+ }
backend/pytest.ini ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [pytest]
2
+ testpaths = tests
3
+ python_files = test_*.py
4
+ python_classes = Test*
5
+ python_functions = test_*
6
+
7
+ # Show extra test summary info
8
+ addopts =
9
+ -v
10
+ --strict-markers
11
+ --tb=short
12
+ --disable-warnings
13
+ --cov=.
14
+ --cov-report=term-missing
15
+ --cov-report=html
16
+ --cov-branch
17
+
18
+ # Markers for categorizing tests
19
+ markers =
20
+ unit: Unit tests for individual functions
21
+ integration: Integration tests for multiple components
22
+ slow: Tests that take longer to run
23
+ gpu: Tests that require GPU
24
+ network: Tests that require network access
25
+
26
+ # Asyncio configuration
27
+ asyncio_mode = auto
28
+ asyncio_default_fixture_loop_scope = function
29
+
30
+ # Coverage options
31
+ [coverage:run]
32
+ omit =
33
+ tests/*
34
+ __pycache__/*
35
+ */site-packages/*
36
+ venv/*
37
+
38
+ [coverage:report]
39
+ exclude_lines =
40
+ pragma: no cover
41
+ def __repr__
42
+ raise AssertionError
43
+ raise NotImplementedError
44
+ if __name__ == .__main__.:
45
+ if TYPE_CHECKING:
backend/requirements-test.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test dependencies for Rescored backend
2
+ -r requirements.txt
3
+
4
+ # Testing framework
5
+ pytest==8.2.0
6
+ pytest-asyncio==0.24.0
7
+ pytest-cov==4.1.0
8
+ pytest-mock==3.12.0
9
+
10
+ # HTTP testing
11
+ httpx==0.26.0
12
+
13
+ # Test utilities
14
+ faker==22.5.1
backend/requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Web Framework
2
+ # Note: This file now includes torch/torchaudio as they are required by demucs on macOS
3
+ fastapi==0.115.5
4
+ uvicorn[standard]==0.32.1
5
+ python-multipart==0.0.20
6
+
7
+ # Task Queue
8
+ celery==5.4.0
9
+ redis==5.2.1
10
+
11
+ # Audio Processing
12
+ yt-dlp>=2025.12.8
13
+ soundfile==0.12.1
14
+ scipy
15
+ torch>=2.0.0
16
+ torchaudio>=2.9.1
17
+ torchcodec>=0.9.1
18
+ demucs>=3.0.6
19
+
20
+ # Pitch detection (macOS default runtime is CoreML)
21
+ basic-pitch==0.4.0
22
+
23
+ # Music Processing
24
+ music21==9.3.0
25
+ mido==1.3.3
26
+
27
+ # Utilities
28
+ python-dotenv==1.0.1
29
+ tenacity==9.0.0
30
+ pydantic==2.10.4
31
+ pydantic-settings==2.7.0
32
+ numpy<2.0.0
33
+
34
+ # WebSocket
35
+ websockets==14.1
backend/scripts/README.md ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Backend Scripts
2
+
3
+ Utility scripts for testing and analyzing the Rescored transcription pipeline.
4
+
5
+ ## Scripts
6
+
7
+ ### test_accuracy.py
8
+
9
+ **NEW** - Comprehensive accuracy testing suite that tests the pipeline with 10 diverse piano videos covering different styles and difficulty levels.
10
+
11
+ **Usage:**
12
+ ```bash
13
+ cd backend
14
+ python scripts/test_accuracy.py
15
+ ```
16
+
17
+ **Output:**
18
+ - Progress for each of 10 test videos
19
+ - Success/failure status per video
20
+ - Metrics: note count, measure count, separation quality
21
+ - Summary statistics (success rate, average metrics)
22
+ - Full results saved to JSON: `/tmp/rescored/accuracy_test_results.json`
23
+
24
+ **Test Videos** (varying difficulty):
25
+ - **Easy**: Simple scales, Twinkle Twinkle
26
+ - **Medium**: Für Elise, Canon in D, River Flows in You, Moonlight Sonata, Jazz Blues
27
+ - **Hard**: Chopin Nocturne, Clair de Lune
28
+ - **Very Hard**: La Campanella (Liszt)
29
+
30
+ **Expected Runtime**: 30-60 minutes for all 10 videos
31
+
32
+ **Purpose**: Establish baseline accuracy metrics for the MVP pipeline, identify common failure modes, and track improvements across phases.
33
+
34
+ ### test_e2e.py
35
+
36
+ End-to-end pipeline testing script. Downloads a YouTube video, runs the full transcription pipeline, and displays results.
37
+
38
+ **Usage:**
39
+ ```bash
40
+ cd backend
41
+ python scripts/test_e2e.py "<youtube_url>"
42
+ ```
43
+
44
+ **Example:**
45
+ ```bash
46
+ python scripts/test_e2e.py "https://www.youtube.com/watch?v=PAE88urB1xs"
47
+ ```
48
+
49
+ **Output:**
50
+ - Progress updates for each pipeline stage
51
+ - Total processing time
52
+ - MusicXML file path and size
53
+ - List of intermediate files
54
+ - Preview of generated MusicXML
55
+
56
+ **Test Videos:**
57
+ - Simple piano melody: https://www.youtube.com/watch?v=WyTb3DTu88c
58
+ - Classical piano: https://www.youtube.com/watch?v=fJ9rUzIMcZQ
59
+
60
+ ---
61
+
62
+ ### analyze_transcription.py
63
+
64
+ MIDI file analysis tool. Provides detailed statistics about transcribed notes to identify quality issues.
65
+
66
+ **Usage:**
67
+ ```bash
68
+ cd backend
69
+ python scripts/analyze_transcription.py <midi_path>
70
+ ```
71
+
72
+ **Example:**
73
+ ```bash
74
+ python scripts/analyze_transcription.py /tmp/rescored/temp/test_e2e/piano.mid
75
+ python scripts/analyze_transcription.py /tmp/rescored/temp/test_e2e/piano_clean.mid
76
+ ```
77
+
78
+ **Analysis Includes:**
79
+ - Total note count and density (notes/second)
80
+ - Pitch range and distribution
81
+ - Note duration statistics (average, median, min, max)
82
+ - Velocity (dynamics) analysis
83
+ - Polyphony (simultaneous notes)
84
+ - Detection of potential issues:
85
+ - Very short notes (< 100ms) - likely false positives
86
+ - Very quiet notes (velocity < 30) - likely noise
87
+ - High note density - over-transcription
88
+ - Extreme polyphony - detecting noise as notes
89
+ - Notes outside piano range
90
+
91
+ **Output Example:**
92
+ ```
93
+ ============================================================
94
+ MIDI Transcription Analysis
95
+ ============================================================
96
+ File: piano.mid
97
+ Duration: 248.1 seconds
98
+ Total notes: 1333
99
+ Notes per second: 5.37
100
+
101
+ Pitch Range:
102
+ Lowest: 35 (MIDI) = B1
103
+ Highest: 86 (MIDI) = D6
104
+ Range: 51 semitones
105
+
106
+ Note Durations:
107
+ Average: 0.433 seconds
108
+ Median: 0.325 seconds
109
+ Very short notes (< 100ms): 0 (0.0%)
110
+
111
+ Potential Issues:
112
+ ✓ No obvious issues detected
113
+ ============================================================
114
+ ```
115
+
116
+ ---
117
+
118
+ ## Workflow
119
+
120
+ 1. **Test the pipeline:**
121
+ ```bash
122
+ python scripts/test_e2e.py "https://www.youtube.com/watch?v=VIDEO_ID"
123
+ ```
124
+
125
+ 2. **Analyze the raw output:**
126
+ ```bash
127
+ python scripts/analyze_transcription.py /tmp/rescored/temp/test_e2e/piano.mid
128
+ ```
129
+
130
+ 3. **Analyze the cleaned output:**
131
+ ```bash
132
+ python scripts/analyze_transcription.py /tmp/rescored/temp/test_e2e/piano_clean.mid
133
+ ```
134
+
135
+ 4. **Listen to the result:**
136
+ ```bash
137
+ # Using MuseScore
138
+ musescore /tmp/rescored/temp/test_e2e/test_e2e.musicxml
139
+
140
+ # Or using timidity (MIDI playback)
141
+ timidity /tmp/rescored/temp/test_e2e/piano_clean.mid
142
+ ```
143
+
144
+ ---
145
+
146
+ ## Interpreting Results
147
+
148
+ ### Good Transcription Indicators
149
+ - Notes/second: 3-8 for piano (depends on complexity)
150
+ - Very short notes: < 10%
151
+ - Max polyphony: 3-10 simultaneous notes (piano is typically 2-6)
152
+ - Pitch range: Within MIDI 21-108 (A0 to C8)
153
+ - No significant issues detected
154
+
155
+ ### Warning Signs
156
+ - Notes/second > 10: Likely over-transcribing (too many false positives)
157
+ - Very short notes > 30%: Detecting noise as notes
158
+ - Max polyphony > 15: Probably including noise
159
+ - Many notes outside piano range: Need better filtering
160
+
161
+ ### Tuning Recommendations
162
+ If you see issues, adjust parameters in [pipeline.py](../pipeline.py):
163
+
164
+ **For too many false positives:**
165
+ - Increase `onset-threshold` (0.5 → 0.6)
166
+ - Increase `frame-threshold` (0.4 → 0.45)
167
+ - Increase `minimum-note-length` (127 → 150ms)
168
+
169
+ **For too many missing notes:**
170
+ - Decrease `onset-threshold` (0.5 → 0.45)
171
+ - Decrease `frame-threshold` (0.4 → 0.35)
172
+
173
+ **For timing issues:**
174
+ - Adjust quantization in `clean_midi()` method
175
+ - Change `ticks_per_16th` to `ticks_per_32nd` for lighter quantization
176
+
177
+ ---
178
+
179
+ ## Notes
180
+
181
+ - Scripts must be run from the `backend` directory (they use relative imports)
182
+ - Temporary files are stored in `/tmp/rescored/temp/<job_id>/`
183
+ - MusicXML output is saved in the temp directory with the job_id as filename
184
+ - Analysis works on both raw and cleaned MIDI files for comparison
backend/scripts/analyze_transcription.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Analyze transcription quality and identify common issues.
4
+
5
+ Usage (from backend directory):
6
+ python scripts/analyze_transcription.py <midi_path>
7
+
8
+ Example:
9
+ python scripts/analyze_transcription.py /tmp/rescored/temp/test_e2e/piano.mid
10
+ """
11
+ import sys
12
+ from pathlib import Path
13
+ import mido
14
+ from collections import Counter
15
+ import statistics
16
+
17
+
18
+ def analyze_midi(midi_path: Path):
19
+ """Analyze MIDI file for common transcription issues."""
20
+ mid = mido.MidiFile(midi_path)
21
+
22
+ # Collect all notes with timing
23
+ notes = [] # (time, pitch, velocity, duration)
24
+
25
+ for track in mid.tracks:
26
+ absolute_time = 0
27
+ active_notes = {} # pitch -> (start_time, velocity)
28
+
29
+ for msg in track:
30
+ absolute_time += msg.time
31
+
32
+ if msg.type == 'note_on' and msg.velocity > 0:
33
+ active_notes[msg.note] = (absolute_time, msg.velocity)
34
+
35
+ elif msg.type in ['note_off', 'note_on']: # note_on with velocity 0 is also note_off
36
+ if msg.note in active_notes:
37
+ start_time, velocity = active_notes.pop(msg.note)
38
+ duration = absolute_time - start_time
39
+ notes.append((start_time, msg.note, velocity, duration))
40
+
41
+ if not notes:
42
+ print("No notes found in MIDI file!")
43
+ return
44
+
45
+ # Sort notes by time
46
+ notes.sort(key=lambda n: n[0])
47
+
48
+ # Analysis
49
+ print("=" * 60)
50
+ print("MIDI Transcription Analysis")
51
+ print("=" * 60)
52
+ print(f"File: {midi_path.name}")
53
+ print(f"Duration: {mid.length:.1f} seconds")
54
+ print(f"Total notes: {len(notes)}")
55
+ print(f"Notes per second: {len(notes) / mid.length:.2f}")
56
+ print()
57
+
58
+ # Pitch analysis
59
+ pitches = [n[1] for n in notes]
60
+ pitch_counts = Counter(pitches)
61
+ print("Pitch Range:")
62
+ print(f" Lowest: {min(pitches)} (MIDI) = {_midi_to_note(min(pitches))}")
63
+ print(f" Highest: {max(pitches)} (MIDI) = {_midi_to_note(max(pitches))}")
64
+ print(f" Range: {max(pitches) - min(pitches)} semitones")
65
+ print()
66
+
67
+ # Duration analysis
68
+ durations_ticks = [n[3] for n in notes]
69
+ durations_seconds = [mido.tick2second(d, mid.ticks_per_beat, 500000) for d in durations_ticks]
70
+ print("Note Durations:")
71
+ print(f" Average: {statistics.mean(durations_seconds):.3f} seconds")
72
+ print(f" Median: {statistics.median(durations_seconds):.3f} seconds")
73
+ print(f" Min: {min(durations_seconds):.3f} seconds")
74
+ print(f" Max: {max(durations_seconds):.3f} seconds")
75
+
76
+ # Identify very short notes (likely noise/false positives)
77
+ very_short_notes = [d for d in durations_seconds if d < 0.1] # < 100ms
78
+ short_notes = [d for d in durations_seconds if d < 0.2] # < 200ms
79
+ print(f" Very short notes (< 100ms): {len(very_short_notes)} ({len(very_short_notes)/len(notes)*100:.1f}%)")
80
+ print(f" Short notes (< 200ms): {len(short_notes)} ({len(short_notes)/len(notes)*100:.1f}%)")
81
+ print()
82
+
83
+ # Velocity analysis
84
+ velocities = [n[2] for n in notes]
85
+ print("Velocity (dynamics):")
86
+ print(f" Average: {statistics.mean(velocities):.1f}")
87
+ print(f" Min: {min(velocities)}")
88
+ print(f" Max: {max(velocities)}")
89
+ print(f" Range: {max(velocities) - min(velocities)}")
90
+
91
+ # Identify very quiet notes (likely noise/false positives)
92
+ quiet_notes = [v for v in velocities if v < 30]
93
+ print(f" Very quiet notes (velocity < 30): {len(quiet_notes)} ({len(quiet_notes)/len(notes)*100:.1f}%)")
94
+ print()
95
+
96
+ # Polyphony analysis (notes happening at same time)
97
+ time_windows = {} # time_window -> count
98
+ window_size = 50 # 50 ticks
99
+ for note_time, _, _, _ in notes:
100
+ window = note_time // window_size
101
+ time_windows[window] = time_windows.get(window, 0) + 1
102
+
103
+ max_polyphony = max(time_windows.values())
104
+ avg_polyphony = statistics.mean(time_windows.values())
105
+ print("Polyphony (simultaneous notes):")
106
+ print(f" Max simultaneous: ~{max_polyphony}")
107
+ print(f" Average: ~{avg_polyphony:.1f}")
108
+ print()
109
+
110
+ # Most common pitches
111
+ print("Most frequent pitches (top 10):")
112
+ for pitch, count in pitch_counts.most_common(10):
113
+ print(f" {_midi_to_note(pitch):>3s} (MIDI {pitch:>2d}): {count:>4d} times ({count/len(notes)*100:>5.1f}%)")
114
+ print()
115
+
116
+ # Identify potential issues
117
+ print("Potential Issues:")
118
+ issues = []
119
+
120
+ if len(very_short_notes) / len(notes) > 0.3:
121
+ issues.append(f"⚠️ {len(very_short_notes)/len(notes)*100:.1f}% of notes are very short (< 100ms) - likely false positives")
122
+
123
+ if len(quiet_notes) / len(notes) > 0.3:
124
+ issues.append(f"⚠️ {len(quiet_notes)/len(notes)*100:.1f}% of notes are very quiet (velocity < 30) - likely noise")
125
+
126
+ if len(notes) / mid.length > 15:
127
+ issues.append(f"⚠️ Very high note density ({len(notes) / mid.length:.1f} notes/sec) - likely over-transcribing")
128
+
129
+ if max_polyphony > 20:
130
+ issues.append(f"⚠️ Very high polyphony (max {max_polyphony} notes) - likely detecting noise as notes")
131
+
132
+ if min(pitches) < 21 or max(pitches) > 108:
133
+ issues.append(f"⚠️ Notes outside piano range (MIDI 21-108) detected")
134
+
135
+ if not issues:
136
+ print(" ✓ No obvious issues detected")
137
+ else:
138
+ for issue in issues:
139
+ print(f" {issue}")
140
+
141
+ print()
142
+ print("Recommendations:")
143
+ if len(very_short_notes) / len(notes) > 0.3:
144
+ print(" • Increase minimum-note-length threshold in basic-pitch")
145
+ if len(quiet_notes) / len(notes) > 0.3:
146
+ print(" • Increase frame-threshold in basic-pitch to ignore quieter notes")
147
+ if len(notes) / mid.length > 15:
148
+ print(" • Increase onset-threshold in basic-pitch to be less sensitive")
149
+ if max_polyphony > 20:
150
+ print(" • Use median filtering or harmonic analysis to remove noise")
151
+
152
+ print("=" * 60)
153
+
154
+
155
+ def _midi_to_note(midi_num):
156
+ """Convert MIDI number to note name."""
157
+ notes = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
158
+ octave = (midi_num // 12) - 1
159
+ note = notes[midi_num % 12]
160
+ return f"{note}{octave}"
161
+
162
+
163
+ if __name__ == "__main__":
164
+ if len(sys.argv) < 2:
165
+ print("Usage: python analyze_transcription.py <midi_path>")
166
+ print("\nExample:")
167
+ print(" python analyze_transcription.py /tmp/rescored/temp/test_e2e/piano.mid")
168
+ sys.exit(1)
169
+
170
+ midi_path = Path(sys.argv[1])
171
+ if not midi_path.exists():
172
+ print(f"Error: File not found: {midi_path}")
173
+ sys.exit(1)
174
+
175
+ analyze_midi(midi_path)
backend/scripts/diagnose_pipeline.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Diagnose pipeline accuracy issues by analyzing each stage.
4
+
5
+ Usage (from backend directory):
6
+ python scripts/diagnose_pipeline.py <job_id>
7
+
8
+ Example:
9
+ python scripts/diagnose_pipeline.py test_e2e
10
+ """
11
+ import sys
12
+ from pathlib import Path
13
+ import soundfile as sf
14
+ import numpy as np
15
+ import mido
16
+
17
+ # Add parent directory to path for imports
18
+ sys.path.insert(0, str(Path(__file__).parent.parent))
19
+
20
+ from config import settings
21
+
22
+
23
+ def analyze_audio_file(audio_path: Path, label: str):
24
+ """Analyze audio file characteristics."""
25
+ print(f"\n{label}:")
26
+ print(f" Path: {audio_path}")
27
+
28
+ if not audio_path.exists():
29
+ print(f" ❌ File not found!")
30
+ return
31
+
32
+ # Read audio
33
+ data, samplerate = sf.read(audio_path)
34
+
35
+ # Calculate statistics
36
+ duration = len(data) / samplerate
37
+ channels = 1 if len(data.shape) == 1 else data.shape[1]
38
+
39
+ # RMS energy (loudness)
40
+ if channels == 1:
41
+ rms = np.sqrt(np.mean(data**2))
42
+ else:
43
+ rms = np.sqrt(np.mean(data**2, axis=0))
44
+
45
+ # Peak amplitude
46
+ peak = np.max(np.abs(data))
47
+
48
+ # Dynamic range
49
+ if channels == 1:
50
+ dynamic_range = 20 * np.log10(peak / (rms + 1e-10))
51
+ else:
52
+ dynamic_range = 20 * np.log10(peak / (np.mean(rms) + 1e-10))
53
+
54
+ print(f" Duration: {duration:.1f}s")
55
+ print(f" Sample rate: {samplerate} Hz")
56
+ print(f" Channels: {channels}")
57
+ print(f" Peak amplitude: {peak:.3f}")
58
+
59
+ if channels == 1:
60
+ print(f" RMS energy: {rms:.3f}")
61
+ else:
62
+ print(f" RMS energy (L/R): {rms[0]:.3f} / {rms[1]:.3f}")
63
+
64
+ print(f" Dynamic range: {dynamic_range:.1f} dB")
65
+
66
+ # Check for clipping
67
+ clipped_samples = np.sum(np.abs(data) >= 0.99)
68
+ if clipped_samples > 0:
69
+ print(f" ⚠️ Clipped samples: {clipped_samples} ({clipped_samples/len(data)*100:.2f}%)")
70
+
71
+ # Check for silence
72
+ silence_threshold = 0.01
73
+ if channels == 1:
74
+ silent_samples = np.sum(np.abs(data) < silence_threshold)
75
+ else:
76
+ silent_samples = np.sum(np.max(np.abs(data), axis=1) < silence_threshold)
77
+
78
+ if silent_samples > len(data) * 0.1:
79
+ print(f" ⚠️ Silence: {silent_samples/len(data)*100:.1f}% of audio")
80
+
81
+ # Check if mostly quiet (could indicate poor separation)
82
+ if isinstance(rms, np.ndarray):
83
+ avg_rms = np.mean(rms)
84
+ else:
85
+ avg_rms = rms
86
+
87
+ if avg_rms < 0.01:
88
+ print(f" ⚠️ Very quiet audio (RMS: {avg_rms:.4f}) - may indicate poor source separation")
89
+ elif avg_rms < 0.05:
90
+ print(f" ⚠️ Quiet audio (RMS: {avg_rms:.4f}) - basic-pitch may struggle")
91
+
92
+
93
+ def analyze_midi_file(midi_path: Path, label: str):
94
+ """Analyze MIDI file."""
95
+ print(f"\n{label}:")
96
+ print(f" Path: {midi_path}")
97
+
98
+ if not midi_path.exists():
99
+ print(f" ❌ File not found!")
100
+ return
101
+
102
+ mid = mido.MidiFile(midi_path)
103
+
104
+ # Count notes
105
+ note_count = 0
106
+ note_pitches = []
107
+ note_velocities = []
108
+
109
+ for track in mid.tracks:
110
+ for msg in track:
111
+ if msg.type == 'note_on' and msg.velocity > 0:
112
+ note_count += 1
113
+ note_pitches.append(msg.note)
114
+ note_velocities.append(msg.velocity)
115
+
116
+ print(f" Duration: {mid.length:.1f}s")
117
+ print(f" Total notes: {note_count}")
118
+ print(f" Notes per second: {note_count / mid.length:.2f}")
119
+
120
+ if note_pitches:
121
+ print(f" Pitch range: {min(note_pitches)} - {max(note_pitches)}")
122
+ print(f" Avg velocity: {np.mean(note_velocities):.1f}")
123
+ print(f" Velocity range: {min(note_velocities)} - {max(note_velocities)}")
124
+
125
+
126
+ def diagnose_job(job_id: str):
127
+ """Diagnose a specific transcription job."""
128
+ storage_path = Path(settings.storage_path)
129
+ job_dir = storage_path / "temp" / job_id
130
+
131
+ print("=" * 60)
132
+ print("PIPELINE DIAGNOSTIC REPORT")
133
+ print("=" * 60)
134
+ print(f"Job ID: {job_id}")
135
+ print(f"Job Directory: {job_dir}")
136
+
137
+ if not job_dir.exists():
138
+ print(f"\n❌ Job directory not found: {job_dir}")
139
+ print("\nRun test_e2e.py first to create a job:")
140
+ print(f' python scripts/test_e2e.py "https://www.youtube.com/watch?v=VIDEO_ID"')
141
+ sys.exit(1)
142
+
143
+ print("\n" + "=" * 60)
144
+ print("STAGE 1: AUDIO DOWNLOAD")
145
+ print("=" * 60)
146
+
147
+ audio_path = job_dir / "audio.wav"
148
+ analyze_audio_file(audio_path, "Downloaded Audio")
149
+
150
+ print("\n" + "=" * 60)
151
+ print("STAGE 2: SOURCE SEPARATION (Demucs)")
152
+ print("=" * 60)
153
+
154
+ demucs_dir = job_dir / "htdemucs" / "audio"
155
+ other_stem = demucs_dir / "other.wav"
156
+ no_other_stem = demucs_dir / "no_other.wav"
157
+
158
+ analyze_audio_file(other_stem, "Other Stem (Piano/Melodic)")
159
+ analyze_audio_file(no_other_stem, "No-Other Stem (Drums/Bass/Vocals)")
160
+
161
+ # Compare separation quality
162
+ if audio_path.exists() and other_stem.exists() and no_other_stem.exists():
163
+ print("\n Separation Quality Check:")
164
+
165
+ # Read all audio
166
+ original, sr = sf.read(audio_path)
167
+ other, _ = sf.read(other_stem)
168
+ no_other, _ = sf.read(no_other_stem)
169
+
170
+ # Calculate energy distribution
171
+ original_energy = np.sum(original**2)
172
+ other_energy = np.sum(other**2)
173
+ no_other_energy = np.sum(no_other**2)
174
+ total_separated_energy = other_energy + no_other_energy
175
+
176
+ print(f" Original energy: {original_energy:.2e}")
177
+ print(f" Other energy: {other_energy:.2e} ({other_energy/original_energy*100:.1f}%)")
178
+ print(f" No-other energy: {no_other_energy:.2e} ({no_other_energy/original_energy*100:.1f}%)")
179
+ print(f" Energy preservation: {total_separated_energy/original_energy*100:.1f}%")
180
+
181
+ # Check if 'other' stem is too quiet (bad separation)
182
+ if other_energy / original_energy < 0.1:
183
+ print(f" ⚠️ 'Other' stem has very low energy - poor separation for melodic content")
184
+ elif other_energy / original_energy < 0.2:
185
+ print(f" ⚠️ 'Other' stem has low energy - separation may not be ideal")
186
+
187
+ print("\n" + "=" * 60)
188
+ print("STAGE 3: TRANSCRIPTION (basic-pitch)")
189
+ print("=" * 60)
190
+
191
+ piano_midi = job_dir / "piano.mid"
192
+ analyze_midi_file(piano_midi, "Raw MIDI Output")
193
+
194
+ print("\n" + "=" * 60)
195
+ print("STAGE 4: MIDI CLEANING")
196
+ print("=" * 60)
197
+
198
+ clean_midi = job_dir / "piano_clean.mid"
199
+ analyze_midi_file(clean_midi, "Cleaned MIDI Output")
200
+
201
+ # Compare raw vs cleaned
202
+ if piano_midi.exists() and clean_midi.exists():
203
+ raw_mid = mido.MidiFile(piano_midi)
204
+ clean_mid = mido.MidiFile(clean_midi)
205
+
206
+ raw_notes = sum(1 for track in raw_mid.tracks for msg in track if msg.type == 'note_on' and msg.velocity > 0)
207
+ clean_notes = sum(1 for track in clean_mid.tracks for msg in track if msg.type == 'note_on' and msg.velocity > 0)
208
+
209
+ removed_notes = raw_notes - clean_notes
210
+ print(f"\n Cleaning Impact:")
211
+ print(f" Notes removed: {removed_notes} ({removed_notes/raw_notes*100:.1f}%)")
212
+
213
+ if removed_notes / raw_notes > 0.5:
214
+ print(f" ⚠️ Removed >50% of notes - cleaning may be too aggressive")
215
+
216
+ print("\n" + "=" * 60)
217
+ print("DIAGNOSIS SUMMARY")
218
+ print("=" * 60)
219
+
220
+ # Provide recommendations based on analysis
221
+ print("\nPotential Issues:")
222
+
223
+ issues_found = False
224
+
225
+ # Check 1: Source separation quality
226
+ if other_stem.exists():
227
+ other_data, _ = sf.read(other_stem)
228
+ other_rms = np.sqrt(np.mean(other_data**2))
229
+
230
+ if other_rms < 0.05:
231
+ print(" ⚠️ 'Other' stem is very quiet - Demucs may not be separating piano well")
232
+ print(" → This is the most likely cause of poor transcription accuracy")
233
+ print(" → The piano might be mixed with other instruments in different stems")
234
+ issues_found = True
235
+
236
+ # Check 2: Note density
237
+ if piano_midi.exists():
238
+ mid = mido.MidiFile(piano_midi)
239
+ note_count = sum(1 for track in mid.tracks for msg in track if msg.type == 'note_on' and msg.velocity > 0)
240
+ density = note_count / mid.length
241
+
242
+ if density < 2:
243
+ print(" ⚠️ Very low note density - basic-pitch may be too conservative")
244
+ print(" → Try decreasing onset-threshold and frame-threshold")
245
+ issues_found = True
246
+ elif density > 10:
247
+ print(" ⚠️ Very high note density - basic-pitch may be too aggressive")
248
+ print(" → Current thresholds might already be good; check if it's detecting noise")
249
+ issues_found = True
250
+
251
+ if not issues_found:
252
+ print(" No obvious technical issues detected")
253
+ print(" The problem may be:")
254
+ print(" • Music is too complex for current models")
255
+ print(" • Need better source separation (try different Demucs model)")
256
+ print(" • basic-pitch limitations with this type of music")
257
+
258
+ print("\n" + "=" * 60)
259
+ print("RECOMMENDATIONS")
260
+ print("=" * 60)
261
+
262
+ print("""
263
+ Next steps to improve accuracy:
264
+
265
+ 1. LISTEN to the separated stems:
266
+ - Play 'other.wav' to verify piano is properly separated
267
+ - If piano is barely audible, source separation failed
268
+
269
+ 2. Try different Demucs models:
270
+ - Current: htdemucs with --two-stems=other
271
+ - Try: htdemucs_6s (6-stem with dedicated piano separation)
272
+ - Command: demucs --model htdemucs_6s audio.wav
273
+
274
+ 3. Test with simpler music:
275
+ - Solo piano (no other instruments)
276
+ - Clear, slow melodies
277
+ - This helps isolate if issue is separation or transcription
278
+
279
+ 4. Compare with ground truth:
280
+ - Find sheet music for the test song
281
+ - Compare transcribed notes with actual notes
282
+ - Identify patterns (missing high notes? wrong octaves?)
283
+
284
+ 5. Try alternative transcription models:
285
+ - MT3 (Music Transformer) - slower but more accurate
286
+ - Omnizart piano model - specialized for piano
287
+ """)
288
+
289
+ print("=" * 60)
290
+ print("\nTo listen to the separated 'other' stem:")
291
+ print(f" play {other_stem}")
292
+ print(f" # or")
293
+ print(f" ffplay {other_stem}")
294
+ print("=" * 60)
295
+
296
+
297
+ if __name__ == "__main__":
298
+ if len(sys.argv) < 2:
299
+ print("Usage: python scripts/diagnose_pipeline.py <job_id>")
300
+ print("\nExample:")
301
+ print(" python scripts/diagnose_pipeline.py test_e2e")
302
+ print("\nFirst run test_e2e.py to create a job:")
303
+ print(' python scripts/test_e2e.py "https://www.youtube.com/watch?v=VIDEO_ID"')
304
+ sys.exit(1)
305
+
306
+ job_id = sys.argv[1]
307
+ diagnose_job(job_id)
backend/scripts/test_accuracy.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Accuracy Testing Suite for Rescored Pipeline
4
+
5
+ Tests transcription accuracy on 10 diverse piano videos covering different styles and complexities.
6
+ """
7
+ import sys
8
+ from pathlib import Path
9
+ sys.path.insert(0, str(Path(__file__).parent.parent))
10
+
11
+ from pipeline import TranscriptionPipeline
12
+ from config import settings
13
+ import json
14
+ from datetime import datetime
15
+
16
+
17
+ # Test videos with varying complexity
18
+ TEST_VIDEOS = [
19
+ {
20
+ "id": "simple_melody",
21
+ "url": "https://www.youtube.com/watch?v=TK1Ij_-mank",
22
+ "description": "Simple piano melody - C major scale practice",
23
+ "difficulty": "easy",
24
+ "expected_accuracy": ">80%",
25
+ "notes": "Slow tempo, single notes, clear recording"
26
+ },
27
+ {
28
+ "id": "twinkle_twinkle",
29
+ "url": "https://www.youtube.com/watch?v=YCZ_d_4ZEqk",
30
+ "description": "Twinkle Twinkle Little Star - Beginner piano",
31
+ "difficulty": "easy",
32
+ "expected_accuracy": ">75%",
33
+ "notes": "Very simple melody, slow tempo"
34
+ },
35
+ {
36
+ "id": "fur_elise",
37
+ "url": "https://www.youtube.com/watch?v=_mVW8tgGY_w",
38
+ "description": "Beethoven - Für Elise (simplified)",
39
+ "difficulty": "medium",
40
+ "expected_accuracy": "60-70%",
41
+ "notes": "Classic piece, moderate tempo, some ornaments"
42
+ },
43
+ {
44
+ "id": "chopin_nocturne",
45
+ "url": "https://www.youtube.com/watch?v=9E6b3swbnWg",
46
+ "description": "Chopin - Nocturne Op. 9 No. 2",
47
+ "difficulty": "hard",
48
+ "expected_accuracy": "50-60%",
49
+ "notes": "Complex harmonies, expressive dynamics, rubato"
50
+ },
51
+ {
52
+ "id": "canon_in_d",
53
+ "url": "https://www.youtube.com/watch?v=NlprozGcs80",
54
+ "description": "Pachelbel - Canon in D (piano arrangement)",
55
+ "difficulty": "medium",
56
+ "expected_accuracy": "60-70%",
57
+ "notes": "Repetitive patterns, moderate polyphony"
58
+ },
59
+ {
60
+ "id": "river_flows",
61
+ "url": "https://www.youtube.com/watch?v=7maJOI3QMu0",
62
+ "description": "Yiruma - River Flows in You",
63
+ "difficulty": "medium",
64
+ "expected_accuracy": "60-70%",
65
+ "notes": "Modern piano, flowing arpeggios"
66
+ },
67
+ {
68
+ "id": "moonlight_sonata",
69
+ "url": "https://www.youtube.com/watch?v=4Tr0otuiQuU",
70
+ "description": "Beethoven - Moonlight Sonata (1st movement)",
71
+ "difficulty": "medium",
72
+ "expected_accuracy": "60-70%",
73
+ "notes": "Slow tempo, triplet arpeggios, bass notes"
74
+ },
75
+ {
76
+ "id": "jazz_blues",
77
+ "url": "https://www.youtube.com/watch?v=F3W_alUuFkA",
78
+ "description": "Simple jazz blues piano",
79
+ "difficulty": "medium",
80
+ "expected_accuracy": "55-65%",
81
+ "notes": "Swing rhythm, blue notes, syncopation"
82
+ },
83
+ {
84
+ "id": "claire_de_lune",
85
+ "url": "https://www.youtube.com/watch?v=WNcsUNKlAKw",
86
+ "description": "Debussy - Clair de Lune",
87
+ "difficulty": "hard",
88
+ "expected_accuracy": "50-60%",
89
+ "notes": "Impressionist harmony, complex textures"
90
+ },
91
+ {
92
+ "id": "la_campanella",
93
+ "url": "https://www.youtube.com/watch?v=MD6xMyuZls0",
94
+ "description": "Liszt - La Campanella",
95
+ "difficulty": "very_hard",
96
+ "expected_accuracy": "40-50%",
97
+ "notes": "Virtuosic, extremely fast, wide range, many notes"
98
+ }
99
+ ]
100
+
101
+
102
+ def run_accuracy_test(video, verbose=True):
103
+ """
104
+ Run transcription pipeline on a test video and collect metrics.
105
+
106
+ Args:
107
+ video: Dictionary with video metadata
108
+ verbose: Print progress messages
109
+
110
+ Returns:
111
+ Dictionary with test results and metrics
112
+ """
113
+ if verbose:
114
+ print(f"\n{'='*70}")
115
+ print(f"Testing: {video['description']}")
116
+ print(f"Difficulty: {video['difficulty']} | Expected: {video['expected_accuracy']}")
117
+ print(f"{'='*70}")
118
+
119
+ job_id = f"accuracy_test_{video['id']}"
120
+ storage_path = Path(settings.storage_path)
121
+
122
+ # Progress callback
123
+ def progress_callback(percent, stage, message):
124
+ if verbose:
125
+ print(f"[{percent:3d}%] {stage:12s} | {message}")
126
+
127
+ result = {
128
+ "video_id": video["id"],
129
+ "description": video["description"],
130
+ "difficulty": video["difficulty"],
131
+ "url": video["url"],
132
+ "timestamp": datetime.utcnow().isoformat(),
133
+ "success": False,
134
+ "error": None,
135
+ "metrics": {}
136
+ }
137
+
138
+ try:
139
+ # Run pipeline
140
+ pipeline = TranscriptionPipeline(job_id, video["url"], storage_path)
141
+ pipeline.set_progress_callback(progress_callback)
142
+
143
+ musicxml_path = pipeline.run()
144
+
145
+ # Get intermediate file paths for analysis
146
+ temp_dir = pipeline.temp_dir
147
+ original_audio = temp_dir / "audio.wav"
148
+ other_stem = temp_dir / "htdemucs" / job_id / "other.wav"
149
+ midi_path = temp_dir / "other_basic_pitch.mid"
150
+ clean_midi = temp_dir / "piano_clean.mid"
151
+
152
+ # Collect metrics
153
+ import soundfile as sf
154
+ import mido
155
+
156
+ # Audio metrics
157
+ if original_audio.exists():
158
+ audio_data, sr = sf.read(original_audio)
159
+ result["metrics"]["audio_duration_seconds"] = len(audio_data) / sr
160
+
161
+ # Separation quality (simple energy ratio)
162
+ if original_audio.exists() and other_stem.exists():
163
+ import numpy as np
164
+ original_data, _ = sf.read(original_audio)
165
+ other_data, _ = sf.read(other_stem)
166
+
167
+ original_energy = np.sum(original_data ** 2)
168
+ other_energy = np.sum(other_data ** 2)
169
+
170
+ result["metrics"]["separation"] = {
171
+ "other_energy_ratio": other_energy / original_energy if original_energy > 0 else 0
172
+ }
173
+
174
+ # MIDI analysis (simple note count)
175
+ if clean_midi.exists():
176
+ mid = mido.MidiFile(clean_midi)
177
+ note_count = sum(1 for track in mid.tracks for msg in track if msg.type == 'note_on')
178
+
179
+ result["metrics"]["midi"] = {
180
+ "total_notes": note_count,
181
+ "duration_seconds": mid.length
182
+ }
183
+
184
+ # MusicXML analysis (measure count, etc)
185
+ if musicxml_path.exists():
186
+ from music21 import converter
187
+ score = converter.parse(musicxml_path)
188
+ measures = score.parts[0].getElementsByClass('Measure') if score.parts else []
189
+
190
+ result["metrics"]["musicxml"] = {
191
+ "total_measures": len(measures),
192
+ "file_size_kb": musicxml_path.stat().st_size / 1024
193
+ }
194
+
195
+ result["success"] = True
196
+ result["output_files"] = {
197
+ "musicxml": str(musicxml_path),
198
+ "midi": str(clean_midi),
199
+ "temp_dir": str(temp_dir)
200
+ }
201
+
202
+ if verbose:
203
+ print(f"\n✅ SUCCESS - Output: {musicxml_path}")
204
+ print(f" MIDI notes: {result['metrics']['midi']['total_notes']}")
205
+ print(f" Measures: {result['metrics']['musicxml']['total_measures']}")
206
+ if 'separation' in result['metrics']:
207
+ sep = result['metrics']['separation']
208
+ print(f" Separation: {sep['other_energy_ratio']:.1%} energy in 'other' stem")
209
+
210
+ except Exception as e:
211
+ result["error"] = str(e)
212
+ if verbose:
213
+ print(f"\n❌ FAILED - Error: {e}")
214
+
215
+ return result
216
+
217
+
218
+ def main():
219
+ """Run accuracy tests on all test videos."""
220
+ print("="*70)
221
+ print("Rescored Accuracy Testing Suite")
222
+ print("="*70)
223
+ print(f"Testing {len(TEST_VIDEOS)} videos with varying difficulty")
224
+ print(f"Storage: {settings.storage_path}")
225
+ print()
226
+
227
+ # Run tests
228
+ results = []
229
+ for i, video in enumerate(TEST_VIDEOS, 1):
230
+ print(f"\n[{i}/{len(TEST_VIDEOS)}] Starting test: {video['id']}")
231
+ result = run_accuracy_test(video, verbose=True)
232
+ results.append(result)
233
+
234
+ # Summary
235
+ print("\n" + "="*70)
236
+ print("ACCURACY TEST SUMMARY")
237
+ print("="*70)
238
+
239
+ successful = [r for r in results if r["success"]]
240
+ failed = [r for r in results if not r["success"]]
241
+
242
+ print(f"\nTotal: {len(results)} | Success: {len(successful)} | Failed: {len(failed)}")
243
+ print(f"Success Rate: {len(successful)/len(results)*100:.1f}%")
244
+
245
+ if successful:
246
+ print("\n✅ Successful Transcriptions:")
247
+ for r in successful:
248
+ midi_notes = r["metrics"]["midi"]["total_notes"]
249
+ measures = r["metrics"]["musicxml"]["total_measures"]
250
+ print(f" - {r['video_id']:20s} | {midi_notes:4d} notes | {measures:3d} measures | {r['difficulty']}")
251
+
252
+ if failed:
253
+ print("\n❌ Failed Transcriptions:")
254
+ for r in failed:
255
+ print(f" - {r['video_id']:20s} | Error: {r['error'][:60]}")
256
+
257
+ # Save results to JSON
258
+ output_path = Path(settings.storage_path) / "accuracy_test_results.json"
259
+ output_path.parent.mkdir(parents=True, exist_ok=True)
260
+
261
+ with open(output_path, 'w') as f:
262
+ json.dump({
263
+ "test_date": datetime.utcnow().isoformat(),
264
+ "total_tests": len(results),
265
+ "successful": len(successful),
266
+ "failed": len(failed),
267
+ "success_rate": len(successful) / len(results),
268
+ "results": results
269
+ }, f, indent=2)
270
+
271
+ print(f"\n📊 Full results saved to: {output_path}")
272
+
273
+ return 0 if not failed else 1
274
+
275
+
276
+ if __name__ == "__main__":
277
+ sys.exit(main())
backend/scripts/test_demucs_models.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test different Demucs models to find the best source separation.
4
+
5
+ Usage (from backend directory):
6
+ python scripts/test_demucs_models.py <audio_path>
7
+
8
+ Example:
9
+ python scripts/test_demucs_models.py /tmp/rescored/temp/test_e2e/audio.wav
10
+ """
11
+ import sys
12
+ from pathlib import Path
13
+ import subprocess
14
+ import soundfile as sf
15
+ import numpy as np
16
+ import tempfile
17
+ import shutil
18
+
19
+
20
+ def test_demucs_model(audio_path: Path, model_name: str, stems: str = None):
21
+ """Test a specific Demucs model."""
22
+ print(f"\n{'='*60}")
23
+ print(f"Testing: {model_name}")
24
+ print(f"{'='*60}")
25
+
26
+ # Create temp directory for this test
27
+ with tempfile.TemporaryDirectory() as temp_dir:
28
+ temp_path = Path(temp_dir)
29
+
30
+ # Build command
31
+ cmd = ["demucs", "--model", model_name, "-o", str(temp_path), str(audio_path)]
32
+
33
+ if stems:
34
+ cmd.extend(["--two-stems", stems])
35
+
36
+ print(f"Command: {' '.join(cmd)}")
37
+ print("Running... (this may take a minute)")
38
+
39
+ # Run Demucs
40
+ try:
41
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
42
+
43
+ if result.returncode != 0:
44
+ print(f"❌ Failed: {result.stderr[:500]}")
45
+ return None
46
+
47
+ # Find output directory
48
+ model_output_dir = temp_path / model_name / audio_path.stem
49
+
50
+ if not model_output_dir.exists():
51
+ print(f"❌ Output directory not found: {model_output_dir}")
52
+ return None
53
+
54
+ # Analyze stems
55
+ print("\nStem Analysis:")
56
+ original_data, sr = sf.read(audio_path)
57
+ original_energy = np.sum(original_data**2)
58
+
59
+ stem_energies = {}
60
+
61
+ for stem_file in sorted(model_output_dir.glob("*.wav")):
62
+ stem_name = stem_file.stem
63
+ stem_data, _ = sf.read(stem_file)
64
+ stem_energy = np.sum(stem_data**2)
65
+ stem_rms = np.sqrt(np.mean(stem_data**2))
66
+
67
+ percentage = (stem_energy / original_energy) * 100
68
+ stem_energies[stem_name] = (stem_energy, stem_rms, percentage)
69
+
70
+ print(f" {stem_name:15s}: {percentage:5.1f}% energy, RMS: {stem_rms:.3f}")
71
+
72
+ # Find best stem for piano/melodic content
73
+ # Usually 'other', 'piano', or 'other' in 2-stem
74
+ print("\nBest stem for piano:")
75
+
76
+ if 'piano' in stem_energies:
77
+ best_stem = 'piano'
78
+ print(f" ✓ Dedicated 'piano' stem found")
79
+ elif 'other' in stem_energies:
80
+ best_stem = 'other'
81
+ print(f" ✓ Using 'other' stem")
82
+ else:
83
+ # Find stem with most energy
84
+ best_stem = max(stem_energies.items(), key=lambda x: x[1][0])[0]
85
+ print(f" → Using '{best_stem}' (highest energy)")
86
+
87
+ energy, rms, percentage = stem_energies[best_stem]
88
+ print(f" Energy: {percentage:.1f}%, RMS: {rms:.3f}")
89
+
90
+ if percentage < 15:
91
+ print(f" ⚠️ Very low energy - may not work well")
92
+ elif percentage < 25:
93
+ print(f" ⚠️ Low energy - borderline")
94
+ else:
95
+ print(f" ✓ Good energy level")
96
+
97
+ return {
98
+ 'model': model_name,
99
+ 'best_stem': best_stem,
100
+ 'energy_percentage': percentage,
101
+ 'rms': rms,
102
+ 'all_stems': stem_energies
103
+ }
104
+
105
+ except subprocess.TimeoutExpired:
106
+ print(f"❌ Timeout after 5 minutes")
107
+ return None
108
+ except Exception as e:
109
+ print(f"❌ Error: {e}")
110
+ return None
111
+
112
+
113
+ def main():
114
+ if len(sys.argv) < 2:
115
+ print("Usage: python scripts/test_demucs_models.py <audio_path>")
116
+ print("\nExample:")
117
+ print(" python scripts/test_demucs_models.py /tmp/rescored/temp/test_e2e/audio.wav")
118
+ sys.exit(1)
119
+
120
+ audio_path = Path(sys.argv[1])
121
+
122
+ if not audio_path.exists():
123
+ print(f"Error: Audio file not found: {audio_path}")
124
+ sys.exit(1)
125
+
126
+ print("=" * 60)
127
+ print("DEMUCS MODEL COMPARISON")
128
+ print("=" * 60)
129
+ print(f"Audio file: {audio_path}")
130
+ print(f"Duration: ~{sf.info(audio_path).duration:.1f}s")
131
+
132
+ # Test different models
133
+ results = []
134
+
135
+ # Test 1: Current model (htdemucs 2-stem)
136
+ print("\n\n" + "="*60)
137
+ print("TEST 1: htdemucs (2-stem: other)")
138
+ print("="*60)
139
+ result = test_demucs_model(audio_path, "htdemucs", stems="other")
140
+ if result:
141
+ results.append(result)
142
+
143
+ # Test 2: htdemucs_6s (6-stem with dedicated piano)
144
+ print("\n\n" + "="*60)
145
+ print("TEST 2: htdemucs_6s (6-stem with piano)")
146
+ print("="*60)
147
+ result = test_demucs_model(audio_path, "htdemucs_6s")
148
+ if result:
149
+ results.append(result)
150
+
151
+ # Test 3: htdemucs full 4-stem
152
+ print("\n\n" + "="*60)
153
+ print("TEST 3: htdemucs (4-stem)")
154
+ print("="*60)
155
+ result = test_demucs_model(audio_path, "htdemucs")
156
+ if result:
157
+ results.append(result)
158
+
159
+ # Summary
160
+ print("\n\n" + "="*60)
161
+ print("SUMMARY & RECOMMENDATIONS")
162
+ print("="*60)
163
+
164
+ if not results:
165
+ print("No successful tests!")
166
+ sys.exit(1)
167
+
168
+ # Sort by energy percentage
169
+ results.sort(key=lambda x: x['energy_percentage'], reverse=True)
170
+
171
+ print("\nRanking (by piano/melodic energy):")
172
+ for i, result in enumerate(results, 1):
173
+ print(f"{i}. {result['model']:20s} - {result['best_stem']:10s} - "
174
+ f"{result['energy_percentage']:5.1f}% energy, RMS: {result['rms']:.3f}")
175
+
176
+ best_result = results[0]
177
+ print(f"\n✓ RECOMMENDED: Use {best_result['model']} with '{best_result['best_stem']}' stem")
178
+
179
+ if best_result['energy_percentage'] < 20:
180
+ print("\n⚠️ WARNING: Even the best model has low energy (<20%)")
181
+ print(" This suggests:")
182
+ print(" - The audio may not have much piano/melodic content")
183
+ print(" - The piano may be heavily mixed with other instruments")
184
+ print(" - You may need to try a different test video")
185
+
186
+ print("\nTo update pipeline.py:")
187
+ if best_result['model'] == 'htdemucs_6s':
188
+ print(f" 1. Change line ~98: --two-stems=other → remove this flag")
189
+ print(f" 2. Change line ~96: demucs_output / 'htdemucs_6s' / audio_path.stem")
190
+ print(f" 3. Use stem: {best_result['best_stem']}.wav")
191
+ elif best_result['model'] == 'htdemucs' and '--two-stems' not in str(best_result):
192
+ print(f" 1. Change line ~98: --two-stems=other → remove this flag")
193
+ print(f" 2. Use stem: {best_result['best_stem']}.wav")
194
+
195
+ print("\n" + "="*60)
196
+
197
+
198
+ if __name__ == "__main__":
199
+ main()
backend/scripts/test_e2e.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ End-to-end test script for the transcription pipeline.
4
+
5
+ Usage (from backend directory):
6
+ python scripts/test_e2e.py <youtube_url>
7
+
8
+ Example:
9
+ python scripts/test_e2e.py "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
10
+ """
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ # Add parent directory to path for imports
15
+ sys.path.insert(0, str(Path(__file__).parent.parent))
16
+
17
+ from pipeline import TranscriptionPipeline
18
+ from config import settings
19
+ import time
20
+
21
+
22
+ def progress_callback(percent: int, stage: str, message: str):
23
+ """Print progress updates."""
24
+ print(f"[{percent:3d}%] {stage:12s} | {message}")
25
+
26
+
27
+ def main():
28
+ if len(sys.argv) < 2:
29
+ print("Usage: python test_e2e.py <youtube_url>")
30
+ print("\nExample simple piano videos to test:")
31
+ print("1. Twinkle Twinkle: https://www.youtube.com/watch?v=WyTb3DTu88c")
32
+ print("2. Simple melody: https://www.youtube.com/watch?v=fJ9rUzIMcZQ")
33
+ sys.exit(1)
34
+
35
+ youtube_url = sys.argv[1]
36
+ job_id = "test_e2e"
37
+ storage_path = Path(settings.storage_path)
38
+
39
+ print("=" * 60)
40
+ print("Rescored End-to-End Pipeline Test")
41
+ print("=" * 60)
42
+ print(f"YouTube URL: {youtube_url}")
43
+ print(f"Job ID: {job_id}")
44
+ print(f"Storage: {storage_path}")
45
+ print("=" * 60)
46
+ print()
47
+
48
+ # Create pipeline
49
+ pipeline = TranscriptionPipeline(job_id, youtube_url, storage_path)
50
+ pipeline.set_progress_callback(progress_callback)
51
+
52
+ # Run pipeline
53
+ try:
54
+ start_time = time.time()
55
+ musicxml_path = pipeline.run()
56
+ elapsed_time = time.time() - start_time
57
+
58
+ print()
59
+ print("=" * 60)
60
+ print("SUCCESS!")
61
+ print("=" * 60)
62
+ print(f"Total time: {elapsed_time:.1f} seconds")
63
+ print(f"MusicXML file: {musicxml_path}")
64
+ print(f"File size: {musicxml_path.stat().st_size / 1024:.1f} KB")
65
+ print()
66
+
67
+ # Show temp directory contents
68
+ print("Intermediate files:")
69
+ temp_dir = storage_path / "temp" / job_id
70
+ for file in sorted(temp_dir.rglob("*")):
71
+ if file.is_file():
72
+ size_kb = file.stat().st_size / 1024
73
+ rel_path = file.relative_to(temp_dir)
74
+ print(f" {rel_path} ({size_kb:.1f} KB)")
75
+ print()
76
+
77
+ # Preview MusicXML
78
+ print("MusicXML preview (first 50 lines):")
79
+ print("-" * 60)
80
+ with open(musicxml_path, 'r') as f:
81
+ for i, line in enumerate(f):
82
+ if i >= 50:
83
+ print("... (truncated)")
84
+ break
85
+ print(line.rstrip())
86
+ print("-" * 60)
87
+ print()
88
+
89
+ print("Next steps:")
90
+ print(f"1. Open in MuseScore: musescore {musicxml_path}")
91
+ print(f"2. Inspect MIDI: timidity {temp_dir}/piano_clean.mid")
92
+ print(f"3. Review temp files: ls -lh {temp_dir}")
93
+
94
+ except Exception as e:
95
+ print()
96
+ print("=" * 60)
97
+ print("FAILED!")
98
+ print("=" * 60)
99
+ print(f"Error: {e}")
100
+ import traceback
101
+ traceback.print_exc()
102
+ sys.exit(1)
103
+
104
+
105
+ if __name__ == "__main__":
106
+ main()
backend/scripts/test_quick_verify.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick verification test - only runs the 6 videos that had code bugs (now fixed).
4
+
5
+ This is faster than the full suite and verifies our bug fixes work.
6
+ """
7
+ import sys
8
+ from pathlib import Path
9
+ sys.path.insert(0, str(Path(__file__).parent.parent))
10
+
11
+ from test_accuracy import run_accuracy_test
12
+ import json
13
+ from datetime import datetime
14
+
15
+ # Only test the 6 videos that had code bugs (should all pass now)
16
+ QUICK_TEST_VIDEOS = [
17
+ {
18
+ "id": "chopin_nocturne",
19
+ "url": "https://www.youtube.com/watch?v=9E6b3swbnWg",
20
+ "description": "Chopin - Nocturne Op. 9 No. 2",
21
+ "difficulty": "hard",
22
+ "expected_accuracy": "50-60%",
23
+ "notes": "2048th note duration (Bug #2b)",
24
+ "bug": "2048th note duration (Bug #2b)"
25
+ },
26
+ {
27
+ "id": "canon_in_d",
28
+ "url": "https://www.youtube.com/watch?v=NlprozGcs80",
29
+ "description": "Pachelbel - Canon in D",
30
+ "difficulty": "medium",
31
+ "expected_accuracy": "60-70%",
32
+ "notes": "NoneType velocity (Bug #2a)",
33
+ "bug": "NoneType velocity (Bug #2a)"
34
+ },
35
+ {
36
+ "id": "river_flows",
37
+ "url": "https://www.youtube.com/watch?v=7maJOI3QMu0",
38
+ "description": "Yiruma - River Flows in You",
39
+ "difficulty": "medium",
40
+ "expected_accuracy": "60-70%",
41
+ "notes": "NoneType velocity (Bug #2a)",
42
+ "bug": "NoneType velocity (Bug #2a)"
43
+ },
44
+ {
45
+ "id": "moonlight_sonata",
46
+ "url": "https://www.youtube.com/watch?v=4Tr0otuiQuU",
47
+ "description": "Beethoven - Moonlight Sonata",
48
+ "difficulty": "medium",
49
+ "expected_accuracy": "60-70%",
50
+ "notes": "NoneType velocity (Bug #2a)",
51
+ "bug": "NoneType velocity (Bug #2a)"
52
+ },
53
+ {
54
+ "id": "claire_de_lune",
55
+ "url": "https://www.youtube.com/watch?v=WNcsUNKlAKw",
56
+ "description": "Debussy - Clair de Lune",
57
+ "difficulty": "hard",
58
+ "expected_accuracy": "50-60%",
59
+ "notes": "2048th note duration (Bug #2b)",
60
+ "bug": "2048th note duration (Bug #2b)"
61
+ },
62
+ {
63
+ "id": "la_campanella",
64
+ "url": "https://www.youtube.com/watch?v=MD6xMyuZls0",
65
+ "description": "Liszt - La Campanella",
66
+ "difficulty": "very_hard",
67
+ "expected_accuracy": "40-50%",
68
+ "notes": "NoneType velocity (Bug #2a)",
69
+ "bug": "NoneType velocity (Bug #2a)"
70
+ }
71
+ ]
72
+
73
+ def main():
74
+ """Run quick verification tests."""
75
+ print("="*70)
76
+ print("Quick Verification Test - Bug Fixes")
77
+ print("="*70)
78
+ print(f"Testing {len(QUICK_TEST_VIDEOS)} videos that previously failed")
79
+ print("All should now succeed (verifies bug fixes)")
80
+ print()
81
+
82
+ results = []
83
+ for i, video in enumerate(QUICK_TEST_VIDEOS, 1):
84
+ print(f"\n[{i}/{len(QUICK_TEST_VIDEOS)}] Testing: {video['id']}")
85
+ print(f"Previous error: {video['bug']}")
86
+
87
+ result = run_accuracy_test(video, verbose=True)
88
+ results.append(result)
89
+
90
+ # Summary
91
+ print("\n" + "="*70)
92
+ print("QUICK VERIFICATION SUMMARY")
93
+ print("="*70)
94
+
95
+ successful = [r for r in results if r["success"]]
96
+ failed = [r for r in results if not r["success"]]
97
+
98
+ print(f"\nTotal: {len(results)} | Success: {len(successful)} | Failed: {len(failed)}")
99
+ print(f"Success Rate: {len(successful)/len(results)*100:.1f}%")
100
+
101
+ if successful:
102
+ print("\n✅ Bug Fixes Verified - Successful Transcriptions:")
103
+ for r in successful:
104
+ if "midi" in r["metrics"] and "musicxml" in r["metrics"]:
105
+ notes = r["metrics"]["midi"]["total_notes"]
106
+ measures = r["metrics"]["musicxml"]["total_measures"]
107
+ print(f" - {r['video_id']:20s} | {notes:4d} notes | {measures:3d} measures")
108
+
109
+ if failed:
110
+ print("\n❌ Still Failing:")
111
+ for r in failed:
112
+ error_preview = r["error"][:80] if r["error"] else "Unknown"
113
+ print(f" - {r['video_id']:20s} | {error_preview}")
114
+
115
+ # Save results
116
+ from config import settings
117
+ output_path = Path(settings.storage_path) / "quick_verify_results.json"
118
+ output_path.parent.mkdir(parents=True, exist_ok=True)
119
+
120
+ with open(output_path, 'w') as f:
121
+ json.dump({
122
+ "test_date": datetime.utcnow().isoformat(),
123
+ "test_type": "bug_fix_verification",
124
+ "total_tests": len(results),
125
+ "successful": len(successful),
126
+ "failed": len(failed),
127
+ "success_rate": len(successful) / len(results),
128
+ "results": results
129
+ }, f, indent=2)
130
+
131
+ print(f"\n📊 Results saved to: {output_path}")
132
+
133
+ if len(successful) == len(results):
134
+ print("\n🎉 ALL BUG FIXES VERIFIED! Ready for full test suite.")
135
+ return 0
136
+ else:
137
+ print(f"\n⚠️ {len(failed)} test(s) still failing - investigate before full suite")
138
+ return 1
139
+
140
+
141
+ if __name__ == "__main__":
142
+ sys.exit(main())
backend/tasks.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Celery tasks for background job processing."""
2
+ from celery import Task
3
+ from celery_app import celery_app
4
+ from pipeline import TranscriptionPipeline, run_transcription_pipeline
5
+ import redis
6
+ import json
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from config import settings
10
+ import shutil
11
+
12
+ # Redis client
13
+ redis_client = redis.Redis.from_url(settings.redis_url, decode_responses=True)
14
+
15
+
16
+ class TranscriptionTask(Task):
17
+ """Base task with progress tracking."""
18
+
19
+ def update_progress(self, job_id: str, progress: int, stage: str, message: str):
20
+ """
21
+ Update job progress in Redis and publish to WebSocket subscribers.
22
+
23
+ Args:
24
+ job_id: Job identifier
25
+ progress: Progress percentage (0-100)
26
+ stage: Current stage name
27
+ message: Status message
28
+ """
29
+ job_key = f"job:{job_id}"
30
+
31
+ # Update Redis hash
32
+ redis_client.hset(job_key, mapping={
33
+ "progress": progress,
34
+ "current_stage": stage,
35
+ "status_message": message,
36
+ "updated_at": datetime.utcnow().isoformat(),
37
+ })
38
+
39
+ # Publish to pub/sub for WebSocket clients
40
+ update = {
41
+ "type": "progress",
42
+ "job_id": job_id,
43
+ "progress": progress,
44
+ "stage": stage,
45
+ "message": message,
46
+ "timestamp": datetime.utcnow().isoformat(),
47
+ }
48
+ redis_client.publish(f"job:{job_id}:updates", json.dumps(update))
49
+
50
+
51
+ @celery_app.task(base=TranscriptionTask, bind=True)
52
+ def process_transcription_task(self, job_id: str):
53
+ """
54
+ Main transcription task.
55
+
56
+ Args:
57
+ job_id: Unique job identifier
58
+
59
+ Returns:
60
+ Path to generated MusicXML file
61
+ """
62
+ try:
63
+ # Mark job as started
64
+ redis_client.hset(f"job:{job_id}", mapping={
65
+ "status": "processing",
66
+ "started_at": datetime.utcnow().isoformat(),
67
+ })
68
+
69
+ # Get job data
70
+ job_data = redis_client.hgetall(f"job:{job_id}")
71
+
72
+ if not job_data:
73
+ raise ValueError(f"Job not found: {job_id}")
74
+
75
+ youtube_url = job_data.get('youtube_url')
76
+ if not youtube_url:
77
+ raise ValueError(f"Job missing youtube_url: {job_id}")
78
+
79
+ # Initialize pipeline
80
+ pipeline = TranscriptionPipeline(
81
+ job_id=job_id,
82
+ youtube_url=youtube_url,
83
+ storage_path=settings.storage_path
84
+ )
85
+ pipeline.set_progress_callback(lambda p, s, m: self.update_progress(job_id, p, s, m))
86
+
87
+ # Run pipeline
88
+ temp_output_path = pipeline.run()
89
+
90
+ # Output is already in the temp directory, move to persistent storage
91
+ output_path = settings.outputs_path / f"{job_id}.musicxml"
92
+ midi_path = settings.outputs_path / f"{job_id}.mid"
93
+
94
+ # Ensure outputs directory exists
95
+ settings.outputs_path.mkdir(parents=True, exist_ok=True)
96
+
97
+ # Copy the MusicXML file to outputs
98
+ shutil.copy(str(temp_output_path), str(output_path))
99
+
100
+ # Copy the cleaned MIDI file to outputs
101
+ temp_midi_path = pipeline.temp_dir / "piano_clean.mid"
102
+ if temp_midi_path.exists():
103
+ shutil.copy(str(temp_midi_path), str(midi_path))
104
+
105
+ # Cleanup temp files (pipeline has its own cleanup method)
106
+ pipeline.cleanup()
107
+
108
+ # Mark job as completed
109
+ redis_client.hset(f"job:{job_id}", mapping={
110
+ "status": "completed",
111
+ "progress": 100,
112
+ "output_path": str(output_path),
113
+ "midi_path": str(midi_path) if temp_midi_path.exists() else "",
114
+ "completed_at": datetime.utcnow().isoformat(),
115
+ })
116
+
117
+ # Publish completion message
118
+ completion_msg = {
119
+ "type": "completed",
120
+ "job_id": job_id,
121
+ "result_url": f"/api/v1/scores/{job_id}",
122
+ "timestamp": datetime.utcnow().isoformat(),
123
+ }
124
+ redis_client.publish(f"job:{job_id}:updates", json.dumps(completion_msg))
125
+
126
+ return str(output_path)
127
+
128
+ except Exception as e:
129
+ # Mark job as failed
130
+ redis_client.hset(f"job:{job_id}", mapping={
131
+ "status": "failed",
132
+ "error": json.dumps({
133
+ "message": str(e),
134
+ "retryable": self.request.retries < self.max_retries,
135
+ }),
136
+ "failed_at": datetime.utcnow().isoformat(),
137
+ })
138
+
139
+ # Publish error message
140
+ error_msg = {
141
+ "type": "error",
142
+ "job_id": job_id,
143
+ "error": {
144
+ "message": str(e),
145
+ "retryable": self.request.retries < self.max_retries,
146
+ },
147
+ "timestamp": datetime.utcnow().isoformat(),
148
+ }
149
+ redis_client.publish(f"job:{job_id}:updates", json.dumps(error_msg))
150
+
151
+ # Retry if retryable
152
+ if self.request.retries < self.max_retries:
153
+ raise self.retry(exc=e, countdown=2 ** self.request.retries)
154
+ else:
155
+ raise
156
+
157
+
158
+ # === Module-level helper functions for backward compatibility ===
159
+
160
+ def update_progress(job_id: str, progress: int, stage: str, message: str):
161
+ """
162
+ Update job progress in Redis and publish to WebSocket subscribers.
163
+
164
+ Args:
165
+ job_id: Job identifier
166
+ progress: Progress percentage (0-100)
167
+ stage: Current stage name
168
+ message: Status message
169
+ """
170
+ job_key = f"job:{job_id}"
171
+
172
+ # Update Redis hash
173
+ redis_client.hset(job_key, mapping={
174
+ "progress": progress,
175
+ "current_stage": stage,
176
+ "status_message": message,
177
+ "updated_at": datetime.utcnow().isoformat(),
178
+ })
179
+
180
+ # Publish to pub/sub for WebSocket clients
181
+ update = {
182
+ "type": "progress",
183
+ "job_id": job_id,
184
+ "progress": progress,
185
+ "stage": stage,
186
+ "message": message,
187
+ "timestamp": datetime.utcnow().isoformat(),
188
+ }
189
+ redis_client.publish(f"job:{job_id}:updates", json.dumps(update))
190
+
191
+
192
+ def cleanup_temp_files(job_id: str, storage_path: Path = None):
193
+ """
194
+ Clean up temporary files for a job.
195
+
196
+ Args:
197
+ job_id: Job identifier
198
+ storage_path: Path to storage directory (uses settings if not provided)
199
+ """
200
+ if storage_path is None:
201
+ storage_path = settings.storage_path
202
+
203
+ temp_dir = storage_path / "temp" / job_id
204
+ if temp_dir.exists():
205
+ shutil.rmtree(temp_dir, ignore_errors=True)
backend/tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Test suite for Rescored backend."""
backend/tests/conftest.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pytest configuration and fixtures for backend tests."""
2
+ import pytest
3
+ from pathlib import Path
4
+ import tempfile
5
+ import shutil
6
+ from fastapi.testclient import TestClient
7
+ from redis import Redis
8
+ from unittest.mock import MagicMock, patch
9
+ import uuid
10
+
11
+
12
+ @pytest.fixture
13
+ def temp_storage_dir():
14
+ """Create temporary storage directory for tests."""
15
+ temp_dir = tempfile.mkdtemp()
16
+ yield Path(temp_dir)
17
+ shutil.rmtree(temp_dir, ignore_errors=True)
18
+
19
+
20
+ @pytest.fixture
21
+ def mock_redis():
22
+ """Mock Redis client for testing."""
23
+ redis_mock = MagicMock(spec=Redis)
24
+ redis_mock.ping.return_value = True
25
+ redis_mock.hgetall.return_value = {}
26
+ redis_mock.hset.return_value = True
27
+ redis_mock.pubsub.return_value.subscribe.return_value = None
28
+ return redis_mock
29
+
30
+
31
+ @pytest.fixture
32
+ def test_client(mock_redis, temp_storage_dir):
33
+ """Create FastAPI test client with mocked dependencies."""
34
+ with patch('main.redis_client', mock_redis):
35
+ with patch('config.settings.storage_path', temp_storage_dir):
36
+ from main import app
37
+ client = TestClient(app)
38
+ yield client
39
+
40
+
41
+ @pytest.fixture
42
+ def sample_job_id():
43
+ """Generate a sample job ID for testing."""
44
+ return str(uuid.uuid4())
45
+
46
+
47
+ @pytest.fixture
48
+ def sample_job_data(sample_job_id):
49
+ """Sample job data for testing."""
50
+ return {
51
+ "job_id": sample_job_id,
52
+ "status": "queued",
53
+ "youtube_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
54
+ "video_id": "dQw4w9WgXcQ",
55
+ "options": '{"instruments": ["piano"]}',
56
+ "created_at": "2025-01-01T00:00:00",
57
+ "progress": 0,
58
+ "current_stage": "queued",
59
+ "status_message": "Job queued for processing",
60
+ }
61
+
62
+
63
+ @pytest.fixture
64
+ def sample_youtube_urls():
65
+ """Collection of sample YouTube URLs for testing."""
66
+ return {
67
+ "valid": [
68
+ "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
69
+ "https://youtu.be/dQw4w9WgXcQ",
70
+ "https://m.youtube.com/watch?v=dQw4w9WgXcQ",
71
+ "https://www.youtube.com/embed/dQw4w9WgXcQ",
72
+ ],
73
+ "invalid": [
74
+ "https://example.com/video",
75
+ "not-a-url",
76
+ "https://vimeo.com/12345",
77
+ "https://youtube.com/invalid",
78
+ ]
79
+ }
80
+
81
+
82
+ @pytest.fixture
83
+ def mock_yt_dlp_info():
84
+ """Mock yt-dlp video info."""
85
+ return {
86
+ 'id': 'dQw4w9WgXcQ',
87
+ 'title': 'Test Video',
88
+ 'duration': 180, # 3 minutes
89
+ 'age_limit': 0,
90
+ 'formats': [
91
+ {'format_id': '140', 'ext': 'wav', 'abr': 128}
92
+ ]
93
+ }
94
+
95
+
96
+ @pytest.fixture
97
+ def sample_audio_file(temp_storage_dir):
98
+ """Create a sample WAV file for testing."""
99
+ import numpy as np
100
+ import soundfile as sf
101
+
102
+ # Generate 1 second of silence at 44.1kHz
103
+ sample_rate = 44100
104
+ duration = 1.0
105
+ samples = np.zeros(int(sample_rate * duration), dtype=np.float32)
106
+
107
+ audio_path = temp_storage_dir / "test_audio.wav"
108
+ sf.write(str(audio_path), samples, sample_rate)
109
+
110
+ return audio_path
111
+
112
+
113
+ @pytest.fixture
114
+ def sample_midi_file(temp_storage_dir):
115
+ """Create a sample MIDI file for testing."""
116
+ import mido
117
+
118
+ mid = mido.MidiFile()
119
+ track = mido.MidiTrack()
120
+ mid.tracks.append(track)
121
+
122
+ # Add some notes (middle C for 1 beat)
123
+ track.append(mido.Message('note_on', note=60, velocity=64, time=0))
124
+ track.append(mido.Message('note_off', note=60, velocity=64, time=480))
125
+
126
+ midi_path = temp_storage_dir / "test_midi.mid"
127
+ mid.save(str(midi_path))
128
+
129
+ return midi_path
130
+
131
+
132
+ @pytest.fixture
133
+ def sample_musicxml_content():
134
+ """Sample MusicXML content for testing."""
135
+ return '''<?xml version="1.0" encoding="UTF-8"?>
136
+ <!DOCTYPE score-partwise PUBLIC "-//Recordare//DTD MusicXML 3.1 Partwise//EN" "http://www.musicxml.org/dtds/partwise.dtd">
137
+ <score-partwise version="3.1">
138
+ <part-list>
139
+ <score-part id="P1">
140
+ <part-name>Piano</part-name>
141
+ </score-part>
142
+ </part-list>
143
+ <part id="P1">
144
+ <measure number="1">
145
+ <attributes>
146
+ <divisions>1</divisions>
147
+ <key>
148
+ <fifths>0</fifths>
149
+ </key>
150
+ <time>
151
+ <beats>4</beats>
152
+ <beat-type>4</beat-type>
153
+ </time>
154
+ <clef>
155
+ <sign>G</sign>
156
+ <line>2</line>
157
+ </clef>
158
+ </attributes>
159
+ <note>
160
+ <pitch>
161
+ <step>C</step>
162
+ <octave>4</octave>
163
+ </pitch>
164
+ <duration>4</duration>
165
+ <type>whole</type>
166
+ </note>
167
+ </measure>
168
+ </part>
169
+ </score-partwise>'''
backend/tests/test_api.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration tests for FastAPI endpoints."""
2
+ import pytest
3
+ from unittest.mock import patch, MagicMock
4
+ import json
5
+
6
+
7
+ class TestRootEndpoint:
8
+ """Test root endpoint."""
9
+
10
+ def test_root(self, test_client):
11
+ """Test root endpoint returns API info."""
12
+ response = test_client.get("/")
13
+ assert response.status_code == 200
14
+ data = response.json()
15
+ assert data["name"] == "Rescored API"
16
+ assert data["version"] == "1.0.0"
17
+ assert data["docs"] == "/docs"
18
+
19
+
20
+ class TestHealthCheck:
21
+ """Test health check endpoint."""
22
+
23
+ def test_health_check_healthy(self, test_client, mock_redis):
24
+ """Test health check when all services are healthy."""
25
+ mock_redis.ping.return_value = True
26
+
27
+ response = test_client.get("/health")
28
+ assert response.status_code == 200
29
+ data = response.json()
30
+ assert data["status"] == "healthy"
31
+ assert data["redis"] == "healthy"
32
+
33
+ def test_health_check_redis_down(self, test_client, mock_redis):
34
+ """Test health check when Redis is down."""
35
+ mock_redis.ping.side_effect = Exception("Connection failed")
36
+
37
+ response = test_client.get("/health")
38
+ assert response.status_code == 200
39
+ data = response.json()
40
+ assert data["status"] == "degraded"
41
+ assert data["redis"] == "unhealthy"
42
+
43
+
44
+ class TestTranscribeEndpoint:
45
+ """Test transcription submission endpoint."""
46
+
47
+ @patch('main.process_transcription_task')
48
+ @patch('utils.check_video_availability')
49
+ @patch('utils.validate_youtube_url')
50
+ def test_submit_valid_transcription(
51
+ self,
52
+ mock_validate,
53
+ mock_check_availability,
54
+ mock_task,
55
+ test_client,
56
+ mock_redis
57
+ ):
58
+ """Test submitting valid transcription request."""
59
+ mock_validate.return_value = (True, "dQw4w9WgXcQ")
60
+ mock_check_availability.return_value = {
61
+ 'available': True,
62
+ 'info': {'duration': 180}
63
+ }
64
+ mock_task.delay.return_value = MagicMock(id="task-id")
65
+
66
+ response = test_client.post(
67
+ "/api/v1/transcribe",
68
+ json={"youtube_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"}
69
+ )
70
+
71
+ assert response.status_code == 201
72
+ data = response.json()
73
+ assert "job_id" in data
74
+ assert data["status"] == "queued"
75
+ assert "websocket_url" in data
76
+ assert data["estimated_duration_seconds"] == 120
77
+
78
+ # Verify Redis was called to store job
79
+ assert mock_redis.hset.called
80
+
81
+ # Verify Celery task was queued
82
+ assert mock_task.delay.called
83
+
84
+ @patch('utils.validate_youtube_url')
85
+ def test_submit_invalid_url(self, mock_validate, test_client):
86
+ """Test submitting invalid YouTube URL."""
87
+ mock_validate.return_value = (False, "Invalid YouTube URL format")
88
+
89
+ response = test_client.post(
90
+ "/api/v1/transcribe",
91
+ json={"youtube_url": "https://invalid.com/video"}
92
+ )
93
+
94
+ assert response.status_code == 400
95
+ assert "Invalid YouTube URL format" in response.json()["detail"]
96
+
97
+ @patch('main.validate_youtube_url')
98
+ @patch('main.check_video_availability')
99
+ def test_submit_unavailable_video(
100
+ self,
101
+ mock_check_availability,
102
+ mock_validate,
103
+ test_client
104
+ ):
105
+ """Test submitting unavailable video."""
106
+ mock_validate.return_value = (True, "dQw4w9WgXcQ")
107
+ mock_check_availability.return_value = {
108
+ 'available': False,
109
+ 'reason': 'Video too long (max 15 minutes)'
110
+ }
111
+
112
+ response = test_client.post(
113
+ "/api/v1/transcribe",
114
+ json={"youtube_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"}
115
+ )
116
+
117
+ assert response.status_code == 422
118
+ assert "too long" in response.json()["detail"]
119
+
120
+ @patch('utils.validate_youtube_url')
121
+ @patch('utils.check_video_availability')
122
+ def test_submit_with_options(
123
+ self,
124
+ mock_check_availability,
125
+ mock_validate,
126
+ test_client,
127
+ mock_redis
128
+ ):
129
+ """Test submitting transcription with custom options."""
130
+ mock_validate.return_value = (True, "dQw4w9WgXcQ")
131
+ mock_check_availability.return_value = {'available': True, 'info': {}}
132
+
133
+ with patch('main.process_transcription_task') as mock_task:
134
+ response = test_client.post(
135
+ "/api/v1/transcribe",
136
+ json={
137
+ "youtube_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
138
+ "options": {"instruments": ["piano", "guitar"]}
139
+ }
140
+ )
141
+
142
+ assert response.status_code == 201
143
+
144
+
145
+ class TestRateLimiting:
146
+ """Test rate limiting middleware."""
147
+
148
+ @patch('utils.validate_youtube_url')
149
+ @patch('utils.check_video_availability')
150
+ @patch('main.process_transcription_task')
151
+ def test_rate_limit_enforced(
152
+ self,
153
+ mock_task,
154
+ mock_check_availability,
155
+ mock_validate,
156
+ test_client,
157
+ mock_redis
158
+ ):
159
+ """Test that rate limit is enforced after 10 requests."""
160
+ mock_validate.return_value = (True, "dQw4w9WgXcQ")
161
+ mock_check_availability.return_value = {'available': True, 'info': {}}
162
+ mock_task.delay.return_value = MagicMock(id="task-id")
163
+
164
+ # Mock Redis counter for rate limiting
165
+ mock_redis.get.return_value = "10" # Already at limit
166
+
167
+ response = test_client.post(
168
+ "/api/v1/transcribe",
169
+ json={"youtube_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"}
170
+ )
171
+
172
+ assert response.status_code == 429
173
+ assert "Rate limit exceeded" in response.json()["detail"]
174
+
175
+ @patch('utils.validate_youtube_url')
176
+ @patch('utils.check_video_availability')
177
+ @patch('main.process_transcription_task')
178
+ def test_rate_limit_under_limit(
179
+ self,
180
+ mock_task,
181
+ mock_check_availability,
182
+ mock_validate,
183
+ test_client,
184
+ mock_redis
185
+ ):
186
+ """Test that requests under limit succeed."""
187
+ mock_validate.return_value = (True, "dQw4w9WgXcQ")
188
+ mock_check_availability.return_value = {'available': True, 'info': {}}
189
+ mock_task.delay.return_value = MagicMock(id="task-id")
190
+
191
+ # Mock Redis counter for rate limiting (under limit)
192
+ mock_redis.get.return_value = "5" # 5 out of 10
193
+
194
+ response = test_client.post(
195
+ "/api/v1/transcribe",
196
+ json={"youtube_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"}
197
+ )
198
+
199
+ assert response.status_code == 201 # Request succeeds
200
+ assert mock_redis.pipeline.called # Counter incremented
201
+
202
+
203
+ class TestJobStatusEndpoint:
204
+ """Test job status endpoint."""
205
+
206
+ def test_get_existing_job_status(self, test_client, mock_redis, sample_job_data):
207
+ """Test getting status of existing job."""
208
+ mock_redis.hgetall.return_value = sample_job_data
209
+
210
+ response = test_client.get(f"/api/v1/jobs/{sample_job_data['job_id']}")
211
+
212
+ assert response.status_code == 200
213
+ data = response.json()
214
+ assert data["job_id"] == sample_job_data["job_id"]
215
+ assert data["status"] == "queued"
216
+ assert data["progress"] == 0
217
+ assert data["current_stage"] == "queued"
218
+
219
+ def test_get_nonexistent_job(self, test_client, mock_redis):
220
+ """Test getting status of nonexistent job."""
221
+ mock_redis.hgetall.return_value = {}
222
+
223
+ response = test_client.get("/api/v1/jobs/nonexistent-id")
224
+
225
+ assert response.status_code == 404
226
+ assert "not found" in response.json()["detail"]
227
+
228
+ def test_get_completed_job_status(self, test_client, mock_redis, sample_job_data):
229
+ """Test getting status of completed job."""
230
+ completed_job = {**sample_job_data, "status": "completed", "progress": 100}
231
+ mock_redis.hgetall.return_value = completed_job
232
+
233
+ response = test_client.get(f"/api/v1/jobs/{sample_job_data['job_id']}")
234
+
235
+ assert response.status_code == 200
236
+ data = response.json()
237
+ assert data["status"] == "completed"
238
+ assert data["progress"] == 100
239
+ assert data["result_url"] is not None
240
+
241
+ def test_get_failed_job_status(self, test_client, mock_redis, sample_job_data):
242
+ """Test getting status of failed job."""
243
+ error_data = {"message": "Transcription failed", "stage": "audio_download"}
244
+ failed_job = {
245
+ **sample_job_data,
246
+ "status": "failed",
247
+ "error": json.dumps(error_data)
248
+ }
249
+ mock_redis.hgetall.return_value = failed_job
250
+
251
+ response = test_client.get(f"/api/v1/jobs/{sample_job_data['job_id']}")
252
+
253
+ assert response.status_code == 200
254
+ data = response.json()
255
+ assert data["status"] == "failed"
256
+ assert data["error"] is not None
257
+ assert data["error"]["message"] == "Transcription failed"
258
+
259
+
260
+ class TestScoreDownloadEndpoint:
261
+ """Test score download endpoint."""
262
+
263
+ def test_download_completed_score(
264
+ self,
265
+ test_client,
266
+ mock_redis,
267
+ sample_job_data,
268
+ temp_storage_dir,
269
+ sample_musicxml_content
270
+ ):
271
+ """Test downloading a completed score."""
272
+ # Create a real MusicXML file
273
+ score_path = temp_storage_dir / "score.musicxml"
274
+ score_path.write_text(sample_musicxml_content)
275
+
276
+ completed_job = {
277
+ **sample_job_data,
278
+ "status": "completed",
279
+ "output_path": str(score_path)
280
+ }
281
+ mock_redis.hgetall.return_value = completed_job
282
+
283
+ response = test_client.get(f"/api/v1/scores/{sample_job_data['job_id']}")
284
+
285
+ assert response.status_code == 200
286
+ assert response.headers["content-type"] == "application/vnd.recordare.musicxml+xml"
287
+ assert "score-partwise" in response.text
288
+
289
+ def test_download_nonexistent_job(self, test_client, mock_redis):
290
+ """Test downloading score for nonexistent job."""
291
+ mock_redis.hgetall.return_value = {}
292
+
293
+ response = test_client.get("/api/v1/scores/nonexistent-id")
294
+
295
+ assert response.status_code == 404
296
+
297
+ def test_download_incomplete_job(self, test_client, mock_redis, sample_job_data):
298
+ """Test downloading score for incomplete job."""
299
+ mock_redis.hgetall.return_value = sample_job_data # Still queued
300
+
301
+ response = test_client.get(f"/api/v1/scores/{sample_job_data['job_id']}")
302
+
303
+ assert response.status_code == 404
304
+ assert "not available" in response.json()["detail"]
305
+
306
+ def test_download_missing_file(self, test_client, mock_redis, sample_job_data):
307
+ """Test downloading score when file is missing."""
308
+ completed_job = {
309
+ **sample_job_data,
310
+ "status": "completed",
311
+ "output_path": "/nonexistent/path/score.musicxml"
312
+ }
313
+ mock_redis.hgetall.return_value = completed_job
314
+
315
+ response = test_client.get(f"/api/v1/scores/{sample_job_data['job_id']}")
316
+
317
+ assert response.status_code == 404
318
+ assert "not found" in response.json()["detail"]
319
+
320
+
321
+ class TestMIDIDownloadEndpoint:
322
+ """Test MIDI download endpoint."""
323
+
324
+ def test_download_completed_midi(self, test_client, sample_job_id, tmp_path, mock_redis):
325
+ """Test downloading MIDI from completed job."""
326
+ # Create a dummy MIDI file
327
+ midi_file = tmp_path / "test.mid"
328
+ midi_file.write_bytes(b"MIDI_DATA")
329
+
330
+ # Set job as completed with MIDI path
331
+ mock_redis.hgetall.return_value = {
332
+ "status": "completed",
333
+ "midi_path": str(midi_file)
334
+ }
335
+
336
+ response = test_client.get(f"/api/v1/scores/{sample_job_id}/midi")
337
+
338
+ assert response.status_code == 200
339
+ assert response.headers["content-type"] == "audio/midi"
340
+ assert response.content == b"MIDI_DATA"
341
+
342
+ def test_download_nonexistent_job_midi(self, test_client, mock_redis):
343
+ """Test downloading MIDI from nonexistent job."""
344
+ mock_redis.hgetall.return_value = {}
345
+
346
+ response = test_client.get("/api/v1/scores/nonexistent/midi")
347
+
348
+ assert response.status_code == 404
349
+ assert "not available" in response.json()["detail"]
350
+
351
+ def test_download_incomplete_job_midi(self, test_client, sample_job_id, mock_redis):
352
+ """Test downloading MIDI from incomplete job."""
353
+ mock_redis.hgetall.return_value = {"status": "processing"}
354
+
355
+ response = test_client.get(f"/api/v1/scores/{sample_job_id}/midi")
356
+
357
+ assert response.status_code == 404
358
+
359
+ def test_download_missing_midi_file(self, test_client, sample_job_id, mock_redis):
360
+ """Test downloading when MIDI file doesn't exist."""
361
+ mock_redis.hgetall.return_value = {
362
+ "status": "completed",
363
+ "midi_path": "/nonexistent/path.mid"
364
+ }
365
+
366
+ response = test_client.get(f"/api/v1/scores/{sample_job_id}/midi")
367
+
368
+ assert response.status_code == 404
369
+ assert "file not found" in response.json()["detail"].lower()
backend/tests/test_pipeline.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for audio processing pipeline - simplified version."""
2
+ import pytest
3
+ from pathlib import Path
4
+
5
+
6
+ class TestPipelineImports:
7
+ """Test that pipeline functions can be imported and are callable."""
8
+
9
+ def test_download_audio_callable(self):
10
+ """Test download_audio is callable."""
11
+ from pipeline import download_audio
12
+ assert callable(download_audio)
13
+
14
+ def test_separate_sources_callable(self):
15
+ """Test separate_sources is callable."""
16
+ from pipeline import separate_sources
17
+ assert callable(separate_sources)
18
+
19
+ def test_transcribe_audio_callable(self):
20
+ """Test transcribe_audio is callable."""
21
+ from pipeline import transcribe_audio
22
+ assert callable(transcribe_audio)
23
+
24
+ def test_quantize_midi_callable(self):
25
+ """Test quantize_midi is callable."""
26
+ from pipeline import quantize_midi
27
+ assert callable(quantize_midi)
28
+
29
+ def test_remove_duplicate_notes_callable(self):
30
+ """Test remove_duplicate_notes is callable."""
31
+ from pipeline import remove_duplicate_notes
32
+ assert callable(remove_duplicate_notes)
33
+
34
+ def test_remove_short_notes_callable(self):
35
+ """Test remove_short_notes is callable."""
36
+ from pipeline import remove_short_notes
37
+ assert callable(remove_short_notes)
38
+
39
+ def test_generate_musicxml_callable(self):
40
+ """Test generate_musicxml is callable."""
41
+ from pipeline import generate_musicxml
42
+ assert callable(generate_musicxml)
43
+
44
+ def test_detect_key_signature_callable(self):
45
+ """Test detect_key_signature is callable."""
46
+ from pipeline import detect_key_signature
47
+ assert callable(detect_key_signature)
48
+
49
+ def test_detect_time_signature_callable(self):
50
+ """Test detect_time_signature is callable."""
51
+ from pipeline import detect_time_signature
52
+ assert callable(detect_time_signature)
53
+
54
+ def test_detect_tempo_callable(self):
55
+ """Test detect_tempo is callable."""
56
+ from pipeline import detect_tempo
57
+ assert callable(detect_tempo)
58
+
59
+ def test_run_transcription_pipeline_callable(self):
60
+ """Test run_transcription_pipeline is callable."""
61
+ from pipeline import run_transcription_pipeline
62
+ assert callable(run_transcription_pipeline)
63
+
64
+
65
+ class TestTranscriptionPipelineClass:
66
+ """Test TranscriptionPipeline class."""
67
+
68
+ def test_pipeline_class_exists(self):
69
+ """Test TranscriptionPipeline class can be instantiated."""
70
+ from pipeline import TranscriptionPipeline
71
+
72
+ pipeline = TranscriptionPipeline("test_job", "http://example.com", Path("/tmp"))
73
+ assert pipeline.job_id == "test_job"
74
+ assert pipeline.youtube_url == "http://example.com"
75
+ assert isinstance(pipeline.storage_path, Path)
76
+
77
+ def test_pipeline_has_progress_callback(self):
78
+ """Test TranscriptionPipeline has progress_callback."""
79
+ from pipeline import TranscriptionPipeline
80
+
81
+ pipeline = TranscriptionPipeline("test_job", "http://example.com", Path("/tmp"))
82
+ assert hasattr(pipeline, 'set_progress_callback')
83
+ assert callable(pipeline.set_progress_callback)
84
+
85
+ def test_pipeline_has_required_methods(self):
86
+ """Test TranscriptionPipeline has all required methods."""
87
+ from pipeline import TranscriptionPipeline
88
+
89
+ pipeline = TranscriptionPipeline("test_job", "http://example.com", Path("/tmp"))
90
+
91
+ required_methods = [
92
+ 'download_audio',
93
+ 'separate_sources',
94
+ 'transcribe_to_midi',
95
+ 'clean_midi',
96
+ 'generate_musicxml',
97
+ 'cleanup'
98
+ ]
99
+
100
+ for method in required_methods:
101
+ assert hasattr(pipeline, method)
102
+ assert callable(getattr(pipeline, method))
backend/tests/test_tasks.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for Celery tasks."""
2
+ import pytest
3
+ from unittest.mock import patch, MagicMock, call
4
+ import json
5
+
6
+
7
+ class TestProcessTranscriptionTask:
8
+ """Test the main Celery transcription task."""
9
+
10
+ @patch('tasks.shutil.copy')
11
+ @patch('tasks.TranscriptionPipeline')
12
+ @patch('tasks.redis_client')
13
+ def test_task_success(self, mock_redis, mock_pipeline_class, mock_copy, sample_job_id, temp_storage_dir):
14
+ """Test successful task execution."""
15
+ from tasks import process_transcription_task
16
+
17
+ # Mock job data in Redis
18
+ job_data = {
19
+ 'job_id': sample_job_id,
20
+ 'youtube_url': 'https://www.youtube.com/watch?v=dQw4w9WgXcQ',
21
+ 'video_id': 'dQw4w9WgXcQ',
22
+ 'options': '{"instruments": ["piano"]}'
23
+ }
24
+ mock_redis.hgetall.return_value = job_data
25
+
26
+ # Mock successful pipeline instance
27
+ mock_pipeline = MagicMock()
28
+ mock_pipeline.run.return_value = str(temp_storage_dir / "output.musicxml")
29
+ mock_pipeline_class.return_value = mock_pipeline
30
+
31
+ # Execute task
32
+ process_transcription_task(sample_job_id)
33
+
34
+ # Verify pipeline ran
35
+ mock_pipeline.run.assert_called_once()
36
+
37
+ # Verify progress updates were published
38
+ assert mock_redis.publish.call_count > 0
39
+
40
+ # Verify final status was set to completed
41
+ completed_calls = [
42
+ call for call in mock_redis.hset.call_args_list
43
+ if 'completed' in str(call)
44
+ ]
45
+ assert len(completed_calls) > 0
46
+
47
+ @patch('tasks.shutil.copy')
48
+ @patch('tasks.TranscriptionPipeline')
49
+ @patch('tasks.redis_client')
50
+ def test_task_failure(self, mock_redis, mock_pipeline_class, mock_copy, sample_job_id):
51
+ """Test task execution with pipeline failure."""
52
+ from tasks import process_transcription_task
53
+ from celery.exceptions import Retry
54
+
55
+ job_data = {
56
+ 'job_id': sample_job_id,
57
+ 'youtube_url': 'https://www.youtube.com/watch?v=invalid',
58
+ 'video_id': 'invalid',
59
+ 'options': '{}'
60
+ }
61
+ mock_redis.hgetall.return_value = job_data
62
+
63
+ # Mock failed pipeline
64
+ mock_pipeline = MagicMock()
65
+ mock_pipeline.run.side_effect = RuntimeError("Download failed")
66
+ mock_pipeline_class.return_value = mock_pipeline
67
+
68
+ # Execute task - should raise Retry due to Celery's retry mechanism
69
+ with pytest.raises((Retry, RuntimeError)):
70
+ process_transcription_task(sample_job_id)
71
+
72
+ # Verify error was stored in Redis before retry
73
+ error_calls = [
74
+ call for call in mock_redis.hset.call_args_list
75
+ if 'error' in str(call)
76
+ ]
77
+ assert len(error_calls) > 0
78
+
79
+ @patch('tasks.shutil.copy')
80
+ @patch('tasks.TranscriptionPipeline')
81
+ @patch('tasks.redis_client')
82
+ def test_task_progress_updates(self, mock_redis, mock_pipeline_class, mock_copy, sample_job_id, temp_storage_dir):
83
+ """Test that task publishes progress updates."""
84
+ from tasks import process_transcription_task
85
+
86
+ job_data = {
87
+ 'job_id': sample_job_id,
88
+ 'youtube_url': 'https://www.youtube.com/watch?v=dQw4w9WgXcQ',
89
+ 'video_id': 'dQw4w9WgXcQ',
90
+ 'options': '{}'
91
+ }
92
+ mock_redis.hgetall.return_value = job_data
93
+
94
+ mock_pipeline = MagicMock()
95
+ mock_pipeline.run.return_value = str(temp_storage_dir / "output.musicxml")
96
+ mock_pipeline_class.return_value = mock_pipeline
97
+
98
+ process_transcription_task(sample_job_id)
99
+
100
+ # Verify completion message was published
101
+ publish_calls = mock_redis.publish.call_args_list
102
+ assert len(publish_calls) >= 1 # At least completion message
103
+
104
+ # Verify final publish call contains completion info
105
+ final_call = publish_calls[-1]
106
+ channel, message = final_call[0]
107
+ assert channel == f"job:{sample_job_id}:updates"
108
+ update_data = json.loads(message)
109
+ assert 'type' in update_data
110
+ assert update_data['type'] == 'completed'
111
+
112
+ @patch('tasks.redis_client')
113
+ def test_task_job_not_found(self, mock_redis, sample_job_id):
114
+ """Test task execution when job doesn't exist."""
115
+ from tasks import process_transcription_task
116
+
117
+ mock_redis.hgetall.return_value = {}
118
+
119
+ with pytest.raises(ValueError) as exc_info:
120
+ process_transcription_task(sample_job_id)
121
+
122
+ assert "Job not found" in str(exc_info.value)
123
+
124
+ @patch('tasks.shutil.copy')
125
+ @patch('tasks.TranscriptionPipeline')
126
+ @patch('tasks.redis_client')
127
+ def test_task_retry_on_network_error(self, mock_redis, mock_pipeline_class, mock_copy, sample_job_id):
128
+ """Test task retry logic for transient errors."""
129
+ from tasks import process_transcription_task
130
+ from celery.exceptions import Retry
131
+
132
+ job_data = {
133
+ 'job_id': sample_job_id,
134
+ 'youtube_url': 'https://www.youtube.com/watch?v=dQw4w9WgXcQ',
135
+ 'video_id': 'dQw4w9WgXcQ',
136
+ 'options': '{}'
137
+ }
138
+ mock_redis.hgetall.return_value = job_data
139
+
140
+ # Mock transient network error
141
+ mock_pipeline = MagicMock()
142
+ mock_pipeline.run.side_effect = ConnectionError("Network timeout")
143
+ mock_pipeline_class.return_value = mock_pipeline
144
+
145
+ with pytest.raises((Retry, ConnectionError)):
146
+ process_transcription_task(sample_job_id)
147
+
148
+
149
+ class TestProgressCallback:
150
+ """Test progress callback functionality."""
151
+
152
+ @patch('tasks.redis_client')
153
+ def test_update_progress(self, mock_redis, sample_job_id):
154
+ """Test progress update function."""
155
+ from tasks import update_progress
156
+
157
+ update_progress(sample_job_id, 50, "transcription", "Transcribing audio...")
158
+
159
+ # Verify Redis was updated
160
+ mock_redis.hset.assert_called()
161
+ call_args = mock_redis.hset.call_args[0]
162
+ assert call_args[0] == f"job:{sample_job_id}"
163
+
164
+ # Verify WebSocket message was published
165
+ mock_redis.publish.assert_called()
166
+ channel, message = mock_redis.publish.call_args[0]
167
+ assert channel == f"job:{sample_job_id}:updates"
168
+
169
+ update_data = json.loads(message)
170
+ assert update_data['progress'] == 50
171
+ assert update_data['stage'] == "transcription"
172
+ assert update_data['message'] == "Transcribing audio..."
173
+
174
+ @patch('tasks.redis_client')
175
+ def test_multiple_progress_updates(self, mock_redis, sample_job_id):
176
+ """Test sequence of progress updates."""
177
+ from tasks import update_progress
178
+
179
+ stages = [
180
+ (5, "download", "Downloading audio"),
181
+ (25, "separation", "Separating audio sources"),
182
+ (60, "transcription", "Transcribing to MIDI"),
183
+ (90, "musicxml", "Generating MusicXML"),
184
+ (100, "completed", "Processing complete")
185
+ ]
186
+
187
+ for progress, stage, message in stages:
188
+ update_progress(sample_job_id, progress, stage, message)
189
+
190
+ # Should have 5 updates
191
+ assert mock_redis.hset.call_count == 5
192
+ assert mock_redis.publish.call_count == 5
193
+
194
+
195
+ class TestCleanup:
196
+ """Test cleanup of temporary files."""
197
+
198
+ @patch('tasks.shutil.rmtree')
199
+ def test_cleanup_temp_files(self, mock_rmtree, sample_job_id, temp_storage_dir):
200
+ """Test cleanup of temporary files after job completion."""
201
+ from tasks import cleanup_temp_files
202
+
203
+ # Create the temp directory so cleanup will attempt to remove it
204
+ temp_dir = temp_storage_dir / "temp" / sample_job_id
205
+ temp_dir.mkdir(parents=True, exist_ok=True)
206
+
207
+ cleanup_temp_files(sample_job_id, storage_path=temp_storage_dir)
208
+
209
+ # Verify temp directory was removed
210
+ mock_rmtree.assert_called()
211
+
212
+ def test_cleanup_preserves_output(self, sample_job_id, temp_storage_dir):
213
+ """Test that cleanup preserves final output files."""
214
+ from tasks import cleanup_temp_files
215
+
216
+ # Create a temp directory with files
217
+ temp_dir = temp_storage_dir / "temp" / sample_job_id
218
+ temp_dir.mkdir(parents=True, exist_ok=True)
219
+
220
+ # Create temp files
221
+ (temp_dir / "temp_audio.wav").touch()
222
+ (temp_dir / "temp_midi.mid").touch()
223
+
224
+ # Create output files
225
+ outputs_dir = temp_storage_dir / "outputs"
226
+ outputs_dir.mkdir(parents=True, exist_ok=True)
227
+ output_files = [
228
+ outputs_dir / "output.musicxml",
229
+ outputs_dir / "output.mid"
230
+ ]
231
+
232
+ for f in output_files:
233
+ f.touch()
234
+
235
+ # Run cleanup
236
+ cleanup_temp_files(sample_job_id, storage_path=temp_storage_dir)
237
+
238
+ # Verify temp directory was removed
239
+ assert not temp_dir.exists()
240
+
241
+ # Verify output files still exist
242
+ for f in output_files:
243
+ assert f.exists()
backend/tests/test_utils.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for utility functions."""
2
+ import pytest
3
+ from utils import validate_youtube_url, check_video_availability
4
+ from unittest.mock import patch, MagicMock
5
+ import yt_dlp
6
+
7
+
8
+ class TestValidateYouTubeURL:
9
+ """Test YouTube URL validation."""
10
+
11
+ def test_valid_watch_url(self):
12
+ """Test standard youtube.com/watch URL."""
13
+ is_valid, video_id = validate_youtube_url("https://www.youtube.com/watch?v=dQw4w9WgXcQ")
14
+ assert is_valid is True
15
+ assert video_id == "dQw4w9WgXcQ"
16
+
17
+ def test_valid_short_url(self):
18
+ """Test youtu.be short URL."""
19
+ is_valid, video_id = validate_youtube_url("https://youtu.be/dQw4w9WgXcQ")
20
+ assert is_valid is True
21
+ assert video_id == "dQw4w9WgXcQ"
22
+
23
+ def test_valid_mobile_url(self):
24
+ """Test mobile YouTube URL."""
25
+ is_valid, video_id = validate_youtube_url("https://m.youtube.com/watch?v=dQw4w9WgXcQ")
26
+ assert is_valid is True
27
+ assert video_id == "dQw4w9WgXcQ"
28
+
29
+ def test_valid_embed_url(self):
30
+ """Test embedded YouTube URL."""
31
+ is_valid, video_id = validate_youtube_url("https://www.youtube.com/embed/dQw4w9WgXcQ")
32
+ assert is_valid is True
33
+ assert video_id == "dQw4w9WgXcQ"
34
+
35
+ def test_valid_with_extra_params(self):
36
+ """Test URL with additional query parameters."""
37
+ is_valid, video_id = validate_youtube_url("https://www.youtube.com/watch?v=dQw4w9WgXcQ&t=30s")
38
+ assert is_valid is True
39
+ assert video_id == "dQw4w9WgXcQ"
40
+
41
+ def test_invalid_domain(self):
42
+ """Test URL from wrong domain."""
43
+ is_valid, error = validate_youtube_url("https://vimeo.com/12345")
44
+ assert is_valid is False
45
+ assert error == "Invalid YouTube URL format"
46
+
47
+ def test_invalid_format(self):
48
+ """Test malformed URL."""
49
+ is_valid, error = validate_youtube_url("not-a-url")
50
+ assert is_valid is False
51
+ assert error == "Invalid YouTube URL format"
52
+
53
+ def test_invalid_video_id_length(self):
54
+ """Test URL with incorrect video ID length."""
55
+ is_valid, error = validate_youtube_url("https://www.youtube.com/watch?v=short")
56
+ assert is_valid is False
57
+ assert error == "Invalid YouTube URL format"
58
+
59
+ def test_empty_url(self):
60
+ """Test empty URL."""
61
+ is_valid, error = validate_youtube_url("")
62
+ assert is_valid is False
63
+ assert error == "Invalid YouTube URL format"
64
+
65
+
66
+ class TestCheckVideoAvailability:
67
+ """Test video availability checking."""
68
+
69
+ @patch('yt_dlp.YoutubeDL')
70
+ def test_available_video(self, mock_ydl_class, mock_yt_dlp_info):
71
+ """Test checking available video."""
72
+ mock_ydl = MagicMock()
73
+ mock_ydl.extract_info.return_value = mock_yt_dlp_info
74
+ mock_ydl_class.return_value.__enter__.return_value = mock_ydl
75
+
76
+ result = check_video_availability("dQw4w9WgXcQ")
77
+
78
+ assert result['available'] is True
79
+ assert 'info' in result
80
+
81
+ @patch('yt_dlp.YoutubeDL')
82
+ def test_video_too_long(self, mock_ydl_class):
83
+ """Test video exceeding duration limit."""
84
+ mock_ydl = MagicMock()
85
+ mock_ydl.extract_info.return_value = {
86
+ 'duration': 1200, # 20 minutes
87
+ 'age_limit': 0
88
+ }
89
+ mock_ydl_class.return_value.__enter__.return_value = mock_ydl
90
+
91
+ result = check_video_availability("dQw4w9WgXcQ", max_duration=900)
92
+
93
+ assert result['available'] is False
94
+ assert 'max 15 minutes' in result['reason']
95
+
96
+ @patch('yt_dlp.YoutubeDL')
97
+ def test_age_restricted_video(self, mock_ydl_class):
98
+ """Test age-restricted video."""
99
+ mock_ydl = MagicMock()
100
+ mock_ydl.extract_info.return_value = {
101
+ 'duration': 180,
102
+ 'age_limit': 18
103
+ }
104
+ mock_ydl_class.return_value.__enter__.return_value = mock_ydl
105
+
106
+ result = check_video_availability("dQw4w9WgXcQ")
107
+
108
+ assert result['available'] is False
109
+ assert 'Age-restricted' in result['reason']
110
+
111
+ @patch('yt_dlp.YoutubeDL')
112
+ def test_download_error(self, mock_ydl_class):
113
+ """Test yt-dlp download error."""
114
+ mock_ydl = MagicMock()
115
+ mock_ydl.extract_info.side_effect = yt_dlp.utils.DownloadError("Video unavailable")
116
+ mock_ydl_class.return_value.__enter__.return_value = mock_ydl
117
+
118
+ result = check_video_availability("invalid_id")
119
+
120
+ assert result['available'] is False
121
+ assert 'Video unavailable' in result['reason']
122
+
123
+ @patch('yt_dlp.YoutubeDL')
124
+ def test_generic_error(self, mock_ydl_class):
125
+ """Test generic error handling."""
126
+ mock_ydl = MagicMock()
127
+ mock_ydl.extract_info.side_effect = Exception("Unknown error")
128
+ mock_ydl_class.return_value.__enter__.return_value = mock_ydl
129
+
130
+ result = check_video_availability("dQw4w9WgXcQ")
131
+
132
+ assert result['available'] is False
133
+ assert 'Error checking video' in result['reason']
134
+
135
+ @patch('yt_dlp.YoutubeDL')
136
+ def test_video_at_max_duration(self, mock_ydl_class):
137
+ """Test video exactly at duration limit."""
138
+ mock_ydl = MagicMock()
139
+ mock_ydl.extract_info.return_value = {
140
+ 'duration': 900, # Exactly 15 minutes
141
+ 'age_limit': 0
142
+ }
143
+ mock_ydl_class.return_value.__enter__.return_value = mock_ydl
144
+
145
+ result = check_video_availability("dQw4w9WgXcQ", max_duration=900)
146
+
147
+ assert result['available'] is True
backend/utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for Rescored backend."""
2
+ import re
3
+ from urllib.parse import urlparse, parse_qs
4
+ import yt_dlp
5
+
6
+
7
+ def validate_youtube_url(url: str) -> tuple[bool, str | None]:
8
+ """
9
+ Validate YouTube URL and extract video ID.
10
+
11
+ Args:
12
+ url: YouTube URL to validate
13
+
14
+ Returns:
15
+ (is_valid, video_id or error_message)
16
+ """
17
+ # Supported formats:
18
+ # - https://www.youtube.com/watch?v=VIDEO_ID
19
+ # - https://youtu.be/VIDEO_ID
20
+ # - https://m.youtube.com/watch?v=VIDEO_ID
21
+
22
+ patterns = [
23
+ r'(?:youtube\.com/watch\?v=|youtu\.be/)([a-zA-Z0-9_-]{11})',
24
+ r'youtube\.com/embed/([a-zA-Z0-9_-]{11})',
25
+ ]
26
+
27
+ for pattern in patterns:
28
+ match = re.search(pattern, url)
29
+ if match:
30
+ return True, match.group(1)
31
+
32
+ return False, "Invalid YouTube URL format"
33
+
34
+
35
+ def check_video_availability(video_id: str, max_duration: int = 900) -> dict:
36
+ """
37
+ Check if video is available for download.
38
+
39
+ Args:
40
+ video_id: YouTube video ID
41
+ max_duration: Maximum allowed duration in seconds
42
+
43
+ Returns:
44
+ Dictionary with 'available' (bool) and 'reason' or 'info'
45
+ """
46
+ ydl_opts = {
47
+ 'quiet': True,
48
+ 'no_warnings': True,
49
+ 'extract_flat': True,
50
+ }
51
+
52
+ try:
53
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
54
+ info = ydl.extract_info(
55
+ f"https://youtube.com/watch?v={video_id}",
56
+ download=False
57
+ )
58
+
59
+ # Check duration
60
+ duration = info.get('duration', 0)
61
+ if duration > max_duration:
62
+ return {
63
+ 'available': False,
64
+ 'reason': f'Video too long (max {max_duration // 60} minutes)'
65
+ }
66
+
67
+ # Check if age-restricted
68
+ if info.get('age_limit', 0) > 0:
69
+ return {
70
+ 'available': False,
71
+ 'reason': 'Age-restricted content not supported'
72
+ }
73
+
74
+ return {'available': True, 'info': info}
75
+
76
+ except yt_dlp.utils.DownloadError as e:
77
+ return {'available': False, 'reason': str(e)}
78
+ except Exception as e:
79
+ return {'available': False, 'reason': f'Error checking video: {str(e)}'}
docker-compose.yml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ # Redis - Message broker and cache
5
+ redis:
6
+ image: redis:7-alpine
7
+ ports:
8
+ - "6379:6379"
9
+ volumes:
10
+ - redis_data:/data
11
+ healthcheck:
12
+ test: ["CMD", "redis-cli", "ping"]
13
+ interval: 5s
14
+ timeout: 3s
15
+ retries: 5
16
+
17
+ # Backend API
18
+ api:
19
+ build:
20
+ context: ./backend
21
+ dockerfile: Dockerfile
22
+ ports:
23
+ - "8000:8000"
24
+ environment:
25
+ - REDIS_URL=redis://redis:6379/0
26
+ - STORAGE_PATH=/app/storage
27
+ - API_HOST=0.0.0.0
28
+ - API_PORT=8000
29
+ - CORS_ORIGINS=http://localhost:5173,http://localhost:3000
30
+ volumes:
31
+ - ./backend:/app
32
+ - ./storage:/app/storage
33
+ depends_on:
34
+ redis:
35
+ condition: service_healthy
36
+ command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
37
+
38
+ # Celery Worker (GPU-enabled for ML processing)
39
+ worker:
40
+ build:
41
+ context: ./backend
42
+ dockerfile: Dockerfile.worker
43
+ environment:
44
+ - REDIS_URL=redis://redis:6379/0
45
+ - STORAGE_PATH=/app/storage
46
+ - GPU_ENABLED=true
47
+ volumes:
48
+ - ./backend:/app
49
+ - ./storage:/app/storage
50
+ depends_on:
51
+ redis:
52
+ condition: service_healthy
53
+ command: celery -A tasks worker --loglevel=info --concurrency=1
54
+ # Uncomment for GPU support (requires NVIDIA Docker runtime)
55
+ # deploy:
56
+ # resources:
57
+ # reservations:
58
+ # devices:
59
+ # - driver: nvidia
60
+ # count: 1
61
+ # capabilities: [gpu]
62
+
63
+ # Frontend (React + Vite)
64
+ frontend:
65
+ build:
66
+ context: ./frontend
67
+ dockerfile: Dockerfile
68
+ ports:
69
+ - "5173:5173"
70
+ environment:
71
+ - VITE_API_URL=http://localhost:8000
72
+ volumes:
73
+ - ./frontend:/app
74
+ - /app/node_modules
75
+ command: npm run dev -- --host 0.0.0.0
76
+
77
+ volumes:
78
+ redis_data:
79
+ storage:
docs/testing/backend-testing.md ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Backend Testing Guide
2
+
3
+ Comprehensive guide for testing the Rescored backend.
4
+
5
+ ## Table of Contents
6
+
7
+ - [Setup](#setup)
8
+ - [Running Tests](#running-tests)
9
+ - [Test Structure](#test-structure)
10
+ - [Writing Tests](#writing-tests)
11
+ - [Testing Patterns](#testing-patterns)
12
+ - [Troubleshooting](#troubleshooting)
13
+
14
+ ## Setup
15
+
16
+ ### Install Test Dependencies
17
+
18
+ ```bash
19
+ cd backend
20
+ pip install -r requirements-test.txt
21
+ ```
22
+
23
+ This installs:
24
+ - `pytest`: Test framework
25
+ - `pytest-asyncio`: Async test support
26
+ - `pytest-cov`: Coverage reporting
27
+ - `pytest-mock`: Enhanced mocking
28
+ - `httpx`: HTTP testing client
29
+
30
+ ### Configuration
31
+
32
+ Test configuration is in `pytest.ini`:
33
+
34
+ ```ini
35
+ [pytest]
36
+ testpaths = tests
37
+ markers =
38
+ unit: Unit tests
39
+ integration: Integration tests
40
+ slow: Slow-running tests
41
+ gpu: Tests requiring GPU
42
+ network: Tests requiring network
43
+ ```
44
+
45
+ ## Running Tests
46
+
47
+ ### Basic Commands
48
+
49
+ ```bash
50
+ # Run all tests
51
+ pytest
52
+
53
+ # Run with coverage
54
+ pytest --cov
55
+
56
+ # Run specific file
57
+ pytest tests/test_utils.py
58
+
59
+ # Run specific test
60
+ pytest tests/test_utils.py::TestValidateYouTubeURL::test_valid_watch_url
61
+
62
+ # Run by marker
63
+ pytest -m unit
64
+ pytest -m "unit and not slow"
65
+ ```
66
+
67
+ ### Watch Mode
68
+
69
+ Use `pytest-watch` for continuous testing:
70
+
71
+ ```bash
72
+ pip install pytest-watch
73
+ ptw # Runs tests on file changes
74
+ ```
75
+
76
+ ### Coverage Reports
77
+
78
+ ```bash
79
+ # Terminal report
80
+ pytest --cov --cov-report=term-missing
81
+
82
+ # HTML report
83
+ pytest --cov --cov-report=html
84
+ open htmlcov/index.html
85
+
86
+ # Both
87
+ pytest --cov --cov-report=term-missing --cov-report=html
88
+ ```
89
+
90
+ ## Test Structure
91
+
92
+ ### Test Files
93
+
94
+ Each module has a corresponding test file:
95
+
96
+ - `utils.py` → `tests/test_utils.py`
97
+ - `pipeline.py` → `tests/test_pipeline.py`
98
+ - `main.py` → `tests/test_api.py`
99
+ - `tasks.py` → `tests/test_tasks.py`
100
+
101
+ ### Test Organization
102
+
103
+ Group related tests in classes:
104
+
105
+ ```python
106
+ class TestValidateYouTubeURL:
107
+ """Test YouTube URL validation."""
108
+
109
+ def test_valid_watch_url(self):
110
+ """Test standard youtube.com/watch URL."""
111
+ is_valid, video_id = validate_youtube_url("https://www.youtube.com/watch?v=...")
112
+ assert is_valid is True
113
+ assert video_id == "..."
114
+
115
+ def test_invalid_domain(self):
116
+ """Test URL from wrong domain."""
117
+ is_valid, error = validate_youtube_url("https://vimeo.com/12345")
118
+ assert is_valid is False
119
+ ```
120
+
121
+ ## Writing Tests
122
+
123
+ ### Basic Test Template
124
+
125
+ ```python
126
+ import pytest
127
+ from module_name import function_to_test
128
+
129
+ class TestFunctionName:
130
+ """Test suite for function_name."""
131
+
132
+ def test_happy_path(self):
133
+ """Test normal successful execution."""
134
+ result = function_to_test(valid_input)
135
+ assert result == expected_output
136
+
137
+ def test_edge_case(self):
138
+ """Test boundary condition."""
139
+ result = function_to_test(edge_case_input)
140
+ assert result == expected_edge_output
141
+
142
+ def test_error_handling(self):
143
+ """Test error is raised for invalid input."""
144
+ with pytest.raises(ValueError) as exc_info:
145
+ function_to_test(invalid_input)
146
+ assert "expected error message" in str(exc_info.value)
147
+ ```
148
+
149
+ ### Using Fixtures
150
+
151
+ Fixtures provide reusable test data:
152
+
153
+ ```python
154
+ @pytest.fixture
155
+ def sample_audio_file(temp_storage_dir):
156
+ """Create a sample WAV file for testing."""
157
+ import numpy as np
158
+ import soundfile as sf
159
+
160
+ sample_rate = 44100
161
+ duration = 1.0
162
+ samples = np.zeros(int(sample_rate * duration), dtype=np.float32)
163
+
164
+ audio_path = temp_storage_dir / "test_audio.wav"
165
+ sf.write(str(audio_path), samples, sample_rate)
166
+
167
+ return audio_path
168
+
169
+ def test_using_fixture(sample_audio_file):
170
+ """Test that uses the fixture."""
171
+ assert sample_audio_file.exists()
172
+ assert sample_audio_file.suffix == ".wav"
173
+ ```
174
+
175
+ ### Mocking External Dependencies
176
+
177
+ #### Mock yt-dlp
178
+
179
+ ```python
180
+ from unittest.mock import patch, MagicMock
181
+
182
+ @patch('pipeline.yt_dlp.YoutubeDL')
183
+ def test_download_audio(mock_ydl_class, temp_storage_dir):
184
+ """Test audio download with mocked yt-dlp."""
185
+ mock_ydl = MagicMock()
186
+ mock_ydl_class.return_value.__enter__.return_value = mock_ydl
187
+
188
+ result = download_audio("https://youtube.com/watch?v=...", temp_storage_dir)
189
+
190
+ assert result.exists()
191
+ mock_ydl.download.assert_called_once()
192
+ ```
193
+
194
+ #### Mock Redis
195
+
196
+ ```python
197
+ @pytest.fixture
198
+ def mock_redis():
199
+ """Mock Redis client."""
200
+ redis_mock = MagicMock(spec=Redis)
201
+ redis_mock.ping.return_value = True
202
+ redis_mock.hgetall.return_value = {}
203
+ return redis_mock
204
+
205
+ def test_with_redis(mock_redis):
206
+ """Test function that uses Redis."""
207
+ # Redis is mocked, no real connection needed
208
+ mock_redis.hset("key", "field", "value")
209
+ assert mock_redis.hset.called
210
+ ```
211
+
212
+ #### Mock ML Models
213
+
214
+ ```python
215
+ @patch('pipeline.basic_pitch.inference.predict')
216
+ def test_transcribe_audio(mock_predict, sample_audio_file, temp_storage_dir):
217
+ """Test transcription with mocked ML model."""
218
+ # Mock model output
219
+ mock_predict.return_value = (
220
+ np.zeros((100, 88)), # note activations
221
+ np.zeros((100, 88)), # onsets
222
+ np.zeros((100, 1)) # contours
223
+ )
224
+
225
+ result = transcribe_audio(sample_audio_file, temp_storage_dir)
226
+
227
+ assert result.exists()
228
+ assert result.suffix == ".mid"
229
+ ```
230
+
231
+ ## Testing Patterns
232
+
233
+ ### Testing API Endpoints
234
+
235
+ ```python
236
+ from fastapi.testclient import TestClient
237
+
238
+ def test_submit_transcription(test_client, mock_redis):
239
+ """Test transcription submission endpoint."""
240
+ response = test_client.post(
241
+ "/api/v1/transcribe",
242
+ json={"youtube_url": "https://www.youtube.com/watch?v=..."}
243
+ )
244
+
245
+ assert response.status_code == 201
246
+ data = response.json()
247
+ assert "job_id" in data
248
+ assert data["status"] == "queued"
249
+ ```
250
+
251
+ ### Testing Async Functions
252
+
253
+ ```python
254
+ import pytest
255
+
256
+ @pytest.mark.asyncio
257
+ async def test_async_function():
258
+ """Test async function."""
259
+ result = await async_operation()
260
+ assert result == expected_value
261
+ ```
262
+
263
+ ### Testing WebSocket Connections
264
+
265
+ ```python
266
+ def test_websocket(test_client, sample_job_id):
267
+ """Test WebSocket connection."""
268
+ with test_client.websocket_connect(f"/api/v1/jobs/{sample_job_id}/stream") as websocket:
269
+ data = websocket.receive_json()
270
+ assert data["type"] == "progress"
271
+ assert "job_id" in data
272
+ ```
273
+
274
+ ### Testing Error Scenarios
275
+
276
+ ```python
277
+ def test_video_too_long(test_client):
278
+ """Test error handling for videos exceeding duration limit."""
279
+ with patch('utils.check_video_availability') as mock_check:
280
+ mock_check.return_value = {
281
+ 'available': False,
282
+ 'reason': 'Video too long (max 15 minutes)'
283
+ }
284
+
285
+ response = test_client.post(
286
+ "/api/v1/transcribe",
287
+ json={"youtube_url": "https://www.youtube.com/watch?v=long"}
288
+ )
289
+
290
+ assert response.status_code == 422
291
+ assert "too long" in response.json()["detail"]
292
+ ```
293
+
294
+ ### Testing Retries
295
+
296
+ ```python
297
+ def test_retry_on_network_error():
298
+ """Test that function retries on network error."""
299
+ mock_func = MagicMock()
300
+ mock_func.side_effect = [
301
+ ConnectionError("Network timeout"), # First call fails
302
+ ConnectionError("Network timeout"), # Second call fails
303
+ {"success": True} # Third call succeeds
304
+ ]
305
+
306
+ # Function should retry and eventually succeed
307
+ result = function_with_retry(mock_func)
308
+ assert result == {"success": True}
309
+ assert mock_func.call_count == 3
310
+ ```
311
+
312
+ ### Parametrized Tests
313
+
314
+ Test multiple inputs efficiently:
315
+
316
+ ```python
317
+ @pytest.mark.parametrize("url,expected_valid,expected_id", [
318
+ ("https://www.youtube.com/watch?v=dQw4w9WgXcQ", True, "dQw4w9WgXcQ"),
319
+ ("https://youtu.be/dQw4w9WgXcQ", True, "dQw4w9WgXcQ"),
320
+ ("https://vimeo.com/12345", False, None),
321
+ ("not-a-url", False, None),
322
+ ])
323
+ def test_url_validation(url, expected_valid, expected_id):
324
+ """Test URL validation with multiple inputs."""
325
+ is_valid, result = validate_youtube_url(url)
326
+ assert is_valid == expected_valid
327
+ if expected_valid:
328
+ assert result == expected_id
329
+ ```
330
+
331
+ ## Testing Pipeline Stages
332
+
333
+ ### Audio Download
334
+
335
+ ```python
336
+ @patch('pipeline.yt_dlp.YoutubeDL')
337
+ def test_download_audio_success(mock_ydl_class, temp_storage_dir):
338
+ """Test successful audio download."""
339
+ mock_ydl = MagicMock()
340
+ mock_ydl_class.return_value.__enter__.return_value = mock_ydl
341
+
342
+ result = download_audio("https://youtube.com/watch?v=...", temp_storage_dir)
343
+
344
+ assert result.exists()
345
+ assert result.suffix == ".wav"
346
+ ```
347
+
348
+ ### Source Separation
349
+
350
+ ```python
351
+ @patch('pipeline.demucs.separate.main')
352
+ def test_separate_sources(mock_demucs, sample_audio_file, temp_storage_dir):
353
+ """Test source separation."""
354
+ # Create mock output files
355
+ stems_dir = temp_storage_dir / "htdemucs" / "test_audio"
356
+ stems_dir.mkdir(parents=True)
357
+ for stem in ["drums", "bass", "vocals", "other"]:
358
+ (stems_dir / f"{stem}.wav").touch()
359
+
360
+ result = separate_sources(sample_audio_file, temp_storage_dir)
361
+
362
+ assert all(stem in result for stem in ["drums", "bass", "vocals", "other"])
363
+ assert all(path.exists() for path in result.values())
364
+ ```
365
+
366
+ ### Transcription
367
+
368
+ ```python
369
+ @patch('pipeline.basic_pitch.inference.predict')
370
+ def test_transcribe_audio(mock_predict, sample_audio_file, temp_storage_dir):
371
+ """Test audio transcription."""
372
+ mock_predict.return_value = (
373
+ np.random.rand(100, 88),
374
+ np.random.rand(100, 88),
375
+ np.random.rand(100, 1)
376
+ )
377
+
378
+ result = transcribe_audio(sample_audio_file, temp_storage_dir)
379
+
380
+ assert result.exists()
381
+ assert result.suffix == ".mid"
382
+ ```
383
+
384
+ ### MusicXML Generation
385
+
386
+ ```python
387
+ @patch('pipeline.music21.converter.parse')
388
+ def test_generate_musicxml(mock_parse, sample_midi_file, temp_storage_dir):
389
+ """Test MusicXML generation."""
390
+ mock_score = MagicMock()
391
+ mock_parse.return_value = mock_score
392
+
393
+ result = generate_musicxml(sample_midi_file, temp_storage_dir)
394
+
395
+ assert result.exists()
396
+ assert result.suffix == ".musicxml"
397
+ mock_score.write.assert_called_once()
398
+ ```
399
+
400
+ ## Troubleshooting
401
+
402
+ ### Common Issues
403
+
404
+ **Import Errors**
405
+
406
+ ```bash
407
+ # Ensure backend directory is in PYTHONPATH
408
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
409
+ pytest
410
+ ```
411
+
412
+ **Redis Connection Errors**
413
+
414
+ ```python
415
+ # Always mock Redis in tests unless testing Redis specifically
416
+ @pytest.fixture(autouse=True)
417
+ def mock_redis():
418
+ with patch('main.redis_client') as mock:
419
+ yield mock
420
+ ```
421
+
422
+ **File Permission Errors**
423
+
424
+ ```python
425
+ # Always use temp directories
426
+ @pytest.fixture
427
+ def temp_storage_dir():
428
+ temp_dir = tempfile.mkdtemp()
429
+ yield Path(temp_dir)
430
+ shutil.rmtree(temp_dir, ignore_errors=True)
431
+ ```
432
+
433
+ **GPU Not Available**
434
+
435
+ ```python
436
+ # Mark GPU tests and skip if unavailable
437
+ @pytest.mark.gpu
438
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available")
439
+ def test_gpu_processing():
440
+ ...
441
+ ```
442
+
443
+ ### Debugging Failed Tests
444
+
445
+ ```bash
446
+ # Show print statements
447
+ pytest -s
448
+
449
+ # Verbose output
450
+ pytest -vv
451
+
452
+ # Drop into debugger on failure
453
+ pytest --pdb
454
+
455
+ # Run only failed tests
456
+ pytest --lf
457
+ ```
458
+
459
+ ### Performance Issues
460
+
461
+ ```bash
462
+ # Identify slow tests
463
+ pytest --durations=10
464
+
465
+ # Run tests in parallel
466
+ pytest -n auto # Requires pytest-xdist
467
+ ```
468
+
469
+ ## Best Practices
470
+
471
+ 1. **Mock external dependencies**: Don't make real API calls, network requests, or ML inferences
472
+ 2. **Use fixtures**: Share common setup code across tests
473
+ 3. **Test edge cases**: Empty inputs, None values, boundary conditions
474
+ 4. **Clean up resources**: Always clean up temp files, connections
475
+ 5. **Keep tests independent**: Tests should not depend on each other
476
+ 6. **Write descriptive names**: Test names should explain what they verify
477
+ 7. **Test one thing**: Each test should verify one specific behavior
478
+ 8. **Use markers**: Tag tests by type (unit, integration, slow, gpu)
479
+
480
+ ## Example Test File
481
+
482
+ Complete example showing best practices:
483
+
484
+ ```python
485
+ """Tests for audio processing pipeline."""
486
+ import pytest
487
+ from pathlib import Path
488
+ from unittest.mock import patch, MagicMock
489
+ import numpy as np
490
+ from pipeline import download_audio, separate_sources, transcribe_audio
491
+
492
+
493
+ class TestAudioDownload:
494
+ """Test audio download stage."""
495
+
496
+ @patch('pipeline.yt_dlp.YoutubeDL')
497
+ def test_success(self, mock_ydl_class, temp_storage_dir):
498
+ """Test successful audio download."""
499
+ mock_ydl = MagicMock()
500
+ mock_ydl_class.return_value.__enter__.return_value = mock_ydl
501
+
502
+ result = download_audio("https://youtube.com/watch?v=test", temp_storage_dir)
503
+
504
+ assert result.exists()
505
+ assert result.suffix == ".wav"
506
+ mock_ydl.download.assert_called_once()
507
+
508
+ @patch('pipeline.yt_dlp.YoutubeDL')
509
+ def test_network_error(self, mock_ydl_class, temp_storage_dir):
510
+ """Test handling of network error."""
511
+ import yt_dlp
512
+ mock_ydl = MagicMock()
513
+ mock_ydl.download.side_effect = yt_dlp.utils.DownloadError("Network error")
514
+ mock_ydl_class.return_value.__enter__.return_value = mock_ydl
515
+
516
+ with pytest.raises(Exception) as exc_info:
517
+ download_audio("https://youtube.com/watch?v=test", temp_storage_dir)
518
+
519
+ assert "Network error" in str(exc_info.value)
520
+ ```
docs/testing/baseline-accuracy.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Baseline Accuracy Report
2
+
3
+ **Date**: 2024-12-24
4
+ **Pipeline Version**: Phase 1 Complete (MusicXML corruption fixes, MIDI export, rate limiting)
5
+ **Test Suite**: 10 diverse piano videos
6
+
7
+ ## Executive Summary
8
+
9
+ This report establishes the baseline transcription accuracy for the Rescored MVP pipeline after Phase 1 improvements.
10
+
11
+ **Initial Test Results** (Before Bug Fixes):
12
+ - Overall Success Rate: **10%** (1/10 videos)
13
+ - Videos Blocked: 3 (YouTube copyright/availability)
14
+ - Code Bugs Found: 6 (all fixed ✅)
15
+ - Successful Test: simple_melody (2,588 notes, 122 measures)
16
+
17
+ **Expected After Fixes**:
18
+ - Success Rate: **70-80%** (7-8/10 videos, excluding blocked ones)
19
+ - All code bugs resolved
20
+ - Need to replace 3 blocked videos with alternatives
21
+
22
+ **Key Finding**: Measure timing accuracy is imperfect (78% of measures show duration warnings), but this is expected for ML-based transcription. MusicXML files load successfully in notation software.
23
+
24
+ ## Test Videos
25
+
26
+ | ID | Description | Difficulty | Expected Accuracy | URL |
27
+ |----|-------------|------------|-------------------|-----|
28
+ | simple_melody | C major scale practice | Easy | >80% | [Link](https://www.youtube.com/watch?v=TK1Ij_-mank) |
29
+ | twinkle_twinkle | Twinkle Twinkle Little Star | Easy | >75% | [Link](https://www.youtube.com/watch?v=YCZ_d_4ZEqk) |
30
+ | fur_elise | Beethoven - Für Elise (simplified) | Medium | 60-70% | [Link](https://www.youtube.com/watch?v=_mVW8tgGY_w) |
31
+ | chopin_nocturne | Chopin - Nocturne Op. 9 No. 2 | Hard | 50-60% | [Link](https://www.youtube.com/watch?v=9E6b3swbnWg) |
32
+ | canon_in_d | Pachelbel - Canon in D | Medium | 60-70% | [Link](https://www.youtube.com/watch?v=NlprozGcs80) |
33
+ | river_flows | Yiruma - River Flows in You | Medium | 60-70% | [Link](https://www.youtube.com/watch?v=7maJOI3QMu0) |
34
+ | moonlight_sonata | Beethoven - Moonlight Sonata | Medium | 60-70% | [Link](https://www.youtube.com/watch?v=4Tr0otuiQuU) |
35
+ | jazz_blues | Simple jazz blues piano | Medium | 55-65% | [Link](https://www.youtube.com/watch?v=F3W_alUuFkA) |
36
+ | claire_de_lune | Debussy - Clair de Lune | Hard | 50-60% | [Link](https://www.youtube.com/watch?v=WNcsUNKlAKw) |
37
+ | la_campanella | Liszt - La Campanella | Very Hard | 40-50% | [Link](https://www.youtube.com/watch?v=MD6xMyuZls0) |
38
+
39
+ ## Results
40
+
41
+ ### Overall Statistics
42
+
43
+ (To be filled after test completion)
44
+
45
+ - **Total Tests**: 10
46
+ - **Successful**: TBD
47
+ - **Failed**: TBD
48
+ - **Success Rate**: TBD%
49
+
50
+ ### Per-Video Results
51
+
52
+ #### Easy Difficulty (2 videos)
53
+
54
+ **simple_melody** ✅:
55
+ - Status: **SUCCESS**
56
+ - MIDI Notes: 2,588
57
+ - Measures: 122
58
+ - Duration: 245.2 seconds
59
+ - Separation Quality: 99.3% energy in 'other' stem (excellent)
60
+ - Measure Warnings: 95/122 (78%) - typical for ML transcription
61
+ - Issues: None - clean transcription
62
+
63
+ **twinkle_twinkle** ❌:
64
+ - Status: **BLOCKED**
65
+ - Error: "Video unavailable"
66
+ - Action: Replace with alternative video
67
+
68
+ #### Medium Difficulty (5 videos)
69
+
70
+ **fur_elise** ❌:
71
+ - Status: **BLOCKED**
72
+ - Error: "Video unavailable"
73
+ - Action: Replace with alternative video
74
+
75
+ **canon_in_d** ❌ → ✅:
76
+ - Status: **FIXED**
77
+ - Error: NoneType velocity comparison (Bug #2a)
78
+ - Fix Applied: Safe velocity handling in deduplication
79
+ - Expected: Success on re-run
80
+
81
+ **river_flows** ❌ → ✅:
82
+ - Status: **FIXED**
83
+ - Error: NoneType velocity comparison (Bug #2a)
84
+ - Fix Applied: Safe velocity handling
85
+ - Expected: Success on re-run
86
+
87
+ **moonlight_sonata** ❌ → ✅:
88
+ - Status: **FIXED**
89
+ - Error: NoneType velocity comparison (Bug #2a)
90
+ - Fix Applied: Safe velocity handling
91
+ - Expected: Success on re-run
92
+
93
+ **jazz_blues** ❌:
94
+ - Status: **BLOCKED**
95
+ - Error: "Blocked on copyright grounds"
96
+ - Action: Replace with public domain jazz piano
97
+
98
+ #### Hard Difficulty (2 videos)
99
+
100
+ **chopin_nocturne** ❌ → ✅:
101
+ - Status: **FIXED**
102
+ - Error: 2048th note duration in measure 129 (Bug #2b)
103
+ - Fix Applied: Increased minimum duration threshold to 128th note
104
+ - Expected: Success on re-run
105
+
106
+ **claire_de_lune** ❌ → ✅:
107
+ - Status: **FIXED**
108
+ - Error: 2048th note duration in measure 30 (Bug #2b)
109
+ - Fix Applied: Increased minimum duration threshold
110
+ - Expected: Success on re-run
111
+
112
+ #### Very Hard Difficulty (1 video)
113
+
114
+ **la_campanella** ❌ → ✅:
115
+ - Status: **FIXED**
116
+ - Error: NoneType velocity comparison (Bug #2a)
117
+ - Fix Applied: Safe velocity handling
118
+ - Expected: Success on re-run (may have low accuracy due to extreme difficulty)
119
+
120
+ ## Common Failure Modes
121
+
122
+ Detailed analysis in [failure-modes.md](failure-modes.md)
123
+
124
+ ### 1. Video Availability (30% of failures)
125
+ - YouTube blocking, copyright claims, unavailable videos
126
+ - **Solution**: Replace with Creative Commons alternatives
127
+
128
+ ### 2. Code Bugs - All Fixed ✅ (60% of failures)
129
+ - **Bug 2a**: NoneType velocity comparison (4 videos)
130
+ - Fixed in [pipeline.py:403-409](../../backend/pipeline.py#L403-L409)
131
+ - **Bug 2b**: 2048th note duration errors (2 videos)
132
+ - Fixed in [pipeline.py:465-502](../../backend/pipeline.py#L465-L502)
133
+
134
+ ### 3. Measure Timing Accuracy (78% imperfect)
135
+ - Most measures deviate from exact 4.0 beats
136
+ - Range: 0.0 to 7.83 beats (should be 4.0)
137
+ - **Root causes**: basic-pitch timing, duration snapping, polyphony
138
+ - **Impact**: MusicXML loads but rhythms need manual correction
139
+ - **Status**: Expected limitation for ML transcription - Phase 3 will improve
140
+
141
+ ## Accuracy by Difficulty
142
+
143
+ | Difficulty | Avg Success Rate | Avg Notes | Avg Measures | Notes |
144
+ |------------|------------------|-----------|--------------|-------|
145
+ | Easy | TBD | TBD | TBD | TBD |
146
+ | Medium | TBD | TBD | TBD | TBD |
147
+ | Hard | TBD | TBD | TBD | TBD |
148
+ | Very Hard | TBD | TBD | TBD | TBD |
149
+
150
+ ## Known Limitations
151
+
152
+ Based on Phase 1 implementation:
153
+
154
+ 1. **Measure Timing**: Many measures show duration warnings (3.5-6.5 beats instead of exactly 4.0). This is expected due to:
155
+ - basic-pitch not perfectly aligned to beats
156
+ - Duration snapping to nearest valid note values
157
+ - Imperfect tempo detection
158
+
159
+ 2. **MusicXML Warnings**: music21 reports some "overfull measures" when parsing. These are handled gracefully but indicate timing imperfections.
160
+
161
+ 3. **Single Staff Only**: Grand staff (treble + bass) disabled in Phase 1 due to polyphony issues.
162
+
163
+ 4. **Piano Only**: Currently only transcribes "other" stem from Demucs, assuming piano/keyboard content.
164
+
165
+ ## Recommendations for Phase 3
166
+
167
+ (To be filled based on failure analysis)
168
+
169
+ 1. **Parameter Tuning**: TBD
170
+ 2. **Model Improvements**: TBD
171
+ 3. **Post-Processing**: TBD
172
+ 4. **Source Separation**: TBD
173
+
174
+ ## Appendix: Raw Test Data
175
+
176
+ Full test results JSON: `/tmp/rescored/accuracy_test_results.json`
177
+
178
+ Individual test outputs in: `/tmp/rescored/temp/accuracy_test_*/`
docs/testing/failure-modes.md ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Failure Mode Analysis
2
+
3
+ **Date**: 2024-12-24
4
+ **Test Suite**: Phase 2 Accuracy Baseline (10 videos)
5
+ **Pipeline Version**: Phase 1 Complete + Bug Fixes
6
+
7
+ ## Executive Summary
8
+
9
+ Initial accuracy testing revealed **3 major failure categories** affecting 9 out of 10 test videos:
10
+
11
+ 1. **Video Availability** (30% of failures) - YouTube blocking/copyright
12
+ 2. **Code Bugs** (60% of failures) - NoneType errors and 2048th note duration issues
13
+ 3. **MusicXML Export** (20% of failures) - Impossible duration errors
14
+
15
+ **All code bugs have been fixed.** Success rate expected to improve significantly with re-run.
16
+
17
+ ## Failure Categories
18
+
19
+ ### 1. Video Availability Issues (3 videos - 30%)
20
+
21
+ **Videos Affected:**
22
+ - `twinkle_twinkle` - "Video unavailable"
23
+ - `fur_elise` - "Video unavailable"
24
+ - `jazz_blues` - "Blocked in your country on copyright grounds"
25
+
26
+ **Root Cause:** YouTube access restrictions, not pipeline issues
27
+
28
+ **Mitigation:**
29
+ - Replace with alternative videos for same difficulty level
30
+ - Use Creative Commons licensed videos
31
+ - Host test videos on alternative platforms
32
+
33
+ **Impact:** Not a pipeline issue - will replace test videos
34
+
35
+ ---
36
+
37
+ ### 2. Code Bugs - Fixed ✅ (6 videos - 60%)
38
+
39
+ #### Bug 2a: NoneType Velocity Comparison (4 videos)
40
+
41
+ **Error:** `'<' not supported between instances of 'int' and 'NoneType'`
42
+
43
+ **Videos Affected:**
44
+ - `canon_in_d`
45
+ - `river_flows`
46
+ - `moonlight_sonata`
47
+ - `la_campanella`
48
+
49
+ **Root Cause:** In `_deduplicate_overlapping_notes()` at [pipeline.py:403-407](../backend/pipeline.py#L403-L407), the code tried to sort notes by velocity, but `note.volume.velocity` can return `None`.
50
+
51
+ **Fix Applied:**
52
+ ```python
53
+ def get_velocity(note):
54
+ if hasattr(note, 'volume') and hasattr(note.volume, 'velocity'):
55
+ vel = note.volume.velocity
56
+ return vel if vel is not None else 64
57
+ return 64
58
+
59
+ pitch_notes.sort(key=lambda x: (x.quarterLength, get_velocity(x)), reverse=True)
60
+ ```
61
+
62
+ **Status:** ✅ Fixed in [pipeline.py:403-409](../backend/pipeline.py#L403-L409)
63
+
64
+ ---
65
+
66
+ #### Bug 2b: 2048th Note Duration (2 videos)
67
+
68
+ **Error:** `In part (Piano), measure (X): Cannot convert "2048th" duration to MusicXML (too short).`
69
+
70
+ **Videos Affected:**
71
+ - `chopin_nocturne` (measure 129)
72
+ - `claire_de_lune` (measure 30)
73
+
74
+ **Root Cause:** `music21.makeMeasures()` creates extremely short rests (2048th notes) when filling gaps between notes. MusicXML export fails because these durations are too short to represent.
75
+
76
+ **Previous Attempts:**
77
+ 1. ❌ Filtered notes < 64th note (0.0625) before `makeMeasures()` - didn't work
78
+ 2. ❌ Removed notes < 64th note after `makeMeasures()` - still had issues
79
+
80
+ **Final Fix:**
81
+ - Increased minimum duration threshold to **128th note** (0.03125)
82
+ - Added logging to show how many notes/rests were removed
83
+ - Applied in `_remove_impossible_durations()` at [pipeline.py:465-502](../backend/pipeline.py#L465-L502)
84
+
85
+ **Status:** ✅ Fixed - more aggressive filtering
86
+
87
+ ---
88
+
89
+ ### 3. Successful Test Analysis
90
+
91
+ **Video:** `simple_melody` (C major scale practice, Easy difficulty)
92
+
93
+ **Results:**
94
+ - ✅ Successfully generated MusicXML
95
+ - **2,588 notes** detected
96
+ - **122 measures** created
97
+ - **245 seconds** duration
98
+ - **99.3% energy** preserved in 'other' stem (excellent separation)
99
+
100
+ **Key Metrics:**
101
+
102
+ | Metric | Value | Assessment |
103
+ |--------|-------|------------|
104
+ | Note density | 5.36 notes/sec | Reasonable for piano |
105
+ | Pitch range | G1 to A6 (62 semitones) | Full piano range |
106
+ | Polyphony | ~1.6 avg, ~6 max | Modest polyphony |
107
+ | Short notes | 271 (21%) under 200ms | Acceptable |
108
+ | Measure warnings | 95/122 (78%) | **High** - timing imperfect |
109
+
110
+ **Measure Timing Issues:**
111
+
112
+ 78% of measures showed duration warnings (range 0.0 - 7.83 beats instead of exactly 4.0). Examples:
113
+ - Measure 1: 0.00 beats (empty)
114
+ - Measure 30: 6.41 beats (overfull)
115
+ - Measure 69: 7.33 beats (very overfull)
116
+ - Measure 77: 7.83 beats (worst case)
117
+
118
+ **Root Causes:**
119
+ 1. **basic-pitch timing** not aligned to musical beats
120
+ 2. **Duration snapping** to nearest valid note value loses precision
121
+ 3. **Tempo detection** may be inaccurate
122
+ 4. **Polyphonic overlaps** creating extra duration
123
+
124
+ **Impact:** MusicXML loads in notation software but rhythms are imperfect. This is expected with ML-based transcription.
125
+
126
+ ---
127
+
128
+ ## Common Patterns
129
+
130
+ ### Pattern 1: Quiet Audio Detection
131
+ - Diagnostic shows RMS energy of 0.0432 (very quiet)
132
+ - 20% silence in audio
133
+ - basic-pitch may struggle with quiet inputs
134
+
135
+ ### Pattern 2: Separation Quality
136
+ - For `simple_melody`: 99.3% energy in 'other' stem ✅
137
+ - Only 0.2% in 'no_other' stem (excellent isolation)
138
+ - Demucs successfully isolated piano
139
+
140
+ ### Pattern 3: Measure Duration Accuracy
141
+ - **Only 22%** of measures have exactly 4.0 beats
142
+ - **78%** show timing deviations
143
+ - Range: -4.0 to +3.83 beats deviation
144
+ - Largest errors in complex sections (likely polyphony)
145
+
146
+ ---
147
+
148
+ ## Recommendations
149
+
150
+ ### Immediate Actions (Phase 2 completion)
151
+
152
+ 1. **Replace unavailable videos** with Creative Commons alternatives
153
+ 2. **Re-run accuracy suite** with bug fixes
154
+ 3. **Document actual baseline** with successful tests
155
+
156
+ ### Phase 3 Improvements (Accuracy Tuning)
157
+
158
+ 1. **Tempo Detection:**
159
+ - Implement better tempo detection (analyze beat patterns)
160
+ - Consider fixed tempo option for practice scales
161
+
162
+ 2. **Quantization:**
163
+ - Improve rhythmic quantization to align with detected beats
164
+ - Consider time signature detection
165
+
166
+ 3. **Post-Processing:**
167
+ - Add measure duration normalization
168
+ - Stretch/compress note timings to fit exact 4.0 beats
169
+
170
+ 4. **Parameter Tuning:**
171
+ - Test different `onset-threshold` values (current: 0.5)
172
+ - Test different `frame-threshold` values (current: 0.4)
173
+ - Experiment with `minimum-note-length`
174
+
175
+ ### Alternative Models (Phase 3 - Optional)
176
+
177
+ Consider testing:
178
+ - **MT3** (Google's Music Transformer) - better rhythm accuracy
179
+ - **htdemucs_6s** - 6-stem model with dedicated piano stem
180
+ - **Omnizart** - specialized for classical music
181
+
182
+ ---
183
+
184
+ ## Success Criteria
185
+
186
+ After fixes and re-run, we expect:
187
+
188
+ - ✅ **Video availability**: 7-8 working videos (replacing blocked ones)
189
+ - ✅ **Code bugs**: 0% failure rate (all fixed)
190
+ - ✅ **MusicXML export**: 100% success for available videos
191
+ - 🎯 **Overall success rate**: 70-80% (from 10%)
192
+
193
+ Measure timing accuracy will remain imperfect (~78% with warnings) but this is expected for MVP. Phase 3 will focus on improving timing accuracy.
194
+
195
+ ---
196
+
197
+ ## Appendix: Error Details
198
+
199
+ ### NoneType Error Stack Trace
200
+ ```
201
+ File "pipeline.py", line 403
202
+ pitch_notes.sort(key=lambda x: (x.quarterLength, x.volume.velocity if ...))
203
+ TypeError: '<' not supported between instances of 'int' and 'NoneType'
204
+ ```
205
+
206
+ ### 2048th Note Error Stack Trace
207
+ ```
208
+ File "music21/musicxml/m21ToXml.py", line 4702
209
+ mxNormalType.text = typeToMusicXMLType(tup.durationNormal.type)
210
+ MusicXMLExportException: In part (Piano), measure (129): Cannot convert "2048th" duration to MusicXML (too short).
211
+ ```
212
+
213
+ ---
214
+
215
+ **Last Updated**: 2024-12-24
216
+ **Next Review**: After accuracy suite re-run
docs/testing/frontend-testing.md ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Frontend Testing Guide
2
+
3
+ Comprehensive guide for testing the Rescored frontend.
4
+
5
+ ## Table of Contents
6
+
7
+ - [Setup](#setup)
8
+ - [Running Tests](#running-tests)
9
+ - [Test Structure](#test-structure)
10
+ - [Writing Tests](#writing-tests)
11
+ - [Testing Patterns](#testing-patterns)
12
+ - [Troubleshooting](#troubleshooting)
13
+
14
+ ## Setup
15
+
16
+ ### Install Test Dependencies
17
+
18
+ ```bash
19
+ cd frontend
20
+ npm install
21
+ ```
22
+
23
+ Test dependencies (already in `package.json`):
24
+ - `vitest`: Test framework
25
+ - `@testing-library/react`: React testing utilities
26
+ - `@testing-library/user-event`: User interaction simulation
27
+ - `@testing-library/jest-dom`: DOM matchers
28
+ - `jsdom`: DOM implementation for Node.js
29
+ - `@vitest/ui`: Interactive test UI
30
+ - `@vitest/coverage-v8`: Coverage reporting
31
+
32
+ ### Configuration
33
+
34
+ Test configuration is in `vitest.config.ts`:
35
+
36
+ ```typescript
37
+ export default defineConfig({
38
+ test: {
39
+ globals: true,
40
+ environment: 'jsdom',
41
+ setupFiles: ['./src/tests/setup.ts'],
42
+ coverage: {
43
+ provider: 'v8',
44
+ reporter: ['text', 'html', 'lcov'],
45
+ },
46
+ },
47
+ });
48
+ ```
49
+
50
+ ## Running Tests
51
+
52
+ ### Basic Commands
53
+
54
+ ```bash
55
+ # Run all tests
56
+ npm test
57
+
58
+ # Run in watch mode
59
+ npm test -- --watch
60
+
61
+ # Run with UI
62
+ npm run test:ui
63
+
64
+ # Run with coverage
65
+ npm run test:coverage
66
+
67
+ # Run specific file
68
+ npm test -- src/tests/api/client.test.ts
69
+
70
+ # Run tests matching pattern
71
+ npm test -- --grep "JobSubmission"
72
+ ```
73
+
74
+ ### Watch Mode
75
+
76
+ Watch mode automatically re-runs tests when files change:
77
+
78
+ ```bash
79
+ npm test -- --watch
80
+
81
+ # Watch specific file
82
+ npm test -- --watch src/tests/components/NotationCanvas.test.tsx
83
+ ```
84
+
85
+ ### Coverage Reports
86
+
87
+ ```bash
88
+ # Generate coverage report
89
+ npm run test:coverage
90
+
91
+ # Open HTML report
92
+ open coverage/index.html
93
+ ```
94
+
95
+ ## Test Structure
96
+
97
+ ### Test Files
98
+
99
+ Component tests live alongside components or in `src/tests/`:
100
+
101
+ ```
102
+ frontend/src/
103
+ ├── components/
104
+ │ ├── JobSubmission.tsx
105
+ │ └── JobSubmission.test.tsx # Option 1: Co-located
106
+ ├── tests/
107
+ │ ├── setup.ts # Test configuration
108
+ │ ├── fixtures.ts # Shared test data
109
+ │ ├── components/
110
+ │ │ └── JobSubmission.test.tsx # Option 2: Separate directory
111
+ │ └── api/
112
+ │ └── client.test.ts
113
+ ```
114
+
115
+ ### Test Organization
116
+
117
+ ```typescript
118
+ import { describe, it, expect, vi, beforeEach } from 'vitest';
119
+ import { render, screen } from '@testing-library/react';
120
+ import Component from './Component';
121
+
122
+ describe('Component', () => {
123
+ beforeEach(() => {
124
+ // Setup before each test
125
+ });
126
+
127
+ describe('Rendering', () => {
128
+ it('should render correctly', () => {
129
+ // Test rendering
130
+ });
131
+ });
132
+
133
+ describe('Interactions', () => {
134
+ it('should handle user input', async () => {
135
+ // Test interactions
136
+ });
137
+ });
138
+
139
+ describe('Edge Cases', () => {
140
+ it('should handle empty state', () => {
141
+ // Test edge cases
142
+ });
143
+ });
144
+ });
145
+ ```
146
+
147
+ ## Writing Tests
148
+
149
+ ### Basic Component Test
150
+
151
+ ```typescript
152
+ import { describe, it, expect } from 'vitest';
153
+ import { render, screen } from '@testing-library/react';
154
+ import MyComponent from './MyComponent';
155
+
156
+ describe('MyComponent', () => {
157
+ it('should render text', () => {
158
+ render(<MyComponent text="Hello" />);
159
+ expect(screen.getByText('Hello')).toBeInTheDocument();
160
+ });
161
+
162
+ it('should handle button click', async () => {
163
+ const user = userEvent.setup();
164
+ const handleClick = vi.fn();
165
+
166
+ render(<MyComponent onClick={handleClick} />);
167
+
168
+ const button = screen.getByRole('button');
169
+ await user.click(button);
170
+
171
+ expect(handleClick).toHaveBeenCalledTimes(1);
172
+ });
173
+ });
174
+ ```
175
+
176
+ ### Testing with User Interactions
177
+
178
+ Use `@testing-library/user-event` for realistic interactions:
179
+
180
+ ```typescript
181
+ import userEvent from '@testing-library/user-event';
182
+
183
+ it('should accept user input', async () => {
184
+ const user = userEvent.setup();
185
+ render(<JobSubmission />);
186
+
187
+ const input = screen.getByPlaceholderText(/youtube url/i);
188
+
189
+ // Type into input
190
+ await user.type(input, 'https://www.youtube.com/watch?v=...');
191
+ expect(input).toHaveValue('https://www.youtube.com/watch?v=...');
192
+
193
+ // Click button
194
+ const button = screen.getByRole('button', { name: /submit/i });
195
+ await user.click(button);
196
+
197
+ // Verify action
198
+ await waitFor(() => {
199
+ expect(mockSubmit).toHaveBeenCalled();
200
+ });
201
+ });
202
+ ```
203
+
204
+ ### Testing Async Operations
205
+
206
+ ```typescript
207
+ import { waitFor } from '@testing-library/react';
208
+
209
+ it('should load data', async () => {
210
+ const mockFetch = vi.fn().mockResolvedValue({
211
+ ok: true,
212
+ json: async () => ({ data: 'test' }),
213
+ });
214
+ global.fetch = mockFetch;
215
+
216
+ render(<DataComponent />);
217
+
218
+ await waitFor(() => {
219
+ expect(screen.getByText('test')).toBeInTheDocument();
220
+ });
221
+ });
222
+ ```
223
+
224
+ ### Mocking Dependencies
225
+
226
+ #### Mock API Client
227
+
228
+ ```typescript
229
+ vi.mock('../../api/client', () => ({
230
+ submitTranscription: vi.fn(),
231
+ getJobStatus: vi.fn(),
232
+ downloadScore: vi.fn(),
233
+ }));
234
+
235
+ import { submitTranscription } from '../../api/client';
236
+
237
+ it('should call API', async () => {
238
+ const mockSubmit = vi.mocked(submitTranscription);
239
+ mockSubmit.mockResolvedValue({ job_id: '123', status: 'queued' });
240
+
241
+ // Test component that uses submitTranscription
242
+ // ...
243
+
244
+ expect(mockSubmit).toHaveBeenCalledWith('https://youtube.com/...');
245
+ });
246
+ ```
247
+
248
+ #### Mock Zustand Store
249
+
250
+ ```typescript
251
+ import { renderHook, act } from '@testing-library/react';
252
+ import { useScoreStore } from '../../store/scoreStore';
253
+
254
+ it('should update store', () => {
255
+ const { result } = renderHook(() => useScoreStore());
256
+
257
+ act(() => {
258
+ result.current.setMusicXML('<musicxml>...</musicxml>');
259
+ });
260
+
261
+ expect(result.current.musicXML).toBe('<musicxml>...</musicxml>');
262
+ });
263
+ ```
264
+
265
+ #### Mock VexFlow
266
+
267
+ ```typescript
268
+ // In setup.ts
269
+ vi.mock('vexflow', () => ({
270
+ Flow: {
271
+ Renderer: vi.fn(() => ({
272
+ resize: vi.fn(),
273
+ getContext: vi.fn(() => ({
274
+ clear: vi.fn(),
275
+ setFont: vi.fn(),
276
+ })),
277
+ })),
278
+ Stave: vi.fn(() => ({
279
+ addClef: vi.fn().mockReturnThis(),
280
+ addTimeSignature: vi.fn().mockReturnThis(),
281
+ setContext: vi.fn().mockReturnThis(),
282
+ draw: vi.fn(),
283
+ })),
284
+ },
285
+ }));
286
+ ```
287
+
288
+ ## Testing Patterns
289
+
290
+ ### Testing Form Submission
291
+
292
+ ```typescript
293
+ it('should submit form with valid data', async () => {
294
+ const user = userEvent.setup();
295
+ const onSubmit = vi.fn();
296
+
297
+ render(<Form onSubmit={onSubmit} />);
298
+
299
+ // Fill out form
300
+ await user.type(screen.getByLabelText(/url/i), 'https://youtube.com/...');
301
+
302
+ // Submit
303
+ await user.click(screen.getByRole('button', { name: /submit/i }));
304
+
305
+ // Verify
306
+ await waitFor(() => {
307
+ expect(onSubmit).toHaveBeenCalledWith({
308
+ url: 'https://youtube.com/...',
309
+ });
310
+ });
311
+ });
312
+ ```
313
+
314
+ ### Testing Error States
315
+
316
+ ```typescript
317
+ it('should show error message', async () => {
318
+ const mockFetch = vi.fn().mockRejectedValue(new Error('Network error'));
319
+ global.fetch = mockFetch;
320
+
321
+ render(<Component />);
322
+
323
+ await waitFor(() => {
324
+ expect(screen.getByText(/network error/i)).toBeInTheDocument();
325
+ });
326
+ });
327
+ ```
328
+
329
+ ### Testing Loading States
330
+
331
+ ```typescript
332
+ it('should show loading indicator', async () => {
333
+ const mockFetch = vi.fn(() =>
334
+ new Promise(resolve => setTimeout(() => resolve({ ok: true }), 100))
335
+ );
336
+ global.fetch = mockFetch;
337
+
338
+ render(<Component />);
339
+
340
+ // Should show loading
341
+ expect(screen.getByText(/loading/i)).toBeInTheDocument();
342
+
343
+ // Should hide loading after data loads
344
+ await waitFor(() => {
345
+ expect(screen.queryByText(/loading/i)).not.toBeInTheDocument();
346
+ });
347
+ });
348
+ ```
349
+
350
+ ### Testing WebSocket Connections
351
+
352
+ ```typescript
353
+ it('should handle WebSocket messages', () => {
354
+ const mockWS = {
355
+ addEventListener: vi.fn(),
356
+ send: vi.fn(),
357
+ close: vi.fn(),
358
+ };
359
+
360
+ global.WebSocket = vi.fn(() => mockWS) as any;
361
+
362
+ render(<WebSocketComponent />);
363
+
364
+ // Get message handler
365
+ const messageHandler = mockWS.addEventListener.mock.calls.find(
366
+ call => call[0] === 'message'
367
+ )?.[1];
368
+
369
+ // Simulate message
370
+ messageHandler?.({ data: JSON.stringify({ type: 'progress', progress: 50 }) });
371
+
372
+ // Verify UI updated
373
+ expect(screen.getByText(/50%/)).toBeInTheDocument();
374
+ });
375
+ ```
376
+
377
+ ### Testing Conditional Rendering
378
+
379
+ ```typescript
380
+ it('should render different states', () => {
381
+ const { rerender } = render(<StatusIndicator status="loading" />);
382
+ expect(screen.getByText(/loading/i)).toBeInTheDocument();
383
+
384
+ rerender(<StatusIndicator status="success" />);
385
+ expect(screen.getByText(/success/i)).toBeInTheDocument();
386
+
387
+ rerender(<StatusIndicator status="error" />);
388
+ expect(screen.getByText(/error/i)).toBeInTheDocument();
389
+ });
390
+ ```
391
+
392
+ ### Testing Canvas/VexFlow Components
393
+
394
+ ```typescript
395
+ it('should render notation', () => {
396
+ // Mock canvas context
397
+ const mockContext = {
398
+ fillRect: vi.fn(),
399
+ clearRect: vi.fn(),
400
+ beginPath: vi.fn(),
401
+ stroke: vi.fn(),
402
+ };
403
+
404
+ HTMLCanvasElement.prototype.getContext = vi.fn(() => mockContext) as any;
405
+
406
+ const { container } = render(<NotationCanvas musicXML={sampleXML} />);
407
+
408
+ // Verify canvas or SVG exists
409
+ const canvas = container.querySelector('canvas');
410
+ expect(canvas).toBeInTheDocument();
411
+ });
412
+ ```
413
+
414
+ ### Snapshot Testing
415
+
416
+ Use snapshots for stable UI components:
417
+
418
+ ```typescript
419
+ it('should match snapshot', () => {
420
+ const { container } = render(<StaticComponent />);
421
+ expect(container).toMatchSnapshot();
422
+ });
423
+ ```
424
+
425
+ **Update snapshots:**
426
+ ```bash
427
+ npm test -- -u
428
+ ```
429
+
430
+ ## Testing Custom Hooks
431
+
432
+ ```typescript
433
+ import { renderHook, act } from '@testing-library/react';
434
+ import { useCustomHook } from './useCustomHook';
435
+
436
+ it('should handle state changes', () => {
437
+ const { result } = renderHook(() => useCustomHook());
438
+
439
+ expect(result.current.count).toBe(0);
440
+
441
+ act(() => {
442
+ result.current.increment();
443
+ });
444
+
445
+ expect(result.current.count).toBe(1);
446
+ });
447
+ ```
448
+
449
+ ## Accessibility Testing
450
+
451
+ ```typescript
452
+ it('should be accessible', () => {
453
+ render(<Component />);
454
+
455
+ // Check for proper labels
456
+ expect(screen.getByLabelText(/input field/i)).toBeInTheDocument();
457
+
458
+ // Check for ARIA attributes
459
+ expect(screen.getByRole('button')).toHaveAttribute('aria-label', 'Submit');
460
+
461
+ // Check keyboard navigation
462
+ const button = screen.getByRole('button');
463
+ button.focus();
464
+ expect(button).toHaveFocus();
465
+ });
466
+ ```
467
+
468
+ ## Troubleshooting
469
+
470
+ ### Common Issues
471
+
472
+ **Canvas/VexFlow Errors**
473
+
474
+ ```typescript
475
+ // Mock canvas in setup.ts
476
+ beforeEach(() => {
477
+ HTMLCanvasElement.prototype.getContext = vi.fn(() => ({
478
+ fillRect: vi.fn(),
479
+ // ... other canvas methods
480
+ })) as any;
481
+ });
482
+ ```
483
+
484
+ **WebSocket Errors**
485
+
486
+ ```typescript
487
+ // Mock WebSocket in setup.ts
488
+ global.WebSocket = vi.fn(() => ({
489
+ addEventListener: vi.fn(),
490
+ send: vi.fn(),
491
+ close: vi.fn(),
492
+ readyState: WebSocket.OPEN,
493
+ })) as any;
494
+ ```
495
+
496
+ **Module Import Errors**
497
+
498
+ ```typescript
499
+ // Use vi.mock at top of test file
500
+ vi.mock('external-module', () => ({
501
+ default: vi.fn(),
502
+ namedExport: vi.fn(),
503
+ }));
504
+ ```
505
+
506
+ **Async Test Timeouts**
507
+
508
+ ```typescript
509
+ // Increase timeout for slow tests
510
+ it('slow test', async () => {
511
+ // ...
512
+ }, { timeout: 10000 });
513
+ ```
514
+
515
+ ### Debugging Tests
516
+
517
+ ```bash
518
+ # Run with UI for interactive debugging
519
+ npm run test:ui
520
+
521
+ # Run specific test in watch mode
522
+ npm test -- --watch --grep "test name"
523
+
524
+ # Debug in VS Code
525
+ # Add breakpoint and use "Debug Test" code lens
526
+ ```
527
+
528
+ ### Performance Issues
529
+
530
+ ```bash
531
+ # Identify slow tests
532
+ npm test -- --reporter=verbose
533
+
534
+ # Run tests in parallel (default)
535
+ npm test
536
+
537
+ # Run sequentially if needed
538
+ npm test -- --no-threads
539
+ ```
540
+
541
+ ## Best Practices
542
+
543
+ 1. **Test user behavior, not implementation**: Focus on what users see and do
544
+ 2. **Use accessible queries**: Prefer `getByRole`, `getByLabelText` over `getByTestId`
545
+ 3. **Avoid testing implementation details**: Don't test internal state or methods
546
+ 4. **Keep tests simple**: Each test should verify one thing
547
+ 5. **Use realistic data**: Test with data similar to production
548
+ 6. **Clean up**: Always clean up side effects (timers, listeners)
549
+ 7. **Mock external dependencies**: Don't make real API calls or WebSocket connections
550
+ 8. **Test edge cases**: Empty states, errors, loading states
551
+
552
+ ## Query Priority
553
+
554
+ Use queries in this order (most preferred first):
555
+
556
+ 1. **Accessible Queries**:
557
+ - `getByRole`
558
+ - `getByLabelText`
559
+ - `getByPlaceholderText`
560
+ - `getByText`
561
+
562
+ 2. **Semantic Queries**:
563
+ - `getByAltText`
564
+ - `getByTitle`
565
+
566
+ 3. **Test IDs** (last resort):
567
+ - `getByTestId`
568
+
569
+ Example:
570
+
571
+ ```typescript
572
+ // Good
573
+ const button = screen.getByRole('button', { name: /submit/i });
574
+ const input = screen.getByLabelText(/email/i);
575
+
576
+ // Acceptable
577
+ const image = screen.getByAltText('Logo');
578
+
579
+ // Last resort
580
+ const element = screen.getByTestId('custom-element');
581
+ ```
582
+
583
+ ## Example Test File
584
+
585
+ Complete example showing best practices:
586
+
587
+ ```typescript
588
+ import { describe, it, expect, vi, beforeEach } from 'vitest';
589
+ import { render, screen, waitFor } from '@testing-library/react';
590
+ import userEvent from '@testing-library/user-event';
591
+ import JobSubmission from './JobSubmission';
592
+
593
+ vi.mock('../../api/client', () => ({
594
+ submitTranscription: vi.fn(),
595
+ }));
596
+
597
+ import { submitTranscription } from '../../api/client';
598
+
599
+ describe('JobSubmission', () => {
600
+ beforeEach(() => {
601
+ vi.clearAllMocks();
602
+ });
603
+
604
+ describe('Rendering', () => {
605
+ it('should render input and button', () => {
606
+ render(<JobSubmission />);
607
+
608
+ expect(screen.getByPlaceholderText(/youtube url/i)).toBeInTheDocument();
609
+ expect(screen.getByRole('button', { name: /transcribe/i })).toBeInTheDocument();
610
+ });
611
+ });
612
+
613
+ describe('User Interactions', () => {
614
+ it('should accept and submit valid URL', async () => {
615
+ const user = userEvent.setup();
616
+ const mockSubmit = vi.mocked(submitTranscription);
617
+ mockSubmit.mockResolvedValue({ job_id: '123', status: 'queued' });
618
+
619
+ render(<JobSubmission />);
620
+
621
+ const input = screen.getByPlaceholderText(/youtube url/i);
622
+ const button = screen.getByRole('button', { name: /transcribe/i });
623
+
624
+ await user.type(input, 'https://www.youtube.com/watch?v=...');
625
+ await user.click(button);
626
+
627
+ await waitFor(() => {
628
+ expect(mockSubmit).toHaveBeenCalledWith(
629
+ 'https://www.youtube.com/watch?v=...',
630
+ expect.any(Object)
631
+ );
632
+ });
633
+ });
634
+ });
635
+
636
+ describe('Error Handling', () => {
637
+ it('should show error for invalid URL', async () => {
638
+ const user = userEvent.setup();
639
+ render(<JobSubmission />);
640
+
641
+ const input = screen.getByPlaceholderText(/youtube url/i);
642
+ const button = screen.getByRole('button', { name: /transcribe/i });
643
+
644
+ await user.type(input, 'invalid-url');
645
+ await user.click(button);
646
+
647
+ await waitFor(() => {
648
+ expect(screen.getByText(/invalid/i)).toBeInTheDocument();
649
+ });
650
+ });
651
+ });
652
+ });
653
+ ```
docs/testing/overview.md ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Testing Guide
2
+
3
+ Complete testing guide for the Rescored project.
4
+
5
+ ## Quick Start
6
+
7
+ ### Backend Tests
8
+
9
+ ```bash
10
+ cd backend
11
+ pip install -r requirements-test.txt
12
+ pytest --cov
13
+ ```
14
+
15
+ ### Frontend Tests
16
+
17
+ ```bash
18
+ cd frontend
19
+ npm install
20
+ npm test
21
+ ```
22
+
23
+ ## Testing Philosophy
24
+
25
+ Rescored follows these testing principles:
26
+
27
+ 1. **Test behavior, not implementation** - Verify what the code does, not how
28
+ 2. **Write tests that give confidence** - Focus on high-value tests that catch real bugs
29
+ 3. **Keep tests maintainable** - Tests should be easy to understand and modify
30
+ 4. **Test at the right level** - Unit tests for logic, integration tests for workflows
31
+ 5. **Fast feedback loops** - Tests should run quickly to enable rapid development
32
+
33
+ ## Test Suites
34
+
35
+ ### Backend Test Suite (`backend/tests/`)
36
+
37
+ - **Unit Tests** (`test_utils.py`) - URL validation, video availability checks
38
+ - **API Tests** (`test_api.py`) - FastAPI endpoints, WebSocket connections
39
+ - **Pipeline Tests** (`test_pipeline.py`) - Audio processing, transcription, MusicXML generation
40
+ - **Task Tests** (`test_tasks.py`) - Celery workers, job processing, progress updates
41
+
42
+ **Features**: Mocked external dependencies (yt-dlp, Redis, ML models), temporary file handling, parametrized tests, coverage reporting
43
+
44
+ ### Frontend Test Suite (`frontend/src/tests/`)
45
+
46
+ - **API Client Tests** (`api/client.test.ts`) - HTTP requests, WebSocket connections
47
+ - **Component Tests** (`components/`) - JobSubmission, NotationCanvas, PlaybackControls
48
+ - **Store Tests** (`store/useScoreStore.test.ts`) - Zustand state management
49
+
50
+ **Features**: React Testing Library, user event simulation, mocked VexFlow and Tone.js, coverage reporting
51
+
52
+ ## Coverage Goals
53
+
54
+ | Component | Target | Priority |
55
+ |-----------|--------|----------|
56
+ | Backend Utils | 90%+ | High |
57
+ | Backend Pipeline | 85%+ | Critical |
58
+ | Backend API | 80%+ | High |
59
+ | Frontend API Client | 85%+ | Critical |
60
+ | Frontend Components | 75%+ | High |
61
+ | Frontend Store | 80%+ | High |
62
+
63
+ ## Running Tests
64
+
65
+ ### Backend
66
+
67
+ ```bash
68
+ # Run all tests
69
+ pytest
70
+
71
+ # With coverage
72
+ pytest --cov --cov-report=html
73
+
74
+ # Specific tests
75
+ pytest tests/test_utils.py
76
+ pytest tests/test_utils.py::TestValidateYouTubeURL::test_valid_watch_url
77
+
78
+ # By category
79
+ pytest -m unit # Only unit tests
80
+ pytest -m integration # Only integration tests
81
+ pytest -m "not slow" # Exclude slow tests
82
+ pytest -m "not gpu" # Exclude GPU tests
83
+
84
+ # Debugging
85
+ pytest -vv # Verbose output
86
+ pytest -s # Show print statements
87
+ pytest --pdb # Drop into debugger on failure
88
+ pytest --lf # Run last failed tests
89
+ ```
90
+
91
+ ### Frontend
92
+
93
+ ```bash
94
+ # Run all tests
95
+ npm test
96
+
97
+ # Watch mode
98
+ npm test -- --watch
99
+
100
+ # With UI
101
+ npm run test:ui
102
+
103
+ # With coverage
104
+ npm run test:coverage
105
+
106
+ # Specific tests
107
+ npm test -- src/tests/api/client.test.ts
108
+ npm test -- --grep "JobSubmission"
109
+ ```
110
+
111
+ ## Test Structure
112
+
113
+ ### Backend
114
+
115
+ ```
116
+ backend/tests/
117
+ ├── conftest.py # Shared fixtures (temp dirs, mock Redis, sample files)
118
+ ├── test_utils.py # Utility function tests
119
+ ├── test_api.py # API endpoint tests
120
+ ├── test_pipeline.py # Audio processing tests
121
+ └── test_tasks.py # Celery task tests
122
+ ```
123
+
124
+ ### Frontend
125
+
126
+ ```
127
+ frontend/src/tests/
128
+ ├── setup.ts # Test configuration (mocks for VexFlow, Tone.js, WebSocket)
129
+ ├── fixtures.ts # Shared test data (MusicXML, job responses, etc.)
130
+ ├── api/client.test.ts
131
+ ├── components/
132
+ │ ├── JobSubmission.test.tsx
133
+ │ ├── NotationCanvas.test.tsx
134
+ │ └── PlaybackControls.test.tsx
135
+ └── store/useScoreStore.test.ts
136
+ ```
137
+
138
+ ## Common Patterns
139
+
140
+ ### Backend Testing
141
+
142
+ ```python
143
+ # Mock external services
144
+ @patch('pipeline.yt_dlp.YoutubeDL')
145
+ def test_download_audio(mock_ydl_class, temp_storage_dir):
146
+ mock_ydl = MagicMock()
147
+ mock_ydl_class.return_value.__enter__.return_value = mock_ydl
148
+
149
+ result = download_audio("https://youtube.com/...", temp_storage_dir)
150
+
151
+ assert result.exists()
152
+ assert result.suffix == ".wav"
153
+
154
+ # Test API endpoints
155
+ def test_submit_transcription(test_client):
156
+ response = test_client.post(
157
+ "/api/v1/transcribe",
158
+ json={"youtube_url": "https://www.youtube.com/watch?v=..."}
159
+ )
160
+
161
+ assert response.status_code == 201
162
+ assert "job_id" in response.json()
163
+
164
+ # Parametrized tests
165
+ @pytest.mark.parametrize("url,expected_valid", [
166
+ ("https://www.youtube.com/watch?v=dQw4w9WgXcQ", True),
167
+ ("https://vimeo.com/12345", False),
168
+ ])
169
+ def test_url_validation(url, expected_valid):
170
+ is_valid, _ = validate_youtube_url(url)
171
+ assert is_valid == expected_valid
172
+ ```
173
+
174
+ ### Frontend Testing
175
+
176
+ ```typescript
177
+ // Test components with user interaction
178
+ it('should submit form', async () => {
179
+ const user = userEvent.setup();
180
+ const onSubmit = vi.fn();
181
+
182
+ render(<JobSubmission onSubmit={onSubmit} />);
183
+
184
+ const input = screen.getByPlaceholderText(/youtube url/i);
185
+ await user.type(input, 'https://www.youtube.com/watch?v=...');
186
+
187
+ const button = screen.getByRole('button', { name: /submit/i });
188
+ await user.click(button);
189
+
190
+ await waitFor(() => {
191
+ expect(onSubmit).toHaveBeenCalled();
192
+ });
193
+ });
194
+
195
+ // Mock API calls
196
+ vi.mock('../../api/client', () => ({
197
+ submitTranscription: vi.fn(),
198
+ }));
199
+
200
+ it('should call API', async () => {
201
+ const mockSubmit = vi.mocked(submitTranscription);
202
+ mockSubmit.mockResolvedValue({ job_id: '123' });
203
+
204
+ // Test component that uses submitTranscription
205
+ // ...
206
+ });
207
+
208
+ // Test store
209
+ it('should update store', () => {
210
+ const { result } = renderHook(() => useScoreStore());
211
+
212
+ act(() => {
213
+ result.current.setMusicXML('<musicxml>...</musicxml>');
214
+ });
215
+
216
+ expect(result.current.musicXML).toBe('<musicxml>...</musicxml>');
217
+ });
218
+ ```
219
+
220
+ ## Mocking Strategy
221
+
222
+ ### Backend
223
+ - **External Services**: Mock yt-dlp, Redis, Celery
224
+ - **ML Models**: Mock Demucs and basic-pitch for fast tests
225
+ - **File System**: Use temporary directories
226
+
227
+ ### Frontend
228
+ - **API Calls**: Mock fetch with vitest
229
+ - **WebSockets**: Mock WebSocket connections
230
+ - **Browser APIs**: Mock Canvas, Audio, localStorage
231
+ - **Libraries**: Mock VexFlow, Tone.js
232
+
233
+ ## Best Practices
234
+
235
+ ### General
236
+ 1. ✅ Write descriptive test names that explain the scenario
237
+ 2. ✅ Keep tests simple and focused (one thing per test)
238
+ 3. ✅ Use Arrange-Act-Assert structure
239
+ 4. ✅ Make tests independent (no shared state)
240
+ 5. ✅ Clean up resources (files, connections, timers)
241
+ 6. ✅ Mock external dependencies
242
+ 7. ✅ Add tests when fixing bugs
243
+ 8. ✅ Keep test code as clean as production code
244
+
245
+ ### Backend-Specific
246
+ - Use pytest fixtures for shared setup
247
+ - Mock yt-dlp, Redis, Celery, ML models
248
+ - Use temporary directories for file operations
249
+ - Mark slow/GPU tests with `@pytest.mark.slow` and `@pytest.mark.gpu`
250
+ - Test both success and error paths
251
+
252
+ ### Frontend-Specific
253
+ - Test user behavior, not implementation details
254
+ - Use accessible queries: `getByRole`, `getByLabelText` (not `getByTestId`)
255
+ - Mock API calls and WebSocket connections
256
+ - Test loading states and error handling
257
+ - Clean up side effects (timers, event listeners)
258
+
259
+ ## Troubleshooting
260
+
261
+ ### Backend
262
+
263
+ **Import errors**
264
+ ```bash
265
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
266
+ ```
267
+
268
+ **Redis connection errors** - Always mock Redis unless testing Redis specifically
269
+
270
+ **GPU tests failing** - Mark with `@pytest.mark.gpu` and skip if unavailable
271
+
272
+ ### Frontend
273
+
274
+ **Canvas errors** - Mock canvas context in `setup.ts`
275
+
276
+ **WebSocket errors** - Mock WebSocket in `setup.ts`
277
+
278
+ **Module import errors** - Use `vi.mock()` at top of test file
279
+
280
+ **Async timeouts** - Increase timeout: `it('test', async () => { ... }, { timeout: 10000 })`
281
+
282
+ ## Test Performance
283
+
284
+ **Benchmarks:**
285
+ - Unit tests: < 100ms each
286
+ - Full backend suite: < 30 seconds
287
+ - Full frontend suite: < 20 seconds
288
+
289
+ **Optimization:**
290
+ - Mock expensive operations (ML inference, network calls)
291
+ - Use test markers to skip slow tests during development
292
+ - Parallelize tests (pytest-xdist for backend, vitest default)
293
+ - Cache expensive fixtures
294
+
295
+ ## CI/CD Integration
296
+
297
+ Tests run automatically on:
298
+ - **Pull Requests** - All tests must pass
299
+ - **Main Branch** - Full suite including slow tests
300
+ - **Nightly** - Extended test suite with real YouTube videos
301
+ - **Pre-release** - E2E tests, performance benchmarks
302
+
303
+ ## Detailed Guides
304
+
305
+ For detailed information, see:
306
+ - **[Backend Testing Guide](./backend-testing.md)** - In-depth backend testing patterns and examples
307
+ - **[Frontend Testing Guide](./frontend-testing.md)** - In-depth frontend testing patterns and examples
308
+ - **[Test Video Collection](./test-videos.md)** - Curated YouTube videos for testing transcription quality
309
+
310
+ ## Resources
311
+
312
+ - [pytest Documentation](https://docs.pytest.org/)
313
+ - [Vitest Documentation](https://vitest.dev/)
314
+ - [React Testing Library](https://testing-library.com/react)
315
+ - [FastAPI Testing](https://fastapi.tiangolo.com/tutorial/testing/)
docs/testing/test-videos.md ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test Video Collection
2
+
3
+ Curated collection of YouTube videos for testing transcription quality and edge cases.
4
+
5
+ ## Table of Contents
6
+
7
+ - [Simple Piano Tests](#simple-piano-tests)
8
+ - [Classical Piano](#classical-piano)
9
+ - [Pop Piano Covers](#pop-piano-covers)
10
+ - [Jazz Piano](#jazz-piano)
11
+ - [Complex/Challenging](#complexchallenging)
12
+ - [Edge Cases](#edge-cases)
13
+ - [Testing Criteria](#testing-criteria)
14
+
15
+ ## Simple Piano Tests
16
+
17
+ Use these for basic functionality and quick iteration.
18
+
19
+ ### 1. Twinkle Twinkle Little Star (Beginner Piano)
20
+ - **Duration**: ~1 minute
21
+ - **Tempo**: Slow (60-80 BPM)
22
+ - **Complexity**: Very simple melody, single notes
23
+ - **Expected Accuracy**: 95%+
24
+ - **Use For**: Smoke tests, basic functionality
25
+
26
+ ### 2. Mary Had a Little Lamb
27
+ - **Duration**: ~1 minute
28
+ - **Tempo**: Moderate (100 BPM)
29
+ - **Complexity**: Simple melody with consistent rhythm
30
+ - **Expected Accuracy**: 90%+
31
+ - **Use For**: Key signature detection, basic transcription
32
+
33
+ ### 3. Happy Birthday (Piano Solo)
34
+ - **Duration**: ~1 minute
35
+ - **Tempo**: Moderate (120 BPM)
36
+ - **Complexity**: Simple melody with occasional harmony
37
+ - **Expected Accuracy**: 85%+
38
+ - **Use For**: Time signature detection (3/4 time)
39
+
40
+ ## Classical Piano
41
+
42
+ Test with well-known classical pieces to verify quality.
43
+
44
+ ### 4. Chopin - Nocturne Op. 9 No. 2
45
+ - **Duration**: 4-5 minutes
46
+ - **Tempo**: Andante (60-70 BPM)
47
+ - **Complexity**: Expressive melody with arpeggiated accompaniment
48
+ - **Expected Accuracy**: 75-80%
49
+ - **Use For**:
50
+ - Pedal sustain handling
51
+ - Rubato tempo changes
52
+ - Expressive timing
53
+
54
+ **Challenges**:
55
+ - Overlapping notes from pedal
56
+ - Tempo fluctuations
57
+ - Decorative grace notes
58
+
59
+ ### 5. Beethoven - Für Elise
60
+ - **Duration**: 3 minutes
61
+ - **Tempo**: Poco moto (120-130 BPM)
62
+ - **Complexity**: Famous melody with consistent rhythm
63
+ - **Expected Accuracy**: 80-85%
64
+ - **Use For**:
65
+ - A minor key signature
66
+ - Repeated patterns
67
+ - Multiple sections
68
+
69
+ **Challenges**:
70
+ - Fast 16th note passages
71
+ - Dynamic contrasts
72
+
73
+ ### 6. Mozart - Piano Sonata K. 545 (1st Movement)
74
+ - **Duration**: 3-4 minutes
75
+ - **Tempo**: Allegro (120-140 BPM)
76
+ - **Complexity**: Clear melody with Alberti bass
77
+ - **Expected Accuracy**: 75-80%
78
+ - **Use For**:
79
+ - C major scale passages
80
+ - Alberti bass pattern recognition
81
+ - Classical form
82
+
83
+ **Challenges**:
84
+ - Fast running passages
85
+ - Hand coordination
86
+
87
+ ## Pop Piano Covers
88
+
89
+ Test with contemporary music to verify modern styles.
90
+
91
+ ### 7. Let It Be (Piano Cover)
92
+ - **Duration**: 3-4 minutes
93
+ - **Tempo**: Moderate (76 BPM)
94
+ - **Complexity**: Block chords with melody
95
+ - **Expected Accuracy**: 70-75%
96
+ - **Use For**:
97
+ - Chord detection
98
+ - Popular music transcription
99
+ - Mixed rhythm patterns
100
+
101
+ **Challenges**:
102
+ - Dense chords
103
+ - Vocal line vs accompaniment
104
+
105
+ ### 8. Someone Like You (Piano Cover)
106
+ - **Duration**: 4-5 minutes
107
+ - **Tempo**: Slow (67 BPM)
108
+ - **Complexity**: Arpeggiated chords with melody
109
+ - **Expected Accuracy**: 70-75%
110
+ - **Use For**:
111
+ - Sustained notes
112
+ - Emotional expression
113
+ - Modern pop harmony
114
+
115
+ **Challenges**:
116
+ - Overlapping arpeggios
117
+ - Pedal sustain
118
+
119
+ ### 9. River Flows in You (Original Piano)
120
+ - **Duration**: 3-4 minutes
121
+ - **Tempo**: Moderato (110 BPM)
122
+ - **Complexity**: Flowing arpeggios with melody
123
+ - **Expected Accuracy**: 75-80%
124
+ - **Use For**:
125
+ - Continuous motion
126
+ - Pattern recognition
127
+ - Popular instrumental
128
+
129
+ **Challenges**:
130
+ - Rapid note sequences
131
+ - Consistent texture
132
+
133
+ ## Jazz Piano
134
+
135
+ Test improvisation and complex harmony.
136
+
137
+ ### 10. Bill Evans - Waltz for Debby
138
+ - **Duration**: 5-7 minutes
139
+ - **Tempo**: Moderate waltz (140-160 BPM)
140
+ - **Complexity**: Jazz voicings, walking bass, improvisation
141
+ - **Expected Accuracy**: 60-70%
142
+ - **Use For**:
143
+ - Jazz harmony
144
+ - 3/4 time signature
145
+ - Complex chord voicings
146
+
147
+ **Challenges**:
148
+ - Extended chords (7ths, 9ths, 11ths)
149
+ - Improvised passages
150
+ - Swing feel
151
+
152
+ ### 11. Oscar Peterson - C Jam Blues
153
+ - **Duration**: 3-4 minutes
154
+ - **Tempo**: Fast (200+ BPM)
155
+ - **Complexity**: Blues progression with virtuosic runs
156
+ - **Expected Accuracy**: 55-65%
157
+ - **Use For**:
158
+ - Fast tempo handling
159
+ - Blues scale
160
+ - Virtuosic passages
161
+
162
+ **Challenges**:
163
+ - Extremely fast notes
164
+ - Grace notes and ornaments
165
+ - Complex rhythm
166
+
167
+ ## Complex/Challenging
168
+
169
+ Stress tests for the transcription system.
170
+
171
+ ### 12. Flight of the Bumblebee (Piano)
172
+ - **Duration**: 1-2 minutes
173
+ - **Tempo**: Presto (170-200 BPM)
174
+ - **Complexity**: Extremely fast chromatic runs
175
+ - **Expected Accuracy**: 50-60%
176
+ - **Use For**:
177
+ - Stress testing
178
+ - Fast passage detection
179
+ - Chromatic scales
180
+
181
+ **Challenges**:
182
+ - Very fast notes (32nd notes)
183
+ - Chromatic passages
184
+ - Continuous motion
185
+
186
+ ### 13. Liszt - La Campanella
187
+ - **Duration**: 4-5 minutes
188
+ - **Tempo**: Allegretto (120 BPM)
189
+ - **Complexity**: Virtuosic with wide leaps and rapid passages
190
+ - **Expected Accuracy**: 55-65%
191
+ - **Use For**:
192
+ - Wide register jumps
193
+ - Repeated notes
194
+ - Virtuosic technique
195
+
196
+ **Challenges**:
197
+ - Octave leaps
198
+ - Repeated staccato notes
199
+ - Ornamentation
200
+
201
+ ### 14. Rachmaninoff - Prelude in C# Minor
202
+ - **Duration**: 3-4 minutes
203
+ - **Tempo**: Lento (60 BPM) to Agitato
204
+ - **Complexity**: Dense chords, dramatic dynamics
205
+ - **Expected Accuracy**: 60-70%
206
+ - **Use For**:
207
+ - Heavy chords
208
+ - Dramatic contrasts
209
+ - Multiple voices
210
+
211
+ **Challenges**:
212
+ - 6+ note chords
213
+ - Extreme dynamics
214
+ - Multiple simultaneous voices
215
+
216
+ ## Edge Cases
217
+
218
+ Special cases to test error handling and boundaries.
219
+
220
+ ### 15. Prepared Piano / Extended Techniques
221
+ - **Use For**: Testing unusual timbres
222
+ - **Expected Accuracy**: 30-50%
223
+ - **Expected Behavior**: Should handle gracefully
224
+
225
+ ### 16. Piano with Background Noise
226
+ - **Use For**: Testing source separation quality
227
+ - **Expected Accuracy**: Variable
228
+ - **Expected Behavior**: Should isolate piano reasonably
229
+
230
+ ### 17. Poor Audio Quality
231
+ - **Use For**: Testing robustness
232
+ - **Expected Accuracy**: Reduced
233
+ - **Expected Behavior**: Should not crash
234
+
235
+ ### 18. Non-Piano Video (Should Fail Gracefully)
236
+ - **Examples**:
237
+ - Drum solo
238
+ - A cappella singing
239
+ - Electronic music
240
+ - **Expected Behavior**: Should complete but with poor results
241
+
242
+ ## Testing Criteria
243
+
244
+ ### Accuracy Metrics
245
+
246
+ **High Priority (Must Work Well)**:
247
+ - Note pitch accuracy: 85%+ for simple pieces
248
+ - Note onset timing: 80%+ within 50ms
249
+ - Note duration: 70%+ within one quantization unit
250
+
251
+ **Medium Priority (Should Work)**:
252
+ - Key signature detection: 80%+ accuracy
253
+ - Time signature detection: 75%+ accuracy
254
+ - Tempo detection: 70%+ within 10 BPM
255
+
256
+ **Low Priority (Nice to Have)**:
257
+ - Dynamic markings: Not implemented in MVP
258
+ - Articulations: Not implemented in MVP
259
+ - Pedal markings: Not implemented in MVP
260
+
261
+ ### Performance Benchmarks
262
+
263
+ | Video Duration | Target Processing Time (GPU) | Max Processing Time (CPU) |
264
+ |---------------|------------------------------|---------------------------|
265
+ | 1 minute | < 30 seconds | < 5 minutes |
266
+ | 3 minutes | < 2 minutes | < 10 minutes |
267
+ | 5 minutes | < 3 minutes | < 15 minutes |
268
+
269
+ ### Success Criteria
270
+
271
+ A transcription is considered successful if:
272
+
273
+ 1. **Job completes without error**: 95%+ success rate
274
+ 2. **Basic pitch accuracy**: 70%+ correct notes for simple pieces, 60%+ for complex
275
+ 3. **Playback sounds recognizable**: User can identify the piece
276
+ 4. **Usable for editing**: Notation is clean enough to edit and correct
277
+
278
+ ### Quality Grades
279
+
280
+ **A (90%+ accuracy)**:
281
+ - Simple melodies
282
+ - Clear recordings
283
+ - Slow to moderate tempo
284
+ - Minimal harmony
285
+
286
+ **B (75-89% accuracy)**:
287
+ - Standard classical pieces
288
+ - Good recordings
289
+ - Moderate tempo
290
+ - Some harmony
291
+
292
+ **C (60-74% accuracy)**:
293
+ - Complex pieces
294
+ - Standard recordings
295
+ - Fast tempo or complex harmony
296
+ - Multiple voices
297
+
298
+ **D (50-59% accuracy)**:
299
+ - Virtuosic pieces
300
+ - Poor recordings
301
+ - Very fast or complex
302
+ - Jazz/improvisation
303
+
304
+ **F (< 50% accuracy)**:
305
+ - Extended techniques
306
+ - Very poor quality
307
+ - Non-piano instruments
308
+ - Extreme complexity
309
+
310
+ ## Using Test Videos
311
+
312
+ ### Manual Testing
313
+
314
+ 1. Submit each video URL through the UI
315
+ 2. Wait for processing to complete
316
+ 3. Check for errors in each pipeline stage
317
+ 4. Download and inspect MusicXML output
318
+ 5. Load in MuseScore or similar to verify quality
319
+ 6. Note accuracy, timing issues, and artifacts
320
+
321
+ ### Automated Testing
322
+
323
+ ```python
324
+ # In tests/test_integration.py
325
+ @pytest.mark.parametrize("video_id,expected_grade", [
326
+ ("simple_melody", "A"),
327
+ ("fur_elise", "B"),
328
+ ("jazz_piece", "C"),
329
+ ])
330
+ def test_transcription_quality(video_id, expected_grade):
331
+ """Test transcription quality meets expectations."""
332
+ result = transcribe_video(video_id)
333
+
334
+ assert result['status'] == 'success'
335
+ accuracy = calculate_accuracy(result['musicxml'])
336
+ assert accuracy >= grade_threshold(expected_grade)
337
+ ```
338
+
339
+ ### Regression Testing
340
+
341
+ Maintain a suite of test videos and track accuracy over time:
342
+
343
+ ```bash
344
+ # Run regression test suite
345
+ python scripts/run_regression_tests.py
346
+
347
+ # Compare with baseline
348
+ python scripts/compare_results.py --baseline v1.0.0 --current HEAD
349
+ ```
350
+
351
+ ## Maintaining Test Collection
352
+
353
+ 1. **Add new test cases** when bugs are found
354
+ 2. **Update expected accuracy** as system improves
355
+ 3. **Remove broken links** and replace with alternatives
356
+ 4. **Document edge cases** that reveal system limitations
357
+ 5. **Share results** with team to track progress
358
+
359
+ ## Test Video Sources
360
+
361
+ When selecting test videos:
362
+
363
+ - ✅ Use videos with clear audio
364
+ - ✅ Prefer solo piano recordings
365
+ - ✅ Choose varied difficulty levels
366
+ - ✅ Include different musical styles
367
+ - ✅ Ensure videos are publicly accessible
368
+ - ✅ Respect copyright and fair use
369
+ - ❌ Avoid videos with talking/commentary
370
+ - ❌ Avoid poor audio quality unless testing robustness
371
+ - ❌ Don't use videos over 15 minutes (MVP limit)
frontend/.env.example ADDED
@@ -0,0 +1 @@
 
 
1
+ VITE_API_URL=http://localhost:8000
frontend/.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ pnpm-debug.log*
8
+ lerna-debug.log*
9
+
10
+ node_modules
11
+ dist
12
+ dist-ssr
13
+ *.local
14
+
15
+ # Editor directories and files
16
+ .vscode/*
17
+ !.vscode/extensions.json
18
+ .idea
19
+ .DS_Store
20
+ *.suo
21
+ *.ntvs*
22
+ *.njsproj
23
+ *.sln
24
+ *.sw?
frontend/Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM node:20-alpine
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Copy package files
7
+ COPY package*.json ./
8
+
9
+ # Install dependencies
10
+ RUN npm install
11
+
12
+ # Copy application code
13
+ COPY . .
14
+
15
+ # Expose Vite dev server port
16
+ EXPOSE 5173
17
+
18
+ # Default command
19
+ CMD ["npm", "run", "dev", "--", "--host", "0.0.0.0"]
frontend/README.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # React + TypeScript + Vite
2
+
3
+ This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
4
+
5
+ Currently, two official plugins are available:
6
+
7
+ - [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) (or [oxc](https://oxc.rs) when used in [rolldown-vite](https://vite.dev/guide/rolldown)) for Fast Refresh
8
+ - [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
9
+
10
+ ## React Compiler
11
+
12
+ The React Compiler is not enabled on this template because of its impact on dev & build performances. To add it, see [this documentation](https://react.dev/learn/react-compiler/installation).
13
+
14
+ ## Expanding the ESLint configuration
15
+
16
+ If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
17
+
18
+ ```js
19
+ export default defineConfig([
20
+ globalIgnores(['dist']),
21
+ {
22
+ files: ['**/*.{ts,tsx}'],
23
+ extends: [
24
+ // Other configs...
25
+
26
+ // Remove tseslint.configs.recommended and replace with this
27
+ tseslint.configs.recommendedTypeChecked,
28
+ // Alternatively, use this for stricter rules
29
+ tseslint.configs.strictTypeChecked,
30
+ // Optionally, add this for stylistic rules
31
+ tseslint.configs.stylisticTypeChecked,
32
+
33
+ // Other configs...
34
+ ],
35
+ languageOptions: {
36
+ parserOptions: {
37
+ project: ['./tsconfig.node.json', './tsconfig.app.json'],
38
+ tsconfigRootDir: import.meta.dirname,
39
+ },
40
+ // other options...
41
+ },
42
+ },
43
+ ])
44
+ ```
45
+
46
+ You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
47
+
48
+ ```js
49
+ // eslint.config.js
50
+ import reactX from 'eslint-plugin-react-x'
51
+ import reactDom from 'eslint-plugin-react-dom'
52
+
53
+ export default defineConfig([
54
+ globalIgnores(['dist']),
55
+ {
56
+ files: ['**/*.{ts,tsx}'],
57
+ extends: [
58
+ // Other configs...
59
+ // Enable lint rules for React
60
+ reactX.configs['recommended-typescript'],
61
+ // Enable lint rules for React DOM
62
+ reactDom.configs.recommended,
63
+ ],
64
+ languageOptions: {
65
+ parserOptions: {
66
+ project: ['./tsconfig.node.json', './tsconfig.app.json'],
67
+ tsconfigRootDir: import.meta.dirname,
68
+ },
69
+ // other options...
70
+ },
71
+ },
72
+ ])
73
+ ```
frontend/eslint.config.js ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from '@eslint/js'
2
+ import globals from 'globals'
3
+ import reactHooks from 'eslint-plugin-react-hooks'
4
+ import reactRefresh from 'eslint-plugin-react-refresh'
5
+ import tseslint from 'typescript-eslint'
6
+ import { defineConfig, globalIgnores } from 'eslint/config'
7
+
8
+ export default defineConfig([
9
+ globalIgnores(['dist']),
10
+ {
11
+ files: ['**/*.{ts,tsx}'],
12
+ extends: [
13
+ js.configs.recommended,
14
+ tseslint.configs.recommended,
15
+ reactHooks.configs.flat.recommended,
16
+ reactRefresh.configs.vite,
17
+ ],
18
+ languageOptions: {
19
+ ecmaVersion: 2020,
20
+ globals: globals.browser,
21
+ },
22
+ },
23
+ ])
frontend/index.html ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>frontend</title>
8
+ </head>
9
+ <body>
10
+ <div id="root"></div>
11
+ <script type="module" src="/src/main.tsx"></script>
12
+ </body>
13
+ </html>
frontend/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
frontend/package.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "frontend",
3
+ "private": true,
4
+ "version": "0.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite",
8
+ "build": "tsc -b && vite build",
9
+ "lint": "eslint .",
10
+ "preview": "vite preview",
11
+ "test": "vitest",
12
+ "test:ui": "vitest --ui",
13
+ "test:coverage": "vitest --coverage"
14
+ },
15
+ "dependencies": {
16
+ "@xmldom/xmldom": "^0.8.11",
17
+ "react": "^19.2.0",
18
+ "react-dom": "^19.2.0",
19
+ "tone": "^15.1.3",
20
+ "vexflow": "^4.2.4",
21
+ "zustand": "^5.0.3"
22
+ },
23
+ "devDependencies": {
24
+ "@eslint/js": "^9.39.1",
25
+ "@testing-library/jest-dom": "^6.1.5",
26
+ "@testing-library/react": "^14.1.2",
27
+ "@testing-library/user-event": "^14.5.1",
28
+ "@types/node": "^24.10.1",
29
+ "@types/react": "^19.2.5",
30
+ "@types/react-dom": "^19.2.3",
31
+ "@vitejs/plugin-react": "^5.1.1",
32
+ "@vitest/ui": "^1.1.0",
33
+ "eslint": "^9.39.1",
34
+ "eslint-plugin-react-hooks": "^7.0.1",
35
+ "eslint-plugin-react-refresh": "^0.4.24",
36
+ "globals": "^16.5.0",
37
+ "jsdom": "^23.0.1",
38
+ "typescript": "~5.9.3",
39
+ "typescript-eslint": "^8.46.4",
40
+ "vite": "^7.2.4",
41
+ "vitest": "^1.1.0",
42
+ "@vitest/coverage-v8": "^1.1.0"
43
+ }
44
+ }
frontend/public/vite.svg ADDED
frontend/scripts/debug-parser.cjs ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Quick script to test the parser
2
+ const fs = require('fs');
3
+ const { DOMParser } = require('@xmldom/xmldom');
4
+
5
+ const xml = fs.readFileSync('../storage/outputs/497306b6-8e09-41c2-b8c7-0792dbd22022.musicxml', 'utf-8');
6
+ const parser = new DOMParser();
7
+ const doc = parser.parseFromString(xml, 'text/xml');
8
+
9
+ // Check what we're getting
10
+ const beats = doc.getElementsByTagName('beats')[0]?.textContent;
11
+ const beatType = doc.getElementsByTagName('beat-type')[0]?.textContent;
12
+ console.log('Time signature:', beats + '/' + beatType);
13
+ console.log('Divisions:', doc.getElementsByTagName('divisions')[0]?.textContent);
14
+ console.log('Key (fifths):', doc.getElementsByTagName('fifths')[0]?.textContent);
15
+
16
+ const soundEl = doc.getElementsByTagName('sound')[0];
17
+ console.log('Tempo:', soundEl?.getAttribute('tempo'));
18
+
19
+ const measures = doc.getElementsByTagName('measure');
20
+ console.log('\nTotal measures:', measures.length);
21
+ console.log('First 10 measures:');
22
+
23
+ const divisions = parseInt(doc.getElementsByTagName('divisions')[0]?.textContent || '10080');
24
+
25
+ for (let i = 0; i < Math.min(10, measures.length); i++) {
26
+ const m = measures[i];
27
+ const notes = m.getElementsByTagName('note');
28
+ const pitchedNotes = [];
29
+ let totalDuration = 0;
30
+
31
+ for (let n = 0; n < notes.length; n++) {
32
+ const note = notes[n];
33
+ const isRest = note.getElementsByTagName('rest').length > 0;
34
+ const duration = parseInt(note.getElementsByTagName('duration')[0]?.textContent || '0');
35
+ totalDuration += duration;
36
+
37
+ if (!isRest) {
38
+ pitchedNotes.push(note);
39
+ }
40
+ }
41
+
42
+ const expectedDuration = divisions * 4; // 4 beats in 4/4
43
+ const durationMatch = totalDuration === expectedDuration ? '✓' : `✗ (expected ${expectedDuration}, got ${totalDuration})`;
44
+
45
+ console.log(` Measure ${m.getAttribute('number')}: ${notes.length} total notes, ${pitchedNotes.length} pitched notes, duration ${durationMatch}`);
46
+
47
+ // Show first 3 pitched notes
48
+ for (let j = 0; j < Math.min(3, pitchedNotes.length); j++) {
49
+ const note = pitchedNotes[j];
50
+ const pitch = note.getElementsByTagName('step')[0]?.textContent;
51
+ const octave = note.getElementsByTagName('octave')[0]?.textContent;
52
+ const duration = note.getElementsByTagName('duration')[0]?.textContent;
53
+ const type = note.getElementsByTagName('type')[0]?.textContent;
54
+ const alter = note.getElementsByTagName('alter')[0]?.textContent;
55
+ const accidental = alter === '1' ? '#' : alter === '-1' ? 'b' : '';
56
+ console.log(` Note ${j+1}: ${pitch}${accidental}${octave}, duration=${duration}, type=${type}`);
57
+ }
58
+ }
frontend/scripts/test-chord-handling.cjs ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const fs = require('fs');
2
+ const { DOMParser } = require('@xmldom/xmldom');
3
+
4
+ const xml = fs.readFileSync('../storage/outputs/497306b6-8e09-41c2-b8c7-0792dbd22022.musicxml', 'utf-8');
5
+ const parser = new DOMParser();
6
+ const doc = parser.parseFromString(xml, 'text/xml');
7
+
8
+ // Look at measure 4 which has chords
9
+ const measures = doc.getElementsByTagName('measure');
10
+ const measure4 = measures[3]; // 0-indexed
11
+
12
+ console.log('=== MEASURE 4 ANALYSIS ===');
13
+ const notes = measure4.getElementsByTagName('note');
14
+ console.log('Total note elements:', notes.length);
15
+
16
+ let noteCount = 0;
17
+ let totalDuration = 0;
18
+
19
+ for (let i = 0; i < notes.length; i++) {
20
+ const note = notes[i];
21
+ const isChord = note.getElementsByTagName('chord').length > 0;
22
+ const isRest = note.getElementsByTagName('rest').length > 0;
23
+ const duration = parseInt(note.getElementsByTagName('duration')[0]?.textContent || '0');
24
+
25
+ // Chord notes share duration with previous note
26
+ if (!isChord) {
27
+ totalDuration += duration;
28
+ }
29
+
30
+ if (!isRest) {
31
+ const pitch = note.getElementsByTagName('step')[0]?.textContent;
32
+ const octave = note.getElementsByTagName('octave')[0]?.textContent;
33
+ const type = note.getElementsByTagName('type')[0]?.textContent;
34
+ noteCount++;
35
+ console.log('Note', noteCount, ':', pitch + octave, '(' + type + '), duration=' + duration, ', chord=' + isChord);
36
+ }
37
+ }
38
+
39
+ const divisions = 10080;
40
+ const expected = divisions * 4;
41
+ console.log('\nTotal duration:', totalDuration, '(expected', expected + ')');
42
+ console.log('Duration ratio:', (totalDuration / expected).toFixed(2) + 'x');
frontend/src/App.css ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ box-sizing: border-box;
3
+ }
4
+
5
+ body {
6
+ margin: 0;
7
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen',
8
+ 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue',
9
+ sans-serif;
10
+ -webkit-font-smoothing: antialiased;
11
+ -moz-osx-font-smoothing: grayscale;
12
+ }
13
+
14
+ .app {
15
+ min-height: 100vh;
16
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
17
+ padding: 2rem;
18
+ }
19
+
20
+ .back-button {
21
+ margin-bottom: 1rem;
22
+ background: white;
23
+ color: #667eea;
24
+ border: 2px solid #667eea;
25
+ }
26
+
27
+ .back-button:hover {
28
+ background: #667eea;
29
+ color: white;
30
+ }
frontend/src/App.tsx ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Main application component.
3
+ */
4
+ import { useState } from 'react';
5
+ import { JobSubmission } from './components/JobSubmission';
6
+ import { ScoreEditor } from './components/ScoreEditor';
7
+ import './App.css';
8
+
9
+ function App() {
10
+ const [currentJobId, setCurrentJobId] = useState<string | null>(null);
11
+
12
+ const handleJobComplete = (jobId: string) => {
13
+ setCurrentJobId(jobId);
14
+ };
15
+
16
+ const handleReset = () => {
17
+ setCurrentJobId(null);
18
+ };
19
+
20
+ return (
21
+ <div className="app">
22
+ {!currentJobId ? (
23
+ <JobSubmission onComplete={handleJobComplete} />
24
+ ) : (
25
+ <div>
26
+ <button className="back-button" onClick={handleReset}>
27
+ ← New Transcription
28
+ </button>
29
+ <ScoreEditor jobId={currentJobId} />
30
+ </div>
31
+ )}
32
+ </div>
33
+ );
34
+ }
35
+
36
+ export default App;
frontend/src/api/client.ts ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * API client for Rescored backend.
3
+ */
4
+
5
+ const API_BASE_URL = import.meta.env.VITE_API_URL || 'http://localhost:8000';
6
+ const WS_BASE_URL = API_BASE_URL.replace('http', 'ws');
7
+
8
+ export interface TranscribeRequest {
9
+ youtube_url: string;
10
+ options?: {
11
+ instruments: string[];
12
+ };
13
+ }
14
+
15
+ export interface TranscribeResponse {
16
+ job_id: string;
17
+ status: string;
18
+ created_at: string;
19
+ estimated_duration_seconds: number;
20
+ websocket_url: string;
21
+ }
22
+
23
+ export interface JobStatus {
24
+ job_id: string;
25
+ status: 'queued' | 'processing' | 'completed' | 'failed';
26
+ progress: number;
27
+ current_stage: string | null;
28
+ status_message: string | null;
29
+ created_at: string;
30
+ started_at: string | null;
31
+ completed_at: string | null;
32
+ failed_at: string | null;
33
+ error: { message: string; retryable: boolean } | null;
34
+ result_url: string | null;
35
+ }
36
+
37
+ export interface ProgressUpdate {
38
+ type: 'progress' | 'completed' | 'error' | 'heartbeat';
39
+ job_id: string;
40
+ progress?: number;
41
+ stage?: string;
42
+ message?: string;
43
+ result_url?: string;
44
+ error?: { message: string; retryable: boolean };
45
+ timestamp: string;
46
+ }
47
+
48
+ export class RescoredAPI {
49
+ private baseURL = API_BASE_URL;
50
+ private wsBaseURL = WS_BASE_URL;
51
+
52
+ async submitJob(youtubeURL: string, options?: { instruments?: string[] }): Promise<TranscribeResponse> {
53
+ const response = await fetch(`${this.baseURL}/api/v1/transcribe`, {
54
+ method: 'POST',
55
+ headers: {
56
+ 'Content-Type': 'application/json',
57
+ },
58
+ body: JSON.stringify({
59
+ youtube_url: youtubeURL,
60
+ options: options ?? { instruments: ['piano'] },
61
+ }),
62
+ });
63
+
64
+ if (!response.ok) {
65
+ const error = await response.json();
66
+ throw new Error(error.detail || 'Failed to submit job');
67
+ }
68
+
69
+ return response.json();
70
+ }
71
+
72
+ async getJobStatus(jobId: string): Promise<JobStatus> {
73
+ const response = await fetch(`${this.baseURL}/api/v1/jobs/${jobId}`);
74
+
75
+ if (!response.ok) {
76
+ throw new Error('Failed to fetch job status');
77
+ }
78
+
79
+ return response.json();
80
+ }
81
+
82
+ async getScore(jobId: string): Promise<string> {
83
+ const response = await fetch(`${this.baseURL}/api/v1/scores/${jobId}`);
84
+
85
+ if (!response.ok) {
86
+ throw new Error('Failed to fetch score');
87
+ }
88
+
89
+ return response.text();
90
+ }
91
+
92
+ connectWebSocket(
93
+ jobId: string,
94
+ onMessage: (update: ProgressUpdate) => void,
95
+ onError?: (error: Event) => void,
96
+ onClose?: () => void
97
+ ): WebSocket {
98
+ const ws = new WebSocket(`${this.wsBaseURL}/api/v1/jobs/${jobId}/stream`);
99
+
100
+ ws.onmessage = (event) => {
101
+ const update: ProgressUpdate = JSON.parse(event.data);
102
+ onMessage(update);
103
+
104
+ // Send pong for heartbeat
105
+ if (update.type === 'heartbeat') {
106
+ ws.send(JSON.stringify({ type: 'pong', timestamp: new Date().toISOString() }));
107
+ }
108
+ };
109
+
110
+ if (onError) {
111
+ ws.onerror = onError;
112
+ }
113
+
114
+ if (onClose) {
115
+ ws.onclose = onClose;
116
+ }
117
+
118
+ return ws;
119
+ }
120
+
121
+ getScoreURL(jobId: string): string {
122
+ return `${this.baseURL}/api/v1/scores/${jobId}`;
123
+ }
124
+ }
125
+
126
+ export const api = new RescoredAPI();
127
+
128
+ // Compatibility function wrappers for tests
129
+ export async function submitTranscription(
130
+ youtubeURL: string,
131
+ options?: { instruments?: string[] }
132
+ ) {
133
+ // Delegate to class method; include options if provided
134
+ return api.submitJob(youtubeURL, options);
135
+ }
136
+
137
+ export async function getJobStatus(jobId: string) {
138
+ return api.getJobStatus(jobId);
139
+ }
140
+
141
+ export async function downloadScore(jobId: string) {
142
+ return api.getScore(jobId);
143
+ }
frontend/src/assets/react.svg ADDED
frontend/src/components/JobSubmission.css ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .job-submission {
2
+ max-width: 600px;
3
+ margin: 0 auto;
4
+ padding: 2rem;
5
+ }
6
+
7
+ .job-submission h1 {
8
+ font-size: 2rem;
9
+ margin-bottom: 0.5rem;
10
+ }
11
+
12
+ .form-group {
13
+ margin-bottom: 1rem;
14
+ }
15
+
16
+ .form-group label {
17
+ display: block;
18
+ margin-bottom: 0.5rem;
19
+ font-weight: bold;
20
+ }
21
+
22
+ .form-group input {
23
+ width: 100%;
24
+ padding: 0.5rem;
25
+ border: 1px solid #ccc;
26
+ border-radius: 4px;
27
+ font-size: 1rem;
28
+ }
29
+
30
+ button {
31
+ padding: 0.75rem 1.5rem;
32
+ background-color: #007bff;
33
+ color: white;
34
+ border: none;
35
+ border-radius: 4px;
36
+ font-size: 1rem;
37
+ cursor: pointer;
38
+ }
39
+
40
+ button:hover {
41
+ background-color: #0056b3;
42
+ }
43
+
44
+ .progress-container {
45
+ text-align: center;
46
+ }
47
+
48
+ .progress-bar {
49
+ width: 100%;
50
+ height: 30px;
51
+ background-color: #f0f0f0;
52
+ border-radius: 15px;
53
+ overflow: hidden;
54
+ margin: 1rem 0;
55
+ }
56
+
57
+ .progress-fill {
58
+ height: 100%;
59
+ background-color: #28a745;
60
+ transition: width 0.3s ease;
61
+ }
62
+
63
+ .progress-text {
64
+ color: #666;
65
+ font-size: 0.9rem;
66
+ }
67
+
68
+ .success-message,
69
+ .error-message {
70
+ text-align: center;
71
+ padding: 2rem;
72
+ border-radius: 8px;
73
+ }
74
+
75
+ .success-message {
76
+ background-color: #d4edda;
77
+ color: #155724;
78
+ }
79
+
80
+ .error-message {
81
+ background-color: #f8d7da;
82
+ color: #721c24;
83
+ }