shethjenil commited on
Commit
71e43dc
·
verified ·
1 Parent(s): 8370302

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -29
app.py CHANGED
@@ -1,14 +1,15 @@
1
  import gradio as gr
2
  import torchaudio
3
- from torchaudio.transforms import Resample
4
  import math
 
 
5
  from typing import Dict, Tuple
6
  from huggingface_hub import hf_hub_download
7
- import torch
8
  from torch import nn, Tensor
9
  from torch.nn import functional as F
10
  from tqdm import tqdm
11
  from safetensors.torch import load_file
 
12
  def batchify(tensor: Tensor, T: int) -> Tensor:
13
  orig_size = tensor.size(-1)
14
  new_size = math.ceil(orig_size / T) * T
@@ -59,18 +60,13 @@ class UNet(nn.Module):
59
  def __init__(
60
  self,
61
  n_layers: int = 6,
62
- in_channels: int = 1,
63
  ) -> None:
64
  super().__init__()
65
 
66
  # DownSample layers
67
  down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)]
68
- self.encoder_layers = nn.ModuleList(
69
- [
70
- EncoderBlock(in_channels=in_ch, out_channels=out_ch)
71
- for in_ch, out_ch in zip(down_set[:-1], down_set[1:])
72
- ]
73
- )
74
 
75
  # UpSample layers
76
  up_set = [1] + [2 ** (i + 4) for i in range(n_layers)]
@@ -122,22 +118,16 @@ class UNet(nn.Module):
122
 
123
  class Splitter(nn.Module):
124
 
125
- def __init__(self, stem_num=2):
126
  super(Splitter, self).__init__()
127
- if stem_num == 2:
128
- stem_names = ["vocals","other"]
129
- if stem_num == 4:
130
- stem_names = ["vocals", "drums", "bass", "other"]
131
- if stem_num == 5:
132
- stem_names = ["vocals", "piano", "drums", "bass", "other"]
133
- # stft config
134
  self.F = 1024
135
  self.T = 512
136
  self.win_length = 4096
137
  self.hop_length = 1024
138
  self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
139
- self.stems = nn.ModuleDict({name: UNet(in_channels=2) for name in stem_names})
140
- self.load_state_dict(load_file(hf_hub_download("shethjenil/spleeter",f"{stem_num}.safetensors")))
 
141
  self.eval()
142
 
143
  def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
@@ -193,7 +183,7 @@ class Splitter(nn.Module):
193
  return wav.detach()
194
 
195
  @torch.inference_mode()
196
- def forward(self, wav: Tensor,batch_size=16,allow=['vocals']) -> Dict[str, Tensor]:
197
  # stft - 2 X F x L x 2
198
  # stft_mag - 2 X F x L
199
  stft, stft_mag = self.compute_stft(wav.squeeze())
@@ -203,7 +193,7 @@ class Splitter(nn.Module):
203
  stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
204
  stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
205
  # compute stems' mask
206
- masks = self.infer_with_batches(stft_mag,batch_size,allow)
207
  # compute denominator
208
  mask_sum = sum([m**2 for m in masks.values()])
209
  mask_sum += 1e-10
@@ -216,11 +206,11 @@ class Splitter(nn.Module):
216
  return stft_masked
217
  return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()}
218
 
219
- def infer_with_batches(self, stft_mag, batch_size, allow):
220
  masks = {name: [] for name in self.stems.keys()}
221
  for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
222
  batch = stft_mag[i:i + batch_size]
223
- batch_outputs = {name: net(batch) for name, net in self.stems.items() if name in allow}
224
  for name in batch_outputs:
225
  masks[name].append(batch_outputs[name])
226
  return {
@@ -229,17 +219,15 @@ class Splitter(nn.Module):
229
  if masks[name]
230
  }
231
 
232
- def separate_audio(audio_path:str,instrument_model:int,batch_size:int,allow:list,progress=gr.Progress(True)):
233
- model = Splitter(instrument_model)
234
  wav, sr = torchaudio.load(audio_path)
235
  target_sr = 44100
236
  if sr != target_sr:
237
  resampler = Resample(sr, target_sr)
238
  wav = resampler(wav)
239
  sr = target_sr
240
- results = model.forward(wav,batch_size,allow)
241
  for i in results:
242
  torchaudio.save(f"{i}.mp3", results[i], sr)
