jam-tracks / backend /services /audio_processor.py
Mina Emadi
adding post processing and increasing the number of co-current processes
184639f
"""Audio processing for pitch shifting and time stretching."""
import subprocess
import tempfile
import os
from typing import Optional, Callable
from concurrent.futures import ProcessPoolExecutor, as_completed
import numpy as np
import soundfile as sf
import pyrubberband as pyrb
from ..models.session import StemData
from ..utils.audio_utils import normalize
def get_rubberband_options(stem_type: str) -> list[str]:
"""
Get rubberband CLI flags optimized for stem type.
Args:
stem_type: Type of stem ("drums", "bass", or default)
Returns:
List of CLI flags
"""
stem_type = stem_type.lower()
if "drum" in stem_type or "percussion" in stem_type:
return ["--crisp", "6"] # Max transient preservation
elif "bass" in stem_type:
return ["--crisp", "3", "--fine"] # Precise low-freq handling
else:
return ["--crisp", "4"] # Default for guitar, synth, piano, etc.
def shift_and_stretch_single(
audio: np.ndarray,
sr: int,
semitones: float,
tempo_ratio: float,
stem_type: str
) -> np.ndarray:
"""
Single-pass pitch shift + time stretch using rubberband.
Args:
audio: Audio array
sr: Sample rate
semitones: Pitch shift amount (positive = up, negative = down)
tempo_ratio: Tempo ratio (> 1.0 = faster, < 1.0 = slower)
stem_type: Type of stem for optimization
Returns:
Processed audio array
"""
# No change needed - return copy
if semitones == 0 and tempo_ratio == 1.0:
return audio.copy()
# If only pitch change, use pyrubberband directly
if tempo_ratio == 1.0:
return pyrb.pitch_shift(audio, sr, n_steps=semitones)
# If only tempo change, use pyrubberband directly
if semitones == 0:
return pyrb.time_stretch(audio, sr, rate=tempo_ratio)
# Both changes - use rubberband CLI for single-pass
return _rubberband_cli(audio, sr, semitones, tempo_ratio, stem_type)
def _rubberband_cli(
audio: np.ndarray,
sr: int,
semitones: float,
tempo_ratio: float,
stem_type: str
) -> np.ndarray:
"""
Use rubberband CLI for combined pitch+tempo processing.
"""
with tempfile.TemporaryDirectory() as tmpdir:
input_path = os.path.join(tmpdir, "input.wav")
output_path = os.path.join(tmpdir, "output.wav")
# Write input
sf.write(input_path, audio, sr)
# Build command
cmd = ["rubberband"]
if semitones != 0:
cmd += ["--pitch", str(semitones)]
if tempo_ratio != 1.0:
cmd += ["--tempo", str(tempo_ratio)]
cmd += get_rubberband_options(stem_type)
cmd += [input_path, output_path]
# Run
try:
subprocess.run(cmd, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
# Fall back to two-pass pyrubberband
result = pyrb.pitch_shift(audio, sr, n_steps=semitones)
result = pyrb.time_stretch(result, sr, rate=tempo_ratio)
return result
except FileNotFoundError:
# rubberband CLI not available, use pyrubberband
result = pyrb.pitch_shift(audio, sr, n_steps=semitones)
result = pyrb.time_stretch(result, sr, rate=tempo_ratio)
return result
# Read output
result, _ = sf.read(output_path)
return result.astype(np.float32)
def _process_single_stem(args: tuple) -> tuple[str, np.ndarray]:
"""Worker function for parallel processing."""
name, audio, sr, semitones, tempo_ratio, stem_type = args
result = shift_and_stretch_single(audio, sr, semitones, tempo_ratio, stem_type)
return name, result
def process_single_stem_standalone(
stem_name: str,
stem: StemData,
semitones: float,
tempo_ratio: float
) -> StemData:
"""
Process a single stem (for use with async processing).
Args:
stem_name: Name of the stem
stem: StemData object
semitones: Pitch shift amount
tempo_ratio: Tempo ratio
Returns:
Processed StemData object
"""
# No change needed
if semitones == 0 and tempo_ratio == 1.0:
return StemData(
name=stem_name,
audio=stem.audio.copy(),
sample_rate=stem.sample_rate
)
# Determine stem type
name_lower = stem_name.lower()
if "drum" in name_lower or "percussion" in name_lower:
stem_type = "drums"
elif "bass" in name_lower:
stem_type = "bass"
else:
stem_type = "default"
# Process
processed_audio = shift_and_stretch_single(
stem.audio, stem.sample_rate, semitones, tempo_ratio, stem_type
)
return StemData(
name=stem_name,
audio=processed_audio,
sample_rate=stem.sample_rate
)
def process_all_stems(
stems: dict[str, StemData],
semitones: float,
tempo_ratio: float,
progress_callback: Optional[Callable[[str, str], None]] = None
) -> dict[str, StemData]:
"""
Process all stems in parallel.
Args:
stems: Dict of StemData objects
semitones: Pitch shift amount
tempo_ratio: Tempo ratio (target_bpm / detected_bpm)
progress_callback: Optional callback(stem_name, status)
Returns:
Dict of processed StemData objects
"""
# No change needed - return copies
if semitones == 0 and tempo_ratio == 1.0:
return {
name: StemData(
name=name,
audio=stem.audio.copy(),
sample_rate=stem.sample_rate
)
for name, stem in stems.items()
}
# Determine stem types
def get_stem_type(name: str) -> str:
name_lower = name.lower()
if "drum" in name_lower or "percussion" in name_lower:
return "drums"
elif "bass" in name_lower:
return "bass"
return "default"
# Prepare arguments for parallel processing
args_list = [
(name, stem.audio, stem.sample_rate, semitones, tempo_ratio, get_stem_type(name))
for name, stem in stems.items()
]
results = {}
max_workers = min(len(stems), 6)
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(_process_single_stem, args): args[0]
for args in args_list
}
for future in as_completed(futures):
name = futures[future]
if progress_callback:
progress_callback(name, "processing")
try:
stem_name, processed_audio = future.result()
sr = stems[stem_name].sample_rate
results[stem_name] = StemData(
name=stem_name,
audio=processed_audio,
sample_rate=sr
)
if progress_callback:
progress_callback(stem_name, "done")
except Exception as e:
# On error, keep original
results[name] = stems[name]
if progress_callback:
progress_callback(name, f"error: {e}")
return results
def mix_stems(stems: dict[str, np.ndarray], sample_rate: int = 48000) -> np.ndarray:
"""
Sum all stem arrays, apply mastering, and return final mix.
Args:
stems: Dict mapping stem names to audio arrays
sample_rate: Sample rate in Hz (default 48000)
Returns:
Mastered mixed audio array
"""
if not stems:
return np.array([], dtype=np.float32)
# Determine if any stem is stereo and find max length
has_stereo = False
max_length = 0
for audio in stems.values():
if audio.ndim == 2:
has_stereo = True
max_length = max(max_length, audio.shape[0])
else:
max_length = max(max_length, len(audio))
# Initialize mixed array (stereo if any input is stereo)
if has_stereo:
mixed = np.zeros((max_length, 2), dtype=np.float64)
else:
mixed = np.zeros(max_length, dtype=np.float64)
# Sum all stems (pad shorter ones, convert mono to stereo if needed)
for audio in stems.values():
# Get length based on array shape
length = audio.shape[0] if audio.ndim == 2 else len(audio)
if has_stereo:
# Convert mono to stereo if needed
if audio.ndim == 1:
stereo_audio = np.column_stack([audio, audio])
else:
stereo_audio = audio
if length < max_length:
mixed[:length] += stereo_audio
else:
mixed += stereo_audio
else:
# All mono
if length < max_length:
mixed[:length] += audio
else:
mixed += audio
# Convert to float32
mixed = mixed.astype(np.float32)
# Apply mastering using Pedalboard
try:
from pedalboard import Pedalboard, Compressor, Limiter
board = Pedalboard([
Compressor(
threshold_db=-10,
ratio=3,
attack_ms=10,
release_ms=150
),
Limiter(
threshold_db=-1,
release_ms=100
)
])
mastered = board(mixed, sample_rate)
return mastered
except ImportError:
# Fallback to normalize if pedalboard not installed
return normalize(mixed, peak=0.95)