Spaces:
Sleeping
Sleeping
Update spleeter.py
Browse files- spleeter.py +17 -15
spleeter.py
CHANGED
|
@@ -5,7 +5,7 @@ import torch
|
|
| 5 |
from torch import nn, Tensor
|
| 6 |
from torch.nn import functional as F
|
| 7 |
from tqdm import tqdm
|
| 8 |
-
|
| 9 |
def batchify(tensor: Tensor, T: int) -> Tensor:
|
| 10 |
orig_size = tensor.size(-1)
|
| 11 |
new_size = math.ceil(orig_size / T) * T
|
|
@@ -122,7 +122,7 @@ class Splitter(nn.Module):
|
|
| 122 |
def __init__(self, stem_num=2):
|
| 123 |
super(Splitter, self).__init__()
|
| 124 |
if stem_num == 2:
|
| 125 |
-
stem_names = ["vocals","
|
| 126 |
if stem_num == 4:
|
| 127 |
stem_names = ["vocals", "drums", "bass", "other"]
|
| 128 |
if stem_num == 5:
|
|
@@ -134,7 +134,7 @@ class Splitter(nn.Module):
|
|
| 134 |
self.hop_length = 1024
|
| 135 |
self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
|
| 136 |
self.stems = nn.ModuleDict({name: UNet(in_channels=2) for name in stem_names})
|
| 137 |
-
self.load_state_dict(
|
| 138 |
self.eval()
|
| 139 |
|
| 140 |
def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
|
|
@@ -189,7 +189,8 @@ class Splitter(nn.Module):
|
|
| 189 |
|
| 190 |
return wav.detach()
|
| 191 |
|
| 192 |
-
|
|
|
|
| 193 |
# stft - 2 X F x L x 2
|
| 194 |
# stft_mag - 2 X F x L
|
| 195 |
stft, stft_mag = self.compute_stft(wav.squeeze())
|
|
@@ -199,7 +200,7 @@ class Splitter(nn.Module):
|
|
| 199 |
stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
|
| 200 |
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
|
| 201 |
# compute stems' mask
|
| 202 |
-
masks = self.infer_with_batches(stft_mag,batch_size)
|
| 203 |
# compute denominator
|
| 204 |
mask_sum = sum([m**2 for m in masks.values()])
|
| 205 |
mask_sum += 1e-10
|
|
@@ -212,14 +213,15 @@ class Splitter(nn.Module):
|
|
| 212 |
return stft_masked
|
| 213 |
return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()}
|
| 214 |
|
| 215 |
-
def infer_with_batches(self, stft_mag, batch_size):
|
| 216 |
masks = {name: [] for name in self.stems.keys()}
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
| 5 |
from torch import nn, Tensor
|
| 6 |
from torch.nn import functional as F
|
| 7 |
from tqdm import tqdm
|
| 8 |
+
from safetensors.torch import load_file
|
| 9 |
def batchify(tensor: Tensor, T: int) -> Tensor:
|
| 10 |
orig_size = tensor.size(-1)
|
| 11 |
new_size = math.ceil(orig_size / T) * T
|
|
|
|
| 122 |
def __init__(self, stem_num=2):
|
| 123 |
super(Splitter, self).__init__()
|
| 124 |
if stem_num == 2:
|
| 125 |
+
stem_names = ["vocals","other"]
|
| 126 |
if stem_num == 4:
|
| 127 |
stem_names = ["vocals", "drums", "bass", "other"]
|
| 128 |
if stem_num == 5:
|
|
|
|
| 134 |
self.hop_length = 1024
|
| 135 |
self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
|
| 136 |
self.stems = nn.ModuleDict({name: UNet(in_channels=2) for name in stem_names})
|
| 137 |
+
self.load_state_dict(load_file(hf_hub_download("shethjenil/spleeter",f"{stem_num}.safetensors")))
|
| 138 |
self.eval()
|
| 139 |
|
| 140 |
def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
|
| 189 |
|
| 190 |
return wav.detach()
|
| 191 |
|
| 192 |
+
@torch.inference_mode()
|
| 193 |
+
def forward(self, wav: Tensor,batch_size=16,allow=['vocals']) -> Dict[str, Tensor]:
|
| 194 |
# stft - 2 X F x L x 2
|
| 195 |
# stft_mag - 2 X F x L
|
| 196 |
stft, stft_mag = self.compute_stft(wav.squeeze())
|
|
|
|
| 200 |
stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
|
| 201 |
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
|
| 202 |
# compute stems' mask
|
| 203 |
+
masks = self.infer_with_batches(stft_mag,batch_size,allow)
|
| 204 |
# compute denominator
|
| 205 |
mask_sum = sum([m**2 for m in masks.values()])
|
| 206 |
mask_sum += 1e-10
|
|
|
|
| 213 |
return stft_masked
|
| 214 |
return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()}
|
| 215 |
|
| 216 |
+
def infer_with_batches(self, stft_mag, batch_size, allow):
|
| 217 |
masks = {name: [] for name in self.stems.keys()}
|
| 218 |
+
for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
|
| 219 |
+
batch = stft_mag[i:i + batch_size]
|
| 220 |
+
batch_outputs = {name: net(batch) for name, net in self.stems.items() if name in allow}
|
| 221 |
+
for name in batch_outputs:
|
| 222 |
+
masks[name].append(batch_outputs[name])
|
| 223 |
+
return {
|
| 224 |
+
name: torch.cat(masks[name], dim=0)
|
| 225 |
+
for name in masks
|
| 226 |
+
if masks[name]
|
| 227 |
+
}
|