File size: 13,184 Bytes
44a2550 a5359f9 44a2550 5ad0f35 44a2550 f09b953 44a2550 75d3906 44a2550 5ad0f35 44a2550 ac5c764 44a2550 75d3906 44a2550 f849d05 44a2550 f849d05 75d3906 f849d05 44a2550 a5359f9 44a2550 a5359f9 44a2550 a5359f9 44a2550 a5359f9 44a2550 a5359f9 44a2550 3c8b7e3 6293e69 44a2550 6293e69 44a2550 6293e69 3c8b7e3 6293e69 44a2550 75d3906 6293e69 44a2550 6293e69 44a2550 6293e69 44a2550 6293e69 44a2550 6293e69 44a2550 6293e69 44a2550 ac5c764 44a2550 ac5c764 44a2550 ac5c764 44a2550 ac5c764 44a2550 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 | """Celery tasks for background job processing."""
import sys
from pathlib import Path
# Ensure backend directory is in Python path for imports
backend_dir = Path(__file__).parent.resolve()
if str(backend_dir) not in sys.path:
sys.path.insert(0, str(backend_dir))
from celery import Task
from celery_app import celery_app
from pipeline import TranscriptionPipeline, run_transcription_pipeline
from redis_client import get_redis_client
import json
import os
from datetime import datetime
from app_config import settings
import shutil
# Get shared Redis client singleton
redis_client = get_redis_client()
class TranscriptionTask(Task):
"""Base task with progress tracking."""
def update_progress(self, job_id: str, progress: int, stage: str, message: str) -> None:
"""
Update job progress in Redis and publish to WebSocket subscribers.
Args:
job_id: Job identifier
progress: Progress percentage (0-100)
stage: Current stage name
message: Status message
"""
print(f"[PROGRESS] {progress}% - {stage} - {message}")
job_key = f"job:{job_id}"
# Update Redis hash
redis_client.hset(job_key, mapping={
"progress": progress,
"current_stage": stage,
"status_message": message,
"updated_at": datetime.utcnow().isoformat(),
})
# Store progress in a list for history (helps with eager mode where client connects late)
update = {
"type": "progress",
"job_id": job_id,
"progress": progress,
"stage": stage,
"message": message,
"timestamp": datetime.utcnow().isoformat(),
}
redis_client.rpush(f"job:{job_id}:progress_history", json.dumps(update))
# Publish to pub/sub for WebSocket clients (in case they're already connected)
num_subscribers = redis_client.publish(f"job:{job_id}:updates", json.dumps(update))
if num_subscribers > 0:
print(f"[PROGRESS] Published to {num_subscribers} subscribers")
else:
print(f"[PROGRESS] Stored in history (no live subscribers)")
@celery_app.task(base=TranscriptionTask, bind=True)
def process_transcription_task(self, job_id: str):
"""
Main transcription task.
Args:
job_id: Unique job identifier
Returns:
Path to generated MusicXML file
"""
try:
# Mark job as started
redis_client.hset(f"job:{job_id}", mapping={
"status": "processing",
"started_at": datetime.utcnow().isoformat(),
})
# Get job data
job_data = redis_client.hgetall(f"job:{job_id}")
if not job_data:
raise ValueError(f"Job not found: {job_id}")
# Check if this is a file upload or YouTube URL job
upload_path = job_data.get('upload_path')
youtube_url = job_data.get('youtube_url')
# Parse instruments option (defaults to piano only)
instruments = ['piano']
vocal_instrument_program = 40 # Default to violin
if 'options' in job_data:
try:
options = json.loads(job_data['options'])
instruments = options.get('instruments', ['piano'])
vocal_instrument_program = options.get('vocal_instrument', 40)
except (json.JSONDecodeError, KeyError):
instruments = ['piano']
vocal_instrument_program = 40
# Import shutil and subprocess
import shutil
import subprocess
# Create pipeline
pipeline = TranscriptionPipeline(
job_id=job_id,
youtube_url=youtube_url or "file://uploaded", # Dummy URL for file uploads
storage_path=settings.storage_path,
instruments=instruments
)
pipeline.set_progress_callback(lambda p, s, m: self.update_progress(job_id, p, s, m))
# Get audio.wav - either from upload or YouTube download
audio_path = pipeline.temp_dir / "audio.wav"
if upload_path:
# File upload - convert to WAV if needed
upload_file = Path(upload_path)
if upload_file.suffix.lower() == '.wav':
shutil.copy(str(upload_file), str(audio_path))
else:
# Convert to WAV using ffmpeg
result = subprocess.run([
'ffmpeg', '-i', str(upload_file),
'-ar', '44100', '-ac', '2',
str(audio_path)
], capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Audio conversion failed: {result.stderr}")
elif youtube_url:
# YouTube download
pipeline.progress(0, "download", "Starting audio download")
audio_path = pipeline.download_audio()
else:
raise ValueError(f"Job missing both youtube_url and upload_path: {job_id}")
# From here, both paths converge - process audio.wav the same way
# Preprocess audio if enabled
if pipeline.config.enable_audio_preprocessing:
pipeline.progress(10, "preprocess", "Preprocessing audio")
audio_path = pipeline.preprocess_audio(audio_path)
# Source separation
pipeline.progress(20, "separate", "Starting source separation")
all_stems = pipeline.separate_sources(audio_path)
# Select stems to transcribe based on user selection
stems_to_transcribe = {}
for instrument in instruments:
if instrument in all_stems:
stems_to_transcribe[instrument] = all_stems[instrument]
print(f" [DEBUG] Will transcribe {instrument} stem")
else:
print(f" [WARNING] {instrument} stem not found in separated audio")
# If no selected stems available, fall back to piano
if not stems_to_transcribe:
print(f" [WARNING] No selected stems found, falling back to piano")
if 'piano' in all_stems:
stems_to_transcribe['piano'] = all_stems['piano']
else:
stems_to_transcribe['other'] = all_stems['other']
pipeline.progress(50, "transcribe", f"Transcribing {len(stems_to_transcribe)} instrument(s)")
# Transcribe stems
if len(stems_to_transcribe) == 1:
# Single stem - use original method
stem_path = list(stems_to_transcribe.values())[0]
combined_midi = pipeline.transcribe_to_midi(stem_path)
else:
# Multiple stems - use new multi-stem method
combined_midi = pipeline.transcribe_multiple_stems(stems_to_transcribe)
# Filter MIDI to only include selected instruments
filtered_midi = pipeline.filter_midi_by_instruments(combined_midi)
# Remap vocals MIDI program if vocals were selected
if 'vocals' in instruments and vocal_instrument_program != 65:
print(f" [DEBUG] Remapping vocals MIDI program from 65 to {vocal_instrument_program}")
import pretty_midi
pm = pretty_midi.PrettyMIDI(str(filtered_midi))
for inst in pm.instruments:
if inst.program == 65 and not inst.is_drum: # Singing Voice
inst.program = vocal_instrument_program
print(f" [DEBUG] Changed track '{inst.name}' program to {vocal_instrument_program}")
# Save remapped MIDI
pm.write(str(filtered_midi))
# Apply post-processing
midi_path = pipeline.apply_post_processing_filters(filtered_midi)
pipeline.final_midi_path = midi_path
# Get audio stem for MusicXML generation (use piano if available, otherwise first available stem)
audio_stem = stems_to_transcribe.get('piano') or list(stems_to_transcribe.values())[0]
pipeline.progress(90, "musicxml", "Generating MusicXML")
temp_output_path = pipeline.generate_musicxml_minimal(midi_path, audio_stem)
pipeline.progress(100, "complete", "Transcription complete")
# Output is already in the temp directory, move to persistent storage
output_path = settings.outputs_path / f"{job_id}.musicxml"
midi_path = settings.outputs_path / f"{job_id}.mid"
# Ensure outputs directory exists
settings.outputs_path.mkdir(parents=True, exist_ok=True)
# Copy the MusicXML file to outputs
shutil.copy(str(temp_output_path), str(output_path))
# Copy the MIDI file to outputs (use actual processed MIDI from pipeline)
# Pipeline stores the final MIDI path (after quantization) in final_midi_path
temp_midi_path = getattr(pipeline, 'final_midi_path', pipeline.temp_dir / "piano.mid")
print(f"[DEBUG] Using MIDI from pipeline: {temp_midi_path}")
print(f"[DEBUG] MIDI exists: {temp_midi_path.exists()}")
if temp_midi_path.exists():
print(f"[DEBUG] Copying MIDI from {temp_midi_path} to {midi_path}")
shutil.copy(str(temp_midi_path), str(midi_path))
print(f"[DEBUG] Copy complete, destination exists: {midi_path.exists()}")
else:
print(f"[DEBUG] WARNING: No MIDI file found at {temp_midi_path}!")
# Store metadata for API access
metadata = getattr(pipeline, 'metadata', {
"tempo": 120.0,
"time_signature": {"numerator": 4, "denominator": 4},
"key_signature": "C",
})
# Cleanup temp files (pipeline has its own cleanup method)
pipeline.cleanup()
# Mark job as completed
redis_client.hset(f"job:{job_id}", mapping={
"status": "completed",
"progress": 100,
"output_path": str(output_path.absolute()),
"midi_path": str(midi_path.absolute()) if temp_midi_path.exists() else "",
"metadata": json.dumps(metadata),
"completed_at": datetime.utcnow().isoformat(),
})
# Publish completion message
completion_msg = {
"type": "completed",
"job_id": job_id,
"result_url": f"/api/v1/scores/{job_id}",
"timestamp": datetime.utcnow().isoformat(),
}
redis_client.publish(f"job:{job_id}:updates", json.dumps(completion_msg))
return str(output_path)
except Exception as e:
import traceback
# Determine if error is retryable (only retry transient errors, not code bugs)
RETRYABLE_EXCEPTIONS = (
ConnectionError, # Network errors
TimeoutError, # Timeout errors
IOError, # I/O errors (file system, disk full, etc.)
)
is_retryable = isinstance(e, RETRYABLE_EXCEPTIONS) and self.request.retries < self.max_retries
# Mark job as failed
redis_client.hset(f"job:{job_id}", mapping={
"status": "failed",
"error": json.dumps({
"message": str(e),
"type": type(e).__name__,
"retryable": is_retryable,
"traceback": traceback.format_exc(),
}),
"failed_at": datetime.utcnow().isoformat(),
})
# Publish error message
error_msg = {
"type": "error",
"job_id": job_id,
"error": {
"message": str(e),
"type": type(e).__name__,
"retryable": is_retryable,
},
"timestamp": datetime.utcnow().isoformat(),
}
redis_client.publish(f"job:{job_id}:updates", json.dumps(error_msg))
# Only retry if the error is transient (network, I/O, etc.)
if is_retryable:
print(f"[RETRY] Retrying job {job_id} (attempt {self.request.retries + 1}/{self.max_retries})")
raise self.retry(exc=e, countdown=2 ** self.request.retries)
else:
# Non-retryable error (code bug, validation error, etc.) - fail immediately
print(f"[ERROR] Non-retryable error for job {job_id}: {type(e).__name__}: {e}")
raise
# === Module-level helper functions ===
def update_progress(job_id: str, progress: int, stage: str, message: str) -> None:
"""
Update job progress (wrapper for backward compatibility).
Args:
job_id: Job identifier
progress: Progress percentage (0-100)
stage: Current stage name
message: Status message
"""
# Instantiate task to use its update_progress method
task = TranscriptionTask()
task.update_progress(job_id, progress, stage, message)
def cleanup_temp_files(job_id: str, storage_path: Path = None) -> None:
"""
Clean up temporary files for a job.
Args:
job_id: Job identifier
storage_path: Path to storage directory (uses settings if not provided)
"""
if storage_path is None:
storage_path = settings.storage_path
temp_dir = storage_path / "temp" / job_id
if temp_dir.exists():
shutil.rmtree(temp_dir, ignore_errors=True)
|