shethjenil commited on
Commit
650b3a2
·
verified ·
1 Parent(s): ee6da54

Update spleeter.py

Browse files
Files changed (1) hide show
  1. spleeter.py +2 -10
spleeter.py CHANGED
@@ -7,18 +7,9 @@ from torch.nn import functional as F
7
  from tqdm import tqdm
8
 
9
  def batchify(tensor: Tensor, T: int) -> Tensor:
10
- """
11
- partition tensor into segments of length T, zero pad any ragged samples
12
- Args:
13
- tensor(Tensor): BxCxFxL
14
- Returns:
15
- tensor of size (B*[L/T] x C x F x T)
16
- """
17
- # Zero pad the original tensor to an even multiple of T
18
  orig_size = tensor.size(-1)
19
  new_size = math.ceil(orig_size / T) * T
20
  tensor = F.pad(tensor, [0, new_size - orig_size])
21
- # Partition the tensor into multiple samples of length T and stack them into a batch
22
  return torch.cat(torch.split(tensor, T, dim=-1), dim=0)
23
 
24
 
@@ -223,7 +214,7 @@ class Splitter(nn.Module):
223
 
224
  def infer_with_batches(self, stft_mag, batch_size):
225
  masks = {name: [] for name in self.stems.keys()}
226
- with torch.no_grad():
227
  for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
228
  batch = stft_mag[i:i + batch_size]
229
  batch_outputs = {name: net(batch) for name, net in self.stems.items()}
@@ -231,3 +222,4 @@ class Splitter(nn.Module):
231
  masks[name].append(batch_outputs[name])
232
  masks = {name: torch.cat(masks[name], dim=0) for name in masks}
233
  return masks
 
 
7
  from tqdm import tqdm
8
 
9
  def batchify(tensor: Tensor, T: int) -> Tensor:
 
 
 
 
 
 
 
 
10
  orig_size = tensor.size(-1)
11
  new_size = math.ceil(orig_size / T) * T
12
  tensor = F.pad(tensor, [0, new_size - orig_size])
 
13
  return torch.cat(torch.split(tensor, T, dim=-1), dim=0)
14
 
15
 
 
214
 
215
  def infer_with_batches(self, stft_mag, batch_size):
216
  masks = {name: [] for name in self.stems.keys()}
217
+ with torch.inference_mode():
218
  for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
219
  batch = stft_mag[i:i + batch_size]
220
  batch_outputs = {name: net(batch) for name, net in self.stems.items()}
 
222
  masks[name].append(batch_outputs[name])
223
  masks = {name: torch.cat(masks[name], dim=0) for name in masks}
224
  return masks
225
+