| | """Celery tasks for background job processing.""" |
| | import sys |
| | from pathlib import Path |
| |
|
| | |
| | backend_dir = Path(__file__).parent.resolve() |
| | if str(backend_dir) not in sys.path: |
| | sys.path.insert(0, str(backend_dir)) |
| |
|
| | from celery import Task |
| | from celery_app import celery_app |
| | from pipeline import TranscriptionPipeline, run_transcription_pipeline |
| | from redis_client import get_redis_client |
| | import json |
| | import os |
| | from datetime import datetime |
| | from app_config import settings |
| | import shutil |
| |
|
| | |
| | redis_client = get_redis_client() |
| |
|
| |
|
| | class TranscriptionTask(Task): |
| | """Base task with progress tracking.""" |
| |
|
| | def update_progress(self, job_id: str, progress: int, stage: str, message: str) -> None: |
| | """ |
| | Update job progress in Redis and publish to WebSocket subscribers. |
| | |
| | Args: |
| | job_id: Job identifier |
| | progress: Progress percentage (0-100) |
| | stage: Current stage name |
| | message: Status message |
| | """ |
| | print(f"[PROGRESS] {progress}% - {stage} - {message}") |
| | job_key = f"job:{job_id}" |
| |
|
| | |
| | redis_client.hset(job_key, mapping={ |
| | "progress": progress, |
| | "current_stage": stage, |
| | "status_message": message, |
| | "updated_at": datetime.utcnow().isoformat(), |
| | }) |
| |
|
| | |
| | update = { |
| | "type": "progress", |
| | "job_id": job_id, |
| | "progress": progress, |
| | "stage": stage, |
| | "message": message, |
| | "timestamp": datetime.utcnow().isoformat(), |
| | } |
| | redis_client.rpush(f"job:{job_id}:progress_history", json.dumps(update)) |
| | |
| | |
| | num_subscribers = redis_client.publish(f"job:{job_id}:updates", json.dumps(update)) |
| | if num_subscribers > 0: |
| | print(f"[PROGRESS] Published to {num_subscribers} subscribers") |
| | else: |
| | print(f"[PROGRESS] Stored in history (no live subscribers)") |
| |
|
| |
|
| | @celery_app.task(base=TranscriptionTask, bind=True) |
| | def process_transcription_task(self, job_id: str): |
| | """ |
| | Main transcription task. |
| | |
| | Args: |
| | job_id: Unique job identifier |
| | |
| | Returns: |
| | Path to generated MusicXML file |
| | """ |
| | try: |
| | |
| | redis_client.hset(f"job:{job_id}", mapping={ |
| | "status": "processing", |
| | "started_at": datetime.utcnow().isoformat(), |
| | }) |
| |
|
| | |
| | job_data = redis_client.hgetall(f"job:{job_id}") |
| |
|
| | if not job_data: |
| | raise ValueError(f"Job not found: {job_id}") |
| |
|
| | |
| | upload_path = job_data.get('upload_path') |
| | youtube_url = job_data.get('youtube_url') |
| |
|
| | |
| | instruments = ['piano'] |
| | vocal_instrument_program = 40 |
| | if 'options' in job_data: |
| | try: |
| | options = json.loads(job_data['options']) |
| | instruments = options.get('instruments', ['piano']) |
| | vocal_instrument_program = options.get('vocal_instrument', 40) |
| | except (json.JSONDecodeError, KeyError): |
| | instruments = ['piano'] |
| | vocal_instrument_program = 40 |
| |
|
| | |
| | import shutil |
| | import subprocess |
| |
|
| | |
| | pipeline = TranscriptionPipeline( |
| | job_id=job_id, |
| | youtube_url=youtube_url or "file://uploaded", |
| | storage_path=settings.storage_path, |
| | instruments=instruments |
| | ) |
| | pipeline.set_progress_callback(lambda p, s, m: self.update_progress(job_id, p, s, m)) |
| |
|
| | |
| | audio_path = pipeline.temp_dir / "audio.wav" |
| |
|
| | if upload_path: |
| | |
| | upload_file = Path(upload_path) |
| | if upload_file.suffix.lower() == '.wav': |
| | shutil.copy(str(upload_file), str(audio_path)) |
| | else: |
| | |
| | result = subprocess.run([ |
| | 'ffmpeg', '-i', str(upload_file), |
| | '-ar', '44100', '-ac', '2', |
| | str(audio_path) |
| | ], capture_output=True, text=True) |
| | if result.returncode != 0: |
| | raise RuntimeError(f"Audio conversion failed: {result.stderr}") |
| | elif youtube_url: |
| | |
| | pipeline.progress(0, "download", "Starting audio download") |
| | audio_path = pipeline.download_audio() |
| | else: |
| | raise ValueError(f"Job missing both youtube_url and upload_path: {job_id}") |
| |
|
| | |
| | |
| | if pipeline.config.enable_audio_preprocessing: |
| | pipeline.progress(10, "preprocess", "Preprocessing audio") |
| | audio_path = pipeline.preprocess_audio(audio_path) |
| |
|
| | |
| | pipeline.progress(20, "separate", "Starting source separation") |
| | all_stems = pipeline.separate_sources(audio_path) |
| |
|
| | |
| | stems_to_transcribe = {} |
| | for instrument in instruments: |
| | if instrument in all_stems: |
| | stems_to_transcribe[instrument] = all_stems[instrument] |
| | print(f" [DEBUG] Will transcribe {instrument} stem") |
| | else: |
| | print(f" [WARNING] {instrument} stem not found in separated audio") |
| |
|
| | |
| | if not stems_to_transcribe: |
| | print(f" [WARNING] No selected stems found, falling back to piano") |
| | if 'piano' in all_stems: |
| | stems_to_transcribe['piano'] = all_stems['piano'] |
| | else: |
| | stems_to_transcribe['other'] = all_stems['other'] |
| |
|
| | pipeline.progress(50, "transcribe", f"Transcribing {len(stems_to_transcribe)} instrument(s)") |
| |
|
| | |
| | if len(stems_to_transcribe) == 1: |
| | |
| | stem_path = list(stems_to_transcribe.values())[0] |
| | combined_midi = pipeline.transcribe_to_midi(stem_path) |
| | else: |
| | |
| | combined_midi = pipeline.transcribe_multiple_stems(stems_to_transcribe) |
| |
|
| | |
| | filtered_midi = pipeline.filter_midi_by_instruments(combined_midi) |
| |
|
| | |
| | if 'vocals' in instruments and vocal_instrument_program != 65: |
| | print(f" [DEBUG] Remapping vocals MIDI program from 65 to {vocal_instrument_program}") |
| | import pretty_midi |
| | pm = pretty_midi.PrettyMIDI(str(filtered_midi)) |
| | for inst in pm.instruments: |
| | if inst.program == 65 and not inst.is_drum: |
| | inst.program = vocal_instrument_program |
| | print(f" [DEBUG] Changed track '{inst.name}' program to {vocal_instrument_program}") |
| | |
| | pm.write(str(filtered_midi)) |
| |
|
| | |
| | midi_path = pipeline.apply_post_processing_filters(filtered_midi) |
| | pipeline.final_midi_path = midi_path |
| |
|
| | |
| | audio_stem = stems_to_transcribe.get('piano') or list(stems_to_transcribe.values())[0] |
| |
|
| | pipeline.progress(90, "musicxml", "Generating MusicXML") |
| | temp_output_path = pipeline.generate_musicxml_minimal(midi_path, audio_stem) |
| | pipeline.progress(100, "complete", "Transcription complete") |
| |
|
| | |
| | output_path = settings.outputs_path / f"{job_id}.musicxml" |
| | midi_path = settings.outputs_path / f"{job_id}.mid" |
| |
|
| | |
| | settings.outputs_path.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | shutil.copy(str(temp_output_path), str(output_path)) |
| |
|
| | |
| | |
| | temp_midi_path = getattr(pipeline, 'final_midi_path', pipeline.temp_dir / "piano.mid") |
| | print(f"[DEBUG] Using MIDI from pipeline: {temp_midi_path}") |
| | print(f"[DEBUG] MIDI exists: {temp_midi_path.exists()}") |
| |
|
| | if temp_midi_path.exists(): |
| | print(f"[DEBUG] Copying MIDI from {temp_midi_path} to {midi_path}") |
| | shutil.copy(str(temp_midi_path), str(midi_path)) |
| | print(f"[DEBUG] Copy complete, destination exists: {midi_path.exists()}") |
| | else: |
| | print(f"[DEBUG] WARNING: No MIDI file found at {temp_midi_path}!") |
| |
|
| | |
| | metadata = getattr(pipeline, 'metadata', { |
| | "tempo": 120.0, |
| | "time_signature": {"numerator": 4, "denominator": 4}, |
| | "key_signature": "C", |
| | }) |
| |
|
| | |
| | pipeline.cleanup() |
| |
|
| | |
| | redis_client.hset(f"job:{job_id}", mapping={ |
| | "status": "completed", |
| | "progress": 100, |
| | "output_path": str(output_path.absolute()), |
| | "midi_path": str(midi_path.absolute()) if temp_midi_path.exists() else "", |
| | "metadata": json.dumps(metadata), |
| | "completed_at": datetime.utcnow().isoformat(), |
| | }) |
| |
|
| | |
| | completion_msg = { |
| | "type": "completed", |
| | "job_id": job_id, |
| | "result_url": f"/api/v1/scores/{job_id}", |
| | "timestamp": datetime.utcnow().isoformat(), |
| | } |
| | redis_client.publish(f"job:{job_id}:updates", json.dumps(completion_msg)) |
| |
|
| | return str(output_path) |
| |
|
| | except Exception as e: |
| | import traceback |
| |
|
| | |
| | RETRYABLE_EXCEPTIONS = ( |
| | ConnectionError, |
| | TimeoutError, |
| | IOError, |
| | ) |
| |
|
| | is_retryable = isinstance(e, RETRYABLE_EXCEPTIONS) and self.request.retries < self.max_retries |
| |
|
| | |
| | redis_client.hset(f"job:{job_id}", mapping={ |
| | "status": "failed", |
| | "error": json.dumps({ |
| | "message": str(e), |
| | "type": type(e).__name__, |
| | "retryable": is_retryable, |
| | "traceback": traceback.format_exc(), |
| | }), |
| | "failed_at": datetime.utcnow().isoformat(), |
| | }) |
| |
|
| | |
| | error_msg = { |
| | "type": "error", |
| | "job_id": job_id, |
| | "error": { |
| | "message": str(e), |
| | "type": type(e).__name__, |
| | "retryable": is_retryable, |
| | }, |
| | "timestamp": datetime.utcnow().isoformat(), |
| | } |
| | redis_client.publish(f"job:{job_id}:updates", json.dumps(error_msg)) |
| |
|
| | |
| | if is_retryable: |
| | print(f"[RETRY] Retrying job {job_id} (attempt {self.request.retries + 1}/{self.max_retries})") |
| | raise self.retry(exc=e, countdown=2 ** self.request.retries) |
| | else: |
| | |
| | print(f"[ERROR] Non-retryable error for job {job_id}: {type(e).__name__}: {e}") |
| | raise |
| |
|
| |
|
| | |
| |
|
| | def update_progress(job_id: str, progress: int, stage: str, message: str) -> None: |
| | """ |
| | Update job progress (wrapper for backward compatibility). |
| | |
| | Args: |
| | job_id: Job identifier |
| | progress: Progress percentage (0-100) |
| | stage: Current stage name |
| | message: Status message |
| | """ |
| | |
| | task = TranscriptionTask() |
| | task.update_progress(job_id, progress, stage, message) |
| |
|
| |
|
| | def cleanup_temp_files(job_id: str, storage_path: Path = None) -> None: |
| | """ |
| | Clean up temporary files for a job. |
| | |
| | Args: |
| | job_id: Job identifier |
| | storage_path: Path to storage directory (uses settings if not provided) |
| | """ |
| | if storage_path is None: |
| | storage_path = settings.storage_path |
| |
|
| | temp_dir = storage_path / "temp" / job_id |
| | if temp_dir.exists(): |
| | shutil.rmtree(temp_dir, ignore_errors=True) |
| |
|