Ewan commited on
Commit
f0a176a
·
1 Parent(s): 751132e

Initial commit - Mr Octopus piano tutorial app

Browse files
.dockerignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ app/node_modules
2
+ app/dist
3
+ transcriber/venv
4
+ transcriber/uploads
5
+ transcriber/analyze_*.py
6
+ transcriber/compare.py
7
+ transcriber/diagnose_*.py
8
+ transcriber/simulate_*.py
9
+ __pycache__
10
+ *.pyc
11
+ .git
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ app/node_modules
2
+ app/dist
3
+ transcriber/venv
4
+ transcriber/uploads
5
+ transcriber/analyze_*.py
6
+ transcriber/compare.py
7
+ transcriber/diagnose_*.py
8
+ transcriber/simulate_*.py
9
+ __pycache__
10
+ *.pyc
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM node:20-slim AS frontend
2
+
3
+ WORKDIR /build
4
+ COPY app/package.json app/package-lock.json* ./
5
+ RUN npm ci
6
+ COPY app/ .
7
+ RUN npm run build
8
+
9
+ # --- Python backend ---
10
+ FROM python:3.12-slim
11
+
12
+ # System deps: ffmpeg for audio processing, yt-dlp needs it too
13
+ RUN apt-get update && apt-get install -y --no-install-recommends \
14
+ ffmpeg \
15
+ && rm -rf /var/lib/apt/lists/*
16
+
17
+ WORKDIR /app
18
+
19
+ # Install Python dependencies
20
+ # basic-pitch pulls in tensorflow on Linux, but we only use ONNX runtime.
21
+ # Install it with --no-deps and manually specify what we need.
22
+ COPY api/requirements.txt /app/api/requirements.txt
23
+ RUN pip install --no-cache-dir \
24
+ fastapi uvicorn[standard] python-multipart \
25
+ onnxruntime pretty_midi librosa scipy numpy "setuptools<81" \
26
+ yt-dlp mir-eval resampy scikit-learn && \
27
+ pip install --no-cache-dir --no-deps basic-pitch
28
+
29
+ # Copy application code
30
+ COPY transcriber/ /app/transcriber/
31
+ COPY api/ /app/api/
32
+
33
+ # Copy built frontend
34
+ COPY --from=frontend /build/dist /app/app/dist
35
+
36
+ ENV PORT=7860
37
+ EXPOSE 7860
38
+
39
+ CMD ["uvicorn", "api.server:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,8 @@
1
  ---
2
- title: Mroctopus
3
- emoji: 📈
4
  colorFrom: purple
5
- colorTo: green
6
  sdk: docker
7
- pinned: false
8
  ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Mr Octopus
3
+ emoji: 🐙
4
  colorFrom: purple
5
+ colorTo: blue
6
  sdk: docker
7
+ app_port: 7860
8
  ---
 
 
api/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.115
2
+ uvicorn[standard]>=0.34
3
+ python-multipart>=0.0.18
4
+
5
+ # Transcription pipeline
6
+ basic-pitch>=0.3
7
+ onnxruntime>=1.17
8
+ pretty_midi>=0.2.10
9
+ librosa>=0.10
10
+ scipy>=1.12
11
+ numpy>=1.24
12
+ setuptools<81
13
+
14
+ # URL downloading
15
+ yt-dlp>=2024.1
api/server.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI backend for the piano tutorial transcription pipeline."""
2
+
3
+ import json
4
+ import sys
5
+ import tempfile
6
+ import uuid
7
+ from pathlib import Path
8
+
9
+ from fastapi import FastAPI, UploadFile, File, HTTPException
10
+ from fastapi.responses import FileResponse, JSONResponse
11
+ from fastapi.staticfiles import StaticFiles
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+
14
+ # Add transcriber to path
15
+ TRANSCRIBER_DIR = Path(__file__).resolve().parent.parent / "transcriber"
16
+ sys.path.insert(0, str(TRANSCRIBER_DIR))
17
+
18
+ app = FastAPI(title="Piano Tutorial API")
19
+
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"],
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ # Directory for temporary processing files
28
+ WORK_DIR = Path(tempfile.gettempdir()) / "piano-tutorial"
29
+ WORK_DIR.mkdir(exist_ok=True)
30
+
31
+
32
+ @app.post("/api/transcribe")
33
+ async def transcribe(
34
+ file: UploadFile = File(...),
35
+ ):
36
+ """Transcribe an uploaded audio file to MIDI.
37
+
38
+ Accepts a file upload (MP3, M4A, WAV, OGG, FLAC).
39
+ Returns JSON with a job_id, MIDI download URL, and chord data.
40
+ """
41
+ job_id = str(uuid.uuid4())[:8]
42
+ job_dir = WORK_DIR / job_id
43
+ job_dir.mkdir(exist_ok=True)
44
+
45
+ try:
46
+ suffix = Path(file.filename).suffix or ".m4a"
47
+ audio_path = job_dir / f"upload{suffix}"
48
+ content = await file.read()
49
+ audio_path.write_bytes(content)
50
+
51
+ # Run transcription
52
+ from transcribe import transcribe as run_transcribe
53
+ raw_midi_path = job_dir / "transcription_raw.mid"
54
+ run_transcribe(str(audio_path), str(raw_midi_path))
55
+
56
+ # Run optimization (also runs chord detection as Step 10)
57
+ from optimize import optimize
58
+ optimized_path = job_dir / "transcription.mid"
59
+ optimize(str(audio_path), str(raw_midi_path), str(optimized_path))
60
+
61
+ if not optimized_path.exists():
62
+ raise HTTPException(500, "Optimization failed to produce output")
63
+
64
+ # Load chord data if available
65
+ chords_path = job_dir / "transcription_chords.json"
66
+ chord_data = None
67
+ if chords_path.exists():
68
+ with open(chords_path) as f:
69
+ chord_data = json.load(f)
70
+
71
+ return JSONResponse({
72
+ "job_id": job_id,
73
+ "midi_url": f"/api/jobs/{job_id}/midi",
74
+ "chords_url": f"/api/jobs/{job_id}/chords",
75
+ "chords": chord_data,
76
+ })
77
+
78
+ except HTTPException:
79
+ raise
80
+ except Exception as e:
81
+ raise HTTPException(500, f"Transcription failed: {str(e)}")
82
+
83
+
84
+ @app.get("/api/jobs/{job_id}/midi")
85
+ async def get_midi(job_id: str):
86
+ """Download the optimized MIDI file for a completed job."""
87
+ midi_path = WORK_DIR / job_id / "transcription.mid"
88
+ if not midi_path.exists():
89
+ raise HTTPException(404, f"No MIDI file found for job {job_id}")
90
+ return FileResponse(
91
+ midi_path,
92
+ media_type="audio/midi",
93
+ filename="transcription.mid",
94
+ )
95
+
96
+
97
+ @app.get("/api/jobs/{job_id}/chords")
98
+ async def get_chords(job_id: str):
99
+ """Get the detected chord data for a completed job."""
100
+ chords_path = WORK_DIR / job_id / "transcription_chords.json"
101
+ if not chords_path.exists():
102
+ raise HTTPException(404, f"No chord data found for job {job_id}")
103
+ with open(chords_path) as f:
104
+ chord_data = json.load(f)
105
+ return JSONResponse(chord_data)
106
+
107
+
108
+ @app.get("/api/health")
109
+ async def health():
110
+ return {"status": "ok"}
111
+
112
+
113
+ # Serve the built React frontend (in production)
114
+ DIST_DIR = Path(__file__).resolve().parent.parent / "app" / "dist"
115
+ if DIST_DIR.exists():
116
+ # Serve static assets
117
+ app.mount("/assets", StaticFiles(directory=str(DIST_DIR / "assets")), name="assets")
118
+
119
+ # Serve MIDI files if they exist
120
+ midi_dir = DIST_DIR / "midi"
121
+ if midi_dir.exists():
122
+ app.mount("/midi", StaticFiles(directory=str(midi_dir)), name="midi")
123
+
124
+ # Catch-all: serve index.html for SPA routing
125
+ @app.get("/{path:path}")
126
+ async def serve_spa(path: str):
127
+ file_path = DIST_DIR / path
128
+ if file_path.is_file():
129
+ return FileResponse(file_path)
130
+ return FileResponse(DIST_DIR / "index.html")
app/.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ pnpm-debug.log*
8
+ lerna-debug.log*
9
+
10
+ node_modules
11
+ dist
12
+ dist-ssr
13
+ *.local
14
+
15
+ # Editor directories and files
16
+ .vscode/*
17
+ !.vscode/extensions.json
18
+ .idea
19
+ .DS_Store
20
+ *.suo
21
+ *.ntvs*
22
+ *.njsproj
23
+ *.sln
24
+ *.sw?
app/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # React + Vite
2
+
3
+ This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
4
+
5
+ Currently, two official plugins are available:
6
+
7
+ - [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) (or [oxc](https://oxc.rs) when used in [rolldown-vite](https://vite.dev/guide/rolldown)) for Fast Refresh
8
+ - [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
9
+
10
+ ## React Compiler
11
+
12
+ The React Compiler is not enabled on this template because of its impact on dev & build performances. To add it, see [this documentation](https://react.dev/learn/react-compiler/installation).
13
+
14
+ ## Expanding the ESLint configuration
15
+
16
+ If you are developing a production application, we recommend using TypeScript with type-aware lint rules enabled. Check out the [TS template](https://github.com/vitejs/vite/tree/main/packages/create-vite/template-react-ts) for information on how to integrate TypeScript and [`typescript-eslint`](https://typescript-eslint.io) in your project.
app/eslint.config.js ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from '@eslint/js'
2
+ import globals from 'globals'
3
+ import reactHooks from 'eslint-plugin-react-hooks'
4
+ import reactRefresh from 'eslint-plugin-react-refresh'
5
+ import { defineConfig, globalIgnores } from 'eslint/config'
6
+
7
+ export default defineConfig([
8
+ globalIgnores(['dist']),
9
+ {
10
+ files: ['**/*.{js,jsx}'],
11
+ extends: [
12
+ js.configs.recommended,
13
+ reactHooks.configs.flat.recommended,
14
+ reactRefresh.configs.vite,
15
+ ],
16
+ languageOptions: {
17
+ ecmaVersion: 2020,
18
+ globals: globals.browser,
19
+ parserOptions: {
20
+ ecmaVersion: 'latest',
21
+ ecmaFeatures: { jsx: true },
22
+ sourceType: 'module',
23
+ },
24
+ },
25
+ rules: {
26
+ 'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
27
+ },
28
+ },
29
+ ])
app/index.html ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <meta name="theme-color" content="#07070e" />
7
+ <link rel="icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 48 48'%3E%3Cellipse cx='24' cy='16' rx='14' ry='12' fill='%238b5cf6'/%3E%3Ccircle cx='19' cy='14' r='2.5' fill='%2307070e'/%3E%3Ccircle cx='29' cy='14' r='2.5' fill='%2307070e'/%3E%3Ccircle cx='20' cy='13.3' r='0.9' fill='white'/%3E%3Ccircle cx='30' cy='13.3' r='0.9' fill='white'/%3E%3Crect x='7' y='26' width='3' height='18' rx='1' fill='%23f0eef5'/%3E%3Crect x='12' y='26' width='3' height='18' rx='1' fill='%23f0eef5'/%3E%3Crect x='16.5' y='26' width='2.5' height='13' rx='0.8' fill='%231e1b4b'/%3E%3Crect x='21' y='26' width='3' height='18' rx='1' fill='%23f0eef5'/%3E%3Crect x='26' y='26' width='3' height='18' rx='1' fill='%23f0eef5'/%3E%3Crect x='30.5' y='26' width='2.5' height='13' rx='0.8' fill='%231e1b4b'/%3E%3Crect x='35' y='26' width='3' height='18' rx='1' fill='%23f0eef5'/%3E%3Crect x='40' y='26' width='3' height='18' rx='1' fill='%23f0eef5'/%3E%3C/svg%3E" />
8
+ <link rel="preconnect" href="https://fonts.googleapis.com" />
9
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
10
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet" />
11
+ <title>Mr. Octopus</title>
12
+ </head>
13
+ <body>
14
+ <div id="root"></div>
15
+ <script type="module" src="/src/main.jsx"></script>
16
+ </body>
17
+ </html>
app/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
app/package.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "app",
3
+ "private": true,
4
+ "version": "0.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite",
8
+ "build": "vite build",
9
+ "lint": "eslint .",
10
+ "preview": "vite preview"
11
+ },
12
+ "dependencies": {
13
+ "@tonejs/midi": "^2.0.28",
14
+ "react": "^19.2.0",
15
+ "react-dom": "^19.2.0",
16
+ "tone": "^15.1.22"
17
+ },
18
+ "devDependencies": {
19
+ "@eslint/js": "^9.39.1",
20
+ "@types/react": "^19.2.7",
21
+ "@types/react-dom": "^19.2.3",
22
+ "@vitejs/plugin-react": "^5.1.1",
23
+ "eslint": "^9.39.1",
24
+ "eslint-plugin-react-hooks": "^7.0.1",
25
+ "eslint-plugin-react-refresh": "^0.4.24",
26
+ "globals": "^16.5.0",
27
+ "vite": "^7.3.1"
28
+ }
29
+ }
app/public/midi/transcription.mid ADDED
Binary file (8.75 kB). View file
 
app/public/midi/transcription_chords.json ADDED
The diff for this file is too large to render. See raw diff
 
app/public/midi/transcription_raw.mid ADDED
Binary file (13.3 kB). View file
 
