PyTorch
English
delulu
custom_code
massabaali commited on
Commit
a4a60fd
·
verified ·
1 Parent(s): 46724ac

Upload folder using huggingface_hub

Browse files
DELULU_MODEL_CARD.md ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-nd-4.0
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ tags:
7
+ - speaker-verification
8
+ - speaker-diarization
9
+ - speaker-profiling
10
+ - speech
11
+ - audio
12
+ - self-supervised-learning
13
+ - ssl
14
+ - hubert
15
+ - speech-representation
16
+ - pytorch
17
+ - deep-learning
18
+ datasets:
19
+ - librispeech_asr
20
+ metrics:
21
+ - eer
22
+ pipeline_tag: audio-classification
23
+ model-index:
24
+ - name: DELULU
25
+ results:
26
+ - task:
27
+ type: speaker-verification
28
+ name: Speaker Verification
29
+ dataset:
30
+ type: VoxCeleb1-O
31
+ name: VoxCeleb1-O
32
+ metrics:
33
+ - type: eer
34
+ value: 13.52
35
+ name: Equal Error Rate (Upstream)
36
+ ---
37
+
38
+ # DELULU: Discriminative Embedding Learning Using Latent Units
39
+
40
+ <div align="center">
41
+
42
+ **A Speaker-Aware Self-Supervised Speech Foundational Model**
43
+
44
+ [![Paper](https://img.shields.io/badge/arXiv-2510.17662-b31b1b.svg)](https://arxiv.org/abs/2510.17662)
45
+ [![License](https://img.shields.io/badge/License-CC%20BY--NC--ND%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc-nd/4.0/)
46
+
47
+ </div>
48
+
49
+ ## Model Description
50
+
51
+ **DELULU** (Discriminative Embedding Learning Using Latent Units) is a speaker-aware self-supervised speech foundational model that addresses a critical limitation of existing SSL models: their inability to capture speaker-discriminative features essential for verification, diarization, and profiling applications.
52
+
53
+ While conventional SSL models like HuBERT, wav2vec 2.0, and WavLM excel at content-driven tasks (ASR, speech recognition), they learn representations optimized for phonetic/linguistic content, inadvertently discarding speaker identity information. DELULU bridges this gap by integrating external speaker supervision into the pseudo-label generation process.
54
+
55
+ ### Key Innovation
56
+
57
+ DELULU introduces a novel approach to self-supervised speech learning by leveraging **frame-level embeddings from ReDimNet**, a state-of-the-art speaker verification model, to guide the k-means clustering step during pre-training. This introduces a strong **speaker-discriminative inductive bias** that aligns representation learning with speaker identity—a fundamental shift from content-focused SSL paradigms.
58
+
59
+ ## Architecture
60
+
61
+ DELULU is based on the HuBERT architecture with a **modified convolutional feature extractor** optimized for speaker verification:
62
+
63
+ ### Convolutional Feature Extractor
64
+
65
+ | Layer | Channels | Kernel Size | Stride |
66
+ |-------|----------|-------------|--------|
67
+ | 1 | 512 | 10 | **4** |
68
+ | 2 | 512 | 3 | 2 |
69
+ | 3 | 512 | 3 | 2 |
70
+ | 4 | 512 | 3 | 2 |
71
+ | 5 | 512 | 3 | 2 |
72
+ | 6 | 512 | 2 | 2 |
73
+ | 7 | 512 | 2 | 2 |
74
+
75
+ > **Key Difference**: The first layer uses stride **4** (vs. stride 5 in standard HuBERT), resulting in a **16ms frame shift** optimized for speaker verification tasks.
76
+
77
+ ### Transformer Encoder
78
+
79
+ - **Hidden size**: 768
80
+ - **Attention heads**: 12
81
+ - **Layers**: 12
82
+ - **Intermediate size**: 3,072
83
+ - **Frame shift**: 16ms (vs. 20ms in HuBERT)
84
+
85
+ ### Training Configuration
86
+
87
+ - **Clustering**: ReDimNet-guided k-means with k=256 clusters
88
+ - **Feature dimension**: 2,304 (ReDimNet frame-level embeddings)
89
+ - **Training objective**: Dual objective combining masked prediction + denoising
90
+ - **Pre-training data**: LibriSpeech 960h
91
+ - **Training steps**: 400k updates
92
+
93
+ ## Performance
94
+
95
+ ### Upstream Speaker Verification (Zero-Shot)
96
+
97
+ | Model | VoxCeleb1-O EER (%) |
98
+ |-------|---------------------|
99
+ | wav2vec 2.0 | 37.21 |
100
+ | HuBERT | 34.05 |
101
+ | WavLM | 29.84 |
102
+ | **DELULU** | **13.52** |
103
+
104
+ > **62% relative improvement** over standard HuBERT in equal error rate.
105
+
106
+ ### Ablation: Why ReDimNet-Guided Clustering?
107
+
108
+ | Clustering Features | k | EER (%) |
109
+ |---------------------|---|---------|
110
+ | MFCC | 100 | 37.73 |
111
+ | HuBERT (pretrained) | 500 | 34.05 |
112
+ | **ReDimNet** | 256 | **13.53** |
113
+
114
+ ReDimNet-guided pseudo-labels provide a **60% relative improvement** over HuBERT's acoustic-only approach.
115
+
116
+ ### Demographic Robustness
117
+
118
+ DELULU consistently outperforms baselines across all demographic groups, with particularly strong improvements for challenging subgroups:
119
+
120
+ | Demographic | HuBERT EER (%) | DELULU EER (%) | Improvement |
121
+ |-------------|----------------|----------------|-------------|
122
+ | Male 36-45 | 39.47 | 24.53 | 38% |
123
+ | All groups | Varies | Consistent | ✓ |
124
+
125
+ ### Zero-Shot Speaker Profiling (DynamicSUPERB)
126
+
127
+ DELULU excels on multiple speaker-related tasks without fine-tuning:
128
+ - Gender classification
129
+ - Age estimation
130
+ - Accent recognition
131
+ - Speaker counting
132
+ - Spoof detection
133
+
134
+ ## Intended Uses
135
+
136
+ ### Primary Use Cases
137
+
138
+ 1. **Speaker Verification**: Verify whether two speech samples are from the same speaker
139
+ 2. **Speaker Diarization**: Segment and cluster speech by speaker identity
140
+ 3. **Speaker Profiling**: Extract demographic attributes (age, gender, accent)
141
+ 4. **Forensic Audio Analysis**: Speaker identification in investigative contexts
142
+
143
+ ### Downstream Applications
144
+
145
+ - Voice biometrics and authentication systems
146
+ - Meeting transcription with speaker labels
147
+ - Call center analytics
148
+ - Content personalization based on speaker identity
149
+ - Multi-speaker dialogue systems
150
+
151
+ ## How to Use
152
+
153
+ ### Installation
154
+
155
+ ```bash
156
+ pip install transformers torch torchaudio
157
+ ```
158
+
159
+ ### Loading the Model
160
+
161
+ ```python
162
+ import torch
163
+ from transformers import AutoModel, AutoConfig
164
+
165
+ # Load DELULU model
166
+ model = AutoModel.from_pretrained("username/DELULU", trust_remote_code=True)
167
+ model.eval()
168
+ ```
169
+
170
+ ### Feature Extraction
171
+
172
+ ```python
173
+ import torchaudio
174
+
175
+ # Load audio (16kHz sampling rate required)
176
+ waveform, sample_rate = torchaudio.load("audio.wav")
177
+ if sample_rate != 16000:
178
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
179
+ waveform = resampler(waveform)
180
+
181
+ # Extract features
182
+ with torch.no_grad():
183
+ outputs = model(waveform)
184
+ # Use last hidden state for downstream tasks
185
+ features = outputs.last_hidden_state # [batch, time, 768]
186
+
187
+ # For speaker verification, typically use mean pooling
188
+ speaker_embedding = features.mean(dim=1) # [batch, 768]
189
+ ```
190
+
191
+ ### Speaker Verification Example
192
+
193
+ ```python
194
+ import torch.nn.functional as F
195
+
196
+ def compute_similarity(embedding1, embedding2):
197
+ """Compute cosine similarity between two speaker embeddings."""
198
+ return F.cosine_similarity(embedding1, embedding2, dim=-1)
199
+
200
+ # Extract embeddings for two audio samples
201
+ emb1 = extract_embedding(model, audio1)
202
+ emb2 = extract_embedding(model, audio2)
203
+
204
+ # Compute similarity score
205
+ similarity = compute_similarity(emb1, emb2)
206
+ print(f"Similarity score: {similarity.item():.4f}")
207
+
208
+ # Threshold-based decision (tune threshold on validation data)
209
+ threshold = 0.7
210
+ same_speaker = similarity > threshold
211
+ ```
212
+
213
+ ### Fine-Tuning for Downstream Tasks
214
+
215
+ ```python
216
+ from transformers import Trainer, TrainingArguments
217
+
218
+ # Add task-specific head
219
+ class SpeakerVerificationModel(torch.nn.Module):
220
+ def __init__(self, base_model, embedding_dim=256):
221
+ super().__init__()
222
+ self.base = base_model
223
+ self.projector = torch.nn.Linear(768, embedding_dim)
224
+
225
+ def forward(self, x):
226
+ features = self.base(x).last_hidden_state
227
+ pooled = features.mean(dim=1)
228
+ return self.projector(pooled)
229
+
230
+ # Fine-tune with your speaker verification dataset
231
+ model = SpeakerVerificationModel(base_model)
232
+ ```
233
+
234
+ ## Training Details
235
+
236
+ ### Pre-training Process
237
+
238
+ 1. **Pseudo-Label Generation**:
239
+ - Extract frame-level embeddings using ReDimNet (dimension: 2,304)
240
+ - Apply k-means clustering with k=256 to create speaker-aware pseudo-labels
241
+ - ReDimNet stride modified to match encoder stride (16ms)
242
+
243
+ 2. **Training Objective**:
244
+ - **Masked Prediction**: Predict pseudo-labels for masked frames
245
+ - **Denoising**: Additional denoising objective for robustness
246
+
247
+ 3. **Optimization**:
248
+ - Training data: LibriSpeech 960 hours
249
+ - Training steps: 400k updates
250
+ - Batch size: 87.5 seconds of audio per GPU
251
+ - Hardware: 32 GPUs
252
+
253
+ ### Why 16ms Frame Shift?
254
+
255
+ Ablation studies showed that **16ms stride achieves optimal EER (13.52%)**, while both lower (≤15ms) and higher (≥20ms) strides resulted in EER >14%. This precise temporal resolution balances:
256
+ - Fine-grained speaker characteristics capture
257
+ - Computational efficiency
258
+ - Training stability
259
+
260
+ ## Limitations
261
+
262
+ 1. **Domain Shift**: Performance may degrade on audio with characteristics significantly different from LibriSpeech (e.g., noisy environments, non-English speech, telephony audio)
263
+
264
+ 2. **Computational Requirements**: As a transformer-based model, DELULU requires substantial computational resources for inference on long audio
265
+
266
+ 3. **Fine-tuning May Be Required**: While DELULU provides strong zero-shot speaker representations, task-specific fine-tuning typically improves performance
267
+
268
+ 4. **Language**: Pre-trained on English speech; cross-lingual transfer may be limited
269
+
270
+ ## Ethical Considerations
271
+
272
+ ### Potential Misuse
273
+
274
+ Speaker verification technology can be misused for:
275
+ - Unauthorized surveillance
276
+ - Privacy violations
277
+ - Identity fraud
278
+ - Discriminatory profiling
279
+
280
+ ### Recommended Safeguards
281
+
282
+ - Obtain explicit consent before processing voice data
283
+ - Implement robust access controls
284
+ - Follow data protection regulations (GDPR, CCPA)
285
+ - Conduct bias audits across demographic groups
286
+ - Maintain transparency about system capabilities and limitations
287
+
288
+ ### Bias Evaluation
289
+
290
+ DELULU was evaluated across demographic subgroups and shows consistent improvements without introducing systematic biases. However, users should validate performance on their specific populations.
291
+
292
+ ## Citation
293
+
294
+ If you use DELULU in your research, please cite:
295
+
296
+ ```bibtex
297
+ @article{baali2025delulu,
298
+ title={DELULU: Discriminative Embedding Learning Using Latent Units for Speaker-Aware Self-Supervised Speech Foundational Model},
299
+ author={Baali, Massa and Singh, Rita and Raj, Bhiksha},
300
+ journal={arXiv preprint arXiv:2510.17662},
301
+ year={2025}
302
+ }
303
+ ```
304
+
305
+ ## Related Work
306
+
307
+ - **HuBERT**: [Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447)
308
+ - **ReDimNet**: State-of-the-art speaker verification model used for pseudo-label generation
309
+ - **WavLM**: [Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900)
310
+
311
+ ## Acknowledgments
312
+
313
+ This work was conducted at Carnegie Mellon University's Language Technologies Institute. We thank the speech processing community for foundational work on self-supervised learning and speaker verification.
314
+
315
+ ## Contact
316
+
317
+ For questions about the model or paper:
318
+ - **Author**: Massa Baali
319
+ - **Advisors**: Prof. Rita Singh, Prof. Bhiksha Raj
320
+ - **Institution**: Carnegie Mellon University, Language Technologies Institute
321
+
322
+ ---
323
+
324
+ <div align="center">
325
+ <i>DELULU: Where Self-Supervised Learning Meets Speaker Identity</i>
326
+ </div>
config.json CHANGED
@@ -1,39 +1,13 @@
1
  {
2
  "model_type": "delulu",
3
- "architectures": [
4
- "DELULUModel"
5
- ],
6
  "auto_map": {
7
  "AutoConfig": "configuration_delulu.DELULUConfig",
8
  "AutoModel": "modeling_delulu.DELULUModel"
9
  },
10
- "conv_dim": [
11
- 512,
12
- 512,
13
- 512,
14
- 512,
15
- 512,
16
- 512,
17
- 512
18
- ],
19
- "conv_kernel": [
20
- 10,
21
- 3,
22
- 3,
23
- 3,
24
- 3,
25
- 2,
26
- 2
27
- ],
28
- "conv_stride": [
29
- 4,
30
- 2,
31
- 2,
32
- 2,
33
- 2,
34
- 2,
35
- 2
36
- ],
37
  "conv_bias": false,
38
  "extractor_mode": "group_norm",
39
  "hidden_size": 768,
@@ -44,7 +18,7 @@
44
  "attention_dropout": 0.1,
45
  "final_dropout": 0.1,
46
  "feat_proj_dropout": 0.1,
47
- "layer_norm_eps": 1e-05,
48
  "layer_drop": 0.05,
49
  "num_conv_pos_embeddings": 128,
50
  "num_conv_pos_embedding_groups": 16,
@@ -56,5 +30,6 @@
56
  "pad_token_id": 0,
57
  "bos_token_id": 1,
58
  "eos_token_id": 2,
 
59
  "torch_dtype": "float32"
60
- }
 
1
  {
2
  "model_type": "delulu",
3
+ "architectures": ["DELULUModel"],
 
 
4
  "auto_map": {
5
  "AutoConfig": "configuration_delulu.DELULUConfig",
6
  "AutoModel": "modeling_delulu.DELULUModel"
7
  },
8
+ "conv_dim": [512, 512, 512, 512, 512, 512, 512],
9
+ "conv_kernel": [10, 3, 3, 3, 3, 2, 2],
10
+ "conv_stride": [4, 2, 2, 2, 2, 2, 2],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  "conv_bias": false,
12
  "extractor_mode": "group_norm",
13
  "hidden_size": 768,
 
18
  "attention_dropout": 0.1,
19
  "final_dropout": 0.1,
20
  "feat_proj_dropout": 0.1,
21
+ "layer_norm_eps": 1e-5,
22
  "layer_drop": 0.05,
23
  "num_conv_pos_embeddings": 128,
24
  "num_conv_pos_embedding_groups": 16,
 
30
  "pad_token_id": 0,
31
  "bos_token_id": 1,
32
  "eos_token_id": 2,
33
+ "transformers_version": "4.36.0",
34
  "torch_dtype": "float32"
35
+ }
configuration_delulu.py CHANGED
@@ -1,20 +1,80 @@
1
- """DELULU Configuration"""
 
 
 
 
 
 
 
 
2
 
3
  from transformers import PretrainedConfig
4
 
5
 
6
  class DELULUConfig(PretrainedConfig):
7
- """Configuration class for DELULU model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  model_type = "delulu"
10
 
11
  def __init__(
12
  self,
 
13
  conv_dim=None,
14
  conv_kernel=None,
15
  conv_stride=None,
16
  conv_bias=False,
17
  extractor_mode="group_norm",
 
 
18
  hidden_size=768,
19
  num_hidden_layers=12,
20
  num_attention_heads=12,
@@ -25,15 +85,24 @@ class DELULUConfig(PretrainedConfig):
25
  feat_proj_dropout=0.1,
26
  layer_norm_eps=1e-5,
27
  layer_drop=0.05,
 
 
28
  num_conv_pos_embeddings=128,
29
  num_conv_pos_embedding_groups=16,
 
 
30
  sampling_rate=16000,
31
  do_stable_layer_norm=False,
 
 
32
  num_clusters=256,
33
  feature_type="redimnet",
 
 
34
  pad_token_id=0,
35
  bos_token_id=1,
36
  eos_token_id=2,
 
37
  **kwargs
38
  ):
39
  super().__init__(
@@ -43,10 +112,18 @@ class DELULUConfig(PretrainedConfig):
43
  **kwargs
44
  )
45
 
46
- # DELULU conv config: [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
47
- self.conv_dim = conv_dim or [512, 512, 512, 512, 512, 512, 512]
48
- self.conv_kernel = conv_kernel or [10, 3, 3, 3, 3, 2, 2]
49
- self.conv_stride = conv_stride or [4, 2, 2, 2, 2, 2, 2]
 
 
 
 
 
 
 
 
50
  self.conv_bias = conv_bias
51
  self.extractor_mode = extractor_mode
52
 
@@ -70,4 +147,13 @@ class DELULUConfig(PretrainedConfig):
70
  self.num_clusters = num_clusters
71
  self.feature_type = feature_type
72
 
 
73
  self.num_feat_extract_layers = len(self.conv_dim)
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DELULU Configuration
3
+
4
+ Configuration class for DELULU (Discriminative Embedding Learning Using Latent Units),
5
+ a speaker-aware self-supervised speech foundational model.
6
+
7
+ Paper: https://arxiv.org/abs/2510.17662
8
+ Authors: Massa Baali, Rita Singh, Bhiksha Raj
9
+ """
10
 
11
  from transformers import PretrainedConfig
12
 
13
 
14
  class DELULUConfig(PretrainedConfig):
15
+ r"""
16
+ Configuration class for DELULU model.
17
+
18
+ DELULU is based on HuBERT architecture with modified convolutional strides
19
+ optimized for speaker verification (16ms frame shift).
20
+
21
+ Args:
22
+ conv_dim (`List[int]`, *optional*, defaults to `[512, 512, 512, 512, 512, 512, 512]`):
23
+ Dimensions of each convolutional layer in the feature extractor.
24
+ conv_kernel (`List[int]`, *optional*, defaults to `[10, 3, 3, 3, 3, 2, 2]`):
25
+ Kernel sizes of each convolutional layer in the feature extractor.
26
+ conv_stride (`List[int]`, *optional*, defaults to `[4, 2, 2, 2, 2, 2, 2]`):
27
+ Stride sizes of each convolutional layer. Note: first stride is 4 (vs 5 in HuBERT)
28
+ for 16ms frame shift optimized for speaker verification.
29
+ conv_bias (`bool`, *optional*, defaults to `False`):
30
+ Whether to use bias in convolutional layers.
31
+ hidden_size (`int`, *optional*, defaults to 768):
32
+ Dimensionality of the encoder layers and pooler layer.
33
+ num_hidden_layers (`int`, *optional*, defaults to 12):
34
+ Number of hidden layers in the Transformer encoder.
35
+ num_attention_heads (`int`, *optional*, defaults to 12):
36
+ Number of attention heads for each attention layer.
37
+ intermediate_size (`int`, *optional*, defaults to 3072):
38
+ Dimensionality of the feed-forward layer in the Transformer encoder.
39
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
40
+ Dropout probability for all fully connected layers.
41
+ attention_dropout (`float`, *optional*, defaults to 0.1):
42
+ Dropout probability for attention weights.
43
+ feat_proj_dropout (`float`, *optional*, defaults to 0.1):
44
+ Dropout probability for feature projection layer.
45
+ layer_drop (`float`, *optional*, defaults to 0.05):
46
+ Layer drop probability during training.
47
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
48
+ Number of convolutional positional embeddings.
49
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
50
+ Number of groups for convolutional positional embeddings.
51
+ sampling_rate (`int`, *optional*, defaults to 16000):
52
+ Audio sampling rate in Hz.
53
+
54
+ Example:
55
+ ```python
56
+ from transformers import AutoConfig, AutoModel
57
+
58
+ # Load config
59
+ config = AutoConfig.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
60
+
61
+ # Load model
62
+ model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
63
+ ```
64
+ """
65
 
66
  model_type = "delulu"
67
 
68
  def __init__(
69
  self,
70
+ # Convolutional feature extractor
71
  conv_dim=None,
72
  conv_kernel=None,
73
  conv_stride=None,
74
  conv_bias=False,
75
  extractor_mode="group_norm",
76
+
77
+ # Transformer encoder
78
  hidden_size=768,
79
  num_hidden_layers=12,
80
  num_attention_heads=12,
 
85
  feat_proj_dropout=0.1,
86
  layer_norm_eps=1e-5,
87
  layer_drop=0.05,
88
+
89
+ # Positional encoding
90
  num_conv_pos_embeddings=128,
91
  num_conv_pos_embedding_groups=16,
92
+
93
+ # Audio settings
94
  sampling_rate=16000,
95
  do_stable_layer_norm=False,
96
+
97
+ # DELULU-specific settings
98
  num_clusters=256,
99
  feature_type="redimnet",
100
+
101
+ # Pad token for compatibility
102
  pad_token_id=0,
103
  bos_token_id=1,
104
  eos_token_id=2,
105
+
106
  **kwargs
107
  ):
108
  super().__init__(
 
112
  **kwargs
113
  )
114
 
115
+ # Set default DELULU conv configuration
116
+ # Key difference from HuBERT: first stride is 4 instead of 5
117
+ if conv_dim is None:
118
+ conv_dim = [512, 512, 512, 512, 512, 512, 512]
119
+ if conv_kernel is None:
120
+ conv_kernel = [10, 3, 3, 3, 3, 2, 2]
121
+ if conv_stride is None:
122
+ conv_stride = [4, 2, 2, 2, 2, 2, 2]
123
+
124
+ self.conv_dim = conv_dim
125
+ self.conv_kernel = conv_kernel
126
+ self.conv_stride = conv_stride
127
  self.conv_bias = conv_bias
128
  self.extractor_mode = extractor_mode
129
 
 
147
  self.num_clusters = num_clusters
148
  self.feature_type = feature_type
149
 
150
+ # Computed properties
151
  self.num_feat_extract_layers = len(self.conv_dim)
152
+
153
+ @property
154
+ def inputs_to_logits_ratio(self):
155
+ """Compute the ratio between input samples and output frames."""
156
+ ratio = 1
157
+ for stride in self.conv_stride:
158
+ ratio *= stride
159
+ return ratio
convert_delulu_fixed.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ DELULU Checkpoint Converter - Fixed Version
4
+
5
+ Converts DELULU model checkpoints from torchaudio/PyTorch Lightning format
6
+ to Hugging Face compatible format with proper metadata.
7
+
8
+ Usage:
9
+ python convert_delulu_fixed.py \
10
+ --checkpoint /path/to/epoch=45-step=400000.ckpt \
11
+ --output-dir ./delulu_hf_model
12
+
13
+ Author: Massa Baali
14
+ """
15
+
16
+ import argparse
17
+ import json
18
+ import os
19
+ import sys
20
+ from collections import OrderedDict
21
+ from pathlib import Path
22
+
23
+ import torch
24
+
25
+ try:
26
+ from safetensors.torch import save_file as save_safetensors
27
+ SAFETENSORS_AVAILABLE = True
28
+ except ImportError:
29
+ SAFETENSORS_AVAILABLE = False
30
+ print("Warning: safetensors not installed. Install with: pip install safetensors")
31
+
32
+
33
+ def load_lightning_checkpoint(checkpoint_path: str) -> dict:
34
+ """Load and clean PyTorch Lightning checkpoint."""
35
+ print(f"Loading checkpoint: {checkpoint_path}")
36
+
37
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
38
+
39
+ # Extract state dict
40
+ if "state_dict" in checkpoint:
41
+ state_dict = checkpoint["state_dict"]
42
+ else:
43
+ state_dict = checkpoint
44
+
45
+ # Clean up state dict keys
46
+ cleaned_state_dict = OrderedDict()
47
+
48
+ for key, value in state_dict.items():
49
+ new_key = key
50
+
51
+ # Remove Lightning prefixes
52
+ if key.startswith("model.wav2vec2."):
53
+ new_key = key.replace("model.wav2vec2.", "")
54
+ elif key.startswith("model."):
55
+ new_key = key.replace("model.", "")
56
+
57
+ # Skip auxiliary heads
58
+ if "aux" in new_key:
59
+ print(f" Skipping: {key}")
60
+ continue
61
+
62
+ cleaned_state_dict[new_key] = value
63
+
64
+ print(f"Loaded {len(cleaned_state_dict)} parameters")
65
+ return cleaned_state_dict
66
+
67
+
68
+ def save_pytorch_model_bin(state_dict: dict, output_path: Path):
69
+ """
70
+ Save state dict as pytorch_model.bin with proper format.
71
+
72
+ This saves ONLY the state dict (not a full checkpoint with metadata),
73
+ which is what HuggingFace expects.
74
+ """
75
+ print(f"Saving pytorch_model.bin to: {output_path}")
76
+
77
+ # Convert all tensors to contiguous for safety
78
+ clean_state_dict = OrderedDict()
79
+ for key, value in state_dict.items():
80
+ if isinstance(value, torch.Tensor):
81
+ clean_state_dict[key] = value.contiguous()
82
+ else:
83
+ clean_state_dict[key] = value
84
+
85
+ # Save just the state dict (NOT a checkpoint dict)
86
+ torch.save(clean_state_dict, output_path)
87
+
88
+ print(f" Saved {len(clean_state_dict)} tensors")
89
+ print(f" File size: {output_path.stat().st_size / 1024 / 1024:.2f} MB")
90
+
91
+
92
+ def save_safetensors_model(state_dict: dict, output_path: Path):
93
+ """Save state dict in safetensors format."""
94
+ if not SAFETENSORS_AVAILABLE:
95
+ print("Skipping safetensors (not installed)")
96
+ return
97
+
98
+ print(f"Saving model.safetensors to: {output_path}")
99
+
100
+ # Safetensors requires contiguous tensors
101
+ clean_state_dict = {}
102
+ for key, value in state_dict.items():
103
+ if isinstance(value, torch.Tensor):
104
+ clean_state_dict[key] = value.contiguous()
105
+
106
+ save_safetensors(clean_state_dict, str(output_path))
107
+ print(f" File size: {output_path.stat().st_size / 1024 / 1024:.2f} MB")
108
+
109
+
110
+ def create_config_json(output_dir: Path):
111
+ """Create config.json with DELULU configuration."""
112
+ config = {
113
+ "model_type": "delulu",
114
+ "architectures": ["DELULUModel"],
115
+ "auto_map": {
116
+ "AutoConfig": "configuration_delulu.DELULUConfig",
117
+ "AutoModel": "modeling_delulu.DELULUModel"
118
+ },
119
+ "conv_dim": [512, 512, 512, 512, 512, 512, 512],
120
+ "conv_kernel": [10, 3, 3, 3, 3, 2, 2],
121
+ "conv_stride": [4, 2, 2, 2, 2, 2, 2],
122
+ "conv_bias": False,
123
+ "extractor_mode": "group_norm",
124
+ "hidden_size": 768,
125
+ "num_hidden_layers": 12,
126
+ "num_attention_heads": 12,
127
+ "intermediate_size": 3072,
128
+ "hidden_dropout": 0.1,
129
+ "attention_dropout": 0.1,
130
+ "final_dropout": 0.1,
131
+ "feat_proj_dropout": 0.1,
132
+ "layer_norm_eps": 1e-5,
133
+ "layer_drop": 0.05,
134
+ "num_conv_pos_embeddings": 128,
135
+ "num_conv_pos_embedding_groups": 16,
136
+ "sampling_rate": 16000,
137
+ "do_stable_layer_norm": False,
138
+ "num_clusters": 256,
139
+ "feature_type": "redimnet",
140
+ "num_feat_extract_layers": 7,
141
+ "pad_token_id": 0,
142
+ "bos_token_id": 1,
143
+ "eos_token_id": 2,
144
+ "torch_dtype": "float32"
145
+ }
146
+
147
+ config_path = output_dir / "config.json"
148
+ with open(config_path, "w") as f:
149
+ json.dump(config, f, indent=2)
150
+ print(f"Created config.json")
151
+
152
+
153
+ def create_configuration_delulu(output_dir: Path):
154
+ """Create configuration_delulu.py file."""
155
+ code = '''"""DELULU Configuration"""
156
+
157
+ from transformers import PretrainedConfig
158
+
159
+
160
+ class DELULUConfig(PretrainedConfig):
161
+ """Configuration class for DELULU model."""
162
+
163
+ model_type = "delulu"
164
+
165
+ def __init__(
166
+ self,
167
+ conv_dim=None,
168
+ conv_kernel=None,
169
+ conv_stride=None,
170
+ conv_bias=False,
171
+ extractor_mode="group_norm",
172
+ hidden_size=768,
173
+ num_hidden_layers=12,
174
+ num_attention_heads=12,
175
+ intermediate_size=3072,
176
+ hidden_dropout=0.1,
177
+ attention_dropout=0.1,
178
+ final_dropout=0.1,
179
+ feat_proj_dropout=0.1,
180
+ layer_norm_eps=1e-5,
181
+ layer_drop=0.05,
182
+ num_conv_pos_embeddings=128,
183
+ num_conv_pos_embedding_groups=16,
184
+ sampling_rate=16000,
185
+ do_stable_layer_norm=False,
186
+ num_clusters=256,
187
+ feature_type="redimnet",
188
+ pad_token_id=0,
189
+ bos_token_id=1,
190
+ eos_token_id=2,
191
+ **kwargs
192
+ ):
193
+ super().__init__(
194
+ pad_token_id=pad_token_id,
195
+ bos_token_id=bos_token_id,
196
+ eos_token_id=eos_token_id,
197
+ **kwargs
198
+ )
199
+
200
+ # DELULU conv config: [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
201
+ self.conv_dim = conv_dim or [512, 512, 512, 512, 512, 512, 512]
202
+ self.conv_kernel = conv_kernel or [10, 3, 3, 3, 3, 2, 2]
203
+ self.conv_stride = conv_stride or [4, 2, 2, 2, 2, 2, 2]
204
+ self.conv_bias = conv_bias
205
+ self.extractor_mode = extractor_mode
206
+
207
+ self.hidden_size = hidden_size
208
+ self.num_hidden_layers = num_hidden_layers
209
+ self.num_attention_heads = num_attention_heads
210
+ self.intermediate_size = intermediate_size
211
+ self.hidden_dropout = hidden_dropout
212
+ self.attention_dropout = attention_dropout
213
+ self.final_dropout = final_dropout
214
+ self.feat_proj_dropout = feat_proj_dropout
215
+ self.layer_norm_eps = layer_norm_eps
216
+ self.layer_drop = layer_drop
217
+
218
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
219
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
220
+
221
+ self.sampling_rate = sampling_rate
222
+ self.do_stable_layer_norm = do_stable_layer_norm
223
+
224
+ self.num_clusters = num_clusters
225
+ self.feature_type = feature_type
226
+
227
+ self.num_feat_extract_layers = len(self.conv_dim)
228
+ '''
229
+
230
+ with open(output_dir / "configuration_delulu.py", "w") as f:
231
+ f.write(code)
232
+ print("Created configuration_delulu.py")
233
+
234
+
235
+ def create_modeling_delulu(output_dir: Path):
236
+ """Create modeling_delulu.py file."""
237
+ code = '''"""DELULU Model"""
238
+
239
+ import torch
240
+ import torch.nn as nn
241
+ from typing import Optional, Tuple, Union
242
+ from transformers import PreTrainedModel
243
+ from transformers.modeling_outputs import BaseModelOutput
244
+ from .configuration_delulu import DELULUConfig
245
+
246
+ try:
247
+ from torchaudio.models.wav2vec2 import wav2vec2_model
248
+ TORCHAUDIO_AVAILABLE = True
249
+ except ImportError:
250
+ TORCHAUDIO_AVAILABLE = False
251
+
252
+
253
+ class DELULUModel(PreTrainedModel):
254
+ """
255
+ DELULU Model for speaker-aware speech representation learning.
256
+
257
+ Example:
258
+ ```python
259
+ from transformers import AutoModel
260
+ import torch
261
+
262
+ model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
263
+ waveform = torch.randn(1, 16000) # 1 second at 16kHz
264
+ outputs = model(waveform)
265
+ features = outputs.last_hidden_state
266
+ ```
267
+ """
268
+
269
+ config_class = DELULUConfig
270
+ base_model_prefix = "delulu"
271
+ main_input_name = "input_values"
272
+
273
+ def __init__(self, config: DELULUConfig):
274
+ super().__init__(config)
275
+ self.config = config
276
+
277
+ if not TORCHAUDIO_AVAILABLE:
278
+ raise ImportError("torchaudio is required. Install with: pip install torchaudio")
279
+
280
+ # Build conv config
281
+ conv_layer_config = list(zip(
282
+ config.conv_dim,
283
+ config.conv_kernel,
284
+ config.conv_stride
285
+ ))
286
+
287
+ # Create torchaudio model
288
+ self.wav2vec2 = wav2vec2_model(
289
+ extractor_mode=config.extractor_mode,
290
+ extractor_conv_layer_config=conv_layer_config,
291
+ extractor_conv_bias=config.conv_bias,
292
+ encoder_embed_dim=config.hidden_size,
293
+ encoder_projection_dropout=config.feat_proj_dropout,
294
+ encoder_pos_conv_kernel=config.num_conv_pos_embeddings,
295
+ encoder_pos_conv_groups=config.num_conv_pos_embedding_groups,
296
+ encoder_num_layers=config.num_hidden_layers,
297
+ encoder_num_heads=config.num_attention_heads,
298
+ encoder_attention_dropout=config.attention_dropout,
299
+ encoder_ff_interm_features=config.intermediate_size,
300
+ encoder_ff_interm_dropout=config.hidden_dropout,
301
+ encoder_dropout=config.hidden_dropout,
302
+ encoder_layer_norm_first=config.do_stable_layer_norm,
303
+ encoder_layer_drop=config.layer_drop,
304
+ aux_num_out=None,
305
+ )
306
+
307
+ self.post_init()
308
+
309
+ def _init_weights(self, module):
310
+ """Initialize weights."""
311
+ pass # Handled by torchaudio
312
+
313
+ def forward(
314
+ self,
315
+ input_values: torch.Tensor,
316
+ attention_mask: Optional[torch.Tensor] = None,
317
+ output_hidden_states: Optional[bool] = None,
318
+ return_dict: Optional[bool] = None,
319
+ ) -> Union[Tuple, BaseModelOutput]:
320
+ """
321
+ Args:
322
+ input_values: Audio waveform (batch, samples) at 16kHz
323
+ attention_mask: Optional attention mask
324
+ output_hidden_states: Whether to return all hidden states
325
+ return_dict: Whether to return BaseModelOutput
326
+ """
327
+ return_dict = return_dict if return_dict is not None else True
328
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
329
+
330
+ if input_values.dim() == 1:
331
+ input_values = input_values.unsqueeze(0)
332
+
333
+ lengths = None
334
+ if attention_mask is not None:
335
+ lengths = attention_mask.sum(dim=-1)
336
+
337
+ if output_hidden_states:
338
+ features, _ = self.wav2vec2.extract_features(input_values, lengths=lengths)
339
+ hidden_states = tuple(features)
340
+ last_hidden_state = features[-1]
341
+ else:
342
+ last_hidden_state, _ = self.wav2vec2(input_values, lengths=lengths)
343
+ hidden_states = None
344
+
345
+ if not return_dict:
346
+ return (last_hidden_state, hidden_states) if hidden_states else (last_hidden_state,)
347
+
348
+ return BaseModelOutput(
349
+ last_hidden_state=last_hidden_state,
350
+ hidden_states=hidden_states,
351
+ )
352
+
353
+ def extract_features(self, input_values: torch.Tensor):
354
+ """Extract features from all layers."""
355
+ if input_values.dim() == 1:
356
+ input_values = input_values.unsqueeze(0)
357
+ features, _ = self.wav2vec2.extract_features(input_values)
358
+ return tuple(features)
359
+
360
+ @classmethod
361
+ def _load_pretrained_model_low_mem(cls, *args, **kwargs):
362
+ """Override to handle custom loading."""
363
+ return super()._load_pretrained_model_low_mem(*args, **kwargs)
364
+ '''
365
+
366
+ with open(output_dir / "modeling_delulu.py", "w") as f:
367
+ f.write(code)
368
+ print("Created modeling_delulu.py")
369
+
370
+
371
+ def convert_checkpoint(checkpoint_path: str, output_dir: str):
372
+ """Main conversion function."""
373
+ output_path = Path(output_dir)
374
+ output_path.mkdir(parents=True, exist_ok=True)
375
+
376
+ print("=" * 60)
377
+ print("DELULU Checkpoint Converter")
378
+ print("=" * 60)
379
+
380
+ # Step 1: Load checkpoint
381
+ state_dict = load_lightning_checkpoint(checkpoint_path)
382
+
383
+ # Step 2: Print some keys for verification
384
+ print("\nSample keys in state dict:")
385
+ for i, key in enumerate(list(state_dict.keys())[:10]):
386
+ print(f" {key}")
387
+ print(f" ... and {len(state_dict) - 10} more")
388
+
389
+ # Step 3: Save weights
390
+ save_pytorch_model_bin(state_dict, output_path / "pytorch_model.bin")
391
+
392
+ if SAFETENSORS_AVAILABLE:
393
+ save_safetensors_model(state_dict, output_path / "model.safetensors")
394
+
395
+ # Step 4: Create config and code files
396
+ create_config_json(output_path)
397
+ create_configuration_delulu(output_path)
398
+ create_modeling_delulu(output_path)
399
+
400
+ # Step 5: Summary
401
+ print("\n" + "=" * 60)
402
+ print("Conversion Complete!")
403
+ print("=" * 60)
404
+ print(f"\nOutput directory: {output_path}")
405
+ print("\nFiles created:")
406
+ for f in sorted(output_path.iterdir()):
407
+ size_mb = f.stat().st_size / 1024 / 1024
408
+ print(f" {f.name}: {size_mb:.2f} MB")
409
+
410
+ print("\nNext steps:")
411
+ print(" 1. Upload all files to huggingface.co/cmu-mlsp/DELULU")
412
+ print(" 2. Test with:")
413
+ print(' model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)')
414
+
415
+
416
+ def main():
417
+ parser = argparse.ArgumentParser(description="Convert DELULU checkpoint to HuggingFace format")
418
+ parser.add_argument("--checkpoint", "-c", required=True, help="Path to .ckpt file")
419
+ parser.add_argument("--output-dir", "-o", required=True, help="Output directory")
420
+ args = parser.parse_args()
421
+
422
+ convert_checkpoint(args.checkpoint, args.output_dir)
423
+
424
+
425
+ if __name__ == "__main__":
426
+ main()
convert_delulu_to_hf.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ DELULU Checkpoint Converter
4
+ ===========================
5
+
6
+ Converts DELULU model checkpoints from torchaudio/PyTorch Lightning format
7
+ to Hugging Face compatible format (config.json + model weights).
8
+
9
+ Usage:
10
+ python convert_delulu_to_hf.py \
11
+ --checkpoint /path/to/epoch=45-step=400000.ckpt \
12
+ --output-dir ./delulu_hf_model
13
+
14
+ Author: Massa Baali
15
+ Model: DELULU - Speaker-Aware Self-Supervised Speech Foundational Model
16
+ """
17
+
18
+ import argparse
19
+ import json
20
+ import logging
21
+ import os
22
+ import sys
23
+ from pathlib import Path
24
+ from typing import Optional, Tuple, List
25
+ from collections import OrderedDict
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ try:
31
+ from safetensors.torch import save_file as save_safetensors
32
+ SAFETENSORS_AVAILABLE = True
33
+ except ImportError:
34
+ SAFETENSORS_AVAILABLE = False
35
+ print("Warning: safetensors not installed. Will save as pytorch_model.bin only.")
36
+ print("Install with: pip install safetensors")
37
+
38
+ # Configure logging
39
+ logging.basicConfig(
40
+ level=logging.INFO,
41
+ format="%(asctime)s - %(levelname)s - %(message)s",
42
+ datefmt="%Y-%m-%d %H:%M:%S",
43
+ )
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ # =============================================================================
48
+ # DELULU Configuration
49
+ # =============================================================================
50
+
51
+ class DELULUConfig:
52
+ """
53
+ Configuration class for DELULU model.
54
+
55
+ DELULU uses HuBERT architecture with modified convolutional strides
56
+ for 16ms frame shift, optimized for speaker verification.
57
+ """
58
+
59
+ # Model architecture identifier
60
+ model_type = "delulu"
61
+ architectures = ["DELULUModel"]
62
+
63
+ def __init__(
64
+ self,
65
+ # Convolutional feature extractor config
66
+ # DELULU: [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
67
+ conv_dim: List[int] = None,
68
+ conv_kernel: List[int] = None,
69
+ conv_stride: List[int] = None,
70
+ conv_bias: bool = False,
71
+ extractor_mode: str = "group_norm",
72
+
73
+ # Transformer encoder config
74
+ hidden_size: int = 768,
75
+ num_hidden_layers: int = 12,
76
+ num_attention_heads: int = 12,
77
+ intermediate_size: int = 3072,
78
+ hidden_dropout: float = 0.1,
79
+ attention_dropout: float = 0.1,
80
+ final_dropout: float = 0.1,
81
+ feat_proj_dropout: float = 0.1,
82
+ layer_norm_eps: float = 1e-5,
83
+ layer_drop: float = 0.05,
84
+
85
+ # Positional encoding
86
+ num_conv_pos_embeddings: int = 128,
87
+ num_conv_pos_embedding_groups: int = 16,
88
+
89
+ # Audio config
90
+ sampling_rate: int = 16000,
91
+ do_stable_layer_norm: bool = False,
92
+
93
+ # Training config (for reference)
94
+ num_clusters: int = 256,
95
+ feature_type: str = "redimnet",
96
+
97
+ **kwargs
98
+ ):
99
+ # Set default conv config for DELULU
100
+ if conv_dim is None:
101
+ conv_dim = [512, 512, 512, 512, 512, 512, 512]
102
+ if conv_kernel is None:
103
+ conv_kernel = [10, 3, 3, 3, 3, 2, 2]
104
+ if conv_stride is None:
105
+ conv_stride = [4, 2, 2, 2, 2, 2, 2] # Key difference from HuBERT!
106
+
107
+ self.conv_dim = conv_dim
108
+ self.conv_kernel = conv_kernel
109
+ self.conv_stride = conv_stride
110
+ self.conv_bias = conv_bias
111
+ self.extractor_mode = extractor_mode
112
+
113
+ self.hidden_size = hidden_size
114
+ self.num_hidden_layers = num_hidden_layers
115
+ self.num_attention_heads = num_attention_heads
116
+ self.intermediate_size = intermediate_size
117
+ self.hidden_dropout = hidden_dropout
118
+ self.attention_dropout = attention_dropout
119
+ self.final_dropout = final_dropout
120
+ self.feat_proj_dropout = feat_proj_dropout
121
+ self.layer_norm_eps = layer_norm_eps
122
+ self.layer_drop = layer_drop
123
+
124
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
125
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
126
+
127
+ self.sampling_rate = sampling_rate
128
+ self.do_stable_layer_norm = do_stable_layer_norm
129
+
130
+ self.num_clusters = num_clusters
131
+ self.feature_type = feature_type
132
+
133
+ # Store any additional kwargs
134
+ for key, value in kwargs.items():
135
+ setattr(self, key, value)
136
+
137
+ def to_dict(self) -> dict:
138
+ """Convert config to dictionary for JSON serialization."""
139
+ return {
140
+ # Model identification
141
+ "model_type": self.model_type,
142
+ "architectures": self.architectures,
143
+
144
+ # Convolutional feature extractor
145
+ "conv_dim": self.conv_dim,
146
+ "conv_kernel": self.conv_kernel,
147
+ "conv_stride": self.conv_stride,
148
+ "conv_bias": self.conv_bias,
149
+ "extractor_mode": self.extractor_mode,
150
+
151
+ # Transformer encoder
152
+ "hidden_size": self.hidden_size,
153
+ "num_hidden_layers": self.num_hidden_layers,
154
+ "num_attention_heads": self.num_attention_heads,
155
+ "intermediate_size": self.intermediate_size,
156
+ "hidden_dropout": self.hidden_dropout,
157
+ "attention_dropout": self.attention_dropout,
158
+ "final_dropout": self.final_dropout,
159
+ "feat_proj_dropout": self.feat_proj_dropout,
160
+ "layer_norm_eps": self.layer_norm_eps,
161
+ "layer_drop": self.layer_drop,
162
+
163
+ # Positional encoding
164
+ "num_conv_pos_embeddings": self.num_conv_pos_embeddings,
165
+ "num_conv_pos_embedding_groups": self.num_conv_pos_embedding_groups,
166
+
167
+ # Audio config
168
+ "sampling_rate": self.sampling_rate,
169
+ "do_stable_layer_norm": self.do_stable_layer_norm,
170
+
171
+ # Training reference
172
+ "num_clusters": self.num_clusters,
173
+ "feature_type": self.feature_type,
174
+
175
+ # Transformers compatibility
176
+ "transformers_version": "4.36.0",
177
+ "torch_dtype": "float32",
178
+
179
+ # Auto-mapping for custom code
180
+ "auto_map": {
181
+ "AutoConfig": "configuration_delulu.DELULUConfig",
182
+ "AutoModel": "modeling_delulu.DELULUModel"
183
+ }
184
+ }
185
+
186
+ def save_pretrained(self, save_directory: str):
187
+ """Save config to directory."""
188
+ os.makedirs(save_directory, exist_ok=True)
189
+ config_path = os.path.join(save_directory, "config.json")
190
+ with open(config_path, "w") as f:
191
+ json.dump(self.to_dict(), f, indent=2)
192
+ logger.info(f"Config saved to: {config_path}")
193
+
194
+
195
+ # =============================================================================
196
+ # Weight Mapping: torchaudio -> Hugging Face
197
+ # =============================================================================
198
+
199
+ def create_weight_mapping() -> dict:
200
+ """
201
+ Create mapping from torchaudio wav2vec2_model keys to Hugging Face format.
202
+
203
+ torchaudio structure:
204
+ feature_extractor.conv_layers.{i}.{0,1,2}...
205
+ encoder.feature_projection.{projection,layer_norm}...
206
+ encoder.transformer.pos_conv_embed...
207
+ encoder.transformer.layers.{i}.{attention,feed_forward,layer_norms}...
208
+ encoder.transformer.layer_norm...
209
+
210
+ HuggingFace structure:
211
+ feature_extractor.conv_layers.{i}.{conv,layer_norm}...
212
+ feature_projection.{projection,layer_norm}...
213
+ encoder.pos_conv_embed...
214
+ encoder.layers.{i}.{attention,feed_forward,layer_norm}...
215
+ encoder.layer_norm...
216
+ """
217
+
218
+ # This will be populated dynamically based on actual keys
219
+ mapping = {}
220
+ return mapping
221
+
222
+
223
+ def convert_torchaudio_to_hf(state_dict: dict) -> dict:
224
+ """
225
+ Convert torchaudio wav2vec2_model state dict to Hugging Face format.
226
+
227
+ Args:
228
+ state_dict: State dict from torchaudio model
229
+
230
+ Returns:
231
+ Converted state dict in HuggingFace format
232
+ """
233
+ new_state_dict = OrderedDict()
234
+
235
+ for key, value in state_dict.items():
236
+ new_key = key
237
+
238
+ # Feature extractor conv layers
239
+ # torchaudio: feature_extractor.conv_layers.0.0.weight -> hf: feature_extractor.conv_layers.0.conv.weight
240
+ if "feature_extractor.conv_layers" in key:
241
+ # Handle conv layer structure: .{layer_idx}.0. -> .{layer_idx}.conv.
242
+ # Handle norm layer structure: .{layer_idx}.2.1. -> .{layer_idx}.layer_norm.
243
+ parts = key.split(".")
244
+ layer_idx = parts[2]
245
+
246
+ if ".0." in key and "weight" in key:
247
+ # Convolution weight
248
+ new_key = f"delulu.feature_extractor.conv_layers.{layer_idx}.conv.weight"
249
+ elif ".2.1." in key or (".1." in key and "layer_norm" not in key):
250
+ # Group norm / layer norm
251
+ if "weight" in key:
252
+ new_key = f"delulu.feature_extractor.conv_layers.{layer_idx}.layer_norm.weight"
253
+ elif "bias" in key:
254
+ new_key = f"delulu.feature_extractor.conv_layers.{layer_idx}.layer_norm.bias"
255
+ else:
256
+ new_key = f"delulu.{key}"
257
+
258
+ # Feature projection
259
+ elif "encoder.feature_projection" in key:
260
+ new_key = key.replace("encoder.feature_projection", "delulu.feature_projection")
261
+
262
+ # Positional conv embedding
263
+ elif "encoder.transformer.pos_conv_embed" in key:
264
+ new_key = key.replace("encoder.transformer.pos_conv_embed", "delulu.encoder.pos_conv_embed")
265
+
266
+ # Transformer layers
267
+ elif "encoder.transformer.layers" in key:
268
+ new_key = key.replace("encoder.transformer.layers", "delulu.encoder.layers")
269
+
270
+ # Attention mappings
271
+ new_key = new_key.replace(".attention.k_proj", ".attention.k_proj")
272
+ new_key = new_key.replace(".attention.v_proj", ".attention.v_proj")
273
+ new_key = new_key.replace(".attention.q_proj", ".attention.q_proj")
274
+ new_key = new_key.replace(".attention.out_proj", ".attention.out_proj")
275
+
276
+ # Feed forward mappings
277
+ new_key = new_key.replace(".feed_forward.intermediate_dense", ".feed_forward.intermediate_dense")
278
+ new_key = new_key.replace(".feed_forward.output_dense", ".feed_forward.output_dense")
279
+
280
+ # Layer norm mappings
281
+ new_key = new_key.replace(".layer_norms.0", ".layer_norm")
282
+ new_key = new_key.replace(".layer_norms.1", ".final_layer_norm")
283
+
284
+ # Final layer norm
285
+ elif "encoder.transformer.layer_norm" in key:
286
+ new_key = key.replace("encoder.transformer.layer_norm", "delulu.encoder.layer_norm")
287
+
288
+ # Mask embedding (if present)
289
+ elif "mask_emb" in key:
290
+ new_key = f"delulu.{key}"
291
+
292
+ # Auxiliary head (if present)
293
+ elif "aux" in key:
294
+ new_key = key # Keep as is for now
295
+
296
+ else:
297
+ # Default: add delulu prefix
298
+ new_key = f"delulu.{key}"
299
+
300
+ new_state_dict[new_key] = value
301
+
302
+ if new_key != key:
303
+ logger.debug(f"Mapped: {key} -> {new_key}")
304
+
305
+ return new_state_dict
306
+
307
+
308
+ def convert_simple_format(state_dict: dict) -> dict:
309
+ """
310
+ Simple conversion that just renames keys minimally.
311
+ Suitable for direct loading with torchaudio models.
312
+ """
313
+ new_state_dict = OrderedDict()
314
+
315
+ for key, value in state_dict.items():
316
+ # Just add a model prefix for organization
317
+ new_key = f"model.{key}" if not key.startswith("model.") else key
318
+ new_state_dict[new_key] = value
319
+
320
+ return new_state_dict
321
+
322
+
323
+ # =============================================================================
324
+ # Checkpoint Loading
325
+ # =============================================================================
326
+
327
+ def load_lightning_checkpoint(checkpoint_path: str) -> Tuple[dict, dict]:
328
+ """
329
+ Load PyTorch Lightning checkpoint and extract model state dict.
330
+
331
+ Args:
332
+ checkpoint_path: Path to .ckpt file
333
+
334
+ Returns:
335
+ Tuple of (state_dict, hyperparameters)
336
+ """
337
+ logger.info(f"Loading checkpoint: {checkpoint_path}")
338
+
339
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
340
+
341
+ # Extract state dict
342
+ if "state_dict" in checkpoint:
343
+ state_dict = checkpoint["state_dict"]
344
+ else:
345
+ state_dict = checkpoint
346
+
347
+ # Extract hyperparameters if available
348
+ hparams = checkpoint.get("hyper_parameters", {})
349
+
350
+ # Clean up state dict keys (remove Lightning prefixes)
351
+ cleaned_state_dict = OrderedDict()
352
+ for key, value in state_dict.items():
353
+ new_key = key
354
+
355
+ # Remove common Lightning prefixes
356
+ if key.startswith("model.wav2vec2."):
357
+ new_key = key.replace("model.wav2vec2.", "")
358
+ elif key.startswith("model."):
359
+ new_key = key.replace("model.", "")
360
+
361
+ # Skip auxiliary heads unless needed
362
+ if "aux" in new_key:
363
+ logger.debug(f"Skipping auxiliary layer: {key}")
364
+ continue
365
+
366
+ cleaned_state_dict[new_key] = value
367
+
368
+ logger.info(f"Loaded {len(cleaned_state_dict)} parameters")
369
+
370
+ return cleaned_state_dict, hparams
371
+
372
+
373
+ def verify_state_dict(state_dict: dict) -> bool:
374
+ """
375
+ Verify the state dict has expected DELULU components.
376
+ """
377
+ expected_prefixes = [
378
+ "feature_extractor",
379
+ "encoder",
380
+ ]
381
+
382
+ found_prefixes = set()
383
+ for key in state_dict.keys():
384
+ for prefix in expected_prefixes:
385
+ if prefix in key:
386
+ found_prefixes.add(prefix)
387
+
388
+ missing = set(expected_prefixes) - found_prefixes
389
+ if missing:
390
+ logger.warning(f"Missing expected components: {missing}")
391
+ return False
392
+
393
+ logger.info("✓ State dict contains expected components")
394
+ return True
395
+
396
+
397
+ # =============================================================================
398
+ # Main Conversion
399
+ # =============================================================================
400
+
401
+ def convert_checkpoint(
402
+ checkpoint_path: str,
403
+ output_dir: str,
404
+ save_safetensors_format: bool = True,
405
+ save_bin_format: bool = True,
406
+ verify: bool = True
407
+ ) -> None:
408
+ """
409
+ Convert DELULU checkpoint to Hugging Face format.
410
+
411
+ Args:
412
+ checkpoint_path: Path to input .ckpt file
413
+ output_dir: Output directory for converted model
414
+ save_safetensors_format: Save in safetensors format
415
+ save_bin_format: Save in pytorch_model.bin format
416
+ verify: Verify the conversion
417
+ """
418
+ output_path = Path(output_dir)
419
+ output_path.mkdir(parents=True, exist_ok=True)
420
+
421
+ # Step 1: Load checkpoint
422
+ state_dict, hparams = load_lightning_checkpoint(checkpoint_path)
423
+
424
+ # Step 2: Verify state dict
425
+ if verify:
426
+ verify_state_dict(state_dict)
427
+
428
+ # Step 3: Create and save config
429
+ logger.info("Creating DELULU config...")
430
+ config = DELULUConfig(
431
+ # Use DELULU's custom conv config
432
+ conv_dim=[512, 512, 512, 512, 512, 512, 512],
433
+ conv_kernel=[10, 3, 3, 3, 3, 2, 2],
434
+ conv_stride=[4, 2, 2, 2, 2, 2, 2], # Key difference!
435
+ )
436
+ config.save_pretrained(output_dir)
437
+
438
+ # Step 4: Convert state dict format (minimal conversion)
439
+ logger.info("Converting state dict format...")
440
+ # Keep the original format since it's compatible with torchaudio loading
441
+ converted_state_dict = state_dict
442
+
443
+ # Step 5: Save weights
444
+ if save_safetensors_format and SAFETENSORS_AVAILABLE:
445
+ safetensors_path = output_path / "model.safetensors"
446
+ logger.info(f"Saving safetensors to: {safetensors_path}")
447
+ save_safetensors(converted_state_dict, str(safetensors_path))
448
+
449
+ if save_bin_format:
450
+ bin_path = output_path / "pytorch_model.bin"
451
+ logger.info(f"Saving pytorch_model.bin to: {bin_path}")
452
+ torch.save(converted_state_dict, str(bin_path))
453
+
454
+ # Step 6: Create additional files
455
+ create_additional_files(output_path, config)
456
+
457
+ # Step 7: Print summary
458
+ print_conversion_summary(checkpoint_path, output_dir, converted_state_dict)
459
+
460
+
461
+ def create_additional_files(output_path: Path, config: DELULUConfig) -> None:
462
+ """Create additional files needed for Hugging Face model."""
463
+
464
+ # Create preprocessor_config.json
465
+ preprocessor_config = {
466
+ "do_normalize": True,
467
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
468
+ "feature_size": 1,
469
+ "padding_side": "right",
470
+ "padding_value": 0.0,
471
+ "return_attention_mask": True,
472
+ "sampling_rate": config.sampling_rate,
473
+ }
474
+
475
+ with open(output_path / "preprocessor_config.json", "w") as f:
476
+ json.dump(preprocessor_config, f, indent=2)
477
+ logger.info("Created preprocessor_config.json")
478
+
479
+ # Create a simple modeling file for reference
480
+ modeling_code = '''"""
481
+ DELULU Model - Minimal Loading Example
482
+
483
+ This file shows how to load DELULU weights with torchaudio.
484
+ For full Hugging Face Transformers integration, see the modeling_delulu.py file.
485
+ """
486
+
487
+ import torch
488
+ from torchaudio.models.wav2vec2 import wav2vec2_model
489
+
490
+ # DELULU configuration
491
+ DELULU_CONV_CONFIG = [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
492
+
493
+ def load_delulu(checkpoint_path: str = None, weights_path: str = None):
494
+ """
495
+ Load DELULU model.
496
+
497
+ Args:
498
+ checkpoint_path: Path to original .ckpt file (PyTorch Lightning format)
499
+ weights_path: Path to pytorch_model.bin (Hugging Face format)
500
+
501
+ Returns:
502
+ DELULU model ready for inference
503
+ """
504
+ model = wav2vec2_model(
505
+ extractor_mode="group_norm",
506
+ extractor_conv_layer_config=DELULU_CONV_CONFIG,
507
+ extractor_conv_bias=False,
508
+ encoder_embed_dim=768,
509
+ encoder_projection_dropout=0.1,
510
+ encoder_pos_conv_kernel=128,
511
+ encoder_pos_conv_groups=16,
512
+ encoder_num_layers=12,
513
+ encoder_num_heads=12,
514
+ encoder_attention_dropout=0.1,
515
+ encoder_ff_interm_features=3072,
516
+ encoder_ff_interm_dropout=0.1,
517
+ encoder_dropout=0.1,
518
+ encoder_layer_norm_first=False,
519
+ encoder_layer_drop=0.05,
520
+ aux_num_out=None,
521
+ )
522
+
523
+ if checkpoint_path:
524
+ # Load from original Lightning checkpoint
525
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
526
+ state_dict = checkpoint.get("state_dict", checkpoint)
527
+
528
+ # Clean keys
529
+ new_state_dict = {}
530
+ for k, v in state_dict.items():
531
+ if "model.wav2vec2" in k:
532
+ new_state_dict[k.replace("model.wav2vec2.", "")] = v
533
+ elif not k.startswith("aux"):
534
+ new_state_dict[k] = v
535
+
536
+ model.load_state_dict(new_state_dict, strict=False)
537
+
538
+ elif weights_path:
539
+ # Load from Hugging Face format
540
+ state_dict = torch.load(weights_path, map_location="cpu")
541
+ model.load_state_dict(state_dict, strict=False)
542
+
543
+ return model
544
+
545
+
546
+ def extract_features(model, waveform: torch.Tensor) -> torch.Tensor:
547
+ """
548
+ Extract speaker features from audio waveform.
549
+
550
+ Args:
551
+ model: DELULU model
552
+ waveform: Audio tensor of shape (batch, samples) at 16kHz
553
+
554
+ Returns:
555
+ Features of shape (batch, time, 768)
556
+ """
557
+ model.eval()
558
+ with torch.no_grad():
559
+ features, _ = model.extract_features(waveform)
560
+ # Return last layer features
561
+ return features[-1]
562
+
563
+
564
+ if __name__ == "__main__":
565
+ # Example usage
566
+ import sys
567
+
568
+ if len(sys.argv) > 1:
569
+ model = load_delulu(weights_path=sys.argv[1])
570
+ print(f"Model loaded successfully!")
571
+ print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
572
+ else:
573
+ print("Usage: python load_delulu.py path/to/pytorch_model.bin")
574
+ '''
575
+
576
+ with open(output_path / "load_delulu.py", "w") as f:
577
+ f.write(modeling_code)
578
+ logger.info("Created load_delulu.py")
579
+
580
+
581
+ def print_conversion_summary(
582
+ input_path: str,
583
+ output_dir: str,
584
+ state_dict: dict
585
+ ) -> None:
586
+ """Print summary of the conversion."""
587
+
588
+ total_params = sum(p.numel() for p in state_dict.values())
589
+
590
+ print("\n" + "=" * 60)
591
+ print("DELULU Checkpoint Conversion Complete!")
592
+ print("=" * 60)
593
+ print(f"\nInput: {input_path}")
594
+ print(f"Output: {output_dir}")
595
+ print(f"\nModel Statistics:")
596
+ print(f" - Total parameters: {total_params:,}")
597
+ print(f" - Parameter tensors: {len(state_dict)}")
598
+ print(f"\nOutput Files:")
599
+
600
+ output_path = Path(output_dir)
601
+ for f in sorted(output_path.iterdir()):
602
+ size_mb = f.stat().st_size / 1024 / 1024
603
+ print(f" - {f.name}: {size_mb:.2f} MB")
604
+
605
+ print(f"\nNext Steps:")
606
+ print(f" 1. Test loading: python {output_dir}/load_delulu.py {output_dir}/pytorch_model.bin")
607
+ print(f" 2. Upload to HF: python upload_delulu_to_hf.py --checkpoint-dir {output_dir} --repo-id YOUR_USERNAME/DELULU")
608
+ print("=" * 60 + "\n")
609
+
610
+
611
+ # =============================================================================
612
+ # CLI Interface
613
+ # =============================================================================
614
+
615
+ def parse_args() -> argparse.Namespace:
616
+ parser = argparse.ArgumentParser(
617
+ description="Convert DELULU checkpoint to Hugging Face format",
618
+ formatter_class=argparse.RawDescriptionHelpFormatter,
619
+ epilog="""
620
+ Examples:
621
+ # Basic conversion
622
+ python convert_delulu_to_hf.py \\
623
+ --checkpoint /path/to/epoch=45-step=400000.ckpt \\
624
+ --output-dir ./delulu_hf_model
625
+
626
+ # Save only safetensors format
627
+ python convert_delulu_to_hf.py \\
628
+ --checkpoint /path/to/checkpoint.ckpt \\
629
+ --output-dir ./delulu_hf_model \\
630
+ --no-bin
631
+
632
+ # Skip verification
633
+ python convert_delulu_to_hf.py \\
634
+ --checkpoint /path/to/checkpoint.ckpt \\
635
+ --output-dir ./delulu_hf_model \\
636
+ --no-verify
637
+ """
638
+ )
639
+
640
+ parser.add_argument(
641
+ "--checkpoint", "-c",
642
+ type=str,
643
+ required=True,
644
+ help="Path to DELULU checkpoint (.ckpt file)"
645
+ )
646
+
647
+ parser.add_argument(
648
+ "--output-dir", "-o",
649
+ type=str,
650
+ required=True,
651
+ help="Output directory for converted model"
652
+ )
653
+
654
+ parser.add_argument(
655
+ "--no-safetensors",
656
+ action="store_true",
657
+ help="Don't save in safetensors format"
658
+ )
659
+
660
+ parser.add_argument(
661
+ "--no-bin",
662
+ action="store_true",
663
+ help="Don't save pytorch_model.bin"
664
+ )
665
+
666
+ parser.add_argument(
667
+ "--no-verify",
668
+ action="store_true",
669
+ help="Skip state dict verification"
670
+ )
671
+
672
+ parser.add_argument(
673
+ "--verbose", "-v",
674
+ action="store_true",
675
+ help="Enable verbose logging"
676
+ )
677
+
678
+ return parser.parse_args()
679
+
680
+
681
+ def main():
682
+ args = parse_args()
683
+
684
+ if args.verbose:
685
+ logging.getLogger().setLevel(logging.DEBUG)
686
+
687
+ convert_checkpoint(
688
+ checkpoint_path=args.checkpoint,
689
+ output_dir=args.output_dir,
690
+ save_safetensors_format=not args.no_safetensors,
691
+ save_bin_format=not args.no_bin,
692
+ verify=not args.no_verify
693
+ )
694
+
695
+
696
+ if __name__ == "__main__":
697
+ main()
checksums.json → delulu_hf_model/checksums.json RENAMED
File without changes
delulu_hf_model/config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "delulu",
3
+ "architectures": [
4
+ "DELULUModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_delulu.DELULUConfig",
8
+ "AutoModel": "modeling_delulu.DELULUModel"
9
+ },
10
+ "conv_dim": [
11
+ 512,
12
+ 512,
13
+ 512,
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512
18
+ ],
19
+ "conv_kernel": [
20
+ 10,
21
+ 3,
22
+ 3,
23
+ 3,
24
+ 3,
25
+ 2,
26
+ 2
27
+ ],
28
+ "conv_stride": [
29
+ 4,
30
+ 2,
31
+ 2,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2
36
+ ],
37
+ "conv_bias": false,
38
+ "extractor_mode": "group_norm",
39
+ "hidden_size": 768,
40
+ "num_hidden_layers": 12,
41
+ "num_attention_heads": 12,
42
+ "intermediate_size": 3072,
43
+ "hidden_dropout": 0.1,
44
+ "attention_dropout": 0.1,
45
+ "final_dropout": 0.1,
46
+ "feat_proj_dropout": 0.1,
47
+ "layer_norm_eps": 1e-05,
48
+ "layer_drop": 0.05,
49
+ "num_conv_pos_embeddings": 128,
50
+ "num_conv_pos_embedding_groups": 16,
51
+ "sampling_rate": 16000,
52
+ "do_stable_layer_norm": false,
53
+ "num_clusters": 256,
54
+ "feature_type": "redimnet",
55
+ "num_feat_extract_layers": 7,
56
+ "pad_token_id": 0,
57
+ "bos_token_id": 1,
58
+ "eos_token_id": 2,
59
+ "torch_dtype": "float32"
60
+ }
delulu_hf_model/configuration_delulu.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DELULU Configuration"""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class DELULUConfig(PretrainedConfig):
7
+ """Configuration class for DELULU model."""
8
+
9
+ model_type = "delulu"
10
+
11
+ def __init__(
12
+ self,
13
+ conv_dim=None,
14
+ conv_kernel=None,
15
+ conv_stride=None,
16
+ conv_bias=False,
17
+ extractor_mode="group_norm",
18
+ hidden_size=768,
19
+ num_hidden_layers=12,
20
+ num_attention_heads=12,
21
+ intermediate_size=3072,
22
+ hidden_dropout=0.1,
23
+ attention_dropout=0.1,
24
+ final_dropout=0.1,
25
+ feat_proj_dropout=0.1,
26
+ layer_norm_eps=1e-5,
27
+ layer_drop=0.05,
28
+ num_conv_pos_embeddings=128,
29
+ num_conv_pos_embedding_groups=16,
30
+ sampling_rate=16000,
31
+ do_stable_layer_norm=False,
32
+ num_clusters=256,
33
+ feature_type="redimnet",
34
+ pad_token_id=0,
35
+ bos_token_id=1,
36
+ eos_token_id=2,
37
+ **kwargs
38
+ ):
39
+ super().__init__(
40
+ pad_token_id=pad_token_id,
41
+ bos_token_id=bos_token_id,
42
+ eos_token_id=eos_token_id,
43
+ **kwargs
44
+ )
45
+
46
+ # DELULU conv config: [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
47
+ self.conv_dim = conv_dim or [512, 512, 512, 512, 512, 512, 512]
48
+ self.conv_kernel = conv_kernel or [10, 3, 3, 3, 3, 2, 2]
49
+ self.conv_stride = conv_stride or [4, 2, 2, 2, 2, 2, 2]
50
+ self.conv_bias = conv_bias
51
+ self.extractor_mode = extractor_mode
52
+
53
+ self.hidden_size = hidden_size
54
+ self.num_hidden_layers = num_hidden_layers
55
+ self.num_attention_heads = num_attention_heads
56
+ self.intermediate_size = intermediate_size
57
+ self.hidden_dropout = hidden_dropout
58
+ self.attention_dropout = attention_dropout
59
+ self.final_dropout = final_dropout
60
+ self.feat_proj_dropout = feat_proj_dropout
61
+ self.layer_norm_eps = layer_norm_eps
62
+ self.layer_drop = layer_drop
63
+
64
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
65
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
66
+
67
+ self.sampling_rate = sampling_rate
68
+ self.do_stable_layer_norm = do_stable_layer_norm
69
+
70
+ self.num_clusters = num_clusters
71
+ self.feature_type = feature_type
72
+
73
+ self.num_feat_extract_layers = len(self.conv_dim)
model.safetensors → delulu_hf_model/model.safetensors RENAMED
File without changes
delulu_hf_model/modeling_delulu.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DELULU Model"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Optional, Tuple, Union
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import BaseModelOutput
8
+ from .configuration_delulu import DELULUConfig
9
+
10
+ try:
11
+ from torchaudio.models.wav2vec2 import wav2vec2_model
12
+ TORCHAUDIO_AVAILABLE = True
13
+ except ImportError:
14
+ TORCHAUDIO_AVAILABLE = False
15
+
16
+
17
+ class DELULUModel(PreTrainedModel):
18
+ """
19
+ DELULU Model for speaker-aware speech representation learning.
20
+
21
+ Example:
22
+ ```python
23
+ from transformers import AutoModel
24
+ import torch
25
+
26
+ model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
27
+ waveform = torch.randn(1, 16000) # 1 second at 16kHz
28
+ outputs = model(waveform)
29
+ features = outputs.last_hidden_state
30
+ ```
31
+ """
32
+
33
+ config_class = DELULUConfig
34
+ base_model_prefix = "delulu"
35
+ main_input_name = "input_values"
36
+
37
+ def __init__(self, config: DELULUConfig):
38
+ super().__init__(config)
39
+ self.config = config
40
+
41
+ if not TORCHAUDIO_AVAILABLE:
42
+ raise ImportError("torchaudio is required. Install with: pip install torchaudio")
43
+
44
+ # Build conv config
45
+ conv_layer_config = list(zip(
46
+ config.conv_dim,
47
+ config.conv_kernel,
48
+ config.conv_stride
49
+ ))
50
+
51
+ # Create torchaudio model
52
+ self.wav2vec2 = wav2vec2_model(
53
+ extractor_mode=config.extractor_mode,
54
+ extractor_conv_layer_config=conv_layer_config,
55
+ extractor_conv_bias=config.conv_bias,
56
+ encoder_embed_dim=config.hidden_size,
57
+ encoder_projection_dropout=config.feat_proj_dropout,
58
+ encoder_pos_conv_kernel=config.num_conv_pos_embeddings,
59
+ encoder_pos_conv_groups=config.num_conv_pos_embedding_groups,
60
+ encoder_num_layers=config.num_hidden_layers,
61
+ encoder_num_heads=config.num_attention_heads,
62
+ encoder_attention_dropout=config.attention_dropout,
63
+ encoder_ff_interm_features=config.intermediate_size,
64
+ encoder_ff_interm_dropout=config.hidden_dropout,
65
+ encoder_dropout=config.hidden_dropout,
66
+ encoder_layer_norm_first=config.do_stable_layer_norm,
67
+ encoder_layer_drop=config.layer_drop,
68
+ aux_num_out=None,
69
+ )
70
+
71
+ self.post_init()
72
+
73
+ def _init_weights(self, module):
74
+ """Initialize weights."""
75
+ pass # Handled by torchaudio
76
+
77
+ def forward(
78
+ self,
79
+ input_values: torch.Tensor,
80
+ attention_mask: Optional[torch.Tensor] = None,
81
+ output_hidden_states: Optional[bool] = None,
82
+ return_dict: Optional[bool] = None,
83
+ ) -> Union[Tuple, BaseModelOutput]:
84
+ """
85
+ Args:
86
+ input_values: Audio waveform (batch, samples) at 16kHz
87
+ attention_mask: Optional attention mask
88
+ output_hidden_states: Whether to return all hidden states
89
+ return_dict: Whether to return BaseModelOutput
90
+ """
91
+ return_dict = return_dict if return_dict is not None else True
92
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
93
+
94
+ if input_values.dim() == 1:
95
+ input_values = input_values.unsqueeze(0)
96
+
97
+ lengths = None
98
+ if attention_mask is not None:
99
+ lengths = attention_mask.sum(dim=-1)
100
+
101
+ if output_hidden_states:
102
+ features, _ = self.wav2vec2.extract_features(input_values, lengths=lengths)
103
+ hidden_states = tuple(features)
104
+ last_hidden_state = features[-1]
105
+ else:
106
+ last_hidden_state, _ = self.wav2vec2(input_values, lengths=lengths)
107
+ hidden_states = None
108
+
109
+ if not return_dict:
110
+ return (last_hidden_state, hidden_states) if hidden_states else (last_hidden_state,)
111
+
112
+ return BaseModelOutput(
113
+ last_hidden_state=last_hidden_state,
114
+ hidden_states=hidden_states,
115
+ )
116
+
117
+ def extract_features(self, input_values: torch.Tensor):
118
+ """Extract features from all layers."""
119
+ if input_values.dim() == 1:
120
+ input_values = input_values.unsqueeze(0)
121
+ features, _ = self.wav2vec2.extract_features(input_values)
122
+ return tuple(features)
123
+
124
+ @classmethod
125
+ def _load_pretrained_model_low_mem(cls, *args, **kwargs):
126
+ """Override to handle custom loading."""
127
+ return super()._load_pretrained_model_low_mem(*args, **kwargs)
pytorch_model.bin → delulu_hf_model/pytorch_model.bin RENAMED
File without changes
upload_metadata.json → delulu_hf_model/upload_metadata.json RENAMED
File without changes
load_delulu.py DELETED
@@ -1,94 +0,0 @@
1
- """
2
- DELULU Model - Minimal Loading Example
3
-
4
- This file shows how to load DELULU weights with torchaudio.
5
- For full Hugging Face Transformers integration, see the modeling_delulu.py file.
6
- """
7
-
8
- import torch
9
- from torchaudio.models.wav2vec2 import wav2vec2_model
10
-
11
- # DELULU configuration
12
- DELULU_CONV_CONFIG = [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
13
-
14
- def load_delulu(checkpoint_path: str = None, weights_path: str = None):
15
- """
16
- Load DELULU model.
17
-
18
- Args:
19
- checkpoint_path: Path to original .ckpt file (PyTorch Lightning format)
20
- weights_path: Path to pytorch_model.bin (Hugging Face format)
21
-
22
- Returns:
23
- DELULU model ready for inference
24
- """
25
- model = wav2vec2_model(
26
- extractor_mode="group_norm",
27
- extractor_conv_layer_config=DELULU_CONV_CONFIG,
28
- extractor_conv_bias=False,
29
- encoder_embed_dim=768,
30
- encoder_projection_dropout=0.1,
31
- encoder_pos_conv_kernel=128,
32
- encoder_pos_conv_groups=16,
33
- encoder_num_layers=12,
34
- encoder_num_heads=12,
35
- encoder_attention_dropout=0.1,
36
- encoder_ff_interm_features=3072,
37
- encoder_ff_interm_dropout=0.1,
38
- encoder_dropout=0.1,
39
- encoder_layer_norm_first=False,
40
- encoder_layer_drop=0.05,
41
- aux_num_out=None,
42
- )
43
-
44
- if checkpoint_path:
45
- # Load from original Lightning checkpoint
46
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
47
- state_dict = checkpoint.get("state_dict", checkpoint)
48
-
49
- # Clean keys
50
- new_state_dict = {}
51
- for k, v in state_dict.items():
52
- if "model.wav2vec2" in k:
53
- new_state_dict[k.replace("model.wav2vec2.", "")] = v
54
- elif not k.startswith("aux"):
55
- new_state_dict[k] = v
56
-
57
- model.load_state_dict(new_state_dict, strict=False)
58
-
59
- elif weights_path:
60
- # Load from Hugging Face format
61
- state_dict = torch.load(weights_path, map_location="cpu")
62
- model.load_state_dict(state_dict, strict=False)
63
-
64
- return model
65
-
66
-
67
- def extract_features(model, waveform: torch.Tensor) -> torch.Tensor:
68
- """
69
- Extract speaker features from audio waveform.
70
-
71
- Args:
72
- model: DELULU model
73
- waveform: Audio tensor of shape (batch, samples) at 16kHz
74
-
75
- Returns:
76
- Features of shape (batch, time, 768)
77
- """
78
- model.eval()
79
- with torch.no_grad():
80
- features, _ = model.extract_features(waveform)
81
- # Return last layer features
82
- return features[-1]
83
-
84
-
85
- if __name__ == "__main__":
86
- # Example usage
87
- import sys
88
-
89
- if len(sys.argv) > 1:
90
- model = load_delulu(weights_path=sys.argv[1])
91
- print(f"Model loaded successfully!")
92
- print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
93
- else:
94
- print("Usage: python load_delulu.py path/to/pytorch_model.bin")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_delulu.py CHANGED
@@ -1,12 +1,27 @@
1
- """DELULU Model"""
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  import torch.nn as nn
5
  from typing import Optional, Tuple, Union
 
 
6
  from transformers import PreTrainedModel
7
  from transformers.modeling_outputs import BaseModelOutput
 
8
  from .configuration_delulu import DELULUConfig
9
 
 
10
  try:
11
  from torchaudio.models.wav2vec2 import wav2vec2_model
12
  TORCHAUDIO_AVAILABLE = True
@@ -14,41 +29,79 @@ except ImportError:
14
  TORCHAUDIO_AVAILABLE = False
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class DELULUModel(PreTrainedModel):
18
  """
19
  DELULU Model for speaker-aware speech representation learning.
20
 
 
 
 
21
  Example:
22
  ```python
23
  from transformers import AutoModel
24
  import torch
25
 
 
26
  model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
27
- waveform = torch.randn(1, 16000) # 1 second at 16kHz
28
- outputs = model(waveform)
29
- features = outputs.last_hidden_state
 
 
 
 
 
 
 
 
30
  ```
31
  """
32
 
33
  config_class = DELULUConfig
34
  base_model_prefix = "delulu"
35
  main_input_name = "input_values"
 
36
 
37
  def __init__(self, config: DELULUConfig):
38
  super().__init__(config)
39
  self.config = config
40
 
41
  if not TORCHAUDIO_AVAILABLE:
42
- raise ImportError("torchaudio is required. Install with: pip install torchaudio")
 
 
 
43
 
44
- # Build conv config
45
  conv_layer_config = list(zip(
46
  config.conv_dim,
47
  config.conv_kernel,
48
  config.conv_stride
49
  ))
50
 
51
- # Create torchaudio model
52
  self.wav2vec2 = wav2vec2_model(
53
  extractor_mode=config.extractor_mode,
54
  extractor_conv_layer_config=conv_layer_config,
@@ -68,60 +121,213 @@ class DELULUModel(PreTrainedModel):
68
  aux_num_out=None,
69
  )
70
 
 
71
  self.post_init()
72
 
73
- def _init_weights(self, module):
74
- """Initialize weights."""
75
- pass # Handled by torchaudio
76
-
77
  def forward(
78
  self,
79
  input_values: torch.Tensor,
80
  attention_mask: Optional[torch.Tensor] = None,
81
  output_hidden_states: Optional[bool] = None,
 
82
  return_dict: Optional[bool] = None,
83
- ) -> Union[Tuple, BaseModelOutput]:
84
  """
 
 
85
  Args:
86
- input_values: Audio waveform (batch, samples) at 16kHz
87
- attention_mask: Optional attention mask
88
- output_hidden_states: Whether to return all hidden states
89
- return_dict: Whether to return BaseModelOutput
 
 
 
 
 
 
 
 
 
90
  """
91
- return_dict = return_dict if return_dict is not None else True
92
- output_hidden_states = output_hidden_states if output_hidden_states is not None else False
 
 
 
 
93
 
 
94
  if input_values.dim() == 1:
95
  input_values = input_values.unsqueeze(0)
96
 
 
97
  lengths = None
98
  if attention_mask is not None:
99
  lengths = attention_mask.sum(dim=-1)
100
 
 
101
  if output_hidden_states:
102
- features, _ = self.wav2vec2.extract_features(input_values, lengths=lengths)
 
 
 
 
 
103
  hidden_states = tuple(features)
104
  last_hidden_state = features[-1]
105
  else:
106
- last_hidden_state, _ = self.wav2vec2(input_values, lengths=lengths)
 
 
107
  hidden_states = None
108
 
 
 
 
109
  if not return_dict:
110
- return (last_hidden_state, hidden_states) if hidden_states else (last_hidden_state,)
 
 
 
111
 
112
- return BaseModelOutput(
113
  last_hidden_state=last_hidden_state,
114
  hidden_states=hidden_states,
 
 
115
  )
116
 
117
- def extract_features(self, input_values: torch.Tensor):
118
- """Extract features from all layers."""
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  if input_values.dim() == 1:
120
  input_values = input_values.unsqueeze(0)
121
- features, _ = self.wav2vec2.extract_features(input_values)
 
122
  return tuple(features)
123
 
124
- @classmethod
125
- def _load_pretrained_model_low_mem(cls, *args, **kwargs):
126
- """Override to handle custom loading."""
127
- return super()._load_pretrained_model_low_mem(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DELULU Model
3
+
4
+ DELULU (Discriminative Embedding Learning Using Latent Units) is a speaker-aware
5
+ self-supervised speech foundational model based on HuBERT architecture.
6
+
7
+ Paper: https://arxiv.org/abs/2510.17662
8
+ Authors: Massa Baali, Rita Singh, Bhiksha Raj
9
+
10
+ This implementation wraps torchaudio's wav2vec2_model for compatibility with
11
+ Hugging Face's AutoModel interface.
12
+ """
13
 
14
  import torch
15
  import torch.nn as nn
16
  from typing import Optional, Tuple, Union
17
+ from dataclasses import dataclass
18
+
19
  from transformers import PreTrainedModel
20
  from transformers.modeling_outputs import BaseModelOutput
21
+
22
  from .configuration_delulu import DELULUConfig
23
 
24
+ # Try to import torchaudio
25
  try:
26
  from torchaudio.models.wav2vec2 import wav2vec2_model
27
  TORCHAUDIO_AVAILABLE = True
 
29
  TORCHAUDIO_AVAILABLE = False
30
 
31
 
32
+ @dataclass
33
+ class DELULUOutput(BaseModelOutput):
34
+ """
35
+ Output class for DELULU model.
36
+
37
+ Args:
38
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
39
+ Sequence of hidden-states at the output of the last layer of the model.
40
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
41
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for each layer)
42
+ of shape `(batch_size, sequence_length, hidden_size)`.
43
+ attentions (`tuple(torch.FloatTensor)`, *optional*):
44
+ Attention weights (not available for torchaudio backend).
45
+ extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
46
+ Features from the convolutional feature extractor.
47
+ """
48
+ last_hidden_state: torch.FloatTensor = None
49
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
50
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
51
+ extract_features: Optional[torch.FloatTensor] = None
52
+
53
+
54
  class DELULUModel(PreTrainedModel):
55
  """
56
  DELULU Model for speaker-aware speech representation learning.
57
 
58
+ This model wraps torchaudio's wav2vec2_model with DELULU's custom configuration
59
+ (modified convolutional strides for 16ms frame shift).
60
+
61
  Example:
62
  ```python
63
  from transformers import AutoModel
64
  import torch
65
 
66
+ # Load model
67
  model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
68
+ model.eval()
69
+
70
+ # Process audio (16kHz, mono)
71
+ waveform = torch.randn(1, 16000) # 1 second of audio
72
+
73
+ with torch.no_grad():
74
+ outputs = model(waveform)
75
+ features = outputs.last_hidden_state # [1, T, 768]
76
+
77
+ # For speaker verification, use mean pooling
78
+ speaker_embedding = features.mean(dim=1) # [1, 768]
79
  ```
80
  """
81
 
82
  config_class = DELULUConfig
83
  base_model_prefix = "delulu"
84
  main_input_name = "input_values"
85
+ supports_gradient_checkpointing = False
86
 
87
  def __init__(self, config: DELULUConfig):
88
  super().__init__(config)
89
  self.config = config
90
 
91
  if not TORCHAUDIO_AVAILABLE:
92
+ raise ImportError(
93
+ "torchaudio is required for DELULU model. "
94
+ "Install with: pip install torchaudio"
95
+ )
96
 
97
+ # Build convolutional layer config from DELULU config
98
  conv_layer_config = list(zip(
99
  config.conv_dim,
100
  config.conv_kernel,
101
  config.conv_stride
102
  ))
103
 
104
+ # Create the underlying torchaudio model
105
  self.wav2vec2 = wav2vec2_model(
106
  extractor_mode=config.extractor_mode,
107
  extractor_conv_layer_config=conv_layer_config,
 
121
  aux_num_out=None,
122
  )
123
 
124
+ # Initialize weights
125
  self.post_init()
126
 
 
 
 
 
127
  def forward(
128
  self,
129
  input_values: torch.Tensor,
130
  attention_mask: Optional[torch.Tensor] = None,
131
  output_hidden_states: Optional[bool] = None,
132
+ output_attentions: Optional[bool] = None,
133
  return_dict: Optional[bool] = None,
134
+ ) -> Union[Tuple, DELULUOutput]:
135
  """
136
+ Forward pass of DELULU model.
137
+
138
  Args:
139
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
140
+ Raw audio waveform at 16kHz sampling rate.
141
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
142
+ Mask to avoid performing attention on padding. Not used in current implementation.
143
+ output_hidden_states (`bool`, *optional*):
144
+ Whether to return all hidden states.
145
+ output_attentions (`bool`, *optional*):
146
+ Whether to return attention weights. Not supported with torchaudio backend.
147
+ return_dict (`bool`, *optional*):
148
+ Whether to return a `DELULUOutput` instead of a tuple.
149
+
150
+ Returns:
151
+ `DELULUOutput` or `tuple`: Model outputs.
152
  """
153
+ output_hidden_states = (
154
+ output_hidden_states if output_hidden_states is not None
155
+ else self.config.output_hidden_states if hasattr(self.config, 'output_hidden_states')
156
+ else False
157
+ )
158
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if hasattr(self.config, 'use_return_dict') else True
159
 
160
+ # Ensure input is 2D: (batch, samples)
161
  if input_values.dim() == 1:
162
  input_values = input_values.unsqueeze(0)
163
 
164
+ # Handle lengths for torchaudio model
165
  lengths = None
166
  if attention_mask is not None:
167
  lengths = attention_mask.sum(dim=-1)
168
 
169
+ # Extract features using torchaudio model
170
  if output_hidden_states:
171
+ # Get all layer outputs
172
+ features, lengths_out = self.wav2vec2.extract_features(
173
+ input_values,
174
+ lengths=lengths
175
+ )
176
+ # features is a list of tensors, one per layer
177
  hidden_states = tuple(features)
178
  last_hidden_state = features[-1]
179
  else:
180
+ # Just get final output
181
+ outputs, lengths_out = self.wav2vec2(input_values, lengths=lengths)
182
+ last_hidden_state = outputs
183
  hidden_states = None
184
 
185
+ # Get convolutional features (before transformer)
186
+ extract_features = self.wav2vec2.feature_extractor(input_values, lengths)[0]
187
+
188
  if not return_dict:
189
+ outputs = (last_hidden_state,)
190
+ if output_hidden_states:
191
+ outputs = outputs + (hidden_states,)
192
+ return outputs
193
 
194
+ return DELULUOutput(
195
  last_hidden_state=last_hidden_state,
196
  hidden_states=hidden_states,
197
+ attentions=None, # torchaudio doesn't expose attention weights
198
+ extract_features=extract_features,
199
  )
200
 
201
+ def extract_features(
202
+ self,
203
+ input_values: torch.Tensor,
204
+ lengths: Optional[torch.Tensor] = None
205
+ ) -> Tuple[torch.Tensor, ...]:
206
+ """
207
+ Extract features from all layers.
208
+
209
+ Args:
210
+ input_values: Audio waveform of shape (batch, samples)
211
+ lengths: Optional lengths for each sample in batch
212
+
213
+ Returns:
214
+ Tuple of tensors, one per layer (including CNN output)
215
+ """
216
  if input_values.dim() == 1:
217
  input_values = input_values.unsqueeze(0)
218
+
219
+ features, _ = self.wav2vec2.extract_features(input_values, lengths=lengths)
220
  return tuple(features)
221
 
222
+ def get_speaker_embedding(
223
+ self,
224
+ input_values: torch.Tensor,
225
+ pooling: str = "mean"
226
+ ) -> torch.Tensor:
227
+ """
228
+ Extract speaker embedding from audio.
229
+
230
+ Args:
231
+ input_values: Audio waveform of shape (batch, samples)
232
+ pooling: Pooling method - "mean", "max", or "first"
233
+
234
+ Returns:
235
+ Speaker embedding of shape (batch, hidden_size)
236
+ """
237
+ outputs = self.forward(input_values, return_dict=True)
238
+ features = outputs.last_hidden_state
239
+
240
+ if pooling == "mean":
241
+ return features.mean(dim=1)
242
+ elif pooling == "max":
243
+ return features.max(dim=1).values
244
+ elif pooling == "first":
245
+ return features[:, 0, :]
246
+ else:
247
+ raise ValueError(f"Unknown pooling method: {pooling}")
248
+
249
+ def _init_weights(self, module):
250
+ """Initialize weights - mostly handled by torchaudio."""
251
+ pass
252
+
253
+
254
+ class DELULUForSequenceClassification(PreTrainedModel):
255
+ """
256
+ DELULU with a classification head for speaker verification and other tasks.
257
+
258
+ Example:
259
+ ```python
260
+ from transformers import AutoModel
261
+
262
+ model = AutoModel.from_pretrained(
263
+ "cmu-mlsp/DELULU",
264
+ trust_remote_code=True,
265
+ num_labels=1251 # Number of speakers in VoxCeleb2
266
+ )
267
+ ```
268
+ """
269
+
270
+ config_class = DELULUConfig
271
+ base_model_prefix = "delulu"
272
+
273
+ def __init__(self, config: DELULUConfig):
274
+ super().__init__(config)
275
+
276
+ self.delulu = DELULUModel(config)
277
+ self.projector = nn.Linear(config.hidden_size, config.hidden_size)
278
+
279
+ num_labels = getattr(config, 'num_labels', None)
280
+ if num_labels:
281
+ self.classifier = nn.Linear(config.hidden_size, num_labels)
282
+ else:
283
+ self.classifier = None
284
+
285
+ self.post_init()
286
+
287
+ def forward(
288
+ self,
289
+ input_values: torch.Tensor,
290
+ attention_mask: Optional[torch.Tensor] = None,
291
+ labels: Optional[torch.Tensor] = None,
292
+ return_dict: Optional[bool] = None,
293
+ ):
294
+ return_dict = return_dict if return_dict is not None else True
295
+
296
+ outputs = self.delulu(
297
+ input_values,
298
+ attention_mask=attention_mask,
299
+ return_dict=True
300
+ )
301
+
302
+ # Pool features
303
+ hidden_states = outputs.last_hidden_state
304
+ pooled = hidden_states.mean(dim=1)
305
+
306
+ # Project
307
+ embeddings = self.projector(pooled)
308
+
309
+ # Classify if head exists
310
+ logits = None
311
+ if self.classifier is not None:
312
+ logits = self.classifier(embeddings)
313
+
314
+ loss = None
315
+ if labels is not None and logits is not None:
316
+ loss_fct = nn.CrossEntropyLoss()
317
+ loss = loss_fct(logits, labels)
318
+
319
+ if not return_dict:
320
+ output = (logits, embeddings) + (outputs.last_hidden_state,)
321
+ return ((loss,) + output) if loss is not None else output
322
+
323
+ return {
324
+ "loss": loss,
325
+ "logits": logits,
326
+ "embeddings": embeddings,
327
+ "last_hidden_state": outputs.last_hidden_state,
328
+ }
329
+
330
+
331
+ # Register for auto classes
332
+ DELULUConfig.register_for_auto_class()
333
+ DELULUModel.register_for_auto_class("AutoModel")
preprocessor_config.json DELETED
@@ -1,9 +0,0 @@
1
- {
2
- "do_normalize": true,
3
- "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
- "feature_size": 1,
5
- "padding_side": "right",
6
- "padding_value": 0.0,
7
- "return_attention_mask": true,
8
- "sampling_rate": 16000
9
- }
 
 
 
 
 
 
 
 
 
 
upload_delulu_to_hf.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ DELULU Model Upload Script for Hugging Face Hub
4
+ ================================================
5
+
6
+ Production-ready script to upload DELULU (Discriminative Embedding Learning Using
7
+ Latent Units) model checkpoints to Hugging Face with safety checks, versioning,
8
+ and best practices.
9
+
10
+ Author: Massa Baali
11
+ Model: DELULU - Speaker-Aware Self-Supervised Speech Foundational Model
12
+ Paper: https://arxiv.org/abs/2510.17662
13
+
14
+ Usage:
15
+ python upload_delulu_to_hf.py --checkpoint-dir ./checkpoints --repo-id username/DELULU
16
+
17
+ # With all options:
18
+ python upload_delulu_to_hf.py \
19
+ --checkpoint-dir ./checkpoints \
20
+ --repo-id username/DELULU \
21
+ --version v1.0.0 \
22
+ --tags speaker-verification speech-ssl hubert \
23
+ --private \
24
+ --dry-run
25
+ """
26
+
27
+ import argparse
28
+ import hashlib
29
+ import json
30
+ import logging
31
+ import os
32
+ import sys
33
+ from dataclasses import dataclass, field
34
+ from datetime import datetime
35
+ from pathlib import Path
36
+ from typing import Optional
37
+
38
+ try:
39
+ from huggingface_hub import (
40
+ HfApi,
41
+ create_repo,
42
+ upload_folder,
43
+ login,
44
+ whoami,
45
+ RepoUrl,
46
+ )
47
+ from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError
48
+ except ImportError:
49
+ print("Error: huggingface_hub not installed. Install with: pip install huggingface_hub")
50
+ sys.exit(1)
51
+
52
+ # Configure logging
53
+ logging.basicConfig(
54
+ level=logging.INFO,
55
+ format="%(asctime)s - %(levelname)s - %(message)s",
56
+ datefmt="%Y-%m-%d %H:%M:%S",
57
+ )
58
+ logger = logging.getLogger(__name__)
59
+
60
+
61
+ # =============================================================================
62
+ # Configuration
63
+ # =============================================================================
64
+
65
+ @dataclass
66
+ class DELULUConfig:
67
+ """Configuration for DELULU model architecture.
68
+
69
+ DELULU uses HuBERT architecture with modified convolutional feature extractor
70
+ strides for 16ms frame shift (optimized for speaker verification).
71
+ """
72
+ # Model architecture (HuBERT-based)
73
+ model_type: str = "hubert"
74
+
75
+ # Modified convolutional feature extractor configuration
76
+ # Standard HuBERT: [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
77
+ # DELULU: [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
78
+ conv_dim: list = field(default_factory=lambda: [512, 512, 512, 512, 512, 512, 512])
79
+ conv_kernel: list = field(default_factory=lambda: [10, 3, 3, 3, 3, 2, 2])
80
+ conv_stride: list = field(default_factory=lambda: [4, 2, 2, 2, 2, 2, 2]) # Key difference!
81
+
82
+ # Transformer configuration
83
+ hidden_size: int = 768
84
+ num_hidden_layers: int = 12
85
+ num_attention_heads: int = 12
86
+ intermediate_size: int = 3072
87
+
88
+ # Training configuration
89
+ frame_shift_ms: int = 16 # Optimal for speaker verification
90
+ sampling_rate: int = 16000
91
+
92
+ # Clustering configuration (ReDimNet-guided)
93
+ num_clusters: int = 256
94
+ cluster_feature_dim: int = 2304 # ReDimNet frame-level embedding dimension
95
+
96
+ def to_dict(self) -> dict:
97
+ """Convert config to dictionary for serialization."""
98
+ return {
99
+ "model_type": self.model_type,
100
+ "conv_dim": self.conv_dim,
101
+ "conv_kernel": self.conv_kernel,
102
+ "conv_stride": self.conv_stride,
103
+ "hidden_size": self.hidden_size,
104
+ "num_hidden_layers": self.num_hidden_layers,
105
+ "num_attention_heads": self.num_attention_heads,
106
+ "intermediate_size": self.intermediate_size,
107
+ "frame_shift_ms": self.frame_shift_ms,
108
+ "sampling_rate": self.sampling_rate,
109
+ "num_clusters": self.num_clusters,
110
+ "cluster_feature_dim": self.cluster_feature_dim,
111
+ "architectures": ["DELULUModel"],
112
+ "auto_map": {
113
+ "AutoModel": "modeling_delulu.DELULUModel",
114
+ "AutoConfig": "configuration_delulu.DELULUConfig"
115
+ }
116
+ }
117
+
118
+
119
+ @dataclass
120
+ class UploadConfig:
121
+ """Configuration for the upload process."""
122
+ checkpoint_dir: Path
123
+ repo_id: str
124
+ version: Optional[str] = None
125
+ tags: list = field(default_factory=list)
126
+ private: bool = False
127
+ dry_run: bool = False
128
+ create_if_missing: bool = True
129
+ commit_message: Optional[str] = None
130
+
131
+ # Safety settings
132
+ verify_checksums: bool = True
133
+ max_file_size_gb: float = 10.0
134
+ required_files: list = field(default_factory=lambda: ["pytorch_model.bin", "config.json"])
135
+
136
+ def __post_init__(self):
137
+ self.checkpoint_dir = Path(self.checkpoint_dir)
138
+
139
+
140
+ # =============================================================================
141
+ # Safety Checks
142
+ # =============================================================================
143
+
144
+ class SafetyChecker:
145
+ """Performs safety checks before upload."""
146
+
147
+ def __init__(self, config: UploadConfig):
148
+ self.config = config
149
+ self.errors: list[str] = []
150
+ self.warnings: list[str] = []
151
+
152
+ def check_all(self) -> bool:
153
+ """Run all safety checks. Returns True if all pass."""
154
+ self._check_directory_exists()
155
+ self._check_required_files()
156
+ self._check_file_sizes()
157
+ self._check_no_sensitive_data()
158
+ self._check_checkpoint_integrity()
159
+
160
+ # Log results
161
+ for warning in self.warnings:
162
+ logger.warning(f"⚠️ {warning}")
163
+ for error in self.errors:
164
+ logger.error(f"❌ {error}")
165
+
166
+ if self.errors:
167
+ logger.error(f"Safety checks failed with {len(self.errors)} error(s)")
168
+ return False
169
+
170
+ logger.info("✅ All safety checks passed")
171
+ return True
172
+
173
+ def _check_directory_exists(self):
174
+ """Verify checkpoint directory exists and is accessible."""
175
+ if not self.config.checkpoint_dir.exists():
176
+ self.errors.append(f"Checkpoint directory not found: {self.config.checkpoint_dir}")
177
+ elif not self.config.checkpoint_dir.is_dir():
178
+ self.errors.append(f"Path is not a directory: {self.config.checkpoint_dir}")
179
+
180
+ def _check_required_files(self):
181
+ """Check that required model files exist."""
182
+ if not self.config.checkpoint_dir.exists():
183
+ return
184
+
185
+ for required_file in self.config.required_files:
186
+ file_path = self.config.checkpoint_dir / required_file
187
+ # Also check for .safetensors variant
188
+ safetensors_variant = required_file.replace(".bin", ".safetensors")
189
+ safetensors_path = self.config.checkpoint_dir / safetensors_variant
190
+
191
+ if not file_path.exists() and not safetensors_path.exists():
192
+ # Special handling for model weights - either .bin or .safetensors is fine
193
+ if "model" in required_file:
194
+ self.warnings.append(
195
+ f"Model file not found: {required_file} or {safetensors_variant}. "
196
+ "Will look for alternative formats."
197
+ )
198
+ else:
199
+ self.errors.append(f"Required file not found: {required_file}")
200
+
201
+ def _check_file_sizes(self):
202
+ """Verify no files exceed maximum size limit."""
203
+ if not self.config.checkpoint_dir.exists():
204
+ return
205
+
206
+ max_size_bytes = self.config.max_file_size_gb * 1024 * 1024 * 1024
207
+
208
+ for file_path in self.config.checkpoint_dir.rglob("*"):
209
+ if file_path.is_file():
210
+ size = file_path.stat().st_size
211
+ if size > max_size_bytes:
212
+ self.errors.append(
213
+ f"File exceeds {self.config.max_file_size_gb}GB limit: "
214
+ f"{file_path.name} ({size / 1024 / 1024 / 1024:.2f}GB)"
215
+ )
216
+
217
+ def _check_no_sensitive_data(self):
218
+ """Check for potentially sensitive files that shouldn't be uploaded."""
219
+ sensitive_patterns = [
220
+ ".env", ".secret", "credentials", "password", "api_key", "token",
221
+ ".git", "__pycache__", ".pyc", ".DS_Store"
222
+ ]
223
+
224
+ if not self.config.checkpoint_dir.exists():
225
+ return
226
+
227
+ for file_path in self.config.checkpoint_dir.rglob("*"):
228
+ file_name = file_path.name.lower()
229
+ for pattern in sensitive_patterns:
230
+ if pattern in file_name:
231
+ self.warnings.append(
232
+ f"Potentially sensitive file detected: {file_path.name}. "
233
+ "Consider adding to .gitignore or removing before upload."
234
+ )
235
+ break
236
+
237
+ def _check_checkpoint_integrity(self):
238
+ """Basic integrity check for PyTorch checkpoint files."""
239
+ if not self.config.checkpoint_dir.exists():
240
+ return
241
+
242
+ try:
243
+ import torch
244
+
245
+ for file_path in self.config.checkpoint_dir.glob("*.bin"):
246
+ try:
247
+ # Just try to load metadata, not full weights
248
+ torch.load(file_path, map_location="cpu", weights_only=False)
249
+ logger.info(f"✓ Checkpoint integrity verified: {file_path.name}")
250
+ except Exception as e:
251
+ self.errors.append(f"Corrupted checkpoint file: {file_path.name} - {e}")
252
+ except ImportError:
253
+ self.warnings.append("PyTorch not installed, skipping checkpoint integrity check")
254
+
255
+
256
+ # =============================================================================
257
+ # Checksum Utilities
258
+ # =============================================================================
259
+
260
+ def compute_file_checksum(file_path: Path, algorithm: str = "sha256") -> str:
261
+ """Compute checksum for a file."""
262
+ hash_func = hashlib.new(algorithm)
263
+
264
+ with open(file_path, "rb") as f:
265
+ for chunk in iter(lambda: f.read(8192), b""):
266
+ hash_func.update(chunk)
267
+
268
+ return hash_func.hexdigest()
269
+
270
+
271
+ def generate_checksums(directory: Path) -> dict:
272
+ """Generate checksums for all files in directory."""
273
+ checksums = {}
274
+
275
+ for file_path in directory.rglob("*"):
276
+ if file_path.is_file():
277
+ relative_path = file_path.relative_to(directory)
278
+ checksums[str(relative_path)] = {
279
+ "sha256": compute_file_checksum(file_path, "sha256"),
280
+ "size_bytes": file_path.stat().st_size
281
+ }
282
+
283
+ return checksums
284
+
285
+
286
+ def save_checksums(checksums: dict, output_path: Path):
287
+ """Save checksums to JSON file."""
288
+ with open(output_path, "w") as f:
289
+ json.dump(checksums, f, indent=2)
290
+ logger.info(f"Checksums saved to: {output_path}")
291
+
292
+
293
+ # =============================================================================
294
+ # Upload Manager
295
+ # =============================================================================
296
+
297
+ class DELULUUploader:
298
+ """Handles uploading DELULU model to Hugging Face Hub."""
299
+
300
+ def __init__(self, upload_config: UploadConfig):
301
+ self.config = upload_config
302
+ self.api = HfApi()
303
+ self.model_config = DELULUConfig()
304
+
305
+ def authenticate(self) -> bool:
306
+ """Verify authentication with Hugging Face Hub."""
307
+ try:
308
+ user_info = whoami()
309
+ logger.info(f"✅ Authenticated as: {user_info['name']}")
310
+ return True
311
+ except Exception as e:
312
+ logger.error(f"❌ Authentication failed: {e}")
313
+ logger.info("Run 'huggingface-cli login' or set HF_TOKEN environment variable")
314
+ return False
315
+
316
+ def prepare_upload_directory(self) -> Path:
317
+ """Prepare files for upload, including config and checksums."""
318
+ upload_dir = self.config.checkpoint_dir
319
+
320
+ # Generate and save config.json if not present
321
+ config_path = upload_dir / "config.json"
322
+ if not config_path.exists():
323
+ logger.info("Generating config.json...")
324
+ with open(config_path, "w") as f:
325
+ json.dump(self.model_config.to_dict(), f, indent=2)
326
+
327
+ # Generate checksums
328
+ if self.config.verify_checksums:
329
+ logger.info("Generating checksums...")
330
+ checksums = generate_checksums(upload_dir)
331
+ save_checksums(checksums, upload_dir / "checksums.json")
332
+
333
+ # Create upload metadata
334
+ metadata = {
335
+ "upload_timestamp": datetime.utcnow().isoformat(),
336
+ "version": self.config.version,
337
+ "uploader_script_version": "1.0.0",
338
+ "model_type": "DELULU",
339
+ "base_architecture": "HuBERT"
340
+ }
341
+
342
+ metadata_path = upload_dir / "upload_metadata.json"
343
+ with open(metadata_path, "w") as f:
344
+ json.dump(metadata, f, indent=2)
345
+
346
+ return upload_dir
347
+
348
+ def create_or_verify_repo(self) -> bool:
349
+ """Create repository if it doesn't exist, or verify access."""
350
+ try:
351
+ # Check if repo exists
352
+ self.api.repo_info(repo_id=self.config.repo_id, repo_type="model")
353
+ logger.info(f"✅ Repository exists: {self.config.repo_id}")
354
+ return True
355
+
356
+ except RepositoryNotFoundError:
357
+ if self.config.create_if_missing:
358
+ logger.info(f"Creating repository: {self.config.repo_id}")
359
+
360
+ if self.config.dry_run:
361
+ logger.info("[DRY RUN] Would create repository")
362
+ return True
363
+
364
+ try:
365
+ repo_url: RepoUrl = create_repo(
366
+ repo_id=self.config.repo_id,
367
+ repo_type="model",
368
+ private=self.config.private,
369
+ exist_ok=True
370
+ )
371
+ logger.info(f"✅ Repository created: {repo_url}")
372
+ return True
373
+ except HfHubHTTPError as e:
374
+ logger.error(f"❌ Failed to create repository: {e}")
375
+ return False
376
+ else:
377
+ logger.error(f"❌ Repository not found: {self.config.repo_id}")
378
+ return False
379
+
380
+ except Exception as e:
381
+ logger.error(f"❌ Error accessing repository: {e}")
382
+ return False
383
+
384
+ def upload(self) -> bool:
385
+ """Execute the upload process."""
386
+ logger.info("=" * 60)
387
+ logger.info("DELULU Model Upload to Hugging Face Hub")
388
+ logger.info("=" * 60)
389
+
390
+ # Step 1: Authenticate
391
+ if not self.authenticate():
392
+ return False
393
+
394
+ # Step 2: Safety checks
395
+ safety_checker = SafetyChecker(self.config)
396
+ if not safety_checker.check_all():
397
+ return False
398
+
399
+ # Step 3: Create/verify repository
400
+ if not self.create_or_verify_repo():
401
+ return False
402
+
403
+ # Step 4: Prepare upload directory
404
+ upload_dir = self.prepare_upload_directory()
405
+
406
+ # Step 5: Generate commit message
407
+ commit_message = self.config.commit_message or self._generate_commit_message()
408
+
409
+ # Step 6: Execute upload
410
+ if self.config.dry_run:
411
+ logger.info("[DRY RUN] Would upload the following files:")
412
+ for file_path in upload_dir.rglob("*"):
413
+ if file_path.is_file():
414
+ size_mb = file_path.stat().st_size / 1024 / 1024
415
+ logger.info(f" - {file_path.relative_to(upload_dir)} ({size_mb:.2f} MB)")
416
+ logger.info(f"[DRY RUN] Commit message: {commit_message}")
417
+ return True
418
+
419
+ logger.info("Starting upload...")
420
+ try:
421
+ upload_folder(
422
+ folder_path=str(upload_dir),
423
+ repo_id=self.config.repo_id,
424
+ repo_type="model",
425
+ commit_message=commit_message,
426
+ ignore_patterns=[
427
+ "*.pyc", "__pycache__", ".git", ".DS_Store",
428
+ "*.log", "wandb", "runs"
429
+ ]
430
+ )
431
+
432
+ logger.info("✅ Upload complete!")
433
+ logger.info(f"View model at: https://huggingface.co/{self.config.repo_id}")
434
+
435
+ # Create version tag if specified
436
+ if self.config.version:
437
+ self._create_version_tag()
438
+
439
+ return True
440
+
441
+ except Exception as e:
442
+ logger.error(f"❌ Upload failed: {e}")
443
+ return False
444
+
445
+ def _generate_commit_message(self) -> str:
446
+ """Generate a descriptive commit message."""
447
+ parts = ["Upload DELULU model checkpoint"]
448
+
449
+ if self.config.version:
450
+ parts.append(f"(version {self.config.version})")
451
+
452
+ parts.append(f"\n\nModel: DELULU - Speaker-Aware Self-Supervised Speech Model")
453
+ parts.append(f"Architecture: HuBERT with modified stride configuration")
454
+ parts.append(f"Frame shift: 16ms (optimized for speaker verification)")
455
+
456
+ if self.config.tags:
457
+ parts.append(f"\nTags: {', '.join(self.config.tags)}")
458
+
459
+ return " ".join(parts[:2]) + "".join(parts[2:])
460
+
461
+ def _create_version_tag(self):
462
+ """Create a Git tag for the version."""
463
+ try:
464
+ self.api.create_tag(
465
+ repo_id=self.config.repo_id,
466
+ tag=self.config.version,
467
+ tag_message=f"DELULU {self.config.version}",
468
+ repo_type="model"
469
+ )
470
+ logger.info(f"✅ Created version tag: {self.config.version}")
471
+ except Exception as e:
472
+ logger.warning(f"⚠️ Could not create version tag: {e}")
473
+
474
+
475
+ # =============================================================================
476
+ # CLI Interface
477
+ # =============================================================================
478
+
479
+ def parse_args() -> argparse.Namespace:
480
+ """Parse command line arguments."""
481
+ parser = argparse.ArgumentParser(
482
+ description="Upload DELULU model checkpoints to Hugging Face Hub",
483
+ formatter_class=argparse.RawDescriptionHelpFormatter,
484
+ epilog="""
485
+ Examples:
486
+ # Basic upload
487
+ python upload_delulu_to_hf.py --checkpoint-dir ./checkpoints --repo-id username/DELULU
488
+
489
+ # Upload with version and tags
490
+ python upload_delulu_to_hf.py \\
491
+ --checkpoint-dir ./checkpoints \\
492
+ --repo-id username/DELULU \\
493
+ --version v1.0.0 \\
494
+ --tags speaker-verification speech-ssl hubert
495
+
496
+ # Dry run (no actual upload)
497
+ python upload_delulu_to_hf.py --checkpoint-dir ./checkpoints --repo-id username/DELULU --dry-run
498
+
499
+ # Private repository
500
+ python upload_delulu_to_hf.py --checkpoint-dir ./checkpoints --repo-id username/DELULU --private
501
+ """
502
+ )
503
+
504
+ # Required arguments
505
+ parser.add_argument(
506
+ "--checkpoint-dir", "-c",
507
+ type=str,
508
+ required=True,
509
+ help="Path to directory containing model checkpoints"
510
+ )
511
+ parser.add_argument(
512
+ "--repo-id", "-r",
513
+ type=str,
514
+ required=True,
515
+ help="Hugging Face repository ID (e.g., username/DELULU)"
516
+ )
517
+
518
+ # Optional arguments
519
+ parser.add_argument(
520
+ "--version", "-v",
521
+ type=str,
522
+ default=None,
523
+ help="Version tag for this upload (e.g., v1.0.0)"
524
+ )
525
+ parser.add_argument(
526
+ "--tags", "-t",
527
+ nargs="+",
528
+ default=["speaker-verification", "speech-ssl", "hubert", "self-supervised"],
529
+ help="Tags to add to the model (space-separated)"
530
+ )
531
+ parser.add_argument(
532
+ "--private",
533
+ action="store_true",
534
+ help="Create as private repository"
535
+ )
536
+ parser.add_argument(
537
+ "--dry-run",
538
+ action="store_true",
539
+ help="Simulate upload without actually uploading"
540
+ )
541
+ parser.add_argument(
542
+ "--commit-message", "-m",
543
+ type=str,
544
+ default=None,
545
+ help="Custom commit message"
546
+ )
547
+ parser.add_argument(
548
+ "--no-verify-checksums",
549
+ action="store_true",
550
+ help="Skip checksum generation and verification"
551
+ )
552
+ parser.add_argument(
553
+ "--max-file-size",
554
+ type=float,
555
+ default=10.0,
556
+ help="Maximum file size in GB (default: 10.0)"
557
+ )
558
+ parser.add_argument(
559
+ "--no-create",
560
+ action="store_true",
561
+ help="Don't create repository if it doesn't exist"
562
+ )
563
+
564
+ return parser.parse_args()
565
+
566
+
567
+ def main():
568
+ """Main entry point."""
569
+ args = parse_args()
570
+
571
+ # Create upload configuration
572
+ upload_config = UploadConfig(
573
+ checkpoint_dir=args.checkpoint_dir,
574
+ repo_id=args.repo_id,
575
+ version=args.version,
576
+ tags=args.tags,
577
+ private=args.private,
578
+ dry_run=args.dry_run,
579
+ commit_message=args.commit_message,
580
+ verify_checksums=not args.no_verify_checksums,
581
+ max_file_size_gb=args.max_file_size,
582
+ create_if_missing=not args.no_create
583
+ )
584
+
585
+ # Create uploader and execute
586
+ uploader = DELULUUploader(upload_config)
587
+ success = uploader.upload()
588
+
589
+ sys.exit(0 if success else 1)
590
+
591
+
592
+ if __name__ == "__main__":
593
+ main()