PyTorch
English
delulu
custom_code
massabaali commited on
Commit
3b984eb
·
verified ·
1 Parent(s): 8ef80e7

Upload modeling_delulu.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_delulu.py +18 -264
modeling_delulu.py CHANGED
@@ -1,27 +1,11 @@
1
- """
2
- DELULU Model
3
-
4
- DELULU (Discriminative Embedding Learning Using Latent Units) is a speaker-aware
5
- self-supervised speech foundational model based on HuBERT architecture.
6
-
7
- Paper: https://arxiv.org/abs/2510.17662
8
- Authors: Massa Baali, Rita Singh, Bhiksha Raj
9
-
10
- This implementation wraps torchaudio's wav2vec2_model for compatibility with
11
- Hugging Face's AutoModel interface.
12
- """
13
-
14
  import torch
15
  import torch.nn as nn
16
  from typing import Optional, Tuple, Union
17
- from dataclasses import dataclass
18
-
19
  from transformers import PreTrainedModel
20
  from transformers.modeling_outputs import BaseModelOutput
21
-
22
  from .configuration_delulu import DELULUConfig
23
 
24
- # Try to import torchaudio
25
  try:
26
  from torchaudio.models.wav2vec2 import wav2vec2_model
27
  TORCHAUDIO_AVAILABLE = True
@@ -29,79 +13,25 @@ except ImportError:
29
  TORCHAUDIO_AVAILABLE = False
30
 
31
 
32
- @dataclass
33
- class DELULUOutput(BaseModelOutput):
34
- """
35
- Output class for DELULU model.
36
-
37
- Args:
38
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
39
- Sequence of hidden-states at the output of the last layer of the model.
40
- hidden_states (`tuple(torch.FloatTensor)`, *optional*):
41
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for each layer)
42
- of shape `(batch_size, sequence_length, hidden_size)`.
43
- attentions (`tuple(torch.FloatTensor)`, *optional*):
44
- Attention weights (not available for torchaudio backend).
45
- extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
46
- Features from the convolutional feature extractor.
47
- """
48
- last_hidden_state: torch.FloatTensor = None
49
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
50
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
51
- extract_features: Optional[torch.FloatTensor] = None
52
-
53
-
54
  class DELULUModel(PreTrainedModel):
55
- """
56
- DELULU Model for speaker-aware speech representation learning.
57
-
58
- This model wraps torchaudio's wav2vec2_model with DELULU's custom configuration
59
- (modified convolutional strides for 16ms frame shift).
60
-
61
- Example:
62
- ```python
63
- from transformers import AutoModel
64
- import torch
65
-
66
- # Load model
67
- model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
68
- model.eval()
69
-
70
- # Process audio (16kHz, mono)
71
- waveform = torch.randn(1, 16000) # 1 second of audio
72
-
73
- with torch.no_grad():
74
- outputs = model(waveform)
75
- features = outputs.last_hidden_state # [1, T, 768]
76
-
77
- # For speaker verification, use mean pooling
78
- speaker_embedding = features.mean(dim=1) # [1, 768]
79
- ```
80
- """
81
-
82
  config_class = DELULUConfig
83
- base_model_prefix = "delulu"
84
  main_input_name = "input_values"
85
  supports_gradient_checkpointing = False
 
86
 
87
  def __init__(self, config: DELULUConfig):
88
  super().__init__(config)
89
- self.config = config
90
 
91
  if not TORCHAUDIO_AVAILABLE:
92
- raise ImportError(
93
- "torchaudio is required for DELULU model. "
94
- "Install with: pip install torchaudio"
95
- )
96
 
97
- # Build convolutional layer config from DELULU config
98
  conv_layer_config = list(zip(
99
  config.conv_dim,
100
  config.conv_kernel,
101
  config.conv_stride
102
  ))
103
 
104
- # Create the underlying torchaudio model
105
  self.wav2vec2 = wav2vec2_model(
106
  extractor_mode=config.extractor_mode,
107
  extractor_conv_layer_config=conv_layer_config,
@@ -120,214 +50,38 @@ class DELULUModel(PreTrainedModel):
120
  encoder_layer_drop=config.layer_drop,
121
  aux_num_out=None,
122
  )
123
-
124
- # Initialize weights
125
  self.post_init()
