Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import torchaudio
|
| 3 |
-
from torchaudio.transforms import Resample
|
| 4 |
import math
|
|
|
|
|
|
|
| 5 |
from typing import Dict, Tuple
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
-
import torch
|
| 8 |
from torch import nn, Tensor
|
| 9 |
from torch.nn import functional as F
|
| 10 |
from tqdm import tqdm
|
| 11 |
from safetensors.torch import load_file
|
|
|
|
| 12 |
def batchify(tensor: Tensor, T: int) -> Tensor:
|
| 13 |
orig_size = tensor.size(-1)
|
| 14 |
new_size = math.ceil(orig_size / T) * T
|
|
@@ -59,18 +60,13 @@ class UNet(nn.Module):
|
|
| 59 |
def __init__(
|
| 60 |
self,
|
| 61 |
n_layers: int = 6,
|
| 62 |
-
in_channels: int =
|
| 63 |
) -> None:
|
| 64 |
super().__init__()
|
| 65 |
|
| 66 |
# DownSample layers
|
| 67 |
down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)]
|
| 68 |
-
self.encoder_layers = nn.ModuleList(
|
| 69 |
-
[
|
| 70 |
-
EncoderBlock(in_channels=in_ch, out_channels=out_ch)
|
| 71 |
-
for in_ch, out_ch in zip(down_set[:-1], down_set[1:])
|
| 72 |
-
]
|
| 73 |
-
)
|
| 74 |
|
| 75 |
# UpSample layers
|
| 76 |
up_set = [1] + [2 ** (i + 4) for i in range(n_layers)]
|
|
@@ -122,22 +118,16 @@ class UNet(nn.Module):
|
|
| 122 |
|
| 123 |
class Splitter(nn.Module):
|
| 124 |
|
| 125 |
-
def __init__(self,
|
| 126 |
super(Splitter, self).__init__()
|
| 127 |
-
if stem_num == 2:
|
| 128 |
-
stem_names = ["vocals","other"]
|
| 129 |
-
if stem_num == 4:
|
| 130 |
-
stem_names = ["vocals", "drums", "bass", "other"]
|
| 131 |
-
if stem_num == 5:
|
| 132 |
-
stem_names = ["vocals", "piano", "drums", "bass", "other"]
|
| 133 |
-
# stft config
|
| 134 |
self.F = 1024
|
| 135 |
self.T = 512
|
| 136 |
self.win_length = 4096
|
| 137 |
self.hop_length = 1024
|
| 138 |
self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
|
| 139 |
-
self.stems = nn.ModuleDict({name: UNet(
|
| 140 |
-
self.
|
|
|
|
| 141 |
self.eval()
|
| 142 |
|
| 143 |
def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
|
|
@@ -193,7 +183,7 @@ class Splitter(nn.Module):
|
|
| 193 |
return wav.detach()
|
| 194 |
|
| 195 |
@torch.inference_mode()
|
| 196 |
-
def forward(self, wav: Tensor,batch_size=16
|
| 197 |
# stft - 2 X F x L x 2
|
| 198 |
# stft_mag - 2 X F x L
|
| 199 |
stft, stft_mag = self.compute_stft(wav.squeeze())
|
|
@@ -203,7 +193,7 @@ class Splitter(nn.Module):
|
|
| 203 |
stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
|
| 204 |
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
|
| 205 |
# compute stems' mask
|
| 206 |
-
masks = self.infer_with_batches(stft_mag,batch_size
|
| 207 |
# compute denominator
|
| 208 |
mask_sum = sum([m**2 for m in masks.values()])
|
| 209 |
mask_sum += 1e-10
|
|
@@ -216,11 +206,11 @@ class Splitter(nn.Module):
|
|
| 216 |
return stft_masked
|
| 217 |
return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()}
|
| 218 |
|
| 219 |
-
def infer_with_batches(self, stft_mag, batch_size
|
| 220 |
masks = {name: [] for name in self.stems.keys()}
|
| 221 |
for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
|
| 222 |
batch = stft_mag[i:i + batch_size]
|
| 223 |
-
batch_outputs = {name: net(batch) for name, net in self.stems.items()
|
| 224 |
for name in batch_outputs:
|
| 225 |
masks[name].append(batch_outputs[name])
|
| 226 |
return {
|
|
@@ -229,17 +219,15 @@ class Splitter(nn.Module):
|
|
| 229 |
if masks[name]
|
| 230 |
}
|
| 231 |
|
| 232 |
-
def separate_audio(audio_path:str,
|
| 233 |
-
model = Splitter(instrument_model)
|
| 234 |
wav, sr = torchaudio.load(audio_path)
|
| 235 |
target_sr = 44100
|
| 236 |
if sr != target_sr:
|
| 237 |
resampler = Resample(sr, target_sr)
|
| 238 |
wav = resampler(wav)
|
| 239 |
sr = target_sr
|
| 240 |
-
results =
|
| 241 |
for i in results:
|
| 242 |
torchaudio.save(f"{i}.mp3", results[i], sr)
|
| 243 |
-
return
|
| 244 |
-
|
| 245 |
-
gr.Interface(separate_audio, [gr.Audio(type="filepath"),gr.Dropdown([2,4,5]),gr.Number(16),gr.Dropdown(["vocals", "piano", "drums", "bass", "other"],multiselect=True,value=["vocals", "piano", "drums", "bass", "other"])], [gr.Audio(type="filepath"), gr.Audio(type="filepath"),gr.Audio(type="filepath"),gr.Audio(type="filepath"),gr.Audio(type="filepath")]).launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torchaudio
|
|
|
|
| 3 |
import math
|
| 4 |
+
import torch
|
| 5 |
+
from torchaudio.transforms import Resample
|
| 6 |
from typing import Dict, Tuple
|
| 7 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 8 |
from torch import nn, Tensor
|
| 9 |
from torch.nn import functional as F
|
| 10 |
from tqdm import tqdm
|
| 11 |
from safetensors.torch import load_file
|
| 12 |
+
|
| 13 |
def batchify(tensor: Tensor, T: int) -> Tensor:
|
| 14 |
orig_size = tensor.size(-1)
|
| 15 |
new_size = math.ceil(orig_size / T) * T
|
|
|
|
| 60 |
def __init__(
|
| 61 |
self,
|
| 62 |
n_layers: int = 6,
|
| 63 |
+
in_channels: int = 2,
|
| 64 |
) -> None:
|
| 65 |
super().__init__()
|
| 66 |
|
| 67 |
# DownSample layers
|
| 68 |
down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)]
|
| 69 |
+
self.encoder_layers = nn.ModuleList([EncoderBlock(in_channels=in_ch, out_channels=out_ch) for in_ch, out_ch in zip(down_set[:-1], down_set[1:])])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
# UpSample layers
|
| 72 |
up_set = [1] + [2 ** (i + 4) for i in range(n_layers)]
|
|
|
|
| 118 |
|
| 119 |
class Splitter(nn.Module):
|
| 120 |
|
| 121 |
+
def __init__(self, instrument_models):
|
| 122 |
super(Splitter, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
self.F = 1024
|
| 124 |
self.T = 512
|
| 125 |
self.win_length = 4096
|
| 126 |
self.hop_length = 1024
|
| 127 |
self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
|
| 128 |
+
self.stems = nn.ModuleDict({name: UNet() for name in instrument_models})
|
| 129 |
+
for name in self.stems.keys():
|
| 130 |
+
self.stems[name].load_state_dict(load_file(hf_hub_download("shethjenil/spleeter",f"{name}.safetensors")))
|
| 131 |
self.eval()
|
| 132 |
|
| 133 |
def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
|
| 183 |
return wav.detach()
|
| 184 |
|
| 185 |
@torch.inference_mode()
|
| 186 |
+
def forward(self, wav: Tensor,batch_size=16) -> Dict[str, Tensor]:
|
| 187 |
# stft - 2 X F x L x 2
|
| 188 |
# stft_mag - 2 X F x L
|
| 189 |
stft, stft_mag = self.compute_stft(wav.squeeze())
|
|
|
|
| 193 |
stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
|
| 194 |
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
|
| 195 |
# compute stems' mask
|
| 196 |
+
masks = self.infer_with_batches(stft_mag,batch_size)
|
| 197 |
# compute denominator
|
| 198 |
mask_sum = sum([m**2 for m in masks.values()])
|
| 199 |
mask_sum += 1e-10
|
|
|
|
| 206 |
return stft_masked
|
| 207 |
return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()}
|
| 208 |
|
| 209 |
+
def infer_with_batches(self, stft_mag, batch_size):
|
| 210 |
masks = {name: [] for name in self.stems.keys()}
|
| 211 |
for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
|
| 212 |
batch = stft_mag[i:i + batch_size]
|
| 213 |
+
batch_outputs = {name: net(batch) for name, net in self.stems.items()}
|
| 214 |
for name in batch_outputs:
|
| 215 |
masks[name].append(batch_outputs[name])
|
| 216 |
return {
|
|
|
|
| 219 |
if masks[name]
|
| 220 |
}
|
| 221 |
|
| 222 |
+
def separate_audio(audio_path:str,batch_size:int,instrument_models:list,progress=gr.Progress(True)):
|
|
|
|
| 223 |
wav, sr = torchaudio.load(audio_path)
|
| 224 |
target_sr = 44100
|
| 225 |
if sr != target_sr:
|
| 226 |
resampler = Resample(sr, target_sr)
|
| 227 |
wav = resampler(wav)
|
| 228 |
sr = target_sr
|
| 229 |
+
results = Splitter(instrument_models).forward(wav,batch_size)
|
| 230 |
for i in results:
|
| 231 |
torchaudio.save(f"{i}.mp3", results[i], sr)
|
| 232 |
+
return [gr.Audio(i,type='filepath',buttons=['download']) for i in results]
|
| 233 |
+
gr.Interface(separate_audio, [gr.Audio(type="filepath"),gr.Number(16),gr.Dropdown(['2_other', '2_vocals', '4_bass', '4_drums', '4_other', '4_vocals', '5_bass', '5_drums', '5_other', '5_piano', '5_vocals'],multiselect=True,value=['5_vocals'])]).launch()
|
|
|