Spaces:
Sleeping
Sleeping
| """Beat-align and tempo-conform a generated WAV to a target BPM and bar count. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import librosa | |
| import numpy as np | |
| import soundfile as sf | |
| logger = logging.getLogger(__name__) | |
| # Safe range for phase-vocoder time-stretching. Wider than the previous | |
| # [0.7, 1.4] so we actually warp in more cases — librosa's vocoder produces | |
| # acceptable audio across this range for music, and the alternative | |
| # (no warp at all) drifts off the grid completely on loop. | |
| _STRETCH_SAFE_MIN = 0.6 | |
| _STRETCH_SAFE_MAX = 1.7 | |
| def align_to_grid( | |
| input_path: Path, | |
| target_bpm: float, | |
| target_bars: int, | |
| beats_per_bar: int = 4, | |
| ) -> Path: | |
| audio, sr = sf.read(str(input_path), always_2d=True) | |
| audio = audio.astype(np.float32, copy=False) | |
| target_samples = int(round(target_bars * beats_per_bar * 60.0 / target_bpm * sr)) | |
| mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0] | |
| # Pass target_bpm as a prior to librosa — biases the beat tracker away | |
| # from half-time / double-time interpretations of the same grid. | |
| detected_bpm, first_beat = _detect_grid_anchor(mono, sr, start_bpm=target_bpm) | |
| head_offset = 0 | |
| if first_beat is not None and 0 < first_beat < sr * 1.5: | |
| head_offset = first_beat | |
| logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms to first beat") | |
| elif first_beat is None: | |
| head_offset = _detect_first_onset_sample(mono, sr) | |
| if head_offset > 0: | |
| logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms (onset fallback)") | |
| if head_offset > 0: | |
| audio = audio[head_offset:] | |
| mono = mono[head_offset:] | |
| if detected_bpm is not None: | |
| rate, effective_bpm = _best_stretch_rate(detected_bpm, target_bpm) | |
| if rate is not None: | |
| if abs(rate - 1.0) > 1e-3: | |
| audio = _time_stretch_multichannel(audio, rate) | |
| interp_note = ( | |
| f" (interpreted as {effective_bpm:.2f} BPM, " | |
| f"octave={effective_bpm / detected_bpm:.2f}×)" | |
| if abs(effective_bpm - detected_bpm) > 1e-2 | |
| else "" | |
| ) | |
| logger.info( | |
| f"align_to_grid: detected {detected_bpm:.2f} BPM{interp_note}, " | |
| f"stretched by {rate:.4f} to match target {target_bpm:.2f} BPM" | |
| ) | |
| else: | |
| logger.info( | |
| f"align_to_grid: detected {detected_bpm:.2f} BPM has no safe " | |
| f"interpretation vs target {target_bpm:.2f}; skipping warp" | |
| ) | |
| else: | |
| logger.info("align_to_grid: no usable tempo detected; skipping warp") | |
| if audio.shape[0] > target_samples: | |
| audio = audio[:target_samples] | |
| # 8ms tail fade prevents the click at the loop boundary when the | |
| # truncation point lands mid-waveform. | |
| fade_samples = min(int(0.008 * sr), audio.shape[0]) | |
| if fade_samples > 1: | |
| fade = np.linspace(1.0, 0.0, fade_samples, dtype=audio.dtype) | |
| audio[-fade_samples:] *= fade[:, np.newaxis] if audio.ndim > 1 else fade | |
| elif audio.shape[0] < target_samples: | |
| pad = np.zeros((target_samples - audio.shape[0], audio.shape[1]), dtype=audio.dtype) | |
| audio = np.concatenate([audio, pad], axis=0) | |
| sf.write(str(input_path), audio, sr, subtype="PCM_16") | |
| return input_path | |
| def _best_stretch_rate( | |
| detected_bpm: float, | |
| target_bpm: float, | |
| ) -> Tuple[Optional[float], float]: | |
| """Pick the time-stretch rate that maps detected → target, considering | |
| half-time and double-time interpretations of the detected tempo. Returns | |
| (rate, effective_bpm) where effective_bpm is the (possibly octave- | |
| corrected) interpretation that was chosen, or (None, detected_bpm) if | |
| nothing safe is available. | |
| Order of preference: | |
| 1. Detected as-is, if it lands inside the safe stretch range. | |
| 2. Octave-corrected (detected × 0.5 or × 2.0), only when the as-is | |
| interpretation is out of range. This is the librosa half-/double- | |
| time error recovery path. | |
| This biases the algorithm toward honesty: only re-interpret the | |
| detector's reading when it can't otherwise produce a usable stretch. | |
| """ | |
| # First, try the detector's reading at face value. | |
| rate_asis = target_bpm / detected_bpm | |
| if _STRETCH_SAFE_MIN <= rate_asis <= _STRETCH_SAFE_MAX: | |
| return rate_asis, detected_bpm | |
| # As-is is out of safe range — almost certainly a librosa octave error. | |
| # Try the half-time and double-time reinterpretations and pick whichever | |
| # is closest to a no-op stretch. | |
| candidates = [] | |
| for octave_factor in (0.5, 2.0): | |
| interpreted = detected_bpm * octave_factor | |
| rate = target_bpm / interpreted | |
| if _STRETCH_SAFE_MIN <= rate <= _STRETCH_SAFE_MAX: | |
| candidates.append((abs(rate - 1.0), rate, interpreted)) | |
| if not candidates: | |
| return None, detected_bpm | |
| candidates.sort() | |
| _, best_rate, best_interp = candidates[0] | |
| return best_rate, best_interp | |
| def _detect_first_onset_sample(mono: np.ndarray, sr: int) -> int: | |
| """Return the sample index of the first detected onset, or 0 if none found.""" | |
| try: | |
| onsets = librosa.onset.onset_detect( | |
| y=mono, sr=sr, units="samples", backtrack=True | |
| ) | |
| except Exception as exc: | |
| logger.warning(f"onset detection failed: {exc}") | |
| return 0 | |
| if onsets is None or len(onsets) == 0: | |
| return 0 | |
| first = int(onsets[0]) | |
| if first > sr * 1.0: | |
| return 0 | |
| return first | |
| def _detect_grid_anchor( | |
| mono: np.ndarray, | |
| sr: int, | |
| start_bpm: Optional[float] = None, | |
| ) -> Tuple[Optional[float], Optional[int]]: | |
| """Run librosa beat tracking with the target tempo as a prior. Passing | |
| start_bpm reduces (but doesn't eliminate) half-time / double-time errors. | |
| The octave-correction in _best_stretch_rate handles whatever librosa | |
| still gets wrong.""" | |
| try: | |
| kwargs = {"y": mono, "sr": sr, "units": "samples"} | |
| if start_bpm is not None and start_bpm > 0: | |
| kwargs["start_bpm"] = float(start_bpm) | |
| tempo, beats = librosa.beat.beat_track(**kwargs) | |
| except Exception as exc: | |
| logger.warning(f"beat tracking failed: {exc}") | |
| return None, None | |
| if beats is None or len(beats) < 4: | |
| return None, None | |
| bpm = float(np.atleast_1d(tempo).flatten()[0]) | |
| if not (40.0 <= bpm <= 240.0): | |
| return None, None | |
| return bpm, int(beats[0]) | |
| def _time_stretch_multichannel(audio: np.ndarray, rate: float) -> np.ndarray: | |
| """Phase-vocoder time stretch, applied per channel and re-stacked.""" | |
| stretched = librosa.effects.time_stretch(audio.T, rate=rate) | |
| return np.ascontiguousarray(stretched.T) | |