126
 
 
 
 
127
  def forward(
128
  self,
129
  input_values: torch.Tensor,
130
  attention_mask: Optional[torch.Tensor] = None,
131
  output_hidden_states: Optional[bool] = None,
132
- output_attentions: Optional[bool] = None,
133
  return_dict: Optional[bool] = None,
134
- ) -> Union[Tuple, DELULUOutput]:
135
- """
136
- Forward pass of DELULU model.
137
-
138
- Args:
139
- input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
140
- Raw audio waveform at 16kHz sampling rate.
141
- attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
142
- Mask to avoid performing attention on padding. Not used in current implementation.
143
- output_hidden_states (`bool`, *optional*):
144
- Whether to return all hidden states.
145
- output_attentions (`bool`, *optional*):
146
- Whether to return attention weights. Not supported with torchaudio backend.
147
- return_dict (`bool`, *optional*):
148
- Whether to return a `DELULUOutput` instead of a tuple.
149
-
150
- Returns:
151
- `DELULUOutput` or `tuple`: Model outputs.
152
- """
153
- output_hidden_states = (
154
- output_hidden_states if output_hidden_states is not None
155
- else self.config.output_hidden_states if hasattr(self.config, 'output_hidden_states')
156
- else False
157
- )
158
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict if hasattr(self.config, 'use_return_dict') else True
159
 
160
- # Ensure input is 2D: (batch, samples)
161
  if input_values.dim() == 1:
162
  input_values = input_values.unsqueeze(0)
163
 
164
- # Handle lengths for torchaudio model
165
- lengths = None
166
- if attention_mask is not None:
167
- lengths = attention_mask.sum(dim=-1)
168
 
169
- # Extract features using torchaudio model
170
  if output_hidden_states:
171
- # Get all layer outputs
172
- features, lengths_out = self.wav2vec2.extract_features(
173
- input_values,
174
- lengths=lengths
175
- )
176
- # features is a list of tensors, one per layer
177
- hidden_states = tuple(features)
178
- last_hidden_state = features[-1]
179
- else:
180
- # Just get final output
181
- outputs, lengths_out = self.wav2vec2(input_values, lengths=lengths)
182
- last_hidden_state = outputs
183
- hidden_states = None
184
 
185
- # Get convolutional features (before transformer)
186
- extract_features = self.wav2vec2.feature_extractor(input_values, lengths)[0]
187
 
188
  if not return_dict:
189
- outputs = (last_hidden_state,)
190
- if output_hidden_states:
191
- outputs = outputs + (hidden_states,)
192
- return outputs
193
-
194
- return DELULUOutput(
195
- last_hidden_state=last_hidden_state,
196
- hidden_states=hidden_states,
197
- attentions=None, # torchaudio doesn't expose attention weights
198
- extract_features=extract_features,
199
- )
200
 
201
- def extract_features(
202
- self,
203
- input_values: torch.Tensor,
204
- lengths: Optional[torch.Tensor] = None
205
- ) -> Tuple[torch.Tensor, ...]:
206
- """
207
- Extract features from all layers.
208
-
209
- Args:
210
- input_values: Audio waveform of shape (batch, samples)
211
- lengths: Optional lengths for each sample in batch
212
-
213
- Returns:
214
- Tuple of tensors, one per layer (including CNN output)
215
- """
216
  if input_values.dim() == 1:
217
  input_values = input_values.unsqueeze(0)
218
-
219
- features, _ = self.wav2vec2.extract_features(input_values, lengths=lengths)
220
  return tuple(features)
