Spaces:
Runtime error
Runtime error
pengdaqian
commited on
Commit
·
d7659a0
1
Parent(s):
d853526
fix
Browse files- app.py +1 -1
- torchspleeter/estimator.py +16 -13
app.py
CHANGED
|
@@ -196,7 +196,7 @@ def svc_main(sid, input_audio):
|
|
| 196 |
if not os.path.exists(tmpfile_path):
|
| 197 |
os.makedirs(tmpfile_path)
|
| 198 |
|
| 199 |
-
split_to_parts(input_audio_tmp_file, tmpfile_path
|
| 200 |
|
| 201 |
curr_tmp_path = os.path.join(tmpfile_path, os.path.splitext(input_audio_tmp_file)[0])
|
| 202 |
vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
|
|
|
|
| 196 |
if not os.path.exists(tmpfile_path):
|
| 197 |
os.makedirs(tmpfile_path)
|
| 198 |
|
| 199 |
+
split_to_parts(input_audio_tmp_file, tmpfile_path)
|
| 200 |
|
| 201 |
curr_tmp_path = os.path.join(tmpfile_path, os.path.splitext(input_audio_tmp_file)[0])
|
| 202 |
vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
|
torchspleeter/estimator.py
CHANGED
|
@@ -7,12 +7,15 @@ import tqdm
|
|
| 7 |
# from torchaudio.functional import istft
|
| 8 |
|
| 9 |
from torchspleeter.unet import UNet
|
| 10 |
-
#from .util import tf2pytorch
|
| 11 |
|
| 12 |
import os
|
|
|
|
| 13 |
dirname = os.path.dirname(__file__)
|
| 14 |
defaultmodel0 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint0.ckpt')
|
| 15 |
defaultmodel1 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint1.ckpt')
|
|
|
|
|
|
|
| 16 |
|
| 17 |
def load_ckpt(model, ckpt):
|
| 18 |
state_dict = model.state_dict()
|
|
@@ -39,7 +42,7 @@ def pad_and_partition(tensor, T):
|
|
| 39 |
tensor of size (B*[L/T] x C x F x T)
|
| 40 |
"""
|
| 41 |
old_size = tensor.size(3)
|
| 42 |
-
new_size = math.ceil(old_size/T) * T
|
| 43 |
tensor = F.pad(tensor, [0, new_size - old_size])
|
| 44 |
[b, c, t, f] = tensor.shape
|
| 45 |
split = new_size // T
|
|
@@ -50,29 +53,29 @@ class Estimator(nn.Module):
|
|
| 50 |
def __init__(self, num_instrumments=2, checkpoint_path=None):
|
| 51 |
super(Estimator, self).__init__()
|
| 52 |
if checkpoint_path is None:
|
| 53 |
-
checkpoint_path=[defaultmodel0,defaultmodel1]
|
| 54 |
else:
|
| 55 |
-
if len(checkpoint_path)<1:
|
| 56 |
-
checkpoint_path=[defaultmodel0,defaultmodel1]
|
| 57 |
# stft config
|
| 58 |
self.F = 1024
|
| 59 |
self.T = 512
|
| 60 |
self.win_length = 4096
|
| 61 |
self.hop_length = 1024
|
| 62 |
self.win = nn.Parameter(
|
| 63 |
-
torch.hann_window(self.win_length),
|
| 64 |
requires_grad=False
|
| 65 |
)
|
| 66 |
|
| 67 |
-
ckpts=[]
|
| 68 |
if len(checkpoint_path) != num_instrumments:
|
| 69 |
raise ValueError("You must submit as many models as there are instruments!")
|
| 70 |
for ckpt_path in checkpoint_path:
|
| 71 |
ckpts.append(torch.load(ckpt_path))
|
| 72 |
|
| 73 |
-
#self.ckpts = ckpt #torch.load(checkpoint_path)#, num_instrumments)
|
| 74 |
|
| 75 |
-
#ckpts = #tf2pytorch(checkpoint_path, num_instrumments)
|
| 76 |
|
| 77 |
# filter
|
| 78 |
self.instruments = nn.ModuleList()
|
|
@@ -109,7 +112,7 @@ class Estimator(nn.Module):
|
|
| 109 |
pad = self.win_length // 2 + 1 - stft.size(1)
|
| 110 |
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
|
| 111 |
wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
|
| 112 |
-
|
| 113 |
return wav.detach()
|
| 114 |
|
| 115 |
def separate(self, wav):
|
|
@@ -145,14 +148,14 @@ class Estimator(nn.Module):
|
|
| 145 |
|
| 146 |
wavs = []
|
| 147 |
for mask in tqdm.tqdm(masks):
|
| 148 |
-
mask = (mask ** 2 + 1e-10/2)/(mask_sum)
|
| 149 |
mask = mask.transpose(2, 3) # B x 2 X F x T
|
| 150 |
|
| 151 |
mask = torch.cat(
|
| 152 |
torch.split(mask, 1, dim=0), dim=3)
|
| 153 |
|
| 154 |
-
mask = mask.squeeze(0)[
|
| 155 |
-
stft_masked = stft *
|
| 156 |
wav_masked = self.inverse_stft(stft_masked)
|
| 157 |
|
| 158 |
wavs.append(wav_masked)
|
|
|
|
| 7 |
# from torchaudio.functional import istft
|
| 8 |
|
| 9 |
from torchspleeter.unet import UNet
|
| 10 |
+
# from .util import tf2pytorch
|
| 11 |
|
| 12 |
import os
|
| 13 |
+
|
| 14 |
dirname = os.path.dirname(__file__)
|
| 15 |
defaultmodel0 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint0.ckpt')
|
| 16 |
defaultmodel1 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint1.ckpt')
|
| 17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
|
| 19 |
|
| 20 |
def load_ckpt(model, ckpt):
|
| 21 |
state_dict = model.state_dict()
|
|
|
|
| 42 |
tensor of size (B*[L/T] x C x F x T)
|
| 43 |
"""
|
| 44 |
old_size = tensor.size(3)
|
| 45 |
+
new_size = math.ceil(old_size / T) * T
|
| 46 |
tensor = F.pad(tensor, [0, new_size - old_size])
|
| 47 |
[b, c, t, f] = tensor.shape
|
| 48 |
split = new_size // T
|
|
|
|
| 53 |
def __init__(self, num_instrumments=2, checkpoint_path=None):
|
| 54 |
super(Estimator, self).__init__()
|
| 55 |
if checkpoint_path is None:
|
| 56 |
+
checkpoint_path = [defaultmodel0, defaultmodel1]
|
| 57 |
else:
|
| 58 |
+
if len(checkpoint_path) < 1:
|
| 59 |
+
checkpoint_path = [defaultmodel0, defaultmodel1]
|
| 60 |
# stft config
|
| 61 |
self.F = 1024
|
| 62 |
self.T = 512
|
| 63 |
self.win_length = 4096
|
| 64 |
self.hop_length = 1024
|
| 65 |
self.win = nn.Parameter(
|
| 66 |
+
torch.hann_window(self.win_length, device=device),
|
| 67 |
requires_grad=False
|
| 68 |
)
|
| 69 |
|
| 70 |
+
ckpts = []
|
| 71 |
if len(checkpoint_path) != num_instrumments:
|
| 72 |
raise ValueError("You must submit as many models as there are instruments!")
|
| 73 |
for ckpt_path in checkpoint_path:
|
| 74 |
ckpts.append(torch.load(ckpt_path))
|
| 75 |
|
| 76 |
+
# self.ckpts = ckpt #torch.load(checkpoint_path)#, num_instrumments)
|
| 77 |
|
| 78 |
+
# ckpts = #tf2pytorch(checkpoint_path, num_instrumments)
|
| 79 |
|
| 80 |
# filter
|
| 81 |
self.instruments = nn.ModuleList()
|
|
|
|
| 112 |
pad = self.win_length // 2 + 1 - stft.size(1)
|
| 113 |
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
|
| 114 |
wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
|
| 115 |
+
window=self.win)
|
| 116 |
return wav.detach()
|
| 117 |
|
| 118 |
def separate(self, wav):
|
|
|
|
| 148 |
|
| 149 |
wavs = []
|
| 150 |
for mask in tqdm.tqdm(masks):
|
| 151 |
+
mask = (mask ** 2 + 1e-10 / 2) / (mask_sum)
|
| 152 |
mask = mask.transpose(2, 3) # B x 2 X F x T
|
| 153 |
|
| 154 |
mask = torch.cat(
|
| 155 |
torch.split(mask, 1, dim=0), dim=3)
|
| 156 |
|
| 157 |
+
mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1
|
| 158 |
+
stft_masked = stft * mask
|
| 159 |
wav_masked = self.inverse_stft(stft_masked)
|
| 160 |
|
| 161 |
wavs.append(wav_masked)
|