Upload voxtral_inference.py with huggingface_hub
Browse files- voxtral_inference.py +708 -0
voxtral_inference.py
ADDED
|
@@ -0,0 +1,708 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Voxtral Realtime 4B inference engine.
|
| 3 |
+
|
| 4 |
+
Loads directly from Mistral-format consolidated.safetensors — no transformers
|
| 5 |
+
dependency. Adapted from voxtral.c/python_simple_implementation.py with CUDA
|
| 6 |
+
and FP16 support for T4 GPUs.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import math
|
| 11 |
+
import os
|
| 12 |
+
import base64
|
| 13 |
+
from typing import Iterator
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from safetensors import safe_open
|
| 20 |
+
|
| 21 |
+
# ============================================================================
|
| 22 |
+
# Config (from params.json)
|
| 23 |
+
# ============================================================================
|
| 24 |
+
|
| 25 |
+
# Encoder
|
| 26 |
+
ENC_DIM = 1280
|
| 27 |
+
ENC_LAYERS = 32
|
| 28 |
+
ENC_HEADS = 32
|
| 29 |
+
ENC_HEAD_DIM = 64
|
| 30 |
+
ENC_HIDDEN = 5120
|
| 31 |
+
ENC_KV_HEADS = 32
|
| 32 |
+
ENC_WINDOW = 750
|
| 33 |
+
ENC_NORM_EPS = 1e-5
|
| 34 |
+
ENC_ROPE_THETA = 1_000_000.0
|
| 35 |
+
|
| 36 |
+
# Decoder
|
| 37 |
+
DEC_DIM = 3072
|
| 38 |
+
DEC_LAYERS = 26
|
| 39 |
+
DEC_HEADS = 32
|
| 40 |
+
DEC_HEAD_DIM = 128
|
| 41 |
+
DEC_HIDDEN = 9216
|
| 42 |
+
DEC_KV_HEADS = 8
|
| 43 |
+
DEC_WINDOW = 8192
|
| 44 |
+
DEC_NORM_EPS = 1e-5
|
| 45 |
+
DEC_ROPE_THETA = 1_000_000.0
|
| 46 |
+
VOCAB_SIZE = 131072
|
| 47 |
+
|
| 48 |
+
# Audio
|
| 49 |
+
SAMPLE_RATE = 16000
|
| 50 |
+
FRAME_RATE = 12.5
|
| 51 |
+
NUM_MEL_BINS = 128
|
| 52 |
+
HOP_LENGTH = 160
|
| 53 |
+
WINDOW_SIZE = 400
|
| 54 |
+
GLOBAL_LOG_MEL_MAX = 1.5
|
| 55 |
+
DOWNSAMPLE_FACTOR = 4
|
| 56 |
+
|
| 57 |
+
# Ada norm
|
| 58 |
+
ADA_NORM_DIM = 32
|
| 59 |
+
|
| 60 |
+
# Streaming
|
| 61 |
+
N_LEFT_PAD_TOKENS = 32
|
| 62 |
+
TRANSCRIPTION_DELAY_MS = 480
|
| 63 |
+
|
| 64 |
+
# Special tokens
|
| 65 |
+
TOKEN_BOS = 1
|
| 66 |
+
TOKEN_EOS = 2
|
| 67 |
+
TOKEN_STREAMING_PAD = 32
|
| 68 |
+
TOKEN_BEGIN_AUDIO = 25
|
| 69 |
+
TOKEN_AUDIO = 24
|
| 70 |
+
|
| 71 |
+
# Derived constants
|
| 72 |
+
RAW_AUDIO_LENGTH_PER_TOK = int(SAMPLE_RATE // FRAME_RATE) # 1280
|
| 73 |
+
AUDIO_LENGTH_PER_TOK = RAW_AUDIO_LENGTH_PER_TOK // HOP_LENGTH # 8
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _num_delay_tokens():
|
| 77 |
+
delay_len = int(TRANSCRIPTION_DELAY_MS / 1000.0 * SAMPLE_RATE)
|
| 78 |
+
n = delay_len
|
| 79 |
+
if n % HOP_LENGTH != 0:
|
| 80 |
+
n = math.ceil(n / HOP_LENGTH - 1)
|
| 81 |
+
else:
|
| 82 |
+
n = n // HOP_LENGTH
|
| 83 |
+
return math.ceil(n / AUDIO_LENGTH_PER_TOK)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
N_DELAY_TOKENS = _num_delay_tokens()
|
| 87 |
+
N_RIGHT_PAD_TOKENS = (N_DELAY_TOKENS + 1) + 10 # 17
|
| 88 |
+
|
| 89 |
+
# ============================================================================
|
| 90 |
+
# Mel filter bank
|
| 91 |
+
# ============================================================================
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _hertz_to_mel(freq):
|
| 95 |
+
min_log_hertz = 1000.0
|
| 96 |
+
min_log_mel = 15.0
|
| 97 |
+
logstep = 27.0 / np.log(6.4)
|
| 98 |
+
mels = 3.0 * freq / 200.0
|
| 99 |
+
if isinstance(freq, np.ndarray):
|
| 100 |
+
log_region = freq >= min_log_hertz
|
| 101 |
+
mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
|
| 102 |
+
elif freq >= min_log_hertz:
|
| 103 |
+
mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
|
| 104 |
+
return mels
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _mel_to_hertz(mels):
|
| 108 |
+
min_log_hertz = 1000.0
|
| 109 |
+
min_log_mel = 15.0
|
| 110 |
+
logstep = np.log(6.4) / 27.0
|
| 111 |
+
freq = 200.0 * mels / 3.0
|
| 112 |
+
log_region = mels >= min_log_mel
|
| 113 |
+
freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
|
| 114 |
+
return freq
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _compute_mel_filters():
|
| 118 |
+
num_frequency_bins = 1 + WINDOW_SIZE // 2 # 201
|
| 119 |
+
fft_freqs = np.linspace(0, SAMPLE_RATE // 2, num_frequency_bins)
|
| 120 |
+
mel_min = _hertz_to_mel(0.0)
|
| 121 |
+
mel_max = _hertz_to_mel(8000.0)
|
| 122 |
+
mel_freqs = np.linspace(mel_min, mel_max, NUM_MEL_BINS + 2)
|
| 123 |
+
filter_freqs = _mel_to_hertz(mel_freqs)
|
| 124 |
+
filter_diff = np.diff(filter_freqs)
|
| 125 |
+
slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
|
| 126 |
+
down_slopes = -slopes[:, :-2] / filter_diff[:-1]
|
| 127 |
+
up_slopes = slopes[:, 2:] / filter_diff[1:]
|
| 128 |
+
fb = np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
|
| 129 |
+
enorm = 2.0 / (filter_freqs[2:NUM_MEL_BINS + 2] - filter_freqs[:NUM_MEL_BINS])
|
| 130 |
+
fb *= np.expand_dims(enorm, 0)
|
| 131 |
+
return fb # [201, 128]
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ============================================================================
|
| 135 |
+
# Mel spectrogram
|
| 136 |
+
# ============================================================================
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _compute_mel_spectrogram(audio, mel_filters, device):
|
| 140 |
+
"""audio: 1D tensor on device, mel_filters: [freq_bins, mel_bins] on device."""
|
| 141 |
+
window = torch.hann_window(WINDOW_SIZE, device=device)
|
| 142 |
+
stft = torch.stft(audio, WINDOW_SIZE, HOP_LENGTH, window=window, return_complex=True)
|
| 143 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
| 144 |
+
mel_spec = mel_filters.T @ magnitudes
|
| 145 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 146 |
+
log_spec = torch.maximum(log_spec, torch.tensor(GLOBAL_LOG_MEL_MAX, device=device) - 8.0)
|
| 147 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 148 |
+
return log_spec # [128, frames]
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ============================================================================
|
| 152 |
+
# Audio streaming padding
|
| 153 |
+
# ============================================================================
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _pad_audio_streaming(audio_array):
|
| 157 |
+
mult_of = RAW_AUDIO_LENGTH_PER_TOK
|
| 158 |
+
n_samples = len(audio_array)
|
| 159 |
+
align_pad = (mult_of - (n_samples % mult_of)) % mult_of
|
| 160 |
+
right_pad = align_pad + N_RIGHT_PAD_TOKENS * mult_of
|
| 161 |
+
left_pad = N_LEFT_PAD_TOKENS * mult_of
|
| 162 |
+
return np.pad(audio_array, (left_pad, right_pad))
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ============================================================================
|
| 166 |
+
# Weight loading helpers
|
| 167 |
+
# ============================================================================
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _get_weight(sf_file, name, device, dtype=None):
|
| 171 |
+
t = sf_file.get_tensor(name)
|
| 172 |
+
if t.dtype == torch.bfloat16:
|
| 173 |
+
t = t.float()
|
| 174 |
+
t = t.to(device)
|
| 175 |
+
if dtype is not None:
|
| 176 |
+
t = t.to(dtype)
|
| 177 |
+
return t
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _get_weight_optional(sf_file, name, device, dtype=None):
|
| 181 |
+
try:
|
| 182 |
+
return _get_weight(sf_file, name, device, dtype)
|
| 183 |
+
except Exception:
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _permute_qk_weight(w, n_heads, head_dim):
|
| 188 |
+
attn_in = n_heads * head_dim
|
| 189 |
+
attn_out = w.shape[1]
|
| 190 |
+
return (
|
| 191 |
+
w.view(n_heads, head_dim // 2, 2, attn_out)
|
| 192 |
+
.transpose(1, 2)
|
| 193 |
+
.reshape(attn_in, attn_out)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _permute_qk_bias(b, n_heads, head_dim):
|
| 198 |
+
attn_in = n_heads * head_dim
|
| 199 |
+
return (
|
| 200 |
+
b.view(n_heads, head_dim // 2, 2)
|
| 201 |
+
.transpose(1, 2)
|
| 202 |
+
.reshape(attn_in)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# ============================================================================
|
| 207 |
+
# RMSNorm
|
| 208 |
+
# ============================================================================
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class _RMSNorm(nn.Module):
|
| 212 |
+
def __init__(self, weight, eps=1e-5):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.weight = weight
|
| 215 |
+
self.eps = eps
|
| 216 |
+
|
| 217 |
+
def forward(self, x):
|
| 218 |
+
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
| 219 |
+
return (x.float() * rms * self.weight.float()).to(x.dtype)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# ============================================================================
|
| 223 |
+
# RoPE
|
| 224 |
+
# ============================================================================
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _compute_rope_freqs(positions, head_dim, theta, device):
|
| 228 |
+
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
|
| 229 |
+
angles = positions.float().unsqueeze(-1) * freqs.unsqueeze(0)
|
| 230 |
+
return torch.cos(angles), torch.sin(angles)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _apply_rope(x, cos_f, sin_f, n_heads, head_dim, is_neox_style=False):
|
| 234 |
+
seq_len = x.shape[0]
|
| 235 |
+
x = x.view(seq_len, n_heads, head_dim)
|
| 236 |
+
cos_f = cos_f.unsqueeze(1)
|
| 237 |
+
sin_f = sin_f.unsqueeze(1)
|
| 238 |
+
|
| 239 |
+
if is_neox_style:
|
| 240 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 241 |
+
o1 = x1 * cos_f - x2 * sin_f
|
| 242 |
+
o2 = x2 * cos_f + x1 * sin_f
|
| 243 |
+
out = torch.cat([o1, o2], dim=-1)
|
| 244 |
+
else:
|
| 245 |
+
x1 = x[..., ::2]
|
| 246 |
+
x2 = x[..., 1::2]
|
| 247 |
+
o1 = x1 * cos_f - x2 * sin_f
|
| 248 |
+
o2 = x2 * cos_f + x1 * sin_f
|
| 249 |
+
out = torch.stack([o1, o2], dim=-1).flatten(-2)
|
| 250 |
+
|
| 251 |
+
return out.view(seq_len, n_heads * head_dim)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ============================================================================
|
| 255 |
+
# Causal Attention
|
| 256 |
+
# ============================================================================
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _causal_attention(q, k, v, n_heads, n_kv_heads, head_dim, window,
|
| 260 |
+
q_start_pos=0, kv_start_pos=0):
|
| 261 |
+
seq_q = q.shape[0]
|
| 262 |
+
seq_kv = k.shape[0]
|
| 263 |
+
gqa_ratio = n_heads // n_kv_heads
|
| 264 |
+
device = q.device
|
| 265 |
+
orig_dtype = q.dtype
|
| 266 |
+
|
| 267 |
+
q = q.view(seq_q, n_heads, head_dim).transpose(0, 1).unsqueeze(0)
|
| 268 |
+
k = k.view(seq_kv, n_kv_heads, head_dim).transpose(0, 1).unsqueeze(0)
|
| 269 |
+
v = v.view(seq_kv, n_kv_heads, head_dim).transpose(0, 1).unsqueeze(0)
|
| 270 |
+
|
| 271 |
+
if gqa_ratio > 1:
|
| 272 |
+
k = k.repeat_interleave(gqa_ratio, dim=1)
|
| 273 |
+
v = v.repeat_interleave(gqa_ratio, dim=1)
|
| 274 |
+
|
| 275 |
+
qi_abs = (q_start_pos + torch.arange(seq_q, device=device)).unsqueeze(1)
|
| 276 |
+
kv_abs = (kv_start_pos + torch.arange(seq_kv, device=device)).unsqueeze(0)
|
| 277 |
+
attn_mask = (kv_abs <= qi_abs) & (kv_abs >= (qi_abs - (window - 1)))
|
| 278 |
+
|
| 279 |
+
out = F.scaled_dot_product_attention(
|
| 280 |
+
q.float(), k.float(), v.float(),
|
| 281 |
+
attn_mask=attn_mask.unsqueeze(0).unsqueeze(0),
|
| 282 |
+
scale=1.0 / math.sqrt(head_dim),
|
| 283 |
+
dropout_p=0.0,
|
| 284 |
+
).to(orig_dtype)
|
| 285 |
+
|
| 286 |
+
return out.squeeze(0).transpose(0, 1).contiguous().view(seq_q, n_heads * head_dim)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# ============================================================================
|
| 290 |
+
# Causal Conv1d
|
| 291 |
+
# ============================================================================
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _causal_conv1d(x, weight, bias, stride):
|
| 295 |
+
kernel_size = weight.shape[2]
|
| 296 |
+
effective_ks = kernel_size
|
| 297 |
+
padding_total = effective_ks - stride
|
| 298 |
+
|
| 299 |
+
n_frames = (x.shape[-1] - effective_ks + padding_total) / stride + 1
|
| 300 |
+
target_length = (math.ceil(n_frames) - 1) * stride + (effective_ks - padding_total)
|
| 301 |
+
extra_padding = int(target_length - x.shape[-1])
|
| 302 |
+
|
| 303 |
+
x = F.pad(x, (padding_total, extra_padding), mode='constant')
|
| 304 |
+
return F.conv1d(x, weight, bias, stride=stride)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# ============================================================================
|
| 308 |
+
# TimeEmbedding
|
| 309 |
+
# ============================================================================
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def _compute_time_embedding(t_value, dim, device, theta=10000.0):
|
| 313 |
+
half_dim = dim // 2
|
| 314 |
+
inv_freq = torch.exp(
|
| 315 |
+
-math.log(theta) * torch.arange(half_dim, device=device).float() / half_dim
|
| 316 |
+
)
|
| 317 |
+
emb = t_value * inv_freq
|
| 318 |
+
return torch.cat([emb.cos(), emb.sin()])
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# ============================================================================
|
| 322 |
+
# Encoder forward
|
| 323 |
+
# ============================================================================
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _encoder_forward(mel, sf_file, device, compute_dtype):
|
| 327 |
+
"""mel: [128, frames] on device -> [seq, 1280] on device."""
|
| 328 |
+
prefix = "mm_streams_embeddings.embedding_module.whisper_encoder"
|
| 329 |
+
|
| 330 |
+
mel_3d = mel.unsqueeze(0)
|
| 331 |
+
conv0_w = _get_weight(sf_file, f"{prefix}.conv_layers.0.conv.weight", device, compute_dtype)
|
| 332 |
+
conv0_b = _get_weight(sf_file, f"{prefix}.conv_layers.0.conv.bias", device, compute_dtype)
|
| 333 |
+
conv1_w = _get_weight(sf_file, f"{prefix}.conv_layers.1.conv.weight", device, compute_dtype)
|
| 334 |
+
conv1_b = _get_weight(sf_file, f"{prefix}.conv_layers.1.conv.bias", device, compute_dtype)
|
| 335 |
+
|
| 336 |
+
h = F.gelu(_causal_conv1d(mel_3d.to(compute_dtype), conv0_w, conv0_b, stride=1))
|
| 337 |
+
h = F.gelu(_causal_conv1d(h, conv1_w, conv1_b, stride=2))
|
| 338 |
+
h = h.squeeze(0).transpose(0, 1) # [seq, 1280]
|
| 339 |
+
conv_len = h.shape[0]
|
| 340 |
+
|
| 341 |
+
trunc = conv_len % DOWNSAMPLE_FACTOR
|
| 342 |
+
if trunc > 0:
|
| 343 |
+
h = h[trunc:]
|
| 344 |
+
seq_len = h.shape[0]
|
| 345 |
+
|
| 346 |
+
positions = torch.arange(seq_len, device=device)
|
| 347 |
+
rope_cos, rope_sin = _compute_rope_freqs(positions, ENC_HEAD_DIM, ENC_ROPE_THETA, device)
|
| 348 |
+
|
| 349 |
+
for layer in range(ENC_LAYERS):
|
| 350 |
+
lp = f"{prefix}.transformer.layers.{layer}"
|
| 351 |
+
|
| 352 |
+
attn_norm_w = _get_weight(sf_file, f"{lp}.attention_norm.weight", device)
|
| 353 |
+
norm = _RMSNorm(attn_norm_w, ENC_NORM_EPS)
|
| 354 |
+
x_norm = norm(h).to(compute_dtype)
|
| 355 |
+
|
| 356 |
+
wq = _get_weight(sf_file, f"{lp}.attention.wq.weight", device, compute_dtype)
|
| 357 |
+
wq_b = _get_weight(sf_file, f"{lp}.attention.wq.bias", device, compute_dtype)
|
| 358 |
+
wk = _get_weight(sf_file, f"{lp}.attention.wk.weight", device, compute_dtype)
|
| 359 |
+
wv = _get_weight(sf_file, f"{lp}.attention.wv.weight", device, compute_dtype)
|
| 360 |
+
wv_b = _get_weight(sf_file, f"{lp}.attention.wv.bias", device, compute_dtype)
|
| 361 |
+
wo = _get_weight(sf_file, f"{lp}.attention.wo.weight", device, compute_dtype)
|
| 362 |
+
wo_b = _get_weight(sf_file, f"{lp}.attention.wo.bias", device, compute_dtype)
|
| 363 |
+
|
| 364 |
+
q = F.linear(x_norm, wq, wq_b)
|
| 365 |
+
k = F.linear(x_norm, wk)
|
| 366 |
+
v = F.linear(x_norm, wv, wv_b)
|
| 367 |
+
|
| 368 |
+
q = _apply_rope(q, rope_cos, rope_sin, ENC_HEADS, ENC_HEAD_DIM, is_neox_style=False)
|
| 369 |
+
k = _apply_rope(k, rope_cos, rope_sin, ENC_KV_HEADS, ENC_HEAD_DIM, is_neox_style=False)
|
| 370 |
+
|
| 371 |
+
attn_out = _causal_attention(q, k, v, ENC_HEADS, ENC_KV_HEADS, ENC_HEAD_DIM, ENC_WINDOW)
|
| 372 |
+
|
| 373 |
+
h = h + F.linear(attn_out, wo, wo_b)
|
| 374 |
+
|
| 375 |
+
ffn_norm_w = _get_weight(sf_file, f"{lp}.ffn_norm.weight", device)
|
| 376 |
+
ffn_norm = _RMSNorm(ffn_norm_w, ENC_NORM_EPS)
|
| 377 |
+
x_norm = ffn_norm(h).to(compute_dtype)
|
| 378 |
+
|
| 379 |
+
w1 = _get_weight(sf_file, f"{lp}.feed_forward.w1.weight", device, compute_dtype)
|
| 380 |
+
w2 = _get_weight(sf_file, f"{lp}.feed_forward.w2.weight", device, compute_dtype)
|
| 381 |
+
w2_b = _get_weight(sf_file, f"{lp}.feed_forward.w2.bias", device, compute_dtype)
|
| 382 |
+
w3 = _get_weight(sf_file, f"{lp}.feed_forward.w3.weight", device, compute_dtype)
|
| 383 |
+
|
| 384 |
+
gate = F.silu(F.linear(x_norm, w1))
|
| 385 |
+
up = F.linear(x_norm, w3)
|
| 386 |
+
h = h + F.linear(gate * up, w2, w2_b)
|
| 387 |
+
|
| 388 |
+
final_norm_w = _get_weight(sf_file, f"{prefix}.transformer.norm.weight", device)
|
| 389 |
+
final_norm = _RMSNorm(final_norm_w, ENC_NORM_EPS)
|
| 390 |
+
h = final_norm(h)
|
| 391 |
+
|
| 392 |
+
return h # [seq, 1280]
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# ============================================================================
|
| 396 |
+
# Adapter forward
|
| 397 |
+
# ============================================================================
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def _adapter_forward(enc_out, sf_file, device, compute_dtype):
|
| 401 |
+
"""enc_out: [seq, 1280] -> [seq/4, 3072]."""
|
| 402 |
+
prefix = "mm_streams_embeddings.embedding_module"
|
| 403 |
+
w0 = _get_weight(sf_file, f"{prefix}.audio_language_projection.0.weight", device, compute_dtype)
|
| 404 |
+
w1 = _get_weight(sf_file, f"{prefix}.audio_language_projection.2.weight", device, compute_dtype)
|
| 405 |
+
|
| 406 |
+
seq_len = enc_out.shape[0]
|
| 407 |
+
ds = enc_out.reshape(seq_len // DOWNSAMPLE_FACTOR, ENC_DIM * DOWNSAMPLE_FACTOR)
|
| 408 |
+
|
| 409 |
+
out = F.gelu(F.linear(ds.to(compute_dtype), w0))
|
| 410 |
+
out = F.linear(out, w1)
|
| 411 |
+
|
| 412 |
+
return out # [seq/4, 3072]
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
# ============================================================================
|
| 416 |
+
# Decoder
|
| 417 |
+
# ============================================================================
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class _Decoder:
|
| 421 |
+
def __init__(self, sf_file, device, compute_dtype):
|
| 422 |
+
self.sf = sf_file
|
| 423 |
+
self.device = device
|
| 424 |
+
self.compute_dtype = compute_dtype
|
| 425 |
+
self.tok_embeddings = _get_weight(
|
| 426 |
+
sf_file,
|
| 427 |
+
"mm_streams_embeddings.embedding_module.tok_embeddings.weight",
|
| 428 |
+
device, compute_dtype,
|
| 429 |
+
)
|
| 430 |
+
self.final_norm = _get_weight(sf_file, "norm.weight", device)
|
| 431 |
+
self.kv_cache = {}
|
| 432 |
+
|
| 433 |
+
self.layers = []
|
| 434 |
+
for i in range(DEC_LAYERS):
|
| 435 |
+
self.layers.append(self._load_layer(i))
|
| 436 |
+
|
| 437 |
+
def _load_layer(self, i):
|
| 438 |
+
sf = self.sf
|
| 439 |
+
lp = f"layers.{i}"
|
| 440 |
+
device = self.device
|
| 441 |
+
dtype = self.compute_dtype
|
| 442 |
+
|
| 443 |
+
return {
|
| 444 |
+
'attention_norm': _get_weight(sf, f"{lp}.attention_norm.weight", device),
|
| 445 |
+
'ffn_norm': _get_weight(sf, f"{lp}.ffn_norm.weight", device),
|
| 446 |
+
'wq': _get_weight(sf, f"{lp}.attention.wq.weight", device, dtype),
|
| 447 |
+
'wk': _get_weight(sf, f"{lp}.attention.wk.weight", device, dtype),
|
| 448 |
+
'wv': _get_weight(sf, f"{lp}.attention.wv.weight", device, dtype),
|
| 449 |
+
'wo': _get_weight(sf, f"{lp}.attention.wo.weight", device, dtype),
|
| 450 |
+
'w1': _get_weight(sf, f"{lp}.feed_forward.w1.weight", device, dtype),
|
| 451 |
+
'w2': _get_weight(sf, f"{lp}.feed_forward.w2.weight", device, dtype),
|
| 452 |
+
'w3': _get_weight(sf, f"{lp}.feed_forward.w3.weight", device, dtype),
|
| 453 |
+
'ada_down': _get_weight(sf, f"{lp}.ada_rms_norm_t_cond.0.weight", device, dtype),
|
| 454 |
+
'ada_up': _get_weight(sf, f"{lp}.ada_rms_norm_t_cond.2.weight", device, dtype),
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
def embed_token(self, token_id):
|
| 458 |
+
return self.tok_embeddings[token_id]
|
| 459 |
+
|
| 460 |
+
def embed_tokens(self, token_ids):
|
| 461 |
+
return self.tok_embeddings[token_ids]
|
| 462 |
+
|
| 463 |
+
def _layer_forward(self, h, layer_idx, pos, kv_seq_len, t_cond=None):
|
| 464 |
+
L = self.layers[layer_idx]
|
| 465 |
+
seq_len = h.shape[0]
|
| 466 |
+
dtype = self.compute_dtype
|
| 467 |
+
device = self.device
|
| 468 |
+
|
| 469 |
+
if h.dtype != dtype:
|
| 470 |
+
h = h.to(dtype)
|
| 471 |
+
|
| 472 |
+
norm = _RMSNorm(L['attention_norm'], DEC_NORM_EPS)
|
| 473 |
+
x_norm = norm(h).to(dtype)
|
| 474 |
+
|
| 475 |
+
q = F.linear(x_norm, L['wq'])
|
| 476 |
+
k = F.linear(x_norm, L['wk'])
|
| 477 |
+
v = F.linear(x_norm, L['wv'])
|
| 478 |
+
|
| 479 |
+
positions = torch.arange(pos, pos + seq_len, device=device)
|
| 480 |
+
rope_cos, rope_sin = _compute_rope_freqs(positions, DEC_HEAD_DIM, DEC_ROPE_THETA, device)
|
| 481 |
+
q = _apply_rope(q.float(), rope_cos, rope_sin, DEC_HEADS, DEC_HEAD_DIM, is_neox_style=False).to(dtype)
|
| 482 |
+
k = _apply_rope(k.float(), rope_cos, rope_sin, DEC_KV_HEADS, DEC_HEAD_DIM, is_neox_style=False).to(dtype)
|
| 483 |
+
|
| 484 |
+
if layer_idx not in self.kv_cache:
|
| 485 |
+
k_cache = k
|
| 486 |
+
v_cache = v
|
| 487 |
+
else:
|
| 488 |
+
k_cache, v_cache = self.kv_cache[layer_idx]
|
| 489 |
+
k_cache = torch.cat([k_cache, k], dim=0)
|
| 490 |
+
v_cache = torch.cat([v_cache, v], dim=0)
|
| 491 |
+
|
| 492 |
+
if k_cache.shape[0] > DEC_WINDOW:
|
| 493 |
+
k_cache = k_cache[-DEC_WINDOW:]
|
| 494 |
+
v_cache = v_cache[-DEC_WINDOW:]
|
| 495 |
+
|
| 496 |
+
self.kv_cache[layer_idx] = (k_cache, v_cache)
|
| 497 |
+
full_k, full_v = self.kv_cache[layer_idx]
|
| 498 |
+
|
| 499 |
+
kv_start_pos = (pos + seq_len - 1) - (full_k.shape[0] - 1)
|
| 500 |
+
attn_out = _causal_attention(
|
| 501 |
+
q, full_k, full_v,
|
| 502 |
+
DEC_HEADS, DEC_KV_HEADS, DEC_HEAD_DIM,
|
| 503 |
+
DEC_WINDOW,
|
| 504 |
+
q_start_pos=pos,
|
| 505 |
+
kv_start_pos=kv_start_pos,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
attn_proj = F.linear(attn_out, L['wo'])
|
| 509 |
+
h = h + attn_proj
|
| 510 |
+
|
| 511 |
+
ffn_norm = _RMSNorm(L['ffn_norm'], DEC_NORM_EPS)
|
| 512 |
+
h_norm = ffn_norm(h).to(dtype)
|
| 513 |
+
|
| 514 |
+
if t_cond is not None:
|
| 515 |
+
t_cond_dt = t_cond.to(dtype)
|
| 516 |
+
ada_hidden = F.gelu(F.linear(t_cond_dt, L['ada_down']))
|
| 517 |
+
ada_scale = F.linear(ada_hidden, L['ada_up'])
|
| 518 |
+
h_norm = h_norm * (1 + ada_scale.unsqueeze(0))
|
| 519 |
+
|
| 520 |
+
gate = F.silu(F.linear(h_norm, L['w1']))
|
| 521 |
+
up = F.linear(h_norm, L['w3'])
|
| 522 |
+
h = h + F.linear(gate * up, L['w2'])
|
| 523 |
+
|
| 524 |
+
return h
|
| 525 |
+
|
| 526 |
+
def prefill(self, input_embeds, t_cond):
|
| 527 |
+
self.kv_cache = {}
|
| 528 |
+
h = input_embeds.to(self.compute_dtype)
|
| 529 |
+
seq_len = h.shape[0]
|
| 530 |
+
|
| 531 |
+
for layer in range(DEC_LAYERS):
|
| 532 |
+
h = self._layer_forward(h, layer, 0, seq_len, t_cond=t_cond)
|
| 533 |
+
|
| 534 |
+
return h
|
| 535 |
+
|
| 536 |
+
def forward_one(self, embed, pos, t_cond):
|
| 537 |
+
h = embed.unsqueeze(0) if embed.dim() == 1 else embed
|
| 538 |
+
h = h.to(self.compute_dtype)
|
| 539 |
+
|
| 540 |
+
for layer in range(DEC_LAYERS):
|
| 541 |
+
h = self._layer_forward(h, layer, pos, pos + 1, t_cond=t_cond)
|
| 542 |
+
|
| 543 |
+
norm = _RMSNorm(self.final_norm, DEC_NORM_EPS)
|
| 544 |
+
h = norm(h)
|
| 545 |
+
|
| 546 |
+
logits = F.linear(h.float().squeeze(0), self.tok_embeddings.float())
|
| 547 |
+
return logits
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
# ============================================================================
|
| 551 |
+
# Tokenizer
|
| 552 |
+
# ============================================================================
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def _load_tokenizer(model_dir):
|
| 556 |
+
tekken_path = os.path.join(model_dir, "tekken.json")
|
| 557 |
+
with open(tekken_path, "r", encoding="utf-8") as f:
|
| 558 |
+
data = json.load(f)
|
| 559 |
+
|
| 560 |
+
vocab = data["vocab"]
|
| 561 |
+
config = data.get("config", {})
|
| 562 |
+
n_special = int(config.get("default_num_special_tokens", 1000))
|
| 563 |
+
special_ids = {int(st["rank"]) for st in data.get("special_tokens", []) if "rank" in st}
|
| 564 |
+
|
| 565 |
+
bytes_cache = {}
|
| 566 |
+
|
| 567 |
+
def token_bytes(token_id: int) -> bytes:
|
| 568 |
+
b = bytes_cache.get(token_id)
|
| 569 |
+
if b is not None:
|
| 570 |
+
return b
|
| 571 |
+
if token_id < 0:
|
| 572 |
+
bytes_cache[token_id] = b""
|
| 573 |
+
return b""
|
| 574 |
+
if token_id < n_special or token_id in special_ids:
|
| 575 |
+
bytes_cache[token_id] = b""
|
| 576 |
+
return b""
|
| 577 |
+
vocab_id = token_id - n_special
|
| 578 |
+
if vocab_id < 0 or vocab_id >= len(vocab):
|
| 579 |
+
bytes_cache[token_id] = b""
|
| 580 |
+
return b""
|
| 581 |
+
b = base64.b64decode(vocab[vocab_id]["token_bytes"])
|
| 582 |
+
bytes_cache[token_id] = b
|
| 583 |
+
return b
|
| 584 |
+
|
| 585 |
+
def decode(token_ids):
|
| 586 |
+
out = bytearray()
|
| 587 |
+
for token_id in map(int, token_ids):
|
| 588 |
+
if token_id < n_special or token_id in special_ids:
|
| 589 |
+
continue
|
| 590 |
+
out += token_bytes(token_id)
|
| 591 |
+
return out.decode("utf-8", errors="replace")
|
| 592 |
+
|
| 593 |
+
return decode
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
# ============================================================================
|
| 597 |
+
# VoxtralModel — singleton inference engine
|
| 598 |
+
# ============================================================================
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
class VoxtralModel:
|
| 602 |
+
"""Load Voxtral from Mistral-format safetensors and run inference on CUDA."""
|
| 603 |
+
|
| 604 |
+
def __init__(self, model_dir: str):
|
| 605 |
+
self.model_dir = model_dir
|
| 606 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 607 |
+
# FP16 for T4 (no good bf16 support); float32 on CPU
|
| 608 |
+
self.compute_dtype = torch.float16 if self.device.type == "cuda" else torch.float32
|
| 609 |
+
|
| 610 |
+
sf_path = os.path.join(model_dir, "consolidated.safetensors")
|
| 611 |
+
self._sf_file = safe_open(sf_path, framework="pt")
|
| 612 |
+
|
| 613 |
+
# Precompute mel filters on device
|
| 614 |
+
self._mel_filters = torch.tensor(
|
| 615 |
+
_compute_mel_filters(), dtype=torch.float32, device=self.device
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
# Preload decoder (holds all layer weights on GPU)
|
| 619 |
+
self._decoder = _Decoder(self._sf_file, self.device, self.compute_dtype)
|
| 620 |
+
|
| 621 |
+
# Load tokenizer
|
| 622 |
+
self._decode = _load_tokenizer(model_dir)
|
| 623 |
+
|
| 624 |
+
def _prepare(self, audio_16k: np.ndarray):
|
| 625 |
+
"""Audio array -> (adapter_out, prompt_ids, t_cond) all on device."""
|
| 626 |
+
prompt_ids = [TOKEN_BOS] + [TOKEN_STREAMING_PAD] * (N_LEFT_PAD_TOKENS + N_DELAY_TOKENS)
|
| 627 |
+
padded = _pad_audio_streaming(audio_16k).astype(np.float32)
|
| 628 |
+
|
| 629 |
+
audio_tensor = torch.tensor(padded, dtype=torch.float32, device=self.device)
|
| 630 |
+
mel = _compute_mel_spectrogram(audio_tensor, self._mel_filters, self.device)
|
| 631 |
+
|
| 632 |
+
if mel.shape[1] % 2 != 0:
|
| 633 |
+
mel = mel[:, 1:]
|
| 634 |
+
|
| 635 |
+
with torch.no_grad():
|
| 636 |
+
enc_out = _encoder_forward(mel, self._sf_file, self.device, self.compute_dtype)
|
| 637 |
+
adapter_out = _adapter_forward(enc_out, self._sf_file, self.device, self.compute_dtype)
|
| 638 |
+
|
| 639 |
+
t_cond = _compute_time_embedding(float(N_DELAY_TOKENS), DEC_DIM, self.device)
|
| 640 |
+
|
| 641 |
+
return adapter_out, prompt_ids, t_cond
|
| 642 |
+
|
| 643 |
+
def transcribe(self, audio_16k: np.ndarray) -> str:
|
| 644 |
+
"""Full pipeline: 16 kHz float32 mono audio -> transcribed text."""
|
| 645 |
+
adapter_out, prompt_ids, t_cond = self._prepare(audio_16k)
|
| 646 |
+
|
| 647 |
+
n_audio = adapter_out.shape[0]
|
| 648 |
+
L = len(prompt_ids)
|
| 649 |
+
|
| 650 |
+
prompt_ids_t = torch.tensor(prompt_ids, dtype=torch.long, device=self.device)
|
| 651 |
+
prefix_text_embeds = self._decoder.embed_tokens(prompt_ids_t)
|
| 652 |
+
prefix_embeds = adapter_out[:L] + prefix_text_embeds
|
| 653 |
+
|
| 654 |
+
with torch.no_grad():
|
| 655 |
+
if L > 1:
|
| 656 |
+
_ = self._decoder.prefill(prefix_embeds[:-1], t_cond)
|
| 657 |
+
logits = self._decoder.forward_one(prefix_embeds[-1], pos=L - 1, t_cond=t_cond)
|
| 658 |
+
token = int(logits.argmax().item())
|
| 659 |
+
|
| 660 |
+
generated = [token]
|
| 661 |
+
|
| 662 |
+
with torch.no_grad():
|
| 663 |
+
for pos in range(L, n_audio):
|
| 664 |
+
if token == TOKEN_EOS:
|
| 665 |
+
break
|
| 666 |
+
embed = adapter_out[pos] + self._decoder.embed_token(token)
|
| 667 |
+
logits = self._decoder.forward_one(embed, pos=pos, t_cond=t_cond)
|
| 668 |
+
token = int(logits.argmax().item())
|
| 669 |
+
generated.append(token)
|
| 670 |
+
|
| 671 |
+
if generated and generated[-1] == TOKEN_EOS:
|
| 672 |
+
generated = generated[:-1]
|
| 673 |
+
|
| 674 |
+
return self._decode(generated).strip()
|
| 675 |
+
|
| 676 |
+
def transcribe_stream(self, audio_16k: np.ndarray) -> Iterator[str]:
|
| 677 |
+
"""Streaming pipeline: yields decoded text fragments as tokens are generated."""
|
| 678 |
+
adapter_out, prompt_ids, t_cond = self._prepare(audio_16k)
|
| 679 |
+
|
| 680 |
+
n_audio = adapter_out.shape[0]
|
| 681 |
+
L = len(prompt_ids)
|
| 682 |
+
|
| 683 |
+
prompt_ids_t = torch.tensor(prompt_ids, dtype=torch.long, device=self.device)
|
| 684 |
+
prefix_text_embeds = self._decoder.embed_tokens(prompt_ids_t)
|
| 685 |
+
prefix_embeds = adapter_out[:L] + prefix_text_embeds
|
| 686 |
+
|
| 687 |
+
with torch.no_grad():
|
| 688 |
+
if L > 1:
|
| 689 |
+
_ = self._decoder.prefill(prefix_embeds[:-1], t_cond)
|
| 690 |
+
logits = self._decoder.forward_one(prefix_embeds[-1], pos=L - 1, t_cond=t_cond)
|
| 691 |
+
token = int(logits.argmax().item())
|
| 692 |
+
|
| 693 |
+
if token != TOKEN_EOS:
|
| 694 |
+
text = self._decode([token])
|
| 695 |
+
if text:
|
| 696 |
+
yield text
|
| 697 |
+
|
| 698 |
+
with torch.no_grad():
|
| 699 |
+
for pos in range(L, n_audio):
|
| 700 |
+
if token == TOKEN_EOS:
|
| 701 |
+
break
|
| 702 |
+
embed = adapter_out[pos] + self._decoder.embed_token(token)
|
| 703 |
+
logits = self._decoder.forward_one(embed, pos=pos, t_cond=t_cond)
|
| 704 |
+
token = int(logits.argmax().item())
|
| 705 |
+
if token != TOKEN_EOS:
|
| 706 |
+
text = self._decode([token])
|
| 707 |
+
if text:
|
| 708 |
+
yield text
|