Upload folder using huggingface_hub
Browse files- DELULU_MODEL_CARD.md +326 -0
- config.json +7 -32
- configuration_delulu.py +92 -6
- convert_delulu_fixed.py +426 -0
- convert_delulu_to_hf.py +697 -0
- checksums.json → delulu_hf_model/checksums.json +0 -0
- delulu_hf_model/config.json +60 -0
- delulu_hf_model/configuration_delulu.py +73 -0
- model.safetensors → delulu_hf_model/model.safetensors +0 -0
- delulu_hf_model/modeling_delulu.py +127 -0
- pytorch_model.bin → delulu_hf_model/pytorch_model.bin +0 -0
- upload_metadata.json → delulu_hf_model/upload_metadata.json +0 -0
- load_delulu.py +0 -94
- modeling_delulu.py +235 -29
- preprocessor_config.json +0 -9
- upload_delulu_to_hf.py +593 -0
DELULU_MODEL_CARD.md
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-nd-4.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: transformers
|
| 6 |
+
tags:
|
| 7 |
+
- speaker-verification
|
| 8 |
+
- speaker-diarization
|
| 9 |
+
- speaker-profiling
|
| 10 |
+
- speech
|
| 11 |
+
- audio
|
| 12 |
+
- self-supervised-learning
|
| 13 |
+
- ssl
|
| 14 |
+
- hubert
|
| 15 |
+
- speech-representation
|
| 16 |
+
- pytorch
|
| 17 |
+
- deep-learning
|
| 18 |
+
datasets:
|
| 19 |
+
- librispeech_asr
|
| 20 |
+
metrics:
|
| 21 |
+
- eer
|
| 22 |
+
pipeline_tag: audio-classification
|
| 23 |
+
model-index:
|
| 24 |
+
- name: DELULU
|
| 25 |
+
results:
|
| 26 |
+
- task:
|
| 27 |
+
type: speaker-verification
|
| 28 |
+
name: Speaker Verification
|
| 29 |
+
dataset:
|
| 30 |
+
type: VoxCeleb1-O
|
| 31 |
+
name: VoxCeleb1-O
|
| 32 |
+
metrics:
|
| 33 |
+
- type: eer
|
| 34 |
+
value: 13.52
|
| 35 |
+
name: Equal Error Rate (Upstream)
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
# DELULU: Discriminative Embedding Learning Using Latent Units
|
| 39 |
+
|
| 40 |
+
<div align="center">
|
| 41 |
+
|
| 42 |
+
**A Speaker-Aware Self-Supervised Speech Foundational Model**
|
| 43 |
+
|
| 44 |
+
[](https://arxiv.org/abs/2510.17662)
|
| 45 |
+
[](https://creativecommons.org/licenses/by-nc-nd/4.0/)
|
| 46 |
+
|
| 47 |
+
</div>
|
| 48 |
+
|
| 49 |
+
## Model Description
|
| 50 |
+
|
| 51 |
+
**DELULU** (Discriminative Embedding Learning Using Latent Units) is a speaker-aware self-supervised speech foundational model that addresses a critical limitation of existing SSL models: their inability to capture speaker-discriminative features essential for verification, diarization, and profiling applications.
|
| 52 |
+
|
| 53 |
+
While conventional SSL models like HuBERT, wav2vec 2.0, and WavLM excel at content-driven tasks (ASR, speech recognition), they learn representations optimized for phonetic/linguistic content, inadvertently discarding speaker identity information. DELULU bridges this gap by integrating external speaker supervision into the pseudo-label generation process.
|
| 54 |
+
|
| 55 |
+
### Key Innovation
|
| 56 |
+
|
| 57 |
+
DELULU introduces a novel approach to self-supervised speech learning by leveraging **frame-level embeddings from ReDimNet**, a state-of-the-art speaker verification model, to guide the k-means clustering step during pre-training. This introduces a strong **speaker-discriminative inductive bias** that aligns representation learning with speaker identity—a fundamental shift from content-focused SSL paradigms.
|
| 58 |
+
|
| 59 |
+
## Architecture
|
| 60 |
+
|
| 61 |
+
DELULU is based on the HuBERT architecture with a **modified convolutional feature extractor** optimized for speaker verification:
|
| 62 |
+
|
| 63 |
+
### Convolutional Feature Extractor
|
| 64 |
+
|
| 65 |
+
| Layer | Channels | Kernel Size | Stride |
|
| 66 |
+
|-------|----------|-------------|--------|
|
| 67 |
+
| 1 | 512 | 10 | **4** |
|
| 68 |
+
| 2 | 512 | 3 | 2 |
|
| 69 |
+
| 3 | 512 | 3 | 2 |
|
| 70 |
+
| 4 | 512 | 3 | 2 |
|
| 71 |
+
| 5 | 512 | 3 | 2 |
|
| 72 |
+
| 6 | 512 | 2 | 2 |
|
| 73 |
+
| 7 | 512 | 2 | 2 |
|
| 74 |
+
|
| 75 |
+
> **Key Difference**: The first layer uses stride **4** (vs. stride 5 in standard HuBERT), resulting in a **16ms frame shift** optimized for speaker verification tasks.
|
| 76 |
+
|
| 77 |
+
### Transformer Encoder
|
| 78 |
+
|
| 79 |
+
- **Hidden size**: 768
|
| 80 |
+
- **Attention heads**: 12
|
| 81 |
+
- **Layers**: 12
|
| 82 |
+
- **Intermediate size**: 3,072
|
| 83 |
+
- **Frame shift**: 16ms (vs. 20ms in HuBERT)
|
| 84 |
+
|
| 85 |
+
### Training Configuration
|
| 86 |
+
|
| 87 |
+
- **Clustering**: ReDimNet-guided k-means with k=256 clusters
|
| 88 |
+
- **Feature dimension**: 2,304 (ReDimNet frame-level embeddings)
|
| 89 |
+
- **Training objective**: Dual objective combining masked prediction + denoising
|
| 90 |
+
- **Pre-training data**: LibriSpeech 960h
|
| 91 |
+
- **Training steps**: 400k updates
|
| 92 |
+
|
| 93 |
+
## Performance
|
| 94 |
+
|
| 95 |
+
### Upstream Speaker Verification (Zero-Shot)
|
| 96 |
+
|
| 97 |
+
| Model | VoxCeleb1-O EER (%) |
|
| 98 |
+
|-------|---------------------|
|
| 99 |
+
| wav2vec 2.0 | 37.21 |
|
| 100 |
+
| HuBERT | 34.05 |
|
| 101 |
+
| WavLM | 29.84 |
|
| 102 |
+
| **DELULU** | **13.52** |
|
| 103 |
+
|
| 104 |
+
> **62% relative improvement** over standard HuBERT in equal error rate.
|
| 105 |
+
|
| 106 |
+
### Ablation: Why ReDimNet-Guided Clustering?
|
| 107 |
+
|
| 108 |
+
| Clustering Features | k | EER (%) |
|
| 109 |
+
|---------------------|---|---------|
|
| 110 |
+
| MFCC | 100 | 37.73 |
|
| 111 |
+
| HuBERT (pretrained) | 500 | 34.05 |
|
| 112 |
+
| **ReDimNet** | 256 | **13.53** |
|
| 113 |
+
|
| 114 |
+
ReDimNet-guided pseudo-labels provide a **60% relative improvement** over HuBERT's acoustic-only approach.
|
| 115 |
+
|
| 116 |
+
### Demographic Robustness
|
| 117 |
+
|
| 118 |
+
DELULU consistently outperforms baselines across all demographic groups, with particularly strong improvements for challenging subgroups:
|
| 119 |
+
|
| 120 |
+
| Demographic | HuBERT EER (%) | DELULU EER (%) | Improvement |
|
| 121 |
+
|-------------|----------------|----------------|-------------|
|
| 122 |
+
| Male 36-45 | 39.47 | 24.53 | 38% |
|
| 123 |
+
| All groups | Varies | Consistent | ✓ |
|
| 124 |
+
|
| 125 |
+
### Zero-Shot Speaker Profiling (DynamicSUPERB)
|
| 126 |
+
|
| 127 |
+
DELULU excels on multiple speaker-related tasks without fine-tuning:
|
| 128 |
+
- Gender classification
|
| 129 |
+
- Age estimation
|
| 130 |
+
- Accent recognition
|
| 131 |
+
- Speaker counting
|
| 132 |
+
- Spoof detection
|
| 133 |
+
|
| 134 |
+
## Intended Uses
|
| 135 |
+
|
| 136 |
+
### Primary Use Cases
|
| 137 |
+
|
| 138 |
+
1. **Speaker Verification**: Verify whether two speech samples are from the same speaker
|
| 139 |
+
2. **Speaker Diarization**: Segment and cluster speech by speaker identity
|
| 140 |
+
3. **Speaker Profiling**: Extract demographic attributes (age, gender, accent)
|
| 141 |
+
4. **Forensic Audio Analysis**: Speaker identification in investigative contexts
|
| 142 |
+
|
| 143 |
+
### Downstream Applications
|
| 144 |
+
|
| 145 |
+
- Voice biometrics and authentication systems
|
| 146 |
+
- Meeting transcription with speaker labels
|
| 147 |
+
- Call center analytics
|
| 148 |
+
- Content personalization based on speaker identity
|
| 149 |
+
- Multi-speaker dialogue systems
|
| 150 |
+
|
| 151 |
+
## How to Use
|
| 152 |
+
|
| 153 |
+
### Installation
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
pip install transformers torch torchaudio
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
### Loading the Model
|
| 160 |
+
|
| 161 |
+
```python
|
| 162 |
+
import torch
|
| 163 |
+
from transformers import AutoModel, AutoConfig
|
| 164 |
+
|
| 165 |
+
# Load DELULU model
|
| 166 |
+
model = AutoModel.from_pretrained("username/DELULU", trust_remote_code=True)
|
| 167 |
+
model.eval()
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
### Feature Extraction
|
| 171 |
+
|
| 172 |
+
```python
|
| 173 |
+
import torchaudio
|
| 174 |
+
|
| 175 |
+
# Load audio (16kHz sampling rate required)
|
| 176 |
+
waveform, sample_rate = torchaudio.load("audio.wav")
|
| 177 |
+
if sample_rate != 16000:
|
| 178 |
+
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
| 179 |
+
waveform = resampler(waveform)
|
| 180 |
+
|
| 181 |
+
# Extract features
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
outputs = model(waveform)
|
| 184 |
+
# Use last hidden state for downstream tasks
|
| 185 |
+
features = outputs.last_hidden_state # [batch, time, 768]
|
| 186 |
+
|
| 187 |
+
# For speaker verification, typically use mean pooling
|
| 188 |
+
speaker_embedding = features.mean(dim=1) # [batch, 768]
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
### Speaker Verification Example
|
| 192 |
+
|
| 193 |
+
```python
|
| 194 |
+
import torch.nn.functional as F
|
| 195 |
+
|
| 196 |
+
def compute_similarity(embedding1, embedding2):
|
| 197 |
+
"""Compute cosine similarity between two speaker embeddings."""
|
| 198 |
+
return F.cosine_similarity(embedding1, embedding2, dim=-1)
|
| 199 |
+
|
| 200 |
+
# Extract embeddings for two audio samples
|
| 201 |
+
emb1 = extract_embedding(model, audio1)
|
| 202 |
+
emb2 = extract_embedding(model, audio2)
|
| 203 |
+
|
| 204 |
+
# Compute similarity score
|
| 205 |
+
similarity = compute_similarity(emb1, emb2)
|
| 206 |
+
print(f"Similarity score: {similarity.item():.4f}")
|
| 207 |
+
|
| 208 |
+
# Threshold-based decision (tune threshold on validation data)
|
| 209 |
+
threshold = 0.7
|
| 210 |
+
same_speaker = similarity > threshold
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
### Fine-Tuning for Downstream Tasks
|
| 214 |
+
|
| 215 |
+
```python
|
| 216 |
+
from transformers import Trainer, TrainingArguments
|
| 217 |
+
|
| 218 |
+
# Add task-specific head
|
| 219 |
+
class SpeakerVerificationModel(torch.nn.Module):
|
| 220 |
+
def __init__(self, base_model, embedding_dim=256):
|
| 221 |
+
super().__init__()
|
| 222 |
+
self.base = base_model
|
| 223 |
+
self.projector = torch.nn.Linear(768, embedding_dim)
|
| 224 |
+
|
| 225 |
+
def forward(self, x):
|
| 226 |
+
features = self.base(x).last_hidden_state
|
| 227 |
+
pooled = features.mean(dim=1)
|
| 228 |
+
return self.projector(pooled)
|
| 229 |
+
|
| 230 |
+
# Fine-tune with your speaker verification dataset
|
| 231 |
+
model = SpeakerVerificationModel(base_model)
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
## Training Details
|
| 235 |
+
|
| 236 |
+
### Pre-training Process
|
| 237 |
+
|
| 238 |
+
1. **Pseudo-Label Generation**:
|
| 239 |
+
- Extract frame-level embeddings using ReDimNet (dimension: 2,304)
|
| 240 |
+
- Apply k-means clustering with k=256 to create speaker-aware pseudo-labels
|
| 241 |
+
- ReDimNet stride modified to match encoder stride (16ms)
|
| 242 |
+
|
| 243 |
+
2. **Training Objective**:
|
| 244 |
+
- **Masked Prediction**: Predict pseudo-labels for masked frames
|
| 245 |
+
- **Denoising**: Additional denoising objective for robustness
|
| 246 |
+
|
| 247 |
+
3. **Optimization**:
|
| 248 |
+
- Training data: LibriSpeech 960 hours
|
| 249 |
+
- Training steps: 400k updates
|
| 250 |
+
- Batch size: 87.5 seconds of audio per GPU
|
| 251 |
+
- Hardware: 32 GPUs
|
| 252 |
+
|
| 253 |
+
### Why 16ms Frame Shift?
|
| 254 |
+
|
| 255 |
+
Ablation studies showed that **16ms stride achieves optimal EER (13.52%)**, while both lower (≤15ms) and higher (≥20ms) strides resulted in EER >14%. This precise temporal resolution balances:
|
| 256 |
+
- Fine-grained speaker characteristics capture
|
| 257 |
+
- Computational efficiency
|
| 258 |
+
- Training stability
|
| 259 |
+
|
| 260 |
+
## Limitations
|
| 261 |
+
|
| 262 |
+
1. **Domain Shift**: Performance may degrade on audio with characteristics significantly different from LibriSpeech (e.g., noisy environments, non-English speech, telephony audio)
|
| 263 |
+
|
| 264 |
+
2. **Computational Requirements**: As a transformer-based model, DELULU requires substantial computational resources for inference on long audio
|
| 265 |
+
|
| 266 |
+
3. **Fine-tuning May Be Required**: While DELULU provides strong zero-shot speaker representations, task-specific fine-tuning typically improves performance
|
| 267 |
+
|
| 268 |
+
4. **Language**: Pre-trained on English speech; cross-lingual transfer may be limited
|
| 269 |
+
|
| 270 |
+
## Ethical Considerations
|
| 271 |
+
|
| 272 |
+
### Potential Misuse
|
| 273 |
+
|
| 274 |
+
Speaker verification technology can be misused for:
|
| 275 |
+
- Unauthorized surveillance
|
| 276 |
+
- Privacy violations
|
| 277 |
+
- Identity fraud
|
| 278 |
+
- Discriminatory profiling
|
| 279 |
+
|
| 280 |
+
### Recommended Safeguards
|
| 281 |
+
|
| 282 |
+
- Obtain explicit consent before processing voice data
|
| 283 |
+
- Implement robust access controls
|
| 284 |
+
- Follow data protection regulations (GDPR, CCPA)
|
| 285 |
+
- Conduct bias audits across demographic groups
|
| 286 |
+
- Maintain transparency about system capabilities and limitations
|
| 287 |
+
|
| 288 |
+
### Bias Evaluation
|
| 289 |
+
|
| 290 |
+
DELULU was evaluated across demographic subgroups and shows consistent improvements without introducing systematic biases. However, users should validate performance on their specific populations.
|
| 291 |
+
|
| 292 |
+
## Citation
|
| 293 |
+
|
| 294 |
+
If you use DELULU in your research, please cite:
|
| 295 |
+
|
| 296 |
+
```bibtex
|
| 297 |
+
@article{baali2025delulu,
|
| 298 |
+
title={DELULU: Discriminative Embedding Learning Using Latent Units for Speaker-Aware Self-Supervised Speech Foundational Model},
|
| 299 |
+
author={Baali, Massa and Singh, Rita and Raj, Bhiksha},
|
| 300 |
+
journal={arXiv preprint arXiv:2510.17662},
|
| 301 |
+
year={2025}
|
| 302 |
+
}
|
| 303 |
+
```
|
| 304 |
+
|
| 305 |
+
## Related Work
|
| 306 |
+
|
| 307 |
+
- **HuBERT**: [Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447)
|
| 308 |
+
- **ReDimNet**: State-of-the-art speaker verification model used for pseudo-label generation
|
| 309 |
+
- **WavLM**: [Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900)
|
| 310 |
+
|
| 311 |
+
## Acknowledgments
|
| 312 |
+
|
| 313 |
+
This work was conducted at Carnegie Mellon University's Language Technologies Institute. We thank the speech processing community for foundational work on self-supervised learning and speaker verification.
|
| 314 |
+
|
| 315 |
+
## Contact
|
| 316 |
+
|
| 317 |
+
For questions about the model or paper:
|
| 318 |
+
- **Author**: Massa Baali
|
| 319 |
+
- **Advisors**: Prof. Rita Singh, Prof. Bhiksha Raj
|
| 320 |
+
- **Institution**: Carnegie Mellon University, Language Technologies Institute
|
| 321 |
+
|
| 322 |
+
---
|
| 323 |
+
|
| 324 |
+
<div align="center">
|
| 325 |
+
<i>DELULU: Where Self-Supervised Learning Meets Speaker Identity</i>
|
| 326 |
+
</div>
|
config.json
CHANGED
|
@@ -1,39 +1,13 @@
|
|
| 1 |
{
|
| 2 |
"model_type": "delulu",
|
| 3 |
-
"architectures": [
|
| 4 |
-
"DELULUModel"
|
| 5 |
-
],
|
| 6 |
"auto_map": {
|
| 7 |
"AutoConfig": "configuration_delulu.DELULUConfig",
|
| 8 |
"AutoModel": "modeling_delulu.DELULUModel"
|
| 9 |
},
|
| 10 |
-
"conv_dim": [
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
512,
|
| 14 |
-
512,
|
| 15 |
-
512,
|
| 16 |
-
512,
|
| 17 |
-
512
|
| 18 |
-
],
|
| 19 |
-
"conv_kernel": [
|
| 20 |
-
10,
|
| 21 |
-
3,
|
| 22 |
-
3,
|
| 23 |
-
3,
|
| 24 |
-
3,
|
| 25 |
-
2,
|
| 26 |
-
2
|
| 27 |
-
],
|
| 28 |
-
"conv_stride": [
|
| 29 |
-
4,
|
| 30 |
-
2,
|
| 31 |
-
2,
|
| 32 |
-
2,
|
| 33 |
-
2,
|
| 34 |
-
2,
|
| 35 |
-
2
|
| 36 |
-
],
|
| 37 |
"conv_bias": false,
|
| 38 |
"extractor_mode": "group_norm",
|
| 39 |
"hidden_size": 768,
|
|
@@ -44,7 +18,7 @@
|
|
| 44 |
"attention_dropout": 0.1,
|
| 45 |
"final_dropout": 0.1,
|
| 46 |
"feat_proj_dropout": 0.1,
|
| 47 |
-
"layer_norm_eps": 1e-
|
| 48 |
"layer_drop": 0.05,
|
| 49 |
"num_conv_pos_embeddings": 128,
|
| 50 |
"num_conv_pos_embedding_groups": 16,
|
|
@@ -56,5 +30,6 @@
|
|
| 56 |
"pad_token_id": 0,
|
| 57 |
"bos_token_id": 1,
|
| 58 |
"eos_token_id": 2,
|
|
|
|
| 59 |
"torch_dtype": "float32"
|
| 60 |
-
}
|
|
|
|
| 1 |
{
|
| 2 |
"model_type": "delulu",
|
| 3 |
+
"architectures": ["DELULUModel"],
|
|
|
|
|
|
|
| 4 |
"auto_map": {
|
| 5 |
"AutoConfig": "configuration_delulu.DELULUConfig",
|
| 6 |
"AutoModel": "modeling_delulu.DELULUModel"
|
| 7 |
},
|
| 8 |
+
"conv_dim": [512, 512, 512, 512, 512, 512, 512],
|
| 9 |
+
"conv_kernel": [10, 3, 3, 3, 3, 2, 2],
|
| 10 |
+
"conv_stride": [4, 2, 2, 2, 2, 2, 2],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"conv_bias": false,
|
| 12 |
"extractor_mode": "group_norm",
|
| 13 |
"hidden_size": 768,
|
|
|
|
| 18 |
"attention_dropout": 0.1,
|
| 19 |
"final_dropout": 0.1,
|
| 20 |
"feat_proj_dropout": 0.1,
|
| 21 |
+
"layer_norm_eps": 1e-5,
|
| 22 |
"layer_drop": 0.05,
|
| 23 |
"num_conv_pos_embeddings": 128,
|
| 24 |
"num_conv_pos_embedding_groups": 16,
|
|
|
|
| 30 |
"pad_token_id": 0,
|
| 31 |
"bos_token_id": 1,
|
| 32 |
"eos_token_id": 2,
|
| 33 |
+
"transformers_version": "4.36.0",
|
| 34 |
"torch_dtype": "float32"
|
| 35 |
+
}
|
configuration_delulu.py
CHANGED
|
@@ -1,20 +1,80 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from transformers import PretrainedConfig
|
| 4 |
|
| 5 |
|
| 6 |
class DELULUConfig(PretrainedConfig):
|
| 7 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
model_type = "delulu"
|
| 10 |
|
| 11 |
def __init__(
|
| 12 |
self,
|
|
|
|
| 13 |
conv_dim=None,
|
| 14 |
conv_kernel=None,
|
| 15 |
conv_stride=None,
|
| 16 |
conv_bias=False,
|
| 17 |
extractor_mode="group_norm",
|
|
|
|
|
|
|
| 18 |
hidden_size=768,
|
| 19 |
num_hidden_layers=12,
|
| 20 |
num_attention_heads=12,
|
|
@@ -25,15 +85,24 @@ class DELULUConfig(PretrainedConfig):
|
|
| 25 |
feat_proj_dropout=0.1,
|
| 26 |
layer_norm_eps=1e-5,
|
| 27 |
layer_drop=0.05,
|
|
|
|
|
|
|
| 28 |
num_conv_pos_embeddings=128,
|
| 29 |
num_conv_pos_embedding_groups=16,
|
|
|
|
|
|
|
| 30 |
sampling_rate=16000,
|
| 31 |
do_stable_layer_norm=False,
|
|
|
|
|
|
|
| 32 |
num_clusters=256,
|
| 33 |
feature_type="redimnet",
|
|
|
|
|
|
|
| 34 |
pad_token_id=0,
|
| 35 |
bos_token_id=1,
|
| 36 |
eos_token_id=2,
|
|
|
|
| 37 |
**kwargs
|
| 38 |
):
|
| 39 |
super().__init__(
|
|
@@ -43,10 +112,18 @@ class DELULUConfig(PretrainedConfig):
|
|
| 43 |
**kwargs
|
| 44 |
)
|
| 45 |
|
| 46 |
-
# DELULU conv
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
self.conv_bias = conv_bias
|
| 51 |
self.extractor_mode = extractor_mode
|
| 52 |
|
|
@@ -70,4 +147,13 @@ class DELULUConfig(PretrainedConfig):
|
|
| 70 |
self.num_clusters = num_clusters
|
| 71 |
self.feature_type = feature_type
|
| 72 |
|
|
|
|
| 73 |
self.num_feat_extract_layers = len(self.conv_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DELULU Configuration
|
| 3 |
+
|
| 4 |
+
Configuration class for DELULU (Discriminative Embedding Learning Using Latent Units),
|
| 5 |
+
a speaker-aware self-supervised speech foundational model.
|
| 6 |
+
|
| 7 |
+
Paper: https://arxiv.org/abs/2510.17662
|
| 8 |
+
Authors: Massa Baali, Rita Singh, Bhiksha Raj
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from transformers import PretrainedConfig
|
| 12 |
|
| 13 |
|
| 14 |
class DELULUConfig(PretrainedConfig):
|
| 15 |
+
r"""
|
| 16 |
+
Configuration class for DELULU model.
|
| 17 |
+
|
| 18 |
+
DELULU is based on HuBERT architecture with modified convolutional strides
|
| 19 |
+
optimized for speaker verification (16ms frame shift).
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
conv_dim (`List[int]`, *optional*, defaults to `[512, 512, 512, 512, 512, 512, 512]`):
|
| 23 |
+
Dimensions of each convolutional layer in the feature extractor.
|
| 24 |
+
conv_kernel (`List[int]`, *optional*, defaults to `[10, 3, 3, 3, 3, 2, 2]`):
|
| 25 |
+
Kernel sizes of each convolutional layer in the feature extractor.
|
| 26 |
+
conv_stride (`List[int]`, *optional*, defaults to `[4, 2, 2, 2, 2, 2, 2]`):
|
| 27 |
+
Stride sizes of each convolutional layer. Note: first stride is 4 (vs 5 in HuBERT)
|
| 28 |
+
for 16ms frame shift optimized for speaker verification.
|
| 29 |
+
conv_bias (`bool`, *optional*, defaults to `False`):
|
| 30 |
+
Whether to use bias in convolutional layers.
|
| 31 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 32 |
+
Dimensionality of the encoder layers and pooler layer.
|
| 33 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 34 |
+
Number of hidden layers in the Transformer encoder.
|
| 35 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 36 |
+
Number of attention heads for each attention layer.
|
| 37 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 38 |
+
Dimensionality of the feed-forward layer in the Transformer encoder.
|
| 39 |
+
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
| 40 |
+
Dropout probability for all fully connected layers.
|
| 41 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
| 42 |
+
Dropout probability for attention weights.
|
| 43 |
+
feat_proj_dropout (`float`, *optional*, defaults to 0.1):
|
| 44 |
+
Dropout probability for feature projection layer.
|
| 45 |
+
layer_drop (`float`, *optional*, defaults to 0.05):
|
| 46 |
+
Layer drop probability during training.
|
| 47 |
+
num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
|
| 48 |
+
Number of convolutional positional embeddings.
|
| 49 |
+
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
|
| 50 |
+
Number of groups for convolutional positional embeddings.
|
| 51 |
+
sampling_rate (`int`, *optional*, defaults to 16000):
|
| 52 |
+
Audio sampling rate in Hz.
|
| 53 |
+
|
| 54 |
+
Example:
|
| 55 |
+
```python
|
| 56 |
+
from transformers import AutoConfig, AutoModel
|
| 57 |
+
|
| 58 |
+
# Load config
|
| 59 |
+
config = AutoConfig.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
|
| 60 |
+
|
| 61 |
+
# Load model
|
| 62 |
+
model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
|
| 63 |
+
```
|
| 64 |
+
"""
|
| 65 |
|
| 66 |
model_type = "delulu"
|
| 67 |
|
| 68 |
def __init__(
|
| 69 |
self,
|
| 70 |
+
# Convolutional feature extractor
|
| 71 |
conv_dim=None,
|
| 72 |
conv_kernel=None,
|
| 73 |
conv_stride=None,
|
| 74 |
conv_bias=False,
|
| 75 |
extractor_mode="group_norm",
|
| 76 |
+
|
| 77 |
+
# Transformer encoder
|
| 78 |
hidden_size=768,
|
| 79 |
num_hidden_layers=12,
|
| 80 |
num_attention_heads=12,
|
|
|
|
| 85 |
feat_proj_dropout=0.1,
|
| 86 |
layer_norm_eps=1e-5,
|
| 87 |
layer_drop=0.05,
|
| 88 |
+
|
| 89 |
+
# Positional encoding
|
| 90 |
num_conv_pos_embeddings=128,
|
| 91 |
num_conv_pos_embedding_groups=16,
|
| 92 |
+
|
| 93 |
+
# Audio settings
|
| 94 |
sampling_rate=16000,
|
| 95 |
do_stable_layer_norm=False,
|
| 96 |
+
|
| 97 |
+
# DELULU-specific settings
|
| 98 |
num_clusters=256,
|
| 99 |
feature_type="redimnet",
|
| 100 |
+
|
| 101 |
+
# Pad token for compatibility
|
| 102 |
pad_token_id=0,
|
| 103 |
bos_token_id=1,
|
| 104 |
eos_token_id=2,
|
| 105 |
+
|
| 106 |
**kwargs
|
| 107 |
):
|
| 108 |
super().__init__(
|
|
|
|
| 112 |
**kwargs
|
| 113 |
)
|
| 114 |
|
| 115 |
+
# Set default DELULU conv configuration
|
| 116 |
+
# Key difference from HuBERT: first stride is 4 instead of 5
|
| 117 |
+
if conv_dim is None:
|
| 118 |
+
conv_dim = [512, 512, 512, 512, 512, 512, 512]
|
| 119 |
+
if conv_kernel is None:
|
| 120 |
+
conv_kernel = [10, 3, 3, 3, 3, 2, 2]
|
| 121 |
+
if conv_stride is None:
|
| 122 |
+
conv_stride = [4, 2, 2, 2, 2, 2, 2]
|
| 123 |
+
|
| 124 |
+
self.conv_dim = conv_dim
|
| 125 |
+
self.conv_kernel = conv_kernel
|
| 126 |
+
self.conv_stride = conv_stride
|
| 127 |
self.conv_bias = conv_bias
|
| 128 |
self.extractor_mode = extractor_mode
|
| 129 |
|
|
|
|
| 147 |
self.num_clusters = num_clusters
|
| 148 |
self.feature_type = feature_type
|
| 149 |
|
| 150 |
+
# Computed properties
|
| 151 |
self.num_feat_extract_layers = len(self.conv_dim)
|
| 152 |
+
|
| 153 |
+
@property
|
| 154 |
+
def inputs_to_logits_ratio(self):
|
| 155 |
+
"""Compute the ratio between input samples and output frames."""
|
| 156 |
+
ratio = 1
|
| 157 |
+
for stride in self.conv_stride:
|
| 158 |
+
ratio *= stride
|
| 159 |
+
return ratio
|
convert_delulu_fixed.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
DELULU Checkpoint Converter - Fixed Version
|
| 4 |
+
|
| 5 |
+
Converts DELULU model checkpoints from torchaudio/PyTorch Lightning format
|
| 6 |
+
to Hugging Face compatible format with proper metadata.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python convert_delulu_fixed.py \
|
| 10 |
+
--checkpoint /path/to/epoch=45-step=400000.ckpt \
|
| 11 |
+
--output-dir ./delulu_hf_model
|
| 12 |
+
|
| 13 |
+
Author: Massa Baali
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
from collections import OrderedDict
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from safetensors.torch import save_file as save_safetensors
|
| 27 |
+
SAFETENSORS_AVAILABLE = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
SAFETENSORS_AVAILABLE = False
|
| 30 |
+
print("Warning: safetensors not installed. Install with: pip install safetensors")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_lightning_checkpoint(checkpoint_path: str) -> dict:
|
| 34 |
+
"""Load and clean PyTorch Lightning checkpoint."""
|
| 35 |
+
print(f"Loading checkpoint: {checkpoint_path}")
|
| 36 |
+
|
| 37 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 38 |
+
|
| 39 |
+
# Extract state dict
|
| 40 |
+
if "state_dict" in checkpoint:
|
| 41 |
+
state_dict = checkpoint["state_dict"]
|
| 42 |
+
else:
|
| 43 |
+
state_dict = checkpoint
|
| 44 |
+
|
| 45 |
+
# Clean up state dict keys
|
| 46 |
+
cleaned_state_dict = OrderedDict()
|
| 47 |
+
|
| 48 |
+
for key, value in state_dict.items():
|
| 49 |
+
new_key = key
|
| 50 |
+
|
| 51 |
+
# Remove Lightning prefixes
|
| 52 |
+
if key.startswith("model.wav2vec2."):
|
| 53 |
+
new_key = key.replace("model.wav2vec2.", "")
|
| 54 |
+
elif key.startswith("model."):
|
| 55 |
+
new_key = key.replace("model.", "")
|
| 56 |
+
|
| 57 |
+
# Skip auxiliary heads
|
| 58 |
+
if "aux" in new_key:
|
| 59 |
+
print(f" Skipping: {key}")
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
cleaned_state_dict[new_key] = value
|
| 63 |
+
|
| 64 |
+
print(f"Loaded {len(cleaned_state_dict)} parameters")
|
| 65 |
+
return cleaned_state_dict
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def save_pytorch_model_bin(state_dict: dict, output_path: Path):
|
| 69 |
+
"""
|
| 70 |
+
Save state dict as pytorch_model.bin with proper format.
|
| 71 |
+
|
| 72 |
+
This saves ONLY the state dict (not a full checkpoint with metadata),
|
| 73 |
+
which is what HuggingFace expects.
|
| 74 |
+
"""
|
| 75 |
+
print(f"Saving pytorch_model.bin to: {output_path}")
|
| 76 |
+
|
| 77 |
+
# Convert all tensors to contiguous for safety
|
| 78 |
+
clean_state_dict = OrderedDict()
|
| 79 |
+
for key, value in state_dict.items():
|
| 80 |
+
if isinstance(value, torch.Tensor):
|
| 81 |
+
clean_state_dict[key] = value.contiguous()
|
| 82 |
+
else:
|
| 83 |
+
clean_state_dict[key] = value
|
| 84 |
+
|
| 85 |
+
# Save just the state dict (NOT a checkpoint dict)
|
| 86 |
+
torch.save(clean_state_dict, output_path)
|
| 87 |
+
|
| 88 |
+
print(f" Saved {len(clean_state_dict)} tensors")
|
| 89 |
+
print(f" File size: {output_path.stat().st_size / 1024 / 1024:.2f} MB")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def save_safetensors_model(state_dict: dict, output_path: Path):
|
| 93 |
+
"""Save state dict in safetensors format."""
|
| 94 |
+
if not SAFETENSORS_AVAILABLE:
|
| 95 |
+
print("Skipping safetensors (not installed)")
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
print(f"Saving model.safetensors to: {output_path}")
|
| 99 |
+
|
| 100 |
+
# Safetensors requires contiguous tensors
|
| 101 |
+
clean_state_dict = {}
|
| 102 |
+
for key, value in state_dict.items():
|
| 103 |
+
if isinstance(value, torch.Tensor):
|
| 104 |
+
clean_state_dict[key] = value.contiguous()
|
| 105 |
+
|
| 106 |
+
save_safetensors(clean_state_dict, str(output_path))
|
| 107 |
+
print(f" File size: {output_path.stat().st_size / 1024 / 1024:.2f} MB")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def create_config_json(output_dir: Path):
|
| 111 |
+
"""Create config.json with DELULU configuration."""
|
| 112 |
+
config = {
|
| 113 |
+
"model_type": "delulu",
|
| 114 |
+
"architectures": ["DELULUModel"],
|
| 115 |
+
"auto_map": {
|
| 116 |
+
"AutoConfig": "configuration_delulu.DELULUConfig",
|
| 117 |
+
"AutoModel": "modeling_delulu.DELULUModel"
|
| 118 |
+
},
|
| 119 |
+
"conv_dim": [512, 512, 512, 512, 512, 512, 512],
|
| 120 |
+
"conv_kernel": [10, 3, 3, 3, 3, 2, 2],
|
| 121 |
+
"conv_stride": [4, 2, 2, 2, 2, 2, 2],
|
| 122 |
+
"conv_bias": False,
|
| 123 |
+
"extractor_mode": "group_norm",
|
| 124 |
+
"hidden_size": 768,
|
| 125 |
+
"num_hidden_layers": 12,
|
| 126 |
+
"num_attention_heads": 12,
|
| 127 |
+
"intermediate_size": 3072,
|
| 128 |
+
"hidden_dropout": 0.1,
|
| 129 |
+
"attention_dropout": 0.1,
|
| 130 |
+
"final_dropout": 0.1,
|
| 131 |
+
"feat_proj_dropout": 0.1,
|
| 132 |
+
"layer_norm_eps": 1e-5,
|
| 133 |
+
"layer_drop": 0.05,
|
| 134 |
+
"num_conv_pos_embeddings": 128,
|
| 135 |
+
"num_conv_pos_embedding_groups": 16,
|
| 136 |
+
"sampling_rate": 16000,
|
| 137 |
+
"do_stable_layer_norm": False,
|
| 138 |
+
"num_clusters": 256,
|
| 139 |
+
"feature_type": "redimnet",
|
| 140 |
+
"num_feat_extract_layers": 7,
|
| 141 |
+
"pad_token_id": 0,
|
| 142 |
+
"bos_token_id": 1,
|
| 143 |
+
"eos_token_id": 2,
|
| 144 |
+
"torch_dtype": "float32"
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
config_path = output_dir / "config.json"
|
| 148 |
+
with open(config_path, "w") as f:
|
| 149 |
+
json.dump(config, f, indent=2)
|
| 150 |
+
print(f"Created config.json")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def create_configuration_delulu(output_dir: Path):
|
| 154 |
+
"""Create configuration_delulu.py file."""
|
| 155 |
+
code = '''"""DELULU Configuration"""
|
| 156 |
+
|
| 157 |
+
from transformers import PretrainedConfig
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class DELULUConfig(PretrainedConfig):
|
| 161 |
+
"""Configuration class for DELULU model."""
|
| 162 |
+
|
| 163 |
+
model_type = "delulu"
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
conv_dim=None,
|
| 168 |
+
conv_kernel=None,
|
| 169 |
+
conv_stride=None,
|
| 170 |
+
conv_bias=False,
|
| 171 |
+
extractor_mode="group_norm",
|
| 172 |
+
hidden_size=768,
|
| 173 |
+
num_hidden_layers=12,
|
| 174 |
+
num_attention_heads=12,
|
| 175 |
+
intermediate_size=3072,
|
| 176 |
+
hidden_dropout=0.1,
|
| 177 |
+
attention_dropout=0.1,
|
| 178 |
+
final_dropout=0.1,
|
| 179 |
+
feat_proj_dropout=0.1,
|
| 180 |
+
layer_norm_eps=1e-5,
|
| 181 |
+
layer_drop=0.05,
|
| 182 |
+
num_conv_pos_embeddings=128,
|
| 183 |
+
num_conv_pos_embedding_groups=16,
|
| 184 |
+
sampling_rate=16000,
|
| 185 |
+
do_stable_layer_norm=False,
|
| 186 |
+
num_clusters=256,
|
| 187 |
+
feature_type="redimnet",
|
| 188 |
+
pad_token_id=0,
|
| 189 |
+
bos_token_id=1,
|
| 190 |
+
eos_token_id=2,
|
| 191 |
+
**kwargs
|
| 192 |
+
):
|
| 193 |
+
super().__init__(
|
| 194 |
+
pad_token_id=pad_token_id,
|
| 195 |
+
bos_token_id=bos_token_id,
|
| 196 |
+
eos_token_id=eos_token_id,
|
| 197 |
+
**kwargs
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# DELULU conv config: [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
| 201 |
+
self.conv_dim = conv_dim or [512, 512, 512, 512, 512, 512, 512]
|
| 202 |
+
self.conv_kernel = conv_kernel or [10, 3, 3, 3, 3, 2, 2]
|
| 203 |
+
self.conv_stride = conv_stride or [4, 2, 2, 2, 2, 2, 2]
|
| 204 |
+
self.conv_bias = conv_bias
|
| 205 |
+
self.extractor_mode = extractor_mode
|
| 206 |
+
|
| 207 |
+
self.hidden_size = hidden_size
|
| 208 |
+
self.num_hidden_layers = num_hidden_layers
|
| 209 |
+
self.num_attention_heads = num_attention_heads
|
| 210 |
+
self.intermediate_size = intermediate_size
|
| 211 |
+
self.hidden_dropout = hidden_dropout
|
| 212 |
+
self.attention_dropout = attention_dropout
|
| 213 |
+
self.final_dropout = final_dropout
|
| 214 |
+
self.feat_proj_dropout = feat_proj_dropout
|
| 215 |
+
self.layer_norm_eps = layer_norm_eps
|
| 216 |
+
self.layer_drop = layer_drop
|
| 217 |
+
|
| 218 |
+
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
| 219 |
+
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
| 220 |
+
|
| 221 |
+
self.sampling_rate = sampling_rate
|
| 222 |
+
self.do_stable_layer_norm = do_stable_layer_norm
|
| 223 |
+
|
| 224 |
+
self.num_clusters = num_clusters
|
| 225 |
+
self.feature_type = feature_type
|
| 226 |
+
|
| 227 |
+
self.num_feat_extract_layers = len(self.conv_dim)
|
| 228 |
+
'''
|
| 229 |
+
|
| 230 |
+
with open(output_dir / "configuration_delulu.py", "w") as f:
|
| 231 |
+
f.write(code)
|
| 232 |
+
print("Created configuration_delulu.py")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def create_modeling_delulu(output_dir: Path):
|
| 236 |
+
"""Create modeling_delulu.py file."""
|
| 237 |
+
code = '''"""DELULU Model"""
|
| 238 |
+
|
| 239 |
+
import torch
|
| 240 |
+
import torch.nn as nn
|
| 241 |
+
from typing import Optional, Tuple, Union
|
| 242 |
+
from transformers import PreTrainedModel
|
| 243 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 244 |
+
from .configuration_delulu import DELULUConfig
|
| 245 |
+
|
| 246 |
+
try:
|
| 247 |
+
from torchaudio.models.wav2vec2 import wav2vec2_model
|
| 248 |
+
TORCHAUDIO_AVAILABLE = True
|
| 249 |
+
except ImportError:
|
| 250 |
+
TORCHAUDIO_AVAILABLE = False
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class DELULUModel(PreTrainedModel):
|
| 254 |
+
"""
|
| 255 |
+
DELULU Model for speaker-aware speech representation learning.
|
| 256 |
+
|
| 257 |
+
Example:
|
| 258 |
+
```python
|
| 259 |
+
from transformers import AutoModel
|
| 260 |
+
import torch
|
| 261 |
+
|
| 262 |
+
model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
|
| 263 |
+
waveform = torch.randn(1, 16000) # 1 second at 16kHz
|
| 264 |
+
outputs = model(waveform)
|
| 265 |
+
features = outputs.last_hidden_state
|
| 266 |
+
```
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
config_class = DELULUConfig
|
| 270 |
+
base_model_prefix = "delulu"
|
| 271 |
+
main_input_name = "input_values"
|
| 272 |
+
|
| 273 |
+
def __init__(self, config: DELULUConfig):
|
| 274 |
+
super().__init__(config)
|
| 275 |
+
self.config = config
|
| 276 |
+
|
| 277 |
+
if not TORCHAUDIO_AVAILABLE:
|
| 278 |
+
raise ImportError("torchaudio is required. Install with: pip install torchaudio")
|
| 279 |
+
|
| 280 |
+
# Build conv config
|
| 281 |
+
conv_layer_config = list(zip(
|
| 282 |
+
config.conv_dim,
|
| 283 |
+
config.conv_kernel,
|
| 284 |
+
config.conv_stride
|
| 285 |
+
))
|
| 286 |
+
|
| 287 |
+
# Create torchaudio model
|
| 288 |
+
self.wav2vec2 = wav2vec2_model(
|
| 289 |
+
extractor_mode=config.extractor_mode,
|
| 290 |
+
extractor_conv_layer_config=conv_layer_config,
|
| 291 |
+
extractor_conv_bias=config.conv_bias,
|
| 292 |
+
encoder_embed_dim=config.hidden_size,
|
| 293 |
+
encoder_projection_dropout=config.feat_proj_dropout,
|
| 294 |
+
encoder_pos_conv_kernel=config.num_conv_pos_embeddings,
|
| 295 |
+
encoder_pos_conv_groups=config.num_conv_pos_embedding_groups,
|
| 296 |
+
encoder_num_layers=config.num_hidden_layers,
|
| 297 |
+
encoder_num_heads=config.num_attention_heads,
|
| 298 |
+
encoder_attention_dropout=config.attention_dropout,
|
| 299 |
+
encoder_ff_interm_features=config.intermediate_size,
|
| 300 |
+
encoder_ff_interm_dropout=config.hidden_dropout,
|
| 301 |
+
encoder_dropout=config.hidden_dropout,
|
| 302 |
+
encoder_layer_norm_first=config.do_stable_layer_norm,
|
| 303 |
+
encoder_layer_drop=config.layer_drop,
|
| 304 |
+
aux_num_out=None,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
self.post_init()
|
| 308 |
+
|
| 309 |
+
def _init_weights(self, module):
|
| 310 |
+
"""Initialize weights."""
|
| 311 |
+
pass # Handled by torchaudio
|
| 312 |
+
|
| 313 |
+
def forward(
|
| 314 |
+
self,
|
| 315 |
+
input_values: torch.Tensor,
|
| 316 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 317 |
+
output_hidden_states: Optional[bool] = None,
|
| 318 |
+
return_dict: Optional[bool] = None,
|
| 319 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 320 |
+
"""
|
| 321 |
+
Args:
|
| 322 |
+
input_values: Audio waveform (batch, samples) at 16kHz
|
| 323 |
+
attention_mask: Optional attention mask
|
| 324 |
+
output_hidden_states: Whether to return all hidden states
|
| 325 |
+
return_dict: Whether to return BaseModelOutput
|
| 326 |
+
"""
|
| 327 |
+
return_dict = return_dict if return_dict is not None else True
|
| 328 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
| 329 |
+
|
| 330 |
+
if input_values.dim() == 1:
|
| 331 |
+
input_values = input_values.unsqueeze(0)
|
| 332 |
+
|
| 333 |
+
lengths = None
|
| 334 |
+
if attention_mask is not None:
|
| 335 |
+
lengths = attention_mask.sum(dim=-1)
|
| 336 |
+
|
| 337 |
+
if output_hidden_states:
|
| 338 |
+
features, _ = self.wav2vec2.extract_features(input_values, lengths=lengths)
|
| 339 |
+
hidden_states = tuple(features)
|
| 340 |
+
last_hidden_state = features[-1]
|
| 341 |
+
else:
|
| 342 |
+
last_hidden_state, _ = self.wav2vec2(input_values, lengths=lengths)
|
| 343 |
+
hidden_states = None
|
| 344 |
+
|
| 345 |
+
if not return_dict:
|
| 346 |
+
return (last_hidden_state, hidden_states) if hidden_states else (last_hidden_state,)
|
| 347 |
+
|
| 348 |
+
return BaseModelOutput(
|
| 349 |
+
last_hidden_state=last_hidden_state,
|
| 350 |
+
hidden_states=hidden_states,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
def extract_features(self, input_values: torch.Tensor):
|
| 354 |
+
"""Extract features from all layers."""
|
| 355 |
+
if input_values.dim() == 1:
|
| 356 |
+
input_values = input_values.unsqueeze(0)
|
| 357 |
+
features, _ = self.wav2vec2.extract_features(input_values)
|
| 358 |
+
return tuple(features)
|
| 359 |
+
|
| 360 |
+
@classmethod
|
| 361 |
+
def _load_pretrained_model_low_mem(cls, *args, **kwargs):
|
| 362 |
+
"""Override to handle custom loading."""
|
| 363 |
+
return super()._load_pretrained_model_low_mem(*args, **kwargs)
|
| 364 |
+
'''
|
| 365 |
+
|
| 366 |
+
with open(output_dir / "modeling_delulu.py", "w") as f:
|
| 367 |
+
f.write(code)
|
| 368 |
+
print("Created modeling_delulu.py")
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def convert_checkpoint(checkpoint_path: str, output_dir: str):
|
| 372 |
+
"""Main conversion function."""
|
| 373 |
+
output_path = Path(output_dir)
|
| 374 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 375 |
+
|
| 376 |
+
print("=" * 60)
|
| 377 |
+
print("DELULU Checkpoint Converter")
|
| 378 |
+
print("=" * 60)
|
| 379 |
+
|
| 380 |
+
# Step 1: Load checkpoint
|
| 381 |
+
state_dict = load_lightning_checkpoint(checkpoint_path)
|
| 382 |
+
|
| 383 |
+
# Step 2: Print some keys for verification
|
| 384 |
+
print("\nSample keys in state dict:")
|
| 385 |
+
for i, key in enumerate(list(state_dict.keys())[:10]):
|
| 386 |
+
print(f" {key}")
|
| 387 |
+
print(f" ... and {len(state_dict) - 10} more")
|
| 388 |
+
|
| 389 |
+
# Step 3: Save weights
|
| 390 |
+
save_pytorch_model_bin(state_dict, output_path / "pytorch_model.bin")
|
| 391 |
+
|
| 392 |
+
if SAFETENSORS_AVAILABLE:
|
| 393 |
+
save_safetensors_model(state_dict, output_path / "model.safetensors")
|
| 394 |
+
|
| 395 |
+
# Step 4: Create config and code files
|
| 396 |
+
create_config_json(output_path)
|
| 397 |
+
create_configuration_delulu(output_path)
|
| 398 |
+
create_modeling_delulu(output_path)
|
| 399 |
+
|
| 400 |
+
# Step 5: Summary
|
| 401 |
+
print("\n" + "=" * 60)
|
| 402 |
+
print("Conversion Complete!")
|
| 403 |
+
print("=" * 60)
|
| 404 |
+
print(f"\nOutput directory: {output_path}")
|
| 405 |
+
print("\nFiles created:")
|
| 406 |
+
for f in sorted(output_path.iterdir()):
|
| 407 |
+
size_mb = f.stat().st_size / 1024 / 1024
|
| 408 |
+
print(f" {f.name}: {size_mb:.2f} MB")
|
| 409 |
+
|
| 410 |
+
print("\nNext steps:")
|
| 411 |
+
print(" 1. Upload all files to huggingface.co/cmu-mlsp/DELULU")
|
| 412 |
+
print(" 2. Test with:")
|
| 413 |
+
print(' model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)')
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def main():
|
| 417 |
+
parser = argparse.ArgumentParser(description="Convert DELULU checkpoint to HuggingFace format")
|
| 418 |
+
parser.add_argument("--checkpoint", "-c", required=True, help="Path to .ckpt file")
|
| 419 |
+
parser.add_argument("--output-dir", "-o", required=True, help="Output directory")
|
| 420 |
+
args = parser.parse_args()
|
| 421 |
+
|
| 422 |
+
convert_checkpoint(args.checkpoint, args.output_dir)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
if __name__ == "__main__":
|
| 426 |
+
main()
|
convert_delulu_to_hf.py
ADDED
|
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
DELULU Checkpoint Converter
|
| 4 |
+
===========================
|
| 5 |
+
|
| 6 |
+
Converts DELULU model checkpoints from torchaudio/PyTorch Lightning format
|
| 7 |
+
to Hugging Face compatible format (config.json + model weights).
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python convert_delulu_to_hf.py \
|
| 11 |
+
--checkpoint /path/to/epoch=45-step=400000.ckpt \
|
| 12 |
+
--output-dir ./delulu_hf_model
|
| 13 |
+
|
| 14 |
+
Author: Massa Baali
|
| 15 |
+
Model: DELULU - Speaker-Aware Self-Supervised Speech Foundational Model
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Optional, Tuple, List
|
| 25 |
+
from collections import OrderedDict
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from safetensors.torch import save_file as save_safetensors
|
| 32 |
+
SAFETENSORS_AVAILABLE = True
|
| 33 |
+
except ImportError:
|
| 34 |
+
SAFETENSORS_AVAILABLE = False
|
| 35 |
+
print("Warning: safetensors not installed. Will save as pytorch_model.bin only.")
|
| 36 |
+
print("Install with: pip install safetensors")
|
| 37 |
+
|
| 38 |
+
# Configure logging
|
| 39 |
+
logging.basicConfig(
|
| 40 |
+
level=logging.INFO,
|
| 41 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 42 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 43 |
+
)
|
| 44 |
+
logger = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# =============================================================================
|
| 48 |
+
# DELULU Configuration
|
| 49 |
+
# =============================================================================
|
| 50 |
+
|
| 51 |
+
class DELULUConfig:
|
| 52 |
+
"""
|
| 53 |
+
Configuration class for DELULU model.
|
| 54 |
+
|
| 55 |
+
DELULU uses HuBERT architecture with modified convolutional strides
|
| 56 |
+
for 16ms frame shift, optimized for speaker verification.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
# Model architecture identifier
|
| 60 |
+
model_type = "delulu"
|
| 61 |
+
architectures = ["DELULUModel"]
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
# Convolutional feature extractor config
|
| 66 |
+
# DELULU: [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
| 67 |
+
conv_dim: List[int] = None,
|
| 68 |
+
conv_kernel: List[int] = None,
|
| 69 |
+
conv_stride: List[int] = None,
|
| 70 |
+
conv_bias: bool = False,
|
| 71 |
+
extractor_mode: str = "group_norm",
|
| 72 |
+
|
| 73 |
+
# Transformer encoder config
|
| 74 |
+
hidden_size: int = 768,
|
| 75 |
+
num_hidden_layers: int = 12,
|
| 76 |
+
num_attention_heads: int = 12,
|
| 77 |
+
intermediate_size: int = 3072,
|
| 78 |
+
hidden_dropout: float = 0.1,
|
| 79 |
+
attention_dropout: float = 0.1,
|
| 80 |
+
final_dropout: float = 0.1,
|
| 81 |
+
feat_proj_dropout: float = 0.1,
|
| 82 |
+
layer_norm_eps: float = 1e-5,
|
| 83 |
+
layer_drop: float = 0.05,
|
| 84 |
+
|
| 85 |
+
# Positional encoding
|
| 86 |
+
num_conv_pos_embeddings: int = 128,
|
| 87 |
+
num_conv_pos_embedding_groups: int = 16,
|
| 88 |
+
|
| 89 |
+
# Audio config
|
| 90 |
+
sampling_rate: int = 16000,
|
| 91 |
+
do_stable_layer_norm: bool = False,
|
| 92 |
+
|
| 93 |
+
# Training config (for reference)
|
| 94 |
+
num_clusters: int = 256,
|
| 95 |
+
feature_type: str = "redimnet",
|
| 96 |
+
|
| 97 |
+
**kwargs
|
| 98 |
+
):
|
| 99 |
+
# Set default conv config for DELULU
|
| 100 |
+
if conv_dim is None:
|
| 101 |
+
conv_dim = [512, 512, 512, 512, 512, 512, 512]
|
| 102 |
+
if conv_kernel is None:
|
| 103 |
+
conv_kernel = [10, 3, 3, 3, 3, 2, 2]
|
| 104 |
+
if conv_stride is None:
|
| 105 |
+
conv_stride = [4, 2, 2, 2, 2, 2, 2] # Key difference from HuBERT!
|
| 106 |
+
|
| 107 |
+
self.conv_dim = conv_dim
|
| 108 |
+
self.conv_kernel = conv_kernel
|
| 109 |
+
self.conv_stride = conv_stride
|
| 110 |
+
self.conv_bias = conv_bias
|
| 111 |
+
self.extractor_mode = extractor_mode
|
| 112 |
+
|
| 113 |
+
self.hidden_size = hidden_size
|
| 114 |
+
self.num_hidden_layers = num_hidden_layers
|
| 115 |
+
self.num_attention_heads = num_attention_heads
|
| 116 |
+
self.intermediate_size = intermediate_size
|
| 117 |
+
self.hidden_dropout = hidden_dropout
|
| 118 |
+
self.attention_dropout = attention_dropout
|
| 119 |
+
self.final_dropout = final_dropout
|
| 120 |
+
self.feat_proj_dropout = feat_proj_dropout
|
| 121 |
+
self.layer_norm_eps = layer_norm_eps
|
| 122 |
+
self.layer_drop = layer_drop
|
| 123 |
+
|
| 124 |
+
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
| 125 |
+
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
| 126 |
+
|
| 127 |
+
self.sampling_rate = sampling_rate
|
| 128 |
+
self.do_stable_layer_norm = do_stable_layer_norm
|
| 129 |
+
|
| 130 |
+
self.num_clusters = num_clusters
|
| 131 |
+
self.feature_type = feature_type
|
| 132 |
+
|
| 133 |
+
# Store any additional kwargs
|
| 134 |
+
for key, value in kwargs.items():
|
| 135 |
+
setattr(self, key, value)
|
| 136 |
+
|
| 137 |
+
def to_dict(self) -> dict:
|
| 138 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 139 |
+
return {
|
| 140 |
+
# Model identification
|
| 141 |
+
"model_type": self.model_type,
|
| 142 |
+
"architectures": self.architectures,
|
| 143 |
+
|
| 144 |
+
# Convolutional feature extractor
|
| 145 |
+
"conv_dim": self.conv_dim,
|
| 146 |
+
"conv_kernel": self.conv_kernel,
|
| 147 |
+
"conv_stride": self.conv_stride,
|
| 148 |
+
"conv_bias": self.conv_bias,
|
| 149 |
+
"extractor_mode": self.extractor_mode,
|
| 150 |
+
|
| 151 |
+
# Transformer encoder
|
| 152 |
+
"hidden_size": self.hidden_size,
|
| 153 |
+
"num_hidden_layers": self.num_hidden_layers,
|
| 154 |
+
"num_attention_heads": self.num_attention_heads,
|
| 155 |
+
"intermediate_size": self.intermediate_size,
|
| 156 |
+
"hidden_dropout": self.hidden_dropout,
|
| 157 |
+
"attention_dropout": self.attention_dropout,
|
| 158 |
+
"final_dropout": self.final_dropout,
|
| 159 |
+
"feat_proj_dropout": self.feat_proj_dropout,
|
| 160 |
+
"layer_norm_eps": self.layer_norm_eps,
|
| 161 |
+
"layer_drop": self.layer_drop,
|
| 162 |
+
|
| 163 |
+
# Positional encoding
|
| 164 |
+
"num_conv_pos_embeddings": self.num_conv_pos_embeddings,
|
| 165 |
+
"num_conv_pos_embedding_groups": self.num_conv_pos_embedding_groups,
|
| 166 |
+
|
| 167 |
+
# Audio config
|
| 168 |
+
"sampling_rate": self.sampling_rate,
|
| 169 |
+
"do_stable_layer_norm": self.do_stable_layer_norm,
|
| 170 |
+
|
| 171 |
+
# Training reference
|
| 172 |
+
"num_clusters": self.num_clusters,
|
| 173 |
+
"feature_type": self.feature_type,
|
| 174 |
+
|
| 175 |
+
# Transformers compatibility
|
| 176 |
+
"transformers_version": "4.36.0",
|
| 177 |
+
"torch_dtype": "float32",
|
| 178 |
+
|
| 179 |
+
# Auto-mapping for custom code
|
| 180 |
+
"auto_map": {
|
| 181 |
+
"AutoConfig": "configuration_delulu.DELULUConfig",
|
| 182 |
+
"AutoModel": "modeling_delulu.DELULUModel"
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
def save_pretrained(self, save_directory: str):
|
| 187 |
+
"""Save config to directory."""
|
| 188 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 189 |
+
config_path = os.path.join(save_directory, "config.json")
|
| 190 |
+
with open(config_path, "w") as f:
|
| 191 |
+
json.dump(self.to_dict(), f, indent=2)
|
| 192 |
+
logger.info(f"Config saved to: {config_path}")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# =============================================================================
|
| 196 |
+
# Weight Mapping: torchaudio -> Hugging Face
|
| 197 |
+
# =============================================================================
|
| 198 |
+
|
| 199 |
+
def create_weight_mapping() -> dict:
|
| 200 |
+
"""
|
| 201 |
+
Create mapping from torchaudio wav2vec2_model keys to Hugging Face format.
|
| 202 |
+
|
| 203 |
+
torchaudio structure:
|
| 204 |
+
feature_extractor.conv_layers.{i}.{0,1,2}...
|
| 205 |
+
encoder.feature_projection.{projection,layer_norm}...
|
| 206 |
+
encoder.transformer.pos_conv_embed...
|
| 207 |
+
encoder.transformer.layers.{i}.{attention,feed_forward,layer_norms}...
|
| 208 |
+
encoder.transformer.layer_norm...
|
| 209 |
+
|
| 210 |
+
HuggingFace structure:
|
| 211 |
+
feature_extractor.conv_layers.{i}.{conv,layer_norm}...
|
| 212 |
+
feature_projection.{projection,layer_norm}...
|
| 213 |
+
encoder.pos_conv_embed...
|
| 214 |
+
encoder.layers.{i}.{attention,feed_forward,layer_norm}...
|
| 215 |
+
encoder.layer_norm...
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
# This will be populated dynamically based on actual keys
|
| 219 |
+
mapping = {}
|
| 220 |
+
return mapping
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def convert_torchaudio_to_hf(state_dict: dict) -> dict:
|
| 224 |
+
"""
|
| 225 |
+
Convert torchaudio wav2vec2_model state dict to Hugging Face format.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
state_dict: State dict from torchaudio model
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Converted state dict in HuggingFace format
|
| 232 |
+
"""
|
| 233 |
+
new_state_dict = OrderedDict()
|
| 234 |
+
|
| 235 |
+
for key, value in state_dict.items():
|
| 236 |
+
new_key = key
|
| 237 |
+
|
| 238 |
+
# Feature extractor conv layers
|
| 239 |
+
# torchaudio: feature_extractor.conv_layers.0.0.weight -> hf: feature_extractor.conv_layers.0.conv.weight
|
| 240 |
+
if "feature_extractor.conv_layers" in key:
|
| 241 |
+
# Handle conv layer structure: .{layer_idx}.0. -> .{layer_idx}.conv.
|
| 242 |
+
# Handle norm layer structure: .{layer_idx}.2.1. -> .{layer_idx}.layer_norm.
|
| 243 |
+
parts = key.split(".")
|
| 244 |
+
layer_idx = parts[2]
|
| 245 |
+
|
| 246 |
+
if ".0." in key and "weight" in key:
|
| 247 |
+
# Convolution weight
|
| 248 |
+
new_key = f"delulu.feature_extractor.conv_layers.{layer_idx}.conv.weight"
|
| 249 |
+
elif ".2.1." in key or (".1." in key and "layer_norm" not in key):
|
| 250 |
+
# Group norm / layer norm
|
| 251 |
+
if "weight" in key:
|
| 252 |
+
new_key = f"delulu.feature_extractor.conv_layers.{layer_idx}.layer_norm.weight"
|
| 253 |
+
elif "bias" in key:
|
| 254 |
+
new_key = f"delulu.feature_extractor.conv_layers.{layer_idx}.layer_norm.bias"
|
| 255 |
+
else:
|
| 256 |
+
new_key = f"delulu.{key}"
|
| 257 |
+
|
| 258 |
+
# Feature projection
|
| 259 |
+
elif "encoder.feature_projection" in key:
|
| 260 |
+
new_key = key.replace("encoder.feature_projection", "delulu.feature_projection")
|
| 261 |
+
|
| 262 |
+
# Positional conv embedding
|
| 263 |
+
elif "encoder.transformer.pos_conv_embed" in key:
|
| 264 |
+
new_key = key.replace("encoder.transformer.pos_conv_embed", "delulu.encoder.pos_conv_embed")
|
| 265 |
+
|
| 266 |
+
# Transformer layers
|
| 267 |
+
elif "encoder.transformer.layers" in key:
|
| 268 |
+
new_key = key.replace("encoder.transformer.layers", "delulu.encoder.layers")
|
| 269 |
+
|
| 270 |
+
# Attention mappings
|
| 271 |
+
new_key = new_key.replace(".attention.k_proj", ".attention.k_proj")
|
| 272 |
+
new_key = new_key.replace(".attention.v_proj", ".attention.v_proj")
|
| 273 |
+
new_key = new_key.replace(".attention.q_proj", ".attention.q_proj")
|
| 274 |
+
new_key = new_key.replace(".attention.out_proj", ".attention.out_proj")
|
| 275 |
+
|
| 276 |
+
# Feed forward mappings
|
| 277 |
+
new_key = new_key.replace(".feed_forward.intermediate_dense", ".feed_forward.intermediate_dense")
|
| 278 |
+
new_key = new_key.replace(".feed_forward.output_dense", ".feed_forward.output_dense")
|
| 279 |
+
|
| 280 |
+
# Layer norm mappings
|
| 281 |
+
new_key = new_key.replace(".layer_norms.0", ".layer_norm")
|
| 282 |
+
new_key = new_key.replace(".layer_norms.1", ".final_layer_norm")
|
| 283 |
+
|
| 284 |
+
# Final layer norm
|
| 285 |
+
elif "encoder.transformer.layer_norm" in key:
|
| 286 |
+
new_key = key.replace("encoder.transformer.layer_norm", "delulu.encoder.layer_norm")
|
| 287 |
+
|
| 288 |
+
# Mask embedding (if present)
|
| 289 |
+
elif "mask_emb" in key:
|
| 290 |
+
new_key = f"delulu.{key}"
|
| 291 |
+
|
| 292 |
+
# Auxiliary head (if present)
|
| 293 |
+
elif "aux" in key:
|
| 294 |
+
new_key = key # Keep as is for now
|
| 295 |
+
|
| 296 |
+
else:
|
| 297 |
+
# Default: add delulu prefix
|
| 298 |
+
new_key = f"delulu.{key}"
|
| 299 |
+
|
| 300 |
+
new_state_dict[new_key] = value
|
| 301 |
+
|
| 302 |
+
if new_key != key:
|
| 303 |
+
logger.debug(f"Mapped: {key} -> {new_key}")
|
| 304 |
+
|
| 305 |
+
return new_state_dict
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def convert_simple_format(state_dict: dict) -> dict:
|
| 309 |
+
"""
|
| 310 |
+
Simple conversion that just renames keys minimally.
|
| 311 |
+
Suitable for direct loading with torchaudio models.
|
| 312 |
+
"""
|
| 313 |
+
new_state_dict = OrderedDict()
|
| 314 |
+
|
| 315 |
+
for key, value in state_dict.items():
|
| 316 |
+
# Just add a model prefix for organization
|
| 317 |
+
new_key = f"model.{key}" if not key.startswith("model.") else key
|
| 318 |
+
new_state_dict[new_key] = value
|
| 319 |
+
|
| 320 |
+
return new_state_dict
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# =============================================================================
|
| 324 |
+
# Checkpoint Loading
|
| 325 |
+
# =============================================================================
|
| 326 |
+
|
| 327 |
+
def load_lightning_checkpoint(checkpoint_path: str) -> Tuple[dict, dict]:
|
| 328 |
+
"""
|
| 329 |
+
Load PyTorch Lightning checkpoint and extract model state dict.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
checkpoint_path: Path to .ckpt file
|
| 333 |
+
|
| 334 |
+
Returns:
|
| 335 |
+
Tuple of (state_dict, hyperparameters)
|
| 336 |
+
"""
|
| 337 |
+
logger.info(f"Loading checkpoint: {checkpoint_path}")
|
| 338 |
+
|
| 339 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 340 |
+
|
| 341 |
+
# Extract state dict
|
| 342 |
+
if "state_dict" in checkpoint:
|
| 343 |
+
state_dict = checkpoint["state_dict"]
|
| 344 |
+
else:
|
| 345 |
+
state_dict = checkpoint
|
| 346 |
+
|
| 347 |
+
# Extract hyperparameters if available
|
| 348 |
+
hparams = checkpoint.get("hyper_parameters", {})
|
| 349 |
+
|
| 350 |
+
# Clean up state dict keys (remove Lightning prefixes)
|
| 351 |
+
cleaned_state_dict = OrderedDict()
|
| 352 |
+
for key, value in state_dict.items():
|
| 353 |
+
new_key = key
|
| 354 |
+
|
| 355 |
+
# Remove common Lightning prefixes
|
| 356 |
+
if key.startswith("model.wav2vec2."):
|
| 357 |
+
new_key = key.replace("model.wav2vec2.", "")
|
| 358 |
+
elif key.startswith("model."):
|
| 359 |
+
new_key = key.replace("model.", "")
|
| 360 |
+
|
| 361 |
+
# Skip auxiliary heads unless needed
|
| 362 |
+
if "aux" in new_key:
|
| 363 |
+
logger.debug(f"Skipping auxiliary layer: {key}")
|
| 364 |
+
continue
|
| 365 |
+
|
| 366 |
+
cleaned_state_dict[new_key] = value
|
| 367 |
+
|
| 368 |
+
logger.info(f"Loaded {len(cleaned_state_dict)} parameters")
|
| 369 |
+
|
| 370 |
+
return cleaned_state_dict, hparams
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def verify_state_dict(state_dict: dict) -> bool:
|
| 374 |
+
"""
|
| 375 |
+
Verify the state dict has expected DELULU components.
|
| 376 |
+
"""
|
| 377 |
+
expected_prefixes = [
|
| 378 |
+
"feature_extractor",
|
| 379 |
+
"encoder",
|
| 380 |
+
]
|
| 381 |
+
|
| 382 |
+
found_prefixes = set()
|
| 383 |
+
for key in state_dict.keys():
|
| 384 |
+
for prefix in expected_prefixes:
|
| 385 |
+
if prefix in key:
|
| 386 |
+
found_prefixes.add(prefix)
|
| 387 |
+
|
| 388 |
+
missing = set(expected_prefixes) - found_prefixes
|
| 389 |
+
if missing:
|
| 390 |
+
logger.warning(f"Missing expected components: {missing}")
|
| 391 |
+
return False
|
| 392 |
+
|
| 393 |
+
logger.info("✓ State dict contains expected components")
|
| 394 |
+
return True
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# =============================================================================
|
| 398 |
+
# Main Conversion
|
| 399 |
+
# =============================================================================
|
| 400 |
+
|
| 401 |
+
def convert_checkpoint(
|
| 402 |
+
checkpoint_path: str,
|
| 403 |
+
output_dir: str,
|
| 404 |
+
save_safetensors_format: bool = True,
|
| 405 |
+
save_bin_format: bool = True,
|
| 406 |
+
verify: bool = True
|
| 407 |
+
) -> None:
|
| 408 |
+
"""
|
| 409 |
+
Convert DELULU checkpoint to Hugging Face format.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
checkpoint_path: Path to input .ckpt file
|
| 413 |
+
output_dir: Output directory for converted model
|
| 414 |
+
save_safetensors_format: Save in safetensors format
|
| 415 |
+
save_bin_format: Save in pytorch_model.bin format
|
| 416 |
+
verify: Verify the conversion
|
| 417 |
+
"""
|
| 418 |
+
output_path = Path(output_dir)
|
| 419 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 420 |
+
|
| 421 |
+
# Step 1: Load checkpoint
|
| 422 |
+
state_dict, hparams = load_lightning_checkpoint(checkpoint_path)
|
| 423 |
+
|
| 424 |
+
# Step 2: Verify state dict
|
| 425 |
+
if verify:
|
| 426 |
+
verify_state_dict(state_dict)
|
| 427 |
+
|
| 428 |
+
# Step 3: Create and save config
|
| 429 |
+
logger.info("Creating DELULU config...")
|
| 430 |
+
config = DELULUConfig(
|
| 431 |
+
# Use DELULU's custom conv config
|
| 432 |
+
conv_dim=[512, 512, 512, 512, 512, 512, 512],
|
| 433 |
+
conv_kernel=[10, 3, 3, 3, 3, 2, 2],
|
| 434 |
+
conv_stride=[4, 2, 2, 2, 2, 2, 2], # Key difference!
|
| 435 |
+
)
|
| 436 |
+
config.save_pretrained(output_dir)
|
| 437 |
+
|
| 438 |
+
# Step 4: Convert state dict format (minimal conversion)
|
| 439 |
+
logger.info("Converting state dict format...")
|
| 440 |
+
# Keep the original format since it's compatible with torchaudio loading
|
| 441 |
+
converted_state_dict = state_dict
|
| 442 |
+
|
| 443 |
+
# Step 5: Save weights
|
| 444 |
+
if save_safetensors_format and SAFETENSORS_AVAILABLE:
|
| 445 |
+
safetensors_path = output_path / "model.safetensors"
|
| 446 |
+
logger.info(f"Saving safetensors to: {safetensors_path}")
|
| 447 |
+
save_safetensors(converted_state_dict, str(safetensors_path))
|
| 448 |
+
|
| 449 |
+
if save_bin_format:
|
| 450 |
+
bin_path = output_path / "pytorch_model.bin"
|
| 451 |
+
logger.info(f"Saving pytorch_model.bin to: {bin_path}")
|
| 452 |
+
torch.save(converted_state_dict, str(bin_path))
|
| 453 |
+
|
| 454 |
+
# Step 6: Create additional files
|
| 455 |
+
create_additional_files(output_path, config)
|
| 456 |
+
|
| 457 |
+
# Step 7: Print summary
|
| 458 |
+
print_conversion_summary(checkpoint_path, output_dir, converted_state_dict)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def create_additional_files(output_path: Path, config: DELULUConfig) -> None:
|
| 462 |
+
"""Create additional files needed for Hugging Face model."""
|
| 463 |
+
|
| 464 |
+
# Create preprocessor_config.json
|
| 465 |
+
preprocessor_config = {
|
| 466 |
+
"do_normalize": True,
|
| 467 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
| 468 |
+
"feature_size": 1,
|
| 469 |
+
"padding_side": "right",
|
| 470 |
+
"padding_value": 0.0,
|
| 471 |
+
"return_attention_mask": True,
|
| 472 |
+
"sampling_rate": config.sampling_rate,
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
with open(output_path / "preprocessor_config.json", "w") as f:
|
| 476 |
+
json.dump(preprocessor_config, f, indent=2)
|
| 477 |
+
logger.info("Created preprocessor_config.json")
|
| 478 |
+
|
| 479 |
+
# Create a simple modeling file for reference
|
| 480 |
+
modeling_code = '''"""
|
| 481 |
+
DELULU Model - Minimal Loading Example
|
| 482 |
+
|
| 483 |
+
This file shows how to load DELULU weights with torchaudio.
|
| 484 |
+
For full Hugging Face Transformers integration, see the modeling_delulu.py file.
|
| 485 |
+
"""
|
| 486 |
+
|
| 487 |
+
import torch
|
| 488 |
+
from torchaudio.models.wav2vec2 import wav2vec2_model
|
| 489 |
+
|
| 490 |
+
# DELULU configuration
|
| 491 |
+
DELULU_CONV_CONFIG = [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
| 492 |
+
|
| 493 |
+
def load_delulu(checkpoint_path: str = None, weights_path: str = None):
|
| 494 |
+
"""
|
| 495 |
+
Load DELULU model.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
checkpoint_path: Path to original .ckpt file (PyTorch Lightning format)
|
| 499 |
+
weights_path: Path to pytorch_model.bin (Hugging Face format)
|
| 500 |
+
|
| 501 |
+
Returns:
|
| 502 |
+
DELULU model ready for inference
|
| 503 |
+
"""
|
| 504 |
+
model = wav2vec2_model(
|
| 505 |
+
extractor_mode="group_norm",
|
| 506 |
+
extractor_conv_layer_config=DELULU_CONV_CONFIG,
|
| 507 |
+
extractor_conv_bias=False,
|
| 508 |
+
encoder_embed_dim=768,
|
| 509 |
+
encoder_projection_dropout=0.1,
|
| 510 |
+
encoder_pos_conv_kernel=128,
|
| 511 |
+
encoder_pos_conv_groups=16,
|
| 512 |
+
encoder_num_layers=12,
|
| 513 |
+
encoder_num_heads=12,
|
| 514 |
+
encoder_attention_dropout=0.1,
|
| 515 |
+
encoder_ff_interm_features=3072,
|
| 516 |
+
encoder_ff_interm_dropout=0.1,
|
| 517 |
+
encoder_dropout=0.1,
|
| 518 |
+
encoder_layer_norm_first=False,
|
| 519 |
+
encoder_layer_drop=0.05,
|
| 520 |
+
aux_num_out=None,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
if checkpoint_path:
|
| 524 |
+
# Load from original Lightning checkpoint
|
| 525 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 526 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
| 527 |
+
|
| 528 |
+
# Clean keys
|
| 529 |
+
new_state_dict = {}
|
| 530 |
+
for k, v in state_dict.items():
|
| 531 |
+
if "model.wav2vec2" in k:
|
| 532 |
+
new_state_dict[k.replace("model.wav2vec2.", "")] = v
|
| 533 |
+
elif not k.startswith("aux"):
|
| 534 |
+
new_state_dict[k] = v
|
| 535 |
+
|
| 536 |
+
model.load_state_dict(new_state_dict, strict=False)
|
| 537 |
+
|
| 538 |
+
elif weights_path:
|
| 539 |
+
# Load from Hugging Face format
|
| 540 |
+
state_dict = torch.load(weights_path, map_location="cpu")
|
| 541 |
+
model.load_state_dict(state_dict, strict=False)
|
| 542 |
+
|
| 543 |
+
return model
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def extract_features(model, waveform: torch.Tensor) -> torch.Tensor:
|
| 547 |
+
"""
|
| 548 |
+
Extract speaker features from audio waveform.
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
model: DELULU model
|
| 552 |
+
waveform: Audio tensor of shape (batch, samples) at 16kHz
|
| 553 |
+
|
| 554 |
+
Returns:
|
| 555 |
+
Features of shape (batch, time, 768)
|
| 556 |
+
"""
|
| 557 |
+
model.eval()
|
| 558 |
+
with torch.no_grad():
|
| 559 |
+
features, _ = model.extract_features(waveform)
|
| 560 |
+
# Return last layer features
|
| 561 |
+
return features[-1]
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
if __name__ == "__main__":
|
| 565 |
+
# Example usage
|
| 566 |
+
import sys
|
| 567 |
+
|
| 568 |
+
if len(sys.argv) > 1:
|
| 569 |
+
model = load_delulu(weights_path=sys.argv[1])
|
| 570 |
+
print(f"Model loaded successfully!")
|
| 571 |
+
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 572 |
+
else:
|
| 573 |
+
print("Usage: python load_delulu.py path/to/pytorch_model.bin")
|
| 574 |
+
'''
|
| 575 |
+
|
| 576 |
+
with open(output_path / "load_delulu.py", "w") as f:
|
| 577 |
+
f.write(modeling_code)
|
| 578 |
+
logger.info("Created load_delulu.py")
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def print_conversion_summary(
|
| 582 |
+
input_path: str,
|
| 583 |
+
output_dir: str,
|
| 584 |
+
state_dict: dict
|
| 585 |
+
) -> None:
|
| 586 |
+
"""Print summary of the conversion."""
|
| 587 |
+
|
| 588 |
+
total_params = sum(p.numel() for p in state_dict.values())
|
| 589 |
+
|
| 590 |
+
print("\n" + "=" * 60)
|
| 591 |
+
print("DELULU Checkpoint Conversion Complete!")
|
| 592 |
+
print("=" * 60)
|
| 593 |
+
print(f"\nInput: {input_path}")
|
| 594 |
+
print(f"Output: {output_dir}")
|
| 595 |
+
print(f"\nModel Statistics:")
|
| 596 |
+
print(f" - Total parameters: {total_params:,}")
|
| 597 |
+
print(f" - Parameter tensors: {len(state_dict)}")
|
| 598 |
+
print(f"\nOutput Files:")
|
| 599 |
+
|
| 600 |
+
output_path = Path(output_dir)
|
| 601 |
+
for f in sorted(output_path.iterdir()):
|
| 602 |
+
size_mb = f.stat().st_size / 1024 / 1024
|
| 603 |
+
print(f" - {f.name}: {size_mb:.2f} MB")
|
| 604 |
+
|
| 605 |
+
print(f"\nNext Steps:")
|
| 606 |
+
print(f" 1. Test loading: python {output_dir}/load_delulu.py {output_dir}/pytorch_model.bin")
|
| 607 |
+
print(f" 2. Upload to HF: python upload_delulu_to_hf.py --checkpoint-dir {output_dir} --repo-id YOUR_USERNAME/DELULU")
|
| 608 |
+
print("=" * 60 + "\n")
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# =============================================================================
|
| 612 |
+
# CLI Interface
|
| 613 |
+
# =============================================================================
|
| 614 |
+
|
| 615 |
+
def parse_args() -> argparse.Namespace:
|
| 616 |
+
parser = argparse.ArgumentParser(
|
| 617 |
+
description="Convert DELULU checkpoint to Hugging Face format",
|
| 618 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 619 |
+
epilog="""
|
| 620 |
+
Examples:
|
| 621 |
+
# Basic conversion
|
| 622 |
+
python convert_delulu_to_hf.py \\
|
| 623 |
+
--checkpoint /path/to/epoch=45-step=400000.ckpt \\
|
| 624 |
+
--output-dir ./delulu_hf_model
|
| 625 |
+
|
| 626 |
+
# Save only safetensors format
|
| 627 |
+
python convert_delulu_to_hf.py \\
|
| 628 |
+
--checkpoint /path/to/checkpoint.ckpt \\
|
| 629 |
+
--output-dir ./delulu_hf_model \\
|
| 630 |
+
--no-bin
|
| 631 |
+
|
| 632 |
+
# Skip verification
|
| 633 |
+
python convert_delulu_to_hf.py \\
|
| 634 |
+
--checkpoint /path/to/checkpoint.ckpt \\
|
| 635 |
+
--output-dir ./delulu_hf_model \\
|
| 636 |
+
--no-verify
|
| 637 |
+
"""
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
parser.add_argument(
|
| 641 |
+
"--checkpoint", "-c",
|
| 642 |
+
type=str,
|
| 643 |
+
required=True,
|
| 644 |
+
help="Path to DELULU checkpoint (.ckpt file)"
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
parser.add_argument(
|
| 648 |
+
"--output-dir", "-o",
|
| 649 |
+
type=str,
|
| 650 |
+
required=True,
|
| 651 |
+
help="Output directory for converted model"
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
parser.add_argument(
|
| 655 |
+
"--no-safetensors",
|
| 656 |
+
action="store_true",
|
| 657 |
+
help="Don't save in safetensors format"
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
parser.add_argument(
|
| 661 |
+
"--no-bin",
|
| 662 |
+
action="store_true",
|
| 663 |
+
help="Don't save pytorch_model.bin"
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
parser.add_argument(
|
| 667 |
+
"--no-verify",
|
| 668 |
+
action="store_true",
|
| 669 |
+
help="Skip state dict verification"
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
parser.add_argument(
|
| 673 |
+
"--verbose", "-v",
|
| 674 |
+
action="store_true",
|
| 675 |
+
help="Enable verbose logging"
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
return parser.parse_args()
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
def main():
|
| 682 |
+
args = parse_args()
|
| 683 |
+
|
| 684 |
+
if args.verbose:
|
| 685 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
| 686 |
+
|
| 687 |
+
convert_checkpoint(
|
| 688 |
+
checkpoint_path=args.checkpoint,
|
| 689 |
+
output_dir=args.output_dir,
|
| 690 |
+
save_safetensors_format=not args.no_safetensors,
|
| 691 |
+
save_bin_format=not args.no_bin,
|
| 692 |
+
verify=not args.no_verify
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
if __name__ == "__main__":
|
| 697 |
+
main()
|
checksums.json → delulu_hf_model/checksums.json
RENAMED
|
File without changes
|
delulu_hf_model/config.json
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "delulu",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"DELULUModel"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_delulu.DELULUConfig",
|
| 8 |
+
"AutoModel": "modeling_delulu.DELULUModel"
|
| 9 |
+
},
|
| 10 |
+
"conv_dim": [
|
| 11 |
+
512,
|
| 12 |
+
512,
|
| 13 |
+
512,
|
| 14 |
+
512,
|
| 15 |
+
512,
|
| 16 |
+
512,
|
| 17 |
+
512
|
| 18 |
+
],
|
| 19 |
+
"conv_kernel": [
|
| 20 |
+
10,
|
| 21 |
+
3,
|
| 22 |
+
3,
|
| 23 |
+
3,
|
| 24 |
+
3,
|
| 25 |
+
2,
|
| 26 |
+
2
|
| 27 |
+
],
|
| 28 |
+
"conv_stride": [
|
| 29 |
+
4,
|
| 30 |
+
2,
|
| 31 |
+
2,
|
| 32 |
+
2,
|
| 33 |
+
2,
|
| 34 |
+
2,
|
| 35 |
+
2
|
| 36 |
+
],
|
| 37 |
+
"conv_bias": false,
|
| 38 |
+
"extractor_mode": "group_norm",
|
| 39 |
+
"hidden_size": 768,
|
| 40 |
+
"num_hidden_layers": 12,
|
| 41 |
+
"num_attention_heads": 12,
|
| 42 |
+
"intermediate_size": 3072,
|
| 43 |
+
"hidden_dropout": 0.1,
|
| 44 |
+
"attention_dropout": 0.1,
|
| 45 |
+
"final_dropout": 0.1,
|
| 46 |
+
"feat_proj_dropout": 0.1,
|
| 47 |
+
"layer_norm_eps": 1e-05,
|
| 48 |
+
"layer_drop": 0.05,
|
| 49 |
+
"num_conv_pos_embeddings": 128,
|
| 50 |
+
"num_conv_pos_embedding_groups": 16,
|
| 51 |
+
"sampling_rate": 16000,
|
| 52 |
+
"do_stable_layer_norm": false,
|
| 53 |
+
"num_clusters": 256,
|
| 54 |
+
"feature_type": "redimnet",
|
| 55 |
+
"num_feat_extract_layers": 7,
|
| 56 |
+
"pad_token_id": 0,
|
| 57 |
+
"bos_token_id": 1,
|
| 58 |
+
"eos_token_id": 2,
|
| 59 |
+
"torch_dtype": "float32"
|
| 60 |
+
}
|
delulu_hf_model/configuration_delulu.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DELULU Configuration"""
|
| 2 |
+
|
| 3 |
+
from transformers import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DELULUConfig(PretrainedConfig):
|
| 7 |
+
"""Configuration class for DELULU model."""
|
| 8 |
+
|
| 9 |
+
model_type = "delulu"
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
conv_dim=None,
|
| 14 |
+
conv_kernel=None,
|
| 15 |
+
conv_stride=None,
|
| 16 |
+
conv_bias=False,
|
| 17 |
+
extractor_mode="group_norm",
|
| 18 |
+
hidden_size=768,
|
| 19 |
+
num_hidden_layers=12,
|
| 20 |
+
num_attention_heads=12,
|
| 21 |
+
intermediate_size=3072,
|
| 22 |
+
hidden_dropout=0.1,
|
| 23 |
+
attention_dropout=0.1,
|
| 24 |
+
final_dropout=0.1,
|
| 25 |
+
feat_proj_dropout=0.1,
|
| 26 |
+
layer_norm_eps=1e-5,
|
| 27 |
+
layer_drop=0.05,
|
| 28 |
+
num_conv_pos_embeddings=128,
|
| 29 |
+
num_conv_pos_embedding_groups=16,
|
| 30 |
+
sampling_rate=16000,
|
| 31 |
+
do_stable_layer_norm=False,
|
| 32 |
+
num_clusters=256,
|
| 33 |
+
feature_type="redimnet",
|
| 34 |
+
pad_token_id=0,
|
| 35 |
+
bos_token_id=1,
|
| 36 |
+
eos_token_id=2,
|
| 37 |
+
**kwargs
|
| 38 |
+
):
|
| 39 |
+
super().__init__(
|
| 40 |
+
pad_token_id=pad_token_id,
|
| 41 |
+
bos_token_id=bos_token_id,
|
| 42 |
+
eos_token_id=eos_token_id,
|
| 43 |
+
**kwargs
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# DELULU conv config: [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
| 47 |
+
self.conv_dim = conv_dim or [512, 512, 512, 512, 512, 512, 512]
|
| 48 |
+
self.conv_kernel = conv_kernel or [10, 3, 3, 3, 3, 2, 2]
|
| 49 |
+
self.conv_stride = conv_stride or [4, 2, 2, 2, 2, 2, 2]
|
| 50 |
+
self.conv_bias = conv_bias
|
| 51 |
+
self.extractor_mode = extractor_mode
|
| 52 |
+
|
| 53 |
+
self.hidden_size = hidden_size
|
| 54 |
+
self.num_hidden_layers = num_hidden_layers
|
| 55 |
+
self.num_attention_heads = num_attention_heads
|
| 56 |
+
self.intermediate_size = intermediate_size
|
| 57 |
+
self.hidden_dropout = hidden_dropout
|
| 58 |
+
self.attention_dropout = attention_dropout
|
| 59 |
+
self.final_dropout = final_dropout
|
| 60 |
+
self.feat_proj_dropout = feat_proj_dropout
|
| 61 |
+
self.layer_norm_eps = layer_norm_eps
|
| 62 |
+
self.layer_drop = layer_drop
|
| 63 |
+
|
| 64 |
+
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
| 65 |
+
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
| 66 |
+
|
| 67 |
+
self.sampling_rate = sampling_rate
|
| 68 |
+
self.do_stable_layer_norm = do_stable_layer_norm
|
| 69 |
+
|
| 70 |
+
self.num_clusters = num_clusters
|
| 71 |
+
self.feature_type = feature_type
|
| 72 |
+
|
| 73 |
+
self.num_feat_extract_layers = len(self.conv_dim)
|
model.safetensors → delulu_hf_model/model.safetensors
RENAMED
|
File without changes
|
delulu_hf_model/modeling_delulu.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DELULU Model"""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from typing import Optional, Tuple, Union
|
| 6 |
+
from transformers import PreTrainedModel
|
| 7 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 8 |
+
from .configuration_delulu import DELULUConfig
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from torchaudio.models.wav2vec2 import wav2vec2_model
|
| 12 |
+
TORCHAUDIO_AVAILABLE = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
TORCHAUDIO_AVAILABLE = False
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DELULUModel(PreTrainedModel):
|
| 18 |
+
"""
|
| 19 |
+
DELULU Model for speaker-aware speech representation learning.
|
| 20 |
+
|
| 21 |
+
Example:
|
| 22 |
+
```python
|
| 23 |
+
from transformers import AutoModel
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
|
| 27 |
+
waveform = torch.randn(1, 16000) # 1 second at 16kHz
|
| 28 |
+
outputs = model(waveform)
|
| 29 |
+
features = outputs.last_hidden_state
|
| 30 |
+
```
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
config_class = DELULUConfig
|
| 34 |
+
base_model_prefix = "delulu"
|
| 35 |
+
main_input_name = "input_values"
|
| 36 |
+
|
| 37 |
+
def __init__(self, config: DELULUConfig):
|
| 38 |
+
super().__init__(config)
|
| 39 |
+
self.config = config
|
| 40 |
+
|
| 41 |
+
if not TORCHAUDIO_AVAILABLE:
|
| 42 |
+
raise ImportError("torchaudio is required. Install with: pip install torchaudio")
|
| 43 |
+
|
| 44 |
+
# Build conv config
|
| 45 |
+
conv_layer_config = list(zip(
|
| 46 |
+
config.conv_dim,
|
| 47 |
+
config.conv_kernel,
|
| 48 |
+
config.conv_stride
|
| 49 |
+
))
|
| 50 |
+
|
| 51 |
+
# Create torchaudio model
|
| 52 |
+
self.wav2vec2 = wav2vec2_model(
|
| 53 |
+
extractor_mode=config.extractor_mode,
|
| 54 |
+
extractor_conv_layer_config=conv_layer_config,
|
| 55 |
+
extractor_conv_bias=config.conv_bias,
|
| 56 |
+
encoder_embed_dim=config.hidden_size,
|
| 57 |
+
encoder_projection_dropout=config.feat_proj_dropout,
|
| 58 |
+
encoder_pos_conv_kernel=config.num_conv_pos_embeddings,
|
| 59 |
+
encoder_pos_conv_groups=config.num_conv_pos_embedding_groups,
|
| 60 |
+
encoder_num_layers=config.num_hidden_layers,
|
| 61 |
+
encoder_num_heads=config.num_attention_heads,
|
| 62 |
+
encoder_attention_dropout=config.attention_dropout,
|
| 63 |
+
encoder_ff_interm_features=config.intermediate_size,
|
| 64 |
+
encoder_ff_interm_dropout=config.hidden_dropout,
|
| 65 |
+
encoder_dropout=config.hidden_dropout,
|
| 66 |
+
encoder_layer_norm_first=config.do_stable_layer_norm,
|
| 67 |
+
encoder_layer_drop=config.layer_drop,
|
| 68 |
+
aux_num_out=None,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.post_init()
|
| 72 |
+
|
| 73 |
+
def _init_weights(self, module):
|
| 74 |
+
"""Initialize weights."""
|
| 75 |
+
pass # Handled by torchaudio
|
| 76 |
+
|
| 77 |
+
def forward(
|
| 78 |
+
self,
|
| 79 |
+
input_values: torch.Tensor,
|
| 80 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 81 |
+
output_hidden_states: Optional[bool] = None,
|
| 82 |
+
return_dict: Optional[bool] = None,
|
| 83 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 84 |
+
"""
|
| 85 |
+
Args:
|
| 86 |
+
input_values: Audio waveform (batch, samples) at 16kHz
|
| 87 |
+
attention_mask: Optional attention mask
|
| 88 |
+
output_hidden_states: Whether to return all hidden states
|
| 89 |
+
return_dict: Whether to return BaseModelOutput
|
| 90 |
+
"""
|
| 91 |
+
return_dict = return_dict if return_dict is not None else True
|
| 92 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
| 93 |
+
|
| 94 |
+
if input_values.dim() == 1:
|
| 95 |
+
input_values = input_values.unsqueeze(0)
|
| 96 |
+
|
| 97 |
+
lengths = None
|
| 98 |
+
if attention_mask is not None:
|
| 99 |
+
lengths = attention_mask.sum(dim=-1)
|
| 100 |
+
|
| 101 |
+
if output_hidden_states:
|
| 102 |
+
features, _ = self.wav2vec2.extract_features(input_values, lengths=lengths)
|
| 103 |
+
hidden_states = tuple(features)
|
| 104 |
+
last_hidden_state = features[-1]
|
| 105 |
+
else:
|
| 106 |
+
last_hidden_state, _ = self.wav2vec2(input_values, lengths=lengths)
|
| 107 |
+
hidden_states = None
|
| 108 |
+
|
| 109 |
+
if not return_dict:
|
| 110 |
+
return (last_hidden_state, hidden_states) if hidden_states else (last_hidden_state,)
|
| 111 |
+
|
| 112 |
+
return BaseModelOutput(
|
| 113 |
+
last_hidden_state=last_hidden_state,
|
| 114 |
+
hidden_states=hidden_states,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def extract_features(self, input_values: torch.Tensor):
|
| 118 |
+
"""Extract features from all layers."""
|
| 119 |
+
if input_values.dim() == 1:
|
| 120 |
+
input_values = input_values.unsqueeze(0)
|
| 121 |
+
features, _ = self.wav2vec2.extract_features(input_values)
|
| 122 |
+
return tuple(features)
|
| 123 |
+
|
| 124 |
+
@classmethod
|
| 125 |
+
def _load_pretrained_model_low_mem(cls, *args, **kwargs):
|
| 126 |
+
"""Override to handle custom loading."""
|
| 127 |
+
return super()._load_pretrained_model_low_mem(*args, **kwargs)
|
pytorch_model.bin → delulu_hf_model/pytorch_model.bin
RENAMED
|
File without changes
|
upload_metadata.json → delulu_hf_model/upload_metadata.json
RENAMED
|
File without changes
|
load_delulu.py
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
DELULU Model - Minimal Loading Example
|
| 3 |
-
|
| 4 |
-
This file shows how to load DELULU weights with torchaudio.
|
| 5 |
-
For full Hugging Face Transformers integration, see the modeling_delulu.py file.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
from torchaudio.models.wav2vec2 import wav2vec2_model
|
| 10 |
-
|
| 11 |
-
# DELULU configuration
|
| 12 |
-
DELULU_CONV_CONFIG = [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
| 13 |
-
|
| 14 |
-
def load_delulu(checkpoint_path: str = None, weights_path: str = None):
|
| 15 |
-
"""
|
| 16 |
-
Load DELULU model.
|
| 17 |
-
|
| 18 |
-
Args:
|
| 19 |
-
checkpoint_path: Path to original .ckpt file (PyTorch Lightning format)
|
| 20 |
-
weights_path: Path to pytorch_model.bin (Hugging Face format)
|
| 21 |
-
|
| 22 |
-
Returns:
|
| 23 |
-
DELULU model ready for inference
|
| 24 |
-
"""
|
| 25 |
-
model = wav2vec2_model(
|
| 26 |
-
extractor_mode="group_norm",
|
| 27 |
-
extractor_conv_layer_config=DELULU_CONV_CONFIG,
|
| 28 |
-
extractor_conv_bias=False,
|
| 29 |
-
encoder_embed_dim=768,
|
| 30 |
-
encoder_projection_dropout=0.1,
|
| 31 |
-
encoder_pos_conv_kernel=128,
|
| 32 |
-
encoder_pos_conv_groups=16,
|
| 33 |
-
encoder_num_layers=12,
|
| 34 |
-
encoder_num_heads=12,
|
| 35 |
-
encoder_attention_dropout=0.1,
|
| 36 |
-
encoder_ff_interm_features=3072,
|
| 37 |
-
encoder_ff_interm_dropout=0.1,
|
| 38 |
-
encoder_dropout=0.1,
|
| 39 |
-
encoder_layer_norm_first=False,
|
| 40 |
-
encoder_layer_drop=0.05,
|
| 41 |
-
aux_num_out=None,
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
if checkpoint_path:
|
| 45 |
-
# Load from original Lightning checkpoint
|
| 46 |
-
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 47 |
-
state_dict = checkpoint.get("state_dict", checkpoint)
|
| 48 |
-
|
| 49 |
-
# Clean keys
|
| 50 |
-
new_state_dict = {}
|
| 51 |
-
for k, v in state_dict.items():
|
| 52 |
-
if "model.wav2vec2" in k:
|
| 53 |
-
new_state_dict[k.replace("model.wav2vec2.", "")] = v
|
| 54 |
-
elif not k.startswith("aux"):
|
| 55 |
-
new_state_dict[k] = v
|
| 56 |
-
|
| 57 |
-
model.load_state_dict(new_state_dict, strict=False)
|
| 58 |
-
|
| 59 |
-
elif weights_path:
|
| 60 |
-
# Load from Hugging Face format
|
| 61 |
-
state_dict = torch.load(weights_path, map_location="cpu")
|
| 62 |
-
model.load_state_dict(state_dict, strict=False)
|
| 63 |
-
|
| 64 |
-
return model
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def extract_features(model, waveform: torch.Tensor) -> torch.Tensor:
|
| 68 |
-
"""
|
| 69 |
-
Extract speaker features from audio waveform.
|
| 70 |
-
|
| 71 |
-
Args:
|
| 72 |
-
model: DELULU model
|
| 73 |
-
waveform: Audio tensor of shape (batch, samples) at 16kHz
|
| 74 |
-
|
| 75 |
-
Returns:
|
| 76 |
-
Features of shape (batch, time, 768)
|
| 77 |
-
"""
|
| 78 |
-
model.eval()
|
| 79 |
-
with torch.no_grad():
|
| 80 |
-
features, _ = model.extract_features(waveform)
|
| 81 |
-
# Return last layer features
|
| 82 |
-
return features[-1]
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
if __name__ == "__main__":
|
| 86 |
-
# Example usage
|
| 87 |
-
import sys
|
| 88 |
-
|
| 89 |
-
if len(sys.argv) > 1:
|
| 90 |
-
model = load_delulu(weights_path=sys.argv[1])
|
| 91 |
-
print(f"Model loaded successfully!")
|
| 92 |
-
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 93 |
-
else:
|
| 94 |
-
print("Usage: python load_delulu.py path/to/pytorch_model.bin")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeling_delulu.py
CHANGED
|
@@ -1,12 +1,27 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
from typing import Optional, Tuple, Union
|
|
|
|
|
|
|
| 6 |
from transformers import PreTrainedModel
|
| 7 |
from transformers.modeling_outputs import BaseModelOutput
|
|
|
|
| 8 |
from .configuration_delulu import DELULUConfig
|
| 9 |
|
|
|
|
| 10 |
try:
|
| 11 |
from torchaudio.models.wav2vec2 import wav2vec2_model
|
| 12 |
TORCHAUDIO_AVAILABLE = True
|
|
@@ -14,41 +29,79 @@ except ImportError:
|
|
| 14 |
TORCHAUDIO_AVAILABLE = False
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
class DELULUModel(PreTrainedModel):
|
| 18 |
"""
|
| 19 |
DELULU Model for speaker-aware speech representation learning.
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
Example:
|
| 22 |
```python
|
| 23 |
from transformers import AutoModel
|
| 24 |
import torch
|
| 25 |
|
|
|
|
| 26 |
model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
```
|
| 31 |
"""
|
| 32 |
|
| 33 |
config_class = DELULUConfig
|
| 34 |
base_model_prefix = "delulu"
|
| 35 |
main_input_name = "input_values"
|
|
|
|
| 36 |
|
| 37 |
def __init__(self, config: DELULUConfig):
|
| 38 |
super().__init__(config)
|
| 39 |
self.config = config
|
| 40 |
|
| 41 |
if not TORCHAUDIO_AVAILABLE:
|
| 42 |
-
raise ImportError(
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
# Build
|
| 45 |
conv_layer_config = list(zip(
|
| 46 |
config.conv_dim,
|
| 47 |
config.conv_kernel,
|
| 48 |
config.conv_stride
|
| 49 |
))
|
| 50 |
|
| 51 |
-
# Create torchaudio model
|
| 52 |
self.wav2vec2 = wav2vec2_model(
|
| 53 |
extractor_mode=config.extractor_mode,
|
| 54 |
extractor_conv_layer_config=conv_layer_config,
|
|
@@ -68,60 +121,213 @@ class DELULUModel(PreTrainedModel):
|
|
| 68 |
aux_num_out=None,
|
| 69 |
)
|
| 70 |
|
|
|
|
| 71 |
self.post_init()
|
| 72 |
|
| 73 |
-
def _init_weights(self, module):
|
| 74 |
-
"""Initialize weights."""
|
| 75 |
-
pass # Handled by torchaudio
|
| 76 |
-
|
| 77 |
def forward(
|
| 78 |
self,
|
| 79 |
input_values: torch.Tensor,
|
| 80 |
attention_mask: Optional[torch.Tensor] = None,
|
| 81 |
output_hidden_states: Optional[bool] = None,
|
|
|
|
| 82 |
return_dict: Optional[bool] = None,
|
| 83 |
-
) -> Union[Tuple,
|
| 84 |
"""
|
|
|
|
|
|
|
| 85 |
Args:
|
| 86 |
-
input_values
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
"""
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
|
|
|
| 94 |
if input_values.dim() == 1:
|
| 95 |
input_values = input_values.unsqueeze(0)
|
| 96 |
|
|
|
|
| 97 |
lengths = None
|
| 98 |
if attention_mask is not None:
|
| 99 |
lengths = attention_mask.sum(dim=-1)
|
| 100 |
|
|
|
|
| 101 |
if output_hidden_states:
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
hidden_states = tuple(features)
|
| 104 |
last_hidden_state = features[-1]
|
| 105 |
else:
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
hidden_states = None
|
| 108 |
|
|
|
|
|
|
|
|
|
|
| 109 |
if not return_dict:
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
return
|
| 113 |
last_hidden_state=last_hidden_state,
|
| 114 |
hidden_states=hidden_states,
|
|
|
|
|
|
|
| 115 |
)
|
| 116 |
|
| 117 |
-
def extract_features(
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
if input_values.dim() == 1:
|
| 120 |
input_values = input_values.unsqueeze(0)
|
| 121 |
-
|
|
|
|
| 122 |
return tuple(features)
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DELULU Model
|
| 3 |
+
|
| 4 |
+
DELULU (Discriminative Embedding Learning Using Latent Units) is a speaker-aware
|
| 5 |
+
self-supervised speech foundational model based on HuBERT architecture.
|
| 6 |
+
|
| 7 |
+
Paper: https://arxiv.org/abs/2510.17662
|
| 8 |
+
Authors: Massa Baali, Rita Singh, Bhiksha Raj
|
| 9 |
+
|
| 10 |
+
This implementation wraps torchaudio's wav2vec2_model for compatibility with
|
| 11 |
+
Hugging Face's AutoModel interface.
|
| 12 |
+
"""
|
| 13 |
|
| 14 |
import torch
|
| 15 |
import torch.nn as nn
|
| 16 |
from typing import Optional, Tuple, Union
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
|
| 19 |
from transformers import PreTrainedModel
|
| 20 |
from transformers.modeling_outputs import BaseModelOutput
|
| 21 |
+
|
| 22 |
from .configuration_delulu import DELULUConfig
|
| 23 |
|
| 24 |
+
# Try to import torchaudio
|
| 25 |
try:
|
| 26 |
from torchaudio.models.wav2vec2 import wav2vec2_model
|
| 27 |
TORCHAUDIO_AVAILABLE = True
|
|
|
|
| 29 |
TORCHAUDIO_AVAILABLE = False
|
| 30 |
|
| 31 |
|
| 32 |
+
@dataclass
|
| 33 |
+
class DELULUOutput(BaseModelOutput):
|
| 34 |
+
"""
|
| 35 |
+
Output class for DELULU model.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 39 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 40 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
| 41 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for each layer)
|
| 42 |
+
of shape `(batch_size, sequence_length, hidden_size)`.
|
| 43 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*):
|
| 44 |
+
Attention weights (not available for torchaudio backend).
|
| 45 |
+
extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
|
| 46 |
+
Features from the convolutional feature extractor.
|
| 47 |
+
"""
|
| 48 |
+
last_hidden_state: torch.FloatTensor = None
|
| 49 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 50 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 51 |
+
extract_features: Optional[torch.FloatTensor] = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
class DELULUModel(PreTrainedModel):
|
| 55 |
"""
|
| 56 |
DELULU Model for speaker-aware speech representation learning.
|
| 57 |
|
| 58 |
+
This model wraps torchaudio's wav2vec2_model with DELULU's custom configuration
|
| 59 |
+
(modified convolutional strides for 16ms frame shift).
|
| 60 |
+
|
| 61 |
Example:
|
| 62 |
```python
|
| 63 |
from transformers import AutoModel
|
| 64 |
import torch
|
| 65 |
|
| 66 |
+
# Load model
|
| 67 |
model = AutoModel.from_pretrained("cmu-mlsp/DELULU", trust_remote_code=True)
|
| 68 |
+
model.eval()
|
| 69 |
+
|
| 70 |
+
# Process audio (16kHz, mono)
|
| 71 |
+
waveform = torch.randn(1, 16000) # 1 second of audio
|
| 72 |
+
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
outputs = model(waveform)
|
| 75 |
+
features = outputs.last_hidden_state # [1, T, 768]
|
| 76 |
+
|
| 77 |
+
# For speaker verification, use mean pooling
|
| 78 |
+
speaker_embedding = features.mean(dim=1) # [1, 768]
|
| 79 |
```
|
| 80 |
"""
|
| 81 |
|
| 82 |
config_class = DELULUConfig
|
| 83 |
base_model_prefix = "delulu"
|
| 84 |
main_input_name = "input_values"
|
| 85 |
+
supports_gradient_checkpointing = False
|
| 86 |
|
| 87 |
def __init__(self, config: DELULUConfig):
|
| 88 |
super().__init__(config)
|
| 89 |
self.config = config
|
| 90 |
|
| 91 |
if not TORCHAUDIO_AVAILABLE:
|
| 92 |
+
raise ImportError(
|
| 93 |
+
"torchaudio is required for DELULU model. "
|
| 94 |
+
"Install with: pip install torchaudio"
|
| 95 |
+
)
|
| 96 |
|
| 97 |
+
# Build convolutional layer config from DELULU config
|
| 98 |
conv_layer_config = list(zip(
|
| 99 |
config.conv_dim,
|
| 100 |
config.conv_kernel,
|
| 101 |
config.conv_stride
|
| 102 |
))
|
| 103 |
|
| 104 |
+
# Create the underlying torchaudio model
|
| 105 |
self.wav2vec2 = wav2vec2_model(
|
| 106 |
extractor_mode=config.extractor_mode,
|
| 107 |
extractor_conv_layer_config=conv_layer_config,
|
|
|
|
| 121 |
aux_num_out=None,
|
| 122 |
)
|
| 123 |
|
| 124 |
+
# Initialize weights
|
| 125 |
self.post_init()
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
def forward(
|
| 128 |
self,
|
| 129 |
input_values: torch.Tensor,
|
| 130 |
attention_mask: Optional[torch.Tensor] = None,
|
| 131 |
output_hidden_states: Optional[bool] = None,
|
| 132 |
+
output_attentions: Optional[bool] = None,
|
| 133 |
return_dict: Optional[bool] = None,
|
| 134 |
+
) -> Union[Tuple, DELULUOutput]:
|
| 135 |
"""
|
| 136 |
+
Forward pass of DELULU model.
|
| 137 |
+
|
| 138 |
Args:
|
| 139 |
+
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 140 |
+
Raw audio waveform at 16kHz sampling rate.
|
| 141 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 142 |
+
Mask to avoid performing attention on padding. Not used in current implementation.
|
| 143 |
+
output_hidden_states (`bool`, *optional*):
|
| 144 |
+
Whether to return all hidden states.
|
| 145 |
+
output_attentions (`bool`, *optional*):
|
| 146 |
+
Whether to return attention weights. Not supported with torchaudio backend.
|
| 147 |
+
return_dict (`bool`, *optional*):
|
| 148 |
+
Whether to return a `DELULUOutput` instead of a tuple.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
`DELULUOutput` or `tuple`: Model outputs.
|
| 152 |
"""
|
| 153 |
+
output_hidden_states = (
|
| 154 |
+
output_hidden_states if output_hidden_states is not None
|
| 155 |
+
else self.config.output_hidden_states if hasattr(self.config, 'output_hidden_states')
|
| 156 |
+
else False
|
| 157 |
+
)
|
| 158 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict if hasattr(self.config, 'use_return_dict') else True
|
| 159 |
|
| 160 |
+
# Ensure input is 2D: (batch, samples)
|
| 161 |
if input_values.dim() == 1:
|
| 162 |
input_values = input_values.unsqueeze(0)
|
| 163 |
|
| 164 |
+
# Handle lengths for torchaudio model
|
| 165 |
lengths = None
|
| 166 |
if attention_mask is not None:
|
| 167 |
lengths = attention_mask.sum(dim=-1)
|
| 168 |
|
| 169 |
+
# Extract features using torchaudio model
|
| 170 |
if output_hidden_states:
|
| 171 |
+
# Get all layer outputs
|
| 172 |
+
features, lengths_out = self.wav2vec2.extract_features(
|
| 173 |
+
input_values,
|
| 174 |
+
lengths=lengths
|
| 175 |
+
)
|
| 176 |
+
# features is a list of tensors, one per layer
|
| 177 |
hidden_states = tuple(features)
|
| 178 |
last_hidden_state = features[-1]
|
| 179 |
else:
|
| 180 |
+
# Just get final output
|
| 181 |
+
outputs, lengths_out = self.wav2vec2(input_values, lengths=lengths)
|
| 182 |
+
last_hidden_state = outputs
|
| 183 |
hidden_states = None
|
| 184 |
|
| 185 |
+
# Get convolutional features (before transformer)
|
| 186 |
+
extract_features = self.wav2vec2.feature_extractor(input_values, lengths)[0]
|
| 187 |
+
|
| 188 |
if not return_dict:
|
| 189 |
+
outputs = (last_hidden_state,)
|
| 190 |
+
if output_hidden_states:
|
| 191 |
+
outputs = outputs + (hidden_states,)
|
| 192 |
+
return outputs
|
| 193 |
|
| 194 |
+
return DELULUOutput(
|
| 195 |
last_hidden_state=last_hidden_state,
|
| 196 |
hidden_states=hidden_states,
|
| 197 |
+
attentions=None, # torchaudio doesn't expose attention weights
|
| 198 |
+
extract_features=extract_features,
|
| 199 |
)
|
| 200 |
|
| 201 |
+
def extract_features(
|
| 202 |
+
self,
|
| 203 |
+
input_values: torch.Tensor,
|
| 204 |
+
lengths: Optional[torch.Tensor] = None
|
| 205 |
+
) -> Tuple[torch.Tensor, ...]:
|
| 206 |
+
"""
|
| 207 |
+
Extract features from all layers.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
input_values: Audio waveform of shape (batch, samples)
|
| 211 |
+
lengths: Optional lengths for each sample in batch
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Tuple of tensors, one per layer (including CNN output)
|
| 215 |
+
"""
|
| 216 |
if input_values.dim() == 1:
|
| 217 |
input_values = input_values.unsqueeze(0)
|
| 218 |
+
|
| 219 |
+
features, _ = self.wav2vec2.extract_features(input_values, lengths=lengths)
|
| 220 |
return tuple(features)
|
| 221 |
|
| 222 |
+
def get_speaker_embedding(
|
| 223 |
+
self,
|
| 224 |
+
input_values: torch.Tensor,
|
| 225 |
+
pooling: str = "mean"
|
| 226 |
+
) -> torch.Tensor:
|
| 227 |
+
"""
|
| 228 |
+
Extract speaker embedding from audio.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
input_values: Audio waveform of shape (batch, samples)
|
| 232 |
+
pooling: Pooling method - "mean", "max", or "first"
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
Speaker embedding of shape (batch, hidden_size)
|
| 236 |
+
"""
|
| 237 |
+
outputs = self.forward(input_values, return_dict=True)
|
| 238 |
+
features = outputs.last_hidden_state
|
| 239 |
+
|
| 240 |
+
if pooling == "mean":
|
| 241 |
+
return features.mean(dim=1)
|
| 242 |
+
elif pooling == "max":
|
| 243 |
+
return features.max(dim=1).values
|
| 244 |
+
elif pooling == "first":
|
| 245 |
+
return features[:, 0, :]
|
| 246 |
+
else:
|
| 247 |
+
raise ValueError(f"Unknown pooling method: {pooling}")
|
| 248 |
+
|
| 249 |
+
def _init_weights(self, module):
|
| 250 |
+
"""Initialize weights - mostly handled by torchaudio."""
|
| 251 |
+
pass
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class DELULUForSequenceClassification(PreTrainedModel):
|
| 255 |
+
"""
|
| 256 |
+
DELULU with a classification head for speaker verification and other tasks.
|
| 257 |
+
|
| 258 |
+
Example:
|
| 259 |
+
```python
|
| 260 |
+
from transformers import AutoModel
|
| 261 |
+
|
| 262 |
+
model = AutoModel.from_pretrained(
|
| 263 |
+
"cmu-mlsp/DELULU",
|
| 264 |
+
trust_remote_code=True,
|
| 265 |
+
num_labels=1251 # Number of speakers in VoxCeleb2
|
| 266 |
+
)
|
| 267 |
+
```
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
config_class = DELULUConfig
|
| 271 |
+
base_model_prefix = "delulu"
|
| 272 |
+
|
| 273 |
+
def __init__(self, config: DELULUConfig):
|
| 274 |
+
super().__init__(config)
|
| 275 |
+
|
| 276 |
+
self.delulu = DELULUModel(config)
|
| 277 |
+
self.projector = nn.Linear(config.hidden_size, config.hidden_size)
|
| 278 |
+
|
| 279 |
+
num_labels = getattr(config, 'num_labels', None)
|
| 280 |
+
if num_labels:
|
| 281 |
+
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
| 282 |
+
else:
|
| 283 |
+
self.classifier = None
|
| 284 |
+
|
| 285 |
+
self.post_init()
|
| 286 |
+
|
| 287 |
+
def forward(
|
| 288 |
+
self,
|
| 289 |
+
input_values: torch.Tensor,
|
| 290 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 291 |
+
labels: Optional[torch.Tensor] = None,
|
| 292 |
+
return_dict: Optional[bool] = None,
|
| 293 |
+
):
|
| 294 |
+
return_dict = return_dict if return_dict is not None else True
|
| 295 |
+
|
| 296 |
+
outputs = self.delulu(
|
| 297 |
+
input_values,
|
| 298 |
+
attention_mask=attention_mask,
|
| 299 |
+
return_dict=True
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Pool features
|
| 303 |
+
hidden_states = outputs.last_hidden_state
|
| 304 |
+
pooled = hidden_states.mean(dim=1)
|
| 305 |
+
|
| 306 |
+
# Project
|
| 307 |
+
embeddings = self.projector(pooled)
|
| 308 |
+
|
| 309 |
+
# Classify if head exists
|
| 310 |
+
logits = None
|
| 311 |
+
if self.classifier is not None:
|
| 312 |
+
logits = self.classifier(embeddings)
|
| 313 |
+
|
| 314 |
+
loss = None
|
| 315 |
+
if labels is not None and logits is not None:
|
| 316 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 317 |
+
loss = loss_fct(logits, labels)
|
| 318 |
+
|
| 319 |
+
if not return_dict:
|
| 320 |
+
output = (logits, embeddings) + (outputs.last_hidden_state,)
|
| 321 |
+
return ((loss,) + output) if loss is not None else output
|
| 322 |
+
|
| 323 |
+
return {
|
| 324 |
+
"loss": loss,
|
| 325 |
+
"logits": logits,
|
| 326 |
+
"embeddings": embeddings,
|
| 327 |
+
"last_hidden_state": outputs.last_hidden_state,
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# Register for auto classes
|
| 332 |
+
DELULUConfig.register_for_auto_class()
|
| 333 |
+
DELULUModel.register_for_auto_class("AutoModel")
|
preprocessor_config.json
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"do_normalize": true,
|
| 3 |
-
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
| 4 |
-
"feature_size": 1,
|
| 5 |
-
"padding_side": "right",
|
| 6 |
-
"padding_value": 0.0,
|
| 7 |
-
"return_attention_mask": true,
|
| 8 |
-
"sampling_rate": 16000
|
| 9 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
upload_delulu_to_hf.py
ADDED
|
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
DELULU Model Upload Script for Hugging Face Hub
|
| 4 |
+
================================================
|
| 5 |
+
|
| 6 |
+
Production-ready script to upload DELULU (Discriminative Embedding Learning Using
|
| 7 |
+
Latent Units) model checkpoints to Hugging Face with safety checks, versioning,
|
| 8 |
+
and best practices.
|
| 9 |
+
|
| 10 |
+
Author: Massa Baali
|
| 11 |
+
Model: DELULU - Speaker-Aware Self-Supervised Speech Foundational Model
|
| 12 |
+
Paper: https://arxiv.org/abs/2510.17662
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python upload_delulu_to_hf.py --checkpoint-dir ./checkpoints --repo-id username/DELULU
|
| 16 |
+
|
| 17 |
+
# With all options:
|
| 18 |
+
python upload_delulu_to_hf.py \
|
| 19 |
+
--checkpoint-dir ./checkpoints \
|
| 20 |
+
--repo-id username/DELULU \
|
| 21 |
+
--version v1.0.0 \
|
| 22 |
+
--tags speaker-verification speech-ssl hubert \
|
| 23 |
+
--private \
|
| 24 |
+
--dry-run
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import hashlib
|
| 29 |
+
import json
|
| 30 |
+
import logging
|
| 31 |
+
import os
|
| 32 |
+
import sys
|
| 33 |
+
from dataclasses import dataclass, field
|
| 34 |
+
from datetime import datetime
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
from typing import Optional
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
from huggingface_hub import (
|
| 40 |
+
HfApi,
|
| 41 |
+
create_repo,
|
| 42 |
+
upload_folder,
|
| 43 |
+
login,
|
| 44 |
+
whoami,
|
| 45 |
+
RepoUrl,
|
| 46 |
+
)
|
| 47 |
+
from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError
|
| 48 |
+
except ImportError:
|
| 49 |
+
print("Error: huggingface_hub not installed. Install with: pip install huggingface_hub")
|
| 50 |
+
sys.exit(1)
|
| 51 |
+
|
| 52 |
+
# Configure logging
|
| 53 |
+
logging.basicConfig(
|
| 54 |
+
level=logging.INFO,
|
| 55 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 56 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 57 |
+
)
|
| 58 |
+
logger = logging.getLogger(__name__)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# =============================================================================
|
| 62 |
+
# Configuration
|
| 63 |
+
# =============================================================================
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class DELULUConfig:
|
| 67 |
+
"""Configuration for DELULU model architecture.
|
| 68 |
+
|
| 69 |
+
DELULU uses HuBERT architecture with modified convolutional feature extractor
|
| 70 |
+
strides for 16ms frame shift (optimized for speaker verification).
|
| 71 |
+
"""
|
| 72 |
+
# Model architecture (HuBERT-based)
|
| 73 |
+
model_type: str = "hubert"
|
| 74 |
+
|
| 75 |
+
# Modified convolutional feature extractor configuration
|
| 76 |
+
# Standard HuBERT: [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
| 77 |
+
# DELULU: [(512, 10, 4)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
| 78 |
+
conv_dim: list = field(default_factory=lambda: [512, 512, 512, 512, 512, 512, 512])
|
| 79 |
+
conv_kernel: list = field(default_factory=lambda: [10, 3, 3, 3, 3, 2, 2])
|
| 80 |
+
conv_stride: list = field(default_factory=lambda: [4, 2, 2, 2, 2, 2, 2]) # Key difference!
|
| 81 |
+
|
| 82 |
+
# Transformer configuration
|
| 83 |
+
hidden_size: int = 768
|
| 84 |
+
num_hidden_layers: int = 12
|
| 85 |
+
num_attention_heads: int = 12
|
| 86 |
+
intermediate_size: int = 3072
|
| 87 |
+
|
| 88 |
+
# Training configuration
|
| 89 |
+
frame_shift_ms: int = 16 # Optimal for speaker verification
|
| 90 |
+
sampling_rate: int = 16000
|
| 91 |
+
|
| 92 |
+
# Clustering configuration (ReDimNet-guided)
|
| 93 |
+
num_clusters: int = 256
|
| 94 |
+
cluster_feature_dim: int = 2304 # ReDimNet frame-level embedding dimension
|
| 95 |
+
|
| 96 |
+
def to_dict(self) -> dict:
|
| 97 |
+
"""Convert config to dictionary for serialization."""
|
| 98 |
+
return {
|
| 99 |
+
"model_type": self.model_type,
|
| 100 |
+
"conv_dim": self.conv_dim,
|
| 101 |
+
"conv_kernel": self.conv_kernel,
|
| 102 |
+
"conv_stride": self.conv_stride,
|
| 103 |
+
"hidden_size": self.hidden_size,
|
| 104 |
+
"num_hidden_layers": self.num_hidden_layers,
|
| 105 |
+
"num_attention_heads": self.num_attention_heads,
|
| 106 |
+
"intermediate_size": self.intermediate_size,
|
| 107 |
+
"frame_shift_ms": self.frame_shift_ms,
|
| 108 |
+
"sampling_rate": self.sampling_rate,
|
| 109 |
+
"num_clusters": self.num_clusters,
|
| 110 |
+
"cluster_feature_dim": self.cluster_feature_dim,
|
| 111 |
+
"architectures": ["DELULUModel"],
|
| 112 |
+
"auto_map": {
|
| 113 |
+
"AutoModel": "modeling_delulu.DELULUModel",
|
| 114 |
+
"AutoConfig": "configuration_delulu.DELULUConfig"
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@dataclass
|
| 120 |
+
class UploadConfig:
|
| 121 |
+
"""Configuration for the upload process."""
|
| 122 |
+
checkpoint_dir: Path
|
| 123 |
+
repo_id: str
|
| 124 |
+
version: Optional[str] = None
|
| 125 |
+
tags: list = field(default_factory=list)
|
| 126 |
+
private: bool = False
|
| 127 |
+
dry_run: bool = False
|
| 128 |
+
create_if_missing: bool = True
|
| 129 |
+
commit_message: Optional[str] = None
|
| 130 |
+
|
| 131 |
+
# Safety settings
|
| 132 |
+
verify_checksums: bool = True
|
| 133 |
+
max_file_size_gb: float = 10.0
|
| 134 |
+
required_files: list = field(default_factory=lambda: ["pytorch_model.bin", "config.json"])
|
| 135 |
+
|
| 136 |
+
def __post_init__(self):
|
| 137 |
+
self.checkpoint_dir = Path(self.checkpoint_dir)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# =============================================================================
|
| 141 |
+
# Safety Checks
|
| 142 |
+
# =============================================================================
|
| 143 |
+
|
| 144 |
+
class SafetyChecker:
|
| 145 |
+
"""Performs safety checks before upload."""
|
| 146 |
+
|
| 147 |
+
def __init__(self, config: UploadConfig):
|
| 148 |
+
self.config = config
|
| 149 |
+
self.errors: list[str] = []
|
| 150 |
+
self.warnings: list[str] = []
|
| 151 |
+
|
| 152 |
+
def check_all(self) -> bool:
|
| 153 |
+
"""Run all safety checks. Returns True if all pass."""
|
| 154 |
+
self._check_directory_exists()
|
| 155 |
+
self._check_required_files()
|
| 156 |
+
self._check_file_sizes()
|
| 157 |
+
self._check_no_sensitive_data()
|
| 158 |
+
self._check_checkpoint_integrity()
|
| 159 |
+
|
| 160 |
+
# Log results
|
| 161 |
+
for warning in self.warnings:
|
| 162 |
+
logger.warning(f"⚠️ {warning}")
|
| 163 |
+
for error in self.errors:
|
| 164 |
+
logger.error(f"❌ {error}")
|
| 165 |
+
|
| 166 |
+
if self.errors:
|
| 167 |
+
logger.error(f"Safety checks failed with {len(self.errors)} error(s)")
|
| 168 |
+
return False
|
| 169 |
+
|
| 170 |
+
logger.info("✅ All safety checks passed")
|
| 171 |
+
return True
|
| 172 |
+
|
| 173 |
+
def _check_directory_exists(self):
|
| 174 |
+
"""Verify checkpoint directory exists and is accessible."""
|
| 175 |
+
if not self.config.checkpoint_dir.exists():
|
| 176 |
+
self.errors.append(f"Checkpoint directory not found: {self.config.checkpoint_dir}")
|
| 177 |
+
elif not self.config.checkpoint_dir.is_dir():
|
| 178 |
+
self.errors.append(f"Path is not a directory: {self.config.checkpoint_dir}")
|
| 179 |
+
|
| 180 |
+
def _check_required_files(self):
|
| 181 |
+
"""Check that required model files exist."""
|
| 182 |
+
if not self.config.checkpoint_dir.exists():
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
for required_file in self.config.required_files:
|
| 186 |
+
file_path = self.config.checkpoint_dir / required_file
|
| 187 |
+
# Also check for .safetensors variant
|
| 188 |
+
safetensors_variant = required_file.replace(".bin", ".safetensors")
|
| 189 |
+
safetensors_path = self.config.checkpoint_dir / safetensors_variant
|
| 190 |
+
|
| 191 |
+
if not file_path.exists() and not safetensors_path.exists():
|
| 192 |
+
# Special handling for model weights - either .bin or .safetensors is fine
|
| 193 |
+
if "model" in required_file:
|
| 194 |
+
self.warnings.append(
|
| 195 |
+
f"Model file not found: {required_file} or {safetensors_variant}. "
|
| 196 |
+
"Will look for alternative formats."
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
self.errors.append(f"Required file not found: {required_file}")
|
| 200 |
+
|
| 201 |
+
def _check_file_sizes(self):
|
| 202 |
+
"""Verify no files exceed maximum size limit."""
|
| 203 |
+
if not self.config.checkpoint_dir.exists():
|
| 204 |
+
return
|
| 205 |
+
|
| 206 |
+
max_size_bytes = self.config.max_file_size_gb * 1024 * 1024 * 1024
|
| 207 |
+
|
| 208 |
+
for file_path in self.config.checkpoint_dir.rglob("*"):
|
| 209 |
+
if file_path.is_file():
|
| 210 |
+
size = file_path.stat().st_size
|
| 211 |
+
if size > max_size_bytes:
|
| 212 |
+
self.errors.append(
|
| 213 |
+
f"File exceeds {self.config.max_file_size_gb}GB limit: "
|
| 214 |
+
f"{file_path.name} ({size / 1024 / 1024 / 1024:.2f}GB)"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def _check_no_sensitive_data(self):
|
| 218 |
+
"""Check for potentially sensitive files that shouldn't be uploaded."""
|
| 219 |
+
sensitive_patterns = [
|
| 220 |
+
".env", ".secret", "credentials", "password", "api_key", "token",
|
| 221 |
+
".git", "__pycache__", ".pyc", ".DS_Store"
|
| 222 |
+
]
|
| 223 |
+
|
| 224 |
+
if not self.config.checkpoint_dir.exists():
|
| 225 |
+
return
|
| 226 |
+
|
| 227 |
+
for file_path in self.config.checkpoint_dir.rglob("*"):
|
| 228 |
+
file_name = file_path.name.lower()
|
| 229 |
+
for pattern in sensitive_patterns:
|
| 230 |
+
if pattern in file_name:
|
| 231 |
+
self.warnings.append(
|
| 232 |
+
f"Potentially sensitive file detected: {file_path.name}. "
|
| 233 |
+
"Consider adding to .gitignore or removing before upload."
|
| 234 |
+
)
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
def _check_checkpoint_integrity(self):
|
| 238 |
+
"""Basic integrity check for PyTorch checkpoint files."""
|
| 239 |
+
if not self.config.checkpoint_dir.exists():
|
| 240 |
+
return
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
import torch
|
| 244 |
+
|
| 245 |
+
for file_path in self.config.checkpoint_dir.glob("*.bin"):
|
| 246 |
+
try:
|
| 247 |
+
# Just try to load metadata, not full weights
|
| 248 |
+
torch.load(file_path, map_location="cpu", weights_only=False)
|
| 249 |
+
logger.info(f"✓ Checkpoint integrity verified: {file_path.name}")
|
| 250 |
+
except Exception as e:
|
| 251 |
+
self.errors.append(f"Corrupted checkpoint file: {file_path.name} - {e}")
|
| 252 |
+
except ImportError:
|
| 253 |
+
self.warnings.append("PyTorch not installed, skipping checkpoint integrity check")
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# =============================================================================
|
| 257 |
+
# Checksum Utilities
|
| 258 |
+
# =============================================================================
|
| 259 |
+
|
| 260 |
+
def compute_file_checksum(file_path: Path, algorithm: str = "sha256") -> str:
|
| 261 |
+
"""Compute checksum for a file."""
|
| 262 |
+
hash_func = hashlib.new(algorithm)
|
| 263 |
+
|
| 264 |
+
with open(file_path, "rb") as f:
|
| 265 |
+
for chunk in iter(lambda: f.read(8192), b""):
|
| 266 |
+
hash_func.update(chunk)
|
| 267 |
+
|
| 268 |
+
return hash_func.hexdigest()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def generate_checksums(directory: Path) -> dict:
|
| 272 |
+
"""Generate checksums for all files in directory."""
|
| 273 |
+
checksums = {}
|
| 274 |
+
|
| 275 |
+
for file_path in directory.rglob("*"):
|
| 276 |
+
if file_path.is_file():
|
| 277 |
+
relative_path = file_path.relative_to(directory)
|
| 278 |
+
checksums[str(relative_path)] = {
|
| 279 |
+
"sha256": compute_file_checksum(file_path, "sha256"),
|
| 280 |
+
"size_bytes": file_path.stat().st_size
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
return checksums
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def save_checksums(checksums: dict, output_path: Path):
|
| 287 |
+
"""Save checksums to JSON file."""
|
| 288 |
+
with open(output_path, "w") as f:
|
| 289 |
+
json.dump(checksums, f, indent=2)
|
| 290 |
+
logger.info(f"Checksums saved to: {output_path}")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# =============================================================================
|
| 294 |
+
# Upload Manager
|
| 295 |
+
# =============================================================================
|
| 296 |
+
|
| 297 |
+
class DELULUUploader:
|
| 298 |
+
"""Handles uploading DELULU model to Hugging Face Hub."""
|
| 299 |
+
|
| 300 |
+
def __init__(self, upload_config: UploadConfig):
|
| 301 |
+
self.config = upload_config
|
| 302 |
+
self.api = HfApi()
|
| 303 |
+
self.model_config = DELULUConfig()
|
| 304 |
+
|
| 305 |
+
def authenticate(self) -> bool:
|
| 306 |
+
"""Verify authentication with Hugging Face Hub."""
|
| 307 |
+
try:
|
| 308 |
+
user_info = whoami()
|
| 309 |
+
logger.info(f"✅ Authenticated as: {user_info['name']}")
|
| 310 |
+
return True
|
| 311 |
+
except Exception as e:
|
| 312 |
+
logger.error(f"❌ Authentication failed: {e}")
|
| 313 |
+
logger.info("Run 'huggingface-cli login' or set HF_TOKEN environment variable")
|
| 314 |
+
return False
|
| 315 |
+
|
| 316 |
+
def prepare_upload_directory(self) -> Path:
|
| 317 |
+
"""Prepare files for upload, including config and checksums."""
|
| 318 |
+
upload_dir = self.config.checkpoint_dir
|
| 319 |
+
|
| 320 |
+
# Generate and save config.json if not present
|
| 321 |
+
config_path = upload_dir / "config.json"
|
| 322 |
+
if not config_path.exists():
|
| 323 |
+
logger.info("Generating config.json...")
|
| 324 |
+
with open(config_path, "w") as f:
|
| 325 |
+
json.dump(self.model_config.to_dict(), f, indent=2)
|
| 326 |
+
|
| 327 |
+
# Generate checksums
|
| 328 |
+
if self.config.verify_checksums:
|
| 329 |
+
logger.info("Generating checksums...")
|
| 330 |
+
checksums = generate_checksums(upload_dir)
|
| 331 |
+
save_checksums(checksums, upload_dir / "checksums.json")
|
| 332 |
+
|
| 333 |
+
# Create upload metadata
|
| 334 |
+
metadata = {
|
| 335 |
+
"upload_timestamp": datetime.utcnow().isoformat(),
|
| 336 |
+
"version": self.config.version,
|
| 337 |
+
"uploader_script_version": "1.0.0",
|
| 338 |
+
"model_type": "DELULU",
|
| 339 |
+
"base_architecture": "HuBERT"
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
metadata_path = upload_dir / "upload_metadata.json"
|
| 343 |
+
with open(metadata_path, "w") as f:
|
| 344 |
+
json.dump(metadata, f, indent=2)
|
| 345 |
+
|
| 346 |
+
return upload_dir
|
| 347 |
+
|
| 348 |
+
def create_or_verify_repo(self) -> bool:
|
| 349 |
+
"""Create repository if it doesn't exist, or verify access."""
|
| 350 |
+
try:
|
| 351 |
+
# Check if repo exists
|
| 352 |
+
self.api.repo_info(repo_id=self.config.repo_id, repo_type="model")
|
| 353 |
+
logger.info(f"✅ Repository exists: {self.config.repo_id}")
|
| 354 |
+
return True
|
| 355 |
+
|
| 356 |
+
except RepositoryNotFoundError:
|
| 357 |
+
if self.config.create_if_missing:
|
| 358 |
+
logger.info(f"Creating repository: {self.config.repo_id}")
|
| 359 |
+
|
| 360 |
+
if self.config.dry_run:
|
| 361 |
+
logger.info("[DRY RUN] Would create repository")
|
| 362 |
+
return True
|
| 363 |
+
|
| 364 |
+
try:
|
| 365 |
+
repo_url: RepoUrl = create_repo(
|
| 366 |
+
repo_id=self.config.repo_id,
|
| 367 |
+
repo_type="model",
|
| 368 |
+
private=self.config.private,
|
| 369 |
+
exist_ok=True
|
| 370 |
+
)
|
| 371 |
+
logger.info(f"✅ Repository created: {repo_url}")
|
| 372 |
+
return True
|
| 373 |
+
except HfHubHTTPError as e:
|
| 374 |
+
logger.error(f"❌ Failed to create repository: {e}")
|
| 375 |
+
return False
|
| 376 |
+
else:
|
| 377 |
+
logger.error(f"❌ Repository not found: {self.config.repo_id}")
|
| 378 |
+
return False
|
| 379 |
+
|
| 380 |
+
except Exception as e:
|
| 381 |
+
logger.error(f"❌ Error accessing repository: {e}")
|
| 382 |
+
return False
|
| 383 |
+
|
| 384 |
+
def upload(self) -> bool:
|
| 385 |
+
"""Execute the upload process."""
|
| 386 |
+
logger.info("=" * 60)
|
| 387 |
+
logger.info("DELULU Model Upload to Hugging Face Hub")
|
| 388 |
+
logger.info("=" * 60)
|
| 389 |
+
|
| 390 |
+
# Step 1: Authenticate
|
| 391 |
+
if not self.authenticate():
|
| 392 |
+
return False
|
| 393 |
+
|
| 394 |
+
# Step 2: Safety checks
|
| 395 |
+
safety_checker = SafetyChecker(self.config)
|
| 396 |
+
if not safety_checker.check_all():
|
| 397 |
+
return False
|
| 398 |
+
|
| 399 |
+
# Step 3: Create/verify repository
|
| 400 |
+
if not self.create_or_verify_repo():
|
| 401 |
+
return False
|
| 402 |
+
|
| 403 |
+
# Step 4: Prepare upload directory
|
| 404 |
+
upload_dir = self.prepare_upload_directory()
|
| 405 |
+
|
| 406 |
+
# Step 5: Generate commit message
|
| 407 |
+
commit_message = self.config.commit_message or self._generate_commit_message()
|
| 408 |
+
|
| 409 |
+
# Step 6: Execute upload
|
| 410 |
+
if self.config.dry_run:
|
| 411 |
+
logger.info("[DRY RUN] Would upload the following files:")
|
| 412 |
+
for file_path in upload_dir.rglob("*"):
|
| 413 |
+
if file_path.is_file():
|
| 414 |
+
size_mb = file_path.stat().st_size / 1024 / 1024
|
| 415 |
+
logger.info(f" - {file_path.relative_to(upload_dir)} ({size_mb:.2f} MB)")
|
| 416 |
+
logger.info(f"[DRY RUN] Commit message: {commit_message}")
|
| 417 |
+
return True
|
| 418 |
+
|
| 419 |
+
logger.info("Starting upload...")
|
| 420 |
+
try:
|
| 421 |
+
upload_folder(
|
| 422 |
+
folder_path=str(upload_dir),
|
| 423 |
+
repo_id=self.config.repo_id,
|
| 424 |
+
repo_type="model",
|
| 425 |
+
commit_message=commit_message,
|
| 426 |
+
ignore_patterns=[
|
| 427 |
+
"*.pyc", "__pycache__", ".git", ".DS_Store",
|
| 428 |
+
"*.log", "wandb", "runs"
|
| 429 |
+
]
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
logger.info("✅ Upload complete!")
|
| 433 |
+
logger.info(f"View model at: https://huggingface.co/{self.config.repo_id}")
|
| 434 |
+
|
| 435 |
+
# Create version tag if specified
|
| 436 |
+
if self.config.version:
|
| 437 |
+
self._create_version_tag()
|
| 438 |
+
|
| 439 |
+
return True
|
| 440 |
+
|
| 441 |
+
except Exception as e:
|
| 442 |
+
logger.error(f"❌ Upload failed: {e}")
|
| 443 |
+
return False
|
| 444 |
+
|
| 445 |
+
def _generate_commit_message(self) -> str:
|
| 446 |
+
"""Generate a descriptive commit message."""
|
| 447 |
+
parts = ["Upload DELULU model checkpoint"]
|
| 448 |
+
|
| 449 |
+
if self.config.version:
|
| 450 |
+
parts.append(f"(version {self.config.version})")
|
| 451 |
+
|
| 452 |
+
parts.append(f"\n\nModel: DELULU - Speaker-Aware Self-Supervised Speech Model")
|
| 453 |
+
parts.append(f"Architecture: HuBERT with modified stride configuration")
|
| 454 |
+
parts.append(f"Frame shift: 16ms (optimized for speaker verification)")
|
| 455 |
+
|
| 456 |
+
if self.config.tags:
|
| 457 |
+
parts.append(f"\nTags: {', '.join(self.config.tags)}")
|
| 458 |
+
|
| 459 |
+
return " ".join(parts[:2]) + "".join(parts[2:])
|
| 460 |
+
|
| 461 |
+
def _create_version_tag(self):
|
| 462 |
+
"""Create a Git tag for the version."""
|
| 463 |
+
try:
|
| 464 |
+
self.api.create_tag(
|
| 465 |
+
repo_id=self.config.repo_id,
|
| 466 |
+
tag=self.config.version,
|
| 467 |
+
tag_message=f"DELULU {self.config.version}",
|
| 468 |
+
repo_type="model"
|
| 469 |
+
)
|
| 470 |
+
logger.info(f"✅ Created version tag: {self.config.version}")
|
| 471 |
+
except Exception as e:
|
| 472 |
+
logger.warning(f"⚠️ Could not create version tag: {e}")
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
# =============================================================================
|
| 476 |
+
# CLI Interface
|
| 477 |
+
# =============================================================================
|
| 478 |
+
|
| 479 |
+
def parse_args() -> argparse.Namespace:
|
| 480 |
+
"""Parse command line arguments."""
|
| 481 |
+
parser = argparse.ArgumentParser(
|
| 482 |
+
description="Upload DELULU model checkpoints to Hugging Face Hub",
|
| 483 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 484 |
+
epilog="""
|
| 485 |
+
Examples:
|
| 486 |
+
# Basic upload
|
| 487 |
+
python upload_delulu_to_hf.py --checkpoint-dir ./checkpoints --repo-id username/DELULU
|
| 488 |
+
|
| 489 |
+
# Upload with version and tags
|
| 490 |
+
python upload_delulu_to_hf.py \\
|
| 491 |
+
--checkpoint-dir ./checkpoints \\
|
| 492 |
+
--repo-id username/DELULU \\
|
| 493 |
+
--version v1.0.0 \\
|
| 494 |
+
--tags speaker-verification speech-ssl hubert
|
| 495 |
+
|
| 496 |
+
# Dry run (no actual upload)
|
| 497 |
+
python upload_delulu_to_hf.py --checkpoint-dir ./checkpoints --repo-id username/DELULU --dry-run
|
| 498 |
+
|
| 499 |
+
# Private repository
|
| 500 |
+
python upload_delulu_to_hf.py --checkpoint-dir ./checkpoints --repo-id username/DELULU --private
|
| 501 |
+
"""
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
# Required arguments
|
| 505 |
+
parser.add_argument(
|
| 506 |
+
"--checkpoint-dir", "-c",
|
| 507 |
+
type=str,
|
| 508 |
+
required=True,
|
| 509 |
+
help="Path to directory containing model checkpoints"
|
| 510 |
+
)
|
| 511 |
+
parser.add_argument(
|
| 512 |
+
"--repo-id", "-r",
|
| 513 |
+
type=str,
|
| 514 |
+
required=True,
|
| 515 |
+
help="Hugging Face repository ID (e.g., username/DELULU)"
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Optional arguments
|
| 519 |
+
parser.add_argument(
|
| 520 |
+
"--version", "-v",
|
| 521 |
+
type=str,
|
| 522 |
+
default=None,
|
| 523 |
+
help="Version tag for this upload (e.g., v1.0.0)"
|
| 524 |
+
)
|
| 525 |
+
parser.add_argument(
|
| 526 |
+
"--tags", "-t",
|
| 527 |
+
nargs="+",
|
| 528 |
+
default=["speaker-verification", "speech-ssl", "hubert", "self-supervised"],
|
| 529 |
+
help="Tags to add to the model (space-separated)"
|
| 530 |
+
)
|
| 531 |
+
parser.add_argument(
|
| 532 |
+
"--private",
|
| 533 |
+
action="store_true",
|
| 534 |
+
help="Create as private repository"
|
| 535 |
+
)
|
| 536 |
+
parser.add_argument(
|
| 537 |
+
"--dry-run",
|
| 538 |
+
action="store_true",
|
| 539 |
+
help="Simulate upload without actually uploading"
|
| 540 |
+
)
|
| 541 |
+
parser.add_argument(
|
| 542 |
+
"--commit-message", "-m",
|
| 543 |
+
type=str,
|
| 544 |
+
default=None,
|
| 545 |
+
help="Custom commit message"
|
| 546 |
+
)
|
| 547 |
+
parser.add_argument(
|
| 548 |
+
"--no-verify-checksums",
|
| 549 |
+
action="store_true",
|
| 550 |
+
help="Skip checksum generation and verification"
|
| 551 |
+
)
|
| 552 |
+
parser.add_argument(
|
| 553 |
+
"--max-file-size",
|
| 554 |
+
type=float,
|
| 555 |
+
default=10.0,
|
| 556 |
+
help="Maximum file size in GB (default: 10.0)"
|
| 557 |
+
)
|
| 558 |
+
parser.add_argument(
|
| 559 |
+
"--no-create",
|
| 560 |
+
action="store_true",
|
| 561 |
+
help="Don't create repository if it doesn't exist"
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
return parser.parse_args()
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def main():
|
| 568 |
+
"""Main entry point."""
|
| 569 |
+
args = parse_args()
|
| 570 |
+
|
| 571 |
+
# Create upload configuration
|
| 572 |
+
upload_config = UploadConfig(
|
| 573 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 574 |
+
repo_id=args.repo_id,
|
| 575 |
+
version=args.version,
|
| 576 |
+
tags=args.tags,
|
| 577 |
+
private=args.private,
|
| 578 |
+
dry_run=args.dry_run,
|
| 579 |
+
commit_message=args.commit_message,
|
| 580 |
+
verify_checksums=not args.no_verify_checksums,
|
| 581 |
+
max_file_size_gb=args.max_file_size,
|
| 582 |
+
create_if_missing=not args.no_create
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# Create uploader and execute
|
| 586 |
+
uploader = DELULUUploader(upload_config)
|
| 587 |
+
success = uploader.upload()
|
| 588 |
+
|
| 589 |
+
sys.exit(0 if success else 1)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
if __name__ == "__main__":
|
| 593 |
+
main()
|