shethjenil commited on
Commit
7328533
·
verified ·
1 Parent(s): f7b8916

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -58
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import torchaudio
2
- import math
3
  import torch
4
  from typing import Dict, Tuple
5
  from huggingface_hub import hf_hub_download
@@ -9,7 +8,7 @@ from torch import nn, Tensor
9
  from torch.nn import functional as F
10
  from torch.utils.data import Dataset , DataLoader
11
  from torchaudio.transforms import Fade
12
- from torchaudio.models._hdemucs import HDemucs
13
 
14
  class Crop2d(nn.Module):
15
  def __init__(self, left, right, top, bottom):
@@ -37,9 +36,9 @@ class DecoderBlock(nn.Module):
37
  def __init__(self, in_channels: int, out_channels: int) -> None:
38
  super().__init__()
39
  self.tconv = nn.ConvTranspose2d(in_channels, out_channels, 5, 2)
 
40
  self.bn = nn.BatchNorm2d(out_channels,0.001,0.01)
41
  self.relu = nn.ReLU()
42
- self.crop = Crop2d(1, 2, 1, 2) # reverse padding
43
  def forward(self, input: Tensor) -> Tensor:
44
  return self.bn(self.relu(self.crop(self.tconv(input))))
45
 
@@ -57,6 +56,7 @@ class UNet(nn.Module):
57
  self.decoder_layers = nn.ModuleList([DecoderBlock(in_ch if i == 0 else in_ch * 2,out_ch) for i, (in_ch, out_ch) in enumerate(zip(up_set[:-1], up_set[1:]))])
58
  self.up_final = nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3)
59
  self.sigmoid = nn.Sigmoid()
 
60
  def forward(self, input: Tensor) -> Tensor:
61
  encoder_outputs_pre_act = []
62
  x = input
@@ -79,78 +79,129 @@ class UNet(nn.Module):
79
  input = input[..., :min_f, :min_t]
80
  return mask * input
81
 
82
- class ChunkSplitterDataset(Dataset):
83
- def __init__(self, wav, win):
84
- self.win_length = 4096
85
- self.hop_length = 1024
86
  self.win = win
87
- self.T = 512
88
- self.stft_mag = self.batchify(self.compute_stft(wav))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def __len__(self):
91
- return self.stft_mag.size(0)
92
 
93
  def __getitem__(self, idx):
94
- return self.stft_mag[idx]
95
-
96
- def compute_stft(self, wav: Tensor):
97
- stft = torch.stft(wav.squeeze(),n_fft=self.win_length,hop_length=self.hop_length,window=self.win,return_complex=False,pad_mode="constant",)
98
- stft = stft[:, :1024, :, :]
99
- real = stft[:, :, :, 0]
100
- imag = stft[:, :, :, 1]
101
- self.stft = stft
102
- self.L = self.stft.size(2)
103
- return torch.sqrt(real**2 + imag**2 + 1e-10).unsqueeze(-1).permute([3, 0, 1, 2])
104
-
105
- def batchify(self, tensor: Tensor) -> Tensor:
106
- orig_size = tensor.size(-1)
107
- new_size = math.ceil(orig_size / self.T) * self.T
108
- tensor = F.pad(tensor, [0, new_size - orig_size])
109
- return torch.cat(torch.split(tensor, self.T, dim=-1), dim=0).transpose(2, 3)
110
-
111
- def apply_mask(self,mask,mask_sum):
112
- mask = (mask**2 + 1e-10 / 2) / (mask_sum)
113
- mask = mask.transpose(2, 3) # B x 2 X F x T
114
- mask = torch.cat(torch.split(mask, 1, dim=0), dim=3)
115
- mask = mask.squeeze(0)[:, :, :self.L].unsqueeze(-1) # 2 x F x L x 1
116
- stft = self.stft * mask
117
- target_F = self.win_length // 2 + 1
118
- if stft.size(1) < target_F:
119
- pad = target_F - stft.size(1)
120
- stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) # pad along freq dim
121
- return torch.istft(torch.view_as_complex(stft),n_fft=self.win_length,hop_length=self.hop_length,win_length=self.win_length,center=True,window=self.win)
122
-
123
- def decoder(self,masks):
124
- mask_sum = sum([m**2 for m in masks.values()]) + 1e-10
125
- return {name: self.apply_mask(m,mask_sum) for name, m in masks.items()}
126
 
