| |
| """ |
| Audio super-resolution using FlashSR. |
| |
| Independently written wrapper around the FlashSR model by Jaekwon Im and |
| Juhan Nam (KAIST). Supports files of arbitrary length via windowed processing |
| with overlap-add. No dependency on torchcodec or FFmpeg -- uses soundfile for |
| all I/O. |
| |
| Paper: https://arxiv.org/abs/2501.10807 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import math |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import soundfile as sf |
| import torch |
| from scipy.signal import resample_poly |
|
|
| from FlashSR.FlashSR import FlashSR |
|
|
| |
|
|
| TARGET_SR = 48_000 |
| WINDOW_LEN = 245_760 |
| OVERLAP = 24_000 |
| HOP = WINDOW_LEN - OVERLAP |
|
|
| AUDIO_EXTENSIONS = {".wav", ".flac", ".mp3", ".ogg", ".opus"} |
|
|
|
|
| |
|
|
| def _load_mono(path: str | Path) -> tuple[np.ndarray, int]: |
| """Read an audio file, mix to mono, return (float32 array, sample_rate).""" |
| data, sr = sf.read(str(path), dtype="float32") |
| if data.ndim == 2: |
| data = data.mean(axis=1) |
| return data, sr |
|
|
|
|
| def _resample_if_needed(audio: np.ndarray, orig_sr: int) -> np.ndarray: |
| """Polyphase resample to TARGET_SR when the source rate differs.""" |
| if orig_sr == TARGET_SR: |
| return audio |
| return resample_poly(audio, TARGET_SR, orig_sr).astype(np.float32) |
|
|
|
|
| def _build_fade(length: int) -> torch.Tensor: |
| """Half-cosine fade-in ramp of *length* samples (0 -> 1).""" |
| t = torch.linspace(0.0, math.pi / 2, length) |
| return torch.sin(t) ** 2 |
|
|
|
|
| def _pad_to(tensor: torch.Tensor, n: int) -> torch.Tensor: |
| """Right-zero-pad the last dimension to at least *n* samples.""" |
| deficit = n - tensor.shape[-1] |
| if deficit <= 0: |
| return tensor |
| return torch.nn.functional.pad(tensor, (0, deficit)) |
|
|
|
|
| |
|
|
| def build_model(weights_dir: str | Path, device: torch.device) -> FlashSR: |
| """Instantiate FlashSR and load pretrained weights.""" |
| w = Path(weights_dir) |
| model = FlashSR( |
| student_ldm_ckpt_path=str(w / "student_ldm.pth"), |
| sr_vocoder_ckpt_path=str(w / "sr_vocoder.pth"), |
| autoencoder_ckpt_path=str(w / "vae.pth"), |
| ) |
| return model.to(device).eval() |
|
|
|
|
| @torch.inference_mode() |
| def enhance( |
| model: FlashSR, |
| waveform: np.ndarray, |
| *, |
| device: torch.device, |
| lowpass: bool = False, |
| ) -> np.ndarray: |
| """ |
| Run FlashSR on a mono waveform (numpy float32, 48 kHz). |
| |
| Long inputs are split into overlapping windows and reassembled with |
| overlap-add using a raised-cosine crossfade. |
| |
| Returns enhanced waveform as numpy float32 at 48 kHz. |
| """ |
| signal = torch.from_numpy(waveform).unsqueeze(0) |
| n_samples = signal.shape[-1] |
|
|
| |
| if n_samples <= WINDOW_LEN: |
| chunk = _pad_to(signal, WINDOW_LEN).to(device) |
| out = model(chunk, lowpass_input=lowpass) |
| return out[0, :n_samples].cpu().numpy() |
|
|
| |
| fade = _build_fade(OVERLAP) |
| accumulator = torch.zeros(n_samples) |
| norm = torch.zeros(n_samples) |
|
|
| offset = 0 |
| while offset < n_samples: |
| end = min(offset + WINDOW_LEN, n_samples) |
| segment = signal[:, offset:end] |
| segment = _pad_to(segment, WINDOW_LEN).to(device) |
|
|
| enhanced_seg = model(segment, lowpass_input=lowpass).cpu().squeeze(0) |
| seg_len = min(WINDOW_LEN, n_samples - offset) |
| enhanced_seg = enhanced_seg[:seg_len] |
|
|
| |
| w = torch.ones(seg_len) |
| if offset > 0 and seg_len > OVERLAP: |
| w[:OVERLAP] = fade |
|
|
| accumulator[offset : offset + seg_len] += enhanced_seg * w |
| norm[offset : offset + seg_len] += w |
| offset += HOP |
|
|
| norm.clamp_(min=1e-8) |
| return (accumulator / norm).numpy() |
|
|
|
|
| |
|
|
| def enhance_file( |
| model: FlashSR, |
| src: str | Path, |
| dst: str | Path, |
| *, |
| device: torch.device, |
| lowpass: bool = False, |
| ) -> float: |
| """Enhance one file. Returns duration in seconds.""" |
| raw, sr = _load_mono(src) |
| audio = _resample_if_needed(raw, sr) |
| result = enhance(model, audio, device=device, lowpass=lowpass) |
| os.makedirs(os.path.dirname(dst) or ".", exist_ok=True) |
| sf.write(str(dst), result, TARGET_SR) |
| return len(audio) / TARGET_SR |
|
|
|
|
| def collect_audio_files(root: str | Path) -> list[Path]: |
| """Recursively find audio files under *root*.""" |
| root = Path(root) |
| return sorted(p for p in root.rglob("*") if p.suffix.lower() in AUDIO_EXTENSIONS) |
|
|
|
|
| |
|
|
| def cli() -> None: |
| ap = argparse.ArgumentParser( |
| description="FlashSR audio super-resolution (by Im & Nam, KAIST)") |
| ap.add_argument("--input", "-i", required=True, |
| help="Input audio file or directory") |
| ap.add_argument("--output", "-o", required=True, |
| help="Output file or directory") |
| ap.add_argument("--weights", "-w", default="./weights", |
| help="Directory containing the three .pth weight files") |
| ap.add_argument("--lowpass", action="store_true", |
| help="Apply lowpass filter before enhancement") |
| ap.add_argument("--device", default="cuda", |
| help="Torch device (default: cuda)") |
| args = ap.parse_args() |
|
|
| dev = torch.device(args.device if torch.cuda.is_available() else "cpu") |
| print(f"Device: {dev}") |
|
|
| print("Loading model...") |
| t0 = time.monotonic() |
| model = build_model(args.weights, dev) |
| print(f"Loaded in {time.monotonic() - t0:.1f}s") |
|
|
| |
| inp = Path(args.input) |
| out = Path(args.output) |
|
|
| if inp.is_dir(): |
| files = collect_audio_files(inp) |
| if not files: |
| sys.exit(f"No audio files found in {inp}") |
| pairs = [(f, out / f.relative_to(inp)) for f in files] |
| else: |
| pairs = [(inp, out)] |
|
|
| total_dur = 0.0 |
| t_start = time.monotonic() |
|
|
| for idx, (src, dst) in enumerate(pairs, 1): |
| print(f"[{idx}/{len(pairs)}] {src} -> {dst}") |
| dur = enhance_file(model, src, dst, device=dev, lowpass=args.lowpass) |
| total_dur += dur |
| print(f" {dur:.1f}s of audio") |
|
|
| elapsed = time.monotonic() - t_start |
| rtf = total_dur / elapsed if elapsed > 0 else 0 |
| print(f"\nDone: {len(pairs)} file(s), {total_dur:.1f}s audio, " |
| f"{elapsed:.1f}s wall-clock ({rtf:.1f}x realtime)") |
|
|
|
|
| if __name__ == "__main__": |
| cli() |
|
|