shethjenil commited on
Commit
d5783f9
·
verified ·
1 Parent(s): 2a96a7c

Update spleeter.py

Browse files
Files changed (1) hide show
  1. spleeter.py +17 -15
spleeter.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  from torch import nn, Tensor
6
  from torch.nn import functional as F
7
  from tqdm import tqdm
8
-
9
  def batchify(tensor: Tensor, T: int) -> Tensor:
10
  orig_size = tensor.size(-1)
11
  new_size = math.ceil(orig_size / T) * T
@@ -122,7 +122,7 @@ class Splitter(nn.Module):
122
  def __init__(self, stem_num=2):
123
  super(Splitter, self).__init__()
124
  if stem_num == 2:
125
- stem_names = ["vocals","accompaniment"]
126
  if stem_num == 4:
127
  stem_names = ["vocals", "drums", "bass", "other"]
128
  if stem_num == 5:
@@ -134,7 +134,7 @@ class Splitter(nn.Module):
134
  self.hop_length = 1024
135
  self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
136
  self.stems = nn.ModuleDict({name: UNet(in_channels=2) for name in stem_names})
137
- self.load_state_dict(torch.load(hf_hub_download("shethjenil/spleeter-torch",f"{stem_num}.pt")))
138
  self.eval()
139
 
140
  def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
@@ -189,7 +189,8 @@ class Splitter(nn.Module):
189
 
190
  return wav.detach()
191
 
192
- def forward(self, wav: Tensor,batch_size=16) -> Dict[str, Tensor]:
 
193
  # stft - 2 X F x L x 2
194
  # stft_mag - 2 X F x L
195
  stft, stft_mag = self.compute_stft(wav.squeeze())
@@ -199,7 +200,7 @@ class Splitter(nn.Module):
199
  stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
200
  stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
201
  # compute stems' mask
202
- masks = self.infer_with_batches(stft_mag,batch_size)
203
  # compute denominator
204
  mask_sum = sum([m**2 for m in masks.values()])
205
  mask_sum += 1e-10
@@ -212,14 +213,15 @@ class Splitter(nn.Module):
212
  return stft_masked
213
  return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()}
214
 
215
- def infer_with_batches(self, stft_mag, batch_size):
216
  masks = {name: [] for name in self.stems.keys()}
217
- with torch.inference_mode():
218
- for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
219
- batch = stft_mag[i:i + batch_size]
220
- batch_outputs = {name: net(batch) for name, net in self.stems.items()}
221
- for name in self.stems.keys():
222
- masks[name].append(batch_outputs[name])
223
- masks = {name: torch.cat(masks[name], dim=0) for name in masks}
224
- return masks
225
-
 
 
5
  from torch import nn, Tensor
6
  from torch.nn import functional as F
7
  from tqdm import tqdm
8
+ from safetensors.torch import load_file
9
  def batchify(tensor: Tensor, T: int) -> Tensor:
10
  orig_size = tensor.size(-1)
11
  new_size = math.ceil(orig_size / T) * T
 
122
  def __init__(self, stem_num=2):
123
  super(Splitter, self).__init__()
124
  if stem_num == 2:
125
+ stem_names = ["vocals","other"]
126
  if stem_num == 4:
127
  stem_names = ["vocals", "drums", "bass", "other"]
128
  if stem_num == 5:
 
134
  self.hop_length = 1024
135
  self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
136
  self.stems = nn.ModuleDict({name: UNet(in_channels=2) for name in stem_names})
137
+ self.load_state_dict(load_file(hf_hub_download("shethjenil/spleeter",f"{stem_num}.safetensors")))
138
  self.eval()
139
 
140
  def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
 
189
 
190
  return wav.detach()
191
 
192
+ @torch.inference_mode()
193
+ def forward(self, wav: Tensor,batch_size=16,allow=['vocals']) -> Dict[str, Tensor]:
194
  # stft - 2 X F x L x 2
195
  # stft_mag - 2 X F x L
196
  stft, stft_mag = self.compute_stft(wav.squeeze())
 
200
  stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
201
  stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
202
  # compute stems' mask
203
+ masks = self.infer_with_batches(stft_mag,batch_size,allow)
204
  # compute denominator
205
  mask_sum = sum([m**2 for m in masks.values()])
206
  mask_sum += 1e-10
 
213
  return stft_masked
214
  return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()}
215
 
216
+ def infer_with_batches(self, stft_mag, batch_size, allow):
217
  masks = {name: [] for name in self.stems.keys()}
218
+ for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
219
+ batch = stft_mag[i:i + batch_size]
220
+ batch_outputs = {name: net(batch) for name, net in self.stems.items() if name in allow}
221
+ for name in batch_outputs:
222
+ masks[name].append(batch_outputs[name])
223
+ return {
224
+ name: torch.cat(masks[name], dim=0)
225
+ for name in masks
226
+ if masks[name]
227
+ }