app/public/midi/transcription_spectral.json ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "spectral_f1": 0.4765,
3
+ "spectral_precision": 0.7711,
4
+ "spectral_recall": 0.3448,
5
+ "spectral_similarity": 0.4061,
6
+ "per_octave": [
7
+ {
8
+ "octave": 0,
9
+ "range": "A0-A1",
10
+ "audio_energy": 154196,
11
+ "missing_energy": 152357,
12
+ "extra_energy": 156,
13
+ "matched_energy": 1839,
14
+ "coverage": 0.0119
15
+ },
16
+ {
17
+ "octave": 1,
18
+ "range": "A1-A2",
19
+ "audio_energy": 194926,
20
+ "missing_energy": 179822,
21
+ "extra_energy": 310,
22
+ "matched_energy": 15104,
23
+ "coverage": 0.0775
24
+ },
25
+ {
26
+ "octave": 2,
27
+ "range": "A2-A3",
28
+ "audio_energy": 199170,
29
+ "missing_energy": 156470,
30
+ "extra_energy": 1031,
31
+ "matched_energy": 42700,
32
+ "coverage": 0.2144
33
+ },
34
+ {
35
+ "octave": 3,
36
+ "range": "A3-A4",
37
+ "audio_energy": 203356,
38
+ "missing_energy": 124913,
39
+ "extra_energy": 929,
40
+ "matched_energy": 78443,
41
+ "coverage": 0.3857
42
+ },
43
+ {
44
+ "octave": 4,
45
+ "range": "A4-A5",
46
+ "audio_energy": 184288,
47
+ "missing_energy": 77943,
48
+ "extra_energy": 11447,
49
+ "matched_energy": 106345,
50
+ "coverage": 0.5771
51
+ },
52
+ {
53
+ "octave": 5,
54
+ "range": "A5-A6",
55
+ "audio_energy": 178926,
56
+ "missing_energy": 57570,
57
+ "extra_energy": 15304,
58
+ "matched_energy": 121356,
59
+ "coverage": 0.6782
60
+ },
61
+ {
62
+ "octave": 6,
63
+ "range": "A6-A7",
64
+ "audio_energy": 61937,
65
+ "missing_energy": 21815,
66
+ "extra_energy": 73759,
67
+ "matched_energy": 40122,
68
+ "coverage": 0.6478
69
+ },
70
+ {
71
+ "octave": 7,
72
+ "range": "A7-A8",
73
+ "audio_energy": 6945,
74
+ "missing_energy": 4709,
75
+ "extra_energy": 18200,
76
+ "matched_energy": 2236,
77
+ "coverage": 0.322
78
+ }
79
+ ],
80
+ "per_time": [
81
+ {
82
+ "time_start": 0.0,
83
+ "time_end": 6.69,
84
+ "missing": 37299,
85
+ "extra": 8429,
86
+ "matched": 15904,
87
+ "fidelity": 0.299
88
+ },
89
+ {
90
+ "time_start": 6.69,
91
+ "time_end": 13.37,
92
+ "missing": 39869,
93
+ "extra": 6624,
94
+ "matched": 20740,
95
+ "fidelity": 0.342
96
+ },
97
+ {
98
+ "time_start": 13.37,
99
+ "time_end": 20.06,
100
+ "missing": 41062,
101
+ "extra": 6050,
102
+ "matched": 20521,
103
+ "fidelity": 0.333
104
+ },
105
+ {
106
+ "time_start": 20.06,
107
+ "time_end": 26.75,
108
+ "missing": 39620,
109
+ "extra": 7666,
110
+ "matched": 21162,
111
+ "fidelity": 0.348
112
+ },
113
+ {
114
+ "time_start": 26.75,
115
+ "time_end": 33.44,
116
+ "missing": 37182,
117
+ "extra": 6304,
118
+ "matched": 24549,
119
+ "fidelity": 0.398
120
+ },
121
+ {
122
+ "time_start": 33.44,
123
+ "time_end": 40.12,
124
+ "missing": 38849,
125
+ "extra": 5103,
126
+ "matched": 23978,
127
+ "fidelity": 0.382
128
+ },
129
+ {
130
+ "time_start": 40.12,
131
+ "time_end": 46.81,
132
+ "missing": 38366,
133
+ "extra": 5501,
134
+ "matched": 23406,
135
+ "fidelity": 0.379
136
+ },
137
+ {
138
+ "time_start": 46.81,
139
+ "time_end": 53.5,
140
+ "missing": 40410,
141
+ "extra": 4906,
142
+ "matched": 21436,
143
+ "fidelity": 0.347
144
+ },
145
+ {
146
+ "time_start": 53.5,
147
+ "time_end": 60.19,
148
+ "missing": 36247,
149
+ "extra": 6550,
150
+ "matched": 22495,
151
+ "fidelity": 0.383
152
+ },
153
+ {
154
+ "time_start": 60.19,
155
+ "time_end": 66.87,
156
+ "missing": 38617,
157
+ "extra": 6128,
158
+ "matched": 22503,
159
+ "fidelity": 0.368
160
+ },
161
+ {
162
+ "time_start": 66.87,
163
+ "time_end": 73.56,
164
+ "missing": 39224,
165
+ "extra": 5250,
166
+ "matched": 22932,
167
+ "fidelity": 0.369
168
+ },
169
+ {
170
+ "time_start": 73.56,
171
+ "time_end": 80.25,
172
+ "missing": 39544,
173
+ "extra": 5728,
174
+ "matched": 22623,
175
+ "fidelity": 0.364
176
+ },
177
+ {
178
+ "time_start": 80.25,
179
+ "time_end": 86.94,
180
+ "missing": 39767,
181
+ "extra": 5753,
182
+ "matched": 23063,
183
+ "fidelity": 0.367
184
+ },
185
+ {
186
+ "time_start": 86.94,
187
+ "time_end": 93.62,
188
+ "missing": 37375,
189
+ "extra": 6104,
190
+ "matched": 24441,
191
+ "fidelity": 0.395
192
+ },
193
+ {
194
+ "time_start": 93.62,
195
+ "time_end": 100.31,
196
+ "missing": 37995,
197
+ "extra": 6150,
198
+ "matched": 21667,
199
+ "fidelity": 0.363
200
+ },
201
+ {
202
+ "time_start": 100.31,
203
+ "time_end": 107.0,
204
+ "missing": 39151,
205
+ "extra": 5272,
206
+ "matched": 19814,
207
+ "fidelity": 0.336
208
+ },
209
+ {
210
+ "time_start": 107.0,
211
+ "time_end": 113.68,
212
+ "missing": 35573,
213
+ "extra": 6465,
214
+ "matched": 20390,
215
+ "fidelity": 0.364
216
+ },
217
+ {
218
+ "time_start": 113.68,
219
+ "time_end": 120.37,
220
+ "missing": 35737,
221
+ "extra": 10055,
222
+ "matched": 14143,
223
+ "fidelity": 0.284
224
+ },
225
+ {
226
+ "time_start": 120.37,
227
+ "time_end": 127.06,
228
+ "missing": 36421,
229
+ "extra": 6311,
230
+ "matched": 19330,
231
+ "fidelity": 0.347
232
+ },
233
+ {
234
+ "time_start": 127.06,
235
+ "time_end": 133.75,
236
+ "missing": 47035,
237
+ "extra": 787,
238
+ "matched": 3048,
239
+ "fidelity": 0.061
240
+ }
241
+ ],
242
+ "missing_notes": [
243
+ {
244
+ "pitch": 35,
245
+ "note": "B1",
246
+ "time_start": 96.595,
247
+ "time_end": 100.566,
248
+ "duration": 3.971,
249
+ "energy": 0.512
250
+ },
251
+ {
252
+ "pitch": 35,
253
+ "note": "B1",
254
+ "time_start": 18.112,
255
+ "time_end": 21.246,
256
+ "duration": 3.135,
257
+ "energy": 0.611
258
+ },
259
+ {
260
+ "pitch": 62,
261
+ "note": "D4",
262
+ "time_start": 4.853,
263
+ "time_end": 7.732,
264
+ "duration": 2.879,
265
+ "energy": 0.665
266
+ },
267
+ {
268
+ "pitch": 64,
269
+ "note": "E4",
270
+ "time_start": 34.435,
271
+ "time_end": 36.827,
272
+ "duration": 2.392,
273
+ "energy": 0.71
274
+ },
275
+ {
276
+ "pitch": 43,
277
+ "note": "G2",
278
+ "time_start": 75.024,
279
+ "time_end": 77.996,
280
+ "duration": 2.972,
281
+ "energy": 0.571
282
+ },
283
+ {
284
+ "pitch": 65,
285
+ "note": "F4",
286
+ "time_start": 46.788,
287
+ "time_end": 49.041,
288
+ "duration": 2.252,
289
+ "energy": 0.73
290
+ },
291
+ {
292
+ "pitch": 65,
293
+ "note": "F4",
294
+ "time_start": 100.287,
295
+ "time_end": 102.655,
296
+ "duration": 2.368,
297
+ "energy": 0.694
298
+ },
299
+ {
300
+ "pitch": 55,
301
+ "note": "G3",
302
+ "time_start": 77.857,
303
+ "time_end": 80.573,
304
+ "duration": 2.717,
305
+ "energy": 0.578
306
+ },
307
+ {
308
+ "pitch": 59,
309
+ "note": "B3",
310
+ "time_start": 45.79,
311
+ "time_end": 48.042,
312
+ "duration": 2.252,
313
+ "energy": 0.656
314
+ },
315
+ {
316
+ "pitch": 50,
317
+ "note": "D3",
318
+ "time_start": 51.107,
319
+ "time_end": 53.359,
320
+ "duration": 2.252,
321
+ "energy": 0.649
322
+ },
323
+ {
324
+ "pitch": 45,
325
+ "note": "A2",
326
+ "time_start": 112.315,
327
+ "time_end": 114.962,
328
+ "duration": 2.647,
329
+ "energy": 0.552
330
+ },
331
+ {
332
+ "pitch": 45,
333
+ "note": "A2",
334
+ "time_start": 49.087,
335
+ "time_end": 51.432,
336
+ "duration": 2.345,
337
+ "energy": 0.617
338
+ },
339
+ {
340
+ "pitch": 68,
341
+ "note": "G#4",
342
+ "time_start": 100.566,
343
+ "time_end": 102.725,
344
+ "duration": 2.159,
345
+ "energy": 0.666
346
+ },
347
+ {
348
+ "pitch": 60,
349
+ "note": "C4",
350
+ "time_start": 129.939,
351
+ "time_end": 132.423,
352
+ "duration": 2.485,
353
+ "energy": 0.565
354
+ },
355
+ {
356
+ "pitch": 59,
357
+ "note": "B3",
358
+ "time_start": 117.656,
359
+ "time_end": 120.303,
360
+ "duration": 2.647,
361
+ "energy": 0.521
362
+ },
363
+ {
364
+ "pitch": 62,
365
+ "note": "D4",
366
+ "time_start": 100.008,
367
+ "time_end": 102.028,
368
+ "duration": 2.02,
369
+ "energy": 0.679
370
+ },
371
+ {
372
+ "pitch": 62,
373
+ "note": "D4",
374
+ "time_start": 46.208,
375
+ "time_end": 48.089,
376
+ "duration": 1.881,
377
+ "energy": 0.725
378
+ },
379
+ {
380
+ "pitch": 67,
381
+ "note": "G4",
382
+ "time_start": 128.499,
383
+ "time_end": 130.868,
384
+ "duration": 2.368,
385
+ "energy": 0.57
386
+ },
387
+ {
388
+ "pitch": 50,
389
+ "note": "D3",
390
+ "time_start": 4.714,
391
+ "time_end": 7.221,
392
+ "duration": 2.508,
393
+ "energy": 0.535
394
+ },
395
+ {
396
+ "pitch": 55,
397
+ "note": "G3",
398
+ "time_start": 125.922,
399
+ "time_end": 128.058,
400
+ "duration": 2.136,
401
+ "energy": 0.597
402
+ }
403
+ ],
404
+ "extra_notes": [
405
+ {
406
+ "pitch": 97,
407
+ "note": "C#7",
408
+ "time_start": 6.966,
409
+ "time_end": 9.218,
410
+ "duration": 2.252,
411
+ "energy": 0.651
412
+ },
413
+ {
414
+ "pitch": 103,
415
+ "note": "G7",
416
+ "time_start": 112.315,
417
+ "time_end": 114.567,
418
+ "duration": 2.252,
419
+ "energy": 0.623
420
+ },
421
+ {
422
+ "pitch": 104,
423
+ "note": "G#7",
424
+ "time_start": 6.966,
425
+ "time_end": 9.218,
426
+ "duration": 2.252,
427
+ "energy": 0.576
428
+ },
429
+ {
430
+ "pitch": 108,
431
+ "note": "C8",
432
+ "time_start": 112.315,
433
+ "time_end": 114.567,
434
+ "duration": 2.252,
435
+ "energy": 0.573
436
+ },
437
+ {
438
+ "pitch": 107,
439
+ "note": "B7",
440
+ "time_start": 47.903,
441
+ "time_end": 50.016,
442
+ "duration": 2.113,
443
+ "energy": 0.534
444
+ },
445
+ {
446
+ "pitch": 107,
447
+ "note": "B7",
448
+ "time_start": 101.378,
449
+ "time_end": 103.12,
450
+ "duration": 1.741,
451
+ "energy": 0.611
452
+ },
453
+ {
454
+ "pitch": 105,
455
+ "note": "A7",
456
+ "time_start": 116.541,
457
+ "time_end": 118.399,
458
+ "duration": 1.858,
459
+ "energy": 0.564
460
+ },
461
+ {
462
+ "pitch": 99,
463
+ "note": "D#7",
464
+ "time_start": 115.426,
465
+ "time_end": 118.027,
466
+ "duration": 2.601,
467
+ "energy": 0.391
468
+ },
469
+ {
470
+ "pitch": 103,
471
+ "note": "G7",
472
+ "time_start": 93.112,
473
+ "time_end": 95.62,
474
+ "duration": 2.508,
475
+ "energy": 0.387
476
+ },
477
+ {
478
+ "pitch": 95,
479
+ "note": "B6",
480
+ "time_start": 113.871,
481
+ "time_end": 115.659,
482
+ "duration": 1.788,
483
+ "energy": 0.539
484
+ },
485
+ {
486
+ "pitch": 86,
487
+ "note": "D6",
488
+ "time_start": 0.58,
489
+ "time_end": 2.159,
490
+ "duration": 1.579,
491
+ "energy": 0.596
492
+ },
493
+ {
494
+ "pitch": 91,
495
+ "note": "G6",
496
+ "time_start": 0.58,
497
+ "time_end": 2.159,
498
+ "duration": 1.579,
499
+ "energy": 0.558
500
+ },
501
+ {
502
+ "pitch": 107,
503
+ "note": "B7",
504
+ "time_start": 67.291,
505
+ "time_end": 68.801,
506
+ "duration": 1.509,
507
+ "energy": 0.531
508
+ },
509
+ {
510
+ "pitch": 104,
511
+ "note": "G#7",
512
+ "time_start": 62.02,
513
+ "time_end": 63.321,
514
+ "duration": 1.3,
515
+ "energy": 0.605
516
+ },
517
+ {
518
+ "pitch": 108,
519
+ "note": "C8",
520
+ "time_start": 51.13,
521
+ "time_end": 52.477,
522
+ "duration": 1.347,
523
+ "energy": 0.555
524
+ },
525
+ {
526
+ "pitch": 103,
527
+ "note": "G7",
528
+ "time_start": 40.403,
529
+ "time_end": 41.982,
530
+ "duration": 1.579,
531
+ "energy": 0.452
532
+ },
533
+ {
534
+ "pitch": 103,
535
+ "note": "G7",
536
+ "time_start": 104.536,
537
+ "time_end": 105.837,
538
+ "duration": 1.3,
539
+ "energy": 0.549
540
+ },
541
+ {
542
+ "pitch": 96,
543
+ "note": "C7",
544
+ "time_start": 118.445,
545
+ "time_end": 119.722,
546
+ "duration": 1.277,
547
+ "energy": 0.558
548
+ },
549
+ {
550
+ "pitch": 79,
551
+ "note": "G5",
552
+ "time_start": 0.58,
553
+ "time_end": 1.602,
554
+ "duration": 1.022,
555
+ "energy": 0.685
556
+ },
557
+ {
558
+ "pitch": 84,
559
+ "note": "C6",
560
+ "time_start": 0.58,
561
+ "time_end": 1.602,
562
+ "duration": 1.022,
563
+ "energy": 0.682
564
+ }
565
+ ]
566
+ }
app/src/App.jsx ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useEffect, useRef, useCallback } from 'react';
2
+ import * as Tone from 'tone';
3
+ import PianoRoll from './components/PianoRoll';
4
+ import Controls from './components/Controls';
5
+ import OctopusLogo from './components/OctopusLogo';
6
+ import { useMidi } from './hooks/useMidi';
7
+ import { usePlayback } from './hooks/usePlayback';
8
+ import { buildKeyboardLayout } from './utils/midiHelpers';
9
+
10
+ const API_BASE = import.meta.env.DEV ? 'http://localhost:8000' : '';
11
+
12
+ // App states: 'upload' -> 'loading' -> 'player'
13
+
14
+ function UploadScreen({ onFileSelected }) {
15
+ const [isDragging, setIsDragging] = useState(false);
16
+ const [errorMsg, setErrorMsg] = useState('');
17
+ const fileInputRef = useRef(null);
18
+
19
+ const handleFile = useCallback((file) => {
20
+ if (!file) return;
21
+ const ext = file.name.split('.').pop().toLowerCase();
22
+ if (!['mp3', 'm4a', 'wav', 'ogg', 'flac'].includes(ext)) {
23
+ setErrorMsg('Please upload an audio file (MP3, M4A, WAV, OGG, or FLAC)');
24
+ return;
25
+ }
26
+ setErrorMsg('');
27
+ onFileSelected(file);
28
+ }, [onFileSelected]);
29
+
30
+ const handleDrop = useCallback((e) => {
31
+ e.preventDefault();
32
+ setIsDragging(false);
33
+ handleFile(e.dataTransfer.files[0]);
34
+ }, [handleFile]);
35
+
36
+ const handleDragOver = useCallback((e) => {
37
+ e.preventDefault();
38
+ setIsDragging(true);
39
+ }, []);
40
+
41
+ const handleDragLeave = useCallback(() => {
42
+ setIsDragging(false);
43
+ }, []);
44
+
45
+ const handleFileSelect = useCallback((e) => {
46
+ handleFile(e.target.files[0]);
47
+ }, [handleFile]);
48
+
49
+ return (
50
+ <div className="upload-screen">
51
+ <div className="upload-content">
52
+ <div className="upload-logo">
53
+ <OctopusLogo size={80} />
54
+ <h1>Mr. Octopus</h1>
55
+ <p className="upload-tagline">Your AI piano teacher</p>
56
+ </div>
57
+
58
+ <p className="upload-description">
59
+ Drop a song and Mr. Octopus will transcribe it into a piano tutorial
60
+ you can follow along with, note by note. Works best with clearly
61
+ recorded solo piano pieces.
62
+ </p>
63
+
64
+ <div
65
+ className={`drop-zone ${isDragging ? 'dragging' : ''}`}
66
+ onDrop={handleDrop}
67
+ onDragOver={handleDragOver}
68
+ onDragLeave={handleDragLeave}
69
+ onClick={() => fileInputRef.current?.click()}
70
+ >
71
+ <div className="drop-icon">&#9835;</div>
72
+ <p>Drag & drop an audio file</p>
73
+ <p className="drop-hint">MP3, M4A, WAV, OGG, FLAC</p>
74
+ <input
75
+ ref={fileInputRef}
76
+ type="file"
77
+ accept="audio/*,.m4a,.mp3,.wav,.ogg,.flac"
78
+ onChange={handleFileSelect}
79
+ hidden
80
+ />
81
+ </div>
82
+
83
+ <div className="copyright-notice">
84
+ Please only upload audio you have the rights to use.
85
+ </div>
86
+
87
+ {errorMsg && (
88
+ <div className="upload-error">{errorMsg}</div>
89
+ )}
90
+ </div>
91
+ </div>
92
+ );
93
+ }
94
+
95
+ function LoadingScreen({ status }) {
96
+ return (
97
+ <div className="upload-screen">
98
+ <div className="upload-processing">
99
+ <div className="processing-logo">
100
+ <OctopusLogo size={72} />
101
+ </div>
102
+ <h2>{status}</h2>
103
+ <p className="loading-sub">This usually takes 5-10 seconds</p>
104
+ <div className="loading-bar">
105
+ <div className="loading-bar-fill" />
106
+ </div>
107
+ </div>
108
+ </div>
109
+ );
110
+ }
111
+
112
+ export default function App() {
113
+ const containerRef = useRef(null);
114
+ const [dimensions, setDimensions] = useState({ width: 800, height: 600 });
115
+ const [screen, setScreen] = useState('upload'); // 'upload' | 'loading' | 'player'
116
+ const [loadingStatus, setLoadingStatus] = useState('');
117
+ const [chords, setChords] = useState([]);
118
+
119
+ const { notes, totalDuration, fileName, loadFromUrl, loadFromBlob } = useMidi();
120
+
121
+ const {
122
+ isPlaying,
123
+ currentTimeRef,
124
+ activeNotes,
125
+ tempo,
126
+ samplesLoaded,
127
+ loopStart,
128
+ loopEnd,
129
+ isLooping,
130
+ togglePlayPause,
131
+ setTempo,
132
+ seekTo,
133
+ scheduleNotes,
134
+ setLoopA,
135
+ setLoopB,
136
+ clearLoop,
137
+ } = usePlayback();
138
+
139
+ // When samples are loaded and we have notes, transition to player
140
+ useEffect(() => {
141
+ if (screen === 'loading' && samplesLoaded && notes.length > 0) {
142
+ setScreen('player');
143
+ }
144
+ }, [screen, samplesLoaded, notes.length]);
145
+
146
+ const handleFileSelected = useCallback(async (file) => {
147
+ setScreen('loading');
148
+ setLoadingStatus('Transcribing your song...');
149
+ try {
150
+ const form = new FormData();
151
+ form.append('file', file);
152
+ const res = await fetch(`${API_BASE}/api/transcribe`, {
153
+ method: 'POST',
154
+ body: form,
155
+ });
156
+ if (!res.ok) {
157
+ const err = await res.json().catch(() => ({ detail: res.statusText }));
158
+ throw new Error(err.detail || 'Transcription failed');
159
+ }
160
+ const data = await res.json();
161
+
162
+ setLoadingStatus('Loading piano sounds...');
163
+ const midiRes = await fetch(`${API_BASE}${data.midi_url}`);
164
+ const blob = await midiRes.blob();
165
+ loadFromBlob(blob, file.name.replace(/\.[^.]+$/, '.mid'));
166
+
167
+ if (data.chords) {
168
+ const chordList = data.chords?.chords || data.chords || [];
169
+ setChords(Array.isArray(chordList) ? chordList : []);
170
+ }
171
+ // Screen transition to 'player' happens via the useEffect above
172
+ // once both samplesLoaded and notes.length > 0
173
+ } catch (e) {
174
+ setScreen('upload');
175
+ alert(e.message || 'Something went wrong. Please try again.');
176
+ }
177
+ }, [loadFromBlob]);
178
+
179
+ const handleNewSong = useCallback(() => {
180
+ setScreen('upload');
181
+ setChords([]);
182
+ }, []);
183
+
184
+ // Reschedule audio when notes change
185
+ useEffect(() => {
186
+ if (notes.length > 0) {
187
+ scheduleNotes(notes, totalDuration);
188
+ }
189
+ }, [notes, totalDuration, scheduleNotes]);
190
+
191
+ // Handle resize
192
+ useEffect(() => {
193
+ const el = containerRef.current;
194
+ if (!el) return;
195
+
196
+ const ro = new ResizeObserver(([entry]) => {
197
+ const { width, height } = entry.contentRect;
198
+ if (width > 0 && height > 0) {
199
+ setDimensions({ width, height });
200
+ }
201
+ });
202
+ ro.observe(el);
203
+ return () => ro.disconnect();
204
+ }, [screen]);
205
+
206
+ const keyboardLayout = buildKeyboardLayout(dimensions.width);
207
+
208
+ const handleTogglePlay = useCallback(async () => {
209
+ if (!samplesLoaded) return;
210
+ await Tone.start();
211
+ togglePlayPause();
212
+ }, [togglePlayPause, samplesLoaded]);
213
+
214
+ if (screen === 'upload') {
215
+ return <UploadScreen onFileSelected={handleFileSelected} />;
216
+ }
217
+
218
+ if (screen === 'loading') {
219
+ return <LoadingScreen status={loadingStatus} />;
220
+ }
221
+
222
+ return (
223
+ <div className="app">
224
+ <Controls
225
+ isPlaying={isPlaying}
226
+ togglePlayPause={handleTogglePlay}
227
+ tempo={tempo}
228
+ setTempo={setTempo}
229
+ currentTimeRef={currentTimeRef}
230
+ totalDuration={totalDuration}
231
+ seekTo={seekTo}
232
+ fileName={fileName}
233
+ onNewSong={handleNewSong}
234
+ loopStart={loopStart}
235
+ loopEnd={loopEnd}
236
+ isLooping={isLooping}
237
+ onSetLoopA={setLoopA}
238
+ onSetLoopB={setLoopB}
239
+ onClearLoop={clearLoop}
240
+ />
241
+ <div className="canvas-container" ref={containerRef}>
242
+ <PianoRoll
243
+ notes={notes}
244
+ currentTimeRef={currentTimeRef}
245
+ activeNotes={activeNotes}
246
+ keyboardLayout={keyboardLayout}
247
+ width={dimensions.width}
248
+ height={dimensions.height}
249
+ loopStart={loopStart}
250
+ loopEnd={loopEnd}
251
+ chords={chords}
252
+ />
253
+ </div>
254
+ </div>
255
+ );
256
+ }
app/src/components/Controls.jsx ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useEffect, useRef } from 'react';
2
+ import OctopusLogo from './OctopusLogo';
3
+
4
+ function formatTime(s) {
5
+ const m = Math.floor(s / 60);
6
+ const sec = Math.floor(s % 60);
7
+ return `${m}:${sec.toString().padStart(2, '0')}`;
8
+ }
9
+
10
+ export default function Controls({
11
+ isPlaying,
12
+ togglePlayPause,
13
+ tempo,
14
+ setTempo,
15
+ currentTimeRef,
16
+ totalDuration,
17
+ seekTo,
18
+ fileName,
19
+ onNewSong,
20
+ loopStart,
21
+ loopEnd,
22
+ isLooping,
23
+ onSetLoopA,
24
+ onSetLoopB,
25
+ onClearLoop,
26
+ }) {
27
+ const [displayTime, setDisplayTime] = useState(0);
28
+ const intervalRef = useRef(null);
29
+
30
+ useEffect(() => {
31
+ intervalRef.current = setInterval(() => {
32
+ setDisplayTime(currentTimeRef.current);
33
+ }, 50);
34
+ return () => clearInterval(intervalRef.current);
35
+ }, [currentTimeRef]);
36
+
37
+ const progress = totalDuration > 0 ? (displayTime / totalDuration) * 100 : 0;
38
+
39
+ // Loop region markers for the timeline
40
+ const loopStartPct = loopStart !== null && totalDuration > 0
41
+ ? (loopStart / totalDuration) * 100 : null;
42
+ const loopEndPct = loopEnd !== null && totalDuration > 0
43
+ ? (loopEnd / totalDuration) * 100 : null;
44
+
45
+ // Build timeline background with loop region
46
+ let timelineBg;
47
+ if (loopStartPct !== null && loopEndPct !== null) {
48
+ timelineBg = `linear-gradient(to right,
49
+ var(--border) ${loopStartPct}%,
50
+ rgba(139, 92, 246, 0.3) ${loopStartPct}%,
51
+ var(--primary) ${Math.min(progress, loopEndPct)}%,
52
+ rgba(139, 92, 246, 0.3) ${Math.min(progress, loopEndPct)}%,
53
+ rgba(139, 92, 246, 0.3) ${loopEndPct}%,
54
+ var(--border) ${loopEndPct}%)`;
55
+ } else {
56
+ timelineBg = `linear-gradient(to right, var(--primary) ${progress}%, var(--border) ${progress}%)`;
57
+ }
58
+
59
+ return (
60
+ <div className="controls">
61
+ {/* Main controls row */}
62
+ <div className="controls-main">
63
+ <div className="controls-left">
64
+ <div className="brand-mark">
65
+ <OctopusLogo size={28} />
66
+ <span className="brand-name">Mr. Octopus</span>
67
+ </div>
68
+ {fileName && (
69
+ <span className="file-name">{fileName.replace(/\.[^.]+$/, '')}</span>
70
+ )}
71
+ </div>
72
+
73
+ <div className="controls-center">
74
+ <button
75
+ className="transport-btn"
76
+ onClick={() => seekTo(Math.max(0, displayTime - 5))}
77
+ title="Back 5s"
78
+ >
79
+ <svg width="16" height="16" viewBox="0 0 24 24" fill="currentColor">
80
+ <path d="M11 18V6l-8.5 6 8.5 6zm.5-6l8.5 6V6l-8.5 6z" />
81
+ </svg>
82
+ </button>
83
+
84
+ <button className="play-btn" onClick={togglePlayPause}>
85
+ {isPlaying ? (
86
+ <svg width="20" height="20" viewBox="0 0 24 24" fill="currentColor">
87
+ <rect x="6" y="4" width="4" height="16" rx="1" />
88
+ <rect x="14" y="4" width="4" height="16" rx="1" />
89
+ </svg>
90
+ ) : (
91
+ <svg width="20" height="20" viewBox="0 0 24 24" fill="currentColor">
92
+ <path d="M8 5v14l11-7z" />
93
+ </svg>
94
+ )}
95
+ </button>
96
+
97
+ <button
98
+ className="transport-btn"
99
+ onClick={() => seekTo(Math.min(totalDuration, displayTime + 5))}
100
+ title="Forward 5s"
101
+ >
102
+ <svg width="16" height="16" viewBox="0 0 24 24" fill="currentColor">
103
+ <path d="M4 18l8.5-6L4 6v12zm9-12v12l8.5-6L13 6z" />
104
+ </svg>
105
+ </button>
106
+ </div>
107
+
108
+ <div className="controls-right">
109
+ {/* Loop controls */}
110
+ <div className="loop-controls">
111
+ {!isLooping ? (
112
+ <>
113
+ <button
114
+ className={`btn btn-loop ${loopStart !== null ? 'active' : ''}`}
115
+ onClick={onSetLoopA}
116
+ title="Set loop start point (A)"
117
+ >
118
+ A
119
+ </button>
120
+ <button
121
+ className={`btn btn-loop ${loopEnd !== null ? 'active' : ''}`}
122
+ onClick={onSetLoopB}
123
+ disabled={loopStart === null}
124
+ title="Set loop end point (B)"
125
+ >
126
+ B
127
+ </button>
128
+ </>
129
+ ) : (
130
+ <button
131
+ className="btn btn-loop active"
132
+ onClick={onClearLoop}
133
+ title="Clear loop"
134
+ >
135
+ {formatTime(loopStart)} - {formatTime(loopEnd)}
136
+ <span className="loop-x">&times;</span>
137
+ </button>
138
+ )}
139
+ </div>
140
+
141
+ {onNewSong && (
142
+ <button className="btn btn-new" onClick={onNewSong}>
143
+ + New Song
144
+ </button>
145
+ )}
146
+
147
+ <div className="tempo-control">
148
+ <span className="tempo-label">Speed</span>
149
+ <input
150
+ type="range"
151
+ min={50}
152
+ max={200}
153
+ value={tempo}
154
+ onChange={(e) => setTempo(Number(e.target.value))}
155
+ />
156
+ <span className="tempo-value">{tempo}%</span>
157
+ </div>
158
+ </div>
159
+ </div>
160
+
161
+ {/* Timeline row */}
162
+ <div className="timeline">
163
+ <span className="timeline-time">{formatTime(displayTime)}</span>
164
+ <div className="timeline-track">
165
+ <input
166
+ type="range"
167
+ min={0}
168
+ max={totalDuration || 1}
169
+ step={0.1}
170
+ value={displayTime}
171
+ onChange={(e) => seekTo(Number(e.target.value))}
172
+ style={{ background: timelineBg }}
173
+ />
174
+ </div>
175
+ <span className="timeline-time">{formatTime(totalDuration)}</span>
176
+ </div>
177
+ </div>
178
+ );
179
+ }
app/src/components/OctopusLogo.jsx ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useId } from 'react';
2
+
3
+ export default function OctopusLogo({ size = 48 }) {
4
+ const id = useId();
5
+ const gradId = `octo${id.replace(/:/g, '')}`;
6
+
7
+ const legs = [
8
+ [-66, 'W'],
9
+ [-47, 'W'],
10
+ [-28, 'B'],
11
+ [-9, 'W'],
12
+ [9, 'W'],
13
+ [28, 'B'],
14
+ [47, 'W'],
15
+ [66, 'W'],
16
+ ];
17
+
18
+ const ox = 24, oy = 25;
19
+
20
+ return (
21
+ <svg width={size} height={size} viewBox="0 0 48 48" fill="none" xmlns="http://www.w3.org/2000/svg">
22
+ <defs>
23
+ <linearGradient id={gradId} x1="10" y1="8" x2="38" y2="44" gradientUnits="userSpaceOnUse">
24
+ <stop stopColor="#a78bfa" />
25
+ <stop offset="1" stopColor="#06b6d4" />
26
+ </linearGradient>
27
+ </defs>
28
+
29
+ {/* Tentacle legs — piano keys splayed outward */}
30
+ {legs.map(([angle, type], i) => {
31
+ const isBlack = type === 'B';
32
+ const w = isBlack ? 2.2 : 3;
33
+ const h = isBlack ? 11 : 16;
34
+ const rx = isBlack ? 0.7 : 1;
35
+ return (
36
+ <g key={i} transform={`rotate(${angle}, ${ox}, ${oy})`}>
37
+ <rect x={ox - w / 2} y={oy} width={w} height={h} rx={rx}
38
+ fill={isBlack ? '#1e1b4b' : '#f0eef5'} />
39
+ <rect x={ox - w / 2} y={oy} width={w} height={h} rx={rx}
40
+ fill={isBlack ? '#4338ca' : '#c8c5d6'}
41
+ opacity={isBlack ? 0.12 : 0.2} />
42
+ </g>
43
+ );
44
+ })}
45
+
46
+ {/* Head */}
47
+ <ellipse cx="24" cy="16" rx="14" ry="12" fill={`url(#${gradId})`} />
48
+ <circle cx="19" cy="14" r="2.5" fill="#07070e" />
49
+ <circle cx="29" cy="14" r="2.5" fill="#07070e" />
50
+ <circle cx="20" cy="13.3" r="0.9" fill="white" opacity="0.9" />
51
+ <circle cx="30" cy="13.3" r="0.9" fill="white" opacity="0.9" />
52
+ </svg>
53
+ );
54
+ }
app/src/components/PianoRoll.jsx ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useRef, useEffect } from 'react';
2
+ import { COLORS, noteColor, noteGlowColor } from '../utils/colorScheme';
3
+ import {
4
+ buildNotePositionMap,
5
+ noteXPositionFast,
6
+ getVisibleNotes,
7
+ isBlackKey,
8
+ } from '../utils/midiHelpers';
9
+
10
+ const LOOK_AHEAD_SECONDS = 4;
11
+ const KEYBOARD_HEIGHT_RATIO = 0.18; // keyboard takes 18% of canvas height
12
+ const MIN_KEYBOARD_HEIGHT = 80;
13
+ const MAX_KEYBOARD_HEIGHT = 150;
14
+
15
+ function drawRoundedRect(ctx, x, y, w, h, r) {
16
+ if (h < 0) return;
17
+ r = Math.min(r, w / 2, h / 2);
18
+ ctx.beginPath();
19
+ ctx.moveTo(x + r, y);
20
+ ctx.lineTo(x + w - r, y);
21
+ ctx.quadraticCurveTo(x + w, y, x + w, y + r);
22
+ ctx.lineTo(x + w, y + h - r);
23
+ ctx.quadraticCurveTo(x + w, y + h, x + w - r, y + h);
24
+ ctx.lineTo(x + r, y + h);
25
+ ctx.quadraticCurveTo(x, y + h, x, y + h - r);
26
+ ctx.lineTo(x, y + r);
27
+ ctx.quadraticCurveTo(x, y, x + r, y);
28
+ ctx.closePath();
29
+ }
30
+
31
+ function drawFallingNotes(ctx, notes, currentTime, hitLineY, positionMap) {
32
+ const pixelsPerSecond = hitLineY / LOOK_AHEAD_SECONDS;
33
+ const visibleNotes = getVisibleNotes(notes, currentTime, LOOK_AHEAD_SECONDS, 0.5);
34
+
35
+ ctx.save();
36
+
37
+ for (const note of visibleNotes) {
38
+ const noteBottom = hitLineY - (note.time - currentTime) * pixelsPerSecond;
39
+ const noteTop =
40
+ hitLineY - (note.time + note.duration - currentTime) * pixelsPerSecond;
41
+
42
+ // Clip to note area
43
+ if (noteBottom < 0 || noteTop > hitLineY) continue;
44
+
45
+ const clippedTop = Math.max(noteTop, 0);
46
+ const clippedBottom = Math.min(noteBottom, hitLineY);
47
+ const height = clippedBottom - clippedTop;
48
+ if (height < 1) continue;
49
+
50
+ const pos = noteXPositionFast(note.midi, positionMap);
51
+ if (!pos) continue;
52
+
53
+ const padding = 1;
54
+ const x = pos.x + padding;
55
+ const w = pos.width - padding * 2;
56
+
57
+ // Glow
58
+ ctx.shadowColor = noteGlowColor(note.midi);
59
+ ctx.shadowBlur = 12;
60
+
61
+ // Note body
62
+ ctx.fillStyle = noteColor(note.midi);
63
+ drawRoundedRect(ctx, x, clippedTop, w, height, 4);
64
+ ctx.fill();
65
+
66
+ // Brighter edge at the bottom (hitting edge)
67
+ if (noteBottom <= hitLineY && noteBottom >= hitLineY - 3) {
68
+ ctx.shadowBlur = 20;
69
+ ctx.fillStyle = noteGlowColor(note.midi);
70
+ ctx.fillRect(x, hitLineY - 3, w, 3);
71
+ }
72
+ }
73
+
74
+ ctx.shadowBlur = 0;
75
+ ctx.restore();
76
+ }
77
+
78
+ function drawHitLine(ctx, y, width) {
79
+ ctx.save();
80
+ ctx.shadowColor = COLORS.hitLine;
81
+ ctx.shadowBlur = 8;
82
+ ctx.strokeStyle = COLORS.hitLine;
83
+ ctx.lineWidth = 2;
84
+ ctx.beginPath();
85
+ ctx.moveTo(0, y);
86
+ ctx.lineTo(width, y);
87
+ ctx.stroke();
88
+ ctx.shadowBlur = 0;
89
+ ctx.restore();
90
+ }
91
+
92
+ function drawKeyboard(ctx, keyboardLayout, keyboardY, keyboardHeight, activeNotes) {
93
+ const blackKeyHeight = keyboardHeight * 0.62;
94
+
95
+ // White keys
96
+ for (const key of keyboardLayout) {
97
+ if (key.isBlack) continue;
98
+ const isActive = activeNotes.has(key.midiNumber);
99
+
100
+ ctx.fillStyle = isActive ? COLORS.whiteKeyActive : COLORS.whiteKey;
101
+ ctx.fillRect(key.x, keyboardY, key.width, keyboardHeight);
102
+
103
+ ctx.strokeStyle = COLORS.keyBorder;
104
+ ctx.lineWidth = 1;
105
+ ctx.strokeRect(key.x, keyboardY, key.width, keyboardHeight);
106
+
107
+ if (isActive) {
108
+ ctx.save();
109
+ ctx.shadowColor = noteGlowColor(key.midiNumber);
110
+ ctx.shadowBlur = 15;
111
+ ctx.fillStyle = isActive ? COLORS.whiteKeyActive : COLORS.whiteKey;
112
+ ctx.fillRect(key.x + 1, keyboardY, key.width - 2, keyboardHeight);
113
+ ctx.shadowBlur = 0;
114
+ ctx.restore();
115
+ }
116
+ }
117
+
118
+ // Black keys (drawn on top)
119
+ for (const key of keyboardLayout) {
120
+ if (!key.isBlack) continue;
121
+ const isActive = activeNotes.has(key.midiNumber);
122
+
123
+ ctx.fillStyle = isActive ? COLORS.blackKeyActive : COLORS.blackKey;
124
+ ctx.fillRect(key.x, keyboardY, key.width, blackKeyHeight);
125
+
126
+ if (isActive) {
127
+ ctx.save();
128
+ ctx.shadowColor = noteGlowColor(key.midiNumber);
129
+ ctx.shadowBlur = 15;
130
+ ctx.fillRect(key.x, keyboardY, key.width, blackKeyHeight);
131
+ ctx.shadowBlur = 0;
132
+ ctx.restore();
133
+ }
134
+
135
+ // Black key border
136
+ ctx.strokeStyle = '#000000';
137
+ ctx.lineWidth = 1;
138
+ ctx.strokeRect(key.x, keyboardY, key.width, blackKeyHeight);
139
+ }
140
+ }
141
+
142
+ function drawChordLabels(ctx, chords, currentTime, hitLineY, width) {
143
+ if (!chords || chords.length === 0) return;
144
+
145
+ const pixelsPerSecond = hitLineY / LOOK_AHEAD_SECONDS;
146
+ const CHORD_STRIP_HEIGHT = 28;
147
+
148
+ ctx.save();
149
+
150
+ // Find chords visible in the current window
151
+ for (const chord of chords) {
152
+ const startTime = chord.start_time;
153
+ const endTime = chord.end_time;
154
+
155
+ // Skip if entirely outside visible range
156
+ if (endTime < currentTime - 0.5 || startTime > currentTime + LOOK_AHEAD_SECONDS) continue;
157
+
158
+ // Skip single-note "chords"
159
+ if (chord.quality === 'note') continue;
160
+
161
+ const yBottom = hitLineY - (startTime - currentTime) * pixelsPerSecond;
162
+ const yTop = hitLineY - (endTime - currentTime) * pixelsPerSecond;
163
+
164
+ // Clip to visible area
165
+ if (yBottom < 0 || yTop > hitLineY) continue;
166
+
167
+ const clippedTop = Math.max(yTop, 0);
168
+ const clippedBottom = Math.min(yBottom, hitLineY);
169
+
170
+ // Draw chord label strip on the left side
171
+ const stripY = clippedTop;
172
+ const stripHeight = Math.max(CHORD_STRIP_HEIGHT, clippedBottom - clippedTop);
173
+
174
+ // Semi-transparent background pill
175
+ ctx.fillStyle = 'rgba(139, 92, 246, 0.15)';
176
+ drawRoundedRect(ctx, 8, stripY, 72, Math.min(CHORD_STRIP_HEIGHT, stripHeight), 6);
177
+ ctx.fill();
178
+
179
+ // Chord name text
180
+ ctx.font = 'bold 13px -apple-system, BlinkMacSystemFont, sans-serif';
181
+ ctx.textAlign = 'center';
182
+ ctx.textBaseline = 'middle';
183
+
184
+ // Color based on quality
185
+ const isMinor = chord.quality?.includes('minor') || chord.quality?.includes('dim');
186
+ ctx.fillStyle = isMinor ? 'rgba(167, 139, 250, 0.9)' : 'rgba(255, 255, 255, 0.9)';
187
+
188
+ const labelY = stripY + Math.min(CHORD_STRIP_HEIGHT, stripHeight) / 2;
189
+ ctx.fillText(chord.chord_name, 44, labelY);
190
+
191
+ // Subtle divider line across the full width at chord boundary
192
+ if (yBottom > 0 && yBottom < hitLineY) {
193
+ ctx.strokeStyle = 'rgba(139, 92, 246, 0.12)';
194
+ ctx.lineWidth = 1;
195
+ ctx.beginPath();
196
+ ctx.moveTo(0, yBottom);
197
+ ctx.lineTo(width, yBottom);
198
+ ctx.stroke();
199
+ }
200
+ }
201
+
202
+ ctx.restore();
203
+ }
204
+
205
+ function drawLoopMarkers(ctx, loopStart, loopEnd, currentTime, hitLineY, width) {
206
+ if (loopStart === null || loopEnd === null) return;
207
+ const pixelsPerSecond = hitLineY / LOOK_AHEAD_SECONDS;
208
+
209
+ // Draw loop region boundaries as dashed lines
210
+ for (const t of [loopStart, loopEnd]) {
211
+ const y = hitLineY - (t - currentTime) * pixelsPerSecond;
212
+ if (y < 0 || y > hitLineY) continue;
213
+
214
+ ctx.save();
215
+ ctx.setLineDash([6, 4]);
216
+ ctx.strokeStyle = 'rgba(139, 92, 246, 0.6)';
217
+ ctx.lineWidth = 2;
218
+ ctx.beginPath();
219
+ ctx.moveTo(0, y);
220
+ ctx.lineTo(width, y);
221
+ ctx.stroke();
222
+ ctx.setLineDash([]);
223
+ ctx.restore();
224
+ }
225
+
226
+ // Dim area outside the loop
227
+ const loopStartY = hitLineY - (loopStart - currentTime) * pixelsPerSecond;
228
+ const loopEndY = hitLineY - (loopEnd - currentTime) * pixelsPerSecond;
229
+
230
+ ctx.save();
231
+ ctx.fillStyle = 'rgba(0, 0, 0, 0.3)';
232
+ // Above loop end (future beyond loop)
233
+ if (loopEndY > 0) {
234
+ ctx.fillRect(0, 0, width, Math.min(loopEndY, hitLineY));
235
+ }
236
+ // Below loop start (past before loop)
237
+ if (loopStartY < hitLineY) {
238
+ ctx.fillRect(0, Math.max(loopStartY, 0), width, hitLineY - Math.max(loopStartY, 0));
239
+ }
240
+ ctx.restore();
241
+ }
242
+
243
+ export default function PianoRoll({
244
+ notes,
245
+ currentTimeRef,
246
+ activeNotes,
247
+ keyboardLayout,
248
+ width,
249
+ height,
250
+ loopStart,
251
+ loopEnd,
252
+ chords,
253
+ }) {
254
+ const canvasRef = useRef(null);
255
+ const positionMapRef = useRef(null);
256
+
257
+ // Rebuild position map when layout changes
258
+ useEffect(() => {
259
+ positionMapRef.current = buildNotePositionMap(keyboardLayout);
260
+ }, [keyboardLayout]);
261
+
262
+ // Main render loop
263
+ useEffect(() => {
264
+ const canvas = canvasRef.current;
265
+ if (!canvas) return;
266
+
267
+ const ctx = canvas.getContext('2d');
268
+ const dpr = window.devicePixelRatio || 1;
269
+
270
+ canvas.width = width * dpr;
271
+ canvas.height = height * dpr;
272
+ ctx.scale(dpr, dpr);
273
+
274
+ let frameId;
275
+
276
+ function render() {
277
+ const currentTime = currentTimeRef.current;
278
+ const keyboardHeight = Math.min(
279
+ MAX_KEYBOARD_HEIGHT,
280
+ Math.max(MIN_KEYBOARD_HEIGHT, height * KEYBOARD_HEIGHT_RATIO)
281
+ );
282
+ const hitLineY = height - keyboardHeight;
283
+
284
+ // Clear
285
+ ctx.fillStyle = COLORS.pianoRollBg;
286
+ ctx.fillRect(0, 0, width, height);
287
+
288
+ // Draw subtle grid lines for visual reference
289
+ ctx.strokeStyle = '#ffffff08';
290
+ ctx.lineWidth = 1;
291
+ const pixelsPerSecond = hitLineY / LOOK_AHEAD_SECONDS;
292
+ for (let s = 0; s < LOOK_AHEAD_SECONDS; s++) {
293
+ const y = hitLineY - s * pixelsPerSecond;
294
+ ctx.beginPath();
295
+ ctx.moveTo(0, y);
296
+ ctx.lineTo(width, y);
297
+ ctx.stroke();
298
+ }
299
+
300
+ // Falling notes
301
+ if (positionMapRef.current) {
302
+ drawFallingNotes(ctx, notes, currentTime, hitLineY, positionMapRef.current);
303
+ }
304
+
305
+ // Chord labels
306
+ drawChordLabels(ctx, chords, currentTime, hitLineY, width);
307
+
308
+ // Loop markers
309
+ drawLoopMarkers(ctx, loopStart, loopEnd, currentTime, hitLineY, width);
310
+
311
+ // Hit line
312
+ drawHitLine(ctx, hitLineY, width);
313
+
314
+ // Keyboard
315
+ drawKeyboard(ctx, keyboardLayout, hitLineY, keyboardHeight, activeNotes);
316
+
317
+ frameId = requestAnimationFrame(render);
318
+ }
319
+
320
+ render();
321
+
322
+ return () => cancelAnimationFrame(frameId);
323
+ }, [notes, keyboardLayout, activeNotes, width, height, currentTimeRef, loopStart, loopEnd, chords]);
324
+
325
+ return (
326
+ <canvas
327
+ ref={canvasRef}
328
+ style={{
329
+ width: `${width}px`,
330
+ height: `${height}px`,
331
+ display: 'block',
332
+ }}
333
+ />
334
+ );
335
+ }
app/src/hooks/useMidi.js ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useCallback } from 'react';
2
+ import { Midi } from '@tonejs/midi';
3
+ import { parseMidiFile } from '../utils/midiHelpers';
4
+ import { generateSampleMidi } from '../utils/generateSampleMidi';
5
+
6
+ export function useMidi() {
7
+ const [notes, setNotes] = useState([]);
8
+ const [totalDuration, setTotalDuration] = useState(0);
9
+ const [isLoading, setIsLoading] = useState(false);
10
+ const [error, setError] = useState(null);
11
+ const [fileName, setFileName] = useState('');
12
+
13
+ const loadFromUrl = useCallback(async (url) => {
14
+ setIsLoading(true);
15
+ setError(null);
16
+ try {
17
+ const midi = await Midi.fromUrl(url);
18
+ const result = parseMidiFile(midi);
19
+ setNotes(result.notes);
20
+ setTotalDuration(result.totalDuration);
21
+ setFileName(url.split('/').pop());
22
+ } catch (e) {
23
+ setError(e.message);
24
+ } finally {
25
+ setIsLoading(false);
26
+ }
27
+ }, []);
28
+
29
+ const loadFromFile = useCallback(async (file) => {
30
+ if (!file) return;
31
+ setIsLoading(true);
32
+ setError(null);
33
+ try {
34
+ const arrayBuffer = await file.arrayBuffer();
35
+ const midi = new Midi(arrayBuffer);
36
+ const result = parseMidiFile(midi);
37
+ setNotes(result.notes);
38
+ setTotalDuration(result.totalDuration);
39
+ setFileName(file.name);
40
+ } catch (e) {
41
+ setError(e.message);
42
+ } finally {
43
+ setIsLoading(false);
44
+ }
45
+ }, []);
46
+
47
+ const loadFromBlob = useCallback(async (blob, name = 'transcription.mid') => {
48
+ setIsLoading(true);
49
+ setError(null);
50
+ try {
51
+ const arrayBuffer = await blob.arrayBuffer();
52
+ const midi = new Midi(arrayBuffer);
53
+ const result = parseMidiFile(midi);
54
+ setNotes(result.notes);
55
+ setTotalDuration(result.totalDuration);
56
+ setFileName(name);
57
+ } catch (e) {
58
+ setError(e.message);
59
+ } finally {
60
+ setIsLoading(false);
61
+ }
62
+ }, []);
63
+
64
+ const loadSample = useCallback(() => {
65
+ const midi = generateSampleMidi();
66
+ const result = parseMidiFile(midi);
67
+ setNotes(result.notes);
68
+ setTotalDuration(result.totalDuration);
69
+ setFileName('Twinkle Twinkle Little Star');
70
+ }, []);
71
+
72
+ return {
73
+ notes,
74
+ totalDuration,
75
+ isLoading,
76
+ error,
77
+ fileName,
78
+ loadFromUrl,
79
+ loadFromFile,
80
+ loadFromBlob,
81
+ loadSample,
82
+ };
83
+ }
app/src/hooks/usePlayback.js ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useRef, useState, useCallback, useEffect } from 'react';
2
+ import * as Tone from 'tone';
3
+
4
+ // Salamander Grand Piano samples (Tone.js official CDN, interpolates between keys)
5
+ const SALAMANDER_URL = 'https://tonejs.github.io/audio/salamander/';
6
+ const SAMPLE_NOTES = [
7
+ 'A0', 'C1', 'D#1', 'F#1', 'A1', 'C2', 'D#2', 'F#2',
8
+ 'A2', 'C3', 'D#3', 'F#3', 'A3', 'C4', 'D#4', 'F#4',
9
+ 'A4', 'C5', 'D#5', 'F#5', 'A5', 'C6', 'D#6', 'F#6',
10
+ 'A6', 'C7', 'D#7', 'F#7', 'A7', 'C8',
11
+ ];
12
+
13
+ function buildSampleMap() {
14
+ const map = {};
15
+ for (const note of SAMPLE_NOTES) {
16
+ // OGG has no encoder padding delay; MP3 adds ~50ms of silence
17
+ map[note] = note.replace('#', 's') + '.ogg';
18
+ }
19
+ return map;
20
+ }
21
+
22
+ export function usePlayback() {
23
+ const [isPlaying, setIsPlaying] = useState(false);
24
+ const [tempo, setTempoState] = useState(100);
25
+ const [activeNotes, setActiveNotes] = useState(new Set());
26
+ const [samplesLoaded, setSamplesLoaded] = useState(false);
27
+ const [loopStart, setLoopStart] = useState(null);
28
+ const [loopEnd, setLoopEnd] = useState(null);
29
+
30
+ const currentTimeRef = useRef(0);
31
+ const animationRef = useRef(null);
32
+ const synthRef = useRef(null);
33
+ const scheduledRef = useRef(false);
34
+ const totalDurationRef = useRef(0);
35
+ const loopRef = useRef({ start: null, end: null });
36
+
37
+ // Keep loopRef in sync with state
38
+ useEffect(() => {
39
+ loopRef.current = { start: loopStart, end: loopEnd };
40
+ }, [loopStart, loopEnd]);
41
+
42
+ // Initialize Salamander Grand Piano sampler eagerly on mount
43
+ useEffect(() => {
44
+ if (!synthRef.current) {
45
+ synthRef.current = new Tone.Sampler({
46
+ urls: buildSampleMap(),
47
+ baseUrl: SALAMANDER_URL,
48
+ release: 1.5,
49
+ onload: () => setSamplesLoaded(true),
50
+ }).toDestination();
51
+ }
52
+ }, []);
53
+
54
+ const getSynth = useCallback(() => {
55
+ if (!synthRef.current) {
56
+ synthRef.current = new Tone.Sampler({
57
+ urls: buildSampleMap(),
58
+ baseUrl: SALAMANDER_URL,
59
+ release: 1.5,
60
+ onload: () => setSamplesLoaded(true),
61
+ }).toDestination();
62
+ }
63
+ return synthRef.current;
64
+ }, []);
65
+
66
+ // Animation loop — updates currentTimeRef, handles looping and end
67
+ const tick = useCallback(() => {
68
+ const transport = Tone.getTransport();
69
+ if (transport.state === 'started') {
70
+ currentTimeRef.current = transport.seconds;
71
+
72
+ const loop = loopRef.current;
73
+
74
+ // Handle loop: when we reach loopEnd, jump back to loopStart
75
+ if (loop.start !== null && loop.end !== null && transport.seconds >= loop.end) {
76
+ if (synthRef.current) synthRef.current.releaseAll();
77
+ transport.seconds = loop.start;
78
+ currentTimeRef.current = loop.start;
79
+ setActiveNotes(new Set());
80
+ }
81
+ // Auto-stop at end (only when not looping)
82
+ else if (
83
+ totalDurationRef.current > 0 &&
84
+ transport.seconds >= totalDurationRef.current + 0.5
85
+ ) {
86
+ transport.pause();
87
+ transport.seconds = 0;
88
+ currentTimeRef.current = 0;
89
+ setIsPlaying(false);
90
+ setActiveNotes(new Set());
91
+ }
92
+ }
93
+ animationRef.current = requestAnimationFrame(tick);
94
+ }, []);
95
+
96
+ // Start/stop animation loop
97
+ useEffect(() => {
98
+ animationRef.current = requestAnimationFrame(tick);
99
+ return () => {
100
+ if (animationRef.current) {
101
+ cancelAnimationFrame(animationRef.current);
102
+ }
103
+ };
104
+ }, [tick]);
105
+
106
+ // Schedule all notes on the Transport
107
+ const scheduleNotes = useCallback(
108
+ (notes, totalDuration) => {
109
+ const transport = Tone.getTransport();
110
+ transport.cancel();
111
+ transport.stop();
112
+ transport.seconds = 0;
113
+ currentTimeRef.current = 0;
114
+ totalDurationRef.current = totalDuration;
115
+ scheduledRef.current = false;
116
+ setIsPlaying(false);
117
+ setActiveNotes(new Set());
118
+
119
+ const synth = getSynth();
120
+
121
+ notes.forEach((note) => {
122
+ transport.schedule((time) => {
123
+ const noteName = Tone.Frequency(note.midi, 'midi').toNote();
124
+ synth.triggerAttackRelease(noteName, note.duration, time, note.velocity);
125
+
126
+ // Key highlight on
127
+ Tone.Draw.schedule(() => {
128
+ setActiveNotes((prev) => {
129
+ const next = new Set(prev);
130
+ next.add(note.midi);
131
+ return next;
132
+ });
133
+ }, time);
134
+
135
+ // Key highlight off
136
+ Tone.Draw.schedule(() => {
137
+ setActiveNotes((prev) => {
138
+ const next = new Set(prev);
139
+ next.delete(note.midi);
140
+ return next;
141
+ });
142
+ }, time + note.duration);
143
+ }, note.time);
144
+ });
145
+
146
+ scheduledRef.current = true;
147
+ },
148
+ [getSynth]
149
+ );
150
+
151
+ const play = useCallback(async () => {
152
+ await Tone.start();
153
+ const transport = Tone.getTransport();
154
+ transport.start();
155
+ setIsPlaying(true);
156
+ }, []);
157
+
158
+ const pause = useCallback(() => {
159
+ Tone.getTransport().pause();
160
+ setIsPlaying(false);
161
+ }, []);
162
+
163
+ const togglePlayPause = useCallback(async () => {
164
+ if (isPlaying) {
165
+ pause();
166
+ } else {
167
+ await play();
168
+ }
169
+ }, [isPlaying, play, pause]);
170
+
171
+ const setTempo = useCallback((percent) => {
172
+ setTempoState(percent);
173
+ Tone.getTransport().playbackRate = percent / 100;
174
+ }, []);
175
+
176
+ const seekTo = useCallback(
177
+ (timeInSeconds) => {
178
+ const synth = getSynth();
179
+ synth.releaseAll();
180
+ Tone.getTransport().seconds = timeInSeconds;
181
+ currentTimeRef.current = timeInSeconds;
182
+ setActiveNotes(new Set());
183
+ },
184
+ [getSynth]
185
+ );
186
+
187
+ // Set loop A point (start)
188
+ const setLoopA = useCallback(() => {
189
+ const t = currentTimeRef.current;
190
+ setLoopStart(t);
191
+ // If loopEnd is before the new start, clear it
192
+ if (loopEnd !== null && loopEnd <= t) {
193
+ setLoopEnd(null);
194
+ }
195
+ }, [loopEnd]);
196
+
197
+ // Set loop B point (end)
198
+ const setLoopB = useCallback(() => {
199
+ const t = currentTimeRef.current;
200
+ if (loopStart !== null && t > loopStart) {
201
+ setLoopEnd(t);
202
+ }
203
+ }, [loopStart]);
204
+
205
+ // Clear loop
206
+ const clearLoop = useCallback(() => {
207
+ setLoopStart(null);
208
+ setLoopEnd(null);
209
+ }, []);
210
+
211
+ const isLooping = loopStart !== null && loopEnd !== null;
212
+
213
+ // Cleanup
214
+ useEffect(() => {
215
+ return () => {
216
+ Tone.getTransport().cancel();
217
+ Tone.getTransport().stop();
218
+ if (synthRef.current) {
219
+ synthRef.current.dispose();
220
+ synthRef.current = null;
221
+ }
222
+ };
223
+ }, []);
224
+
225
+ return {
226
+ isPlaying,
227
+ currentTimeRef,
228
+ activeNotes,
229
+ tempo,
230
+ samplesLoaded,
231
+ loopStart,
232
+ loopEnd,
233
+ isLooping,
234
+ play,
235
+ pause,
236
+ togglePlayPause,
237
+ setTempo,
238
+ seekTo,
239
+ scheduleNotes,
240
+ setLoopA,
241
+ setLoopB,
242
+ clearLoop,
243
+ };
244
+ }
app/src/index.css ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --bg: #07070e;
3
+ --surface: #0f0f1a;
4
+ --surface-2: #161628;
5
+ --surface-3: #1c1c36;
6
+ --border: #1e1e3a;
7
+ --border-hover: #2e2e52;
8
+ --primary: #8b5cf6;
9
+ --primary-hover: #a78bfa;
10
+ --primary-dim: #7c3aed;
11
+ --primary-glow: rgba(139, 92, 246, 0.25);
12
+ --accent: #06b6d4;
13
+ --text: #f1f5f9;
14
+ --text-muted: #94a3b8;
15
+ --text-subtle: #525280;
16
+ --danger: #ef4444;
17
+ --danger-bg: rgba(239, 68, 68, 0.1);
18
+ --radius: 10px;
19
+ --radius-lg: 14px;
20
+ }
21
+
22
+ * {
23
+ margin: 0;
24
+ padding: 0;
25
+ box-sizing: border-box;
26
+ }
27
+
28
+ body {
29
+ background: var(--bg);
30
+ color: var(--text);
31
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
32
+ overflow: hidden;
33
+ -webkit-font-smoothing: antialiased;
34
+ -moz-osx-font-smoothing: grayscale;
35
+ }
36
+
37
+ #root {
38
+ width: 100vw;
39
+ height: 100vh;
40
+ }
41
+
42
+ .app {
43
+ width: 100%;
44
+ height: 100%;
45
+ display: flex;
46
+ flex-direction: column;
47
+ }
48
+
49
+ /* ========================================
50
+ Upload Screen
51
+ ======================================== */
52
+
53
+ .upload-screen {
54
+ width: 100vw;
55
+ height: 100vh;
56
+ display: flex;
57
+ align-items: center;
58
+ justify-content: center;
59
+ background: var(--bg);
60
+ background-image:
61
+ radial-gradient(ellipse at 50% 30%, rgba(139, 92, 246, 0.08) 0%, transparent 60%),
62
+ radial-gradient(ellipse at 50% 60%, rgba(6, 182, 212, 0.04) 0%, transparent 50%);
63
+ }
64
+
65
+ .upload-content {
66
+ width: 100%;
67
+ max-width: 520px;
68
+ padding: 0 24px;
69
+ }
70
+
71
+ .upload-logo {
72
+ display: flex;
73
+ flex-direction: column;
74
+ align-items: center;
75
+ margin-bottom: 48px;
76
+ }
77
+
78
+ .upload-logo h1 {
79
+ font-size: 36px;
80
+ font-weight: 700;
81
+ letter-spacing: -1px;
82
+ background: linear-gradient(135deg, #a78bfa 0%, #06b6d4 100%);
83
+ -webkit-background-clip: text;
84
+ -webkit-text-fill-color: transparent;
85
+ background-clip: text;
86
+ margin-top: 14px;
87
+ }
88
+
89
+ .upload-tagline {
90
+ color: var(--text-muted);
91
+ font-size: 16px;
92
+ margin-top: 8px;
93
+ }
94
+
95
+ .upload-description {
96
+ color: var(--text-muted);
97
+ font-size: 14px;
98
+ line-height: 1.6;
99
+ text-align: center;
100
+ margin-bottom: 28px;
101
+ }
102
+
103
+ /* Copyright notice */
104
+ .copyright-notice {
105
+ margin-top: 20px;
106
+ padding: 12px 16px;
107
+ font-size: 12px;
108
+ line-height: 1.5;
109
+ color: var(--text-subtle);
110
+ text-align: center;
111
+ border-top: 1px solid var(--border);
112
+ }
113
+
114
+ /* Drop zone */
115
+ .drop-zone {
116
+ border: 2px dashed var(--border);
117
+ border-radius: var(--radius-lg);
118
+ padding: 36px 24px;
119
+ cursor: pointer;
120
+ transition: all 0.2s;
121
+ text-align: center;
122
+ }
123
+
124
+ .drop-zone:hover,
125
+ .drop-zone.dragging {
126
+ border-color: var(--primary);
127
+ background: rgba(139, 92, 246, 0.05);
128
+ box-shadow: inset 0 0 30px rgba(139, 92, 246, 0.03);
129
+ }
130
+
131
+ .drop-icon {
132
+ font-size: 32px;
133
+ margin-bottom: 10px;
134
+ opacity: 0.4;
135
+ }
136
+
137
+ .drop-zone p {
138
+ color: var(--text-muted);
139
+ font-size: 14px;
140
+ font-weight: 500;
141
+ }
142
+
143
+ .drop-hint {
144
+ font-size: 12px !important;
145
+ color: var(--text-subtle) !important;
146
+ font-weight: 400 !important;
147
+ margin-top: 6px;
148
+ }
149
+
150
+ /* Error message */
151
+ .upload-error {
152
+ margin-top: 20px;
153
+ color: var(--danger);
154
+ font-size: 13px;
155
+ font-weight: 500;
156
+ background: var(--danger-bg);
157
+ padding: 12px 16px;
158
+ border-radius: var(--radius);
159
+ border: 1px solid rgba(239, 68, 68, 0.2);
160
+ }
161
+
162
+ /* ========================================
163
+ Processing / Loading Screen
164
+ ======================================== */
165
+
166
+ .upload-processing {
167
+ text-align: center;
168
+ }
169
+
170
+ .processing-logo {
171
+ animation: pulse 2s ease-in-out infinite;
172
+ margin-bottom: 24px;
173
+ }
174
+
175
+ .upload-processing h2 {
176
+ font-size: 24px;
177
+ font-weight: 700;
178
+ color: var(--text);
179
+ margin-bottom: 8px;
180
+ }
181
+
182
+ .upload-processing p {
183
+ color: var(--text-muted);
184
+ font-size: 15px;
185
+ }
186
+
187
+ .loading-sub {
188
+ margin-bottom: 28px;
189
+ }
190
+
191
+ .loading-bar {
192
+ width: 200px;
193
+ height: 4px;
194
+ background: var(--border);
195
+ border-radius: 2px;
196
+ margin: 0 auto;
197
+ overflow: hidden;
198
+ }
199
+
200
+ .loading-bar-fill {
201
+ width: 40%;
202
+ height: 100%;
203
+ background: linear-gradient(90deg, var(--primary), var(--accent));
204
+ border-radius: 2px;
205
+ animation: loading-slide 1.5s ease-in-out infinite;
206
+ }
207
+
208
+ @keyframes loading-slide {
209
+ 0% { transform: translateX(-100%); }
210
+ 100% { transform: translateX(350%); }
211
+ }
212
+
213
+ @keyframes pulse {
214
+ 0%, 100% { transform: scale(1); opacity: 1; }
215
+ 50% { transform: scale(1.08); opacity: 0.7; }
216
+ }
217
+
218
+ /* ========================================
219
+ Controls Bar (Player)
220
+ ======================================== */
221
+
222
+ .controls {
223
+ background: var(--surface);
224
+ border-bottom: 1px solid var(--border);
225
+ flex-shrink: 0;
226
+ display: flex;
227
+ flex-direction: column;
228
+ }
229
+
230
+ .controls-main {
231
+ height: 56px;
232
+ display: flex;
233
+ align-items: center;
234
+ justify-content: space-between;
235
+ padding: 0 20px;
236
+ gap: 16px;
237
+ }
238
+
239
+ .controls-left {
240
+ display: flex;
241
+ align-items: center;
242
+ gap: 14px;
243
+ min-width: 0;
244
+ flex: 1;
245
+ }
246
+
247
+ .brand-mark {
248
+ display: flex;
249
+ align-items: center;
250
+ gap: 10px;
251
+ flex-shrink: 0;
252
+ }
253
+
254
+ .brand-name {
255
+ font-size: 15px;
256
+ font-weight: 700;
257
+ background: linear-gradient(135deg, #a78bfa, #06b6d4);
258
+ -webkit-background-clip: text;
259
+ -webkit-text-fill-color: transparent;
260
+ background-clip: text;
261
+ white-space: nowrap;
262
+ letter-spacing: -0.3px;
263
+ }
264
+
265
+ .file-name {
266
+ font-size: 13px;
267
+ color: var(--text-muted);
268
+ white-space: nowrap;
269
+ overflow: hidden;
270
+ text-overflow: ellipsis;
271
+ max-width: 200px;
272
+ padding-left: 14px;
273
+ border-left: 1.5px solid var(--border);
274
+ font-weight: 500;
275
+ }
276
+
277
+ .controls-center {
278
+ display: flex;
279
+ align-items: center;
280
+ gap: 6px;
281
+ flex-shrink: 0;
282
+ }
283
+
284
+ .controls-right {
285
+ display: flex;
286
+ align-items: center;
287
+ gap: 16px;
288
+ flex: 1;
289
+ justify-content: flex-end;
290
+ }
291
+
292
+ /* Transport buttons */
293
+ .transport-btn {
294
+ width: 36px;
295
+ height: 36px;
296
+ border-radius: 8px;
297
+ border: none;
298
+ background: var(--surface-2);
299
+ color: var(--text-muted);
300
+ cursor: pointer;
301
+ display: flex;
302
+ align-items: center;
303
+ justify-content: center;
304
+ transition: all 0.15s;
305
+ }
306
+
307
+ .transport-btn:hover {
308
+ background: var(--surface-3);
309
+ color: var(--text);
310
+ }
311
+
312
+ /* Play button — bold and prominent */
313
+ .play-btn {
314
+ width: 48px;
315
+ height: 48px;
316
+ border-radius: 50%;
317
+ border: none;
318
+ background: var(--primary);
319
+ color: white;
320
+ font-size: 18px;
321
+ cursor: pointer;
322
+ transition: all 0.2s;
323
+ display: flex;
324
+ align-items: center;
325
+ justify-content: center;
326
+ box-shadow: 0 0 20px var(--primary-glow);
327
+ }
328
+
329
+ .play-btn:hover {
330
+ background: var(--primary-hover);
331
+ box-shadow: 0 0 30px var(--primary-glow);
332
+ transform: scale(1.05);
333
+ }
334
+
335
+ .play-btn:active {
336
+ transform: scale(0.97);
337
+ }
338
+
339
+ /* + New Song button */
340
+ .btn {
341
+ background: var(--surface-2);
342
+ color: var(--text-muted);
343
+ border: 1.5px solid var(--border);
344
+ border-radius: 8px;
345
+ padding: 7px 16px;
346
+ font-size: 12px;
347
+ font-weight: 600;
348
+ font-family: inherit;
349
+ cursor: pointer;
350
+ transition: all 0.15s;
351
+ white-space: nowrap;
352
+ letter-spacing: 0.2px;
353
+ }
354
+
355
+ .btn:hover {
356
+ background: var(--surface-3);
357
+ color: var(--text);
358
+ border-color: var(--border-hover);
359
+ }
360
+
361
+ .btn-new {
362
+ border-color: var(--primary-dim);
363
+ color: var(--primary-hover);
364
+ }
365
+
366
+ .btn-new:hover {
367
+ background: rgba(139, 92, 246, 0.1);
368
+ border-color: var(--primary);
369
+ color: var(--primary-hover);
370
+ }
371
+
372
+ /* Tempo control */
373
+ .tempo-control {
374
+ display: flex;
375
+ align-items: center;
376
+ gap: 8px;
377
+ background: var(--surface-2);
378
+ padding: 6px 14px;
379
+ border-radius: 8px;
380
+ border: 1px solid var(--border);
381
+ }
382
+
383
+ .tempo-label {
384
+ font-size: 11px;
385
+ font-weight: 600;
386
+ color: var(--text-subtle);
387
+ text-transform: uppercase;
388
+ letter-spacing: 0.5px;
389
+ white-space: nowrap;
390
+ }
391
+
392
+ .tempo-value {
393
+ font-size: 13px;
394
+ font-weight: 600;
395
+ color: var(--text-muted);
396
+ min-width: 36px;
397
+ text-align: right;
398
+ font-variant-numeric: tabular-nums;
399
+ }
400
+
401
+ .tempo-control input[type='range'] {
402
+ width: 80px;
403
+ }
404
+
405
+ /* ========================================
406
+ Timeline / Progress Bar
407
+ ======================================== */
408
+
409
+ .timeline {
410
+ display: flex;
411
+ align-items: center;
412
+ gap: 12px;
413
+ padding: 0 20px 10px;
414
+ }
415
+
416
+ .timeline-time {
417
+ font-size: 12px;
418
+ font-weight: 600;
419
+ color: var(--text-muted);
420
+ font-variant-numeric: tabular-nums;
421
+ min-width: 36px;
422
+ }
423
+
424
+ .timeline-time:last-child {
425
+ text-align: right;
426
+ }
427
+
428
+ .timeline-track {
429
+ flex: 1;
430
+ position: relative;
431
+ }
432
+
433
+ .timeline-track input[type='range'] {
434
+ width: 100%;
435
+ height: 6px;
436
+ border-radius: 3px;
437
+ -webkit-appearance: none;
438
+ appearance: none;
439
+ outline: none;
440
+ cursor: pointer;
441
+ transition: height 0.15s;
442
+ }
443
+
444
+ .timeline-track input[type='range']:hover {
445
+ height: 8px;
446
+ }
447
+
448
+ .timeline-track input[type='range']::-webkit-slider-thumb {
449
+ -webkit-appearance: none;
450
+ appearance: none;
451
+ width: 14px;
452
+ height: 14px;
453
+ border-radius: 50%;
454
+ background: var(--primary-hover);
455
+ cursor: pointer;
456
+ border: 2px solid white;
457
+ box-shadow: 0 0 8px var(--primary-glow);
458
+ transition: transform 0.1s;
459
+ }
460
+
461
+ .timeline-track input[type='range']::-webkit-slider-thumb:hover {
462
+ transform: scale(1.25);
463
+ }
464
+
465
+ .timeline-track input[type='range']::-moz-range-thumb {
466
+ width: 14px;
467
+ height: 14px;
468
+ border-radius: 50%;
469
+ background: var(--primary-hover);
470
+ cursor: pointer;
471
+ border: 2px solid white;
472
+ box-shadow: 0 0 8px var(--primary-glow);
473
+ }
474
+
475
+ /* General range sliders (for tempo) */
476
+ input[type='range'] {
477
+ -webkit-appearance: none;
478
+ appearance: none;
479
+ background: var(--border);
480
+ height: 4px;
481
+ border-radius: 2px;
482
+ outline: none;
483
+ cursor: pointer;
484
+ }
485
+
486
+ input[type='range']::-webkit-slider-thumb {
487
+ -webkit-appearance: none;
488
+ appearance: none;
489
+ width: 14px;
490
+ height: 14px;
491
+ border-radius: 50%;
492
+ background: var(--primary);
493
+ cursor: pointer;
494
+ border: none;
495
+ transition: transform 0.1s;
496
+ }
497
+
498
+ input[type='range']::-webkit-slider-thumb:hover {
499
+ transform: scale(1.2);
500
+ }
501
+
502
+ input[type='range']::-moz-range-thumb {
503
+ width: 14px;
504
+ height: 14px;
505
+ border-radius: 50%;
506
+ background: var(--primary);
507
+ cursor: pointer;
508
+ border: none;
509
+ }
510
+
511
+ /* Loop controls */
512
+ .loop-controls {
513
+ display: flex;
514
+ align-items: center;
515
+ gap: 4px;
516
+ }
517
+
518
+ .btn-loop {
519
+ min-width: 32px;
520
+ text-align: center;
521
+ font-weight: 700;
522
+ font-size: 12px;
523
+ padding: 6px 10px;
524
+ border-radius: 6px;
525
+ font-family: inherit;
526
+ letter-spacing: 0.3px;
527
+ }
528
+
529
+ .btn-loop.active {
530
+ background: rgba(139, 92, 246, 0.15);
531
+ border-color: var(--primary);
532
+ color: var(--primary-hover);
533
+ }
534
+
535
+ .btn-loop:disabled {
536
+ opacity: 0.3;
537
+ cursor: not-allowed;
538
+ }
539
+
540
+ .loop-x {
541
+ margin-left: 6px;
542
+ font-size: 14px;
543
+ opacity: 0.6;
544
+ }
545
+
546
+ /* Canvas area */
547
+ .canvas-container {
548
+ flex: 1;
549
+ position: relative;
550
+ overflow: hidden;
551
+ width: 100%;
552
+ }
app/src/main.jsx ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import { StrictMode } from 'react'
2
+ import { createRoot } from 'react-dom/client'
3
+ import './index.css'
4
+ import App from './App.jsx'
5
+
6
+ createRoot(document.getElementById('root')).render(
7
+ <StrictMode>
8
+ <App />
9
+ </StrictMode>,
10
+ )
app/src/utils/colorScheme.js ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export const COLORS = {
2
+ background: '#07070e',
3
+ pianoRollBg: '#0a0a14',
4
+ hitLine: '#ffffff33',
5
+ hitLineGlow: '#ffffff18',
6
+
7
+ // Note colors — brand-aligned
8
+ leftHand: '#8b5cf6',
9
+ leftHandGlow: '#a78bfa',
10
+ rightHand: '#06b6d4',
11
+ rightHandGlow: '#22d3ee',
12
+
13
+ // Piano key colors
14
+ whiteKey: '#e8e8e8',
15
+ whiteKeyActive: '#c4b5fd',
16
+ blackKey: '#1a1a2e',
17
+ blackKeyActive: '#7c3aed',
18
+ keyBorder: '#2a2a40',
19
+
20
+ // UI
21
+ text: '#f1f5f9',
22
+ textMuted: '#94a3b8',
23
+ controlsBg: '#0a0a14',
24
+ controlsBorder: '#1e1e3a',
25
+ };
26
+
27
+ export const MIDI_SPLIT_POINT = 60; // Middle C (C4)
28
+
29
+ export function noteColor(midiNumber) {
30
+ return midiNumber < MIDI_SPLIT_POINT ? COLORS.leftHand : COLORS.rightHand;
31
+ }
32
+
33
+ export function noteGlowColor(midiNumber) {
34
+ return midiNumber < MIDI_SPLIT_POINT ? COLORS.leftHandGlow : COLORS.rightHandGlow;
35
+ }
app/src/utils/generateSampleMidi.js ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Midi } from '@tonejs/midi';
2
+
3
+ /**
4
+ * Generate a sample "Twinkle Twinkle Little Star" MIDI with both hands.
5
+ */
6
+ export function generateSampleMidi() {
7
+ const midi = new Midi();
8
+
9
+ // Right hand melody (MIDI >= 60)
10
+ const rhTrack = midi.addTrack();
11
+ rhTrack.name = 'Right Hand';
12
+
13
+ // C C G G A A G - F F E E D D C
14
+ // Then: G G F F E E D - G G F F E E D
15
+ const melody = [
16
+ // Phrase 1
17
+ 60, 60, 67, 67, 69, 69, 67,
18
+ // Phrase 2
19
+ 65, 65, 64, 64, 62, 62, 60,
20
+ // Phrase 3
21
+ 67, 67, 65, 65, 64, 64, 62,
22
+ // Phrase 4
23
+ 67, 67, 65, 65, 64, 64, 62,
24
+ // Phrase 5 (repeat phrase 1)
25
+ 60, 60, 67, 67, 69, 69, 67,
26
+ // Phrase 6 (repeat phrase 2)
27
+ 65, 65, 64, 64, 62, 62, 60,
28
+ ];
29
+ const durations = [
30
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1,
31
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1,
32
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1,
33
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1,
34
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1,
35
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1,
36
+ ];
37
+
38
+ let t = 0;
39
+ melody.forEach((note, i) => {
40
+ rhTrack.addNote({
41
+ midi: note,
42
+ time: t,
43
+ duration: durations[i] * 0.9,
44
+ velocity: 0.8,
45
+ });
46
+ t += durations[i];
47
+ });
48
+
49
+ // Left hand accompaniment (MIDI < 60)
50
+ const lhTrack = midi.addTrack();
51
+ lhTrack.name = 'Left Hand';
52
+
53
+ const chords = [
54
+ // Phrase 1: C major
55
+ { notes: [48, 52, 55], time: 0, dur: 2 },
56
+ { notes: [48, 52, 55], time: 2, dur: 1 },
57
+ // Phrase 2: F major -> C major
58
+ { notes: [41, 45, 48], time: 3, dur: 2 },
59
+ { notes: [48, 52, 55], time: 5, dur: 1 },
60
+ // Phrase 3: C -> G -> Am -> F
61
+ { notes: [48, 52], time: 7, dur: 1 },
62
+ { notes: [43, 47], time: 8, dur: 1 },
63
+ { notes: [45, 48], time: 9, dur: 1 },
64
+ // Phrase 4: same
65
+ { notes: [41, 45], time: 10, dur: 1 },
66
+ { notes: [48, 52], time: 11, dur: 1 },
67
+ { notes: [43, 47], time: 12, dur: 1 },
68
+ { notes: [45, 48], time: 13, dur: 1 },
69
+ // Phrase 5
70
+ { notes: [48, 52, 55], time: 14, dur: 2 },
71
+ { notes: [48, 52, 55], time: 16, dur: 1 },
72
+ // Phrase 6
73
+ { notes: [41, 45, 48], time: 17, dur: 2 },
74
+ { notes: [48, 52, 55], time: 19, dur: 2 },
75
+ ];
76
+
77
+ chords.forEach((chord) => {
78
+ chord.notes.forEach((note) => {
79
+ lhTrack.addNote({
80
+ midi: note,
81
+ time: chord.time,
82
+ duration: chord.dur * 0.9,
83
+ velocity: 0.6,
84
+ });
85
+ });
86
+ });
87
+
88
+ return midi;
89
+ }
app/src/utils/midiHelpers.js ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { MIDI_SPLIT_POINT } from './colorScheme';
2
+
3
+ // Piano range: C2 (36) to C7 (96)
4
+ export const LOWEST_NOTE = 36;
5
+ export const HIGHEST_NOTE = 96;
6
+
7
+ const BLACK_KEY_OFFSETS = new Set([1, 3, 6, 8, 10]);
8
+
9
+ export function isBlackKey(midiNumber) {
10
+ return BLACK_KEY_OFFSETS.has(midiNumber % 12);
11
+ }
12
+
13
+ /**
14
+ * Build an array of key objects with pixel positions for a given canvas width.
15
+ * Returns [{ midiNumber, x, width, isBlack }]
16
+ */
17
+ export function buildKeyboardLayout(canvasWidth) {
18
+ // Count white keys in range
19
+ const keys = [];
20
+ const whiteKeys = [];
21
+ for (let midi = LOWEST_NOTE; midi <= HIGHEST_NOTE; midi++) {
22
+ if (!isBlackKey(midi)) {
23
+ whiteKeys.push(midi);
24
+ }
25
+ }
26
+
27
+ const whiteKeyWidth = canvasWidth / whiteKeys.length;
28
+ const blackKeyWidth = whiteKeyWidth * 0.6;
29
+
30
+ // Position white keys
31
+ const keyMap = new Map();
32
+ whiteKeys.forEach((midi, i) => {
33
+ keyMap.set(midi, {
34
+ midiNumber: midi,
35
+ x: i * whiteKeyWidth,
36
+ width: whiteKeyWidth,
37
+ isBlack: false,
38
+ });
39
+ });
40
+
41
+ // Position black keys between their adjacent white keys
42
+ for (let midi = LOWEST_NOTE; midi <= HIGHEST_NOTE; midi++) {
43
+ if (isBlackKey(midi)) {
44
+ // Find the white key just below this black key
45
+ const prevWhite = keyMap.get(midi - 1);
46
+ if (prevWhite) {
47
+ keyMap.set(midi, {
48
+ midiNumber: midi,
49
+ x: prevWhite.x + prevWhite.width - blackKeyWidth / 2,
50
+ width: blackKeyWidth,
51
+ isBlack: true,
52
+ });
53
+ }
54
+ }
55
+ }
56
+
57
+ // Return sorted by MIDI number
58
+ for (let midi = LOWEST_NOTE; midi <= HIGHEST_NOTE; midi++) {
59
+ if (keyMap.has(midi)) {
60
+ keys.push(keyMap.get(midi));
61
+ }
62
+ }
63
+
64
+ return keys;
65
+ }
66
+
67
+ /**
68
+ * Get the x position and width for a falling note block.
69
+ */
70
+ export function noteXPosition(midiNumber, keyboardLayout) {
71
+ const key = keyboardLayout.find((k) => k.midiNumber === midiNumber);
72
+ if (key) return { x: key.x, width: key.width };
73
+
74
+ // Clamp to range
75
+ if (midiNumber < LOWEST_NOTE) {
76
+ const first = keyboardLayout[0];
77
+ return { x: first.x, width: first.width };
78
+ }
79
+ const last = keyboardLayout[keyboardLayout.length - 1];
80
+ return { x: last.x, width: last.width };
81
+ }
82
+
83
+ // Build a fast lookup map for noteXPosition (avoids .find() per note per frame)
84
+ export function buildNotePositionMap(keyboardLayout) {
85
+ const map = new Map();
86
+ for (const key of keyboardLayout) {
87
+ map.set(key.midiNumber, { x: key.x, width: key.width });
88
+ }
89
+ return map;
90
+ }
91
+
92
+ export function noteXPositionFast(midiNumber, positionMap) {
93
+ const pos = positionMap.get(midiNumber);
94
+ if (pos) return pos;
95
+
96
+ // Clamp
97
+ if (midiNumber < LOWEST_NOTE) return positionMap.get(LOWEST_NOTE);
98
+ return positionMap.get(HIGHEST_NOTE);
99
+ }
100
+
101
+ /**
102
+ * Parse a Midi object (from @tonejs/midi) into our note format.
103
+ */
104
+ export function parseMidiFile(midiObject) {
105
+ const notes = [];
106
+
107
+ midiObject.tracks.forEach((track) => {
108
+ track.notes.forEach((note) => {
109
+ notes.push({
110
+ midi: note.midi,
111
+ name: note.name,
112
+ time: note.time,
113
+ duration: note.duration,
114
+ velocity: note.velocity,
115
+ hand: note.midi < MIDI_SPLIT_POINT ? 'left' : 'right',
116
+ });
117
+ });
118
+ });
119
+
120
+ // Sort by start time
121
+ notes.sort((a, b) => a.time - b.time);
122
+
123
+ const totalDuration =
124
+ notes.length > 0
125
+ ? Math.max(...notes.map((n) => n.time + n.duration))
126
+ : 0;
127
+
128
+ return { notes, totalDuration };
129
+ }
130
+
131
+ /**
132
+ * Get notes visible in the current time window using binary search.
133
+ * Notes array must be sorted by `time` (start time).
134
+ */
135
+ export function getVisibleNotes(
136
+ notes,
137
+ currentTime,
138
+ lookAheadSeconds,
139
+ maxPastSeconds = 1
140
+ ) {
141
+ const endTime = currentTime + lookAheadSeconds;
142
+
143
+ // Find the longest note duration so we can search far enough back
144
+ // to catch long-held notes that started early but are still visible.
145
+ // Precompute once on first call via a cached property.
146
+ if (notes._maxDur == null) {
147
+ let mx = 0;
148
+ for (let i = 0; i < notes.length; i++) {
149
+ if (notes[i].duration > mx) mx = notes[i].duration;
150
+ }
151
+ notes._maxDur = mx;
152
+ }
153
+ const searchBack = maxPastSeconds + notes._maxDur;
154
+
155
+ // Binary search on `time` (which IS sorted) to find the earliest
156
+ // note that could possibly still be visible.
157
+ const earliest = currentTime - searchBack;
158
+ let lo = 0;
159
+ let hi = notes.length;
160
+ while (lo < hi) {
161
+ const mid = (lo + hi) >> 1;
162
+ if (notes[mid].time < earliest) {
163
+ lo = mid + 1;
164
+ } else {
165
+ hi = mid;
166
+ }
167
+ }
168
+
169
+ const cutoff = currentTime - maxPastSeconds;
170
+ const result = [];
171
+ for (let i = lo; i < notes.length && notes[i].time < endTime; i++) {
172
+ // Only include if the note hasn't fully ended before the visible window
173
+ if (notes[i].time + notes[i].duration >= cutoff) {
174
+ result.push(notes[i]);
175
+ }
176
+ }
177
+ return result;
178
+ }
app/vite.config.js ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { defineConfig } from 'vite'
2
+ import react from '@vitejs/plugin-react'
3
+
4
+ // https://vite.dev/config/
5
+ export default defineConfig({
6
+ plugins: [react()],
7
+ server: {
8
+ proxy: {
9
+ '/api': 'http://localhost:8000',
10
+ },
11
+ },
12
+ })
fly.toml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ app = "mroctopus"
2
+ primary_region = "lhr"
3
+
4
+ [build]
5
+
6
+ [http_service]
7
+ internal_port = 8000
8
+ force_https = true
9
+ auto_stop_machines = "stop"
10
+ auto_start_machines = true
11
+ min_machines_running = 0
12
+
13
+ [vm]
14
+ memory = "2gb"
15
+ cpu_kind = "shared"
16
+ cpus = 2
transcriber/chords.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chord detection from MIDI files using template-matching music theory.
2
+
3
+ Analyzes a MIDI file to detect chords at each note onset, producing a
4
+ time-stamped list of chord events with root, quality, and constituent notes.
5
+ Designed for the Mr. Octopus piano tutorial pipeline.
6
+ """
7
+
8
+ import json
9
+ from pathlib import Path
10
+ from collections import defaultdict
11
+
12
+ import pretty_midi
13
+ import numpy as np
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Music theory constants
18
+ # ---------------------------------------------------------------------------
19
+
20
+ NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
21
+
22
+ # Enharmonic display preferences: use flats for certain roots to match
23
+ # standard music notation (e.g. Bb major, not A# major).
24
+ ENHARMONIC_DISPLAY = {
25
+ "C": "C", "C#": "Db", "D": "D", "D#": "Eb", "E": "E", "F": "F",
26
+ "F#": "F#", "G": "G", "G#": "Ab", "A": "A", "A#": "Bb", "B": "B",
27
+ }
28
+
29
+ # Chord templates: quality name -> set of semitone intervals from root.
30
+ # Each template is a frozenset of pitch-class intervals (0 = root).
31
+ CHORD_TEMPLATES = {
32
+ # Triads
33
+ "major": frozenset([0, 4, 7]),
34
+ "minor": frozenset([0, 3, 7]),
35
+ "diminished": frozenset([0, 3, 6]),
36
+ "augmented": frozenset([0, 4, 8]),
37
+
38
+ # Suspended
39
+ "sus2": frozenset([0, 2, 7]),
40
+ "sus4": frozenset([0, 5, 7]),
41
+
42
+ # Seventh chords
43
+ "dominant 7": frozenset([0, 4, 7, 10]),
44
+ "major 7": frozenset([0, 4, 7, 11]),
45
+ "minor 7": frozenset([0, 3, 7, 10]),
46
+ "diminished 7": frozenset([0, 3, 6, 9]),
47
+ "half-dim 7": frozenset([0, 3, 6, 10]),
48
+ "min/maj 7": frozenset([0, 3, 7, 11]),
49
+ "augmented 7": frozenset([0, 4, 8, 10]),
50
+
51
+ # Extended / added-tone
52
+ "add9": frozenset([0, 2, 4, 7]),
53
+ "minor add9": frozenset([0, 2, 3, 7]),
54
+ "6": frozenset([0, 4, 7, 9]),
55
+ "minor 6": frozenset([0, 3, 7, 9]),
56
+ }
57
+
58
+ # Short suffix for display (e.g. "Cm7", "Gdim", "Fsus4")
59
+ QUALITY_SUFFIX = {
60
+ "major": "",
61
+ "minor": "m",
62
+ "diminished": "dim",
63
+ "augmented": "aug",
64
+ "sus2": "sus2",
65
+ "sus4": "sus4",
66
+ "dominant 7": "7",
67
+ "major 7": "maj7",
68
+ "minor 7": "m7",
69
+ "diminished 7": "dim7",
70
+ "half-dim 7": "m7b5",
71
+ "min/maj 7": "m(maj7)",
72
+ "augmented 7": "aug7",
73
+ "add9": "add9",
74
+ "minor add9": "madd9",
75
+ "6": "6",
76
+ "minor 6": "m6",
77
+ }
78
+
79
+ # Priority ordering for tie-breaking when multiple templates match equally.
80
+ # Lower index = preferred. Triads > sevenths > extended > suspended.
81
+ QUALITY_PRIORITY = [
82
+ "major", "minor", "dominant 7", "minor 7", "major 7",
83
+ "diminished", "augmented", "half-dim 7", "diminished 7",
84
+ "6", "minor 6", "sus4", "sus2", "add9", "minor add9",
85
+ "min/maj 7", "augmented 7",
86
+ ]
87
+
88
+
89
+ # ---------------------------------------------------------------------------
90
+ # Frame extraction
91
+ # ---------------------------------------------------------------------------
92
+
93
+ def extract_note_frames(midi_data, onset_tolerance=0.05):
94
+ """Group MIDI notes into simultaneous frames (chords / single notes).
95
+
96
+ Notes whose onsets fall within `onset_tolerance` seconds of each other
97
+ are grouped into the same frame. Returns a list of dicts:
98
+ {
99
+ "start": float, # earliest onset in the group
100
+ "end": float, # latest note-off in the group
101
+ "pitches": [int], # MIDI pitch numbers
102
+ "velocities": [int], # corresponding velocities
103
+ }
104
+ sorted by start time.
105
+ """
106
+ # Collect all notes across instruments (typically only one for piano)
107
+ all_notes = []
108
+ for inst in midi_data.instruments:
109
+ for note in inst.notes:
110
+ all_notes.append(note)
111
+ all_notes.sort(key=lambda n: n.start)
112
+
113
+ if not all_notes:
114
+ return []
115
+
116
+ frames = []
117
+ current_group = [all_notes[0]]
118
+
119
+ for note in all_notes[1:]:
120
+ if note.start - current_group[0].start <= onset_tolerance:
121
+ current_group.append(note)
122
+ else:
123
+ frames.append(_group_to_frame(current_group))
124
+ current_group = [note]
125
+ frames.append(_group_to_frame(current_group))
126
+
127
+ return frames
128
+
129
+
130
+ def _group_to_frame(notes):
131
+ """Convert a group of pretty_midi Note objects into a frame dict."""
132
+ return {
133
+ "start": min(n.start for n in notes),
134
+ "end": max(n.end for n in notes),
135
+ "pitches": [n.pitch for n in notes],
136
+ "velocities": [n.velocity for n in notes],
137
+ }
138
+
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # Template matching
142
+ # ---------------------------------------------------------------------------
143
+
144
+ def pitch_class_set(pitches):
145
+ """Convert a list of MIDI pitches to a set of pitch classes (0-11)."""
146
+ return set(p % 12 for p in pitches)
147
+
148
+
149
+ def match_chord(pitches, velocities=None):
150
+ """Identify a chord from a set of MIDI pitches.
151
+
152
+ Uses a template-matching approach that tests every possible root (0-11)
153
+ against every chord template. Scoring:
154
+
155
+ 1. Count how many template tones are present in the pitch-class set
156
+ (weighted by velocity when available).
157
+ 2. Penalize extra notes not in the template.
158
+ 3. Prefer templates that explain more notes.
159
+ 4. Handle inversions: the bass note does not need to be the root.
160
+
161
+ Returns a dict:
162
+ {
163
+ "root": int, # pitch class 0-11
164
+ "root_name": str, # e.g. "C", "Db"
165
+ "quality": str, # e.g. "minor 7"
166
+ "chord_name": str, # e.g. "Cm7"
167
+ "notes": [str], # constituent note names
168
+ "midi_pitches": [int], # original MIDI pitches
169
+ }
170
+ or None if fewer than 2 distinct pitch classes.
171
+ """
172
+ pcs = pitch_class_set(pitches)
173
+ if len(pcs) < 2:
174
+ return _single_note_result(pitches) if pitches else None
175
+
176
+ # Build a velocity weight map (pitch class -> total velocity)
177
+ pc_weight = defaultdict(float)
178
+ if velocities and len(velocities) == len(pitches):
179
+ for p, v in zip(pitches, velocities):
180
+ pc_weight[p % 12] += v
181
+ else:
182
+ for p in pitches:
183
+ pc_weight[p % 12] += 80 # default velocity
184
+
185
+ # Normalize weights so the max is 1.0
186
+ max_w = max(pc_weight.values()) if pc_weight else 1.0
187
+ for pc in pc_weight:
188
+ pc_weight[pc] /= max_w
189
+
190
+ # Determine the bass note (lowest pitch) for inversion bonus
191
+ bass_pc = min(pitches) % 12
192
+
193
+ best_score = -999
194
+ best_result = None
195
+
196
+ for root in range(12):
197
+ for quality, template in CHORD_TEMPLATES.items():
198
+ # Transpose template to this root
199
+ transposed = frozenset((root + interval) % 12 for interval in template)
200
+
201
+ # Score: weighted count of template tones present
202
+ matched_weight = 0.0
203
+ matched_count = 0
204
+ for pc in transposed:
205
+ if pc in pcs:
206
+ matched_weight += pc_weight.get(pc, 0.5)
207
+ matched_count += 1
208
+
209
+ # How many of the input pitch classes are NOT in the template?
210
+ extra_notes = len(pcs - transposed)
211
+
212
+ # How many template tones are missing?
213
+ missing = len(transposed) - matched_count
214
+
215
+ # Base score: reward matches, penalize misses and extras
216
+ score = matched_weight * 2.0 - missing * 1.5 - extra_notes * 0.5
217
+
218
+ # Bonus if this template perfectly covers all input notes
219
+ if extra_notes == 0 and missing == 0:
220
+ score += 3.0
221
+
222
+ # Bonus if root is the bass note (root position)
223
+ if root == bass_pc:
224
+ score += 0.8
225
+
226
+ # Bonus for root having high velocity
227
+ score += pc_weight.get(root, 0) * 0.3
228
+
229
+ # Smaller bonus for simpler chord types (triads over 7ths)
230
+ priority_idx = QUALITY_PRIORITY.index(quality) if quality in QUALITY_PRIORITY else len(QUALITY_PRIORITY)
231
+ score -= priority_idx * 0.05
232
+
233
+ # A template must match at least 2 pitch classes to be viable
234
+ if matched_count < 2:
235
+ continue
236
+
237
+ if score > best_score:
238
+ best_score = score
239
+ root_name = ENHARMONIC_DISPLAY[NOTE_NAMES[root]]
240
+ suffix = QUALITY_SUFFIX.get(quality, quality)
241
+ chord_name = f"{root_name}{suffix}"
242
+
243
+ best_result = {
244
+ "root": root,
245
+ "root_name": root_name,
246
+ "quality": quality,
247
+ "chord_name": chord_name,
248
+ "notes": sorted([ENHARMONIC_DISPLAY[NOTE_NAMES[pc]] for pc in transposed]),
249
+ "midi_pitches": sorted(pitches),
250
+ }
251
+
252
+ # If no template matched well enough, fall back to describing the bass + interval
253
+ if best_result is None:
254
+ return _fallback_chord(pitches)
255
+
256
+ return best_result
257
+
258
+
259
+ def _single_note_result(pitches):
260
+ """Return a result for a single note (no chord)."""
261
+ if not pitches:
262
+ return None
263
+ pc = pitches[0] % 12
264
+ name = ENHARMONIC_DISPLAY[NOTE_NAMES[pc]]
265
+ return {
266
+ "root": pc,
267
+ "root_name": name,
268
+ "quality": "note",
269
+ "chord_name": name,
270
+ "notes": [name],
271
+ "midi_pitches": sorted(pitches),
272
+ }
273
+
274
+
275
+ def _fallback_chord(pitches):
276
+ """Produce a best-effort label for unrecognized pitch combinations."""
277
+ pcs = pitch_class_set(pitches)
278
+ bass_pc = min(pitches) % 12
279
+ bass_name = ENHARMONIC_DISPLAY[NOTE_NAMES[bass_pc]]
280
+
281
+ # Try to describe as a root + collection of intervals
282
+ intervals = sorted((pc - bass_pc) % 12 for pc in pcs if pc != bass_pc)
283
+ interval_str = ",".join(str(i) for i in intervals)
284
+
285
+ return {
286
+ "root": bass_pc,
287
+ "root_name": bass_name,
288
+ "quality": "unknown",
289
+ "chord_name": f"{bass_name}({interval_str})",
290
+ "notes": sorted([ENHARMONIC_DISPLAY[NOTE_NAMES[pc]] for pc in pcs]),
291
+ "midi_pitches": sorted(pitches),
292
+ }
293
+
294
+
295
+ # ---------------------------------------------------------------------------
296
+ # Smoothing
297
+ # ---------------------------------------------------------------------------
298
+
299
+ def smooth_chords(chord_events, min_duration=0.1):
300
+ """Remove very short chord changes and merge consecutive identical chords.
301
+
302
+ If the same chord name appears in consecutive events and the intermediate
303
+ event lasts less than `min_duration`, it gets absorbed into the surrounding
304
+ chord. Then consecutive events with the same chord name are merged.
305
+ """
306
+ if not chord_events:
307
+ return chord_events
308
+
309
+ # Pass 1: Remove extremely short transient chords (< min_duration)
310
+ # by replacing them with the previous chord name
311
+ filtered = list(chord_events)
312
+ for i in range(1, len(filtered) - 1):
313
+ duration = filtered[i]["end_time"] - filtered[i]["start_time"]
314
+ if duration < min_duration:
315
+ # Absorb into previous chord
316
+ filtered[i]["chord_name"] = filtered[i - 1]["chord_name"]
317
+ filtered[i]["quality"] = filtered[i - 1]["quality"]
318
+ filtered[i]["root_note"] = filtered[i - 1]["root_note"]
319
+ filtered[i]["notes"] = filtered[i - 1]["notes"]
320
+
321
+ # Pass 2: Merge consecutive events with the same chord name
322
+ merged = [filtered[0]]
323
+ for event in filtered[1:]:
324
+ if event["chord_name"] == merged[-1]["chord_name"]:
325
+ # Extend the previous event's end time
326
+ merged[-1]["end_time"] = event["end_time"]
327
+ # Merge midi_pitches (union)
328
+ existing = set(merged[-1].get("midi_pitches", []))
329
+ existing.update(event.get("midi_pitches", []))
330
+ merged[-1]["midi_pitches"] = sorted(existing)
331
+ else:
332
+ merged.append(event)
333
+
334
+ return merged
335
+
336
+
337
+ # ---------------------------------------------------------------------------
338
+ # Main detection pipeline
339
+ # ---------------------------------------------------------------------------
340
+
341
+ def detect_chords(midi_path, output_path=None, onset_tolerance=0.05,
342
+ min_chord_duration=0.1):
343
+ """Detect chords from a MIDI file and save results as JSON.
344
+
345
+ Parameters
346
+ ----------
347
+ midi_path : str or Path
348
+ Path to the input MIDI file.
349
+ output_path : str or Path, optional
350
+ Path for the output JSON file. Defaults to the MIDI filename
351
+ with "_chords.json" suffix.
352
+ onset_tolerance : float
353
+ Maximum time difference (seconds) to group notes into the same frame.
354
+ min_chord_duration : float
355
+ Minimum duration for a chord event; shorter events get smoothed away.
356
+
357
+ Returns
358
+ -------
359
+ list[dict]
360
+ List of chord event dicts, each containing:
361
+ - start_time (float): onset time in seconds
362
+ - end_time (float): offset time in seconds
363
+ - chord_name (str): display name, e.g. "Am7"
364
+ - root_note (str): root pitch class name, e.g. "A"
365
+ - quality (str): chord quality, e.g. "minor 7"
366
+ - notes (list[str]): constituent note names
367
+ - midi_pitches (list[int]): original MIDI pitch numbers
368
+ """
369
+ midi_path = Path(midi_path)
370
+ if output_path is None:
371
+ output_path = midi_path.with_name(
372
+ midi_path.stem + "_chords.json"
373
+ )
374
+ else:
375
+ output_path = Path(output_path)
376
+
377
+ print(f"\nChord detection: {midi_path.name}")
378
+
379
+ # Load MIDI
380
+ midi_data = pretty_midi.PrettyMIDI(str(midi_path))
381
+
382
+ # Extract note frames
383
+ frames = extract_note_frames(midi_data, onset_tolerance=onset_tolerance)
384
+ print(f" Extracted {len(frames)} note frames")
385
+
386
+ if not frames:
387
+ result = []
388
+ _write_json(result, output_path)
389
+ return result
390
+
391
+ # Match chords for each frame
392
+ raw_events = []
393
+ for frame in frames:
394
+ chord = match_chord(frame["pitches"], frame["velocities"])
395
+ if chord is None:
396
+ continue
397
+
398
+ raw_events.append({
399
+ "start_time": round(frame["start"], 4),
400
+ "end_time": round(frame["end"], 4),
401
+ "chord_name": chord["chord_name"],
402
+ "root_note": chord["root_name"],
403
+ "quality": chord["quality"],
404
+ "notes": chord["notes"],
405
+ "midi_pitches": chord["midi_pitches"],
406
+ })
407
+
408
+ print(f" Identified {len(raw_events)} raw chord events")
409
+
410
+ # Smooth results
411
+ smoothed = smooth_chords(raw_events, min_duration=min_chord_duration)
412
+ print(f" After smoothing: {len(smoothed)} chord events")
413
+
414
+ # Round all times for clean output
415
+ for event in smoothed:
416
+ event["start_time"] = round(event["start_time"], 4)
417
+ event["end_time"] = round(event["end_time"], 4)
418
+
419
+ # Summary: count unique chords
420
+ unique_chords = set(e["chord_name"] for e in smoothed)
421
+ print(f" Unique chords: {len(unique_chords)} ({', '.join(sorted(unique_chords))})")
422
+
423
+ # Write JSON
424
+ _write_json(smoothed, output_path)
425
+ print(f" Saved to {output_path}")
426
+
427
+ return smoothed
428
+
429
+
430
+ def _write_json(data, path):
431
+ """Write chord data to a JSON file."""
432
+ output = {
433
+ "version": 1,
434
+ "description": "Chord detection output from Mr. Octopus piano tutorial pipeline",
435
+ "chord_count": len(data),
436
+ "chords": data,
437
+ }
438
+ with open(path, "w") as f:
439
+ json.dump(output, f, indent=2)
440
+
441
+
442
+ # ---------------------------------------------------------------------------
443
+ # CLI entry point
444
+ # ---------------------------------------------------------------------------
445
+
446
+ if __name__ == "__main__":
447
+ import sys
448
+
449
+ if len(sys.argv) < 2:
450
+ print("Usage: python chords.py <midi_file> [output.json]")
451
+ print()
452
+ print("Analyzes a MIDI file and detects chords at each note onset.")
453
+ print("Outputs a JSON file with timestamped chord events.")
454
+ sys.exit(1)
455
+
456
+ midi_file = sys.argv[1]
457
+ out_file = sys.argv[2] if len(sys.argv) > 2 else None
458
+ events = detect_chords(midi_file, out_file)
459
+ print(f"\nDetected {len(events)} chord events")
transcriber/optimize.py ADDED
@@ -0,0 +1,1356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Optimize MIDI transcription by correcting onsets, cleaning artifacts, and
2
+ ensuring rhythmic accuracy against the original audio."""
3
+
4
+ import copy
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import pretty_midi
9
+ import librosa
10
+ from collections import Counter
11
+
12
+
13
+ def remove_leading_silence_notes(midi_data, y, sr):
14
+ """Remove notes that appear during silence/noise before the music starts.
15
+
16
+ Finds the first moment of real musical energy and removes any MIDI notes
17
+ before that point (typically microphone rumble / low-freq noise artifacts).
18
+ """
19
+ midi_out = copy.deepcopy(midi_data)
20
+
21
+ # Compute RMS in 50ms windows
22
+ hop = int(0.05 * sr)
23
+ rms = np.array([
24
+ np.sqrt(np.mean(y[i * hop:(i + 1) * hop] ** 2))
25
+ for i in range(len(y) // hop)
26
+ ])
27
+
28
+ if len(rms) == 0:
29
+ return midi_out, 0, 0.0
30
+
31
+ # Music starts when RMS first exceeds 10% of the peak energy
32
+ max_rms = np.max(rms)
33
+ music_start = 0.0
34
+ for i, r in enumerate(rms):
35
+ if r > max_rms * 0.1:
36
+ music_start = i * 0.05
37
+ break
38
+
39
+ if music_start < 0.1:
40
+ return midi_out, 0, music_start
41
+
42
+ removed = 0
43
+ for instrument in midi_out.instruments:
44
+ filtered = []
45
+ for note in instrument.notes:
46
+ if note.start < music_start:
47
+ removed += 1
48
+ else:
49
+ filtered.append(note)
50
+ instrument.notes = filtered
51
+
52
+ return midi_out, removed, music_start
53
+
54
+
55
+ def remove_trailing_silence_notes(midi_data, y, sr):
56
+ """Remove notes that appear during the audio fade-out/silence at the end."""
57
+ midi_out = copy.deepcopy(midi_data)
58
+
59
+ hop = int(0.05 * sr)
60
+ rms = np.array([
61
+ np.sqrt(np.mean(y[i * hop:(i + 1) * hop] ** 2))
62
+ for i in range(len(y) // hop)
63
+ ])
64
+ if len(rms) == 0:
65
+ return midi_out, 0, len(y) / sr
66
+
67
+ max_rms = np.max(rms)
68
+
69
+ # Find the last moment where RMS exceeds 5% of peak (searching backwards)
70
+ music_end = len(y) / sr
71
+ for i in range(len(rms) - 1, -1, -1):
72
+ if rms[i] > max_rms * 0.05:
73
+ music_end = (i + 1) * 0.05
74
+ break
75
+
76
+ removed = 0
77
+ for instrument in midi_out.instruments:
78
+ filtered = []
79
+ for note in instrument.notes:
80
+ if note.start > music_end:
81
+ removed += 1
82
+ else:
83
+ filtered.append(note)
84
+ instrument.notes = filtered
85
+
86
+ return midi_out, removed, music_end
87
+
88
+
89
+ def remove_low_energy_notes(midi_data, y, sr, hop_length=512):
90
+ """Remove notes whose onsets don't correspond to real audio energy.
91
+
92
+ Computes the onset strength envelope and removes notes at times
93
+ where the audio shows no significant onset energy. This catches
94
+ basic-pitch hallucinations that appear at normal pitches but have
95
+ no corresponding audio event.
96
+
97
+ Uses an adaptive threshold based on the recording's onset strength
98
+ distribution (15th percentile), so it works equally well on loud
99
+ and quiet recordings.
100
+ """
101
+ midi_out = copy.deepcopy(midi_data)
102
+
103
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length)
104
+ onset_times = librosa.frames_to_time(
105
+ np.arange(len(onset_env)), sr=sr, hop_length=hop_length
106
+ )
107
+
108
+ removed = 0
109
+ for instrument in midi_out.instruments:
110
+ # First pass: measure strength per note
111
+ note_strengths = []
112
+ for note in instrument.notes:
113
+ frame = np.argmin(np.abs(onset_times - note.start))
114
+ lo = max(0, frame - 2)
115
+ hi = min(len(onset_env), frame + 3)
116
+ strength = float(np.max(onset_env[lo:hi]))
117
+ note_strengths.append(strength)
118
+
119
+ if not note_strengths:
120
+ continue
121
+
122
+ # Adaptive threshold: 15th percentile of note onset strengths
123
+ # This adapts to the recording's volume — quiet recordings get
124
+ # a lower threshold, loud recordings get a higher one.
125
+ # Floor at 0.5 to always catch clearly silent hallucinations.
126
+ strength_threshold = max(0.5, float(np.percentile(note_strengths, 15)))
127
+
128
+ filtered = []
129
+ for idx, note in enumerate(instrument.notes):
130
+ if note_strengths[idx] >= strength_threshold:
131
+ filtered.append(note)
132
+ else:
133
+ # Keep notes that are part of a chord with a strong onset
134
+ chord_has_energy = False
135
+ for other_idx, other in enumerate(instrument.notes):
136
+ if other is note:
137
+ continue
138
+ if abs(other.start - note.start) < 0.03 and note_strengths[other_idx] >= strength_threshold:
139
+ chord_has_energy = True
140
+ break
141
+ if chord_has_energy:
142
+ filtered.append(note)
143
+ else:
144
+ removed += 1
145
+ instrument.notes = filtered
146
+
147
+ return midi_out, removed
148
+
149
+
150
+ def remove_harmonic_ghosts(midi_data, y=None, sr=22050, hop_length=512):
151
+ """Remove notes that are harmonic doublings of louder lower notes.
152
+
153
+ Pairwise detector: for notes at harmonic intervals (7, 12, 19, 24
154
+ semitones), remove the upper note if it's clearly a harmonic ghost.
155
+
156
+ Uses CQT energy to protect strong notes: if the CQT shows the note
157
+ has strong energy (> -10dB), it's a real played note regardless of
158
+ velocity ratio. This prevents removing notes like C6 that happen to
159
+ co-occur with C5 but are genuinely played.
160
+ """
161
+ midi_out = copy.deepcopy(midi_data)
162
+ removed = 0
163
+
164
+ harmonic_intervals = {7, 12, 19, 24}
165
+
166
+ # Compute CQT for energy verification if audio provided
167
+ C_db = None
168
+ if y is not None:
169
+ N_BINS = 88 * 3
170
+ FMIN = librosa.note_to_hz('A0')
171
+ C = np.abs(librosa.cqt(
172
+ y, sr=sr, hop_length=hop_length,
173
+ fmin=FMIN, n_bins=N_BINS, bins_per_octave=36,
174
+ ))
175
+ C_db = librosa.amplitude_to_db(C, ref=np.max(C))
176
+
177
+ for instrument in midi_out.instruments:
178
+ notes = sorted(instrument.notes, key=lambda n: n.start)
179
+ to_remove = set()
180
+
181
+ for i, note in enumerate(notes):
182
+ if i in to_remove:
183
+ continue
184
+ if note.pitch < 48:
185
+ continue
186
+
187
+ # Check CQT energy — protect notes with moderate+ energy
188
+ if C_db is not None:
189
+ fund_bin = (note.pitch - 21) * 3 + 1
190
+ if 0 <= fund_bin < C_db.shape[0]:
191
+ start_frame = max(0, int(note.start * sr / hop_length))
192
+ end_frame = min(C_db.shape[1], start_frame + max(1, int(0.2 * sr / hop_length)))
193
+ lo = max(0, fund_bin - 1)
194
+ hi = min(C_db.shape[0], fund_bin + 2)
195
+ onset_db = float(np.max(C_db[lo:hi, start_frame:end_frame]))
196
+ if onset_db > -12.0:
197
+ # Real CQT energy present — keep this note
198
+ continue
199
+
200
+ for j, other in enumerate(notes):
201
+ if i == j or j in to_remove:
202
+ continue
203
+ if abs(other.start - note.start) > 0.10:
204
+ continue
205
+ diff = note.pitch - other.pitch
206
+ if diff in harmonic_intervals and diff > 0:
207
+ ratio = note.velocity / max(1, other.velocity)
208
+
209
+ if note.pitch >= 72:
210
+ # C5+: remove if noticeably quieter
211
+ if ratio < 0.75:
212
+ to_remove.add(i)
213
+ break
214
+ elif other.pitch < 48:
215
+ if ratio < 0.95:
216
+ to_remove.add(i)
217
+ break
218
+ else:
219
+ if ratio < 0.75:
220
+ to_remove.add(i)
221
+ break
222
+
223
+ instrument.notes = [n for k, n in enumerate(notes) if k not in to_remove]
224
+ removed += len(to_remove)
225
+
226
+ return midi_out, removed
227
+
228
+
229
+ def remove_phantom_notes(midi_data, max_pitch=None):
230
+ """Remove high-register notes that are likely harmonic artifacts.
231
+
232
+ Uses multiple factors to distinguish real notes from phantoms:
233
+ - Must be above the 95th percentile pitch
234
+ - Must be rare (< 3 occurrences at that exact pitch)
235
+ - Must have low velocity (< 40)
236
+ - Must be isolated (no other notes within 2 semitones and 500ms)
237
+ """
238
+ midi_out = copy.deepcopy(midi_data)
239
+ all_notes = [(n, i) for i, inst in enumerate(midi_out.instruments) for n in inst.notes]
240
+ all_pitches = [n.pitch for n, _ in all_notes]
241
+ if not all_pitches:
242
+ return midi_out, 0
243
+
244
+ if max_pitch is None:
245
+ max_pitch = int(np.percentile(all_pitches, 95))
246
+
247
+ pitch_counts = Counter(all_pitches)
248
+
249
+ # Build a time-sorted list for neighbor checking
250
+ time_sorted = sorted(all_notes, key=lambda x: x[0].start)
251
+
252
+ def is_isolated(note, all_sorted):
253
+ """Check if a note has no other notes nearby (within 100ms).
254
+
255
+ A note in a chord or musical event is not isolated, regardless
256
+ of the pitch of its neighbors. This prevents falsely removing
257
+ high notes that are part of chords with lower-pitched notes.
258
+ """
259
+ for other, _ in all_sorted:
260
+ if other is note:
261
+ continue
262
+ if other.start > note.start + 0.1:
263
+ break
264
+ if abs(other.start - note.start) < 0.1:
265
+ return False
266
+ return True
267
+
268
+ removed = 0
269
+ for instrument in midi_out.instruments:
270
+ filtered = []
271
+ for note in instrument.notes:
272
+ if note.pitch > max_pitch:
273
+ count = pitch_counts[note.pitch]
274
+ duration = note.end - note.start
275
+ # Higher velocity threshold for very high notes (above MIDI 80)
276
+ vel_thresh = 55 if note.pitch > 80 else 40
277
+ # Only remove if MULTIPLE indicators suggest it's a phantom:
278
+ # Very rare AND (low velocity OR very short OR isolated)
279
+ if count < 3 and (note.velocity < vel_thresh or duration < 0.08 or
280
+ is_isolated(note, time_sorted)):
281
+ removed += 1
282
+ continue
283
+ filtered.append(note)
284
+ instrument.notes = filtered
285
+
286
+ return midi_out, removed
287
+
288
+
289
+ def remove_spurious_onsets(midi_data, y, sr, ref_onsets, hop_length=512):
290
+ """Remove MIDI notes that form false-positive onsets not backed by audio.
291
+
292
+ Analysis shows 37 extra MIDI onsets cause the biggest F1 drag (precision=0.918).
293
+ This filter targets three categories of false positives:
294
+
295
+ 1. Chord fragments: notes that basic-pitch split from a real chord, creating
296
+ a separate onset within 60ms of a matched onset. These should have been
297
+ grouped with the chord.
298
+ 2. Isolated ghost onsets: single-note, low-strength onsets far from any
299
+ audio onset -- pure hallucinations.
300
+ 3. Short+quiet artifacts: onsets where every note is both short (<200ms)
301
+ and quiet (velocity < 50).
302
+
303
+ The filter first identifies which MIDI onsets already match audio onsets,
304
+ then only removes unmatched onsets meeting the above criteria.
305
+ """
306
+ midi_out = copy.deepcopy(midi_data)
307
+ tolerance = 0.05
308
+
309
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length)
310
+ onset_times = librosa.frames_to_time(
311
+ np.arange(len(onset_env)), sr=sr, hop_length=hop_length
312
+ )
313
+
314
+ # Collect all notes and compute unique onsets
315
+ all_notes = sorted(
316
+ [n for inst in midi_out.instruments for n in inst.notes],
317
+ key=lambda n: n.start
318
+ )
319
+ midi_onsets = sorted(set(round(n.start, 4) for n in all_notes))
320
+ midi_onsets_arr = np.array(midi_onsets)
321
+
322
+ # Identify which MIDI onsets are already matched to audio onsets
323
+ matched_est = set()
324
+ for r in ref_onsets:
325
+ diffs = np.abs(midi_onsets_arr - r)
326
+ best = np.argmin(diffs)
327
+ if diffs[best] <= tolerance and best not in matched_est:
328
+ matched_est.add(best)
329
+
330
+ # For each unmatched onset, check removal criteria
331
+ onsets_to_remove = set()
332
+ for j, mo in enumerate(midi_onsets_arr):
333
+ if j in matched_est:
334
+ continue
335
+
336
+ # Get notes at this onset
337
+ onset_notes = [n for n in all_notes if abs(n.start - mo) < 0.03]
338
+ if not onset_notes:
339
+ continue
340
+
341
+ # Compute onset strength at this time
342
+ frame = np.argmin(np.abs(onset_times - mo))
343
+ lo = max(0, frame - 2)
344
+ hi = min(len(onset_env), frame + 3)
345
+ strength = float(np.max(onset_env[lo:hi]))
346
+
347
+ # Distance to nearest audio onset
348
+ diffs_audio = np.abs(ref_onsets - mo)
349
+ nearest_audio_ms = float(np.min(diffs_audio)) * 1000
350
+
351
+ # Check if near a matched MIDI onset (chord fragment)
352
+ near_matched = any(
353
+ abs(midi_onsets_arr[k] - mo) < 0.060
354
+ for k in matched_est
355
+ )
356
+
357
+ # Category 1: Chord fragment -- near a matched onset, but only if
358
+ # the onset has weak audio energy. Strong onsets near chords may be
359
+ # real grace notes or arpeggios.
360
+ if near_matched and strength < 2.0:
361
+ onsets_to_remove.add(j)
362
+ continue
363
+
364
+ # Category 2: Isolated ghost -- single note, low strength or far from audio
365
+ if len(onset_notes) == 1 and (strength < 1.5 or nearest_audio_ms > 100):
366
+ onsets_to_remove.add(j)
367
+ continue
368
+
369
+ # Category 3: Short+quiet artifact
370
+ if all(n.end - n.start < 0.2 and n.velocity < 50 for n in onset_notes):
371
+ onsets_to_remove.add(j)
372
+ continue
373
+
374
+ # Category 4: Low-velocity bass ghost -- single bass note (< MIDI 40),
375
+ # low velocity (< 35), far from audio onset. These are rumble artifacts
376
+ # that survive the energy filter.
377
+ if (len(onset_notes) == 1 and onset_notes[0].pitch < 40
378
+ and onset_notes[0].velocity < 35 and nearest_audio_ms > 60):
379
+ onsets_to_remove.add(j)
380
+ continue
381
+
382
+ # Category 5: Multi-note onset very far from any audio onset (> 200ms)
383
+ # with weak onset strength. These are chord-split artifacts or
384
+ # hallucinated events with no audio support.
385
+ if nearest_audio_ms > 200 and strength < 2.0:
386
+ onsets_to_remove.add(j)
387
+ continue
388
+
389
+ # Remove notes belonging to spurious onsets
390
+ times_to_remove = set(midi_onsets_arr[j] for j in onsets_to_remove)
391
+ removed = 0
392
+ for instrument in midi_out.instruments:
393
+ filtered = []
394
+ for note in instrument.notes:
395
+ note_onset = round(note.start, 4)
396
+ if any(abs(note_onset - t) < 0.03 for t in times_to_remove):
397
+ removed += 1
398
+ else:
399
+ filtered.append(note)
400
+ instrument.notes = filtered
401
+
402
+ return midi_out, removed, len(onsets_to_remove)
403
+
404
+
405
+ def remove_pitch_unconfirmed_notes(midi_data, y, sr, hop_length=512):
406
+ """Remove notes where the CQT has no energy at their fundamental pitch.
407
+
408
+ Checks the onset region (first 200ms) of each note for CQT energy,
409
+ not the full duration. This prevents CQT-extended notes from being
410
+ falsely removed due to low average energy over their extended tail.
411
+
412
+ Targets two ranges where hallucinations concentrate:
413
+ - Sub-bass (< MIDI 40): rumble artifacts
414
+ - Upper register (> MIDI 72): harmonic doublings
415
+ Core piano range (MIDI 40-72 / E2-C5) is reliable from basic-pitch.
416
+ """
417
+ midi_out = copy.deepcopy(midi_data)
418
+
419
+ N_BINS = 88 * 3
420
+ FMIN = librosa.note_to_hz('A0')
421
+ C = np.abs(librosa.cqt(
422
+ y, sr=sr, hop_length=hop_length,
423
+ fmin=FMIN, n_bins=N_BINS, bins_per_octave=36,
424
+ ))
425
+ C_db = librosa.amplitude_to_db(C, ref=np.max(C))
426
+
427
+ # Collect all notes for chord checking
428
+ all_notes = sorted(
429
+ [n for inst in midi_out.instruments for n in inst.notes],
430
+ key=lambda n: n.start
431
+ )
432
+
433
+ # Onset region: check max energy in first 200ms
434
+ onset_frames = max(1, int(0.2 * sr / hop_length))
435
+
436
+ removed = 0
437
+ for instrument in midi_out.instruments:
438
+ filtered = []
439
+ for note in instrument.notes:
440
+ # Only filter sub-bass and upper register — core range is reliable
441
+ if 40 <= note.pitch <= 72:
442
+ filtered.append(note)
443
+ continue
444
+
445
+ fund_bin = (note.pitch - 21) * 3 + 1
446
+ if fund_bin < 0 or fund_bin >= N_BINS:
447
+ filtered.append(note)
448
+ continue
449
+
450
+ start_frame = max(0, int(note.start * sr / hop_length))
451
+ check_end = min(C.shape[1], start_frame + onset_frames)
452
+ if start_frame >= check_end:
453
+ filtered.append(note)
454
+ continue
455
+
456
+ lo = max(0, fund_bin - 1)
457
+ hi = min(N_BINS, fund_bin + 2)
458
+ # Use max energy in onset region, not average over full duration
459
+ onset_db = float(np.max(C_db[lo:hi, start_frame:check_end]))
460
+
461
+ if note.pitch < 40:
462
+ thresh = -42.0
463
+ else: # > 72, upper register
464
+ thresh = -20.0
465
+
466
+ if onset_db < thresh:
467
+ # Remove if weak CQT evidence regardless of context
468
+ # Very weak = always remove; moderate weak = check isolation
469
+ if onset_db < thresh - 10:
470
+ # Extremely weak: always remove
471
+ removed += 1
472
+ continue
473
+ concurrent = sum(1 for o in all_notes
474
+ if abs(o.start - note.start) < 0.05 and o is not note)
475
+ if concurrent <= 3 or note.velocity < 55:
476
+ removed += 1
477
+ else:
478
+ filtered.append(note)
479
+ else:
480
+ filtered.append(note)
481
+ instrument.notes = filtered
482
+
483
+ return midi_out, removed
484
+
485
+
486
+ def apply_pitch_ceiling(midi_data, max_pitch=96):
487
+ """Remove notes above a hard pitch ceiling (C7 / MIDI 96).
488
+
489
+ Only truly extreme high notes are blanket-removed. Notes between C6-C7
490
+ are kept and handled by the CQT energy filter instead, since some
491
+ (like C6, D6) can be legitimate played notes.
492
+ """
493
+ midi_out = copy.deepcopy(midi_data)
494
+ removed = 0
495
+
496
+ for instrument in midi_out.instruments:
497
+ filtered = []
498
+ for note in instrument.notes:
499
+ if note.pitch >= max_pitch:
500
+ removed += 1
501
+ else:
502
+ filtered.append(note)
503
+ instrument.notes = filtered
504
+
505
+ return midi_out, removed
506
+
507
+
508
+ def limit_concurrent_notes(midi_data, max_per_hand=4, hand_split=60):
509
+ """Limit notes per chord to max_per_hand per hand.
510
+
511
+ Groups notes by onset time (within 30ms) and splits into left/right hand.
512
+ Removes excess notes — protects melody (highest RH pitch) and bass
513
+ (lowest LH pitch), then removes lowest velocity.
514
+ """
515
+ midi_out = copy.deepcopy(midi_data)
516
+ removed = 0
517
+
518
+ for instrument in midi_out.instruments:
519
+ notes = sorted(instrument.notes, key=lambda n: n.start)
520
+ if not notes:
521
+ continue
522
+
523
+ chords = []
524
+ current_chord = [0]
525
+ for i in range(1, len(notes)):
526
+ if notes[i].start - notes[current_chord[0]].start < 0.03:
527
+ current_chord.append(i)
528
+ else:
529
+ chords.append(current_chord)
530
+ current_chord = [i]
531
+ chords.append(current_chord)
532
+
533
+ to_remove = set()
534
+ for chord_indices in chords:
535
+ left = [idx for idx in chord_indices if notes[idx].pitch < hand_split]
536
+ right = [idx for idx in chord_indices if notes[idx].pitch >= hand_split]
537
+
538
+ for is_right, hand_indices in [(True, right), (False, left)]:
539
+ if len(hand_indices) <= max_per_hand:
540
+ continue
541
+
542
+ # Protect melody (highest RH) or bass (lowest LH)
543
+ if is_right:
544
+ protected = max(hand_indices, key=lambda idx: notes[idx].pitch)
545
+ else:
546
+ protected = min(hand_indices, key=lambda idx: notes[idx].pitch)
547
+
548
+ trimmable = [idx for idx in hand_indices if idx != protected]
549
+ scored = [(notes[idx].velocity, idx) for idx in trimmable]
550
+ scored.sort()
551
+
552
+ excess = len(hand_indices) - max_per_hand
553
+ for _, idx in scored[:excess]:
554
+ to_remove.add(idx)
555
+
556
+ instrument.notes = [n for k, n in enumerate(notes) if k not in to_remove]
557
+ removed += len(to_remove)
558
+
559
+ return midi_out, removed
560
+
561
+
562
+ def limit_total_concurrent(midi_data, max_per_hand=4, hand_split=60):
563
+ """Limit concurrent sounding notes to max_per_hand per hand.
564
+
565
+ Splits notes into left hand (< hand_split) and right hand (>= hand_split).
566
+ At each note onset, count concurrent notes in that hand. If > max_per_hand,
567
+ trim sustained notes — but protect the melody (highest RH pitch) and bass
568
+ (lowest LH pitch). Among the rest, trim lowest velocity first.
569
+ """
570
+ midi_out = copy.deepcopy(midi_data)
571
+ trimmed = 0
572
+
573
+ for instrument in midi_out.instruments:
574
+ notes = sorted(instrument.notes, key=lambda n: n.start)
575
+ if not notes:
576
+ continue
577
+
578
+ for i, note in enumerate(notes):
579
+ is_right = note.pitch >= hand_split
580
+
581
+ # Find all notes in the same hand currently sounding
582
+ sounding = []
583
+ for j in range(i):
584
+ if notes[j].end > note.start:
585
+ same_hand = (notes[j].pitch >= hand_split) == is_right
586
+ if same_hand:
587
+ sounding.append(j)
588
+
589
+ if len(sounding) + 1 > max_per_hand:
590
+ excess = len(sounding) + 1 - max_per_hand
591
+ # All indices including the current note
592
+ all_indices = sounding + [i]
593
+
594
+ if is_right:
595
+ # Protect highest pitch (melody)
596
+ protected = max(all_indices, key=lambda j: notes[j].pitch)
597
+ else:
598
+ # Protect lowest pitch (bass)
599
+ protected = min(all_indices, key=lambda j: notes[j].pitch)
600
+
601
+ # Among the sustained (not the new note), trim lowest velocity
602
+ # but never trim the protected note
603
+ trimmable = [j for j in sounding if j != protected]
604
+ scored = [(notes[j].velocity, j) for j in trimmable]
605
+ scored.sort() # lowest velocity trimmed first
606
+ for _, j in scored[:excess]:
607
+ notes[j].end = note.start
608
+ trimmed += 1
609
+
610
+ instrument.notes = [n for n in notes if n.end - n.start > 0.01]
611
+
612
+ return midi_out, trimmed
613
+
614
+
615
+ def extend_note_durations(midi_data, y, sr, hop_length=512, max_per_hand=4, hand_split=60):
616
+ """Extend MIDI note durations to match audio CQT energy decay.
617
+
618
+ Basic-pitch systematically underestimates note durations. This uses
619
+ the CQT spectrogram to find where the audio energy actually decays
620
+ and extends each note to match, dramatically improving spectral recall.
621
+
622
+ Concurrent-aware: won't extend a note past the point where doing so
623
+ would exceed max_per_hand concurrent notes in the same hand. This
624
+ prevents the downstream concurrent limiter from having to trim hundreds
625
+ of over-extended notes.
626
+ """
627
+ midi_out = copy.deepcopy(midi_data)
628
+
629
+ N_BINS = 88 * 3
630
+ FMIN = librosa.note_to_hz('A0')
631
+ C = np.abs(librosa.cqt(
632
+ y, sr=sr, hop_length=hop_length,
633
+ fmin=FMIN, n_bins=N_BINS, bins_per_octave=36,
634
+ ))
635
+ C_db = librosa.amplitude_to_db(C, ref=np.max(C))
636
+ C_norm = np.maximum(C_db, -80.0)
637
+ C_norm = (C_norm + 80.0) / 80.0
638
+ n_frames = C.shape[1]
639
+
640
+ # Pre-compute per-frame concurrent counts per hand (fast O(1) lookup)
641
+ right_count = np.zeros(n_frames, dtype=int)
642
+ left_count = np.zeros(n_frames, dtype=int)
643
+ for inst in midi_out.instruments:
644
+ for n in inst.notes:
645
+ sf = max(0, int(n.start * sr / hop_length))
646
+ ef = min(n_frames, int(n.end * sr / hop_length))
647
+ if n.pitch >= hand_split:
648
+ right_count[sf:ef] += 1
649
+ else:
650
+ left_count[sf:ef] += 1
651
+
652
+ extended = 0
653
+ for inst in midi_out.instruments:
654
+ # Sort notes by start time for overlap checking
655
+ notes_sorted = sorted(inst.notes, key=lambda n: (n.pitch, n.start))
656
+
657
+ for idx, note in enumerate(notes_sorted):
658
+ fund_bin = (note.pitch - 21) * 3 + 1
659
+ if fund_bin < 0 or fund_bin >= N_BINS:
660
+ continue
661
+
662
+ end_frame = min(n_frames, int(note.end * sr / hop_length))
663
+ # Max extension: 2 seconds beyond current end
664
+ max_extend = min(n_frames, end_frame + int(2.0 * sr / hop_length))
665
+
666
+ # Don't extend into the next note at the same pitch
667
+ next_start_frame = max_extend
668
+ for other in notes_sorted[idx + 1:]:
669
+ if other.pitch == note.pitch:
670
+ next_start_frame = min(next_start_frame, int(other.start * sr / hop_length) - 1)
671
+ break
672
+
673
+ is_right = note.pitch >= hand_split
674
+ hand_count = right_count if is_right else left_count
675
+
676
+ actual_end = end_frame
677
+ for f in range(end_frame, min(max_extend, next_start_frame)):
678
+ lo = max(0, fund_bin - 1)
679
+ hi = min(N_BINS, fund_bin + 2)
680
+ if np.mean(C_norm[lo:hi, f]) > 0.20:
681
+ # Check concurrent: this note isn't counted in hand_count
682
+ # beyond end_frame, so hand_count[f] >= max_per_hand means
683
+ # extending here would create max_per_hand + 1 concurrent
684
+ if hand_count[f] >= max_per_hand:
685
+ break
686
+ actual_end = f
687
+ else:
688
+ break
689
+
690
+ new_end = actual_end * hop_length / sr
691
+ if new_end > note.end + 0.05:
692
+ # Update the concurrent count array for the extended region
693
+ old_end_frame = end_frame
694
+ new_end_frame = min(n_frames, int(new_end * sr / hop_length))
695
+ if new_end_frame > old_end_frame:
696
+ hand_count[old_end_frame:new_end_frame] += 1
697
+ note.end = new_end
698
+ extended += 1
699
+
700
+ return midi_out, extended
701
+
702
+
703
+ def align_chords(midi_data, threshold=0.02):
704
+ """Snap notes within a chord to the exact same onset time.
705
+
706
+ basic-pitch's ~12ms frame resolution can make notes in the same chord
707
+ start at slightly different times, causing a 'flammy' sound.
708
+ """
709
+ midi_out = copy.deepcopy(midi_data)
710
+ aligned = 0
711
+
712
+ for instrument in midi_out.instruments:
713
+ notes = sorted(instrument.notes, key=lambda n: n.start)
714
+ i = 0
715
+ while i < len(notes):
716
+ group = [notes[i]]
717
+ j = i + 1
718
+ while j < len(notes) and notes[j].start - notes[i].start < threshold:
719
+ group.append(notes[j])
720
+ j += 1
721
+
722
+ if len(group) > 1:
723
+ median_start = float(np.median([n.start for n in group]))
724
+ for note in group:
725
+ if note.start != median_start:
726
+ duration = note.end - note.start
727
+ note.start = median_start
728
+ note.end = median_start + duration
729
+ aligned += 1
730
+
731
+ i = j
732
+
733
+ return midi_out, aligned
734
+
735
+
736
+ def quantize_to_beat_grid(midi_data, y, sr, hop_length=512, strength=0.5):
737
+ """Quantize note onsets to a detected beat grid.
738
+
739
+ Uses librosa beat tracking to find the tempo and beat positions,
740
+ builds a 16th-note grid, and snaps onsets toward the nearest grid
741
+ position. Preserves note durations.
742
+ """
743
+ midi_out = copy.deepcopy(midi_data)
744
+
745
+ tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr, hop_length=hop_length)
746
+ if hasattr(tempo, '__len__'):
747
+ tempo = float(tempo[0])
748
+ beat_times = librosa.frames_to_time(beat_frames, sr=sr, hop_length=hop_length)
749
+
750
+ if len(beat_times) < 2:
751
+ print(" Could not detect beats, skipping quantization")
752
+ return midi_out, 0, tempo
753
+
754
+ # Build a 16th-note grid from the beat times
755
+ grid = []
756
+ for i in range(len(beat_times) - 1):
757
+ beat_dur = beat_times[i + 1] - beat_times[i]
758
+ sixteenth = beat_dur / 4
759
+ for sub in range(4):
760
+ grid.append(beat_times[i] + sub * sixteenth)
761
+ if len(beat_times) >= 2:
762
+ last_beat_dur = beat_times[-1] - beat_times[-2]
763
+ sixteenth = last_beat_dur / 4
764
+ for sub in range(4):
765
+ grid.append(beat_times[-1] + sub * sixteenth)
766
+
767
+ grid = np.array(grid)
768
+ quantized = 0
769
+
770
+ for instrument in midi_out.instruments:
771
+ for note in instrument.notes:
772
+ diffs = np.abs(grid - note.start)
773
+ nearest_idx = np.argmin(diffs)
774
+ nearest_grid = grid[nearest_idx]
775
+ deviation = nearest_grid - note.start
776
+
777
+ if abs(deviation) < 0.06: # Only quantize if < 60ms off grid
778
+ duration = note.end - note.start
779
+ note.start = note.start + deviation * strength
780
+ note.end = note.start + duration
781
+ if abs(deviation) > 0.005:
782
+ quantized += 1
783
+
784
+ return midi_out, quantized, tempo
785
+
786
+
787
+ def correct_onsets(midi_data, ref_onsets, min_off=0.02, max_off=0.15):
788
+ """Correct chord onsets that are clearly misaligned with audio onsets.
789
+
790
+ Groups notes into chords, then for each chord checks if there's a closer
791
+ audio onset. Only corrects if min_off-max_off off and no adjacent chord
792
+ is a better match.
793
+ """
794
+ midi_out = copy.deepcopy(midi_data)
795
+
796
+ all_notes = sorted(
797
+ [(n, inst_idx) for inst_idx, inst in enumerate(midi_out.instruments)
798
+ for n in inst.notes],
799
+ key=lambda x: x[0].start
800
+ )
801
+
802
+ chord_groups = []
803
+ if all_notes:
804
+ current_group = [all_notes[0]]
805
+ for item in all_notes[1:]:
806
+ if item[0].start - current_group[0][0].start < 0.03:
807
+ current_group.append(item)
808
+ else:
809
+ chord_groups.append(current_group)
810
+ current_group = [item]
811
+ chord_groups.append(current_group)
812
+
813
+ chord_onsets = np.array([g[0][0].start for g in chord_groups])
814
+ corrections = 0
815
+ total_shift = 0.0
816
+
817
+ for group_idx, group in enumerate(chord_groups):
818
+ chord_onset = chord_onsets[group_idx]
819
+ diffs = ref_onsets - chord_onset
820
+ abs_diffs = np.abs(diffs)
821
+ nearest_idx = np.argmin(abs_diffs)
822
+ nearest_diff = diffs[nearest_idx]
823
+ abs_diff = abs_diffs[nearest_idx]
824
+
825
+ if min_off < abs_diff < max_off:
826
+ # Verify no adjacent chord is a better match
827
+ if group_idx > 0:
828
+ prev_onset = chord_onsets[group_idx - 1]
829
+ if abs(ref_onsets[nearest_idx] - prev_onset) < abs_diff:
830
+ continue
831
+ if group_idx < len(chord_onsets) - 1:
832
+ next_onset = chord_onsets[group_idx + 1]
833
+ if abs(ref_onsets[nearest_idx] - next_onset) < abs_diff:
834
+ continue
835
+
836
+ for note, inst_idx in group:
837
+ duration = note.end - note.start
838
+ note.start = max(0, note.start + nearest_diff)
839
+ note.end = note.start + duration
840
+
841
+ corrections += 1
842
+ total_shift += abs(nearest_diff)
843
+
844
+ initial_f1 = onset_f1(ref_onsets, chord_onsets)
845
+ corrected_onsets = np.array([g[0][0].start for g in chord_groups])
846
+ final_f1 = onset_f1(ref_onsets, corrected_onsets)
847
+
848
+ return midi_out, corrections, total_shift, len(chord_groups), initial_f1, final_f1
849
+
850
+
851
+ def apply_global_offset(midi_data, ref_onsets):
852
+ """Measure and correct systematic timing offset against audio onsets.
853
+
854
+ Computes the median difference between MIDI and audio onsets, then
855
+ shifts all notes to center the distribution around zero.
856
+ """
857
+ midi_out = copy.deepcopy(midi_data)
858
+ all_onsets = sorted(set(n.start for inst in midi_out.instruments for n in inst.notes))
859
+
860
+ diffs = []
861
+ for mo in all_onsets:
862
+ ad = np.abs(ref_onsets - mo)
863
+ if np.min(ad) < 0.10:
864
+ closest = ref_onsets[np.argmin(ad)]
865
+ diffs.append(closest - mo) # positive = MIDI is early, negative = late
866
+
867
+ if not diffs:
868
+ return midi_out, 0.0
869
+
870
+ median_offset = float(np.median(diffs))
871
+
872
+ # Only apply if the offset is meaningful (> 5ms)
873
+ if abs(median_offset) < 0.005:
874
+ return midi_out, 0.0
875
+
876
+ for instrument in midi_out.instruments:
877
+ for note in instrument.notes:
878
+ duration = note.end - note.start
879
+ note.start = max(0, note.start + median_offset)
880
+ note.end = note.start + duration
881
+
882
+ return midi_out, median_offset
883
+
884
+
885
+ def fix_note_overlap(midi_data, hand_split=60, min_duration=0.10):
886
+ """Trim overlapping notes in the right hand so each note releases cleanly.
887
+
888
+ Also enforces a minimum note duration across ALL notes.
889
+ """
890
+ midi_out = copy.deepcopy(midi_data)
891
+ trimmed = 0
892
+
893
+ for instrument in midi_out.instruments:
894
+ rh_notes = [n for n in instrument.notes if n.pitch >= hand_split]
895
+ rh_notes.sort(key=lambda n: (n.start, n.pitch))
896
+
897
+ for i, note in enumerate(rh_notes):
898
+ for j in range(i + 1, min(i + 8, len(rh_notes))):
899
+ next_note = rh_notes[j]
900
+ if next_note.start <= note.start:
901
+ continue
902
+
903
+ overlap = note.end - next_note.start
904
+ if overlap > 0.05: # Only trim significant overlaps (>50ms)
905
+ original_dur = note.end - note.start
906
+ new_end = next_note.start - 0.01
907
+ # Never shorten more than 30% of original duration
908
+ min_allowed = note.start + original_dur * 0.7
909
+ if new_end < min_allowed:
910
+ new_end = min_allowed
911
+ note.end = new_end
912
+ if note.end - note.start < min_duration:
913
+ note.end = note.start + min_duration
914
+ trimmed += 1
915
+ break
916
+
917
+ # Enforce minimum duration on ALL notes (catches any collapsed durations)
918
+ enforced = 0
919
+ for instrument in midi_out.instruments:
920
+ for note in instrument.notes:
921
+ if note.end - note.start < min_duration:
922
+ note.end = note.start + min_duration
923
+ enforced += 1
924
+
925
+ return midi_out, trimmed, enforced
926
+
927
+
928
+ def recover_missing_notes(midi_data, y, sr, hop_length=512, snap_onsets=None):
929
+ """Recover strong notes the transcriber missed using CQT analysis.
930
+
931
+ Scans the audio CQT for pitch energy that isn't represented in the MIDI.
932
+ When a pitch has strong, sustained energy but no corresponding MIDI note,
933
+ synthesize one. Targets upper register (>= C4) where basic-pitch can
934
+ under-detect, especially when harmonics cause false removal.
935
+
936
+ If snap_onsets is provided, recovered notes are snapped to the nearest
937
+ existing onset for rhythmic alignment.
938
+
939
+ Should be run AFTER all removal filters so the coverage map reflects
940
+ what actually survived.
941
+ """
942
+ midi_out = copy.deepcopy(midi_data)
943
+
944
+ N_BINS = 88 * 3
945
+ FMIN = librosa.note_to_hz('A0')
946
+ C = np.abs(librosa.cqt(
947
+ y, sr=sr, hop_length=hop_length,
948
+ fmin=FMIN, n_bins=N_BINS, bins_per_octave=36,
949
+ ))
950
+ C_db = librosa.amplitude_to_db(C, ref=np.max(C))
951
+
952
+ times = librosa.frames_to_time(np.arange(C.shape[1]), sr=sr, hop_length=hop_length)
953
+
954
+ # Build a set of existing note coverage: (pitch, frame) pairs
955
+ existing = set()
956
+ for inst in midi_out.instruments:
957
+ for note in inst.notes:
958
+ start_frame = max(0, int(note.start * sr / hop_length))
959
+ end_frame = min(C.shape[1], int(note.end * sr / hop_length))
960
+ for f in range(start_frame, end_frame):
961
+ existing.add((note.pitch, f))
962
+
963
+ # Scan C4 (60) to B6 (95) for uncovered energy
964
+ recovered = 0
965
+ min_energy = -10.0 # dB threshold — only recover notes with strong CQT energy
966
+ min_duration_s = 0.05 # ~50ms minimum
967
+ gap_tolerance = 3 # allow 3-frame dips without breaking a note
968
+
969
+ for midi_pitch in range(60, 96):
970
+ fund_bin = (midi_pitch - 21) * 3 + 1
971
+ if fund_bin < 0 or fund_bin >= N_BINS:
972
+ continue
973
+
974
+ # Harmonic check: skip if an octave-below note is much louder
975
+ # (this note is likely a harmonic, not a real played note)
976
+ lower_pitch = midi_pitch - 12
977
+ if lower_pitch >= 21:
978
+ lower_bin = (lower_pitch - 21) * 3 + 1
979
+ if 0 <= lower_bin < N_BINS:
980
+ lower_lo = max(0, lower_bin - 1)
981
+ lower_hi = min(N_BINS, lower_bin + 2)
982
+ upper_energy = float(np.max(C_db[max(0, fund_bin - 1):min(N_BINS, fund_bin + 2), :]))
983
+ lower_energy = float(np.max(C_db[lower_lo:lower_hi, :]))
984
+ if lower_energy - upper_energy > 12:
985
+ # Octave below is 12+ dB louder — likely a harmonic
986
+ continue
987
+
988
+ lo = max(0, fund_bin - 1)
989
+ hi = min(N_BINS, fund_bin + 2)
990
+
991
+ # Get energy and coverage for this pitch
992
+ pitch_energy = np.max(C_db[lo:hi, :], axis=0)
993
+
994
+ # Find uncovered regions with strong energy
995
+ strong_uncovered = np.array([
996
+ pitch_energy[f] >= min_energy and (midi_pitch, f) not in existing
997
+ for f in range(len(pitch_energy))
998
+ ])
999
+
1000
+ # Close small gaps (morphological closing)
1001
+ for f in range(1, len(strong_uncovered) - 1):
1002
+ if not strong_uncovered[f] and pitch_energy[f] >= min_energy - 5:
1003
+ before = any(strong_uncovered[max(0, f - gap_tolerance):f])
1004
+ after = any(strong_uncovered[f + 1:min(len(strong_uncovered), f + gap_tolerance + 1)])
1005
+ if before and after:
1006
+ strong_uncovered[f] = True
1007
+
1008
+ # Extract contiguous regions
1009
+ regions = []
1010
+ in_region = False
1011
+ start_f = 0
1012
+ for f in range(len(strong_uncovered)):
1013
+ if strong_uncovered[f] and not in_region:
1014
+ start_f = f
1015
+ in_region = True
1016
+ elif not strong_uncovered[f] and in_region:
1017
+ regions.append((start_f, f))
1018
+ in_region = False
1019
+ if in_region:
1020
+ regions.append((start_f, len(strong_uncovered)))
1021
+
1022
+ for start_f, end_f in regions:
1023
+ t_start = times[start_f]
1024
+ t_end = times[min(end_f, len(times) - 1)]
1025
+ if t_end - t_start < min_duration_s:
1026
+ continue
1027
+
1028
+ avg_energy = float(np.mean(pitch_energy[start_f:end_f]))
1029
+ velocity = min(75, max(35, int(55 + avg_energy * 1.5)))
1030
+
1031
+ # Snap to nearest existing onset for rhythmic alignment
1032
+ note_start = t_start
1033
+ note_end = t_end
1034
+ if snap_onsets is not None and len(snap_onsets) > 0:
1035
+ snap_arr = np.array(snap_onsets)
1036
+ diffs = np.abs(snap_arr - t_start)
1037
+ nearest_idx = np.argmin(diffs)
1038
+ if diffs[nearest_idx] < 0.06:
1039
+ dur = t_end - t_start
1040
+ note_start = snap_arr[nearest_idx]
1041
+ note_end = note_start + dur
1042
+
1043
+ new_note = pretty_midi.Note(
1044
+ velocity=velocity,
1045
+ pitch=midi_pitch,
1046
+ start=note_start,
1047
+ end=note_end,
1048
+ )
1049
+ midi_out.instruments[0].notes.append(new_note)
1050
+ recovered += 1
1051
+
1052
+ return midi_out, recovered
1053
+
1054
+
1055
+ def optimize(original_audio_path, midi_path, output_path=None):
1056
+ """Full optimization pipeline."""
1057
+ if output_path is None:
1058
+ output_path = midi_path
1059
+
1060
+ sr = 22050
1061
+ hop_length = 512
1062
+
1063
+ # Load audio and detect onsets
1064
+ print(f"Analyzing audio: {original_audio_path}")
1065
+ y, _ = librosa.load(original_audio_path, sr=sr, mono=True)
1066
+ audio_duration = len(y) / sr
1067
+
1068
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length)
1069
+ # Use backtrack=False: basic-pitch onsets align with energy peaks, not
1070
+ # the earlier rise points that backtrack finds (~50ms earlier).
1071
+ # Use delta=0.04 for higher sensitivity — detects ~15% more onsets,
1072
+ # reducing unmatched MIDI onsets from 116 to 80.
1073
+ ref_onset_frames = librosa.onset.onset_detect(
1074
+ onset_envelope=onset_env, sr=sr, hop_length=hop_length,
1075
+ backtrack=False, delta=0.04
1076
+ )
1077
+ ref_onsets = librosa.frames_to_time(ref_onset_frames, sr=sr, hop_length=hop_length)
1078
+ print(f" {audio_duration:.1f}s, {len(ref_onsets)} audio onsets")
1079
+
1080
+ # Load MIDI
1081
+ midi_data = pretty_midi.PrettyMIDI(str(midi_path))
1082
+ total_notes = sum(len(inst.notes) for inst in midi_data.instruments)
1083
+ print(f" {total_notes} MIDI notes")
1084
+
1085
+ # Step 0: Remove notes in leading silence (mic rumble artifacts)
1086
+ print("\nStep 0: Removing notes in leading silence...")
1087
+ midi_data, silence_removed, music_start = remove_leading_silence_notes(midi_data, y, sr)
1088
+ if silence_removed:
1089
+ print(f" Music starts at {music_start:.2f}s, removed {silence_removed} noise notes")
1090
+ else:
1091
+ print(f" No leading silence detected")
1092
+
1093
+ # Step 0b: Remove notes in trailing silence
1094
+ print("\nStep 0b: Removing notes in trailing silence...")
1095
+ midi_data, trail_removed, music_end = remove_trailing_silence_notes(midi_data, y, sr)
1096
+ if trail_removed:
1097
+ print(f" Music ends at {music_end:.2f}s, removed {trail_removed} trailing notes")
1098
+ else:
1099
+ print(f" No trailing silence notes detected")
1100
+
1101
+ # Step 0c: Remove low-energy hallucinations
1102
+ print("\nStep 0c: Removing low-energy hallucinations...")
1103
+ midi_data, energy_removed = remove_low_energy_notes(midi_data, y, sr, hop_length)
1104
+ print(f" Removed {energy_removed} notes with no audio onset energy")
1105
+
1106
+ # Step 0d: Remove harmonic ghost notes (CQT-aware)
1107
+ print("\nStep 0d: Removing harmonic ghost notes...")
1108
+ midi_data, ghosts_removed = remove_harmonic_ghosts(midi_data, y, sr, hop_length)
1109
+ print(f" Removed {ghosts_removed} octave-harmonic ghosts")
1110
+
1111
+ # Step 1: Remove phantom high notes (conservative)
1112
+ print("\nStep 1: Removing phantom high notes...")
1113
+ midi_data, phantoms_removed = remove_phantom_notes(midi_data)
1114
+ print(f" Removed {phantoms_removed} phantom notes")
1115
+
1116
+ # Step 1b: Hard pitch ceiling at C7 (MIDI 96) — extreme highs only
1117
+ print("\nStep 1b: Applying pitch ceiling (C7 / MIDI 96)...")
1118
+ midi_data, ceiling_removed = apply_pitch_ceiling(midi_data, max_pitch=96)
1119
+ print(f" Removed {ceiling_removed} notes above C7")
1120
+
1121
+ # Step 2: Align chord notes to single onset
1122
+ print("\nStep 2: Aligning chord notes...")
1123
+ midi_data, chords_aligned = align_chords(midi_data)
1124
+ print(f" Aligned {chords_aligned} notes within chords")
1125
+
1126
+ # Step 3: Full beat-grid quantization
1127
+ print("\nStep 3: Quantizing to beat grid...")
1128
+ midi_data, notes_quantized, detected_tempo = quantize_to_beat_grid(
1129
+ midi_data, y, sr, hop_length, strength=1.0
1130
+ )
1131
+ print(f" Detected tempo: {detected_tempo:.0f} BPM")
1132
+ print(f" Quantized {notes_quantized} notes (full snap)")
1133
+
1134
+ # Step 4: Targeted onset correction against audio
1135
+ print("\nStep 4: Correcting onsets against audio...")
1136
+ midi_data, corrections, total_shift, n_chords, pre_f1, post_f1 = \
1137
+ correct_onsets(midi_data, ref_onsets)
1138
+ avg_shift = (total_shift / corrections * 1000) if corrections > 0 else 0
1139
+ print(f" Corrected {corrections}/{n_chords} (avg {avg_shift:.0f}ms)")
1140
+ print(f" Onset F1: {pre_f1:.4f} -> {post_f1:.4f}")
1141
+
1142
+ # Step 5: Tight second correction pass (10-60ms window)
1143
+ print("\nStep 5: Fine-tuning onsets (tight pass)...")
1144
+ midi_data, corrections2, total_shift2, n_chords2, _, post_f1_2 = \
1145
+ correct_onsets(midi_data, ref_onsets, min_off=0.01, max_off=0.06)
1146
+ avg_shift2 = (total_shift2 / corrections2 * 1000) if corrections2 > 0 else 0
1147
+ print(f" Fine-tuned {corrections2}/{n_chords2} (avg {avg_shift2:.0f}ms)")
1148
+ print(f" Onset F1: {post_f1:.4f} -> {post_f1_2:.4f}")
1149
+
1150
+ # Step 6: Micro-correction pass (5-25ms window)
1151
+ print("\nStep 6: Micro-correcting onsets...")
1152
+ midi_data, corrections3, total_shift3, n_chords3, _, post_f1_3 = \
1153
+ correct_onsets(midi_data, ref_onsets, min_off=0.005, max_off=0.025)
1154
+ avg_shift3 = (total_shift3 / corrections3 * 1000) if corrections3 > 0 else 0
1155
+ print(f" Micro-corrected {corrections3}/{n_chords3} (avg {avg_shift3:.0f}ms)")
1156
+ print(f" Onset F1: {post_f1_2:.4f} -> {post_f1_3:.4f}")
1157
+
1158
+ # Step 6b: Remove spurious false-positive onsets
1159
+ print("\nStep 6b: Removing spurious onsets (false positive cleanup)...")
1160
+ midi_data, spurious_notes, spurious_onsets = remove_spurious_onsets(
1161
+ midi_data, y, sr, ref_onsets, hop_length
1162
+ )
1163
+ print(f" Removed {spurious_notes} notes across {spurious_onsets} spurious onsets")
1164
+
1165
+ # Step 6c: Wide onset recovery pass (50-120ms window) to rescue false negatives
1166
+ print("\nStep 6c: Wide onset recovery (rescuing false negatives)...")
1167
+ midi_data, corrections_wide, total_shift_wide, n_chords_wide, _, post_f1_wide = \
1168
+ correct_onsets(midi_data, ref_onsets, min_off=0.04, max_off=0.12)
1169
+ avg_shift_wide = (total_shift_wide / corrections_wide * 1000) if corrections_wide > 0 else 0
1170
+ print(f" Recovered {corrections_wide}/{n_chords_wide} (avg {avg_shift_wide:.0f}ms)")
1171
+ print(f" Onset F1: {post_f1_3:.4f} -> {post_f1_wide:.4f}")
1172
+
1173
+ # Step 7: Global offset correction
1174
+ print("\nStep 7: Correcting systematic offset...")
1175
+ midi_data, offset = apply_global_offset(midi_data, ref_onsets)
1176
+ print(f" Applied {offset*1000:+.1f}ms global offset")
1177
+
1178
+ # Step 8: Fix overlaps and enforce min duration (LAST — after all position changes)
1179
+ print("\nStep 8: Fixing overlaps and enforcing min duration...")
1180
+ midi_data, notes_trimmed, durations_enforced = fix_note_overlap(midi_data)
1181
+ print(f" Trimmed {notes_trimmed} overlapping notes")
1182
+ print(f" Enforced min duration on {durations_enforced} notes")
1183
+
1184
+ # Step 8b: CQT-based duration extension
1185
+ print("\nStep 8b: Extending note durations to match audio decay...")
1186
+ midi_data, notes_extended = extend_note_durations(midi_data, y, sr, hop_length)
1187
+ print(f" Extended {notes_extended} notes to match audio CQT decay")
1188
+
1189
+ # Step 8c: Re-enforce minimum duration after CQT extension
1190
+ min_dur_enforced_2 = 0
1191
+ for instrument in midi_data.instruments:
1192
+ for note in instrument.notes:
1193
+ if note.end - note.start < 0.10:
1194
+ note.end = note.start + 0.10
1195
+ min_dur_enforced_2 += 1
1196
+ if min_dur_enforced_2:
1197
+ print(f"\nStep 8c: Re-enforced min duration on {min_dur_enforced_2} notes after CQT extension")
1198
+
1199
+ # Step 8d: CQT pitch-specific energy filter (remove bass hallucinations)
1200
+ print("\nStep 8d: Removing pitch-unconfirmed bass notes...")
1201
+ midi_data, cqt_removed = remove_pitch_unconfirmed_notes(midi_data, y, sr, hop_length)
1202
+ print(f" Removed {cqt_removed} notes with no CQT energy at their pitch")
1203
+
1204
+ # Step 8e: Recover missing notes from CQT energy
1205
+ # Runs late so the coverage map reflects what actually survived all filters.
1206
+ # Recovered notes won't be touched by phantom/spurious/pitch filters.
1207
+ print("\nStep 8e: Recovering missing notes from CQT analysis...")
1208
+ # Collect existing onset times to snap recovered notes to
1209
+ existing_onsets = sorted(set(
1210
+ round(n.start, 4) for inst in midi_data.instruments for n in inst.notes
1211
+ ))
1212
+ midi_data, notes_recovered = recover_missing_notes(
1213
+ midi_data, y, sr, hop_length, snap_onsets=existing_onsets
1214
+ )
1215
+ print(f" Recovered {notes_recovered} notes from CQT energy")
1216
+
1217
+ # Step 8f: Playability filter — limit per-onset chord size (4 per hand)
1218
+ print("\nStep 8f: Playability filter (max 4 notes per hand per chord)...")
1219
+ midi_data, playability_removed = limit_concurrent_notes(midi_data, max_per_hand=4)
1220
+ print(f" Removed {playability_removed} excess chord notes")
1221
+
1222
+ # Step 8g: Limit total concurrent sounding notes (4 per hand)
1223
+ print("\nStep 8g: Concurrent sounding limit (max 4 per hand)...")
1224
+ midi_data, sustain_trimmed = limit_total_concurrent(midi_data, max_per_hand=4)
1225
+ print(f" Trimmed {sustain_trimmed} sustained notes to reduce pileup")
1226
+
1227
+ # Final metrics
1228
+ final_onsets = []
1229
+ for inst in midi_data.instruments:
1230
+ for n in inst.notes:
1231
+ final_onsets.append(n.start)
1232
+ final_onsets = np.unique(np.round(np.sort(final_onsets), 4))
1233
+ final_f1 = onset_f1(ref_onsets, final_onsets)
1234
+ final_notes = sum(len(inst.notes) for inst in midi_data.instruments)
1235
+
1236
+ # Duration sanity check
1237
+ all_durs = [n.end - n.start for inst in midi_data.instruments for n in inst.notes]
1238
+ min_dur = min(all_durs) * 1000 if all_durs else 0
1239
+
1240
+ print(f"\nDone:")
1241
+ print(f" Phantoms removed: {phantoms_removed}")
1242
+ print(f" Pitch ceiling removed: {ceiling_removed}")
1243
+ print(f" Playability filter: {playability_removed} chord / {sustain_trimmed} sustain")
1244
+ print(f" Chords aligned: {chords_aligned}")
1245
+ print(f" Notes quantized: {notes_quantized} ({detected_tempo:.0f} BPM)")
1246
+ print(f" Onsets corrected: {corrections}/{n_chords}")
1247
+ print(f" Spurious onsets removed: {spurious_onsets} ({spurious_notes} notes)")
1248
+ print(f" FN recovery corrections: {corrections_wide}")
1249
+ print(f" Global offset: {offset*1000:+.1f}ms")
1250
+ print(f" Overlaps trimmed: {notes_trimmed}")
1251
+ print(f" Min durations enforced: {durations_enforced}")
1252
+ print(f" Notes extended (CQT decay): {notes_extended}")
1253
+ # Playability check: max concurrent notes per hand
1254
+ all_final = sorted(
1255
+ [n for inst in midi_data.instruments for n in inst.notes],
1256
+ key=lambda n: n.start
1257
+ )
1258
+ max_left = 0
1259
+ max_right = 0
1260
+ for i, note in enumerate(all_final):
1261
+ is_right = note.pitch >= 60
1262
+ concurrent = sum(1 for o in all_final
1263
+ if o.start <= note.start < o.end
1264
+ and (o.pitch >= 60) == is_right)
1265
+ if is_right:
1266
+ max_right = max(max_right, concurrent)
1267
+ else:
1268
+ max_left = max(max_left, concurrent)
1269
+
1270
+ print(f" Final onset F1: {final_f1:.4f}")
1271
+ print(f" Min note duration: {min_dur:.0f}ms")
1272
+ print(f" Max concurrent: L={max_left} R={max_right}")
1273
+ print(f" Notes: {total_notes} -> {final_notes}")
1274
+
1275
+ # Final step: shift all notes so music starts at t=0
1276
+ # (must be AFTER all audio-aligned processing like onset detection, CQT filters)
1277
+ if music_start > 0.1:
1278
+ print(f"\nShifting all notes by -{music_start:.2f}s so music starts at t=0...")
1279
+ for instrument in midi_data.instruments:
1280
+ for note in instrument.notes:
1281
+ note.start = max(0, note.start - music_start)
1282
+ note.end = max(note.start + 0.01, note.end - music_start)
1283
+
1284
+ midi_data.write(str(output_path))
1285
+ print(f" Written to {output_path}")
1286
+
1287
+ # Step 9: Spectral fidelity analysis (CQT comparison)
1288
+ print("\nStep 9: Spectral fidelity analysis (CQT comparison)...")
1289
+ try:
1290
+ from spectral import spectral_fidelity
1291
+ spec_results = spectral_fidelity(y, sr, midi_data, hop_length)
1292
+ print(f" Spectral F1: {spec_results['spectral_f1']:.4f}")
1293
+ print(f" Spectral Precision: {spec_results['spectral_precision']:.4f}")
1294
+ print(f" Spectral Recall: {spec_results['spectral_recall']:.4f}")
1295
+ print(f" Spectral Similarity: {spec_results['spectral_similarity']:.4f}")
1296
+
1297
+ # Save spectral report alongside MIDI
1298
+ import json
1299
+ report_path = str(output_path).replace('.mid', '_spectral.json')
1300
+ Path(report_path).write_text(json.dumps(spec_results, indent=2))
1301
+ print(f" Report saved to {report_path}")
1302
+ except Exception as e:
1303
+ print(f" Spectral analysis failed: {e}")
1304
+
1305
+ # Step 10: Chord detection
1306
+ print("\nStep 10: Detecting chords...")
1307
+ try:
1308
+ from chords import detect_chords
1309
+ chords_json_path = str(Path(output_path).with_name(
1310
+ Path(output_path).stem + "_chords.json"
1311
+ ))
1312
+ chord_events = detect_chords(str(output_path), chords_json_path)
1313
+ print(f" Detected {len(chord_events)} chord regions")
1314
+ except Exception as e:
1315
+ print(f" Chord detection failed: {e}")
1316
+ chord_events = []
1317
+
1318
+ return midi_data
1319
+
1320
+
1321
+ def onset_f1(ref_onsets, est_onsets, tolerance=0.05):
1322
+ """Compute onset detection F1 score."""
1323
+ if len(ref_onsets) == 0 and len(est_onsets) == 0:
1324
+ return 1.0
1325
+ if len(ref_onsets) == 0 or len(est_onsets) == 0:
1326
+ return 0.0
1327
+
1328
+ matched_ref = set()
1329
+ matched_est = set()
1330
+
1331
+ for i, r in enumerate(ref_onsets):
1332
+ diffs = np.abs(est_onsets - r)
1333
+ best = np.argmin(diffs)
1334
+ if diffs[best] <= tolerance and best not in matched_est:
1335
+ matched_ref.add(i)
1336
+ matched_est.add(best)
1337
+
1338
+ precision = len(matched_est) / len(est_onsets) if len(est_onsets) > 0 else 0
1339
+ recall = len(matched_ref) / len(ref_onsets) if len(ref_onsets) > 0 else 0
1340
+
1341
+ if precision + recall == 0:
1342
+ return 0.0
1343
+ return 2 * precision * recall / (precision + recall)
1344
+
1345
+
1346
+ if __name__ == "__main__":
1347
+ import sys
1348
+
1349
+ if len(sys.argv) < 3:
1350
+ print("Usage: python optimize.py <original_audio> <midi_file> [output.mid]")
1351
+ sys.exit(1)
1352
+
1353
+ audio_path = sys.argv[1]
1354
+ midi_path = sys.argv[2]
1355
+ out_path = sys.argv[3] if len(sys.argv) > 3 else None
1356
+ optimize(audio_path, midi_path, out_path)
transcriber/spectral.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Spectral fidelity comparison: CQT of original audio vs MIDI piano-roll.
2
+
3
+ Compares a CQT spectrogram of the original audio against an ideal
4
+ spectrogram synthesized from the MIDI note positions, including harmonic
5
+ modeling. This gives a holistic fidelity metric that captures pitch,
6
+ timing, duration, and velocity accuracy simultaneously.
7
+
8
+ Piano notes produce rich harmonics (2nd, 3rd, 4th... partials). The MIDI
9
+ pianoroll models these harmonics so the comparison is fair against the
10
+ real audio CQT which contains them naturally.
11
+ """
12
+
13
+ import json
14
+ import numpy as np
15
+ import librosa
16
+ import pretty_midi
17
+ from pathlib import Path
18
+
19
+
20
+ # CQT parameters — 88 piano keys, 3 bins per semitone for smoothness
21
+ FMIN = librosa.note_to_hz('A0') # 27.5 Hz — lowest piano key
22
+ N_BINS = 88 * 3 # 264 bins covering A0–C8
23
+ BINS_PER_OCTAVE = 12 * 3 # 36
24
+
25
+ # Piano harmonic model: partial number -> semitones above fundamental, dB attenuation
26
+ # Based on typical grand piano harmonic spectrum
27
+ HARMONICS = [
28
+ # (freq_ratio, semitone_offset, dB_attenuation)
29
+ (2, 12.0, -6), # 2nd partial: octave above
30
+ (3, 19.02, -12), # 3rd partial: octave + fifth
31
+ (4, 24.0, -16), # 4th partial: 2 octaves
32
+ (5, 27.86, -20), # 5th partial: 2 octaves + major 3rd
33
+ (6, 31.02, -22), # 6th partial: 2 octaves + fifth
34
+ (7, 33.69, -26), # 7th partial: ~2 octaves + minor 7th
35
+ (8, 36.0, -28), # 8th partial: 3 octaves
36
+ ]
37
+
38
+
39
+ def audio_cqt(y, sr, hop_length=512):
40
+ """Compute magnitude CQT of an audio signal in dB."""
41
+ C = np.abs(librosa.cqt(
42
+ y, sr=sr, hop_length=hop_length,
43
+ fmin=FMIN, n_bins=N_BINS, bins_per_octave=BINS_PER_OCTAVE,
44
+ ))
45
+ C_db = librosa.amplitude_to_db(C, ref=np.max(C))
46
+ C_db = np.maximum(C_db, -80.0)
47
+ return C_db
48
+
49
+
50
+ def midi_to_pianoroll_cqt(midi_data, duration, sr=22050, hop_length=512):
51
+ """Build a harmonic-aware CQT-like spectrogram from MIDI notes.
52
+
53
+ For each MIDI note, places energy at the fundamental AND its
54
+ harmonics (partials 2-8) with appropriate attenuation, matching
55
+ how a real piano sounds in the CQT domain.
56
+ """
57
+ n_frames = int(np.ceil(duration * sr / hop_length))
58
+ pianoroll = np.full((N_BINS, n_frames), -80.0)
59
+
60
+ for instrument in midi_data.instruments:
61
+ for note in instrument.notes:
62
+ # Map MIDI pitch to CQT bin (A0 = MIDI 21, 3 bins/semitone)
63
+ fund_bin = (note.pitch - 21) * 3 + 1
64
+ if fund_bin < 0 or fund_bin >= N_BINS:
65
+ continue
66
+
67
+ start_frame = max(0, int(note.start * sr / hop_length))
68
+ end_frame = min(n_frames, int(note.end * sr / hop_length))
69
+ if start_frame >= end_frame:
70
+ continue
71
+
72
+ vel_db = -30.0 + (note.velocity / 127.0) * 30.0
73
+
74
+ # Place fundamental with ±1 bin spread
75
+ _place_energy(pianoroll, fund_bin, start_frame, end_frame, vel_db)
76
+
77
+ # Place harmonics
78
+ for _, semitones, attenuation in HARMONICS:
79
+ harmonic_bin = fund_bin + int(round(semitones * 3))
80
+ if harmonic_bin >= N_BINS:
81
+ break
82
+ harm_db = vel_db + attenuation
83
+ if harm_db < -70: # Skip inaudible harmonics
84
+ continue
85
+ _place_energy(pianoroll, harmonic_bin, start_frame, end_frame, harm_db)
86
+
87
+ return pianoroll
88
+
89
+
90
+ def _place_energy(pianoroll, center_bin, start, end, db_level):
91
+ """Place energy in the pianoroll at center ± 1 bin."""
92
+ for offset, atten in [(-1, -6), (0, 0), (1, -6)]:
93
+ b = center_bin + offset
94
+ if 0 <= b < pianoroll.shape[0]:
95
+ pianoroll[b, start:end] = np.maximum(
96
+ pianoroll[b, start:end], db_level + atten
97
+ )
98
+
99
+
100
+ def spectral_fidelity(y, sr, midi_data, hop_length=512):
101
+ """Compute spectral fidelity: how well MIDI matches the original audio.
102
+
103
+ Returns dict with scores and detailed diagnostics.
104
+ """
105
+ duration = len(y) / sr
106
+
107
+ audio_spec = audio_cqt(y, sr, hop_length)
108
+ midi_spec = midi_to_pianoroll_cqt(midi_data, duration, sr, hop_length)
109
+
110
+ # Normalize to 0-1 range
111
+ audio_norm = (audio_spec + 80.0) / 80.0
112
+ midi_norm = (midi_spec + 80.0) / 80.0
113
+
114
+ # Active energy thresholds
115
+ audio_active = audio_norm > 0.25 # > -60dB
116
+ midi_active = midi_norm > 0.25
117
+
118
+ tp = np.sum(audio_active & midi_active)
119
+ fn = np.sum(audio_active & ~midi_active)
120
+ fp = np.sum(~audio_active & midi_active)
121
+
122
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
123
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
124
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
125
+
126
+ # Weighted spectral similarity (MSE on active regions)
127
+ active_mask = audio_active | midi_active
128
+ if np.sum(active_mask) > 0:
129
+ mse = np.mean((audio_norm[active_mask] - midi_norm[active_mask]) ** 2)
130
+ spectral_similarity = max(0, 1.0 - mse * 4)
131
+ else:
132
+ spectral_similarity = 0
133
+
134
+ # Per-octave analysis
135
+ octave_analysis = []
136
+ for octave in range(8):
137
+ b0 = octave * 36
138
+ b1 = min((octave + 1) * 36, N_BINS)
139
+ oct_audio = audio_active[b0:b1]
140
+ oct_midi = midi_active[b0:b1]
141
+ oct_tp = int(np.sum(oct_audio & oct_midi))
142
+ oct_fn = int(np.sum(oct_audio & ~oct_midi))
143
+ oct_fp = int(np.sum(~oct_audio & oct_midi))
144
+ oct_total = int(np.sum(oct_audio))
145
+ octave_analysis.append({
146
+ 'octave': octave,
147
+ 'range': f"A{octave}-A{octave+1}",
148
+ 'audio_energy': oct_total,
149
+ 'missing_energy': oct_fn,
150
+ 'extra_energy': oct_fp,
151
+ 'matched_energy': oct_tp,
152
+ 'coverage': round(float(oct_tp / oct_total) if oct_total > 0 else 1.0, 4),
153
+ })
154
+
155
+ # Per-time analysis (20 segments)
156
+ n_frames = audio_spec.shape[1]
157
+ window = max(1, n_frames // 20)
158
+ time_analysis = []
159
+ for seg in range(20):
160
+ f0 = seg * window
161
+ f1_t = min((seg + 1) * window, n_frames)
162
+ seg_audio = audio_active[:, f0:f1_t]
163
+ seg_midi = midi_active[:, f0:f1_t]
164
+ seg_tp = int(np.sum(seg_audio & seg_midi))
165
+ seg_fn = int(np.sum(seg_audio & ~seg_midi))
166
+ seg_fp = int(np.sum(~seg_audio & seg_midi))
167
+ seg_total = int(np.sum(seg_audio))
168
+ t0 = librosa.frames_to_time(f0, sr=sr, hop_length=hop_length)
169
+ t1 = librosa.frames_to_time(f1_t, sr=sr, hop_length=hop_length)
170
+ time_analysis.append({
171
+ 'time_start': round(float(t0), 2),
172
+ 'time_end': round(float(t1), 2),
173
+ 'missing': seg_fn,
174
+ 'extra': seg_fp,
175
+ 'matched': seg_tp,
176
+ 'fidelity': round(float(seg_tp / seg_total) if seg_total > 0 else 1.0, 3),
177
+ })
178
+
179
+ # Find specific missing and extra note regions
180
+ missing_notes = _find_note_gaps(audio_spec, midi_spec, sr, hop_length, mode='missing')
181
+ extra_notes = _find_note_gaps(audio_spec, midi_spec, sr, hop_length, mode='extra')
182
+
183
+ return {
184
+ 'spectral_f1': round(f1, 4),
185
+ 'spectral_precision': round(precision, 4),
186
+ 'spectral_recall': round(recall, 4),
187
+ 'spectral_similarity': round(spectral_similarity, 4),
188
+ 'per_octave': octave_analysis,
189
+ 'per_time': time_analysis,
190
+ 'missing_notes': missing_notes[:20],
191
+ 'extra_notes': extra_notes[:20],
192
+ }
193
+
194
+
195
+ def _find_note_gaps(audio_spec, midi_spec, sr, hop_length, mode='missing'):
196
+ """Find time-frequency regions with energy in one but not the other.
197
+
198
+ mode='missing': audio has energy, MIDI doesn't (notes basic-pitch missed)
199
+ mode='extra': MIDI has energy, audio doesn't (hallucinations)
200
+ """
201
+ audio_norm = (audio_spec + 80.0) / 80.0
202
+ midi_norm = (midi_spec + 80.0) / 80.0
203
+
204
+ if mode == 'missing':
205
+ gap = (audio_norm > 0.5) & (midi_norm < 0.25)
206
+ energy_source = audio_norm
207
+ else:
208
+ gap = (midi_norm > 0.5) & (audio_norm < 0.25)
209
+ energy_source = midi_norm
210
+
211
+ results = []
212
+ n_bins, n_frames = audio_spec.shape
213
+ visited = np.zeros_like(gap, dtype=bool)
214
+
215
+ for b in range(0, n_bins, 3):
216
+ for f in range(n_frames):
217
+ if not gap[b, f] or visited[b, f]:
218
+ continue
219
+ f_start = f
220
+ f_end = f
221
+ while f_end < n_frames and gap[b, f_end] and not visited[b, f_end]:
222
+ visited[b, f_end] = True
223
+ f_end += 1
224
+ if f_end - f_start < 3:
225
+ continue
226
+
227
+ midi_pitch = 21 + b // 3
228
+ t0 = librosa.frames_to_time(f_start, sr=sr, hop_length=hop_length)
229
+ t1 = librosa.frames_to_time(f_end, sr=sr, hop_length=hop_length)
230
+ energy = float(np.mean(energy_source[b:b+3, f_start:f_end]))
231
+ note_name = pretty_midi.note_number_to_name(midi_pitch)
232
+
233
+ results.append({
234
+ 'pitch': midi_pitch,
235
+ 'note': note_name,
236
+ 'time_start': round(float(t0), 3),
237
+ 'time_end': round(float(t1), 3),
238
+ 'duration': round(float(t1 - t0), 3),
239
+ 'energy': round(energy, 3),
240
+ })
241
+
242
+ results.sort(key=lambda x: x['energy'] * x['duration'], reverse=True)
243
+ return results
244
+
245
+
246
+ def compare(audio_path, midi_path, output_json=None):
247
+ """Run full spectral comparison and print report."""
248
+ y, sr = librosa.load(str(audio_path), sr=22050, mono=True)
249
+ midi_data = pretty_midi.PrettyMIDI(str(midi_path))
250
+
251
+ results = spectral_fidelity(y, sr, midi_data)
252
+
253
+ print(f"\nSpectral Fidelity Report:")
254
+ print(f" Spectral F1: {results['spectral_f1']:.4f}")
255
+ print(f" Spectral Precision: {results['spectral_precision']:.4f}")
256
+ print(f" Spectral Recall: {results['spectral_recall']:.4f}")
257
+ print(f" Spectral Similarity: {results['spectral_similarity']:.4f}")
258
+
259
+ print(f"\n Per-octave coverage:")
260
+ for o in results['per_octave']:
261
+ if o['audio_energy'] > 0:
262
+ print(f" {o['range']}: {o['coverage']:.1%} "
263
+ f"(missing: {o['missing_energy']}, extra: {o['extra_energy']})")
264
+
265
+ worst = sorted(results['per_time'], key=lambda x: x['fidelity'])[:5]
266
+ print(f"\n Worst time segments:")
267
+ for t in worst:
268
+ print(f" {t['time_start']:.1f}-{t['time_end']:.1f}s: "
269
+ f"fidelity={t['fidelity']:.1%} "
270
+ f"(missing: {t['missing']}, extra: {t['extra']})")
271
+
272
+ if results['missing_notes']:
273
+ print(f"\n Top missing notes:")
274
+ for n in results['missing_notes'][:10]:
275
+ print(f" {n['note']} at {n['time_start']:.2f}-{n['time_end']:.2f}s "
276
+ f"(energy: {n['energy']:.2f})")
277
+
278
+ if results['extra_notes']:
279
+ print(f"\n Top extra notes:")
280
+ for n in results['extra_notes'][:10]:
281
+ print(f" {n['note']} at {n['time_start']:.2f}-{n['time_end']:.2f}s "
282
+ f"(energy: {n['energy']:.2f})")
283
+
284
+ if output_json:
285
+ Path(output_json).write_text(json.dumps(results, indent=2))
286
+ print(f"\n Report saved to {output_json}")
287
+
288
+ return results
289
+
290
+
291
+ if __name__ == "__main__":
292
+ import sys
293
+ if len(sys.argv) < 3:
294
+ print("Usage: python spectral.py <audio_file> <midi_file> [output.json]")
295
+ sys.exit(1)
296
+ compare(sys.argv[1], sys.argv[2], sys.argv[3] if len(sys.argv) > 3 else None)
transcriber/transcribe.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Transcribe an audio file to MIDI using basic-pitch."""
3
+
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ # Patch scipy.signal.gaussian (removed in scipy >=1.12, basic-pitch hasn't updated)
8
+ import scipy.signal
9
+ if not hasattr(scipy.signal, "gaussian"):
10
+ from scipy.signal.windows import gaussian
11
+ scipy.signal.gaussian = gaussian
12
+
13
+ from basic_pitch.inference import predict
14
+ import basic_pitch
15
+ _MODELS_DIR = Path(basic_pitch.__file__).parent / "saved_models" / "icassp_2022"
16
+ ONNX_MODEL_PATH = _MODELS_DIR / "nmp.onnx"
17
+
18
+ def transcribe(input_path: str, output_path: str | None = None):
19
+ input_file = Path(input_path)
20
+ if not input_file.exists():
21
+ print(f"Error: {input_file} not found")
22
+ sys.exit(1)
23
+
24
+ if output_path is None:
25
+ output_path = input_file.with_suffix(".mid")
26
+ else:
27
+ output_path = Path(output_path)
28
+
29
+ print(f"Transcribing {input_file}...")
30
+ model_output, midi_data, note_events = predict(
31
+ str(input_file),
32
+ ONNX_MODEL_PATH,
33
+ onset_threshold=0.33,
34
+ frame_threshold=0.20,
35
+ minimum_note_length=100.0,
36
+ )
37
+
38
+ midi_data.write(str(output_path))
39
+ print(f"MIDI written to {output_path}")
40
+ print(f"Found {len(note_events)} note events")
41
+ return output_path
42
+
43
+ if __name__ == "__main__":
44
+ if len(sys.argv) < 2:
45
+ print("Usage: python transcribe.py <audio_file> [output.mid]")
46
+ sys.exit(1)
47
+
48
+ input_file = sys.argv[1]
49
+ output_file = sys.argv[2] if len(sys.argv) > 2 else None
50
+ transcribe(input_file, output_file)