vectominist commited on
Commit
737a477
·
verified ·
1 Parent(s): e5ae5a0

Update usad_model.py

Browse files
Files changed (1) hide show
  1. usad_model.py +211 -54
usad_model.py CHANGED
@@ -1,10 +1,14 @@
 
1
  from dataclasses import make_dataclass
 
2
 
3
  import torch
4
  import torchaudio
5
  from torch import nn
 
 
6
 
7
- from .usad_modules import ConformerEncoder
8
 
9
  MAX_MEL_LENGTH = 3000 # 30 seconds
10
 
@@ -15,41 +19,77 @@ def wav_to_fbank(
15
  mel_dim: int = 128,
16
  norm_mean: float = -4.268,
17
  norm_std: float = 4.569,
18
- ) -> torch.Tensor:
 
 
 
19
  """Convert waveform to fbank features.
20
 
21
  Args:
22
  wavs (torch.Tensor): (B, T_wav) waveform tensor.
23
  mel_dim (int, optional): mel dimension. Defaults to 128.
24
- norm_mean (float, optional):
25
- mean for normalization. Defaults to -4.268.
26
- norm_std (float, optional):
27
- std for normalization. Defaults to 4.569.
 
28
 
29
  Returns:
30
- torch.Tensor: (B, T_mel, mel_dim) fbank features.
 
31
  """
32
  # ref: https://github.com/cwx-worst-one/EAT/tree/main/feature_extract
33
- dtype = wavs.dtype
34
- wavs = wavs.to(torch.float32)
35
- wavs = wavs - wavs.mean(dim=-1, keepdim=True)
36
- feats = [
37
- torchaudio.compliance.kaldi.fbank(
38
- wavs[i : i + 1],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  htk_compat=True,
40
- sample_frequency=16000,
41
  use_energy=False,
42
  window_type="hanning",
43
  num_mel_bins=mel_dim,
44
  dither=0.0,
45
  frame_shift=10,
46
- ).to(dtype=dtype)
47
- for i in range(wavs.shape[0])
48
- ]
49
-
50
- mels = torch.stack(feats, dim=0)
51
- mels = (mels - norm_mean) / (norm_std * 2)
52
-
 
 
 
 
 
 
53
  return mels
54
 
55
 
@@ -64,8 +104,6 @@ class UsadModel(nn.Module):
64
  self.cfg = cfg
65
  self.encoder = ConformerEncoder(cfg)
66
  self.max_mel_length = MAX_MEL_LENGTH
67
- # NOTE: The max_mel_length is set to 3000,
68
- # which corresponds to 30 seconds of audio at 100 Hz frame rate.
69
 
70
  @property
71
  def sample_rate(self) -> int:
@@ -73,7 +111,7 @@ class UsadModel(nn.Module):
73
 
74
  @property
75
  def encoder_frame_rate(self) -> int:
76
- return 50 # Hz
77
 
78
  @property
79
  def mel_dim(self) -> int:
@@ -100,9 +138,12 @@ class UsadModel(nn.Module):
100
  """Get the device on which the model is located."""
101
  return next(self.parameters()).device
102
 
 
 
 
 
103
  def set_audio_chunk_size(self, seconds: float = 30.0) -> None:
104
  """Set the maximum chunk size for feature extraction.
105
-
106
  Args:
107
  seconds (float, optional): Chunk size in seconds. Defaults to 30.0.
108
  """
@@ -111,86 +152,202 @@ class UsadModel(nn.Module):
111
  ), f"Chunk size must be greater than 0.1s, got {seconds} seconds."
112
  self.max_mel_length = int(seconds * 100) # 100 Hz frame rate
113
 
114
- def load_audio(self, audio_path: str) -> torch.Tensor:
 
 
115
  """Load audio file and return waveform tensor.
116
  Args:
117
  audio_path (str): Path to the audio file.
118
-
119
  Returns:
120
  torch.Tensor: Waveform tensor of shape (wav_len,).
121
  """
122
 
123
  waveform, sr = torchaudio.load(audio_path)
124
  if sr != self.sample_rate:
125
- waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
 
 
126
  if waveform.shape[0] > 1:
127
  # If stereo, convert to mono by averaging channels
