fragmenta / app /core /generation /audio_post_process.py
MazCodes's picture
Upload folder using huggingface_hub
c7986a9 verified
raw
history blame
6.89 kB
"""Beat-align and tempo-conform a generated WAV to a target BPM and bar count.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional, Tuple
import librosa
import numpy as np
import soundfile as sf
logger = logging.getLogger(__name__)
# Safe range for phase-vocoder time-stretching. Wider than the previous
# [0.7, 1.4] so we actually warp in more cases — librosa's vocoder produces
# acceptable audio across this range for music, and the alternative
# (no warp at all) drifts off the grid completely on loop.
_STRETCH_SAFE_MIN = 0.6
_STRETCH_SAFE_MAX = 1.7
def align_to_grid(
input_path: Path,
target_bpm: float,
target_bars: int,
beats_per_bar: int = 4,
) -> Path:
audio, sr = sf.read(str(input_path), always_2d=True)
audio = audio.astype(np.float32, copy=False)
target_samples = int(round(target_bars * beats_per_bar * 60.0 / target_bpm * sr))
mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
# Pass target_bpm as a prior to librosa — biases the beat tracker away
# from half-time / double-time interpretations of the same grid.
detected_bpm, first_beat = _detect_grid_anchor(mono, sr, start_bpm=target_bpm)
head_offset = 0
if first_beat is not None and 0 < first_beat < sr * 1.5:
head_offset = first_beat
logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms to first beat")
elif first_beat is None:
head_offset = _detect_first_onset_sample(mono, sr)
if head_offset > 0:
logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms (onset fallback)")
if head_offset > 0:
audio = audio[head_offset:]
mono = mono[head_offset:]
if detected_bpm is not None:
rate, effective_bpm = _best_stretch_rate(detected_bpm, target_bpm)
if rate is not None:
if abs(rate - 1.0) > 1e-3:
audio = _time_stretch_multichannel(audio, rate)
interp_note = (
f" (interpreted as {effective_bpm:.2f} BPM, "
f"octave={effective_bpm / detected_bpm:.2f}×)"
if abs(effective_bpm - detected_bpm) > 1e-2
else ""
)
logger.info(
f"align_to_grid: detected {detected_bpm:.2f} BPM{interp_note}, "
f"stretched by {rate:.4f} to match target {target_bpm:.2f} BPM"
)
else:
logger.info(
f"align_to_grid: detected {detected_bpm:.2f} BPM has no safe "
f"interpretation vs target {target_bpm:.2f}; skipping warp"
)
else:
logger.info("align_to_grid: no usable tempo detected; skipping warp")
if audio.shape[0] > target_samples:
audio = audio[:target_samples]
# 8ms tail fade prevents the click at the loop boundary when the
# truncation point lands mid-waveform.
fade_samples = min(int(0.008 * sr), audio.shape[0])
if fade_samples > 1:
fade = np.linspace(1.0, 0.0, fade_samples, dtype=audio.dtype)
audio[-fade_samples:] *= fade[:, np.newaxis] if audio.ndim > 1 else fade
elif audio.shape[0] < target_samples:
pad = np.zeros((target_samples - audio.shape[0], audio.shape[1]), dtype=audio.dtype)
audio = np.concatenate([audio, pad], axis=0)
sf.write(str(input_path), audio, sr, subtype="PCM_16")
return input_path
def _best_stretch_rate(
detected_bpm: float,
target_bpm: float,
) -> Tuple[Optional[float], float]:
"""Pick the time-stretch rate that maps detected → target, considering
half-time and double-time interpretations of the detected tempo. Returns
(rate, effective_bpm) where effective_bpm is the (possibly octave-
corrected) interpretation that was chosen, or (None, detected_bpm) if
nothing safe is available.
Order of preference:
1. Detected as-is, if it lands inside the safe stretch range.
2. Octave-corrected (detected × 0.5 or × 2.0), only when the as-is
interpretation is out of range. This is the librosa half-/double-
time error recovery path.
This biases the algorithm toward honesty: only re-interpret the
detector's reading when it can't otherwise produce a usable stretch.
"""
# First, try the detector's reading at face value.
rate_asis = target_bpm / detected_bpm
if _STRETCH_SAFE_MIN <= rate_asis <= _STRETCH_SAFE_MAX:
return rate_asis, detected_bpm
# As-is is out of safe range — almost certainly a librosa octave error.
# Try the half-time and double-time reinterpretations and pick whichever
# is closest to a no-op stretch.
candidates = []
for octave_factor in (0.5, 2.0):
interpreted = detected_bpm * octave_factor
rate = target_bpm / interpreted
if _STRETCH_SAFE_MIN <= rate <= _STRETCH_SAFE_MAX:
candidates.append((abs(rate - 1.0), rate, interpreted))
if not candidates:
return None, detected_bpm
candidates.sort()
_, best_rate, best_interp = candidates[0]
return best_rate, best_interp
def _detect_first_onset_sample(mono: np.ndarray, sr: int) -> int:
"""Return the sample index of the first detected onset, or 0 if none found."""
try:
onsets = librosa.onset.onset_detect(
y=mono, sr=sr, units="samples", backtrack=True
)
except Exception as exc:
logger.warning(f"onset detection failed: {exc}")
return 0
if onsets is None or len(onsets) == 0:
return 0
first = int(onsets[0])
if first > sr * 1.0:
return 0
return first
def _detect_grid_anchor(
mono: np.ndarray,
sr: int,
start_bpm: Optional[float] = None,
) -> Tuple[Optional[float], Optional[int]]:
"""Run librosa beat tracking with the target tempo as a prior. Passing
start_bpm reduces (but doesn't eliminate) half-time / double-time errors.
The octave-correction in _best_stretch_rate handles whatever librosa
still gets wrong."""
try:
kwargs = {"y": mono, "sr": sr, "units": "samples"}
if start_bpm is not None and start_bpm > 0:
kwargs["start_bpm"] = float(start_bpm)
tempo, beats = librosa.beat.beat_track(**kwargs)
except Exception as exc:
logger.warning(f"beat tracking failed: {exc}")
return None, None
if beats is None or len(beats) < 4:
return None, None
bpm = float(np.atleast_1d(tempo).flatten()[0])
if not (40.0 <= bpm <= 240.0):
return None, None
return bpm, int(beats[0])
def _time_stretch_multichannel(audio: np.ndarray, rate: float) -> np.ndarray:
"""Phase-vocoder time stretch, applied per channel and re-stacked."""
stretched = librosa.effects.time_stretch(audio.T, rate=rate)
return np.ascontiguousarray(stretched.T)