127
  class Splitter(nn.Module):
128
  CONFIG = {
129
- 2:['2_other', '2_vocals'],
130
- 4:['4_bass', '4_drums', '4_other', '4_vocals'],
131
- 5:['5_piano','5_bass', '5_drums','5_other', '5_vocals']
132
- }
 
133
  def __init__(self, stem=2):
134
  super().__init__()
135
  self.win_length = 4096
136
- self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
137
- self.stems = nn.ModuleDict({name: UNet() for name in self.CONFIG[stem]})
 
 
 
 
 
 
 
 
 
 
 
138
  for name in self.stems:
139
- self.stems[name].load_state_dict(load_file(hf_hub_download("shethjenil/spleeter",f"{name}.safetensors")))
 
 
 
 
 
140
  self.eval()
141
 
142
  @torch.inference_mode()
143
- def forward(self, wav: Tensor,sr:int,batch_size) -> Dict[str, Tensor]:
144
  device = next(self.parameters()).device
 
145
  if sr != 44100:
146
  wav = torchaudio.functional.resample(wav, sr, 44100)
147
- ds = ChunkSplitterDataset(wav.to(device),self.win)
148
- masks = {name: [] for name in self.stems}
149
- for batch in tqdm(DataLoader(ds,batch_size)):
150
- outputs = {name: net(batch) for name, net in self.stems.items()}
151
- for name in outputs:
152
- masks[name].append(outputs[name])
153
- return ds.decoder({k: torch.cat(v, dim=0) for k, v in masks.items()})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  def separate_audio_spleeter(audio_path:str,batch_size:int,stem:int,progress=gr.Progress(True)):
156
  wav, sr = torchaudio.load(audio_path)
@@ -293,4 +344,3 @@ gr.TabbedInterface([
293
  gr.Interface(separate_audio_spleeter, [gr.Audio(type="filepath"),gr.Number(16),gr.Radio([2,4,5],label="STEM")],gr.Files()),
294
  gr.Interface(separate_audio_demucs, [gr.Audio(type="filepath"),gr.Number(16),gr.Radio([4],label="STEM")],gr.Files())
295
  ],['spleeter','demucs']).launch()
296
-
 
1
  import torchaudio
 
2
  import torch
3
  from typing import Dict, Tuple
4
  from huggingface_hub import hf_hub_download
 
8
  from torch.nn import functional as F
9
  from torch.utils.data import Dataset , DataLoader
10
  from torchaudio.transforms import Fade
11
+ from torchaudio.models import HDemucs
12
 
13
  class Crop2d(nn.Module):
14
  def __init__(self, left, right, top, bottom):
 
36
  def __init__(self, in_channels: int, out_channels: int) -> None:
37
  super().__init__()
38
  self.tconv = nn.ConvTranspose2d(in_channels, out_channels, 5, 2)
39
+ self.crop = Crop2d(1, 2, 1, 2) # reverse padding
40
  self.bn = nn.BatchNorm2d(out_channels,0.001,0.01)
41
  self.relu = nn.ReLU()
 
42
  def forward(self, input: Tensor) -> Tensor:
43
  return self.bn(self.relu(self.crop(self.tconv(input))))
44
 
 
56
  self.decoder_layers = nn.ModuleList([DecoderBlock(in_ch if i == 0 else in_ch * 2,out_ch) for i, (in_ch, out_ch) in enumerate(zip(up_set[:-1], up_set[1:]))])
57
  self.up_final = nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3)
58
  self.sigmoid = nn.Sigmoid()
59
+
60
  def forward(self, input: Tensor) -> Tensor:
61
  encoder_outputs_pre_act = []
62
  x = input
 
79
  input = input[..., :min_f, :min_t]
80
  return mask * input
81
 
82
+ class STFTChunkDataset(Dataset):
83
+ def __init__(self, wav, win, win_length=4096, T=512, F=1024):
84
+ self.win_length = win_length
 
85
  self.win = win