128
  waveform = waveform.mean(dim=0, keepdim=True)
129
 
130
  waveform = waveform.squeeze(0) # Remove channel dimension if mono
131
- return waveform.to(self.device) # Ensure tensor is on the same device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  def forward(
134
  self,
135
  wavs: torch.Tensor,
 
 
 
136
  norm_mean: float = -4.268,
137
  norm_std: float = 4.569,
138
  ) -> dict:
139
- """Forward pass for the model.
140
-
141
  Args:
142
- wavs (torch.Tensor):
143
- Input waveform tensor of shape (batch_size, wav_len).
144
- norm_mean (float, optional):
145
- Mean for normalization. Defaults to -4.268.
146
- norm_std (float, optional):
147
- Standard deviation for normalization. Defaults to 4.569.
148
-
149
  Returns:
150
- dict: A dictionary containing the model's outputs.
 
 
 
 
 
 
 
151
  """
152
- # wavs: (batch_size, wav_len)
153
 
154
- mel = wav_to_fbank(wavs, norm_mean=norm_mean, norm_std=norm_std)
155
- mel = mel[:, : mel.shape[1] - mel.shape[1] % 2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  if mel.shape[1] <= self.max_mel_length:
157
- x, x_len, layer_results = self.encoder(mel, return_hidden=True)
 
 
 
 
 
 
158
 
159
  result = {
160
  "x": x,
 
 
 
 
161
  "mel": mel,
 
162
  "hidden_states": layer_results["hidden_states"],
163
  "ffn": layer_results["ffn_1"],
164
  }
165
  return result
166
 
 
167
  result = {
168
  "x": [],
 
169
  "mel": mel,
170
- "hidden_states": [[] for _ in range(self.cfg.num_layers)],
171
- "ffn": [[] for _ in range(self.cfg.num_layers)],
 
172
  }
173
  for i in range(0, mel.shape[1], self.max_mel_length):
174
  if mel.shape[1] - i < 10:
175
  break
176
 
 
 
 
 
 
 
 
177
  x, x_len, layer_results = self.encoder(
178
- mel[:, i : i + self.max_mel_length], return_hidden=True
 
 
 
179
  )
 
180
  result["x"].append(x)
181
- for j in range(self.cfg.num_layers):
182
- result["hidden_states"][j].append(layer_results["hidden_states"][j])
 
 
 
183
  result["ffn"][j].append(layer_results["ffn_1"][j])
184
 
185
  result["x"] = torch.cat(result["x"], dim=1)
186
- for j in range(self.cfg.num_layers):
187
- result["hidden_states"][j] = torch.cat(result["hidden_states"][j], dim=1)
 
 
 
 
 
 
 
 
188
  result["ffn"][j] = torch.cat(result["ffn"][j], dim=1)
189
 
190
- # result["x"]: model final output (batch_size, seq_len)
191
- # result["mel"]: mel fbank (batch_size, seq_len * 2, mel_dim)
192
- # result["hidden_states"]: List of (batch_size, seq_len, encoder_dim)
193
- # result["ffn"]: List of (batch_size, seq_len, encoder_dim)
194
  return result
195
 
196
  @classmethod
 
1
+ import os
2
  from dataclasses import make_dataclass
3
+ from typing import List, Optional, Tuple, Union
4
 
5
  import torch
6
  import torchaudio
7
  from torch import nn
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from torchaudio.compliance.kaldi import fbank
10
 
11
+ from .usad_modules import ConformerEncoder, lengths_to_padding_mask
12
 
13
  MAX_MEL_LENGTH = 3000 # 30 seconds
14
 
 
19
  mel_dim: int = 128,
20
  norm_mean: float = -4.268,
21
  norm_std: float = 4.569,
22
+ wav_lengths: Optional[torch.Tensor] = None,
23
+ sample_rate: int = 16000,
24
+ return_lengths: bool = False,
25
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
26
  """Convert waveform to fbank features.
27
 
28
  Args:
29
  wavs (torch.Tensor): (B, T_wav) waveform tensor.
30
  mel_dim (int, optional): mel dimension. Defaults to 128.
31
+ norm_mean (float, optional): mean for normalization. Defaults to -4.268.
32
+ norm_std (float, optional): std for normalization. Defaults to 4.569.
33
+ wav_lengths (torch.Tensor, optional): (B,) valid waveform lengths before padding.
34
+ sample_rate (int, optional): waveform sample rate. Defaults to 16000.
35
+ return_lengths (bool, optional): return exact fbank lengths. Defaults to False.
36
 
37
  Returns:
38
+ torch.Tensor: (B, T_mel, mel_dim) fbank features. If return_lengths is True,
39
+ also returns a (B,) tensor with exact feature lengths before padding.
40
  """
41
  # ref: https://github.com/cwx-worst-one/EAT/tree/main/feature_extract
42
+ feature_dtype = wavs.dtype if wavs.is_floating_point() else torch.float32
43
+ wavs_float = wavs.to(torch.float32)
44
+
45
+ if wav_lengths is None:
46
+ wav_lengths = torch.full(
47
+ (wavs.shape[0],),
48
+ wavs.shape[1],
49
+ dtype=torch.long,
50
+ device=wavs.device,
51
+ )
52
+ else:
53
+ wav_lengths = wav_lengths.to(device=wavs.device, dtype=torch.long)
54
+ if wav_lengths.dim() != 1 or wav_lengths.shape[0] != wavs.shape[0]:
55
+ raise ValueError(
56
+ "wav_lengths must be a 1-D tensor with batch size elements."
57
+ )
58
+ if torch.any(wav_lengths <= 0).item():
59
+ raise ValueError("All wav_lengths values must be positive.")
60
+ if torch.any(wav_lengths > wavs.shape[1]).item():
61
+ raise ValueError(
62
+ "wav_lengths cannot exceed the padded waveform length."
63
+ )
64
+
65
+ feats = []
66
+ feat_lengths = []
67
+ for i, wav_length in enumerate(wav_lengths.detach().cpu().tolist()):
68
+ # Trim padding before centering so batched padding cannot affect valid audio.
69
+ wav = wavs_float[i, :wav_length]
70
+ wav = wav - wav.mean(dim=-1, keepdim=True)
71
+ feat = fbank(
72
+ wav.unsqueeze(0),
73
  htk_compat=True,
74
+ sample_frequency=sample_rate,
75
  use_energy=False,
76
  window_type="hanning",
77
  num_mel_bins=mel_dim,
78
  dither=0.0,
79
  frame_shift=10,
80
+ )
81
+ feat = feat[: feat.shape[0] - feat.shape[0] % 2, :] # For compatibility
82
+ feat = (feat - norm_mean) / (norm_std * 2)
83
+ feats.append(feat.to(dtype=feature_dtype))
84
+ feat_lengths.append(feat.shape[0])
85
+
86
+ mels = pad_sequence(feats, batch_first=True, padding_value=0.0)
87
+ mel_lengths = torch.tensor(
88
+ feat_lengths, dtype=torch.long, device=wavs.device
89
+ )
90
+
91
+ if return_lengths:
92
+ return mels, mel_lengths
93
  return mels
94
 
95
 
 
104
  self.cfg = cfg
105
  self.encoder = ConformerEncoder(cfg)
106
  self.max_mel_length = MAX_MEL_LENGTH
 
 
107
 
108
  @property
109
  def sample_rate(self) -> int:
 
111
 
112
  @property
113
  def encoder_frame_rate(self) -> int:
114
+ return round(100 / self.cfg.conv_subsample_rate) # Hz
115
 
116
  @property
117
  def mel_dim(self) -> int:
 
138
  """Get the device on which the model is located."""
139
  return next(self.parameters()).device
140
 
141
+ @property
142
+ def dtype(self) -> torch.dtype:
143
+ return next(self.parameters()).dtype
144
+
145
  def set_audio_chunk_size(self, seconds: float = 30.0) -> None:
146
  """Set the maximum chunk size for feature extraction.
 
147
  Args:
148
  seconds (float, optional): Chunk size in seconds. Defaults to 30.0.
149
  """
 
152
  ), f"Chunk size must be greater than 0.1s, got {seconds} seconds."
