Demucs / app.py
shethjenil's picture
Update app.py
96ca358 verified
import torchaudio
import torch
from huggingface_hub import hf_hub_download
from tqdm import tqdm
from safetensors.torch import load_file
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import Fade
from torchaudio.models import HDemucs
import gradio as gr
import math
# Constants
WIN_LENGTH, HOP_LENGTH, SR = 4096, 1024, 44100
class Crop2d(nn.Module):
def __init__(self, l, r, t, b):
super().__init__()
self.l, self.r, self.t, self.b = l, r, t, b
def forward(self, x): return x[:, :, self.t:-self.b, self.l:-self.r]
class EncoderBlock(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.conv = nn.Conv2d(in_c, out_c, 5, 2)
self.bn = nn.BatchNorm2d(out_c, 0.001, 0.01)
self.relu = nn.LeakyReLU(0.2)
self.pad = nn.ConstantPad2d((1, 2, 1, 2), 0)
def forward(self, x):
down = self.conv(self.pad(x))
return down, self.relu(self.bn(down))
class DecoderBlock(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.tconv = nn.ConvTranspose2d(in_c, out_c, 5, 2)
self.crop = Crop2d(1, 2, 1, 2)
self.bn = nn.BatchNorm2d(out_c, 0.001, 0.01)
self.relu = nn.ReLU()
def forward(self, x): return self.bn(self.relu(self.crop(self.tconv(x))))
class UNet(nn.Module):
def __init__(self, n_layers=6, in_c=2):
super().__init__()
down = [in_c] + [2**(i+4) for i in range(n_layers)]
self.encoder_layers = nn.ModuleList([EncoderBlock(i, o) for i, o in zip(down[:-1], down[1:])])
up = [1] + [2**(i+4) for i in range(n_layers)]
up.reverse()
self.decoder_layers = nn.ModuleList([DecoderBlock(i if idx==0 else i*2, o) for idx, (i, o) in enumerate(zip(up[:-1], up[1:]))])
self.up_final = nn.Conv2d(1, in_c, 4, dilation=2, padding=3)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
enc_out = []
for down in self.encoder_layers:
conv, x = down(x)
enc_out.append(conv)
for i, up in enumerate(self.decoder_layers):
x = up(enc_out.pop()) if i == 0 else up(torch.cat([enc_out.pop(), x], 1))
mask = self.sigmoid(self.up_final(x))
min_f, min_t = min(mask.size(-2), x.size(-2)), min(mask.size(-1), x.size(-1))
return mask[..., :min_f, :min_t] * x[..., :min_f, :min_t]
class STFTChunkDataset(Dataset):
def __init__(self, wav, win):
self.win, self.T, self.F = win, 512, 1024
wav = wav.unsqueeze(0) if wav.dim() == 1 else wav
stft = torch.stft(wav.squeeze(), WIN_LENGTH, HOP_LENGTH, window=win, return_complex=False, pad_mode="constant")[:, :self.F]
self.stft, self.L = stft, stft.size(2)
mag = torch.sqrt(stft[:,:,:,0]**2 + stft[:,:,:,1]**2 + 1e-10)
mag = mag.unsqueeze(-1).permute(3, 0, 1, 2)
self.stft_mag = self._batch(mag)
def __len__(self): return self.stft_mag.size(0)
def __getitem__(self, i): return self.stft_mag[i]
def _batch(self, x):
new_size = math.ceil(x.size(-1) / self.T) * self.T
x = F.pad(x, [0, new_size - x.size(-1)])
return torch.cat(torch.split(x, self.T, -1), 0).transpose(2, 3)
def apply_mask(self, mask, mask_sum):
mask = (mask**2 + 1e-10/2) / mask_sum
mask = torch.cat(torch.split(mask.transpose(2, 3), 1, 0), 3).squeeze(0)[:,:,:self.L].unsqueeze(-1)
stft = F.pad(self.stft * mask, (0,0,0,0,0,WIN_LENGTH//2+1-self.stft.size(1))) if self.stft.size(1) < WIN_LENGTH//2+1 else self.stft * mask
return torch.istft(torch.view_as_complex(stft), WIN_LENGTH, HOP_LENGTH, WIN_LENGTH, self.win,True)
def decoder(self, masks):
mask_sum = sum([m**2 for m in masks.values()]) + 1e-10
return {n: self.apply_mask(m, mask_sum) for n, m in masks.items()}
class Splitter(nn.Module):
CFG = {2: ['2_other', '2_vocals'], 4: ['4_bass', '4_drums', '4_other', '4_vocals'], 5: ['5_piano', '5_bass', '5_drums', '5_other', '5_vocals']}
def __init__(self, stem=2):
super().__init__()
self.win = nn.Parameter(torch.hann_window(WIN_LENGTH), requires_grad=False)
self.stems = nn.ModuleDict({n: UNet() for n in self.CFG[stem]})
for n in self.stems:
self.stems[n].load_state_dict(load_file(hf_hub_download("shethjenil/spleeter", f"{n}.safetensors")))
self.eval()
@torch.inference_mode()
def forward(self, wav, sr, bs):
dev = next(self.parameters()).device
wav = torchaudio.functional.resample(wav, sr, SR).to(dev) if sr != SR else wav.to(dev)
ds = STFTChunkDataset(wav, self.win)
masks = {n: [] for n in self.stems}
for batch in DataLoader(ds, bs, shuffle=False):
for n, net in self.stems.items():
masks[n].append(net(batch.to(dev)))
return ds.decoder({k: torch.cat(v, 0) for k, v in masks.items()})
class DemucsChunkDataset(Dataset):
def __init__(self, wav, seg=10.0, ovlp=0.1, sr=SR, n_src=4):
super().__init__()
self.mean, self.std = wav.mean(), wav.std()
self.mix = (wav - self.mean) / self.std
self.c, self.len = self.mix.shape
self.ovlp_f = int(ovlp * sr)
self.chunk = int(sr * seg * (1 + ovlp))
self.starts, start, idx = [], 0, 0
while start < self.len - self.ovlp_f:
self.starts.append(start)
start += self.chunk - self.ovlp_f if idx == 0 else self.chunk
idx += 1
self.final = torch.zeros(n_src, self.c, self.len, device=wav.device)
self.fade = Fade(0, self.ovlp_f, "linear").to(wav.device)
def __len__(self): return len(self.starts)
def __getitem__(self, i):
s, e = self.starts[i], min(self.starts[i] + self.chunk, self.len)
chunk = self.mix[:, s:e]
return {"chunk": F.pad(chunk, (0, self.chunk - chunk.shape[-1])) if chunk.shape[-1] < self.chunk else chunk, "start": s, "idx": i}
def decode_and_append(self, out, meta):
for i in range(out.size(0)):
s, idx, e = meta["start"][i], meta["idx"][i], min(meta["start"][i] + self.chunk, self.len)
self.final[:, :, s:e] += self.fade(out[i:i+1])[0, :, :, :e-s]
if idx == 0: self.fade.fade_in_len = self.ovlp_f
if e >= self.len: self.fade.fade_out_len = 0
def get_output(self): return self.final * self.std + self.mean
class Demucs(nn.Module):
CFG = {4: ["drums", "bass", "other", "vocals"]}
def __init__(self, stem=4):
super().__init__()
self.model = HDemucs(self.CFG[stem])
self.model.load_state_dict(torch.load(torchaudio.utils._download_asset("models/hdemucs_high_trained.pt", progress=False)))
self.eval()
@torch.inference_mode()
def forward(self, wav, sr, bs):
dev = next(self.parameters()).device
wav = torchaudio.functional.resample(wav, sr, SR).to(dev) if sr != SR else wav.to(dev)
ds = DemucsChunkDataset(wav)
for b in tqdm(DataLoader(ds, bs), desc="Separating"):
ds.decode_and_append(self.model(b["chunk"]), b)
return dict(zip(self.model.sources, ds.get_output()))
def separate_audio_spleeter(path, bs, stem, progress=gr.Progress(True)):
wav, sr = torchaudio.load(path)
res = Splitter(stem).to("cuda" if torch.cuda.is_available() else "cpu")(wav, sr, bs)
for i in res: torchaudio.save(f"{i}.mp3", res[i].cpu(), SR)
return [f"{i}.mp3" for i in res]
def separate_audio_demucs(path, bs, stem, progress=gr.Progress(True)):
wav, sr = torchaudio.load(path)
res = Demucs(stem).to("cuda" if torch.cuda.is_available() else "cpu")(wav, sr, bs)
for i in res: torchaudio.save(f"{i}.mp3", res[i].cpu(), SR)
return [f"{i}.mp3" for i in res]
gr.TabbedInterface([
gr.Interface(separate_audio_spleeter, [gr.Audio(type="filepath"),gr.Number(16),gr.Radio([2,4,5],label="STEM")],gr.Files()),
gr.Interface(separate_audio_demucs, [gr.Audio(type="filepath"),gr.Number(16),gr.Radio([4],label="STEM")],gr.Files())
],['spleeter','demucs']).launch()