tantk commited on
Commit
30cbe32
·
verified ·
1 Parent(s): a0d2600

Upload voxtral_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. voxtral_inference.py +708 -0
voxtral_inference.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Voxtral Realtime 4B inference engine.
3
+
4
+ Loads directly from Mistral-format consolidated.safetensors — no transformers
5
+ dependency. Adapted from voxtral.c/python_simple_implementation.py with CUDA
6
+ and FP16 support for T4 GPUs.
7
+ """
8
+
9
+ import json
10
+ import math
11
+ import os
12
+ import base64
13
+ from typing import Iterator
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from safetensors import safe_open
20
+
21
+ # ============================================================================
22
+ # Config (from params.json)
23
+ # ============================================================================
24
+
25
+ # Encoder
26
+ ENC_DIM = 1280
27
+ ENC_LAYERS = 32
28
+ ENC_HEADS = 32
29
+ ENC_HEAD_DIM = 64
30
+ ENC_HIDDEN = 5120
31
+ ENC_KV_HEADS = 32
32
+ ENC_WINDOW = 750
33
+ ENC_NORM_EPS = 1e-5
34
+ ENC_ROPE_THETA = 1_000_000.0
35
+
36
+ # Decoder
37
+ DEC_DIM = 3072
38
+ DEC_LAYERS = 26
39
+ DEC_HEADS = 32
40
+ DEC_HEAD_DIM = 128
41
+ DEC_HIDDEN = 9216
42
+ DEC_KV_HEADS = 8
43
+ DEC_WINDOW = 8192
44
+ DEC_NORM_EPS = 1e-5
45
+ DEC_ROPE_THETA = 1_000_000.0
46
+ VOCAB_SIZE = 131072
47
+
48
+ # Audio
49
+ SAMPLE_RATE = 16000
50
+ FRAME_RATE = 12.5
51
+ NUM_MEL_BINS = 128
52
+ HOP_LENGTH = 160
53
+ WINDOW_SIZE = 400
54
+ GLOBAL_LOG_MEL_MAX = 1.5
55
+ DOWNSAMPLE_FACTOR = 4
56
+
57
+ # Ada norm
58
+ ADA_NORM_DIM = 32
59
+
60
+ # Streaming
61
+ N_LEFT_PAD_TOKENS = 32
62
+ TRANSCRIPTION_DELAY_MS = 480
63
+
64
+ # Special tokens
65
+ TOKEN_BOS = 1
66
+ TOKEN_EOS = 2
67
+ TOKEN_STREAMING_PAD = 32
68
+ TOKEN_BEGIN_AUDIO = 25
69
+ TOKEN_AUDIO = 24
70
+
71
+ # Derived constants
72
+ RAW_AUDIO_LENGTH_PER_TOK = int(SAMPLE_RATE // FRAME_RATE) # 1280
73
+ AUDIO_LENGTH_PER_TOK = RAW_AUDIO_LENGTH_PER_TOK // HOP_LENGTH # 8
74
+
75
+
76
+ def _num_delay_tokens():
77
+ delay_len = int(TRANSCRIPTION_DELAY_MS / 1000.0 * SAMPLE_RATE)
78
+ n = delay_len
79
+ if n % HOP_LENGTH != 0:
80
+ n = math.ceil(n / HOP_LENGTH - 1)
81
+ else:
82
+ n = n // HOP_LENGTH
83
+ return math.ceil(n / AUDIO_LENGTH_PER_TOK)
84
+
85
+
86
+ N_DELAY_TOKENS = _num_delay_tokens()
87
+ N_RIGHT_PAD_TOKENS = (N_DELAY_TOKENS + 1) + 10 # 17
88
+
89
+ # ============================================================================
90
+ # Mel filter bank
91
+ # ============================================================================
92
+
93
+
94
+ def _hertz_to_mel(freq):
95
+ min_log_hertz = 1000.0
96
+ min_log_mel = 15.0
97
+ logstep = 27.0 / np.log(6.4)
98
+ mels = 3.0 * freq / 200.0
99
+ if isinstance(freq, np.ndarray):
100
+ log_region = freq >= min_log_hertz
101
+ mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
102
+ elif freq >= min_log_hertz:
103
+ mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
104
+ return mels
105
+
106
+
107
+ def _mel_to_hertz(mels):
108
+ min_log_hertz = 1000.0
109
+ min_log_mel = 15.0
110
+ logstep = np.log(6.4) / 27.0
111
+ freq = 200.0 * mels / 3.0
112
+ log_region = mels >= min_log_mel
113
+ freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
114
+ return freq
115
+
116
+
117
+ def _compute_mel_filters():
118
+ num_frequency_bins = 1 + WINDOW_SIZE // 2 # 201
119
+ fft_freqs = np.linspace(0, SAMPLE_RATE // 2, num_frequency_bins)
120
+ mel_min = _hertz_to_mel(0.0)
121
+ mel_max = _hertz_to_mel(8000.0)
122
+ mel_freqs = np.linspace(mel_min, mel_max, NUM_MEL_BINS + 2)
123
+ filter_freqs = _mel_to_hertz(mel_freqs)
124
+ filter_diff = np.diff(filter_freqs)
125
+ slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
126
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
127
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
128
+ fb = np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
129
+ enorm = 2.0 / (filter_freqs[2:NUM_MEL_BINS + 2] - filter_freqs[:NUM_MEL_BINS])
130
+ fb *= np.expand_dims(enorm, 0)
131
+ return fb # [201, 128]
132
+
133
+
134
+ # ============================================================================
135
+ # Mel spectrogram
136
+ # ============================================================================
137
+
138
+
139
+ def _compute_mel_spectrogram(audio, mel_filters, device):
140
+ """audio: 1D tensor on device, mel_filters: [freq_bins, mel_bins] on device."""
141
+ window = torch.hann_window(WINDOW_SIZE, device=device)
142
+ stft = torch.stft(audio, WINDOW_SIZE, HOP_LENGTH, window=window, return_complex=True)
143
+ magnitudes = stft[..., :-1].abs() ** 2
144
+ mel_spec = mel_filters.T @ magnitudes
145
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
146
+ log_spec = torch.maximum(log_spec, torch.tensor(GLOBAL_LOG_MEL_MAX, device=device) - 8.0)
147
+ log_spec = (log_spec + 4.0) / 4.0
148
+ return log_spec # [128, frames]
149
+
150
+
151
+ # ============================================================================
152
+ # Audio streaming padding
153
+ # ============================================================================
154
+
155
+
156
+ def _pad_audio_streaming(audio_array):
157
+ mult_of = RAW_AUDIO_LENGTH_PER_TOK
158
+ n_samples = len(audio_array)
159
+ align_pad = (mult_of - (n_samples % mult_of)) % mult_of
160
+ right_pad = align_pad + N_RIGHT_PAD_TOKENS * mult_of
161
+ left_pad = N_LEFT_PAD_TOKENS * mult_of
162
+ return np.pad(audio_array, (left_pad, right_pad))
163
+
164
+
165
+ # ============================================================================
166
+ # Weight loading helpers
167
+ # ============================================================================
168
+
169
+
170
+ def _get_weight(sf_file, name, device, dtype=None):
171
+ t = sf_file.get_tensor(name)
172
+ if t.dtype == torch.bfloat16:
173
+ t = t.float()
174
+ t = t.to(device)
175
+ if dtype is not None:
176
+ t = t.to(dtype)
177
+ return t
178
+
179
+
180
+ def _get_weight_optional(sf_file, name, device, dtype=None):
181
+ try:
182
+ return _get_weight(sf_file, name, device, dtype)
183
+ except Exception:
184
+ return None
185
+
186
+
187
+ def _permute_qk_weight(w, n_heads, head_dim):
188
+ attn_in = n_heads * head_dim
189
+ attn_out = w.shape[1]
190
+ return (
191
+ w.view(n_heads, head_dim // 2, 2, attn_out)
192
+ .transpose(1, 2)
193
+ .reshape(attn_in, attn_out)
194
+ )
195
+
196
+
197
+ def _permute_qk_bias(b, n_heads, head_dim):
198
+ attn_in = n_heads * head_dim
199
+ return (
200
+ b.view(n_heads, head_dim // 2, 2)
201
+ .transpose(1, 2)
202
+ .reshape(attn_in)
203
+ )
204
+
205
+
206
+ # ============================================================================
207
+ # RMSNorm
208
+ # ============================================================================
209
+
210
+
211
+ class _RMSNorm(nn.Module):
212
+ def __init__(self, weight, eps=1e-5):
213
+ super().__init__()
214
+ self.weight = weight
215
+ self.eps = eps
216
+
217
+ def forward(self, x):
218
+ rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
219
+ return (x.float() * rms * self.weight.float()).to(x.dtype)
220
+
221
+
222
+ # ============================================================================
223
+ # RoPE
224
+ # ============================================================================
225
+
226
+
227
+ def _compute_rope_freqs(positions, head_dim, theta, device):
228
+ freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
229
+ angles = positions.float().unsqueeze(-1) * freqs.unsqueeze(0)
230
+ return torch.cos(angles), torch.sin(angles)
231
+
232
+
233
+ def _apply_rope(x, cos_f, sin_f, n_heads, head_dim, is_neox_style=False):
234
+ seq_len = x.shape[0]
235
+ x = x.view(seq_len, n_heads, head_dim)
236
+ cos_f = cos_f.unsqueeze(1)
237
+ sin_f = sin_f.unsqueeze(1)
238
+
239
+ if is_neox_style:
240
+ x1, x2 = x.chunk(2, dim=-1)
241
+ o1 = x1 * cos_f - x2 * sin_f
242
+ o2 = x2 * cos_f + x1 * sin_f
243
+ out = torch.cat([o1, o2], dim=-1)
244
+ else:
245
+ x1 = x[..., ::2]
246
+ x2 = x[..., 1::2]
247
+ o1 = x1 * cos_f - x2 * sin_f
248
+ o2 = x2 * cos_f + x1 * sin_f
249
+ out = torch.stack([o1, o2], dim=-1).flatten(-2)
250
+
251
+ return out.view(seq_len, n_heads * head_dim)
252
+
253
+
254
+ # ============================================================================
255
+ # Causal Attention
256
+ # ============================================================================
257
+
258
+
259
+ def _causal_attention(q, k, v, n_heads, n_kv_heads, head_dim, window,
260
+ q_start_pos=0, kv_start_pos=0):
261
+ seq_q = q.shape[0]
262
+ seq_kv = k.shape[0]
263
+ gqa_ratio = n_heads // n_kv_heads
264
+ device = q.device
265
+ orig_dtype = q.dtype
266
+
267
+ q = q.view(seq_q, n_heads, head_dim).transpose(0, 1).unsqueeze(0)
268
+ k = k.view(seq_kv, n_kv_heads, head_dim).transpose(0, 1).unsqueeze(0)
269
+ v = v.view(seq_kv, n_kv_heads, head_dim).transpose(0, 1).unsqueeze(0)
270
+
271
+ if gqa_ratio > 1:
272
+ k = k.repeat_interleave(gqa_ratio, dim=1)
273
+ v = v.repeat_interleave(gqa_ratio, dim=1)
274
+
275
+ qi_abs = (q_start_pos + torch.arange(seq_q, device=device)).unsqueeze(1)
276
+ kv_abs = (kv_start_pos + torch.arange(seq_kv, device=device)).unsqueeze(0)
277
+ attn_mask = (kv_abs <= qi_abs) & (kv_abs >= (qi_abs - (window - 1)))
278
+
279
+ out = F.scaled_dot_product_attention(
280
+ q.float(), k.float(), v.float(),
281
+ attn_mask=attn_mask.unsqueeze(0).unsqueeze(0),
282
+ scale=1.0 / math.sqrt(head_dim),
283
+ dropout_p=0.0,
284
+ ).to(orig_dtype)
285
+
286
+ return out.squeeze(0).transpose(0, 1).contiguous().view(seq_q, n_heads * head_dim)
287
+
288
+
289
+ # ============================================================================
290
+ # Causal Conv1d
291
+ # ============================================================================
292
+
293
+
294
+ def _causal_conv1d(x, weight, bias, stride):
295
+ kernel_size = weight.shape[2]
296
+ effective_ks = kernel_size
297
+ padding_total = effective_ks - stride
298
+
299
+ n_frames = (x.shape[-1] - effective_ks + padding_total) / stride + 1
300
+ target_length = (math.ceil(n_frames) - 1) * stride + (effective_ks - padding_total)
301
+ extra_padding = int(target_length - x.shape[-1])
302
+
303
+ x = F.pad(x, (padding_total, extra_padding), mode='constant')
304
+ return F.conv1d(x, weight, bias, stride=stride)
305
+
306
+
307
+ # ============================================================================
308
+ # TimeEmbedding
309
+ # ============================================================================
310
+
311
+
312
+ def _compute_time_embedding(t_value, dim, device, theta=10000.0):
313
+ half_dim = dim // 2
314
+ inv_freq = torch.exp(
315
+ -math.log(theta) * torch.arange(half_dim, device=device).float() / half_dim
316
+ )
317
+ emb = t_value * inv_freq
318
+ return torch.cat([emb.cos(), emb.sin()])
319
+
320
+
321
+ # ============================================================================
322
+ # Encoder forward
323
+ # ============================================================================
324
+
325
+
326
+ def _encoder_forward(mel, sf_file, device, compute_dtype):
327
+ """mel: [128, frames] on device -> [seq, 1280] on device."""
328
+ prefix = "mm_streams_embeddings.embedding_module.whisper_encoder"
329
+
330
+ mel_3d = mel.unsqueeze(0)
331
+ conv0_w = _get_weight(sf_file, f"{prefix}.conv_layers.0.conv.weight", device, compute_dtype)
332
+ conv0_b = _get_weight(sf_file, f"{prefix}.conv_layers.0.conv.bias", device, compute_dtype)
333
+ conv1_w = _get_weight(sf_file, f"{prefix}.conv_layers.1.conv.weight", device, compute_dtype)
334
+ conv1_b = _get_weight(sf_file, f"{prefix}.conv_layers.1.conv.bias", device, compute_dtype)
335
+
336
+ h = F.gelu(_causal_conv1d(mel_3d.to(compute_dtype), conv0_w, conv0_b, stride=1))
337
+ h = F.gelu(_causal_conv1d(h, conv1_w, conv1_b, stride=2))
338
+ h = h.squeeze(0).transpose(0, 1) # [seq, 1280]
339
+ conv_len = h.shape[0]
340
+
341
+ trunc = conv_len % DOWNSAMPLE_FACTOR
342
+ if trunc > 0:
343
+ h = h[trunc:]
344
+ seq_len = h.shape[0]
345
+
346
+ positions = torch.arange(seq_len, device=device)
347
+ rope_cos, rope_sin = _compute_rope_freqs(positions, ENC_HEAD_DIM, ENC_ROPE_THETA, device)
348
+
349
+ for layer in range(ENC_LAYERS):
350
+ lp = f"{prefix}.transformer.layers.{layer}"
351
+
352
+ attn_norm_w = _get_weight(sf_file, f"{lp}.attention_norm.weight", device)
353
+ norm = _RMSNorm(attn_norm_w, ENC_NORM_EPS)
354
+ x_norm = norm(h).to(compute_dtype)
355
+
356
+ wq = _get_weight(sf_file, f"{lp}.attention.wq.weight", device, compute_dtype)
357
+ wq_b = _get_weight(sf_file, f"{lp}.attention.wq.bias", device, compute_dtype)
358
+ wk = _get_weight(sf_file, f"{lp}.attention.wk.weight", device, compute_dtype)
359
+ wv = _get_weight(sf_file, f"{lp}.attention.wv.weight", device, compute_dtype)
360
+ wv_b = _get_weight(sf_file, f"{lp}.attention.wv.bias", device, compute_dtype)
361
+ wo = _get_weight(sf_file, f"{lp}.attention.wo.weight", device, compute_dtype)
362
+ wo_b = _get_weight(sf_file, f"{lp}.attention.wo.bias", device, compute_dtype)
363
+
364
+ q = F.linear(x_norm, wq, wq_b)
365
+ k = F.linear(x_norm, wk)
366
+ v = F.linear(x_norm, wv, wv_b)
367
+
368
+ q = _apply_rope(q, rope_cos, rope_sin, ENC_HEADS, ENC_HEAD_DIM, is_neox_style=False)
369
+ k = _apply_rope(k, rope_cos, rope_sin, ENC_KV_HEADS, ENC_HEAD_DIM, is_neox_style=False)
370
+
371
+ attn_out = _causal_attention(q, k, v, ENC_HEADS, ENC_KV_HEADS, ENC_HEAD_DIM, ENC_WINDOW)
372
+
373
+ h = h + F.linear(attn_out, wo, wo_b)
374
+
375
+ ffn_norm_w = _get_weight(sf_file, f"{lp}.ffn_norm.weight", device)
376
+ ffn_norm = _RMSNorm(ffn_norm_w, ENC_NORM_EPS)
377
+ x_norm = ffn_norm(h).to(compute_dtype)
378
+
379
+ w1 = _get_weight(sf_file, f"{lp}.feed_forward.w1.weight", device, compute_dtype)
380
+ w2 = _get_weight(sf_file, f"{lp}.feed_forward.w2.weight", device, compute_dtype)
381
+ w2_b = _get_weight(sf_file, f"{lp}.feed_forward.w2.bias", device, compute_dtype)
382
+ w3 = _get_weight(sf_file, f"{lp}.feed_forward.w3.weight", device, compute_dtype)
383
+
384
+ gate = F.silu(F.linear(x_norm, w1))
385
+ up = F.linear(x_norm, w3)
386
+ h = h + F.linear(gate * up, w2, w2_b)
387
+
388
+ final_norm_w = _get_weight(sf_file, f"{prefix}.transformer.norm.weight", device)
389
+ final_norm = _RMSNorm(final_norm_w, ENC_NORM_EPS)
390
+ h = final_norm(h)
391
+
392
+ return h # [seq, 1280]
393
+
394
+
395
+ # ============================================================================
396
+ # Adapter forward
397
+ # ============================================================================
398
+
399
+
400
+ def _adapter_forward(enc_out, sf_file, device, compute_dtype):
401
+ """enc_out: [seq, 1280] -> [seq/4, 3072]."""
402
+ prefix = "mm_streams_embeddings.embedding_module"
403
+ w0 = _get_weight(sf_file, f"{prefix}.audio_language_projection.0.weight", device, compute_dtype)
404
+ w1 = _get_weight(sf_file, f"{prefix}.audio_language_projection.2.weight", device, compute_dtype)
405
+
406
+ seq_len = enc_out.shape[0]
407
+ ds = enc_out.reshape(seq_len // DOWNSAMPLE_FACTOR, ENC_DIM * DOWNSAMPLE_FACTOR)
408
+
409
+ out = F.gelu(F.linear(ds.to(compute_dtype), w0))
410
+ out = F.linear(out, w1)
411
+
412
+ return out # [seq/4, 3072]
413
+
414
+
415
+ # ============================================================================
416
+ # Decoder
417
+ # ============================================================================
418
+
419
+
420
+ class _Decoder:
421
+ def __init__(self, sf_file, device, compute_dtype):
422
+ self.sf = sf_file
423
+ self.device = device
424
+ self.compute_dtype = compute_dtype
425
+ self.tok_embeddings = _get_weight(
426
+ sf_file,
427
+ "mm_streams_embeddings.embedding_module.tok_embeddings.weight",
428
+ device, compute_dtype,
429
+ )
430
+ self.final_norm = _get_weight(sf_file, "norm.weight", device)
431
+ self.kv_cache = {}
432
+
433
+ self.layers = []
434
+ for i in range(DEC_LAYERS):
435
+ self.layers.append(self._load_layer(i))
436
+
437
+ def _load_layer(self, i):
438
+ sf = self.sf
439
+ lp = f"layers.{i}"
440
+ device = self.device
441
+ dtype = self.compute_dtype
442
+
443
+ return {
444
+ 'attention_norm': _get_weight(sf, f"{lp}.attention_norm.weight", device),
445
+ 'ffn_norm': _get_weight(sf, f"{lp}.ffn_norm.weight", device),
446
+ 'wq': _get_weight(sf, f"{lp}.attention.wq.weight", device, dtype),
447
+ 'wk': _get_weight(sf, f"{lp}.attention.wk.weight", device, dtype),
448
+ 'wv': _get_weight(sf, f"{lp}.attention.wv.weight", device, dtype),
449
+ 'wo': _get_weight(sf, f"{lp}.attention.wo.weight", device, dtype),
450
+ 'w1': _get_weight(sf, f"{lp}.feed_forward.w1.weight", device, dtype),
451
+ 'w2': _get_weight(sf, f"{lp}.feed_forward.w2.weight", device, dtype),
452
+ 'w3': _get_weight(sf, f"{lp}.feed_forward.w3.weight", device, dtype),
453
+ 'ada_down': _get_weight(sf, f"{lp}.ada_rms_norm_t_cond.0.weight", device, dtype),
454
+ 'ada_up': _get_weight(sf, f"{lp}.ada_rms_norm_t_cond.2.weight", device, dtype),
455
+ }
456
+
457
+ def embed_token(self, token_id):
458
+ return self.tok_embeddings[token_id]
459
+
460
+ def embed_tokens(self, token_ids):
461
+ return self.tok_embeddings[token_ids]
462
+
463
+ def _layer_forward(self, h, layer_idx, pos, kv_seq_len, t_cond=None):
464
+ L = self.layers[layer_idx]
465
+ seq_len = h.shape[0]
466
+ dtype = self.compute_dtype
467
+ device = self.device
468
+
469
+ if h.dtype != dtype:
470
+ h = h.to(dtype)
471
+
472
+ norm = _RMSNorm(L['attention_norm'], DEC_NORM_EPS)
473
+ x_norm = norm(h).to(dtype)
474
+
475
+ q = F.linear(x_norm, L['wq'])
476
+ k = F.linear(x_norm, L['wk'])
477
+ v = F.linear(x_norm, L['wv'])
478
+
479
+ positions = torch.arange(pos, pos + seq_len, device=device)
480
+ rope_cos, rope_sin = _compute_rope_freqs(positions, DEC_HEAD_DIM, DEC_ROPE_THETA, device)
481
+ q = _apply_rope(q.float(), rope_cos, rope_sin, DEC_HEADS, DEC_HEAD_DIM, is_neox_style=False).to(dtype)
482
+ k = _apply_rope(k.float(), rope_cos, rope_sin, DEC_KV_HEADS, DEC_HEAD_DIM, is_neox_style=False).to(dtype)
483
+
484
+ if layer_idx not in self.kv_cache:
485
+ k_cache = k
486
+ v_cache = v
487
+ else:
488
+ k_cache, v_cache = self.kv_cache[layer_idx]
489
+ k_cache = torch.cat([k_cache, k], dim=0)
490
+ v_cache = torch.cat([v_cache, v], dim=0)
491
+
492
+ if k_cache.shape[0] > DEC_WINDOW:
493
+ k_cache = k_cache[-DEC_WINDOW:]
494
+ v_cache = v_cache[-DEC_WINDOW:]
495
+
496
+ self.kv_cache[layer_idx] = (k_cache, v_cache)
497
+ full_k, full_v = self.kv_cache[layer_idx]
498
+
499
+ kv_start_pos = (pos + seq_len - 1) - (full_k.shape[0] - 1)
500
+ attn_out = _causal_attention(
501
+ q, full_k, full_v,
502
+ DEC_HEADS, DEC_KV_HEADS, DEC_HEAD_DIM,
503
+ DEC_WINDOW,
504
+ q_start_pos=pos,
505
+ kv_start_pos=kv_start_pos,
506
+ )
507
+
508
+ attn_proj = F.linear(attn_out, L['wo'])
509
+ h = h + attn_proj
510
+
511
+ ffn_norm = _RMSNorm(L['ffn_norm'], DEC_NORM_EPS)
512
+ h_norm = ffn_norm(h).to(dtype)
513
+
514
+ if t_cond is not None:
515
+ t_cond_dt = t_cond.to(dtype)
516
+ ada_hidden = F.gelu(F.linear(t_cond_dt, L['ada_down']))
517
+ ada_scale = F.linear(ada_hidden, L['ada_up'])
518
+ h_norm = h_norm * (1 + ada_scale.unsqueeze(0))
519
+
520
+ gate = F.silu(F.linear(h_norm, L['w1']))
521
+ up = F.linear(h_norm, L['w3'])
522
+ h = h + F.linear(gate * up, L['w2'])
523
+
524
+ return h
525
+
526
+ def prefill(self, input_embeds, t_cond):
527
+ self.kv_cache = {}
528
+ h = input_embeds.to(self.compute_dtype)
529
+ seq_len = h.shape[0]
530
+
531
+ for layer in range(DEC_LAYERS):
532
+ h = self._layer_forward(h, layer, 0, seq_len, t_cond=t_cond)
533
+
534
+ return h
535
+
536
+ def forward_one(self, embed, pos, t_cond):
537
+ h = embed.unsqueeze(0) if embed.dim() == 1 else embed
538
+ h = h.to(self.compute_dtype)
539
+
540
+ for layer in range(DEC_LAYERS):
541
+ h = self._layer_forward(h, layer, pos, pos + 1, t_cond=t_cond)
542
+
543
+ norm = _RMSNorm(self.final_norm, DEC_NORM_EPS)
544
+ h = norm(h)
545
+
546
+ logits = F.linear(h.float().squeeze(0), self.tok_embeddings.float())
547
+ return logits
548
+
549
+
550
+ # ============================================================================
551
+ # Tokenizer
552
+ # ============================================================================
553
+
554
+
555
+ def _load_tokenizer(model_dir):
556
+ tekken_path = os.path.join(model_dir, "tekken.json")
557
+ with open(tekken_path, "r", encoding="utf-8") as f:
558
+ data = json.load(f)
559
+
560
+ vocab = data["vocab"]
561
+ config = data.get("config", {})
562
+ n_special = int(config.get("default_num_special_tokens", 1000))
563
+ special_ids = {int(st["rank"]) for st in data.get("special_tokens", []) if "rank" in st}
564
+
565
+ bytes_cache = {}
566
+
567
+ def token_bytes(token_id: int) -> bytes:
568
+ b = bytes_cache.get(token_id)
569
+ if b is not None:
570
+ return b
571
+ if token_id < 0:
572
+ bytes_cache[token_id] = b""
573
+ return b""
574
+ if token_id < n_special or token_id in special_ids:
575
+ bytes_cache[token_id] = b""
576
+ return b""
577
+ vocab_id = token_id - n_special
578
+ if vocab_id < 0 or vocab_id >= len(vocab):
579
+ bytes_cache[token_id] = b""
580
+ return b""
581
+ b = base64.b64decode(vocab[vocab_id]["token_bytes"])
582
+ bytes_cache[token_id] = b
583
+ return b
584
+
585
+ def decode(token_ids):
586
+ out = bytearray()
587
+ for token_id in map(int, token_ids):
588
+ if token_id < n_special or token_id in special_ids:
589
+ continue
590
+ out += token_bytes(token_id)
591
+ return out.decode("utf-8", errors="replace")
592
+
593
+ return decode
594
+
595
+
596
+ # ============================================================================
597
+ # VoxtralModel — singleton inference engine
598
+ # ============================================================================
599
+
600
+
601
+ class VoxtralModel:
602
+ """Load Voxtral from Mistral-format safetensors and run inference on CUDA."""
603
+
604
+ def __init__(self, model_dir: str):
605
+ self.model_dir = model_dir
606
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
607
+ # FP16 for T4 (no good bf16 support); float32 on CPU
608
+ self.compute_dtype = torch.float16 if self.device.type == "cuda" else torch.float32
609
+
610
+ sf_path = os.path.join(model_dir, "consolidated.safetensors")
611
+ self._sf_file = safe_open(sf_path, framework="pt")
612
+
613
+ # Precompute mel filters on device
614
+ self._mel_filters = torch.tensor(
615
+ _compute_mel_filters(), dtype=torch.float32, device=self.device
616
+ )
617
+
618
+ # Preload decoder (holds all layer weights on GPU)
619
+ self._decoder = _Decoder(self._sf_file, self.device, self.compute_dtype)
620
+
621
+ # Load tokenizer
622
+ self._decode = _load_tokenizer(model_dir)
623
+
624
+ def _prepare(self, audio_16k: np.ndarray):
625
+ """Audio array -> (adapter_out, prompt_ids, t_cond) all on device."""
626
+ prompt_ids = [TOKEN_BOS] + [TOKEN_STREAMING_PAD] * (N_LEFT_PAD_TOKENS + N_DELAY_TOKENS)
627
+ padded = _pad_audio_streaming(audio_16k).astype(np.float32)
628
+
629
+ audio_tensor = torch.tensor(padded, dtype=torch.float32, device=self.device)
630
+ mel = _compute_mel_spectrogram(audio_tensor, self._mel_filters, self.device)
631
+
632
+ if mel.shape[1] % 2 != 0:
633
+ mel = mel[:, 1:]
634
+
635
+ with torch.no_grad():
636
+ enc_out = _encoder_forward(mel, self._sf_file, self.device, self.compute_dtype)
637
+ adapter_out = _adapter_forward(enc_out, self._sf_file, self.device, self.compute_dtype)
638
+
639
+ t_cond = _compute_time_embedding(float(N_DELAY_TOKENS), DEC_DIM, self.device)
640
+
641
+ return adapter_out, prompt_ids, t_cond
642
+
643
+ def transcribe(self, audio_16k: np.ndarray) -> str:
644
+ """Full pipeline: 16 kHz float32 mono audio -> transcribed text."""
645
+ adapter_out, prompt_ids, t_cond = self._prepare(audio_16k)
646
+
647
+ n_audio = adapter_out.shape[0]
648
+ L = len(prompt_ids)
649
+
650
+ prompt_ids_t = torch.tensor(prompt_ids, dtype=torch.long, device=self.device)
651
+ prefix_text_embeds = self._decoder.embed_tokens(prompt_ids_t)
652
+ prefix_embeds = adapter_out[:L] + prefix_text_embeds
653
+
654
+ with torch.no_grad():
655
+ if L > 1:
656
+ _ = self._decoder.prefill(prefix_embeds[:-1], t_cond)
657
+ logits = self._decoder.forward_one(prefix_embeds[-1], pos=L - 1, t_cond=t_cond)
658
+ token = int(logits.argmax().item())
659
+
660
+ generated = [token]
661
+
662
+ with torch.no_grad():
663
+ for pos in range(L, n_audio):
664
+ if token == TOKEN_EOS:
665
+ break
666
+ embed = adapter_out[pos] + self._decoder.embed_token(token)
667
+ logits = self._decoder.forward_one(embed, pos=pos, t_cond=t_cond)
668
+ token = int(logits.argmax().item())
669
+ generated.append(token)
670
+
671
+ if generated and generated[-1] == TOKEN_EOS:
672
+ generated = generated[:-1]
673
+
674
+ return self._decode(generated).strip()
675
+
676
+ def transcribe_stream(self, audio_16k: np.ndarray) -> Iterator[str]:
677
+ """Streaming pipeline: yields decoded text fragments as tokens are generated."""
678
+ adapter_out, prompt_ids, t_cond = self._prepare(audio_16k)
679
+
680
+ n_audio = adapter_out.shape[0]
681
+ L = len(prompt_ids)
682
+
683
+ prompt_ids_t = torch.tensor(prompt_ids, dtype=torch.long, device=self.device)
684
+ prefix_text_embeds = self._decoder.embed_tokens(prompt_ids_t)
685
+ prefix_embeds = adapter_out[:L] + prefix_text_embeds
686
+
687
+ with torch.no_grad():
688
+ if L > 1:
689
+ _ = self._decoder.prefill(prefix_embeds[:-1], t_cond)
690
+ logits = self._decoder.forward_one(prefix_embeds[-1], pos=L - 1, t_cond=t_cond)
691
+ token = int(logits.argmax().item())
692
+
693
+ if token != TOKEN_EOS:
694
+ text = self._decode([token])
695
+ if text:
696
+ yield text
697
+
698
+ with torch.no_grad():
699
+ for pos in range(L, n_audio):
700
+ if token == TOKEN_EOS:
701
+ break
702
+ embed = adapter_out[pos] + self._decoder.embed_token(token)
703
+ logits = self._decoder.forward_one(embed, pos=pos, t_cond=t_cond)
704
+ token = int(logits.argmax().item())
705
+ if token != TOKEN_EOS:
706
+ text = self._decode([token])
707
+ if text:
708
+ yield text