klemenk commited on
Commit
f72bc8c
·
verified ·
1 Parent(s): 7819f34

Upload distilled speech model

Browse files
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - speech
7
+ - audio
8
+ - data2vec
9
+ - distillation
10
+ - feature-extraction
11
+ library_name: transformers
12
+ pipeline_tag: feature-extraction
13
+ ---
14
+
15
+ # Distilled Speech Encoder
16
+
17
+ A Data2Vec-style bidirectional speech encoder trained via distillation from AuriStream models.
18
+
19
+ ## Model Details
20
+
21
+ - **Architecture**: 12-layer transformer with RoPE positional encoding
22
+ - **Hidden size**: 768
23
+ - **Attention heads**: 12
24
+ - **Parameters**: ~85M
25
+ - **Teacher model**: `TuKoResearch/AuriStream100M_40Pred_BigAudioDataset_500k`
26
+ - **Training step**: 100000
27
+ - **Input**: 16kHz raw audio waveform
28
+ - **Output**: 50Hz contextualized representations (768-dim)
29
+
30
+ ## Usage
31
+
32
+ ```python
33
+ from transformers import AutoModel, AutoFeatureExtractor
34
+ import torch
35
+
36
+ # Load model and feature extractor
37
+ model = AutoModel.from_pretrained("TuKoResearch/AuriStreamDistill_100M40PredTeacher_librispeech960", trust_remote_code=True)
38
+ feature_extractor = AutoFeatureExtractor.from_pretrained("TuKoResearch/AuriStreamDistill_100M40PredTeacher_librispeech960", trust_remote_code=True)
39
+
40
+ # Prepare audio (16kHz, mono)
41
+ audio = torch.randn(16000) # 1 second of audio
42
+
43
+ # Extract features
44
+ inputs = feature_extractor(audio, return_tensors="pt", sample_rate=16000)
45
+ outputs = model(inputs.input_values, output_hidden_states=True)
46
+
47
+ # Get representations
48
+ last_hidden = outputs.last_hidden_state # (1, 50, 768) for 1 second
49
+ all_hidden = outputs.hidden_states # Tuple of 13 tensors
50
+ ```
51
+
52
+ ## Hidden States
53
+
54
+ When `output_hidden_states=True`, the model returns hidden states from all layers:
55
+ - `hidden_states[0]`: Feature projection output (after conv encoder + projection)
56
+ - `hidden_states[1]` to `hidden_states[12]`: Transformer layer outputs
57
+ - `hidden_states[12]`: Final layer output (same as `last_hidden_state`)
58
+
59
+ This makes the model suitable for linear probing experiments at different layers.
60
+
61
+ ## Training
62
+
63
+ This model was trained using Data2Vec-style distillation:
64
+ 1. A frozen AuriStream teacher model generates target representations
65
+ 2. The student sees masked audio and learns to predict teacher representations
66
+ 3. Loss is computed only on masked positions
67
+
68
+ ## Citation
69
+
70
+ If you use this model, please cite:
71
+
72
+ ```bibtex
73
+ @misc{distilled_speech_encoder,
74
+ title={Distilled Speech Encoder},
75
+ author={TuKo Research},
76
+ year={2025},
77
+ url={https://huggingface.co/TuKoResearch/AuriStreamDistill_100M40PredTeacher_librispeech960}
78
+ }
79
+ ```
__pycache__/configuration_distilled_speech.cpython-311.pyc ADDED
Binary file (6.37 kB). View file
 
__pycache__/feature_extraction_distilled_speech.cpython-311.pyc ADDED
Binary file (7.38 kB). View file
 
__pycache__/modeling_distilled_speech.cpython-311.pyc ADDED
Binary file (28.8 kB). View file
 
