"""Audio processing for pitch shifting and time stretching.""" import subprocess import tempfile import os from typing import Optional, Callable from concurrent.futures import ProcessPoolExecutor, as_completed import numpy as np import soundfile as sf import pyrubberband as pyrb from ..models.session import StemData from ..utils.audio_utils import normalize def get_rubberband_options(stem_type: str) -> list[str]: """ Get rubberband CLI flags optimized for stem type. Args: stem_type: Type of stem ("drums", "bass", or default) Returns: List of CLI flags """ stem_type = stem_type.lower() if "drum" in stem_type or "percussion" in stem_type: return ["--crisp", "6"] # Max transient preservation elif "bass" in stem_type: return ["--crisp", "3", "--fine"] # Precise low-freq handling else: return ["--crisp", "4"] # Default for guitar, synth, piano, etc. def shift_and_stretch_single( audio: np.ndarray, sr: int, semitones: float, tempo_ratio: float, stem_type: str ) -> np.ndarray: """ Single-pass pitch shift + time stretch using rubberband. Args: audio: Audio array sr: Sample rate semitones: Pitch shift amount (positive = up, negative = down) tempo_ratio: Tempo ratio (> 1.0 = faster, < 1.0 = slower) stem_type: Type of stem for optimization Returns: Processed audio array """ # No change needed - return copy if semitones == 0 and tempo_ratio == 1.0: return audio.copy() # If only pitch change, use pyrubberband directly if tempo_ratio == 1.0: return pyrb.pitch_shift(audio, sr, n_steps=semitones) # If only tempo change, use pyrubberband directly if semitones == 0: return pyrb.time_stretch(audio, sr, rate=tempo_ratio) # Both changes - use rubberband CLI for single-pass return _rubberband_cli(audio, sr, semitones, tempo_ratio, stem_type) def _rubberband_cli( audio: np.ndarray, sr: int, semitones: float, tempo_ratio: float, stem_type: str ) -> np.ndarray: """ Use rubberband CLI for combined pitch+tempo processing. """ with tempfile.TemporaryDirectory() as tmpdir: input_path = os.path.join(tmpdir, "input.wav") output_path = os.path.join(tmpdir, "output.wav") # Write input sf.write(input_path, audio, sr) # Build command cmd = ["rubberband"] if semitones != 0: cmd += ["--pitch", str(semitones)] if tempo_ratio != 1.0: cmd += ["--tempo", str(tempo_ratio)] cmd += get_rubberband_options(stem_type) cmd += [input_path, output_path] # Run try: subprocess.run(cmd, check=True, capture_output=True) except subprocess.CalledProcessError as e: # Fall back to two-pass pyrubberband result = pyrb.pitch_shift(audio, sr, n_steps=semitones) result = pyrb.time_stretch(result, sr, rate=tempo_ratio) return result except FileNotFoundError: # rubberband CLI not available, use pyrubberband result = pyrb.pitch_shift(audio, sr, n_steps=semitones) result = pyrb.time_stretch(result, sr, rate=tempo_ratio) return result # Read output result, _ = sf.read(output_path) return result.astype(np.float32) def _process_single_stem(args: tuple) -> tuple[str, np.ndarray]: """Worker function for parallel processing.""" name, audio, sr, semitones, tempo_ratio, stem_type = args result = shift_and_stretch_single(audio, sr, semitones, tempo_ratio, stem_type) return name, result def process_single_stem_standalone( stem_name: str, stem: StemData, semitones: float, tempo_ratio: float ) -> StemData: """ Process a single stem (for use with async processing). Args: stem_name: Name of the stem stem: StemData object semitones: Pitch shift amount tempo_ratio: Tempo ratio Returns: Processed StemData object """ # No change needed if semitones == 0 and tempo_ratio == 1.0: return StemData( name=stem_name, audio=stem.audio.copy(), sample_rate=stem.sample_rate ) # Determine stem type name_lower = stem_name.lower() if "drum" in name_lower or "percussion" in name_lower: stem_type = "drums" elif "bass" in name_lower: stem_type = "bass" else: stem_type = "default" # Process processed_audio = shift_and_stretch_single( stem.audio, stem.sample_rate, semitones, tempo_ratio, stem_type ) return StemData( name=stem_name, audio=processed_audio, sample_rate=stem.sample_rate ) def process_all_stems( stems: dict[str, StemData], semitones: float, tempo_ratio: float, progress_callback: Optional[Callable[[str, str], None]] = None ) -> dict[str, StemData]: """ Process all stems in parallel. Args: stems: Dict of StemData objects semitones: Pitch shift amount tempo_ratio: Tempo ratio (target_bpm / detected_bpm) progress_callback: Optional callback(stem_name, status) Returns: Dict of processed StemData objects """ # No change needed - return copies if semitones == 0 and tempo_ratio == 1.0: return { name: StemData( name=name, audio=stem.audio.copy(), sample_rate=stem.sample_rate ) for name, stem in stems.items() } # Determine stem types def get_stem_type(name: str) -> str: name_lower = name.lower() if "drum" in name_lower or "percussion" in name_lower: return "drums" elif "bass" in name_lower: return "bass" return "default" # Prepare arguments for parallel processing args_list = [ (name, stem.audio, stem.sample_rate, semitones, tempo_ratio, get_stem_type(name)) for name, stem in stems.items() ] results = {} max_workers = min(len(stems), 6) with ProcessPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit(_process_single_stem, args): args[0] for args in args_list } for future in as_completed(futures): name = futures[future] if progress_callback: progress_callback(name, "processing") try: stem_name, processed_audio = future.result() sr = stems[stem_name].sample_rate results[stem_name] = StemData( name=stem_name, audio=processed_audio, sample_rate=sr ) if progress_callback: progress_callback(stem_name, "done") except Exception as e: # On error, keep original results[name] = stems[name] if progress_callback: progress_callback(name, f"error: {e}") return results def mix_stems(stems: dict[str, np.ndarray], sample_rate: int = 48000) -> np.ndarray: """ Sum all stem arrays, apply mastering, and return final mix. Args: stems: Dict mapping stem names to audio arrays sample_rate: Sample rate in Hz (default 48000) Returns: Mastered mixed audio array """ if not stems: return np.array([], dtype=np.float32) # Determine if any stem is stereo and find max length has_stereo = False max_length = 0 for audio in stems.values(): if audio.ndim == 2: has_stereo = True max_length = max(max_length, audio.shape[0]) else: max_length = max(max_length, len(audio)) # Initialize mixed array (stereo if any input is stereo) if has_stereo: mixed = np.zeros((max_length, 2), dtype=np.float64) else: mixed = np.zeros(max_length, dtype=np.float64) # Sum all stems (pad shorter ones, convert mono to stereo if needed) for audio in stems.values(): # Get length based on array shape length = audio.shape[0] if audio.ndim == 2 else len(audio) if has_stereo: # Convert mono to stereo if needed if audio.ndim == 1: stereo_audio = np.column_stack([audio, audio]) else: stereo_audio = audio if length < max_length: mixed[:length] += stereo_audio else: mixed += stereo_audio else: # All mono if length < max_length: mixed[:length] += audio else: mixed += audio # Convert to float32 mixed = mixed.astype(np.float32) # Apply mastering using Pedalboard try: from pedalboard import Pedalboard, Compressor, Limiter board = Pedalboard([ Compressor( threshold_db=-10, ratio=3, attack_ms=10, release_ms=150 ), Limiter( threshold_db=-1, release_ms=100 ) ]) mastered = board(mixed, sample_rate) return mastered except ImportError: # Fallback to normalize if pedalboard not installed return normalize(mixed, peak=0.95)