Background Job Processing with Celery
Overview
Celery workers handle the heavy audio processing asynchronously, allowing the API to respond immediately while transcription happens in the background.
Why Celery?
- Async Processing: Transcription takes 1-2 minutes—can't block HTTP requests
- Retry Logic: Automatic retries for transient failures
- Priority Queues: Urgent jobs (premium users) get processed first
- Distributed: Scale workers horizontally across machines
- Monitoring: Built-in tools (Flower) for observability
Architecture
graph TB
FastAPI["FastAPI<br/>(submits job)"]
RedisQueue["Redis Queue<br/>(broker)"]
Worker["Celery Worker<br/>(GPU-enabled)<br/>executes transcription pipeline"]
RedisResult["Redis<br/>(result backend + pub/sub for progress)"]
FastAPI -->|"task.delay()"| RedisQueue
RedisQueue -->|"worker pulls task"| Worker
Worker -->|"updates progress"| RedisResult
Celery Configuration
celery_app.py
from celery import Celery
from kombu import Exchange, Queue
# Initialize Celery
celery_app = Celery(
"rescored",
broker="redis://localhost:6379/0",
backend="redis://localhost:6379/0",
)
# Configuration
celery_app.conf.update(
task_serializer="json",
accept_content=["json"],
result_serializer="json",
timezone="UTC",
enable_utc=True,
# Task settings
task_track_started=True,
task_time_limit=600, # 10 minutes max per task
task_soft_time_limit=540, # Soft limit at 9 minutes
task_acks_late=True, # Acknowledge task after completion (safer)
worker_prefetch_multiplier=1, # Take 1 task at a time (better for long tasks)
# Retry settings
task_autoretry_for=(Exception,),
task_retry_kwargs={'max_retries': 3},
task_retry_backoff=True, # Exponential backoff: 1s, 2s, 4s
task_retry_backoff_max=600,
# Priority queues
task_queues=(
Queue('default', Exchange('default'), routing_key='default', priority=5),
Queue('high_priority', Exchange('high_priority'), routing_key='high_priority', priority=10),
),
task_default_queue='default',
task_default_routing_key='default',
)
Main Worker Task
tasks.py
from celery import Task
from celery_app import celery_app
from pipeline import AudioPipeline
import redis
import json
from datetime import datetime
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
class TranscriptionTask(Task):
"""Base task with progress tracking."""
def update_progress(self, job_id: str, progress: int, stage: str, message: str):
"""Update job progress in Redis and publish to WebSocket subscribers."""
job_key = f"job:{job_id}"
# Update Redis hash
redis_client.hset(job_key, mapping={
"progress": progress,
"current_stage": stage,
"status_message": message,
"updated_at": datetime.utcnow().isoformat(),
})
# Publish to pub/sub for WebSocket clients
update = {
"type": "progress",
"job_id": job_id,
"progress": progress,
"stage": stage,
"message": message,
"timestamp": datetime.utcnow().isoformat(),
}
redis_client.publish(f"job:{job_id}:updates", json.dumps(update))
@celery_app.task(base=TranscriptionTask, bind=True)
def process_transcription_task(self, job_id: str):
"""
Main transcription task.
Args:
job_id: Unique job identifier
Returns:
Path to generated MusicXML file
"""
try:
# Mark job as started
redis_client.hset(f"job:{job_id}", mapping={
"status": "processing",
"started_at": datetime.utcnow().isoformat(),
})
# Get job data
job_data = redis_client.hgetall(f"job:{job_id}")
video_id = job_data['video_id']
# Initialize pipeline
pipeline = AudioPipeline(
job_id=job_id,
progress_callback=lambda p, s, m: self.update_progress(job_id, p, s, m)
)
# Run pipeline
output_path = pipeline.process(video_id)
# Mark job as completed
redis_client.hset(f"job:{job_id}", mapping={
"status": "completed",
"progress": 100,
"output_path": str(output_path),
"completed_at": datetime.utcnow().isoformat(),
})
# Publish completion message
completion_msg = {
"type": "completed",
"job_id": job_id,
"result_url": f"/api/v1/scores/{job_id}",
"timestamp": datetime.utcnow().isoformat(),
}
redis_client.publish(f"job:{job_id}:updates", json.dumps(completion_msg))
return str(output_path)
except Exception as e:
# Mark job as failed
redis_client.hset(f"job:{job_id}", mapping={
"status": "failed",
"error": json.dumps({
"message": str(e),
"retryable": self.request.retries < self.max_retries,
}),
"failed_at": datetime.utcnow().isoformat(),
})
# Publish error message
error_msg = {
"type": "error",
"job_id": job_id,
"error": {
"message": str(e),
"retryable": self.request.retries < self.max_retries,
},
"timestamp": datetime.utcnow().isoformat(),
}
redis_client.publish(f"job:{job_id}:updates", json.dumps(error_msg))
# Retry if retryable
if self.request.retries < self.max_retries:
raise self.retry(exc=e, countdown=2 ** self.request.retries)
else:
raise
Pipeline Implementation
pipeline.py
from pathlib import Path
from typing import Callable
import tempfile
import shutil
class AudioPipeline:
"""Orchestrates the full transcription pipeline."""
def __init__(self, job_id: str, progress_callback: Callable):
self.job_id = job_id
self.progress = progress_callback
self.temp_dir = Path(tempfile.mkdtemp(prefix=f"rescored_{job_id}_"))
def process(self, video_id: str) -> Path:
"""
Run full pipeline: download → separate → transcribe → MusicXML
Returns:
Path to generated MusicXML file
"""
try:
# Stage 1: Download
self.progress(5, "download", "Validating YouTube URL")
audio_path = self.download_audio(video_id)
# Stage 2: Source Separation
self.progress(20, "separation", "Starting source separation")
stems = self.separate_stems(audio_path)
# Stage 3: Transcription
self.progress(50, "transcription", "Transcribing audio to MIDI")
midi_path = self.transcribe_stems(stems)
# Stage 4: MusicXML
self.progress(90, "musicxml", "Generating MusicXML")
output_path = self.generate_musicxml(midi_path)
self.progress(100, "musicxml", "Transcription complete")
return output_path
finally:
# Cleanup temp files
self.cleanup()
def download_audio(self, video_id: str) -> Path:
# Implementation from pipeline.md
pass
def separate_stems(self, audio_path: Path) -> dict:
# Implementation from pipeline.md
pass
def transcribe_stems(self, stems: dict) -> Path:
# Implementation from pipeline.md (MVP: only 'other' stem)
pass
def generate_musicxml(self, midi_path: Path) -> Path:
# Implementation from pipeline.md
pass
def cleanup(self):
"""Delete temporary files."""
if self.temp_dir.exists():
shutil.rmtree(self.temp_dir)
Worker Startup
Starting Workers Locally
# Start single worker (development)
celery -A tasks worker --loglevel=info
# Start worker with GPU support
celery -A tasks worker --loglevel=info --concurrency=1
# Start with autoscaling (2-10 workers based on queue depth)
celery -A tasks worker --autoscale=10,2
Starting Workers in Production (Docker)
# Dockerfile.worker
FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
RUN apt-get update && apt-get install -y python3.11 python3-pip ffmpeg
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY . .
CMD ["celery", "-A", "tasks", "worker", "--loglevel=info", "--concurrency=1"]
# docker-compose.yml
worker:
build:
context: ./backend
dockerfile: Dockerfile.worker
command: celery -A tasks worker --loglevel=info --concurrency=1
environment:
- CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/0
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
Priority Queues
Assigning Jobs to Queues
# High priority (premium users, future)
process_transcription_task.apply_async(
args=[job_id],
queue='high_priority'
)
# Normal priority
process_transcription_task.delay(job_id) # Uses default queue
Worker Queue Assignment
# Worker 1: Process both queues (high priority first)
celery -A tasks worker -Q high_priority,default
# Worker 2: Only high priority
celery -A tasks worker -Q high_priority
# Worker 3: Only default
celery -A tasks worker -Q default
Monitoring with Flower
Installation
pip install flower
Start Flower
celery -A tasks flower --port=5555
Dashboard: http://localhost:5555
Features:
- Active/queued/completed tasks
- Worker status and resource usage
- Task runtime distribution
- Real-time task monitoring
Error Handling & Retries
Retry Strategy
# Automatic retry with exponential backoff
@celery_app.task(
autoretry_for=(NetworkError, GPUOutOfMemoryError),
retry_kwargs={'max_retries': 3},
retry_backoff=True, # 1s, 2s, 4s
)
def process_transcription_task(job_id: str):
# ...
Handling Specific Errors
try:
result = demucs_inference(audio)
except torch.cuda.OutOfMemoryError:
# Retry with CPU
result = demucs_inference(audio, device='cpu')
except yt_dlp.utils.DownloadError as e:
if "age-restricted" in str(e):
# Don't retry, permanent failure
raise NonRetryableError("Age-restricted video") from e
else:
# Retry (network issue)
raise
Resource Management
GPU Memory Management
import torch
def process_with_gpu_cleanup(audio_path):
try:
# Load model and process
result = run_demucs(audio_path)
return result
finally:
# Free GPU memory
torch.cuda.empty_cache()
Concurrency Limits
Single GPU: Run 1 worker at a time (concurrency=1)
- Demucs needs full GPU memory
- Multiple workers would OOM
Multiple GPUs: Run 1 worker per GPU
# Worker 1 on GPU 0
CUDA_VISIBLE_DEVICES=0 celery -A tasks worker --concurrency=1
# Worker 2 on GPU 1
CUDA_VISIBLE_DEVICES=1 celery -A tasks worker --concurrency=1
Task Routing
Custom Routing
# Route different task types to different queues
celery_app.conf.task_routes = {
'tasks.process_transcription_task': {'queue': 'transcription'},
'tasks.generate_pdf_task': {'queue': 'pdf_export'},
}
Testing Workers
Unit Test
import pytest
from tasks import process_transcription_task
@pytest.mark.celery(result_backend='redis://localhost:6379/1')
def test_transcription_task():
result = process_transcription_task.delay("test_job_id")
assert result.get(timeout=10) is not None
Manual Test
# In Python shell
from tasks import process_transcription_task
# Synchronous execution (for debugging)
result = process_transcription_task("test_job_id")
print(result)
Production Optimizations
Model Preloading
Load ML models once on worker startup (not per task):
# tasks.py
from celery.signals import worker_process_init
@worker_process_init.connect
def init_worker(**kwargs):
"""Initialize models on worker startup."""
global demucs_model, basicpitch_model
demucs_model = load_demucs_model()
basicpitch_model = load_basicpitch_model()
@celery_app.task
def process_transcription_task(job_id: str):
# Use pre-loaded models
stems = demucs_model.separate(audio)
# ...
Benefit: Saves 5-10 seconds per job (no model re-loading)
Result Expiration
Don't keep results forever in Redis:
celery_app.conf.result_expires = 3600 # 1 hour
Task Deduplication
Prevent duplicate jobs for same video:
@celery_app.task(bind=True)
def process_transcription_task(self, job_id: str):
# Check if job already processed
if redis_client.exists(f"result:{job_id}"):
return redis_client.get(f"result:{job_id}")
# ... process
Scaling Strategies
Horizontal Scaling
Add more workers:
# Start 3 workers on different machines
# Machine 1
celery -A tasks worker --hostname=worker1@%h
# Machine 2
celery -A tasks worker --hostname=worker2@%h
# Machine 3
celery -A tasks worker --hostname=worker3@%h
Auto-Scaling
Cloud-based (AWS ECS, K8s):
- Scale workers based on queue depth
- If queue > 10 jobs, add worker
- If queue < 2 jobs, remove worker
Serverless GPU (Modal, RunPod):
- Workers auto-scale from 0 to N
- Pay only when processing
Next Steps
- Implement complete Audio Processing Pipeline
- Test worker with sample YouTube videos
- Monitor GPU utilization and optimize concurrency
- Set up Flower for production monitoring
See Deployment Strategy for production worker deployment.