klemenk commited on
Commit
9c3e596
·
verified ·
1 Parent(s): 5f1cf3e

Upload AuriStream base model code

Browse files
Files changed (3) hide show
  1. README.md +67 -0
  2. configuration_auristream.py +92 -0
  3. modeling_auristream.py +487 -0
README.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - audio
5
+ - speech
6
+ - language-model
7
+ - auristream
8
+ library_name: transformers
9
+ ---
10
+
11
+ # AuriStream - Speech Language Model
12
+
13
+ **AuriStream** is a speech language model by **Greta Tuckute** and **Klemen Kotar**.
14
+
15
+ This repository contains the shared model code for AuriStream models.
16
+
17
+ ## Overview
18
+
19
+ AuriStream is a GPT-like transformer model for cochlear token prediction with optional
20
+ multi-token prediction (MTP) heads.
21
+
22
+ This model predicts cochlear tokens from a tokenizer such as [WavCochCausalV8192](https://huggingface.co/TuKoResearch/WavCochCausalV8192).
23
+
24
+ ## Usage
25
+
26
+ This repository is not meant to be used directly. Instead, use one of the checkpoint
27
+ repositories that reference this base code:
28
+
29
+ - [AuriStream7B_40Pred_BigAudioDataset_500k](https://huggingface.co/TuKoResearch/AuriStream7B_40Pred_BigAudioDataset_500k)
30
+
31
+ To load a checkpoint:
32
+
33
+ ```python
34
+ from transformers import AutoModel, AutoConfig
35
+
36
+ model = AutoModel.from_pretrained(
37
+ "TuKoResearch/AuriStream7B_40Pred_BigAudioDataset_500k",
38
+ trust_remote_code=True,
39
+ )
40
+ ```
41
+
42
+ ## Model Architecture
43
+
44
+ The AuriStream model includes:
45
+ - RMSNorm for layer normalization
46
+ - Rotary Position Embeddings (RoPE)
47
+ - SiLU activation in MLP layers
48
+ - Multi-token prediction heads
49
+
50
+ ## Configuration Options
51
+
52
+ | Parameter | Description | Default |
53
+ |-----------|-------------|---------|
54
+ | `vocab_size` | Number of cochlear tokens | 8192 |
55
+ | `n_embd` | Hidden dimension | 768 |
56
+ | `n_layer` | Number of transformer layers | 12 |
57
+ | `n_head` | Number of attention heads | 12 |
58
+ | `n_pred_steps` | Number of prediction steps (MTP) | 1 |
59
+
60
+ ## Files
61
+
62
+ - `configuration_auristream.py` - Configuration class
63
+ - `modeling_auristream.py` - Model implementation
64
+
65
+ ## Tokenizer
66
+
67
+ This model uses cochlear tokens from [WavCochCausalV8192](https://huggingface.co/TuKoResearch/WavCochCausalV8192).
configuration_auristream.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AuriStream Configuration for HuggingFace Transformers.
3
+
4
+ AuriStream is a speech language model by Greta Tuckute and Klemen Kotar.
5
+ """
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class AuriStreamConfig(PretrainedConfig):
11
+ """
12
+ Configuration class for AuriStream models.
13
+
14
+ This configuration supports various model sizes and prediction head configurations
15
+ for the AuriStream speech language model family.
16
+
17
+ Args:
18
+ vocab_size (`int`, *optional*, defaults to 8192):
19
+ Vocabulary size of the model (number of cochlear tokens).
20
+ n_embd (`int`, *optional*, defaults to 768):
21
+ Dimensionality of the embeddings and hidden states.
22
+ n_layer (`int`, *optional*, defaults to 12):
23
+ Number of transformer layers.
24
+ n_head (`int`, *optional*, defaults to 12):
25
+ Number of attention heads for each attention layer.
26
+ n_pred_steps (`int`, *optional*, defaults to 1):
27
+ Number of future prediction steps (multi-token prediction heads).
28
+ dropout (`float`, *optional*, defaults to 0.0):
29
+ Dropout probability for all fully connected layers.
30
+ bias (`bool`, *optional*, defaults to False):
31
+ Whether to use bias in linear layers.
32
+ rope_theta (`float`, *optional*, defaults to 10000.0):
33
+ Base theta for RoPE embeddings.
34
+ input_conv_kernel_size (`int`, *optional*, defaults to 0):
35
+ Kernel size for input convolution layer (0 means no input conv).
36
+ """
37
+
38
+ model_type = "AuriStream"
39
+
40
+ def __init__(
41
+ self,
42
+ vocab_size: int = 8192,
43
+ n_embd: int = 768,
44
+ n_layer: int = 12,
45
+ n_head: int = 12,
46
+ n_pred_steps: int = 1,
47
+ dropout: float = 0.0,
48
+ bias: bool = False,
49
+ rope_theta: float = 10000.0,
50
+ input_conv_kernel_size: int = 0,
51
+ **kwargs,
52
+ ):
53
+ self.vocab_size = vocab_size
54
+ self.n_embd = n_embd
55
+ self.n_layer = n_layer
56
+ self.n_head = n_head
57
+ self.n_pred_steps = n_pred_steps
58
+ self.dropout = dropout
59
+ self.bias = bias
60
+ self.rope_theta = rope_theta
61
+ self.input_conv_kernel_size = input_conv_kernel_size
62
+
63
+ super().__init__(**kwargs)
64
+
65
+ @classmethod
66
+ def from_local_config(cls, local_cfg):
67
+ """
68
+ Create an AuriStreamConfig from a local dataclass config.
69
+
70
+ Args:
71
+ local_cfg: A dataclass config object (e.g., AuriStream100M20PredConfig)
72
+
73
+ Returns:
74
+ AuriStreamConfig instance
75
+ """
76
+ config_dict = {}
77
+
78
+ # Map all known attributes
79
+ known_attrs = [
80
+ 'vocab_size', 'n_embd', 'n_layer', 'n_head', 'n_pred_steps',
81
+ 'dropout', 'bias', 'rope_theta', 'input_conv_kernel_size'
82
+ ]
83
+
84
+ for attr in known_attrs:
85
+ if hasattr(local_cfg, attr):
86
+ config_dict[attr] = getattr(local_cfg, attr)
87
+
88
+ # Handle n_pred_steps default (if not present, it's 1)
89
+ if 'n_pred_steps' not in config_dict:
90
+ config_dict['n_pred_steps'] = 1
91
+
92
+ return cls(**config_dict)
modeling_auristream.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AuriStream Model for HuggingFace Transformers.
3
+
4
+ AuriStream is a speech language model by Greta Tuckute and Klemen Kotar.
5
+ This model predicts cochlear tokens from a tokenizer such as WavCochCausalV8192.
6
+
7
+ https://huggingface.co/TuKoResearch/WavCochCausalV8192
8
+ """
9
+
10
+ import math
11
+ from typing import Optional, List
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.nn import functional as F
16
+
17
+ from transformers import PreTrainedModel
18
+ from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
19
+
20
+ from .configuration_auristream import AuriStreamConfig
21
+
22
+
23
+ # ============================================================================
24
+ # Building Blocks
25
+ # ============================================================================
26
+
27
+ class RMSNorm(nn.Module):
28
+ """Root Mean Square Normalization."""
29
+
30
+ def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6):
31
+ super().__init__()
32
+ self.eps = eps
33
+ self.weight = nn.Parameter(torch.ones(dim)) if weight else None
34
+
35
+ def _norm(self, x):
36
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
37
+
38
+ def forward(self, x):
39
+ output = self._norm(x.float()).type_as(x)
40
+ if self.weight is not None:
41
+ return output * self.weight
42
+ return output
43
+
44
+
45
+ class Rotary(nn.Module):
46
+ """Rotary Position Embeddings (RoPE)."""
47
+
48
+ def __init__(self, dim: int, base: float = 10000):
49
+ super().__init__()
50
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
51
+ self.register_buffer("inv_freq", inv_freq)
52
+
53
+ def forward(self, x):
54
+ seq_len = x.shape[1]
55
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
56
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
57
+ cos_cached = freqs.cos()
58
+ sin_cached = freqs.sin()
59
+ return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
60
+
61
+
62
+ def apply_rotary_emb(x, cos, sin):
63
+ """Apply rotary embeddings to input tensor."""
64
+ assert x.ndim == 4 # multihead attention expected
65
+ d = x.shape[3] // 2
66
+ x1 = x[..., :d]
67
+ x2 = x[..., d:]
68
+ y1 = x1 * cos + x2 * sin
69
+ y2 = x1 * (-sin) + x2 * cos
70
+ return torch.cat([y1, y2], dim=3)
71
+
72
+
73
+ class CausalSelfAttention(nn.Module):
74
+ """Multi-head causal self attention with RoPE."""
75
+
76
+ def __init__(self, config: AuriStreamConfig):
77
+ super().__init__()
78
+ self.n_head = config.n_head
79
+ self.n_embd = config.n_embd
80
+ self.head_dim = self.n_embd // self.n_head
81
+ assert self.n_embd % self.n_head == 0
82
+
83
+ # Key, query, value projections for all heads
84
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
85
+ # Output projection
86
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
87
+
88
+ # RoPE
89
+ rope_theta = getattr(config, 'rope_theta', 10000)
90
+ if rope_theta is None:
91
+ rope_theta = 10000
92
+ self.rotary = Rotary(self.head_dim, base=rope_theta)
93
+
94
+ def forward(self, x, return_kv=False, return_attn_maps=False):
95
+ B, T, C = x.size()
96
+
97
+ # Calculate query, key, values for all heads
98
+ qkv = self.c_attn(x)
99
+ q, k, v = qkv.split(self.n_embd, dim=2)
100
+ k = k.view(B, T, self.n_head, self.head_dim)
101
+ q = q.view(B, T, self.n_head, self.head_dim)
102
+ v = v.view(B, T, self.n_head, self.head_dim)
103
+
104
+ # Apply RoPE
105
+ cos, sin = self.rotary(q)
106
+ q = apply_rotary_emb(q, cos, sin)
107
+ k = apply_rotary_emb(k, cos, sin)
108
+
109
+ if not return_kv and not return_attn_maps:
110
+ y = F.scaled_dot_product_attention(
111
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
112
+ is_causal=True
113
+ )
114
+ else:
115
+ # Manual implementation of attention
116
+ q = q.transpose(1, 2)
117
+ k = k.transpose(1, 2)
118
+ v = v.transpose(1, 2)
119
+ att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
120
+ mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
121
+ mask = mask.view(1, 1, T, T)
122
+ masked_att = att.masked_fill(mask, float('-inf'))
123
+ masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
124
+ y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
125
+
126
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
127
+ y = self.c_proj(y)
128
+
129
+ if return_attn_maps:
130
+ return y, F.softmax(att, dim=-1)
131
+ if return_kv:
132
+ return y, k, v
133
+ return y
134
+
135
+ def kv_cache_forward(self, x, k_cache=None, v_cache=None):
136
+ """Forward pass with KV cache for efficient generation."""
137
+ B, T, C = x.size()
138
+
139
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
140
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
141
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
142
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
143
+
144
+ # Apply RoPE with correct position
145
+ cache_len = k_cache.shape[2] if k_cache is not None else 0
146
+ dummy = torch.zeros(B, cache_len + T, self.n_head, self.head_dim,
147
+ device=q.device, dtype=q.dtype)
148
+ cos, sin = self.rotary(dummy)
149
+ cos = cos[:, cache_len:cache_len+T, :, :]
150
+ sin = sin[:, cache_len:cache_len+T, :, :]
151
+ q = apply_rotary_emb(q, cos, sin)
152
+ k = apply_rotary_emb(k, cos, sin)
153
+
154
+ # Concatenate with cache
155
+ if k_cache is not None:
156
+ k = torch.cat((k_cache, k), dim=2)
157
+ if v_cache is not None:
158
+ v = torch.cat((v_cache, v), dim=2)
159
+
160
+ # Attention
161
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
162
+ att = F.softmax(att, dim=-1)
163
+ y = att @ v
164
+
165
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
166
+ y = self.c_proj(y)
167
+
168
+ return y, k, v
169
+
170
+
171
+ class MLP(nn.Module):
172
+ """MLP with SiLU activation."""
173
+
174
+ def __init__(self, config: AuriStreamConfig):
175
+ super().__init__()
176
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
177
+ self.gelu = nn.SiLU()
178
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
179
+ self.dropout = nn.Dropout(config.dropout)
180
+
181
+ def forward(self, x):
182
+ x = self.c_fc(x)
183
+ x = self.gelu(x)
184
+ x = self.c_proj(x)
185
+ x = self.dropout(x)
186
+ return x
187
+
188
+
189
+ class Block(nn.Module):
190
+ """Transformer block with pre-normalization."""
191
+
192
+ def __init__(self, config: AuriStreamConfig):
193
+ super().__init__()
194
+ self.attn = CausalSelfAttention(config)
195
+ self.mlp = MLP(config)
196
+ self.attn_scale = 1.0
197
+ self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
198
+ self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
199
+
200
+ def forward(self, x, return_kv=False, k_cache=None, v_cache=None):
201
+ if k_cache is not None and v_cache is not None:
202
+ x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), k_cache, v_cache)
203
+ x = x + x_attn
204
+ x = x + self.mlp(self.norm2(x))
205
+ return x, k, v
206
+ elif return_kv:
207
+ x_attn, k, v = self.attn(self.norm1(x), return_kv=True)
208
+ x = x + x_attn
209
+ x = x + self.mlp(self.norm2(x))
210
+ return x, k, v
211
+
212
+ x = x + self.attn_scale * self.attn(self.norm1(x))
213
+ x = x + self.mlp(self.norm2(x))
214
+ return x
215
+
216
+
217
+ # ============================================================================
218
+ # Main Model
219
+ # ============================================================================
220
+
221
+ class AuriStreamPreTrainedModel(PreTrainedModel):
222
+ """Base class for AuriStream models."""
223
+
224
+ config_class = AuriStreamConfig
225
+ base_model_prefix = "model"
226
+ supports_gradient_checkpointing = True
227
+ _no_split_modules = ["Block"]
228
+
229
+ def _init_weights(self, module):
230
+ if isinstance(module, nn.Linear):
231
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
232
+ if module.bias is not None:
233
+ torch.nn.init.zeros_(module.bias)
234
+ elif isinstance(module, nn.Embedding):
235
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
236
+
237
+
238
+ class AuriStreamModel(AuriStreamPreTrainedModel):
239
+ """
240
+ AuriStream speech language model.
241
+
242
+ A GPT-like transformer model for cochlear token prediction with optional
243
+ multi-token prediction (MTP) heads for speculative decoding.
244
+
245
+ Developed by Greta Tuckute and Klemen Kotar.
246
+ """
247
+
248
+ config_class = AuriStreamConfig
249
+
250
+ def __init__(self, config: AuriStreamConfig):
251
+ super().__init__(config)
252
+ self.config = config
253
+
254
+ # Transformer components
255
+ self.transformer = nn.ModuleDict(dict(
256
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
257
+ drop=nn.Dropout(config.dropout),
258
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
259
+ ln_f=RMSNorm(config.n_embd, bias=config.bias),
260
+ ))
261
+
262
+ # Multi-token prediction heads
263
+ if hasattr(config, 'n_pred_steps') and config.n_pred_steps > 1:
264
+ self.future_heads = nn.ModuleList([
265
+ nn.Linear(config.n_embd, config.vocab_size, bias=False)
266
+ for _ in range(config.n_pred_steps - 1)
267
+ ])
268
+ else:
269
+ self.future_heads = None
270
+
271
+ # Output head
272
+ self.coch_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
273
+
274
+ # Initialize weights
275
+ self.apply(self._init_weights)
276
+ # Apply special scaled init to residual projections
277
+ for pn, p in self.named_parameters():
278
+ if pn.endswith('c_proj.weight'):
279
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
280
+
281
+ def get_input_embeddings(self):
282
+ return self.transformer.wte
283
+
284
+ def set_input_embeddings(self, value):
285
+ self.transformer.wte = value
286
+
287
+ def get_num_params(self, non_embedding=True):
288
+ """Return the number of parameters in the model."""
289
+ return sum(p.numel() for p in self.parameters())
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: Optional[torch.LongTensor] = None,
294
+ labels: Optional[torch.LongTensor] = None,
295
+ output_hidden_states: Optional[bool] = False,
296
+ return_dict: Optional[bool] = True,
297
+ # Legacy arguments for compatibility
298
+ seq: Optional[torch.LongTensor] = None,
299
+ tgt: Optional[torch.LongTensor] = None,
300
+ ):
301
+ """
302
+ Forward pass for the AuriStream model.
303
+
304
+ Args:
305
+ input_ids: Input token IDs of shape (batch_size, seq_len)
306
+ labels: Target token IDs for computing loss
307
+ output_hidden_states: Whether to return all hidden states
308
+ return_dict: Whether to return a dict or tuple
309
+ seq: Legacy argument (alias for input_ids)
310
+ tgt: Legacy argument (alias for labels)
311
+
312
+ Returns:
313
+ CausalLMOutput with logits and optional loss
314
+ """
315
+ # Handle legacy arguments
316
+ if seq is not None:
317
+ input_ids = seq
318
+ if tgt is not None:
319
+ labels = tgt
320
+
321
+ # Get embeddings
322
+ tok_emb = self.transformer.wte(input_ids)
323
+ x = self.transformer.drop(tok_emb)
324
+
325
+ # Collect hidden states if requested
326
+ all_hidden_states = []
327
+
328
+ # Forward through transformer blocks
329
+ for block in self.transformer.h:
330
+ if output_hidden_states:
331
+ all_hidden_states.append(x)
332
+ x = block(x)
333
+
334
+ if output_hidden_states:
335
+ all_hidden_states.append(x)
336
+
337
+ # Final layer norm and output head
338
+ x = self.transformer.ln_f(x)
339
+ logits = self.coch_head(x)
340
+
341
+ # Compute loss if labels provided
342
+ loss = None
343
+ if labels is not None:
344
+ loss = F.cross_entropy(
345
+ logits.reshape(-1, self.config.vocab_size),
346
+ labels.reshape(-1),
347
+ )
348
+
349
+ # Multi-token prediction loss
350
+ if self.future_heads is not None:
351
+ for i, head in enumerate(self.future_heads):
352
+ future_logits = head(x[:, :-(i+1)])
353
+ loss = loss + F.cross_entropy(
354
+ future_logits.reshape(-1, self.config.vocab_size),
355
+ labels[:, (i+1):].reshape(-1),
356
+ )
357
+
358
+ if not return_dict:
359
+ if labels is not None:
360
+ return logits, loss
361
+ return logits, None
362
+
363
+ return CausalLMOutput(
364
+ loss=loss,
365
+ logits=logits,
366
+ hidden_states=all_hidden_states if output_hidden_states else None,
367
+ )
368
+
369
+ def sample_logits(
370
+ self,
371
+ logits: torch.FloatTensor,
372
+ temperature: float = 0.9,
373
+ top_k: Optional[int] = None,
374
+ top_p: Optional[float] = None,
375
+ ) -> torch.LongTensor:
376
+ """Sample from logits with temperature, top-k, and top-p."""
377
+ if temperature == 0.0:
378
+ return torch.argmax(logits, dim=-1)
379
+
380
+ logits = logits / temperature
381
+
382
+ if top_k is not None:
383
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
384
+ logits[logits < v[..., [-1]]] = -float('Inf')
385
+
386
+ if top_p is not None:
387
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
388
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
389
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
390
+ sorted_indices_to_remove = cumulative_probs > top_p
391
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
392
+ sorted_indices_to_remove[..., 0] = 0
393
+ indices_to_remove = sorted_indices_to_remove.scatter(
394
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
395
+ )
396
+ logits[indices_to_remove] = -float('Inf')
397
+
398
+ probs = F.softmax(logits, dim=-1)
399
+ flat_probs = probs.view(-1, probs.size(-1))
400
+ sampled = torch.multinomial(flat_probs, num_samples=1)
401
+ sampled = sampled.view(*logits.shape[:-1])
402
+ return sampled
403
+
404
+ @torch.no_grad()
405
+ def generate(
406
+ self,
407
+ seq: torch.Tensor,
408
+ n_tokens: int = 1,
409
+ temp: float = 1.0,
410
+ top_k: Optional[int] = None,
411
+ top_p: Optional[float] = None,
412
+ seed: Optional[int] = None,
413
+ ):
414
+ """
415
+ Generate new tokens autoregressively.
416
+
417
+ Args:
418
+ seq: Input token IDs of shape (batch_size, seq_len)
419
+ n_tokens: Number of tokens to generate
420
+ temp: Sampling temperature
421
+ top_k: Top-k sampling parameter
422
+ top_p: Nucleus sampling parameter
423
+ seed: Random seed
424
+
425
+ Returns:
426
+ Tuple of (generated_tokens, all_logits)
427
+ """
428
+ import random
429
+ import numpy as np
430
+
431
+ if seed is not None:
432
+ random.seed(seed)
433
+ np.random.seed(seed)
434
+ torch.manual_seed(seed)
435
+
436
+ all_logits = []
437
+ device = seq.device
438
+ b, t = seq.size()
439
+
440
+ # Encode conditioning sequence into KV cache
441
+ tok_emb = self.transformer.wte(seq)
442
+ x = self.transformer.drop(tok_emb)
443
+
444
+ k_list = []
445
+ v_list = []
446
+ for block in self.transformer.h:
447
+ x, k, v = block(x, return_kv=True)
448
+ k_list.append(k)
449
+ v_list.append(v)
450
+
451
+ k_cache = torch.stack(k_list, dim=0)
452
+ v_cache = torch.stack(v_list, dim=0)
453
+ x = self.transformer.ln_f(x)
454
+
455
+ # First prediction
456
+ logits = self.coch_head(x[:, [-1]])
457
+ predictions = [self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p)]
458
+ all_logits.append(logits)
459
+
460
+ # Generate remaining tokens
461
+ for i in range(n_tokens - 1):
462
+ tok_emb = self.transformer.wte(predictions[-1])
463
+ x = self.transformer.drop(tok_emb)
464
+
465
+ k_list = []
466
+ v_list = []
467
+ for block_idx, block in enumerate(self.transformer.h):
468
+ x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
469
+ k_list.append(k)
470
+ v_list.append(v)
471
+
472
+ x = self.transformer.ln_f(x)
473
+ k_cache = torch.stack(k_list, dim=0)
474
+ v_cache = torch.stack(v_list, dim=0)
475
+
476
+ logits = self.coch_head(x)
477
+ predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
478
+ all_logits.append(logits)
479
+
480
+ pred_coch = torch.cat(predictions, dim=1)
481
+ all_logits = torch.cat(all_logits, dim=1)
482
+
483
+ return pred_coch, all_logits
484
+
485
+
486
+ # Alias for backward compatibility
487
+ AuriStream = AuriStreamModel