config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "hidden_size": 768,
3
+ "num_hidden_layers": 12,
4
+ "num_attention_heads": 12,
5
+ "intermediate_size": 3072,
6
+ "hidden_dropout": 0.1,
7
+ "attention_dropout": 0.1,
8
+ "activation_dropout": 0.0,
9
+ "layer_norm_eps": 1e-05,
10
+ "feat_extract_norm": "group",
11
+ "feat_extract_activation": "gelu",
12
+ "feat_proj_dropout": 0.0,
13
+ "use_rope": true,
14
+ "rope_theta": 10000.0,
15
+ "sample_rate": 16000,
16
+ "teacher_model_name": "TuKoResearch/AuriStream100M_40Pred_BigAudioDataset_500k",
17
+ "teacher_hidden_size": 768,
18
+ "conv_dim": [
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512,
25
+ 512
26
+ ],
27
+ "conv_stride": [
28
+ 5,
29
+ 2,
30
+ 2,
31
+ 2,
32
+ 2,
33
+ 2,
34
+ 2
35
+ ],
36
+ "conv_kernel": [
37
+ 10,
38
+ 3,
39
+ 3,
40
+ 3,
41
+ 3,
42
+ 2,
43
+ 2
44
+ ],
45
+ "conv_bias": false,
46
+ "model_type": "distilled_speech",
47
+ "auto_map": {
48
+ "AutoConfig": "configuration_distilled_speech.DistilledSpeechConfig",
49
+ "AutoModel": "modeling_distilled_speech.DistilledSpeechModel",
50
+ "AutoFeatureExtractor": "feature_extraction_distilled_speech.DistilledSpeechFeatureExtractor"
51
+ },
52
+ "architectures": [
53
+ "DistilledSpeechModel"
54
+ ],
55
+ "training_step": 100000
56
+ }
configuration_distilled_speech.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Configuration for Distilled Speech Encoder.
3
+
4
+ This is a Data2Vec-style bidirectional speech encoder distilled from AuriStream.
5
+ """
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class DistilledSpeechConfig(PretrainedConfig):
11
+ """
12
+ Configuration class for DistilledSpeechModel.
13
+
14
+ This is a bidirectional transformer encoder for speech, trained via
15
+ Data2Vec-style distillation from AuriStream models.
16
+
17
+ Architecture:
18
+ - 7-layer convolutional feature encoder (16kHz -> 50Hz)
19
+ - N-layer bidirectional transformer with RoPE
20
+ - Optional projection head (for distillation training)
21
+
22
+ Args:
23
+ hidden_size (`int`, *optional*, defaults to 768):
24
+ Dimensionality of the encoder layers and the pooler layer.
25
+ num_hidden_layers (`int`, *optional*, defaults to 12):
26
+ Number of hidden layers in the Transformer encoder.
27
+ num_attention_heads (`int`, *optional*, defaults to 12):
28
+ Number of attention heads for each attention layer.
29
+ intermediate_size (`int`, *optional*, defaults to 3072):
30
+ Dimensionality of the "intermediate" (feed-forward) layer.
31
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
32
+ The non-linear activation function in the encoder.
33
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
34
+ The dropout probability for all fully connected layers.
35
+ attention_dropout (`float`, *optional*, defaults to 0.1):
36
+ The dropout ratio for the attention probabilities.
37
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
38
+ The epsilon used by the layer normalization layers.
39
+ conv_dim (`tuple`, *optional*):
40
+ Tuple of integers defining the number of channels in each conv layer.
41
+ conv_stride (`tuple`, *optional*):
42
+ Tuple of integers defining the stride of each conv layer.
43
+ conv_kernel (`tuple`, *optional*):
44
+ Tuple of integers defining the kernel size of each conv layer.
45
+ conv_bias (`bool`, *optional*, defaults to `False`):
46
+ Whether to use bias in conv layers.
47
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
48
+ Normalization type for first conv layer ("group" or "layer").
49
+ feat_extract_activation (`str`, *optional*, defaults to `"gelu"`):
50
+ Activation function for conv layers.
51
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
52
+ Dropout for feature projection layer.
53
+ use_rope (`bool`, *optional*, defaults to `True`):
54
+ Whether to use Rotary Position Embeddings (RoPE).
55
+ rope_theta (`float`, *optional*, defaults to 10000.0):
56
+ Base frequency for RoPE.
57
+ mask_time_prob (`float`, *optional*, defaults to 0.065):
58
+ Probability of masking time steps (for training).
59
+ mask_time_length (`int`, *optional*, defaults to 10):
60
+ Length of masked time spans (for training).
61
+ """
62
+
63
+ model_type = "distilled_speech"
64
+
65
+ def __init__(
66
+ self,
67
+ # Transformer architecture
68
+ hidden_size: int = 768,
69
+ num_hidden_layers: int = 12,
70
+ num_attention_heads: int = 12,
71
+ intermediate_size: int = 3072,
72
+ hidden_act: str = "gelu",
73
+ hidden_dropout: float = 0.1,
74
+ attention_dropout: float = 0.1,
75
+ activation_dropout: float = 0.0,
76
+ layer_norm_eps: float = 1e-5,
77
+
78
+ # Convolutional feature encoder
79
+ conv_dim: tuple = (512, 512, 512, 512, 512, 512, 512),
80
+ conv_stride: tuple = (5, 2, 2, 2, 2, 2, 2),
81
+ conv_kernel: tuple = (10, 3, 3, 3, 3, 2, 2),
82
+ conv_bias: bool = False,
83
+ feat_extract_norm: str = "group",
84
+ feat_extract_activation: str = "gelu",
85
+ feat_proj_dropout: float = 0.0,
86
+
87
+ # Positional encoding
88
+ use_rope: bool = True,
89
+ rope_theta: float = 10000.0,
90
+
91
+ # Masking (for training, disabled by default for inference)
92
+ mask_time_prob: float = 0.065,
93
+ mask_time_length: int = 10,
94
+ mask_time_min_masks: int = 2,
95
+
96
+ # Teacher info (for reference, not used in inference)
97
+ teacher_model_name: str = None,
98
+ teacher_hidden_size: int = None,
99
+
100
+ # Audio
101
+ sample_rate: int = 16000,
102
+
103
+ **kwargs,
104
+ ):
105
+ super().__init__(**kwargs)
106
+
107
+ self.hidden_size = hidden_size
108
+ self.num_hidden_layers = num_hidden_layers
109
+ self.num_attention_heads = num_attention_heads
110
+ self.intermediate_size = intermediate_size
111
+ self.hidden_act = hidden_act
112
+ self.hidden_dropout = hidden_dropout
113
+ self.attention_dropout = attention_dropout
114
+ self.activation_dropout = activation_dropout
115
+ self.layer_norm_eps = layer_norm_eps
116
+
117
+ # Conv encoder
118
+ self.conv_dim = list(conv_dim)
119
+ self.conv_stride = list(conv_stride)
120
+ self.conv_kernel = list(conv_kernel)
121
+ self.conv_bias = conv_bias
122
+ self.feat_extract_norm = feat_extract_norm
123
+ self.feat_extract_activation = feat_extract_activation
124
+ self.feat_proj_dropout = feat_proj_dropout
125
+
126
+ # Position encoding
127
+ self.use_rope = use_rope
128
+ self.rope_theta = rope_theta
129
+
130
+ # Masking
131
+ self.mask_time_prob = mask_time_prob
132
+ self.mask_time_length = mask_time_length
133
+ self.mask_time_min_masks = mask_time_min_masks
134
+
135
+ # Teacher info
136
+ self.teacher_model_name = teacher_model_name
137
+ self.teacher_hidden_size = teacher_hidden_size
138
+
139
+ # Audio
140
+ self.sample_rate = sample_rate
141
+
142
+ @property
143
+ def output_hz(self) -> int:
144
+ """Output frequency of the model in Hz."""
145
+ stride_product = 1
146
+ for s in self.conv_stride:
147
+ stride_product *= s
148
+ return self.sample_rate // stride_product # 50 Hz for default config
feature_extraction_distilled_speech.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feature extractor for Distilled Speech Model.
3
+
4
+ Handles audio preprocessing: normalization to zero mean and unit variance.
5
+ """
6
+
7
+ from typing import List, Optional, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ class DistilledSpeechFeatureExtractor:
14
+ """
15
+ Feature extractor for DistilledSpeechModel.
16
+
17
+ Normalizes audio to zero mean and unit variance (per-sample).
18
+ Expected input: 16kHz mono audio.
19
+
20
+ Example:
21
+ >>> extractor = DistilledSpeechFeatureExtractor()
22
+ >>> audio = np.random.randn(16000) # 1 second
23
+ >>> inputs = extractor(audio, return_tensors="pt", sample_rate=16000)
24
+ >>> inputs.input_values.shape
25
+ torch.Size([1, 16000])
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ sampling_rate: int = 16000,
31
+ do_normalize: bool = True,
32
+ return_attention_mask: bool = False,
33
+ ):
34
+ self.sampling_rate = sampling_rate
35
+ self.do_normalize = do_normalize
36
+ self.return_attention_mask = return_attention_mask
37
+
38
+ def __call__(
39
+ self,
40
+ raw_speech: Union[np.ndarray, List[float], torch.Tensor],
41
+ return_tensors: Optional[str] = "pt",
42
+ sample_rate: Optional[int] = None,
43
+ **kwargs,
44
+ ):
45
+ """
46
+ Process raw audio into model inputs.
47
+
48
+ Args:
49
+ raw_speech: Raw audio waveform (1D array or tensor)
50
+ return_tensors: "pt" for PyTorch tensors, "np" for numpy
51
+ sample_rate: Sample rate of input audio (for validation)
52
+
53
+ Returns:
54
+ Object with input_values attribute
55
+ """
56
+ # Validate sample rate
57
+ if sample_rate is not None and sample_rate != self.sampling_rate:
58
+ raise ValueError(
59
+ f"Expected sample rate {self.sampling_rate}, got {sample_rate}. "
60
+ f"Please resample your audio to {self.sampling_rate}Hz."
61
+ )
62
+
63
+ # Convert to numpy if needed
64
+ if isinstance(raw_speech, torch.Tensor):
65
+ raw_speech = raw_speech.numpy()
66
+ elif isinstance(raw_speech, list):
67
+ raw_speech = np.array(raw_speech)
68
+
69
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
70
+
71
+ # Ensure 1D
72
+ if raw_speech.ndim > 1:
73
+ raw_speech = raw_speech.squeeze()
74
+ if raw_speech.ndim != 1:
75
+ raise ValueError(f"Expected 1D audio, got shape {raw_speech.shape}")
76
+
77
+ # Normalize
78
+ if self.do_normalize:
79
+ raw_speech = (raw_speech - raw_speech.mean()) / (raw_speech.std() + 1e-7)
80
+
81
+ # Add batch dimension
82
+ raw_speech = raw_speech[np.newaxis, :]
83
+
84
+ # Convert to tensors
85
+ if return_tensors == "pt":
86
+ input_values = torch.from_numpy(raw_speech)
87
+ else:
88
+ input_values = raw_speech
89
+
90
+ return FeatureExtractorOutput(input_values=input_values)
91
+
92
+ def to_dict(self):
93
+ """Serialize to dict for saving."""
94
+ return {
95
+ "sampling_rate": self.sampling_rate,
96
+ "do_normalize": self.do_normalize,
97
+ "return_attention_mask": self.return_attention_mask,
98
+ "feature_extractor_type": "DistilledSpeechFeatureExtractor",
99
+ }
100
+
101
+ @classmethod
102
+ def from_dict(cls, config_dict):
103
+ """Load from dict."""
104
+ return cls(
105
+ sampling_rate=config_dict.get("sampling_rate", 16000),
106
+ do_normalize=config_dict.get("do_normalize", True),
107
+ return_attention_mask=config_dict.get("return_attention_mask", False),
108
+ )
109
+
110
+ def save_pretrained(self, save_directory: str):
111
+ """Save feature extractor config."""
112
+ import json
113
+ import os
114
+ os.makedirs(save_directory, exist_ok=True)
115
+ with open(os.path.join(save_directory, "preprocessor_config.json"), "w") as f:
116
+ json.dump(self.to_dict(), f, indent=2)
117
+
118
+ @classmethod
119
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
120
+ """Load feature extractor from directory or hub."""
121
+ import json
122
+ import os
123
+
124
+ if os.path.isdir(pretrained_model_name_or_path):
125
+ config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
126
+ else:
127
+ # Try to download from hub
128
+ from huggingface_hub import hf_hub_download
129
+ config_path = hf_hub_download(
130
+ repo_id=pretrained_model_name_or_path,
131
+ filename="preprocessor_config.json",
132
+ )
133
+
134
+ with open(config_path, "r") as f:
135
+ config = json.load(f)
136
+
137
+ return cls.from_dict(config)
138
+
139
+
140
+ class FeatureExtractorOutput:
141
+ """Simple container for feature extractor output."""
142
+
143
+ def __init__(self, input_values):
144
+ self.input_values = input_values
145
+
146
+ def to(self, device):
147
+ """Move tensors to device."""
148
+ if isinstance(self.input_values, torch.Tensor):
149
+ self.input_values = self.input_values.to(device)
150
+ return self
modeling_distilled_speech.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Model for Distilled Speech Encoder.
3
+
4
+ A Data2Vec-style bidirectional speech encoder distilled from AuriStream.
5
+ Returns hidden states from all layers for downstream probing/finetuning.
6
+ """
7
+
8
+ import math
9
+ from dataclasses import dataclass
10
+ from typing import Optional, Tuple, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from transformers import PreTrainedModel
16
+ from transformers.modeling_outputs import BaseModelOutput
17
+
18
+ try:
19
+ # When used as a HuggingFace model (trust_remote_code=True)
20
+ from configuration_distilled_speech import DistilledSpeechConfig
21
+ except ImportError:
22
+ # When used as part of a package
23
+ from .configuration_distilled_speech import DistilledSpeechConfig
24
+
25
+
26
+ @dataclass
27
+ class DistilledSpeechOutput(BaseModelOutput):
28
+ """
29
+ Output type for DistilledSpeechModel.
30
+
31
+ Args:
32
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
33
+ Sequence of hidden-states at the output of the last layer of the model.
34
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
35
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for each layer)
36
+ of shape `(batch_size, sequence_length, hidden_size)`.
37
+ extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
38
+ Output of the convolutional feature encoder (before projection).
39
+ """
40
+ last_hidden_state: torch.FloatTensor = None
41
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
42
+ extract_features: Optional[torch.FloatTensor] = None
43
+
44
+
45
+ # ==============================================================================
46
+ # Convolutional Feature Encoder
47
+ # ==============================================================================
48
+
49
+ class GroupNorm1D(nn.Module):
50
+ """Group normalization for 1D convolutions (B, C, T) -> (B, C, T)."""
51
+
52
+ def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
53
+ super().__init__()
54
+ self.norm = nn.GroupNorm(num_groups, num_channels, eps=eps)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ return self.norm(x)
58
+
59
+
60
+ class ConvLayer(nn.Module):
61
+ """Single convolutional layer with normalization and activation."""
62
+
63
+ def __init__(
64
+ self,
65
+ in_channels: int,
66
+ out_channels: int,
67
+ kernel_size: int,
68
+ stride: int,
69
+ bias: bool = False,
70
+ norm: str = "group",
71
+ activation: str = "gelu",
72
+ ):
73
+ super().__init__()
74
+ self.conv = nn.Conv1d(
75
+ in_channels,
76
+ out_channels,
77
+ kernel_size=kernel_size,
78
+ stride=stride,
79
+ bias=bias,
80
+ )
81
+
82
+ if norm == "group":
83
+ self.norm = GroupNorm1D(num_groups=out_channels, num_channels=out_channels)
84
+ elif norm == "layer":
85
+ self.norm = nn.LayerNorm(out_channels)
86
+ else:
87
+ self.norm = None
88
+
89
+ if activation == "gelu":
90
+ self.activation = nn.GELU()
91
+ elif activation == "relu":
92
+ self.activation = nn.ReLU()
93
+ else:
94
+ self.activation = None
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ x = self.conv(x)
98
+ if self.norm is not None:
99
+ if isinstance(self.norm, nn.LayerNorm):
100
+ x = x.transpose(1, 2)
101
+ x = self.norm(x)
102
+ x = x.transpose(1, 2)
103
+ else:
104
+ x = self.norm(x)
105
+ if self.activation is not None:
106
+ x = self.activation(x)
107
+ return x
108
+
109
+
110
+ class ConvFeatureEncoder(nn.Module):
111
+ """
112
+ 7-layer convolutional feature encoder.
113
+
114
+ Transforms raw 16kHz audio into 50Hz feature representations.
115
+ Total stride: 5 * 2 * 2 * 2 * 2 * 2 * 2 = 320 (16kHz / 320 = 50Hz)
116
+ """
117
+
118
+ def __init__(self, config: DistilledSpeechConfig):
119
+ super().__init__()
120
+
121
+ conv_layers = []
122
+ in_channels = 1
123
+
124
+ for i, (out_channels, kernel, stride) in enumerate(
125
+ zip(config.conv_dim, config.conv_kernel, config.conv_stride)
126
+ ):
127
+ norm = "group" if i > 0 else config.feat_extract_norm
128
+ conv_layers.append(
129
+ ConvLayer(
130
+ in_channels=in_channels,
131
+ out_channels=out_channels,
132
+ kernel_size=kernel,
133
+ stride=stride,
134
+ bias=config.conv_bias,
135
+ norm=norm,
136
+ activation=config.feat_extract_activation,
137
+ )
138
+ )
139
+ in_channels = out_channels
140
+
141
+ self.conv_layers = nn.ModuleList(conv_layers)
142
+ self.output_dim = config.conv_dim[-1]
143
+
144
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
145
+ """
146
+ Args:
147
+ x: Raw audio waveform (B, T) or (B, 1, T)
148
+
149
+ Returns:
150
+ Features (B, T', C) where T' = T // 320
151
+ """
152
+ if x.dim() == 2:
153
+ x = x.unsqueeze(1)
154
+
155
+ for conv_layer in self.conv_layers:
156
+ x = conv_layer(x)
157
+
158
+ x = x.transpose(1, 2)
159
+ return x
160
+
161
+
162
+ class FeatureProjection(nn.Module):
163
+ """Projects conv features to transformer hidden size."""
164
+
165
+ def __init__(self, config: DistilledSpeechConfig):
166
+ super().__init__()
167
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
168
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
169
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
170
+
171
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
172
+ x = self.layer_norm(x)
173
+ x = self.projection(x)
174
+ x = self.dropout(x)
175
+ return x
176
+
177
+
178
+ # ==============================================================================
179
+ # Rotary Position Embeddings
180
+ # ==============================================================================
181
+
182
+ class RotaryEmbedding(nn.Module):
183
+ """Rotary Position Embedding (RoPE)."""
184
+
185
+ def __init__(self, dim: int, theta: float = 10000.0, max_seq_len: int = 8192):
186
+ super().__init__()
187
+ self.dim = dim
188
+ self.theta = theta
189
+ self.max_seq_len = max_seq_len
190
+
191
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
192
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
193
+
194
+ self._cos_cached = None
195
+ self._sin_cached = None
196
+ self._seq_len_cached = 0
197
+
198
+ def _update_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
199
+ if seq_len > self._seq_len_cached or self._cos_cached is None:
200
+ self._seq_len_cached = max(seq_len, self.max_seq_len)
201
+ t = torch.arange(self._seq_len_cached, device=device, dtype=dtype)
202
+ freqs = torch.outer(t, self.inv_freq.to(device))
203
+ emb = torch.cat((freqs, freqs), dim=-1)
204
+ self._cos_cached = emb.cos()
205
+ self._sin_cached = emb.sin()
206
+
207
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
208
+ self._update_cache(seq_len, x.device, x.dtype)
209
+ return (
210
+ self._cos_cached[:seq_len].to(x.dtype),
211
+ self._sin_cached[:seq_len].to(x.dtype),
212
+ )
213
+
214
+
215
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
216
+ """Rotate half the hidden dims of the input."""
217
+ x1 = x[..., : x.shape[-1] // 2]
218
+ x2 = x[..., x.shape[-1] // 2 :]
219
+ return torch.cat((-x2, x1), dim=-1)
220
+
221
+
222
+ def apply_rotary_pos_emb(
223
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
224
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
225
+ """Apply rotary position embedding to query and key tensors."""
226
+ cos = cos.unsqueeze(0).unsqueeze(0)
227
+ sin = sin.unsqueeze(0).unsqueeze(0)
228
+ q_embed = (q * cos) + (rotate_half(q) * sin)
229
+ k_embed = (k * cos) + (rotate_half(k) * sin)
230
+ return q_embed, k_embed
231
+
232
+
233
+ # ==============================================================================
234
+ # Transformer Layers
235
+ # ==============================================================================
236
+
237
+ class MultiHeadAttention(nn.Module):
238
+ """Multi-head self-attention with RoPE support."""
239
+
240
+ def __init__(self, config: DistilledSpeechConfig):
241
+ super().__init__()
242
+ self.hidden_size = config.hidden_size
243
+ self.num_heads = config.num_attention_heads
244
+ self.head_dim = config.hidden_size // config.num_attention_heads
245
+
246
+ assert self.head_dim * self.num_heads == self.hidden_size
247
+
248
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
249
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
250
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
251
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
252
+
253
+ self.dropout = nn.Dropout(config.attention_dropout)
254
+ self.use_rope = config.use_rope
255
+
256
+ def forward(
257
+ self,
258
+ x: torch.Tensor,
259
+ cos: Optional[torch.Tensor] = None,
260
+ sin: Optional[torch.Tensor] = None,
261
+ attention_mask: Optional[torch.Tensor] = None,
262
+ ) -> torch.Tensor:
263
+ B, T, _ = x.shape
264
+
265
+ q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
266
+ k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
267
+ v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
268
+
269
+ if self.use_rope and cos is not None and sin is not None:
270
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
271
+
272
+ # Scaled dot-product attention
273
+ attn_output = F.scaled_dot_product_attention(
274
+ q, k, v,
275
+ attn_mask=attention_mask,
276
+ dropout_p=self.dropout.p if self.training else 0.0,
277
+ )
278
+
279
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.hidden_size)
280
+ attn_output = self.out_proj(attn_output)
281
+
282
+ return attn_output
283
+
284
+
285
+ class FeedForward(nn.Module):
286
+ """Feed-forward network with GELU activation."""
287
+
288
+ def __init__(self, config: DistilledSpeechConfig):
289
+ super().__init__()
290
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
291
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
292
+ self.activation = nn.GELU()
293
+ self.dropout = nn.Dropout(config.activation_dropout)
294
+
295
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
296
+ x = self.fc1(x)
297
+ x = self.activation(x)
298
+ x = self.dropout(x)
299
+ x = self.fc2(x)
300
+ return x
301
+
302
+
303
+ class TransformerLayer(nn.Module):
304
+ """Single transformer encoder layer with pre-norm."""
305
+
306
+ def __init__(self, config: DistilledSpeechConfig):
307
+ super().__init__()
308
+ self.attention = MultiHeadAttention(config)
309
+ self.feed_forward = FeedForward(config)
310
+ self.attention_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
+ self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
312
+ self.dropout = nn.Dropout(config.hidden_dropout)
313
+
314
+ def forward(
315
+ self,
316
+ x: torch.Tensor,
317
+ cos: Optional[torch.Tensor] = None,
318
+ sin: Optional[torch.Tensor] = None,
319
+ attention_mask: Optional[torch.Tensor] = None,
320
+ ) -> torch.Tensor:
321
+ # Self-attention with pre-norm
322
+ residual = x
323
+ x = self.attention_norm(x)
324
+ x = self.attention(x, cos, sin, attention_mask)
325
+ x = self.dropout(x)
326
+ x = residual + x
327
+
328
+ # Feed-forward with pre-norm
329
+ residual = x
330
+ x = self.ffn_norm(x)
331
+ x = self.feed_forward(x)
332
+ x = self.dropout(x)
333
+ x = residual + x
334
+
335
+ return x
336
+
337
+
338
+ class TransformerEncoder(nn.Module):
339
+ """Stack of transformer encoder layers with hidden state collection."""
340
+
341
+ def __init__(self, config: DistilledSpeechConfig):
342
+ super().__init__()
343
+ self.config = config
344
+ self.layers = nn.ModuleList([
345
+ TransformerLayer(config) for _ in range(config.num_hidden_layers)
346
+ ])
347
+
348
+ if config.use_rope:
349
+ self.rotary_emb = RotaryEmbedding(
350
+ dim=config.hidden_size // config.num_attention_heads,
351
+ theta=config.rope_theta,
352
+ )
353
+ else:
354
+ self.rotary_emb = None
355
+
356
+ def forward(
357
+ self,
358
+ x: torch.Tensor,
359
+ attention_mask: Optional[torch.Tensor] = None,
360
+ output_hidden_states: bool = False,
361
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
362
+ """
363
+ Args:
364
+ x: Input tensor (B, T, D)
365
+ attention_mask: Optional attention mask
366
+ output_hidden_states: Whether to return all hidden states
367
+
368
+ Returns:
369
+ Tuple of (last_hidden_state, all_hidden_states)
370
+ all_hidden_states: tuple of (num_layers + 1) tensors if output_hidden_states=True
371
+ - hidden_states[0]: input to first transformer layer
372
+ - hidden_states[i]: output of transformer layer i-1 (for i > 0)
373
+ """
374
+ B, T, _ = x.shape
375
+
376
+ cos, sin = None, None
377
+ if self.rotary_emb is not None:
378
+ cos, sin = self.rotary_emb(x, T)
379
+
380
+ all_hidden_states = () if output_hidden_states else None
381
+
382
+ # Collect hidden state before first layer (embedding output)
383
+ if output_hidden_states:
384
+ all_hidden_states = all_hidden_states + (x,)
385
+
386
+ for layer in self.layers:
387
+ x = layer(x, cos, sin, attention_mask)
388
+ # Collect hidden state after each layer
389
+ if output_hidden_states:
390
+ all_hidden_states = all_hidden_states + (x,)
391
+
392
+ return x, all_hidden_states
393
+
394
+
395
+ # ==============================================================================
396
+ # Main Model
397
+ # ==============================================================================
398
+
399
+ class DistilledSpeechModel(PreTrainedModel):
400
+ """
401
+ Distilled Speech Encoder Model.
402
+
403
+ A Data2Vec-style bidirectional transformer encoder for speech,
404
+ trained via distillation from AuriStream models.
405
+
406
+ This model takes raw audio waveforms as input and outputs contextualized
407
+ representations at 50Hz (20ms stride). It returns hidden states from all
408
+ transformer layers, making it suitable for downstream probing and finetuning.
409
+
410
+ Hidden states structure (for 12-layer model, output_hidden_states=True):
411
+ - hidden_states[0]: Feature projection output (input to transformer)
412
+ - hidden_states[1]: Output of transformer layer 0
413
+ - hidden_states[2]: Output of transformer layer 1
414
+ - ...
415
+ - hidden_states[12]: Output of transformer layer 11
416
+ Total: 13 hidden states (1 embedding + 12 layers)
417
+
418
+ Example usage:
419
+ >>> from transformers import AutoModel, AutoFeatureExtractor
420
+ >>> model = AutoModel.from_pretrained("your-model-name", trust_remote_code=True)
421
+ >>> processor = AutoFeatureExtractor.from_pretrained("your-model-name", trust_remote_code=True)
422
+ >>> audio = torch.randn(16000) # 1 second of audio at 16kHz
423
+ >>> inputs = processor(audio, return_tensors="pt", sample_rate=16000)
424
+ >>> outputs = model(inputs.input_values, output_hidden_states=True)
425
+ >>> last_hidden = outputs.last_hidden_state # (1, 50, 768)
426
+ >>> all_hidden = outputs.hidden_states # Tuple of 13 tensors
427
+ >>> # Or use dict-style access:
428
+ >>> all_hidden = outputs["hidden_states"]
429
+ """
430
+
431
+ config_class = DistilledSpeechConfig
432
+ base_model_prefix = "distilled_speech"
433
+ main_input_name = "input_values"
434
+ supports_gradient_checkpointing = True
435
+
436
+ def __init__(self, config: DistilledSpeechConfig):
437
+ super().__init__(config)
438
+ self.config = config
439
+
440
+ # Feature extraction
441
+ self.conv_encoder = ConvFeatureEncoder(config)
442
+ self.feature_projection = FeatureProjection(config)
443
+
444
+ # Transformer encoder
445
+ self.encoder = TransformerEncoder(config)
446
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
447
+
448
+ # Initialize weights
449
+ self.post_init()
450
+
451
+ def _init_weights(self, module):
452
+ """Initialize the weights."""
453
+ if isinstance(module, nn.Linear):
454
+ nn.init.trunc_normal_(module.weight, std=0.02)
455
+ if module.bias is not None:
456
+ nn.init.zeros_(module.bias)
457
+ elif isinstance(module, nn.LayerNorm):
458
+ nn.init.ones_(module.weight)
459
+ nn.init.zeros_(module.bias)
460
+ elif isinstance(module, nn.Conv1d):
461
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
462
+ if module.bias is not None:
463
+ nn.init.zeros_(module.bias)
464
+
465
+ def forward(
466
+ self,
467
+ input_values: torch.Tensor,
468
+ attention_mask: Optional[torch.Tensor] = None,
469
+ output_hidden_states: Optional[bool] = None,
470
+ return_dict: Optional[bool] = None,
471
+ ) -> Union[Tuple, DistilledSpeechOutput]:
472
+ """
473
+ Forward pass through the model.
474
+
475
+ Args:
476
+ input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
477
+ Raw audio waveform, normalized to zero mean and unit variance.
478
+ Expected sample rate: 16kHz.
479
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
480
+ Mask to avoid performing attention on padding tokens.
481
+ output_hidden_states (`bool`, *optional*):
482
+ Whether to return hidden states from all layers.
483
+ return_dict (`bool`, *optional*):
484
+ Whether to return a ModelOutput instead of a plain tuple.
485
+
486
+ Returns:
487
+ `DistilledSpeechOutput` or `tuple`:
488
+ - last_hidden_state: (B, T', hidden_size) where T' = T // 320
489
+ - hidden_states: Tuple of (B, T', hidden_size) for each layer if output_hidden_states=True
490
+ - extract_features: (B, T', conv_dim[-1]) raw conv features
491
+ """
492
+ output_hidden_states = (
493
+ output_hidden_states if output_hidden_states is not None
494
+ else self.config.output_hidden_states
495
+ )
496
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
497
+
498
+ # Conv encoder: (B, T) -> (B, T', conv_dim)
499
+ extract_features = self.conv_encoder(input_values)
500
+
501
+ # Feature projection: (B, T', conv_dim) -> (B, T', hidden_size)
502
+ hidden_states = self.feature_projection(extract_features)
503
+
504
+ # Transformer encoder
505
+ encoder_output, all_hidden_states = self.encoder(
506
+ hidden_states,
507
+ attention_mask=attention_mask,
508
+ output_hidden_states=output_hidden_states,
509
+ )
510
+
511
+ # Final layer norm
512
+ last_hidden_state = self.final_layer_norm(encoder_output)
513
+
514
+ if not return_dict:
515
+ outputs = (last_hidden_state,)
516
+ if output_hidden_states:
517
+ outputs = outputs + (all_hidden_states,)
518
+ outputs = outputs + (extract_features,)
519
+ return outputs
520
+
521
+ return DistilledSpeechOutput(
522
+ last_hidden_state=last_hidden_state,
523
+ hidden_states=all_hidden_states,
524
+ extract_features=extract_features,
525
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "sampling_rate": 16000,
3
+ "do_normalize": true,
4
+ "return_attention_mask": false,
5
+ "feature_extractor_type": "DistilledSpeechFeatureExtractor"
6
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d818dc5701dedd635879dcc3a5df3056714f5f53ba80d90d11843e9b62fdc3d
3
+ size 358700726