153
  self.max_mel_length = int(seconds * 100) # 100 Hz frame rate
154
 
155
+ def load_audio(
156
+ self, audio_path: str, move_to_device: bool = True
157
+ ) -> torch.Tensor:
158
  """Load audio file and return waveform tensor.
159
  Args:
160
  audio_path (str): Path to the audio file.
 
161
  Returns:
162
  torch.Tensor: Waveform tensor of shape (wav_len,).
163
  """
164
 
165
  waveform, sr = torchaudio.load(audio_path)
166
  if sr != self.sample_rate:
167
+ waveform = torchaudio.functional.resample(
168
+ waveform, sr, self.sample_rate
169
+ )
170
  if waveform.shape[0] > 1:
171
  # If stereo, convert to mono by averaging channels
172
  waveform = waveform.mean(dim=0, keepdim=True)
173
 
174
  waveform = waveform.squeeze(0) # Remove channel dimension if mono
175
+ if move_to_device:
176
+ return waveform.to(
177
+ self.device
178
+ ) # Ensure tensor is on the same device
179
+ return waveform
180
+
181
+ def load_audio_batch(
182
+ self, audio_paths: List[str]
183
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
184
+ wav_list = []
185
+ wav_lengths = []
186
+ for path in audio_paths:
187
+ wav = self.load_audio(path, move_to_device=False)
188
+ wav_list.append(wav)
189
+ wav_lengths.append(wav.shape[0])
190
+ wavs = pad_sequence(wav_list, batch_first=True).to(self.device)
191
+ wav_lengths = torch.tensor(
192
+ wav_lengths, dtype=torch.long, device=self.device
193
+ )
194
+ return wavs, wav_lengths
195
 
196
  def forward(
197
  self,
198
  wavs: torch.Tensor,
199
+ wav_lengths: Optional[torch.Tensor] = None,
200
+ padding_mask: Optional[torch.Tensor] = None,
201
+ target_layer: Optional[int] = None,
202
  norm_mean: float = -4.268,
203
  norm_std: float = 4.569,
204
  ) -> dict:
205
+ """
 
206
  Args:
207
+ wavs (torch.Tensor): (B, T_wav) waveform tensor.
208
+ wav_lengths (torch.Tensor, optional): (B,) lengths of each waveform. Defaults to None.
209
+ padding_mask (torch.Tensor, optional): (B, T_wav) padding mask for the waveforms.
210
+ If wav_lengths is not provided, this is used to infer valid lengths.
211
+ target_layer (int, optional): If specified, only return the output of the target layer. Defaults to None (return all layers).
212
+ norm_mean (float, optional): Mean for normalization. Defaults to -4.268.
213
+ norm_std (float, optional): Std for normalization. Defaults to 4.569.
214
  Returns:
215
+ dict: A dictionary containing the following keys:
216
+ - "x": (B, T_out, encoder_dim) output of the encoder
217
+ - "x_lengths": (B,) valid output lengths after encoder subsampling
218
+ - "x_padding_mask": (B, T_out) output padding mask, where padding is True
219
+ - "mel": (B, T_mel, mel_dim) input mel features
220
+ - "mel_lengths": (B,) valid mel lengths before encoder subsampling
221
+ - "hidden_states": list of (B, T_out, encoder_dim) hidden states of each layer
222
+ - "ffn": list of (B, T_out, encoder_dim) output of the feed-forward network of each layer
223
  """
 
224
 
225
+ # Check types
226
+ assert isinstance(wavs, torch.Tensor), "wavs must be a torch.Tensor"
227
+ assert wavs.dim() == 2, "wavs must be of shape (batch_size, seq_len)"
228
+ if wav_lengths is not None:
229
+ assert isinstance(
230
+ wav_lengths, torch.Tensor
231
+ ), "wav_lengths must be a torch.Tensor"
232
+ assert (
233
+ wav_lengths.dim() == 1
234
+ ), "wav_lengths must be of shape (batch_size,)"
235
+ assert (
236
+ wav_lengths.shape[0] == wavs.shape[0]
237
+ ), "wav_lengths must have the same batch size as wavs"
238
+ if padding_mask is not None:
239
+ assert isinstance(
240
+ padding_mask, torch.Tensor
241
+ ), "padding_mask must be a torch.Tensor"
242
+ assert (
243
+ padding_mask.dim() == 2
244
+ ), "padding_mask must be of shape (batch_size, seq_len)"
245
+ assert (
246
+ padding_mask.shape[0] == wavs.shape[0]
247
+ ), "padding_mask must have the same batch size as wavs"
248
+ assert (
249
+ padding_mask.shape[1] == wavs.shape[1]
250
+ ), "padding_mask must have the same seq_len as wavs"
251
+ if wav_lengths is None:
252
+ wav_lengths = (~padding_mask.to(torch.bool)).sum(dim=1)
253
+ if target_layer is not None:
254
+ assert isinstance(
255
+ target_layer, int
256
+ ), "target_layer must be an int or None"
257
+ assert (
258
+ 1 <= target_layer <= self.cfg.num_layers
259
+ ), f"target_layer must be between 1 and {self.cfg.num_layers}"
260
+
261
+ mel, mel_lengths = wav_to_fbank(
262
+ wavs,
263
+ wav_lengths=wav_lengths,
264
+ mel_dim=self.mel_dim,
265
+ norm_mean=norm_mean,
266
+ norm_std=norm_std,
267
+ sample_rate=self.sample_rate,
268
+ return_lengths=True,
269
+ )
270
+
271
+ dtype = self.dtype
272
+
273
+ if mel.dtype != dtype:
274
+ mel = mel.to(dtype)
275
+
276
+ num_layers = min(
277
+ self.cfg.num_layers,
278
+ target_layer if target_layer is not None else self.cfg.num_layers,
279
+ )
280
+
281
  if mel.shape[1] <= self.max_mel_length:
282
+ # If the mel length is less than or equal to max_mel_length, we can process it in one go
283
+ x, x_len, layer_results = self.encoder(
284
+ inputs=mel,
285
+ input_lengths=mel_lengths,
286
+ return_hidden=True,
287
+ target_layer=target_layer,
288
+ )
289
 
290
  result = {
291
  "x": x,
292
+ "x_lengths": x_len,
293
+ "x_padding_mask": lengths_to_padding_mask(
294
+ x_len, max_len=x.size(1)
295
+ ),
296
  "mel": mel,
297
+ "mel_lengths": mel_lengths,
298
  "hidden_states": layer_results["hidden_states"],
299
  "ffn": layer_results["ffn_1"],
300
  }
301
  return result
302
 
303
+ # If the mel length is greater than max_mel_length, we need to process it in chunks
304
  result = {
305
  "x": [],
306
+ "x_lengths": [],
307
  "mel": mel,
308
+ "mel_lengths": mel_lengths,
309
+ "hidden_states": [[] for _ in range(num_layers)],
310
+ "ffn": [[] for _ in range(num_layers)],
311
  }
312
  for i in range(0, mel.shape[1], self.max_mel_length):
313
  if mel.shape[1] - i < 10:
314
  break
315
 
316
+ _mel = mel[:, i : i + self.max_mel_length]
317
+ _mel_lengths = None
318
+ if mel_lengths is not None:
319
+ _mel_lengths = torch.clamp(
320
+ mel_lengths - i, min=0, max=self.max_mel_length
321
+ )
322
+
323
  x, x_len, layer_results = self.encoder(
324
+ inputs=_mel,
325
+ input_lengths=_mel_lengths,
326
+ return_hidden=True,
327
+ target_layer=target_layer,
328
  )
329
+
330
  result["x"].append(x)
331
+ result["x_lengths"].append(x_len)
332
+ for j in range(num_layers):
333
+ result["hidden_states"][j].append(
334
+ layer_results["hidden_states"][j]
335
+ )
336
  result["ffn"][j].append(layer_results["ffn_1"][j])
337
 
338
  result["x"] = torch.cat(result["x"], dim=1)
339
+ result["x_lengths"] = torch.stack(result["x_lengths"], dim=0).sum(
340
+ dim=0
341
+ )
342
+ result["x_padding_mask"] = lengths_to_padding_mask(
343
+ result["x_lengths"], max_len=result["x"].size(1)
344
+ )
345
+ for j in range(num_layers):
346
+ result["hidden_states"][j] = torch.cat(
347
+ result["hidden_states"][j], dim=1
348
+ )
349
  result["ffn"][j] = torch.cat(result["ffn"][j], dim=1)
350
 
 
 
 
 
351
  return result
352
 
353
  @classmethod