| """Audio processing for pitch shifting and time stretching.""" |
|
|
| import subprocess |
| import tempfile |
| import os |
| from typing import Optional, Callable |
| from concurrent.futures import ProcessPoolExecutor, as_completed |
|
|
| import numpy as np |
| import soundfile as sf |
| import pyrubberband as pyrb |
|
|
| from ..models.session import StemData |
| from ..utils.audio_utils import normalize |
|
|
|
|
| def get_rubberband_options(stem_type: str) -> list[str]: |
| """ |
| Get rubberband CLI flags optimized for stem type. |
| |
| Args: |
| stem_type: Type of stem ("drums", "bass", or default) |
| |
| Returns: |
| List of CLI flags |
| """ |
| stem_type = stem_type.lower() |
|
|
| if "drum" in stem_type or "percussion" in stem_type: |
| return ["--crisp", "6"] |
| elif "bass" in stem_type: |
| return ["--crisp", "3", "--fine"] |
| else: |
| return ["--crisp", "4"] |
|
|
|
|
| def shift_and_stretch_single( |
| audio: np.ndarray, |
| sr: int, |
| semitones: float, |
| tempo_ratio: float, |
| stem_type: str |
| ) -> np.ndarray: |
| """ |
| Single-pass pitch shift + time stretch using rubberband. |
| |
| Args: |
| audio: Audio array |
| sr: Sample rate |
| semitones: Pitch shift amount (positive = up, negative = down) |
| tempo_ratio: Tempo ratio (> 1.0 = faster, < 1.0 = slower) |
| stem_type: Type of stem for optimization |
| |
| Returns: |
| Processed audio array |
| """ |
| |
| if semitones == 0 and tempo_ratio == 1.0: |
| return audio.copy() |
|
|
| |
| if tempo_ratio == 1.0: |
| return pyrb.pitch_shift(audio, sr, n_steps=semitones) |
|
|
| |
| if semitones == 0: |
| return pyrb.time_stretch(audio, sr, rate=tempo_ratio) |
|
|
| |
| return _rubberband_cli(audio, sr, semitones, tempo_ratio, stem_type) |
|
|
|
|
| def _rubberband_cli( |
| audio: np.ndarray, |
| sr: int, |
| semitones: float, |
| tempo_ratio: float, |
| stem_type: str |
| ) -> np.ndarray: |
| """ |
| Use rubberband CLI for combined pitch+tempo processing. |
| """ |
| with tempfile.TemporaryDirectory() as tmpdir: |
| input_path = os.path.join(tmpdir, "input.wav") |
| output_path = os.path.join(tmpdir, "output.wav") |
|
|
| |
| sf.write(input_path, audio, sr) |
|
|
| |
| cmd = ["rubberband"] |
| if semitones != 0: |
| cmd += ["--pitch", str(semitones)] |
| if tempo_ratio != 1.0: |
| cmd += ["--tempo", str(tempo_ratio)] |
| cmd += get_rubberband_options(stem_type) |
| cmd += [input_path, output_path] |
|
|
| |
| try: |
| subprocess.run(cmd, check=True, capture_output=True) |
| except subprocess.CalledProcessError as e: |
| |
| result = pyrb.pitch_shift(audio, sr, n_steps=semitones) |
| result = pyrb.time_stretch(result, sr, rate=tempo_ratio) |
| return result |
| except FileNotFoundError: |
| |
| result = pyrb.pitch_shift(audio, sr, n_steps=semitones) |
| result = pyrb.time_stretch(result, sr, rate=tempo_ratio) |
| return result |
|
|
| |
| result, _ = sf.read(output_path) |
| return result.astype(np.float32) |
|
|
|
|
| def _process_single_stem(args: tuple) -> tuple[str, np.ndarray]: |
| """Worker function for parallel processing.""" |
| name, audio, sr, semitones, tempo_ratio, stem_type = args |
| result = shift_and_stretch_single(audio, sr, semitones, tempo_ratio, stem_type) |
| return name, result |
|
|
|
|
| def process_single_stem_standalone( |
| stem_name: str, |
| stem: StemData, |
| semitones: float, |
| tempo_ratio: float |
| ) -> StemData: |
| """ |
| Process a single stem (for use with async processing). |
| |
| Args: |
| stem_name: Name of the stem |
| stem: StemData object |
| semitones: Pitch shift amount |
| tempo_ratio: Tempo ratio |
| |
| Returns: |
| Processed StemData object |
| """ |
| |
| if semitones == 0 and tempo_ratio == 1.0: |
| return StemData( |
| name=stem_name, |
| audio=stem.audio.copy(), |
| sample_rate=stem.sample_rate |
| ) |
|
|
| |
| name_lower = stem_name.lower() |
| if "drum" in name_lower or "percussion" in name_lower: |
| stem_type = "drums" |
| elif "bass" in name_lower: |
| stem_type = "bass" |
| else: |
| stem_type = "default" |
|
|
| |
| processed_audio = shift_and_stretch_single( |
| stem.audio, stem.sample_rate, semitones, tempo_ratio, stem_type |
| ) |
|
|
| return StemData( |
| name=stem_name, |
| audio=processed_audio, |
| sample_rate=stem.sample_rate |
| ) |
|
|
|
|
| def process_all_stems( |
| stems: dict[str, StemData], |
| semitones: float, |
| tempo_ratio: float, |
| progress_callback: Optional[Callable[[str, str], None]] = None |
| ) -> dict[str, StemData]: |
| """ |
| Process all stems in parallel. |
| |
| Args: |
| stems: Dict of StemData objects |
| semitones: Pitch shift amount |
| tempo_ratio: Tempo ratio (target_bpm / detected_bpm) |
| progress_callback: Optional callback(stem_name, status) |
| |
| Returns: |
| Dict of processed StemData objects |
| """ |
| |
| if semitones == 0 and tempo_ratio == 1.0: |
| return { |
| name: StemData( |
| name=name, |
| audio=stem.audio.copy(), |
| sample_rate=stem.sample_rate |
| ) |
| for name, stem in stems.items() |
| } |
|
|
| |
| def get_stem_type(name: str) -> str: |
| name_lower = name.lower() |
| if "drum" in name_lower or "percussion" in name_lower: |
| return "drums" |
| elif "bass" in name_lower: |
| return "bass" |
| return "default" |
|
|
| |
| args_list = [ |
| (name, stem.audio, stem.sample_rate, semitones, tempo_ratio, get_stem_type(name)) |
| for name, stem in stems.items() |
| ] |
|
|
| results = {} |
| max_workers = min(len(stems), 6) |
|
|
| with ProcessPoolExecutor(max_workers=max_workers) as executor: |
| futures = { |
| executor.submit(_process_single_stem, args): args[0] |
| for args in args_list |
| } |
|
|
| for future in as_completed(futures): |
| name = futures[future] |
| if progress_callback: |
| progress_callback(name, "processing") |
|
|
| try: |
| stem_name, processed_audio = future.result() |
| sr = stems[stem_name].sample_rate |
| results[stem_name] = StemData( |
| name=stem_name, |
| audio=processed_audio, |
| sample_rate=sr |
| ) |
| if progress_callback: |
| progress_callback(stem_name, "done") |
| except Exception as e: |
| |
| results[name] = stems[name] |
| if progress_callback: |
| progress_callback(name, f"error: {e}") |
|
|
| return results |
|
|
|
|
| def mix_stems(stems: dict[str, np.ndarray], sample_rate: int = 48000) -> np.ndarray: |
| """ |
| Sum all stem arrays, apply mastering, and return final mix. |
| |
| Args: |
| stems: Dict mapping stem names to audio arrays |
| sample_rate: Sample rate in Hz (default 48000) |
| |
| Returns: |
| Mastered mixed audio array |
| """ |
| if not stems: |
| return np.array([], dtype=np.float32) |
|
|
| |
| has_stereo = False |
| max_length = 0 |
| for audio in stems.values(): |
| if audio.ndim == 2: |
| has_stereo = True |
| max_length = max(max_length, audio.shape[0]) |
| else: |
| max_length = max(max_length, len(audio)) |
|
|
| |
| if has_stereo: |
| mixed = np.zeros((max_length, 2), dtype=np.float64) |
| else: |
| mixed = np.zeros(max_length, dtype=np.float64) |
|
|
| |
| for audio in stems.values(): |
| |
| length = audio.shape[0] if audio.ndim == 2 else len(audio) |
|
|
| if has_stereo: |
| |
| if audio.ndim == 1: |
| stereo_audio = np.column_stack([audio, audio]) |
| else: |
| stereo_audio = audio |
|
|
| if length < max_length: |
| mixed[:length] += stereo_audio |
| else: |
| mixed += stereo_audio |
| else: |
| |
| if length < max_length: |
| mixed[:length] += audio |
| else: |
| mixed += audio |
|
|
| |
| mixed = mixed.astype(np.float32) |
|
|
| |
| try: |
| from pedalboard import Pedalboard, Compressor, Limiter |
|
|
| board = Pedalboard([ |
| Compressor( |
| threshold_db=-10, |
| ratio=3, |
| attack_ms=10, |
| release_ms=150 |
| ), |
| Limiter( |
| threshold_db=-1, |
| release_ms=100 |
| ) |
| ]) |
|
|
| mastered = board(mixed, sample_rate) |
| return mastered |
|
|
| except ImportError: |
| |
| return normalize(mixed, peak=0.95) |
|
|