221
-
222
- def get_speaker_embedding(
223
- self,
224
- input_values: torch.Tensor,
225
- pooling: str = "mean"
226
- ) -> torch.Tensor:
227
- """
228
- Extract speaker embedding from audio.
229
-
230
- Args:
231
- input_values: Audio waveform of shape (batch, samples)
232
- pooling: Pooling method - "mean", "max", or "first"
233
-
234
- Returns:
235
- Speaker embedding of shape (batch, hidden_size)
236
- """
237
- outputs = self.forward(input_values, return_dict=True)
238
- features = outputs.last_hidden_state
239
-
240
- if pooling == "mean":
241
- return features.mean(dim=1)
242
- elif pooling == "max":
243
- return features.max(dim=1).values
244
- elif pooling == "first":
245
- return features[:, 0, :]
246
- else:
247
- raise ValueError(f"Unknown pooling method: {pooling}")
248
-
249
- def _init_weights(self, module):
250
- """Initialize weights - mostly handled by torchaudio."""
251
- pass
252
-
253
-
254
- class DELULUForSequenceClassification(PreTrainedModel):
255
- """
256
- DELULU with a classification head for speaker verification and other tasks.
257
-
258
- Example:
259
- ```python
260
- from transformers import AutoModel
261
-
262
- model = AutoModel.from_pretrained(
263
- "cmu-mlsp/DELULU",
264
- trust_remote_code=True,
265
- num_labels=1251 # Number of speakers in VoxCeleb2
266
- )
267
- ```
268
- """
269
-
270
- config_class = DELULUConfig
271
- base_model_prefix = "delulu"
272
-
273
- def __init__(self, config: DELULUConfig):
274
- super().__init__(config)
275
-
276
- self.delulu = DELULUModel(config)
277
- self.projector = nn.Linear(config.hidden_size, config.hidden_size)
278
-
279
- num_labels = getattr(config, 'num_labels', None)
280
- if num_labels:
281
- self.classifier = nn.Linear(config.hidden_size, num_labels)
282
- else:
283
- self.classifier = None
284
-
285
- self.post_init()
286
-
287
- def forward(
288
- self,
289
- input_values: torch.Tensor,
290
- attention_mask: Optional[torch.Tensor] = None,
291
- labels: Optional[torch.Tensor] = None,
292
- return_dict: Optional[bool] = None,
293
- ):
294
- return_dict = return_dict if return_dict is not None else True
295
-
296
- outputs = self.delulu(
297
- input_values,
298
- attention_mask=attention_mask,
299
- return_dict=True
300
- )
301
-
302
- # Pool features
303
- hidden_states = outputs.last_hidden_state
304
- pooled = hidden_states.mean(dim=1)
305
-
306
- # Project
307
- embeddings = self.projector(pooled)
308
-
309
- # Classify if head exists
310
- logits = None
311
- if self.classifier is not None:
312
- logits = self.classifier(embeddings)
313
-
314
- loss = None
315
- if labels is not None and logits is not None:
316
- loss_fct = nn.CrossEntropyLoss()
317
- loss = loss_fct(logits, labels)
318
-
319
- if not return_dict:
320
- output = (logits, embeddings) + (outputs.last_hidden_state,)
321
- return ((loss,) + output) if loss is not None else output
322
-
323
- return {
324
- "loss": loss,
325
- "logits": logits,
326
- "embeddings": embeddings,
327
- "last_hidden_state": outputs.last_hidden_state,
328
- }
329
-
330
-
331
- # Register for auto classes
332
- DELULUConfig.register_for_auto_class()
333
- DELULUModel.register_for_auto_class("AutoModel")
 
1
+ """DELULU Model"""
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import torch.nn as nn
4
  from typing import Optional, Tuple, Union
 
 
5
  from transformers import PreTrainedModel
6
  from transformers.modeling_outputs import BaseModelOutput
 
7
  from .configuration_delulu import DELULUConfig
8
 
 
9
  try:
10
  from torchaudio.models.wav2vec2 import wav2vec2_model
11
  TORCHAUDIO_AVAILABLE = True
 
13
  TORCHAUDIO_AVAILABLE = False
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class DELULUModel(PreTrainedModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  config_class = DELULUConfig
18
+ base_model_prefix = "wav2vec2"
19
  main_input_name = "input_values"
20
  supports_gradient_checkpointing = False
21
+ _no_split_modules = []
22
 
23
  def __init__(self, config: DELULUConfig):
24
  super().__init__(config)
 
25
 
26
  if not TORCHAUDIO_AVAILABLE:
27
+ raise ImportError("torchaudio required: pip install torchaudio")
 
 
 
28
 
 
29
  conv_layer_config = list(zip(
30
  config.conv_dim,
31
  config.conv_kernel,
32
  config.conv_stride
33
  ))
34
 
 
35
  self.wav2vec2 = wav2vec2_model(
36
  extractor_mode=config.extractor_mode,
37
  extractor_conv_layer_config=conv_layer_config,
 
50
  encoder_layer_drop=config.layer_drop,
51
  aux_num_out=None,
52
  )
 
 
53
  self.post_init()
54
 
55
+ def _init_weights(self, module):
56
+ pass
57
+
58
  def forward(
59
  self,
60
  input_values: torch.Tensor,
61
  attention_mask: Optional[torch.Tensor] = None,
62
  output_hidden_states: Optional[bool] = None,
 
63
  return_dict: Optional[bool] = None,
64
+ ) -> Union[Tuple, BaseModelOutput]:
65
+ return_dict = return_dict if return_dict is not None else True
66
+ output_hidden_states = output_hidden_states or False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
68
  if input_values.dim() == 1:
69
  input_values = input_values.unsqueeze(0)
70
 
71
+ lengths = attention_mask.sum(-1) if attention_mask is not None else None
 
 
 
72
 
 
73
  if output_hidden_states:
74
+ features, _ = self.wav2vec2.extract_features(input_values, lengths=lengths)
75
+ return BaseModelOutput(last_hidden_state=features[-1], hidden_states=tuple(features))
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ output, _ = self.wav2vec2(input_values, lengths=lengths)
 
78
 
79
  if not return_dict:
80
+ return (output,)
81
+ return BaseModelOutput(last_hidden_state=output)
 
 
 
 
 
 
 
 
 
82
 
83
+ def extract_features(self, input_values: torch.Tensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  if input_values.dim() == 1:
85
  input_values = input_values.unsqueeze(0)
86
+ features, _ = self.wav2vec2.extract_features(input_values)
 
87
  return tuple(features)