calebhan commited on
Commit
a5359f9
·
1 Parent(s): b5fec2f

file upload pipeline

Browse files
.gitignore CHANGED
@@ -244,6 +244,7 @@ storage/youtube_cookies*
244
  !storage/README.txt
245
  storage/outputs/*
246
  storage/temp/*
 
247
 
248
  # Temp files
249
  /tmp/
 
244
  !storage/README.txt
245
  storage/outputs/*
246
  storage/temp/*
247
+ storage/uploads/*
248
 
249
  # Temp files
250
  /tmp/
backend/celery_app.py CHANGED
@@ -1,4 +1,12 @@
1
  """Celery application configuration."""
 
 
 
 
 
 
 
 
2
  from celery import Celery
3
  from kombu import Exchange, Queue
4
  from app_config import settings
 
1
  """Celery application configuration."""
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ # Ensure backend directory is in Python path for imports
6
+ backend_dir = Path(__file__).parent.resolve()
7
+ if str(backend_dir) not in sys.path:
8
+ sys.path.insert(0, str(backend_dir))
9
+
10
  from celery import Celery
11
  from kombu import Exchange, Queue
12
  from app_config import settings
backend/main.py CHANGED
@@ -1,5 +1,5 @@
1
  """FastAPI application for Rescored backend."""
2
- from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Request, File, UploadFile
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import FileResponse
5
  from pydantic import BaseModel, HttpUrl
@@ -157,6 +157,11 @@ class TranscribeRequest(BaseModel):
157
  options: dict = {"instruments": ["piano"]}
158
 
159
 
 
 
 
 
 
160
  class TranscribeResponse(BaseModel):
161
  """Response model for transcription submission."""
162
  job_id: str
@@ -288,6 +293,93 @@ async def submit_transcription(request: TranscribeRequest):
288
  )
289
 
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  @app.get("/api/v1/jobs/{job_id}", response_model=JobStatusResponse)
292
  async def get_job_status(job_id: str):
293
  """
 
1
  """FastAPI application for Rescored backend."""
2
+ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Request, File, UploadFile, Form
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import FileResponse
5
  from pydantic import BaseModel, HttpUrl
 
157
  options: dict = {"instruments": ["piano"]}
158
 
159
 
160
+ class FileUploadTranscribeRequest(BaseModel):
161
+ """Request model for file upload transcription."""
162
+ options: dict = {"instruments": ["piano"]}
163
+
164
+
165
  class TranscribeResponse(BaseModel):
166
  """Response model for transcription submission."""
167
  job_id: str
 
293
  )
294
 
295
 
296
+ @app.post("/api/v1/transcribe/upload", response_model=TranscribeResponse, status_code=201)
297
+ async def submit_file_transcription(
298
+ file: UploadFile = File(...),
299
+ instruments: str = Form('["piano"]'),
300
+ vocal_instrument: int = Form(40) # Default to violin (program 40)
301
+ ):
302
+ """
303
+ Submit an audio file for transcription.
304
+
305
+ Args:
306
+ file: Audio file (WAV, MP3, FLAC, etc.)
307
+ instruments: JSON array of instruments (default: ["piano"])
308
+ vocal_instrument: MIDI program number for vocals (default: 40 = violin)
309
+
310
+ Returns:
311
+ Job information including job ID and WebSocket URL
312
+ """
313
+ print(f"[DEBUG] FastAPI received instruments parameter: {instruments!r}")
314
+ print(f"[DEBUG] FastAPI received vocal_instrument parameter: {vocal_instrument}")
315
+
316
+ # Validate file type
317
+ allowed_extensions = {'.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac'}
318
+ file_ext = Path(file.filename or '').suffix.lower()
319
+
320
+ if file_ext not in allowed_extensions:
321
+ raise HTTPException(
322
+ status_code=400,
323
+ detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}"
324
+ )
325
+
326
+ # Validate file size (max 100MB)
327
+ max_size = 100 * 1024 * 1024 # 100MB
328
+ content = await file.read()
329
+ if len(content) > max_size:
330
+ raise HTTPException(
331
+ status_code=400,
332
+ detail=f"File too large. Maximum size: 100MB"
333
+ )
334
+
335
+ # Parse instruments option
336
+ try:
337
+ import json as json_module
338
+ print(f"[DEBUG] Received instruments parameter (raw): {instruments}")
339
+ instruments_list = json_module.loads(instruments)
340
+ print(f"[DEBUG] Parsed instruments list: {instruments_list}")
341
+ except Exception as e:
342
+ print(f"[DEBUG] Failed to parse instruments, using default ['piano']. Error: {e}")
343
+ instruments_list = ["piano"]
344
+
345
+ # Create job
346
+ job_id = str(uuid4())
347
+
348
+ # Save uploaded file to storage
349
+ upload_dir = settings.storage_path / "uploads"
350
+ upload_dir.mkdir(parents=True, exist_ok=True)
351
+ upload_path = upload_dir / f"{job_id}{file_ext}"
352
+
353
+ with open(upload_path, "wb") as f:
354
+ f.write(content)
355
+
356
+ job_data = {
357
+ "job_id": job_id,
358
+ "status": "queued",
359
+ "upload_path": str(upload_path),
360
+ "original_filename": file.filename or "unknown",
361
+ "options": json.dumps({"instruments": instruments_list, "vocal_instrument": vocal_instrument}),
362
+ "created_at": datetime.utcnow().isoformat(),
363
+ "progress": 0,
364
+ "current_stage": "queued",
365
+ "status_message": "Job queued for processing",
366
+ }
367
+
368
+ # Store in Redis
369
+ redis_client.hset(f"job:{job_id}", mapping=job_data)
370
+
371
+ # Queue Celery task
372
+ process_transcription_task.delay(job_id)
373
+
374
+ return TranscribeResponse(
375
+ job_id=job_id,
376
+ status="queued",
377
+ created_at=datetime.utcnow(),
378
+ estimated_duration_seconds=120,
379
+ websocket_url=f"ws://localhost:{settings.api_port}/api/v1/jobs/{job_id}/stream"
380
+ )
381
+
382
+
383
  @app.get("/api/v1/jobs/{job_id}", response_model=JobStatusResponse)
384
  async def get_job_status(job_id: str):
385
  """
backend/pipeline.py CHANGED
@@ -33,18 +33,52 @@ except ImportError as e:
33
  print(f"WARNING: madmom not available. Falling back to librosa for tempo/beat detection.")
34
  print(f" Error: {e}")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  class TranscriptionPipeline:
39
  """Handles the complete transcription workflow."""
40
 
41
- def __init__(self, job_id: str, youtube_url: str, storage_path: Path, config=None):
42
  self.job_id = job_id
43
  self.youtube_url = youtube_url
44
  self.storage_path = storage_path
45
  self.temp_dir = storage_path / "temp" / job_id
46
  self.temp_dir.mkdir(parents=True, exist_ok=True)
47
  self.progress_callback = None
 
48
 
49
  # Load configuration
50
  if config is None:
@@ -117,6 +151,9 @@ class TranscriptionPipeline:
117
  else:
118
  midi_path = piano_midi
119
 
 
 
 
120
  # Apply post-processing filters (Phase 4)
121
  midi_path = self.apply_post_processing_filters(midi_path)
122
 
@@ -164,6 +201,16 @@ class TranscriptionPipeline:
164
  # Log the full error for debugging
165
  print(f"yt-dlp stderr: {result.stderr}")
166
  print(f"yt-dlp stdout: {result.stdout}")
 
 
 
 
 
 
 
 
 
 
167
  raise RuntimeError(f"yt-dlp failed: {result.stderr}")
168
 
169
  if not output_path.exists():
@@ -231,7 +278,9 @@ class TranscriptionPipeline:
231
  # 2. Demucs separates clean instrumental into piano/guitar/drums/bass/other
232
  print(" Using two-stage separation (BS-RoFormer + Demucs)")
233
 
234
- from audio_separator_wrapper import AudioSeparator
 
 
235
  separator = AudioSeparator()
236
 
237
  separation_dir = self.temp_dir / "separation"
@@ -254,7 +303,9 @@ class TranscriptionPipeline:
254
  # Direct Demucs 6-stem separation (no vocal pre-removal)
255
  print(" Using Demucs 6-stem separation")
256
 
257
- from audio_separator_wrapper import AudioSeparator
 
 
258
  separator = AudioSeparator()
259
 
260
  instrument_dir = self.temp_dir / "instruments"
@@ -298,6 +349,52 @@ class TranscriptionPipeline:
298
 
299
  return stems
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  def transcribe_to_midi(
302
  self,
303
  audio_path: Path,
@@ -407,16 +504,8 @@ class TranscriptionPipeline:
407
  Raises:
408
  RuntimeError: If transcription fails
409
  """
410
- try:
411
- from yourmt3_wrapper import YourMT3Transcriber
412
- except ImportError:
413
- # Try adding backend directory to path
414
- import sys
415
- from pathlib import Path as PathLib
416
- backend_dir = PathLib(__file__).parent
417
- if str(backend_dir) not in sys.path:
418
- sys.path.insert(0, str(backend_dir))
419
- from yourmt3_wrapper import YourMT3Transcriber
420
 
421
  print(f" Transcribing with YourMT3+ (direct call, device: {self.config.yourmt3_device})...")
422
 
@@ -459,20 +548,12 @@ class TranscriptionPipeline:
459
  Raises:
460
  RuntimeError: If transcription fails
461
  """
462
- try:
463
- from yourmt3_wrapper import YourMT3Transcriber
464
- from bytedance_wrapper import ByteDanceTranscriber
465
- from ensemble_transcriber import EnsembleTranscriber
466
- except ImportError:
467
- # Try adding backend directory to path
468
- import sys
469
- from pathlib import Path as PathLib
470
- backend_dir = PathLib(__file__).parent
471
- if str(backend_dir) not in sys.path:
472
- sys.path.insert(0, str(backend_dir))
473
- from yourmt3_wrapper import YourMT3Transcriber
474
- from bytedance_wrapper import ByteDanceTranscriber
475
- from ensemble_transcriber import EnsembleTranscriber
476
 
477
  try:
478
  # Initialize transcribers
@@ -527,15 +608,8 @@ class TranscriptionPipeline:
527
 
528
  # Use YourMT3+ for vocal transcription
529
  # (Could use dedicated melody transcription model in future)
530
- try:
531
- from yourmt3_wrapper import YourMT3Transcriber
532
- except ImportError:
533
- import sys
534
- from pathlib import Path as PathLib
535
- backend_dir = PathLib(__file__).parent
536
- if str(backend_dir) not in sys.path:
537
- sys.path.insert(0, str(backend_dir))
538
- from yourmt3_wrapper import YourMT3Transcriber
539
 
540
  transcriber = YourMT3Transcriber(
541
  model_name="YPTF.MoE+Multi (noPS)",
@@ -648,6 +722,103 @@ class TranscriptionPipeline:
648
 
649
  return merged_path
650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
  def apply_post_processing_filters(self, midi_path: Path) -> Path:
652
  """
653
  Apply post-processing filters to improve transcription quality.
 
33
  print(f"WARNING: madmom not available. Falling back to librosa for tempo/beat detection.")
34
  print(f" Error: {e}")
35
 
36
+ # Import wrapper modules at top level
37
+ try:
38
+ from audio_separator_wrapper import AudioSeparator
39
+ AUDIO_SEPARATOR_AVAILABLE = True
40
+ except ImportError as e:
41
+ AUDIO_SEPARATOR_AVAILABLE = False
42
+ AudioSeparator = None
43
+ print(f"WARNING: audio_separator_wrapper not available: {e}")
44
+
45
+ try:
46
+ from yourmt3_wrapper import YourMT3Transcriber
47
+ YOURMT3_AVAILABLE = True
48
+ except ImportError as e:
49
+ YOURMT3_AVAILABLE = False
50
+ YourMT3Transcriber = None
51
+ print(f"WARNING: yourmt3_wrapper not available: {e}")
52
+
53
+ try:
54
+ from bytedance_wrapper import ByteDanceTranscriber
55
+ BYTEDANCE_AVAILABLE = True
56
+ except ImportError as e:
57
+ BYTEDANCE_AVAILABLE = False
58
+ ByteDanceTranscriber = None
59
+ print(f"WARNING: bytedance_wrapper not available: {e}")
60
+
61
+ try:
62
+ from ensemble_transcriber import EnsembleTranscriber
63
+ ENSEMBLE_AVAILABLE = True
64
+ except ImportError as e:
65
+ ENSEMBLE_AVAILABLE = False
66
+ EnsembleTranscriber = None
67
+ print(f"WARNING: ensemble_transcriber not available: {e}")
68
+
69
 
70
 
71
  class TranscriptionPipeline:
72
  """Handles the complete transcription workflow."""
73
 
74
+ def __init__(self, job_id: str, youtube_url: str, storage_path: Path, config=None, instruments: list = None):
75
  self.job_id = job_id
76
  self.youtube_url = youtube_url
77
  self.storage_path = storage_path
78
  self.temp_dir = storage_path / "temp" / job_id
79
  self.temp_dir.mkdir(parents=True, exist_ok=True)
80
  self.progress_callback = None
81
+ self.instruments = instruments if instruments else ['piano']
82
 
83
  # Load configuration
84
  if config is None:
 
151
  else:
152
  midi_path = piano_midi
153
 
154
+ # Filter MIDI to only include selected instruments
155
+ midi_path = self.filter_midi_by_instruments(midi_path)
156
+
157
  # Apply post-processing filters (Phase 4)
158
  midi_path = self.apply_post_processing_filters(midi_path)
159
 
 
201
  # Log the full error for debugging
202
  print(f"yt-dlp stderr: {result.stderr}")
203
  print(f"yt-dlp stdout: {result.stdout}")
204
+
205
+ # Check for DNS resolution errors
206
+ stderr_lower = result.stderr.lower()
207
+ if ("failed to resolve" in stderr_lower or
208
+ "no address associated with hostname" in stderr_lower or
209
+ "unable to download api page" in stderr_lower):
210
+ raise RuntimeError(
211
+ "Unable to connect to YouTube. For this demo version, please upload your audio file directly using the file upload option."
212
+ )
213
+
214
  raise RuntimeError(f"yt-dlp failed: {result.stderr}")
215
 
216
  if not output_path.exists():
 
278
  # 2. Demucs separates clean instrumental into piano/guitar/drums/bass/other
279
  print(" Using two-stage separation (BS-RoFormer + Demucs)")
280
 
281
+ if not AUDIO_SEPARATOR_AVAILABLE or AudioSeparator is None:
282
+ raise RuntimeError("audio_separator_wrapper is not available")
283
+
284
  separator = AudioSeparator()
285
 
286
  separation_dir = self.temp_dir / "separation"
 
303
  # Direct Demucs 6-stem separation (no vocal pre-removal)
304
  print(" Using Demucs 6-stem separation")
305
 
306
+ if not AUDIO_SEPARATOR_AVAILABLE or AudioSeparator is None:
307
+ raise RuntimeError("audio_separator_wrapper is not available")
308
+
309
  separator = AudioSeparator()
310
 
311
  instrument_dir = self.temp_dir / "instruments"
 
349
 
350
  return stems
351
 
352
+ def transcribe_multiple_stems(self, stems: dict) -> Path:
353
+ """
354
+ Transcribe multiple instrument stems and combine into single MIDI.
355
+
356
+ Args:
357
+ stems: Dict mapping stem names to file paths (e.g., {'piano': Path, 'vocals': Path})
358
+
359
+ Returns:
360
+ Path to combined MIDI file
361
+ """
362
+ import pretty_midi
363
+
364
+ print(f" Transcribing {len(stems)} stems: {list(stems.keys())}")
365
+
366
+ # Transcribe each stem separately
367
+ stem_midis = {}
368
+ for stem_name, stem_path in stems.items():
369
+ print(f" [Stem {stem_name}] Transcribing {stem_path.name}...")
370
+
371
+ # Use appropriate transcription method
372
+ if stem_name == 'piano' and self.config.use_ensemble_transcription:
373
+ midi_path = self.transcribe_with_ensemble(stem_path)
374
+ else:
375
+ midi_path = self.transcribe_with_yourmt3(stem_path)
376
+
377
+ stem_midis[stem_name] = midi_path
378
+ print(f" [Stem {stem_name}] ✓ Complete")
379
+
380
+ # Combine all MIDI files
381
+ print(f" Combining {len(stem_midis)} MIDI files...")
382
+ combined_pm = pretty_midi.PrettyMIDI()
383
+
384
+ for stem_name, midi_path in stem_midis.items():
385
+ pm = pretty_midi.PrettyMIDI(str(midi_path))
386
+ # Add all instruments from this MIDI to the combined MIDI
387
+ for instrument in pm.instruments:
388
+ combined_pm.instruments.append(instrument)
389
+
390
+ # Save combined MIDI
391
+ combined_path = self.temp_dir / "combined_stems.mid"
392
+ combined_pm.write(str(combined_path))
393
+
394
+ print(f" ✓ Combined {len(stem_midis)} stems into {len(combined_pm.instruments)} MIDI tracks")
395
+
396
+ return combined_path
397
+
398
  def transcribe_to_midi(
399
  self,
400
  audio_path: Path,
 
504
  Raises:
505
  RuntimeError: If transcription fails
506
  """
507
+ if not YOURMT3_AVAILABLE or YourMT3Transcriber is None:
508
+ raise RuntimeError("yourmt3_wrapper is not available")
 
 
 
 
 
 
 
 
509
 
510
  print(f" Transcribing with YourMT3+ (direct call, device: {self.config.yourmt3_device})...")
511
 
 
548
  Raises:
549
  RuntimeError: If transcription fails
550
  """
551
+ if not YOURMT3_AVAILABLE or YourMT3Transcriber is None:
552
+ raise RuntimeError("yourmt3_wrapper is not available")
553
+ if not BYTEDANCE_AVAILABLE or ByteDanceTranscriber is None:
554
+ raise RuntimeError("bytedance_wrapper is not available")
555
+ if not ENSEMBLE_AVAILABLE or EnsembleTranscriber is None:
556
+ raise RuntimeError("ensemble_transcriber is not available")
 
 
 
 
 
 
 
 
557
 
558
  try:
559
  # Initialize transcribers
 
608
 
609
  # Use YourMT3+ for vocal transcription
610
  # (Could use dedicated melody transcription model in future)
611
+ if not YOURMT3_AVAILABLE or YourMT3Transcriber is None:
612
+ raise RuntimeError("yourmt3_wrapper is not available")
 
 
 
 
 
 
 
613
 
614
  transcriber = YourMT3Transcriber(
615
  model_name="YPTF.MoE+Multi (noPS)",
 
722
 
723
  return merged_path
724
 
725
+ def filter_midi_by_instruments(self, midi_path: Path) -> Path:
726
+ """
727
+ Filter MIDI file to only include tracks for selected instruments.
728
+
729
+ YourMT3+ transcribes all instruments it detects. This function filters
730
+ the output to only keep tracks matching the user's selection.
731
+
732
+ Args:
733
+ midi_path: Input MIDI file (may contain multiple instrument tracks)
734
+
735
+ Returns:
736
+ Path to filtered MIDI file containing only selected instruments
737
+ """
738
+ import pretty_midi
739
+
740
+ # Map instrument IDs to MIDI program ranges
741
+ # YourMT3+ uses General MIDI program numbers
742
+ INSTRUMENT_PROGRAMS = {
743
+ 'piano': list(range(0, 8)), # Acoustic Grand Piano to Celesta
744
+ 'guitar': list(range(24, 32)), # Acoustic Guitar to Guitar Harmonics
745
+ 'bass': list(range(32, 40)), # Acoustic Bass to Synth Bass 2
746
+ 'drums': [128], # Drum channel (special case)
747
+ 'vocals': list(range(52, 56)) + [65, 85], # Choir Aahs, Voice Oohs, Synth Voice, Lead Voice, YourMT3+ "Singing Voice" (65)
748
+ 'other': list(range(8, 24)) + list(range(40, 52)) + list(range(56, 65)) + list(range(66, 85)) + list(range(86, 128)) # Everything else (excluding vocals programs)
749
+ }
750
+
751
+ # Load MIDI file
752
+ pm = pretty_midi.PrettyMIDI(str(midi_path))
753
+
754
+ # Debug: Show what's in the MIDI before filtering
755
+ print(f" [DEBUG] MIDI contains {len(pm.instruments)} tracks before filtering:")
756
+ for i, inst in enumerate(pm.instruments):
757
+ print(f" Track {i}: {inst.name} (program={inst.program}, is_drum={inst.is_drum}, notes={len(inst.notes)})")
758
+
759
+ # Determine which programs to keep
760
+ programs_to_keep = set()
761
+ for instrument in self.instruments:
762
+ if instrument in INSTRUMENT_PROGRAMS:
763
+ programs_to_keep.update(INSTRUMENT_PROGRAMS[instrument])
764
+
765
+ print(f" [DEBUG] Looking for programs: {sorted(programs_to_keep)[:20]}... (selected instruments: {self.instruments})")
766
+
767
+ # Group instruments by category to handle YourMT3+ outputting multiple tracks per instrument
768
+ # (e.g., both "Acoustic Piano" and "Electric Piano" for piano)
769
+ instrument_groups = {}
770
+ for inst in pm.instruments:
771
+ # Determine which category this instrument belongs to
772
+ matched_category = None
773
+ if inst.is_drum and 128 in programs_to_keep:
774
+ matched_category = 'drums'
775
+ elif not inst.is_drum and inst.program in programs_to_keep:
776
+ # Find which instrument category this program belongs to
777
+ for instr_name, programs in INSTRUMENT_PROGRAMS.items():
778
+ if inst.program in programs and instr_name in self.instruments:
779
+ matched_category = instr_name
780
+ break
781
+
782
+ if matched_category:
783
+ if matched_category not in instrument_groups:
784
+ instrument_groups[matched_category] = []
785
+ instrument_groups[matched_category].append(inst)
786
+ print(f" [DEBUG] Track '{inst.name}' (program={inst.program}) matched category: {matched_category}")
787
+
788
+ # For each category, keep only the track with the most notes
789
+ # (YourMT3+ sometimes outputs spurious tracks with very few notes)
790
+ filtered_instruments = []
791
+ for category, tracks in instrument_groups.items():
792
+ if len(tracks) == 1:
793
+ filtered_instruments.append(tracks[0])
794
+ else:
795
+ # Keep the track with the most notes
796
+ best_track = max(tracks, key=lambda t: len(t.notes))
797
+ filtered_instruments.append(best_track)
798
+
799
+ # Log which tracks were filtered out
800
+ for track in tracks:
801
+ if track != best_track:
802
+ track_name = track.name or f"Program {track.program}"
803
+ best_name = best_track.name or f"Program {best_track.program}"
804
+ print(f" Filtered out spurious track: {track_name} ({len(track.notes)} notes) - kept {best_name} ({len(best_track.notes)} notes)")
805
+
806
+ # Create new MIDI with only selected instruments
807
+ filtered_pm = pretty_midi.PrettyMIDI()
808
+ filtered_pm.instruments = filtered_instruments
809
+
810
+ # Save filtered MIDI
811
+ filtered_path = midi_path.parent / f"{midi_path.stem}_filtered.mid"
812
+ filtered_pm.write(str(filtered_path))
813
+
814
+ # Log filtering results
815
+ original_count = len(pm.instruments)
816
+ filtered_count = len(filtered_instruments)
817
+ print(f" Filtered MIDI: {original_count} tracks → {filtered_count} tracks (1 per category)")
818
+ print(f" Kept instruments: {self.instruments}")
819
+
820
+ return filtered_path
821
+
822
  def apply_post_processing_filters(self, midi_path: Path) -> Path:
823
  """
824
  Apply post-processing filters to improve transcription quality.
backend/tasks.py CHANGED
@@ -1,4 +1,12 @@
1
  """Celery tasks for background job processing."""
 
 
 
 
 
 
 
 
2
  from celery import Task
3
  from celery_app import celery_app
4
  from pipeline import TranscriptionPipeline, run_transcription_pipeline
@@ -6,7 +14,6 @@ import redis
6
  import json
7
  import os
8
  from datetime import datetime
9
- from pathlib import Path
10
  from app_config import settings
11
  import shutil
12
 
@@ -76,24 +83,126 @@ def process_transcription_task(self, job_id: str):
76
 
77
  # Get job data
78
  job_data = redis_client.hgetall(f"job:{job_id}")
79
-
80
  if not job_data:
81
  raise ValueError(f"Job not found: {job_id}")
82
-
 
 
83
  youtube_url = job_data.get('youtube_url')
84
- if not youtube_url:
85
- raise ValueError(f"Job missing youtube_url: {job_id}")
86
 
87
- # Initialize pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  pipeline = TranscriptionPipeline(
89
  job_id=job_id,
90
- youtube_url=youtube_url,
91
- storage_path=settings.storage_path
 
92
  )
93
  pipeline.set_progress_callback(lambda p, s, m: self.update_progress(job_id, p, s, m))
94
 
95
- # Run pipeline
96
- temp_output_path = pipeline.run()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  # Output is already in the temp directory, move to persistent storage
99
  output_path = settings.outputs_path / f"{job_id}.musicxml"
 
1
  """Celery tasks for background job processing."""
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ # Ensure backend directory is in Python path for imports
6
+ backend_dir = Path(__file__).parent.resolve()
7
+ if str(backend_dir) not in sys.path:
8
+ sys.path.insert(0, str(backend_dir))
9
+
10
  from celery import Task
11
  from celery_app import celery_app
12
  from pipeline import TranscriptionPipeline, run_transcription_pipeline
 
14
  import json
15
  import os
16
  from datetime import datetime
 
17
  from app_config import settings
18
  import shutil
19
 
 
83
 
84
  # Get job data
85
  job_data = redis_client.hgetall(f"job:{job_id}")
86
+
87
  if not job_data:
88
  raise ValueError(f"Job not found: {job_id}")
89
+
90
+ # Check if this is a file upload or YouTube URL job
91
+ upload_path = job_data.get('upload_path')
92
  youtube_url = job_data.get('youtube_url')
 
 
93
 
94
+ # Parse instruments option (defaults to piano only)
95
+ instruments = ['piano']
96
+ vocal_instrument_program = 40 # Default to violin
97
+ if 'options' in job_data:
98
+ try:
99
+ options = json.loads(job_data['options'])
100
+ instruments = options.get('instruments', ['piano'])
101
+ vocal_instrument_program = options.get('vocal_instrument', 40)
102
+ except (json.JSONDecodeError, KeyError):
103
+ instruments = ['piano']
104
+ vocal_instrument_program = 40
105
+
106
+ # Import shutil and subprocess
107
+ import shutil
108
+ import subprocess
109
+
110
+ # Create pipeline
111
  pipeline = TranscriptionPipeline(
112
  job_id=job_id,
113
+ youtube_url=youtube_url or "file://uploaded", # Dummy URL for file uploads
114
+ storage_path=settings.storage_path,
115
+ instruments=instruments
116
  )
117
  pipeline.set_progress_callback(lambda p, s, m: self.update_progress(job_id, p, s, m))
118
 
119
+ # Get audio.wav - either from upload or YouTube download
120
+ audio_path = pipeline.temp_dir / "audio.wav"
121
+
122
+ if upload_path:
123
+ # File upload - convert to WAV if needed
124
+ upload_file = Path(upload_path)
125
+ if upload_file.suffix.lower() == '.wav':
126
+ shutil.copy(str(upload_file), str(audio_path))
127
+ else:
128
+ # Convert to WAV using ffmpeg
129
+ result = subprocess.run([
130
+ 'ffmpeg', '-i', str(upload_file),
131
+ '-ar', '44100', '-ac', '2',
132
+ str(audio_path)
133
+ ], capture_output=True, text=True)
134
+ if result.returncode != 0:
135
+ raise RuntimeError(f"Audio conversion failed: {result.stderr}")
136
+ elif youtube_url:
137
+ # YouTube download
138
+ pipeline.progress(0, "download", "Starting audio download")
139
+ audio_path = pipeline.download_audio()
140
+ else:
141
+ raise ValueError(f"Job missing both youtube_url and upload_path: {job_id}")
142
+
143
+ # From here, both paths converge - process audio.wav the same way
144
+ # Preprocess audio if enabled
145
+ if pipeline.config.enable_audio_preprocessing:
146
+ pipeline.progress(10, "preprocess", "Preprocessing audio")
147
+ audio_path = pipeline.preprocess_audio(audio_path)
148
+
149
+ # Source separation
150
+ pipeline.progress(20, "separate", "Starting source separation")
151
+ all_stems = pipeline.separate_sources(audio_path)
152
+
153
+ # Select stems to transcribe based on user selection
154
+ stems_to_transcribe = {}
155
+ for instrument in instruments:
156
+ if instrument in all_stems:
157
+ stems_to_transcribe[instrument] = all_stems[instrument]
158
+ print(f" [DEBUG] Will transcribe {instrument} stem")
159
+ else:
160
+ print(f" [WARNING] {instrument} stem not found in separated audio")
161
+
162
+ # If no selected stems available, fall back to piano
163
+ if not stems_to_transcribe:
164
+ print(f" [WARNING] No selected stems found, falling back to piano")
165
+ if 'piano' in all_stems:
166
+ stems_to_transcribe['piano'] = all_stems['piano']
167
+ else:
168
+ stems_to_transcribe['other'] = all_stems['other']
169
+
170
+ pipeline.progress(50, "transcribe", f"Transcribing {len(stems_to_transcribe)} instrument(s)")
171
+
172
+ # Transcribe stems
173
+ if len(stems_to_transcribe) == 1:
174
+ # Single stem - use original method
175
+ stem_path = list(stems_to_transcribe.values())[0]
176
+ combined_midi = pipeline.transcribe_to_midi(stem_path)
177
+ else:
178
+ # Multiple stems - use new multi-stem method
179
+ combined_midi = pipeline.transcribe_multiple_stems(stems_to_transcribe)
180
+
181
+ # Filter MIDI to only include selected instruments
182
+ filtered_midi = pipeline.filter_midi_by_instruments(combined_midi)
183
+
184
+ # Remap vocals MIDI program if vocals were selected
185
+ if 'vocals' in instruments and vocal_instrument_program != 65:
186
+ print(f" [DEBUG] Remapping vocals MIDI program from 65 to {vocal_instrument_program}")
187
+ import pretty_midi
188
+ pm = pretty_midi.PrettyMIDI(str(filtered_midi))
189
+ for inst in pm.instruments:
190
+ if inst.program == 65 and not inst.is_drum: # Singing Voice
191
+ inst.program = vocal_instrument_program
192
+ print(f" [DEBUG] Changed track '{inst.name}' program to {vocal_instrument_program}")
193
+ # Save remapped MIDI
194
+ pm.write(str(filtered_midi))
195
+
196
+ # Apply post-processing
197
+ midi_path = pipeline.apply_post_processing_filters(filtered_midi)
198
+ pipeline.final_midi_path = midi_path
199
+
200
+ # Get audio stem for MusicXML generation (use piano if available, otherwise first available stem)
201
+ audio_stem = stems_to_transcribe.get('piano') or list(stems_to_transcribe.values())[0]
202
+
203
+ pipeline.progress(90, "musicxml", "Generating MusicXML")
204
+ temp_output_path = pipeline.generate_musicxml_minimal(midi_path, audio_stem)
205
+ pipeline.progress(100, "complete", "Transcription complete")
206
 
207
  # Output is already in the temp directory, move to persistent storage
208
  output_path = settings.outputs_path / f"{job_id}.musicxml"
frontend/src/api/client.ts CHANGED
@@ -49,7 +49,7 @@ export class RescoredAPI {
49
  private baseURL = API_BASE_URL;
50
  private wsBaseURL = WS_BASE_URL;
51
 
52
- async submitJob(youtubeURL: string, options?: { instruments?: string[] }): Promise<TranscribeResponse> {
53
  const response = await fetch(`${this.baseURL}/api/v1/transcribe`, {
54
  method: 'POST',
55
  headers: {
@@ -69,6 +69,27 @@ export class RescoredAPI {
69
  return response.json();
70
  }
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  async getJobStatus(jobId: string): Promise<JobStatus> {
73
  const response = await fetch(`${this.baseURL}/api/v1/jobs/${jobId}`);
74
 
 
49
  private baseURL = API_BASE_URL;
50
  private wsBaseURL = WS_BASE_URL;
51
 
52
+ async submitJob(youtubeURL: string, options?: { instruments?: string[]; vocalInstrument?: number }): Promise<TranscribeResponse> {
53
  const response = await fetch(`${this.baseURL}/api/v1/transcribe`, {
54
  method: 'POST',
55
  headers: {
 
69
  return response.json();
70
  }
71
 
72
+ async submitFileJob(file: File, options?: { instruments?: string[]; vocalInstrument?: number }): Promise<TranscribeResponse> {
73
+ const formData = new FormData();
74
+ formData.append('file', file);
75
+ formData.append('instruments', JSON.stringify(options?.instruments ?? ['piano']));
76
+ if (options?.vocalInstrument !== undefined) {
77
+ formData.append('vocal_instrument', options.vocalInstrument.toString());
78
+ }
79
+
80
+ const response = await fetch(`${this.baseURL}/api/v1/transcribe/upload`, {
81
+ method: 'POST',
82
+ body: formData,
83
+ });
84
+
85
+ if (!response.ok) {
86
+ const error = await response.json();
87
+ throw new Error(error.detail || 'Failed to submit file');
88
+ }
89
+
90
+ return response.json();
91
+ }
92
+
93
  async getJobStatus(jobId: string): Promise<JobStatus> {
94
  const response = await fetch(`${this.baseURL}/api/v1/jobs/${jobId}`);
95
 
frontend/src/components/InstrumentSelector.css CHANGED
@@ -69,6 +69,42 @@
69
  font-style: italic;
70
  }
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  /* Responsive adjustments */
73
  @media (max-width: 600px) {
74
  .instrument-grid {
 
69
  font-style: italic;
70
  }
71
 
72
+ .vocal-instrument-selector {
73
+ margin: 1.5rem 0;
74
+ padding: 1rem;
75
+ background-color: #f8f9fa;
76
+ border-radius: 8px;
77
+ border: 1px solid #dee2e6;
78
+ }
79
+
80
+ .vocal-instrument-selector label {
81
+ display: block;
82
+ margin-bottom: 0.5rem;
83
+ font-weight: 500;
84
+ color: #495057;
85
+ }
86
+
87
+ .vocal-instrument-selector select {
88
+ width: 100%;
89
+ padding: 0.5rem;
90
+ font-size: 1rem;
91
+ border: 1px solid #ced4da;
92
+ border-radius: 4px;
93
+ background-color: white;
94
+ cursor: pointer;
95
+ transition: border-color 0.2s ease;
96
+ }
97
+
98
+ .vocal-instrument-selector select:hover {
99
+ border-color: #007bff;
100
+ }
101
+
102
+ .vocal-instrument-selector select:focus {
103
+ outline: none;
104
+ border-color: #007bff;
105
+ box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.1);
106
+ }
107
+
108
  /* Responsive adjustments */