86
+ self.T = T
87
+ self.F = F
88
+
89
+ wav = wav.view(wav.size(0), -1)
90
+
91
+ stft = torch.stft(
92
+ wav,
93
+ n_fft=win_length,
94
+ window=win,
95
+ return_complex=True,
96
+ pad_mode="constant"
97
+ )[:, :F, :]
98
+
99
+ self.L = stft.size(-1)
100
+ self.stft_complex = torch.view_as_real(stft)
101
+
102
+ mag = stft.abs().unsqueeze(1) # (1, 1, F, L)
103
+
104
+ # pad time to multiple of T
105
+ pad_T = (T - self.L % T) % T
106
+ mag = F.pad(mag, (0, pad_T))
107
+
108
+ # split into chunks
109
+ self.chunks = mag.view(1, 1, F, -1, T)\
110
+ .permute(3, 0, 1, 2, 4)\
111
+ .squeeze(1)
112
+ # shape: (num_chunks, 1, F, T)
113
 
114
  def __len__(self):
115
+ return self.chunks.size(0)
116
 
117
  def __getitem__(self, idx):
118
+ return self.chunks[idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  class Splitter(nn.Module):
121
  CONFIG = {
122
+ 2: ['2_other', '2_vocals'],
123
+ 4: ['4_bass', '4_drums', '4_other', '4_vocals'],
124
+ 5: ['5_piano', '5_bass', '5_drums', '5_other', '5_vocals']
125
+ }
126
+
127
  def __init__(self, stem=2):
128
  super().__init__()
129
  self.win_length = 4096
130
+ self.T = 512
131
+ self.F = 1024
132
+ self.target_F = self.win_length // 2 + 1
133
+
134
+ self.win = nn.Parameter(
135
+ torch.hann_window(self.win_length),
136
+ requires_grad=False
137
+ )
138
+
139
+ self.stems = nn.ModuleDict({
140
+ name: UNet() for name in self.CONFIG[stem]
141
+ })
142
+
143
  for name in self.stems:
144
+ self.stems[name].load_state_dict(
145
+ load_file(
146
+ hf_hub_download("shethjenil/spleeter", f"{name}.safetensors")
147
+ )
148
+ )
149
+
150
  self.eval()
151
 
152
  @torch.inference_mode()
153
+ def forward(self, wav, sr, batch_size):
154
  device = next(self.parameters()).device
155
+
156
  if sr != 44100:
157
  wav = torchaudio.functional.resample(wav, sr, 44100)
158
+
159
+ wav = wav.to(device)
160
+
161
+ ds = STFTChunkDataset(wav, self.win)
162
+ loader = DataLoader(
163
+ ds,
164
+ batch_size=batch_size,
165
+ shuffle=False, # IMPORTANT
166
+ pin_memory=True
167
+ )
168
+
169
+ masks = {k: [] for k in self.stems}
170
+
171
+ for batch in loader:
172
+ batch = batch.to(device)
173
+
174
+ for name, net in self.stems.items():
175
+ masks[name].append(net(batch))
176
+
177
+ masks = {k: torch.cat(v, dim=0) for k, v in masks.items()}
178
+
179
+ return self.decode(masks, ds)
180
+
181
+ def decode(self, masks, ds):
182
+ mask_sum = sum(m ** 2 for m in masks.values()) + 1e-10
183
+ outputs = {}
184
+
185
+ for name, m in masks.items():
186
+ mask = (m ** 2 / mask_sum)
187
+
188
+ # (chunks, 1, F, T) → (1, F, time)
189
+ mask = mask.permute(1, 2, 0, 3).reshape(1, self.F, -1)
190
+ mask = mask[:, :, :ds.L]
191
+
192
+ stft = ds.stft_complex * mask.unsqueeze(-1)
193
+
194
+ if stft.size(1) < self.target_F:
195
+ pad = self.target_F - stft.size(1)
196
+ stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
197
+
198
+ outputs[name] = torch.istft(
199
+ torch.view_as_complex(stft),
200
+ n_fft=self.win_length,
201
+ window=self.win
202
+ )
203
+
204
+ return outputs
205
 
206
  def separate_audio_spleeter(audio_path:str,batch_size:int,stem:int,progress=gr.Progress(True)):
207
  wav, sr = torchaudio.load(audio_path)
 
344
  gr.Interface(separate_audio_spleeter, [gr.Audio(type="filepath"),gr.Number(16),gr.Radio([2,4,5],label="STEM")],gr.Files()),
345
  gr.Interface(separate_audio_demucs, [gr.Audio(type="filepath"),gr.Number(16),gr.Radio([4],label="STEM")],gr.Files())
346
  ],['spleeter','demucs']).launch()