import math import tempfile from pathlib import Path import yaml import torch import torch.nn as nn import torch.nn.functional as F import torchaudio as ta import soundfile as sf import gradio as gr from tqdm import tqdm from typing import Union, Tuple, Optional from torch import Tensor from pyharp import build_endpoint, ModelCard # ───────────────────────────────────────────── # UNet Utilities # ───────────────────────────────────────────── class UNetUtils: def __init__(self, F=None, T=None, n_fft=4096, win_length=None, hop_length=None, center=True, device='cpu'): self.n_fft = n_fft self.win_length = n_fft if win_length is None else win_length self.hop_length = self.win_length // 4 if hop_length is None else hop_length self.hann_window = torch.hann_window(self.win_length, periodic=True).to(device) self.center = center self.device = device self.F = F self.T = T def fold_unet_inputs(self, x): time_dim = x.size(-1) pad_len = math.ceil(time_dim / self.T) * self.T - time_dim padded = F.pad(x, (0, pad_len)) if time_dim < self.T: return padded return torch.cat(torch.split(padded, self.T, dim=-1), dim=0) def unfold_unet_outputs(self, x, input_size): batch_size, n_frames = input_size[0], input_size[-1] if x.size(0) == batch_size: return x[..., :n_frames] x = torch.cat(torch.split(x, batch_size, dim=0), dim=-1) return x[..., :n_frames] def trim_freq_dim(self, x): return x[..., :self.F, :] def pad_freq_dim(self, x): padding = (self.n_fft // 2 + 1) - x.size(-2) return F.pad(x, (0, 0, 0, padding)) def pad_stft_input(self, x): pad_len = (-(x.size(-1) - self.win_length) % self.hop_length) % self.win_length return F.pad(x, (0, pad_len)) def _stft(self, x): return torch.stft(input=x, n_fft=self.n_fft, window=self.hann_window, win_length=self.win_length, hop_length=self.hop_length, center=self.center, return_complex=True) def _istft(self, x, trim_length=None): return torch.istft(input=x, n_fft=self.n_fft, window=self.hann_window, win_length=self.win_length, hop_length=self.hop_length, center=self.center, length=trim_length) def batch_stft(self, x, pad=True, return_complex=False): x_shape = x.size() x = x.reshape(-1, x_shape[-1]) if pad: x = self.pad_stft_input(x) S = self._stft(x) S = S.reshape(x_shape[:-1] + S.shape[-2:]) if return_complex: return S return S.abs(), S.angle() def batch_istft(self, magnitude, phase, trim_length=None): S = torch.polar(magnitude, phase) S_shape = S.size() S = S.reshape(-1, S_shape[-2], S_shape[-1]) x = self._istft(S, trim_length) return x.reshape(S_shape[:-2] + x.shape[-1:]) # ───────────────────────────────────────────── # UNet Blocks # ───────────────────────────────────────────── class UNetEncoderBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=(5,5), stride=(2,2), padding=(2,2), relu_slope=0.2): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.bn = nn.BatchNorm2d(out_channels) self.activ = nn.LeakyReLU(relu_slope) nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='leaky_relu', a=relu_slope) nn.init.zeros_(self.conv.bias) def forward(self, x): c = self.conv(x) return self.activ(self.bn(c)), c class UNetDecoderBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=(5,5), stride=(2,2), padding=(2,2), output_padding=(1,1), dropout=0.0): super().__init__() self.conv_trans = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding) self.bn = nn.BatchNorm2d(out_channels) self.dropout = nn.Dropout(dropout) self.activ = nn.ReLU() def forward(self, x): return self.dropout(self.bn(self.activ(self.conv_trans(x)))) # ───────────────────────────────────────────── # UNet Models # ───────────────────────────────────────────── class UNet(nn.Module): def __init__(self, input_size: Tuple[int, ...] = (2, 2048, 512), power: float = 1.0, device: Optional[str] = None): super().__init__() self.input_size = input_size audio_channels, f_size, t_size = input_size self.utils = UNetUtils(F=f_size, T=t_size, device=device) self.input_norm = nn.BatchNorm2d(f_size) self.enc1 = UNetEncoderBlock(audio_channels, 16) self.enc2 = UNetEncoderBlock(16, 32) self.enc3 = UNetEncoderBlock(32, 64) self.enc4 = UNetEncoderBlock(64, 128) self.enc5 = UNetEncoderBlock(128, 256) self.enc6 = UNetEncoderBlock(256, 512) self.dec1 = UNetDecoderBlock(512, 256, dropout=0.5) self.dec2 = UNetDecoderBlock(512, 128, dropout=0.5) self.dec3 = UNetDecoderBlock(256, 64, dropout=0.5) self.dec4 = UNetDecoderBlock(128, 32) self.dec5 = UNetDecoderBlock(64, 16) self.dec6 = UNetDecoderBlock(32, audio_channels) self.mask_layer = nn.Sequential( nn.Conv2d(audio_channels, audio_channels, kernel_size=(4,4), dilation=(2,2), padding=3), nn.Sigmoid() ) nn.init.kaiming_uniform_(self.mask_layer[0].weight) nn.init.zeros_(self.mask_layer[0].bias) if device is not None: self.to(device) def produce_mask(self, x: Tensor) -> Tensor: x = self.input_norm(x.transpose(1, 2)).transpose(1, 2) d, c1 = self.enc1(x) d, c2 = self.enc2(d) d, c3 = self.enc3(d) d, c4 = self.enc4(d) d, c5 = self.enc5(d) _, c6 = self.enc6(d) u = self.dec1(c6) u = self.dec2(torch.cat([c5, u], dim=1)) u = self.dec3(torch.cat([c4, u], dim=1)) u = self.dec4(torch.cat([c3, u], dim=1)) u = self.dec5(torch.cat([c2, u], dim=1)) u = self.dec6(torch.cat([c1, u], dim=1)) return self.mask_layer(u) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: input_size = x.size() x = self.utils.fold_unet_inputs(x) i = self.utils.trim_freq_dim(x) mask = self.produce_mask(i) mask = self.utils.pad_freq_dim(mask) return (self.utils.unfold_unet_outputs(x * mask, input_size), self.utils.unfold_unet_outputs(mask, input_size)) class UNetWaveform(UNet): def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: if x.dim() == 1: x = x.repeat(2, 1) if x.dim() == 2: x = x.unsqueeze(0) mag, phase = self.utils.batch_stft(x) mag_hat, mask = super().forward(mag) return self.utils.batch_istft(mag_hat, phase, trim_length=x.size(-1)), mask # ───────────────────────────────────────────── # LarsNet # ───────────────────────────────────────────── class LarsNet(nn.Module): def __init__(self, wiener_filter=False, wiener_exponent=1.0, config: Union[str, Path] = "config.yaml", return_stft=False, device='cpu', **kwargs): super().__init__(**kwargs) with open(config, "r") as f: config = yaml.safe_load(f) self.device = device self.wiener_filter = wiener_filter self.wiener_exponent = wiener_exponent self.return_stft = return_stft self.stems = config['inference_models'].keys() self.utils = UNetUtils(device=self.device) self.sr = config['global']['sr'] self.models = {} print('Loading UNet models...') for stem in tqdm(self.stems): checkpoint_path = Path(config['inference_models'][stem]) F = config[stem]['F'] T = config[stem]['T'] model = (UNet if (wiener_filter or return_stft) else UNetWaveform)( input_size=(2, F, T), device=self.device ) checkpoint = torch.load(str(checkpoint_path), map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() self.models[stem] = model @staticmethod def _fix_dim(x): if x.dim() == 1: x = x.repeat(2, 1) if x.dim() == 2: x = x.unsqueeze(0) return x def separate(self, x): out = {} x = x.to(self.device) for stem, model in tqdm(self.models.items()): y, _ = model(x) out[stem] = y.squeeze(0).detach() return out def separate_wiener(self, x): out = {} mag_pred = [] x = self._fix_dim(x).to(self.device) mag, phase = self.utils.batch_stft(x) for stem, model in tqdm(self.models.items()): _, mask = model(mag) mag_pred.append((mask * mag) ** self.wiener_exponent) pred_sum = sum(mag_pred) for stem, pred in zip(self.stems, mag_pred): wiener_mask = pred / (pred_sum + 1e-7) y = self.utils.batch_istft(mag * wiener_mask, phase, trim_length=x.size(-1)) out[stem] = y.squeeze(0).detach() return out def separate_stft(self, x): out = {} x = self._fix_dim(x).to(self.device) mag, phase = self.utils.batch_stft(x) for stem, model in tqdm(self.models.items()): mag_pred, _ = model(mag) out[stem] = torch.polar(mag_pred, phase).squeeze(0).detach() return out def forward(self, x): if isinstance(x, (str, Path)): x, sr_ = ta.load(str(x)) if sr_ != self.sr: x = ta.functional.resample(x, sr_, self.sr) if self.return_stft: return self.separate_stft(x) elif self.wiener_filter: return self.separate_wiener(x) else: return self.separate(x) # ───────────────────────────────────────────── # App # ───────────────────────────────────────────── model_card = ModelCard( name="LarsNet Drum Stem Separator", description="Separates a drum mix into individual drum stems: Kick, Snare, Toms, Hi-Hat, and Cymbals.", author="A. I. Mezza, et al.", tags=["drums", "demucs", "source-separation", "pyharp", "stems", "multi-output"], ) MODEL = LarsNet(wiener_filter=False, device="cpu", config="config.yaml") @torch.inference_mode() def process_fn(audio_path: str): stems = MODEL(audio_path) output_dir = Path("outputs") output_dir.mkdir(exist_ok=True) output_paths = [] for stem_name in ["kick", "snare", "toms", "hihat", "cymbals"]: out_path = output_dir / f"{stem_name}.wav" sf.write(out_path, stems[stem_name].cpu().numpy().T, MODEL.sr) output_paths.append(str(out_path)) return tuple(output_paths) with gr.Blocks() as demo: input_audio = gr.Audio(type="filepath", label="Drum Mix (Input)").harp_required(True) output_kick = gr.Audio(type="filepath", label="Kick") output_snare = gr.Audio(type="filepath", label="Snare") output_toms = gr.Audio(type="filepath", label="Toms") output_hihat = gr.Audio(type="filepath", label="Hi-Hat") output_cymbals = gr.Audio(type="filepath", label="Cymbals") app = build_endpoint( model_card=model_card, input_components=[input_audio], output_components=[output_kick, output_snare, output_toms, output_hihat, output_cymbals], process_fn=process_fn, ) demo.queue().launch(show_error=True, share=True)