109
  @media (max-width: 600px) {
110
  .instrument-grid {
frontend/src/components/InstrumentSelector.tsx CHANGED
@@ -12,19 +12,35 @@ export interface Instrument {
12
 
13
  const INSTRUMENTS: Instrument[] = [
14
  { id: 'piano', label: 'Piano', icon: '🎹' },
15
- { id: 'vocals', label: 'Vocals (Violin)', icon: '🎤' },
16
  { id: 'drums', label: 'Drums', icon: '🥁' },
17
  { id: 'bass', label: 'Bass', icon: '🎸' },
18
  { id: 'guitar', label: 'Guitar', icon: '🎸' },
19
  { id: 'other', label: 'Other Instruments', icon: '🎵' }
20
  ];
21
 
 
 
 
 
 
 
 
 
 
22
  interface InstrumentSelectorProps {
23
  selectedInstruments: string[];
24
  onChange: (instruments: string[]) => void;
 
 
25
  }
26
 
27
- export function InstrumentSelector({ selectedInstruments, onChange }: InstrumentSelectorProps) {
 
 
 
 
 
28
  const handleToggle = (instrumentId: string) => {
29
  const isSelected = selectedInstruments.includes(instrumentId);
30
 
@@ -33,12 +49,18 @@ export function InstrumentSelector({ selectedInstruments, onChange }: Instrument
33
  if (selectedInstruments.length === 1) {
34
  return;
35
  }
36
- onChange(selectedInstruments.filter(id => id !== instrumentId));
 
 
37
  } else {
38
- onChange([...selectedInstruments, instrumentId]);
 
 
39
  }
40
  };
41
 
 
 
42
  return (
43
  <div className="instrument-selector">
44
  <label className="selector-label">Select Instruments:</label>
@@ -56,6 +78,24 @@ export function InstrumentSelector({ selectedInstruments, onChange }: Instrument
56
  </button>
57
  ))}
58
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  <p className="selector-hint">
60
  Select at least one instrument to transcribe
61
  </p>
 
12
 
13
  const INSTRUMENTS: Instrument[] = [
14
  { id: 'piano', label: 'Piano', icon: '🎹' },
15
+ { id: 'vocals', label: 'Vocals', icon: '🎤' },
16
  { id: 'drums', label: 'Drums', icon: '🥁' },
17
  { id: 'bass', label: 'Bass', icon: '🎸' },
18
  { id: 'guitar', label: 'Guitar', icon: '🎸' },
19
  { id: 'other', label: 'Other Instruments', icon: '🎵' }
20
  ];
21
 
22
+ export const VOCAL_INSTRUMENTS = [
23
+ { id: 'violin', label: 'Violin', program: 40 },
24
+ { id: 'flute', label: 'Flute', program: 73 },
25
+ { id: 'clarinet', label: 'Clarinet', program: 71 },
26
+ { id: 'saxophone', label: 'Saxophone', program: 64 },
27
+ { id: 'trumpet', label: 'Trumpet', program: 56 },
28
+ { id: 'voice', label: 'Singing Voice', program: 65 },
29
+ ];
30
+
31
  interface InstrumentSelectorProps {
32
  selectedInstruments: string[];
33
  onChange: (instruments: string[]) => void;
34
+ vocalInstrument?: string;
35
+ onVocalInstrumentChange?: (instrument: string) => void;
36
  }
37
 
38
+ export function InstrumentSelector({
39
+ selectedInstruments,
40
+ onChange,
41
+ vocalInstrument = 'violin',
42
+ onVocalInstrumentChange
43
+ }: InstrumentSelectorProps) {
44
  const handleToggle = (instrumentId: string) => {
45
  const isSelected = selectedInstruments.includes(instrumentId);
46
 
 
49
  if (selectedInstruments.length === 1) {
50
  return;
51
  }
52
+ const newInstruments = selectedInstruments.filter(id => id !== instrumentId);
53
+ console.log('[DEBUG] InstrumentSelector: Removing', instrumentId, '-> New list:', newInstruments);
54
+ onChange(newInstruments);
55
  } else {
56
+ const newInstruments = [...selectedInstruments, instrumentId];
57
+ console.log('[DEBUG] InstrumentSelector: Adding', instrumentId, '-> New list:', newInstruments);
58
+ onChange(newInstruments);
59
  }
60
  };
61
 
62
+ const vocalsSelected = selectedInstruments.includes('vocals');
63
+
64
  return (
65
  <div className="instrument-selector">
66
  <label className="selector-label">Select Instruments:</label>
 
78
  </button>
79
  ))}
80
  </div>
81
+
82
+ {vocalsSelected && onVocalInstrumentChange && (
83
+ <div className="vocal-instrument-selector">
84
+ <label htmlFor="vocal-instrument">Transcribe vocals as:</label>
85
+ <select
86
+ id="vocal-instrument"
87
+ value={vocalInstrument}
88
+ onChange={(e) => onVocalInstrumentChange(e.target.value)}
89
+ >
90
+ {VOCAL_INSTRUMENTS.map(inst => (
91
+ <option key={inst.id} value={inst.id}>
92
+ {inst.label}
93
+ </option>
94
+ ))}
95
+ </select>
96
+ </div>
97
+ )}
98
+
99
  <p className="selector-hint">
100
  Select at least one instrument to transcribe
101
  </p>
frontend/src/components/JobSubmission.css CHANGED
@@ -111,3 +111,38 @@ button:hover {
111
  margin-top: 1rem;
112
  border: 1px solid #f5c6cb;
113
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  margin-top: 1rem;
112
  border: 1px solid #f5c6cb;
113
  }
114
+
115
+ .upload-mode-selector {
116
+ display: flex;
117
+ gap: 0.5rem;
118
+ margin-top: 0.5rem;
119
+ }
120
+
121
+ .upload-mode-selector button {
122
+ flex: 1;
123
+ padding: 0.5rem 1rem;
124
+ background-color: #f0f0f0;
125
+ color: #333;
126
+ border: 2px solid #ddd;
127
+ border-radius: 4px;
128
+ font-size: 0.9rem;
129
+ cursor: pointer;
130
+ transition: all 0.2s ease;
131
+ }
132
+
133
+ .upload-mode-selector button:hover {
134
+ background-color: #e0e0e0;
135
+ border-color: #bbb;
136
+ }
137
+
138
+ .upload-mode-selector button.active {
139
+ background-color: #007bff;
140
+ color: white;
141
+ border-color: #007bff;
142
+ }
143
+
144
+ .file-info {
145
+ margin-top: 0.5rem;
146
+ font-size: 0.9rem;
147
+ color: #666;
148
+ }
frontend/src/components/JobSubmission.tsx CHANGED
@@ -4,7 +4,7 @@
4
  import { useState, useRef, useEffect } from 'react';
5
  import { api } from '../api/client';
6
  import type { ProgressUpdate } from '../api/client';
7
- import { InstrumentSelector } from './InstrumentSelector';
8
  import './JobSubmission.css';
9
 
10
  interface JobSubmissionProps {
@@ -14,7 +14,10 @@ interface JobSubmissionProps {
14
 
15
  export function JobSubmission({ onComplete, onJobSubmitted }: JobSubmissionProps) {
16
  const [youtubeUrl, setYoutubeUrl] = useState('');
 
 
17
  const [selectedInstruments, setSelectedInstruments] = useState<string[]>(['piano']);
 
18
  const [status, setStatus] = useState<'idle' | 'submitting' | 'processing' | 'failed'>('idle');
19
  const [error, setError] = useState<string | null>(null);
20
  const [progress, setProgress] = useState(0);
@@ -46,11 +49,18 @@ export function JobSubmission({ onComplete, onJobSubmitted }: JobSubmissionProps
46
  e.preventDefault();
47
  setError(null);
48
 
49
- // Validate URL
50
- const validation = validateUrl(youtubeUrl);
51
- if (validation) {
52
- setError(validation);
53
- return;
 
 
 
 
 
 
 
54
  }
55
 
56
  // Validate at least one instrument is selected
@@ -61,9 +71,18 @@ export function JobSubmission({ onComplete, onJobSubmitted }: JobSubmissionProps
61
 
62
  setStatus('submitting');
63
 
 
 
 
 
 
64
  try {
65
- const response = await api.submitJob(youtubeUrl, { instruments: selectedInstruments });
 
 
 
66
  setYoutubeUrl('');
 
67
  if (onJobSubmitted) onJobSubmitted(response);
68
 
69
  // Switch to processing status and connect WebSocket
@@ -164,23 +183,82 @@ export function JobSubmission({ onComplete, onJobSubmitted }: JobSubmissionProps
164
  <InstrumentSelector
165
  selectedInstruments={selectedInstruments}
166
  onChange={setSelectedInstruments}
 
 
167
  />
168
 
169
  <div className="form-group">
170
- <label htmlFor="youtube-url">YouTube URL:</label>
171
- <input
172
- id="youtube-url"
173
- type="text"
174
- value={youtubeUrl}
175
- onChange={(e) => setYoutubeUrl(e.target.value)}
176
- placeholder="https://www.youtube.com/watch?v=..."
177
- required
178
- onBlur={() => {
179
- const validation = validateUrl(youtubeUrl);
180
- if (validation) setError(validation);
181
- }}
182
- />
 
 
 
 
 
 
 
 
 
 
183
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  <button type="submit" disabled={status === 'submitting'}>Transcribe</button>
185
  {status === 'submitting' && <div>Submitting...</div>}
186
  {error && <div role="alert" className="error-alert">{error}</div>}
 
4
  import { useState, useRef, useEffect } from 'react';
5
  import { api } from '../api/client';
6
  import type { ProgressUpdate } from '../api/client';
7
+ import { InstrumentSelector, VOCAL_INSTRUMENTS } from './InstrumentSelector';
8
  import './JobSubmission.css';
9
 
10
  interface JobSubmissionProps {
 
14
 
15
  export function JobSubmission({ onComplete, onJobSubmitted }: JobSubmissionProps) {
16
  const [youtubeUrl, setYoutubeUrl] = useState('');
17
+ const [uploadMode, setUploadMode] = useState<'url' | 'file'>('url');
18
+ const [selectedFile, setSelectedFile] = useState<File | null>(null);
19
  const [selectedInstruments, setSelectedInstruments] = useState<string[]>(['piano']);
20
+ const [vocalInstrument, setVocalInstrument] = useState('violin');
21
  const [status, setStatus] = useState<'idle' | 'submitting' | 'processing' | 'failed'>('idle');
22
  const [error, setError] = useState<string | null>(null);
23
  const [progress, setProgress] = useState(0);
 
49
  e.preventDefault();
50
  setError(null);
51
 
52
+ // Validate based on mode
53
+ if (uploadMode === 'url') {
54
+ const validation = validateUrl(youtubeUrl);
55
+ if (validation) {
56
+ setError(validation);
57
+ return;
58
+ }
59
+ } else {
60
+ if (!selectedFile) {
61
+ setError('Please select an audio file');
62
+ return;
63
+ }
64
  }
65
 
66
  // Validate at least one instrument is selected
 
71
 
72
  setStatus('submitting');
73
 
74
+ console.log('[DEBUG] About to submit job with instruments:', selectedInstruments);
75
+
76
+ // Get the MIDI program number for the selected vocal instrument
77
+ const vocalProgram = VOCAL_INSTRUMENTS.find(v => v.id === vocalInstrument)?.program || 40;
78
+
79
  try {
80
+ const response = uploadMode === 'url'
81
+ ? await api.submitJob(youtubeUrl, { instruments: selectedInstruments, vocalInstrument: vocalProgram })
82
+ : await api.submitFileJob(selectedFile!, { instruments: selectedInstruments, vocalInstrument: vocalProgram });
83
+
84
  setYoutubeUrl('');
85
+ setSelectedFile(null);
86
  if (onJobSubmitted) onJobSubmitted(response);
87
 
88
  // Switch to processing status and connect WebSocket
 
183
  <InstrumentSelector
184
  selectedInstruments={selectedInstruments}
185
  onChange={setSelectedInstruments}
186
+ vocalInstrument={vocalInstrument}
187
+ onVocalInstrumentChange={setVocalInstrument}
188
  />
189
 
190
  <div className="form-group">
191
+ <label>Input Method:</label>
192
+ <div className="upload-mode-selector">
193
+ <button
194
+ type="button"
195
+ className={uploadMode === 'url' ? 'active' : ''}
196
+ onClick={() => {
197
+ setUploadMode('url');
198
+ setError(null);
199
+ }}
200
+ >
201
+ YouTube URL
202
+ </button>
203
+ <button
204
+ type="button"
205
+ className={uploadMode === 'file' ? 'active' : ''}
206
+ onClick={() => {
207
+ setUploadMode('file');
208
+ setError(null);
209
+ }}
210
+ >
211
+ Upload Audio File
212
+ </button>
213
+ </div>
214
  </div>
215
+
216
+ {uploadMode === 'url' ? (
217
+ <div className="form-group">
218
+ <label htmlFor="youtube-url">YouTube URL:</label>
219
+ <input
220
+ id="youtube-url"
221
+ type="text"
222
+ value={youtubeUrl}
223
+ onChange={(e) => setYoutubeUrl(e.target.value)}
224
+ placeholder="https://www.youtube.com/watch?v=..."
225
+ required
226
+ onBlur={() => {
227
+ const validation = validateUrl(youtubeUrl);
228
+ if (validation) setError(validation);
229
+ }}
230
+ />
231
+ </div>
232
+ ) : (
233
+ <div className="form-group">
234
+ <label htmlFor="audio-file">Audio File (WAV, MP3, FLAC, etc.):</label>
235
+ <input
236
+ id="audio-file"
237
+ type="file"
238
+ accept=".wav,.mp3,.flac,.ogg,.m4a,.aac"
239
+ onChange={(e) => {
240
+ const file = e.target.files?.[0];
241
+ if (file) {
242
+ const maxSize = 100 * 1024 * 1024; // 100MB
243
+ if (file.size > maxSize) {
244
+ setError('File too large. Maximum size: 100MB');
245
+ setSelectedFile(null);
246
+ } else {
247
+ setSelectedFile(file);
248
+ setError(null);
249
+ }
250
+ }
251
+ }}
252
+ required
253
+ />
254
+ {selectedFile && (
255
+ <p className="file-info">
256
+ Selected: {selectedFile.name} ({(selectedFile.size / 1024 / 1024).toFixed(2)} MB)
257
+ </p>
258
+ )}
259
+ </div>
260
+ )}
261
+
262
  <button type="submit" disabled={status === 'submitting'}>Transcribe</button>
263
  {status === 'submitting' && <div>Submitting...</div>}
264
  {error && <div role="alert" className="error-alert">{error}</div>}