gpu_endpoint / voxtral_inference.py
tantk's picture
Upload voxtral_inference.py with huggingface_hub
30cbe32 verified
"""
Voxtral Realtime 4B inference engine.
Loads directly from Mistral-format consolidated.safetensors — no transformers
dependency. Adapted from voxtral.c/python_simple_implementation.py with CUDA
and FP16 support for T4 GPUs.
"""
import json
import math
import os
import base64
from typing import Iterator
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors import safe_open
# ============================================================================
# Config (from params.json)
# ============================================================================
# Encoder
ENC_DIM = 1280
ENC_LAYERS = 32
ENC_HEADS = 32
ENC_HEAD_DIM = 64
ENC_HIDDEN = 5120
ENC_KV_HEADS = 32
ENC_WINDOW = 750
ENC_NORM_EPS = 1e-5
ENC_ROPE_THETA = 1_000_000.0
# Decoder
DEC_DIM = 3072
DEC_LAYERS = 26
DEC_HEADS = 32
DEC_HEAD_DIM = 128
DEC_HIDDEN = 9216
DEC_KV_HEADS = 8
DEC_WINDOW = 8192
DEC_NORM_EPS = 1e-5
DEC_ROPE_THETA = 1_000_000.0
VOCAB_SIZE = 131072
# Audio
SAMPLE_RATE = 16000
FRAME_RATE = 12.5
NUM_MEL_BINS = 128
HOP_LENGTH = 160
WINDOW_SIZE = 400
GLOBAL_LOG_MEL_MAX = 1.5
DOWNSAMPLE_FACTOR = 4
# Ada norm
ADA_NORM_DIM = 32
# Streaming
N_LEFT_PAD_TOKENS = 32
TRANSCRIPTION_DELAY_MS = 480
# Special tokens
TOKEN_BOS = 1
TOKEN_EOS = 2
TOKEN_STREAMING_PAD = 32
TOKEN_BEGIN_AUDIO = 25
TOKEN_AUDIO = 24
# Derived constants
RAW_AUDIO_LENGTH_PER_TOK = int(SAMPLE_RATE // FRAME_RATE) # 1280
AUDIO_LENGTH_PER_TOK = RAW_AUDIO_LENGTH_PER_TOK // HOP_LENGTH # 8
def _num_delay_tokens():
delay_len = int(TRANSCRIPTION_DELAY_MS / 1000.0 * SAMPLE_RATE)
n = delay_len
if n % HOP_LENGTH != 0:
n = math.ceil(n / HOP_LENGTH - 1)
else:
n = n // HOP_LENGTH
return math.ceil(n / AUDIO_LENGTH_PER_TOK)
N_DELAY_TOKENS = _num_delay_tokens()
N_RIGHT_PAD_TOKENS = (N_DELAY_TOKENS + 1) + 10 # 17
# ============================================================================
# Mel filter bank
# ============================================================================
def _hertz_to_mel(freq):
min_log_hertz = 1000.0
min_log_mel = 15.0
logstep = 27.0 / np.log(6.4)
mels = 3.0 * freq / 200.0
if isinstance(freq, np.ndarray):
log_region = freq >= min_log_hertz
mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
elif freq >= min_log_hertz:
mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
return mels
def _mel_to_hertz(mels):
min_log_hertz = 1000.0
min_log_mel = 15.0
logstep = np.log(6.4) / 27.0
freq = 200.0 * mels / 3.0
log_region = mels >= min_log_mel
freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
return freq
def _compute_mel_filters():
num_frequency_bins = 1 + WINDOW_SIZE // 2 # 201
fft_freqs = np.linspace(0, SAMPLE_RATE // 2, num_frequency_bins)
mel_min = _hertz_to_mel(0.0)
mel_max = _hertz_to_mel(8000.0)
mel_freqs = np.linspace(mel_min, mel_max, NUM_MEL_BINS + 2)
filter_freqs = _mel_to_hertz(mel_freqs)
filter_diff = np.diff(filter_freqs)
slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
down_slopes = -slopes[:, :-2] / filter_diff[:-1]
up_slopes = slopes[:, 2:] / filter_diff[1:]
fb = np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
enorm = 2.0 / (filter_freqs[2:NUM_MEL_BINS + 2] - filter_freqs[:NUM_MEL_BINS])
fb *= np.expand_dims(enorm, 0)
return fb # [201, 128]
# ============================================================================
# Mel spectrogram
# ============================================================================
def _compute_mel_spectrogram(audio, mel_filters, device):
"""audio: 1D tensor on device, mel_filters: [freq_bins, mel_bins] on device."""
window = torch.hann_window(WINDOW_SIZE, device=device)
stft = torch.stft(audio, WINDOW_SIZE, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, torch.tensor(GLOBAL_LOG_MEL_MAX, device=device) - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec # [128, frames]
# ============================================================================
# Audio streaming padding
# ============================================================================
def _pad_audio_streaming(audio_array):
mult_of = RAW_AUDIO_LENGTH_PER_TOK
n_samples = len(audio_array)
align_pad = (mult_of - (n_samples % mult_of)) % mult_of
right_pad = align_pad + N_RIGHT_PAD_TOKENS * mult_of
left_pad = N_LEFT_PAD_TOKENS * mult_of
return np.pad(audio_array, (left_pad, right_pad))
# ============================================================================
# Weight loading helpers
# ============================================================================
def _get_weight(sf_file, name, device, dtype=None):
t = sf_file.get_tensor(name)
if t.dtype == torch.bfloat16:
t = t.float()
t = t.to(device)
if dtype is not None:
t = t.to(dtype)
return t
def _get_weight_optional(sf_file, name, device, dtype=None):
try:
return _get_weight(sf_file, name, device, dtype)
except Exception:
return None
def _permute_qk_weight(w, n_heads, head_dim):
attn_in = n_heads * head_dim
attn_out = w.shape[1]
return (
w.view(n_heads, head_dim // 2, 2, attn_out)
.transpose(1, 2)
.reshape(attn_in, attn_out)
)
def _permute_qk_bias(b, n_heads, head_dim):
attn_in = n_heads * head_dim
return (
b.view(n_heads, head_dim // 2, 2)
.transpose(1, 2)
.reshape(attn_in)
)
# ============================================================================
# RMSNorm
# ============================================================================
class _RMSNorm(nn.Module):
def __init__(self, weight, eps=1e-5):
super().__init__()
self.weight = weight
self.eps = eps
def forward(self, x):
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return (x.float() * rms * self.weight.float()).to(x.dtype)
# ============================================================================
# RoPE
# ============================================================================
def _compute_rope_freqs(positions, head_dim, theta, device):
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
angles = positions.float().unsqueeze(-1) * freqs.unsqueeze(0)
return torch.cos(angles), torch.sin(angles)
def _apply_rope(x, cos_f, sin_f, n_heads, head_dim, is_neox_style=False):
seq_len = x.shape[0]
x = x.view(seq_len, n_heads, head_dim)
cos_f = cos_f.unsqueeze(1)
sin_f = sin_f.unsqueeze(1)
if is_neox_style:
x1, x2 = x.chunk(2, dim=-1)
o1 = x1 * cos_f - x2 * sin_f
o2 = x2 * cos_f + x1 * sin_f
out = torch.cat([o1, o2], dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos_f - x2 * sin_f
o2 = x2 * cos_f + x1 * sin_f
out = torch.stack([o1, o2], dim=-1).flatten(-2)
return out.view(seq_len, n_heads * head_dim)
# ============================================================================
# Causal Attention
# ============================================================================
def _causal_attention(q, k, v, n_heads, n_kv_heads, head_dim, window,
q_start_pos=0, kv_start_pos=0):
seq_q = q.shape[0]
seq_kv = k.shape[0]
gqa_ratio = n_heads // n_kv_heads
device = q.device
orig_dtype = q.dtype
q = q.view(seq_q, n_heads, head_dim).transpose(0, 1).unsqueeze(0)
k = k.view(seq_kv, n_kv_heads, head_dim).transpose(0, 1).unsqueeze(0)
v = v.view(seq_kv, n_kv_heads, head_dim).transpose(0, 1).unsqueeze(0)
if gqa_ratio > 1:
k = k.repeat_interleave(gqa_ratio, dim=1)
v = v.repeat_interleave(gqa_ratio, dim=1)
qi_abs = (q_start_pos + torch.arange(seq_q, device=device)).unsqueeze(1)
kv_abs = (kv_start_pos + torch.arange(seq_kv, device=device)).unsqueeze(0)
attn_mask = (kv_abs <= qi_abs) & (kv_abs >= (qi_abs - (window - 1)))
out = F.scaled_dot_product_attention(
q.float(), k.float(), v.float(),
attn_mask=attn_mask.unsqueeze(0).unsqueeze(0),
scale=1.0 / math.sqrt(head_dim),
dropout_p=0.0,
).to(orig_dtype)
return out.squeeze(0).transpose(0, 1).contiguous().view(seq_q, n_heads * head_dim)
# ============================================================================
# Causal Conv1d
# ============================================================================
def _causal_conv1d(x, weight, bias, stride):
kernel_size = weight.shape[2]
effective_ks = kernel_size
padding_total = effective_ks - stride
n_frames = (x.shape[-1] - effective_ks + padding_total) / stride + 1
target_length = (math.ceil(n_frames) - 1) * stride + (effective_ks - padding_total)
extra_padding = int(target_length - x.shape[-1])
x = F.pad(x, (padding_total, extra_padding), mode='constant')
return F.conv1d(x, weight, bias, stride=stride)
# ============================================================================
# TimeEmbedding
# ============================================================================
def _compute_time_embedding(t_value, dim, device, theta=10000.0):
half_dim = dim // 2
inv_freq = torch.exp(
-math.log(theta) * torch.arange(half_dim, device=device).float() / half_dim
)
emb = t_value * inv_freq
return torch.cat([emb.cos(), emb.sin()])
# ============================================================================
# Encoder forward
# ============================================================================
def _encoder_forward(mel, sf_file, device, compute_dtype):
"""mel: [128, frames] on device -> [seq, 1280] on device."""
prefix = "mm_streams_embeddings.embedding_module.whisper_encoder"
mel_3d = mel.unsqueeze(0)
conv0_w = _get_weight(sf_file, f"{prefix}.conv_layers.0.conv.weight", device, compute_dtype)
conv0_b = _get_weight(sf_file, f"{prefix}.conv_layers.0.conv.bias", device, compute_dtype)
conv1_w = _get_weight(sf_file, f"{prefix}.conv_layers.1.conv.weight", device, compute_dtype)
conv1_b = _get_weight(sf_file, f"{prefix}.conv_layers.1.conv.bias", device, compute_dtype)
h = F.gelu(_causal_conv1d(mel_3d.to(compute_dtype), conv0_w, conv0_b, stride=1))
h = F.gelu(_causal_conv1d(h, conv1_w, conv1_b, stride=2))
h = h.squeeze(0).transpose(0, 1) # [seq, 1280]
conv_len = h.shape[0]
trunc = conv_len % DOWNSAMPLE_FACTOR
if trunc > 0:
h = h[trunc:]
seq_len = h.shape[0]
positions = torch.arange(seq_len, device=device)
rope_cos, rope_sin = _compute_rope_freqs(positions, ENC_HEAD_DIM, ENC_ROPE_THETA, device)
for layer in range(ENC_LAYERS):
lp = f"{prefix}.transformer.layers.{layer}"
attn_norm_w = _get_weight(sf_file, f"{lp}.attention_norm.weight", device)
norm = _RMSNorm(attn_norm_w, ENC_NORM_EPS)
x_norm = norm(h).to(compute_dtype)
wq = _get_weight(sf_file, f"{lp}.attention.wq.weight", device, compute_dtype)
wq_b = _get_weight(sf_file, f"{lp}.attention.wq.bias", device, compute_dtype)
wk = _get_weight(sf_file, f"{lp}.attention.wk.weight", device, compute_dtype)
wv = _get_weight(sf_file, f"{lp}.attention.wv.weight", device, compute_dtype)
wv_b = _get_weight(sf_file, f"{lp}.attention.wv.bias", device, compute_dtype)
wo = _get_weight(sf_file, f"{lp}.attention.wo.weight", device, compute_dtype)
wo_b = _get_weight(sf_file, f"{lp}.attention.wo.bias", device, compute_dtype)
q = F.linear(x_norm, wq, wq_b)
k = F.linear(x_norm, wk)
v = F.linear(x_norm, wv, wv_b)
q = _apply_rope(q, rope_cos, rope_sin, ENC_HEADS, ENC_HEAD_DIM, is_neox_style=False)
k = _apply_rope(k, rope_cos, rope_sin, ENC_KV_HEADS, ENC_HEAD_DIM, is_neox_style=False)
attn_out = _causal_attention(q, k, v, ENC_HEADS, ENC_KV_HEADS, ENC_HEAD_DIM, ENC_WINDOW)
h = h + F.linear(attn_out, wo, wo_b)
ffn_norm_w = _get_weight(sf_file, f"{lp}.ffn_norm.weight", device)
ffn_norm = _RMSNorm(ffn_norm_w, ENC_NORM_EPS)
x_norm = ffn_norm(h).to(compute_dtype)
w1 = _get_weight(sf_file, f"{lp}.feed_forward.w1.weight", device, compute_dtype)
w2 = _get_weight(sf_file, f"{lp}.feed_forward.w2.weight", device, compute_dtype)
w2_b = _get_weight(sf_file, f"{lp}.feed_forward.w2.bias", device, compute_dtype)
w3 = _get_weight(sf_file, f"{lp}.feed_forward.w3.weight", device, compute_dtype)
gate = F.silu(F.linear(x_norm, w1))
up = F.linear(x_norm, w3)
h = h + F.linear(gate * up, w2, w2_b)
final_norm_w = _get_weight(sf_file, f"{prefix}.transformer.norm.weight", device)
final_norm = _RMSNorm(final_norm_w, ENC_NORM_EPS)
h = final_norm(h)
return h # [seq, 1280]
# ============================================================================
# Adapter forward
# ============================================================================
def _adapter_forward(enc_out, sf_file, device, compute_dtype):
"""enc_out: [seq, 1280] -> [seq/4, 3072]."""
prefix = "mm_streams_embeddings.embedding_module"
w0 = _get_weight(sf_file, f"{prefix}.audio_language_projection.0.weight", device, compute_dtype)
w1 = _get_weight(sf_file, f"{prefix}.audio_language_projection.2.weight", device, compute_dtype)
seq_len = enc_out.shape[0]
ds = enc_out.reshape(seq_len // DOWNSAMPLE_FACTOR, ENC_DIM * DOWNSAMPLE_FACTOR)
out = F.gelu(F.linear(ds.to(compute_dtype), w0))
out = F.linear(out, w1)
return out # [seq/4, 3072]
# ============================================================================
# Decoder
# ============================================================================
class _Decoder:
def __init__(self, sf_file, device, compute_dtype):
self.sf = sf_file
self.device = device
self.compute_dtype = compute_dtype
self.tok_embeddings = _get_weight(
sf_file,
"mm_streams_embeddings.embedding_module.tok_embeddings.weight",
device, compute_dtype,
)
self.final_norm = _get_weight(sf_file, "norm.weight", device)
self.kv_cache = {}
self.layers = []
for i in range(DEC_LAYERS):
self.layers.append(self._load_layer(i))
def _load_layer(self, i):
sf = self.sf
lp = f"layers.{i}"
device = self.device
dtype = self.compute_dtype
return {
'attention_norm': _get_weight(sf, f"{lp}.attention_norm.weight", device),
'ffn_norm': _get_weight(sf, f"{lp}.ffn_norm.weight", device),
'wq': _get_weight(sf, f"{lp}.attention.wq.weight", device, dtype),
'wk': _get_weight(sf, f"{lp}.attention.wk.weight", device, dtype),
'wv': _get_weight(sf, f"{lp}.attention.wv.weight", device, dtype),
'wo': _get_weight(sf, f"{lp}.attention.wo.weight", device, dtype),
'w1': _get_weight(sf, f"{lp}.feed_forward.w1.weight", device, dtype),
'w2': _get_weight(sf, f"{lp}.feed_forward.w2.weight", device, dtype),
'w3': _get_weight(sf, f"{lp}.feed_forward.w3.weight", device, dtype),
'ada_down': _get_weight(sf, f"{lp}.ada_rms_norm_t_cond.0.weight", device, dtype),
'ada_up': _get_weight(sf, f"{lp}.ada_rms_norm_t_cond.2.weight", device, dtype),
}
def embed_token(self, token_id):
return self.tok_embeddings[token_id]
def embed_tokens(self, token_ids):
return self.tok_embeddings[token_ids]
def _layer_forward(self, h, layer_idx, pos, kv_seq_len, t_cond=None):
L = self.layers[layer_idx]
seq_len = h.shape[0]
dtype = self.compute_dtype
device = self.device
if h.dtype != dtype:
h = h.to(dtype)
norm = _RMSNorm(L['attention_norm'], DEC_NORM_EPS)
x_norm = norm(h).to(dtype)
q = F.linear(x_norm, L['wq'])
k = F.linear(x_norm, L['wk'])
v = F.linear(x_norm, L['wv'])
positions = torch.arange(pos, pos + seq_len, device=device)
rope_cos, rope_sin = _compute_rope_freqs(positions, DEC_HEAD_DIM, DEC_ROPE_THETA, device)
q = _apply_rope(q.float(), rope_cos, rope_sin, DEC_HEADS, DEC_HEAD_DIM, is_neox_style=False).to(dtype)
k = _apply_rope(k.float(), rope_cos, rope_sin, DEC_KV_HEADS, DEC_HEAD_DIM, is_neox_style=False).to(dtype)
if layer_idx not in self.kv_cache:
k_cache = k
v_cache = v
else:
k_cache, v_cache = self.kv_cache[layer_idx]
k_cache = torch.cat([k_cache, k], dim=0)
v_cache = torch.cat([v_cache, v], dim=0)
if k_cache.shape[0] > DEC_WINDOW:
k_cache = k_cache[-DEC_WINDOW:]
v_cache = v_cache[-DEC_WINDOW:]
self.kv_cache[layer_idx] = (k_cache, v_cache)
full_k, full_v = self.kv_cache[layer_idx]
kv_start_pos = (pos + seq_len - 1) - (full_k.shape[0] - 1)
attn_out = _causal_attention(
q, full_k, full_v,
DEC_HEADS, DEC_KV_HEADS, DEC_HEAD_DIM,
DEC_WINDOW,
q_start_pos=pos,
kv_start_pos=kv_start_pos,
)
attn_proj = F.linear(attn_out, L['wo'])
h = h + attn_proj
ffn_norm = _RMSNorm(L['ffn_norm'], DEC_NORM_EPS)
h_norm = ffn_norm(h).to(dtype)
if t_cond is not None:
t_cond_dt = t_cond.to(dtype)
ada_hidden = F.gelu(F.linear(t_cond_dt, L['ada_down']))
ada_scale = F.linear(ada_hidden, L['ada_up'])
h_norm = h_norm * (1 + ada_scale.unsqueeze(0))
gate = F.silu(F.linear(h_norm, L['w1']))
up = F.linear(h_norm, L['w3'])
h = h + F.linear(gate * up, L['w2'])
return h
def prefill(self, input_embeds, t_cond):
self.kv_cache = {}
h = input_embeds.to(self.compute_dtype)
seq_len = h.shape[0]
for layer in range(DEC_LAYERS):
h = self._layer_forward(h, layer, 0, seq_len, t_cond=t_cond)
return h
def forward_one(self, embed, pos, t_cond):
h = embed.unsqueeze(0) if embed.dim() == 1 else embed
h = h.to(self.compute_dtype)
for layer in range(DEC_LAYERS):
h = self._layer_forward(h, layer, pos, pos + 1, t_cond=t_cond)
norm = _RMSNorm(self.final_norm, DEC_NORM_EPS)
h = norm(h)
logits = F.linear(h.float().squeeze(0), self.tok_embeddings.float())
return logits
# ============================================================================
# Tokenizer
# ============================================================================
def _load_tokenizer(model_dir):
tekken_path = os.path.join(model_dir, "tekken.json")
with open(tekken_path, "r", encoding="utf-8") as f:
data = json.load(f)
vocab = data["vocab"]
config = data.get("config", {})
n_special = int(config.get("default_num_special_tokens", 1000))
special_ids = {int(st["rank"]) for st in data.get("special_tokens", []) if "rank" in st}
bytes_cache = {}
def token_bytes(token_id: int) -> bytes:
b = bytes_cache.get(token_id)
if b is not None:
return b
if token_id < 0:
bytes_cache[token_id] = b""
return b""
if token_id < n_special or token_id in special_ids:
bytes_cache[token_id] = b""
return b""
vocab_id = token_id - n_special
if vocab_id < 0 or vocab_id >= len(vocab):
bytes_cache[token_id] = b""
return b""
b = base64.b64decode(vocab[vocab_id]["token_bytes"])
bytes_cache[token_id] = b
return b
def decode(token_ids):
out = bytearray()
for token_id in map(int, token_ids):
if token_id < n_special or token_id in special_ids:
continue
out += token_bytes(token_id)
return out.decode("utf-8", errors="replace")
return decode
# ============================================================================
# VoxtralModel — singleton inference engine
# ============================================================================
class VoxtralModel:
"""Load Voxtral from Mistral-format safetensors and run inference on CUDA."""
def __init__(self, model_dir: str):
self.model_dir = model_dir
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# FP16 for T4 (no good bf16 support); float32 on CPU
self.compute_dtype = torch.float16 if self.device.type == "cuda" else torch.float32
sf_path = os.path.join(model_dir, "consolidated.safetensors")
self._sf_file = safe_open(sf_path, framework="pt")
# Precompute mel filters on device
self._mel_filters = torch.tensor(
_compute_mel_filters(), dtype=torch.float32, device=self.device
)
# Preload decoder (holds all layer weights on GPU)
self._decoder = _Decoder(self._sf_file, self.device, self.compute_dtype)
# Load tokenizer
self._decode = _load_tokenizer(model_dir)
def _prepare(self, audio_16k: np.ndarray):
"""Audio array -> (adapter_out, prompt_ids, t_cond) all on device."""
prompt_ids = [TOKEN_BOS] + [TOKEN_STREAMING_PAD] * (N_LEFT_PAD_TOKENS + N_DELAY_TOKENS)
padded = _pad_audio_streaming(audio_16k).astype(np.float32)
audio_tensor = torch.tensor(padded, dtype=torch.float32, device=self.device)
mel = _compute_mel_spectrogram(audio_tensor, self._mel_filters, self.device)
if mel.shape[1] % 2 != 0:
mel = mel[:, 1:]
with torch.no_grad():
enc_out = _encoder_forward(mel, self._sf_file, self.device, self.compute_dtype)
adapter_out = _adapter_forward(enc_out, self._sf_file, self.device, self.compute_dtype)
t_cond = _compute_time_embedding(float(N_DELAY_TOKENS), DEC_DIM, self.device)
return adapter_out, prompt_ids, t_cond
def transcribe(self, audio_16k: np.ndarray) -> str:
"""Full pipeline: 16 kHz float32 mono audio -> transcribed text."""
adapter_out, prompt_ids, t_cond = self._prepare(audio_16k)
n_audio = adapter_out.shape[0]
L = len(prompt_ids)
prompt_ids_t = torch.tensor(prompt_ids, dtype=torch.long, device=self.device)
prefix_text_embeds = self._decoder.embed_tokens(prompt_ids_t)
prefix_embeds = adapter_out[:L] + prefix_text_embeds
with torch.no_grad():
if L > 1:
_ = self._decoder.prefill(prefix_embeds[:-1], t_cond)
logits = self._decoder.forward_one(prefix_embeds[-1], pos=L - 1, t_cond=t_cond)
token = int(logits.argmax().item())
generated = [token]
with torch.no_grad():
for pos in range(L, n_audio):
if token == TOKEN_EOS:
break
embed = adapter_out[pos] + self._decoder.embed_token(token)
logits = self._decoder.forward_one(embed, pos=pos, t_cond=t_cond)
token = int(logits.argmax().item())
generated.append(token)
if generated and generated[-1] == TOKEN_EOS:
generated = generated[:-1]
return self._decode(generated).strip()
def transcribe_stream(self, audio_16k: np.ndarray) -> Iterator[str]:
"""Streaming pipeline: yields decoded text fragments as tokens are generated."""
adapter_out, prompt_ids, t_cond = self._prepare(audio_16k)
n_audio = adapter_out.shape[0]
L = len(prompt_ids)
prompt_ids_t = torch.tensor(prompt_ids, dtype=torch.long, device=self.device)
prefix_text_embeds = self._decoder.embed_tokens(prompt_ids_t)
prefix_embeds = adapter_out[:L] + prefix_text_embeds
with torch.no_grad():
if L > 1:
_ = self._decoder.prefill(prefix_embeds[:-1], t_cond)
logits = self._decoder.forward_one(prefix_embeds[-1], pos=L - 1, t_cond=t_cond)
token = int(logits.argmax().item())
if token != TOKEN_EOS:
text = self._decode([token])
if text:
yield text
with torch.no_grad():
for pos in range(L, n_audio):
if token == TOKEN_EOS:
break
embed = adapter_out[pos] + self._decoder.embed_token(token)
logits = self._decoder.forward_one(embed, pos=pos, t_cond=t_cond)
token = int(logits.argmax().item())
if token != TOKEN_EOS:
text = self._decode([token])
if text:
yield text