rescored / backend /tasks.py
calebhan's picture
deployment 25
f849d05
"""Celery tasks for background job processing."""
import sys
from pathlib import Path
# Ensure backend directory is in Python path for imports
backend_dir = Path(__file__).parent.resolve()
if str(backend_dir) not in sys.path:
sys.path.insert(0, str(backend_dir))
from celery import Task
from celery_app import celery_app
from pipeline import TranscriptionPipeline, run_transcription_pipeline
from redis_client import get_redis_client
import json
import os
from datetime import datetime
from app_config import settings
import shutil
# Get shared Redis client singleton
redis_client = get_redis_client()
class TranscriptionTask(Task):
"""Base task with progress tracking."""
def update_progress(self, job_id: str, progress: int, stage: str, message: str) -> None:
"""
Update job progress in Redis and publish to WebSocket subscribers.
Args:
job_id: Job identifier
progress: Progress percentage (0-100)
stage: Current stage name
message: Status message
"""
print(f"[PROGRESS] {progress}% - {stage} - {message}")
job_key = f"job:{job_id}"
# Update Redis hash
redis_client.hset(job_key, mapping={
"progress": progress,
"current_stage": stage,
"status_message": message,
"updated_at": datetime.utcnow().isoformat(),
})
# Store progress in a list for history (helps with eager mode where client connects late)
update = {
"type": "progress",
"job_id": job_id,
"progress": progress,
"stage": stage,
"message": message,
"timestamp": datetime.utcnow().isoformat(),
}
redis_client.rpush(f"job:{job_id}:progress_history", json.dumps(update))
# Publish to pub/sub for WebSocket clients (in case they're already connected)
num_subscribers = redis_client.publish(f"job:{job_id}:updates", json.dumps(update))
if num_subscribers > 0:
print(f"[PROGRESS] Published to {num_subscribers} subscribers")
else:
print(f"[PROGRESS] Stored in history (no live subscribers)")
@celery_app.task(base=TranscriptionTask, bind=True)
def process_transcription_task(self, job_id: str):
"""
Main transcription task.
Args:
job_id: Unique job identifier
Returns:
Path to generated MusicXML file
"""
try:
# Mark job as started
redis_client.hset(f"job:{job_id}", mapping={
"status": "processing",
"started_at": datetime.utcnow().isoformat(),
})
# Get job data
job_data = redis_client.hgetall(f"job:{job_id}")
if not job_data:
raise ValueError(f"Job not found: {job_id}")
# Check if this is a file upload or YouTube URL job
upload_path = job_data.get('upload_path')
youtube_url = job_data.get('youtube_url')
# Parse instruments option (defaults to piano only)
instruments = ['piano']
vocal_instrument_program = 40 # Default to violin
if 'options' in job_data:
try:
options = json.loads(job_data['options'])
instruments = options.get('instruments', ['piano'])
vocal_instrument_program = options.get('vocal_instrument', 40)
except (json.JSONDecodeError, KeyError):
instruments = ['piano']
vocal_instrument_program = 40
# Import shutil and subprocess
import shutil
import subprocess
# Create pipeline
pipeline = TranscriptionPipeline(
job_id=job_id,
youtube_url=youtube_url or "file://uploaded", # Dummy URL for file uploads
storage_path=settings.storage_path,
instruments=instruments
)
pipeline.set_progress_callback(lambda p, s, m: self.update_progress(job_id, p, s, m))
# Get audio.wav - either from upload or YouTube download
audio_path = pipeline.temp_dir / "audio.wav"
if upload_path:
# File upload - convert to WAV if needed
upload_file = Path(upload_path)
if upload_file.suffix.lower() == '.wav':
shutil.copy(str(upload_file), str(audio_path))
else:
# Convert to WAV using ffmpeg
result = subprocess.run([
'ffmpeg', '-i', str(upload_file),
'-ar', '44100', '-ac', '2',
str(audio_path)
], capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Audio conversion failed: {result.stderr}")
elif youtube_url:
# YouTube download
pipeline.progress(0, "download", "Starting audio download")
audio_path = pipeline.download_audio()
else:
raise ValueError(f"Job missing both youtube_url and upload_path: {job_id}")
# From here, both paths converge - process audio.wav the same way
# Preprocess audio if enabled
if pipeline.config.enable_audio_preprocessing:
pipeline.progress(10, "preprocess", "Preprocessing audio")
audio_path = pipeline.preprocess_audio(audio_path)
# Source separation
pipeline.progress(20, "separate", "Starting source separation")
all_stems = pipeline.separate_sources(audio_path)
# Select stems to transcribe based on user selection
stems_to_transcribe = {}
for instrument in instruments:
if instrument in all_stems:
stems_to_transcribe[instrument] = all_stems[instrument]
print(f" [DEBUG] Will transcribe {instrument} stem")
else:
print(f" [WARNING] {instrument} stem not found in separated audio")
# If no selected stems available, fall back to piano
if not stems_to_transcribe:
print(f" [WARNING] No selected stems found, falling back to piano")
if 'piano' in all_stems:
stems_to_transcribe['piano'] = all_stems['piano']
else:
stems_to_transcribe['other'] = all_stems['other']
pipeline.progress(50, "transcribe", f"Transcribing {len(stems_to_transcribe)} instrument(s)")
# Transcribe stems
if len(stems_to_transcribe) == 1:
# Single stem - use original method
stem_path = list(stems_to_transcribe.values())[0]
combined_midi = pipeline.transcribe_to_midi(stem_path)
else:
# Multiple stems - use new multi-stem method
combined_midi = pipeline.transcribe_multiple_stems(stems_to_transcribe)
# Filter MIDI to only include selected instruments
filtered_midi = pipeline.filter_midi_by_instruments(combined_midi)
# Remap vocals MIDI program if vocals were selected
if 'vocals' in instruments and vocal_instrument_program != 65:
print(f" [DEBUG] Remapping vocals MIDI program from 65 to {vocal_instrument_program}")
import pretty_midi
pm = pretty_midi.PrettyMIDI(str(filtered_midi))
for inst in pm.instruments:
if inst.program == 65 and not inst.is_drum: # Singing Voice
inst.program = vocal_instrument_program
print(f" [DEBUG] Changed track '{inst.name}' program to {vocal_instrument_program}")
# Save remapped MIDI
pm.write(str(filtered_midi))
# Apply post-processing
midi_path = pipeline.apply_post_processing_filters(filtered_midi)
pipeline.final_midi_path = midi_path
# Get audio stem for MusicXML generation (use piano if available, otherwise first available stem)
audio_stem = stems_to_transcribe.get('piano') or list(stems_to_transcribe.values())[0]
pipeline.progress(90, "musicxml", "Generating MusicXML")
temp_output_path = pipeline.generate_musicxml_minimal(midi_path, audio_stem)
pipeline.progress(100, "complete", "Transcription complete")
# Output is already in the temp directory, move to persistent storage
output_path = settings.outputs_path / f"{job_id}.musicxml"
midi_path = settings.outputs_path / f"{job_id}.mid"
# Ensure outputs directory exists
settings.outputs_path.mkdir(parents=True, exist_ok=True)
# Copy the MusicXML file to outputs
shutil.copy(str(temp_output_path), str(output_path))
# Copy the MIDI file to outputs (use actual processed MIDI from pipeline)
# Pipeline stores the final MIDI path (after quantization) in final_midi_path
temp_midi_path = getattr(pipeline, 'final_midi_path', pipeline.temp_dir / "piano.mid")
print(f"[DEBUG] Using MIDI from pipeline: {temp_midi_path}")
print(f"[DEBUG] MIDI exists: {temp_midi_path.exists()}")
if temp_midi_path.exists():
print(f"[DEBUG] Copying MIDI from {temp_midi_path} to {midi_path}")
shutil.copy(str(temp_midi_path), str(midi_path))
print(f"[DEBUG] Copy complete, destination exists: {midi_path.exists()}")
else:
print(f"[DEBUG] WARNING: No MIDI file found at {temp_midi_path}!")
# Store metadata for API access
metadata = getattr(pipeline, 'metadata', {
"tempo": 120.0,
"time_signature": {"numerator": 4, "denominator": 4},
"key_signature": "C",
})
# Cleanup temp files (pipeline has its own cleanup method)
pipeline.cleanup()
# Mark job as completed
redis_client.hset(f"job:{job_id}", mapping={
"status": "completed",
"progress": 100,
"output_path": str(output_path.absolute()),
"midi_path": str(midi_path.absolute()) if temp_midi_path.exists() else "",
"metadata": json.dumps(metadata),
"completed_at": datetime.utcnow().isoformat(),
})
# Publish completion message
completion_msg = {
"type": "completed",
"job_id": job_id,
"result_url": f"/api/v1/scores/{job_id}",
"timestamp": datetime.utcnow().isoformat(),
}
redis_client.publish(f"job:{job_id}:updates", json.dumps(completion_msg))
return str(output_path)
except Exception as e:
import traceback
# Determine if error is retryable (only retry transient errors, not code bugs)
RETRYABLE_EXCEPTIONS = (
ConnectionError, # Network errors
TimeoutError, # Timeout errors
IOError, # I/O errors (file system, disk full, etc.)
)
is_retryable = isinstance(e, RETRYABLE_EXCEPTIONS) and self.request.retries < self.max_retries
# Mark job as failed
redis_client.hset(f"job:{job_id}", mapping={
"status": "failed",
"error": json.dumps({
"message": str(e),
"type": type(e).__name__,
"retryable": is_retryable,
"traceback": traceback.format_exc(),
}),
"failed_at": datetime.utcnow().isoformat(),
})
# Publish error message
error_msg = {
"type": "error",
"job_id": job_id,
"error": {
"message": str(e),
"type": type(e).__name__,
"retryable": is_retryable,
},
"timestamp": datetime.utcnow().isoformat(),
}
redis_client.publish(f"job:{job_id}:updates", json.dumps(error_msg))
# Only retry if the error is transient (network, I/O, etc.)
if is_retryable:
print(f"[RETRY] Retrying job {job_id} (attempt {self.request.retries + 1}/{self.max_retries})")
raise self.retry(exc=e, countdown=2 ** self.request.retries)
else:
# Non-retryable error (code bug, validation error, etc.) - fail immediately
print(f"[ERROR] Non-retryable error for job {job_id}: {type(e).__name__}: {e}")
raise
# === Module-level helper functions ===
def update_progress(job_id: str, progress: int, stage: str, message: str) -> None:
"""
Update job progress (wrapper for backward compatibility).
Args:
job_id: Job identifier
progress: Progress percentage (0-100)
stage: Current stage name
message: Status message
"""
# Instantiate task to use its update_progress method
task = TranscriptionTask()
task.update_progress(job_id, progress, stage, message)
def cleanup_temp_files(job_id: str, storage_path: Path = None) -> None:
"""
Clean up temporary files for a job.
Args:
job_id: Job identifier
storage_path: Path to storage directory (uses settings if not provided)
"""
if storage_path is None:
storage_path = settings.storage_path
temp_dir = storage_path / "temp" / job_id
if temp_dir.exists():
shutil.rmtree(temp_dir, ignore_errors=True)