Spaces:
Paused
Paused
Update spleeter.py
Browse files- spleeter.py +2 -10
spleeter.py
CHANGED
|
@@ -7,18 +7,9 @@ from torch.nn import functional as F
|
|
| 7 |
from tqdm import tqdm
|
| 8 |
|
| 9 |
def batchify(tensor: Tensor, T: int) -> Tensor:
|
| 10 |
-
"""
|
| 11 |
-
partition tensor into segments of length T, zero pad any ragged samples
|
| 12 |
-
Args:
|
| 13 |
-
tensor(Tensor): BxCxFxL
|
| 14 |
-
Returns:
|
| 15 |
-
tensor of size (B*[L/T] x C x F x T)
|
| 16 |
-
"""
|
| 17 |
-
# Zero pad the original tensor to an even multiple of T
|
| 18 |
orig_size = tensor.size(-1)
|
| 19 |
new_size = math.ceil(orig_size / T) * T
|
| 20 |
tensor = F.pad(tensor, [0, new_size - orig_size])
|
| 21 |
-
# Partition the tensor into multiple samples of length T and stack them into a batch
|
| 22 |
return torch.cat(torch.split(tensor, T, dim=-1), dim=0)
|
| 23 |
|
| 24 |
|
|
@@ -223,7 +214,7 @@ class Splitter(nn.Module):
|
|
| 223 |
|
| 224 |
def infer_with_batches(self, stft_mag, batch_size):
|
| 225 |
masks = {name: [] for name in self.stems.keys()}
|
| 226 |
-
with torch.
|
| 227 |
for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
|
| 228 |
batch = stft_mag[i:i + batch_size]
|
| 229 |
batch_outputs = {name: net(batch) for name, net in self.stems.items()}
|
|
@@ -231,3 +222,4 @@ class Splitter(nn.Module):
|
|
| 231 |
masks[name].append(batch_outputs[name])
|
| 232 |
masks = {name: torch.cat(masks[name], dim=0) for name in masks}
|
| 233 |
return masks
|
|
|
|
|
|
| 7 |
from tqdm import tqdm
|
| 8 |
|
| 9 |
def batchify(tensor: Tensor, T: int) -> Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
orig_size = tensor.size(-1)
|
| 11 |
new_size = math.ceil(orig_size / T) * T
|
| 12 |
tensor = F.pad(tensor, [0, new_size - orig_size])
|
|
|
|
| 13 |
return torch.cat(torch.split(tensor, T, dim=-1), dim=0)
|
| 14 |
|
| 15 |
|
|
|
|
| 214 |
|
| 215 |
def infer_with_batches(self, stft_mag, batch_size):
|
| 216 |
masks = {name: [] for name in self.stems.keys()}
|
| 217 |
+
with torch.inference_mode():
|
| 218 |
for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
|
| 219 |
batch = stft_mag[i:i + batch_size]
|
| 220 |
batch_outputs = {name: net(batch) for name, net in self.stems.items()}
|
|
|
|
| 222 |
masks[name].append(batch_outputs[name])
|
| 223 |
masks = {name: torch.cat(masks[name], dim=0) for name in masks}
|
| 224 |
return masks
|
| 225 |
+
|