243
- return tuple([i+".mp3" for i in results] + [None for _ in range(5-len(results))])
244
-
245
- gr.Interface(separate_audio, [gr.Audio(type="filepath"),gr.Dropdown([2,4,5]),gr.Number(16),gr.Dropdown(["vocals", "piano", "drums", "bass", "other"],multiselect=True,value=["vocals", "piano", "drums", "bass", "other"])], [gr.Audio(type="filepath"), gr.Audio(type="filepath"),gr.Audio(type="filepath"),gr.Audio(type="filepath"),gr.Audio(type="filepath")]).launch()
 
1
  import gradio as gr
2
  import torchaudio
 
3
  import math
4
+ import torch
5
+ from torchaudio.transforms import Resample
6
  from typing import Dict, Tuple
7
  from huggingface_hub import hf_hub_download
 
8
  from torch import nn, Tensor
9
  from torch.nn import functional as F
10
  from tqdm import tqdm
11
  from safetensors.torch import load_file
12
+
13
  def batchify(tensor: Tensor, T: int) -> Tensor:
14
  orig_size = tensor.size(-1)
15
  new_size = math.ceil(orig_size / T) * T
 
60
  def __init__(
61
  self,
62
  n_layers: int = 6,
63
+ in_channels: int = 2,
64
  ) -> None:
65
  super().__init__()
66
 
67
  # DownSample layers
68
  down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)]
69
+ self.encoder_layers = nn.ModuleList([EncoderBlock(in_channels=in_ch, out_channels=out_ch) for in_ch, out_ch in zip(down_set[:-1], down_set[1:])])
 
 
 
 
 
70
 
71
  # UpSample layers
72
  up_set = [1] + [2 ** (i + 4) for i in range(n_layers)]
 
118
 
119
  class Splitter(nn.Module):
120
 
121
+ def __init__(self, instrument_models):
122
  super(Splitter, self).__init__()
 
 
 
 
 
 
 
123
  self.F = 1024
124
  self.T = 512
125
  self.win_length = 4096
126
  self.hop_length = 1024
127
  self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
128
+ self.stems = nn.ModuleDict({name: UNet() for name in instrument_models})
129
+ for name in self.stems.keys():
130
+ self.stems[name].load_state_dict(load_file(hf_hub_download("shethjenil/spleeter",f"{name}.safetensors")))
131
  self.eval()
132
 
133
  def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
 
183
  return wav.detach()
184
 
185
  @torch.inference_mode()
186
+ def forward(self, wav: Tensor,batch_size=16) -> Dict[str, Tensor]:
187
  # stft - 2 X F x L x 2
188
  # stft_mag - 2 X F x L
189
  stft, stft_mag = self.compute_stft(wav.squeeze())
 
193
  stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
194
  stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
195
  # compute stems' mask
196
+ masks = self.infer_with_batches(stft_mag,batch_size)
197
  # compute denominator
198
  mask_sum = sum([m**2 for m in masks.values()])
199
  mask_sum += 1e-10
 
206
  return stft_masked
207
  return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()}
208
 
209
+ def infer_with_batches(self, stft_mag, batch_size):
210
  masks = {name: [] for name in self.stems.keys()}
211
  for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
212
  batch = stft_mag[i:i + batch_size]
213
+ batch_outputs = {name: net(batch) for name, net in self.stems.items()}
214
  for name in batch_outputs:
215
  masks[name].append(batch_outputs[name])
216
  return {
 
219
  if masks[name]
220
  }
221
 
222
+ def separate_audio(audio_path:str,batch_size:int,instrument_models:list,progress=gr.Progress(True)):
 
223
  wav, sr = torchaudio.load(audio_path)
224
  target_sr = 44100
225
  if sr != target_sr:
226
  resampler = Resample(sr, target_sr)
227
  wav = resampler(wav)
228
  sr = target_sr
229
+ results = Splitter(instrument_models).forward(wav,batch_size)
230
  for i in results:
231
  torchaudio.save(f"{i}.mp3", results[i], sr)
232
+ return [gr.Audio(i,type='filepath',buttons=['download']) for i in results]
233
+ gr.Interface(separate_audio, [gr.Audio(type="filepath"),gr.Number(16),gr.Dropdown(['2_other', '2_vocals', '4_bass', '4_drums', '4_other', '4_vocals', '5_bass', '5_drums', '5_other', '5_piano', '5_vocals'],multiselect=True,value=['5_vocals'])]).launch()