VibeVoice / vllm_plugin /model.py
ISABS's picture
Upload folder using huggingface_hub
0a81958 verified
"""
VibeVoice vLLM Plugin Model - Native Multimodal Integration
This module implements the VibeVoice ASR model with full vLLM multimodal registry
integration for speech-to-text inference.
"""
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence
import os
import torch
import torch.nn as nn
import numpy as np
import base64
# ============================================================================
# Audio Loading: FFmpeg-based AudioMediaIO
# ============================================================================
# VibeVoice uses FFmpeg for audio decoding to ensure consistent behavior
# across different audio formats (MP3, WAV, FLAC, etc.).
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
def _ffmpeg_load_bytes(data: bytes) -> tuple[np.ndarray, int]:
"""Load audio bytes using FFmpeg via stdin-pipe decoding.
Returns:
Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
"""
audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
normalizer = AudioNormalizer()
audio = normalizer(audio)
return audio, sr
def _ffmpeg_load_file(filepath) -> tuple[np.ndarray, int]:
"""Load audio file using FFmpeg.
Returns:
Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
"""
audio, sr = load_audio_use_ffmpeg(str(filepath), resample=True, target_sr=24000)
normalizer = AudioNormalizer()
audio = normalizer(audio)
return audio, sr
# Register FFmpeg-based audio loader
try:
# Try new location (vLLM >= 0.6.x)
from vllm.multimodal.media.audio import AudioMediaIO as _OriginalAudioMediaIO
except ImportError:
# Fall back to old location (vLLM < 0.6.x)
import vllm.multimodal.audio as _vllm_audio_module
_OriginalAudioMediaIO = _vllm_audio_module.AudioMediaIO
class _PatchedAudioMediaIO(_OriginalAudioMediaIO):
"""AudioMediaIO implementation using FFmpeg for audio decoding."""
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(data)
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(base64.b64decode(data))
def load_file(self, filepath) -> tuple[np.ndarray, int]:
return _ffmpeg_load_file(filepath)
# Replace globally
try:
# For new vLLM versions
import vllm.multimodal.media.audio as _vllm_audio_module
_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO
except ImportError:
# For old vLLM versions
import vllm.multimodal.audio as _vllm_audio_module
_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO
# Also patch in utils module where it's imported
try:
import vllm.multimodal.utils as _vllm_utils_module
_vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO
except (ImportError, AttributeError):
# AudioMediaIO might not be imported in utils in newer versions
pass
# ============================================================================
from transformers import BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import MultiModalDataParser
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings
from vllm.model_executor.models.utils import (
init_vllm_registered_model,
maybe_prefix,
AutoWeightsLoader,
WeightsMapper,
)
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
try:
# Try new location (vLLM >= 0.6.x)
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
except ImportError:
# Fall back to old location (vLLM < 0.6.x)
try:
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
except ImportError:
# If neither location works, try individual imports
from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder
from vllm.multimodal.processing.inputs import ProcessorInputs
# Import VibeVoice components
from vibevoice.modular.modular_vibevoice_tokenizer import (
VibeVoiceAcousticTokenizerModel,
VibeVoiceSemanticTokenizerModel,
VibeVoiceTokenizerStreamingCache,
VibeVoiceTokenizerEncoderOutput,
)
from vibevoice.modular.configuration_vibevoice import (
VibeVoiceAcousticTokenizerConfig,
VibeVoiceSemanticTokenizerConfig,
)
class SpeechConnector(nn.Module):
"""Projects speech features to language model hidden dimension.
Architecture: fc1 -> RMSNorm -> fc2 (no activation function)
"""
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
self.fc2 = nn.Linear(output_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.norm(x)
x = self.fc2(x)
return x
class LlamaRMSNorm(nn.Module):
"""RMSNorm layer used in SpeechConnector."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class VibeVoiceAudioEncoder(nn.Module):
"""
VibeVoice Audio Encoder module.
Encapsulates Acoustic/Semantic VAE Tokenizers and projection Connectors.
Converts raw audio waveforms into embeddings compatible with the language model.
Features:
- Streaming support for long audio (>60s by default)
- Configurable dtype for numerical precision
- Supports both sampling and deterministic (mean) modes
"""
def __init__(self, config):
super().__init__()
self.config = config
import sys
def get_cfg(obj, key, default=None):
if isinstance(obj, dict):
return obj.get(key, default)
return getattr(obj, key, default)
self.acoustic_vae_dim = get_cfg(config, "acoustic_vae_dim", 64)
self.semantic_vae_dim = get_cfg(config, "semantic_vae_dim", 128)
decoder_config = get_cfg(config, "decoder_config")
text_config = get_cfg(config, "text_config")
target_hidden_size = None
if decoder_config is not None:
target_hidden_size = get_cfg(decoder_config, "hidden_size")
if target_hidden_size is None and text_config is not None:
target_hidden_size = get_cfg(text_config, "hidden_size")
if target_hidden_size is None:
target_hidden_size = get_cfg(config, "hidden_size")
if target_hidden_size is None:
print("[VibeVoice] WARN: Could not find hidden_size in config! Defaulting to 3584 (7B).", file=sys.stderr)
self.hidden_size = 3584
else:
self.hidden_size = target_hidden_size
ac_cfg = get_cfg(config, "acoustic_tokenizer_config")
sc_cfg = get_cfg(config, "semantic_tokenizer_config")
if ac_cfg is None or sc_cfg is None:
raise ValueError("Missing acoustic/semantic tokenizer config in model config")
# Handle both dict and already-constructed config objects
if isinstance(ac_cfg, VibeVoiceAcousticTokenizerConfig):
acoustic_config = ac_cfg
elif isinstance(ac_cfg, dict):
acoustic_config = VibeVoiceAcousticTokenizerConfig(**ac_cfg)
else:
raise TypeError(f"acoustic_tokenizer_config has unexpected type: {type(ac_cfg)}")
if isinstance(sc_cfg, VibeVoiceSemanticTokenizerConfig):
semantic_config = sc_cfg
elif isinstance(sc_cfg, dict):
semantic_config = VibeVoiceSemanticTokenizerConfig(**sc_cfg)
else:
raise TypeError(f"semantic_tokenizer_config has unexpected type: {type(sc_cfg)}")
# Tokenizers use float32 for numerical precision
self.acoustic_tokenizer = VibeVoiceAcousticTokenizerModel(acoustic_config)
self.semantic_tokenizer = VibeVoiceSemanticTokenizerModel(semantic_config)
# Get audio encoder dtype from config (defaults to float32 for precision)
root_torch_dtype = get_cfg(config, "torch_dtype", None)
if root_torch_dtype is not None:
if isinstance(root_torch_dtype, str):
self._audio_encoder_dtype = getattr(torch, root_torch_dtype)
else:
self._audio_encoder_dtype = root_torch_dtype
else:
self._audio_encoder_dtype = torch.float32
self.acoustic_connector = SpeechConnector(self.acoustic_vae_dim, self.hidden_size)
self.semantic_connector = SpeechConnector(self.semantic_vae_dim, self.hidden_size)
self.compress_ratio = get_cfg(config, "speech_tok_compress_ratio", 3200)
# Streaming controls
self.sample_rate = get_cfg(config, "target_sample_rate", 24000)
# Default to True (per requirement): segment + cache inside one forward call.
self.enable_streaming = get_cfg(config, "enable_streaming", True)
self.streaming_segment_duration = get_cfg(config, "streaming_segment_duration", 60.0)
# Control whether to use sample() or .mean for acoustic tokens
# Default: use sample() for training-consistent behavior
# Set VIBEVOICE_USE_MEAN=1 for deterministic output
use_mean_env = os.getenv("VIBEVOICE_USE_MEAN", "").strip().lower()
self.use_sample = use_mean_env not in ("1", "true", "yes")
# Language model dtype (set by VibeVoiceForCausalLM.__init__)
# This is the dtype that audio embeddings will be converted to before
# being passed to the language model. Defaults to bfloat16.
self._lm_dtype: torch.dtype = torch.bfloat16
def _ensure_audio_encoder_dtype(self):
"""Ensure all audio encoder components use the correct dtype from config.
vLLM may convert weights to a different dtype (e.g., bfloat16) during loading.
This method converts audio encoder components back to the config-specified dtype
(typically float32) for numerical precision during audio encoding.
"""
import sys
target_dtype = self._audio_encoder_dtype
# Check and convert acoustic_tokenizer
try:
acoustic_dtype = next(self.acoustic_tokenizer.parameters()).dtype
if acoustic_dtype != target_dtype:
self.acoustic_tokenizer = self.acoustic_tokenizer.to(dtype=target_dtype)
print(f"[VibeVoice] Converted acoustic_tokenizer to {target_dtype} (was {acoustic_dtype})", file=sys.stderr)
except StopIteration:
pass
# Check and convert semantic_tokenizer
try:
semantic_dtype = next(self.semantic_tokenizer.parameters()).dtype
if semantic_dtype != target_dtype:
self.semantic_tokenizer = self.semantic_tokenizer.to(dtype=target_dtype)
print(f"[VibeVoice] Converted semantic_tokenizer to {target_dtype} (was {semantic_dtype})", file=sys.stderr)
except StopIteration:
pass
# Check and convert acoustic_connector
try:
ac_conn_dtype = next(self.acoustic_connector.parameters()).dtype
if ac_conn_dtype != target_dtype:
self.acoustic_connector = self.acoustic_connector.to(dtype=target_dtype)
print(f"[VibeVoice] Converted acoustic_connector to {target_dtype} (was {ac_conn_dtype})", file=sys.stderr)
except StopIteration:
pass
# Check and convert semantic_connector
try:
sc_conn_dtype = next(self.semantic_connector.parameters()).dtype
if sc_conn_dtype != target_dtype:
self.semantic_connector = self.semantic_connector.to(dtype=target_dtype)
print(f"[VibeVoice] Converted semantic_connector to {target_dtype} (was {sc_conn_dtype})", file=sys.stderr)
except StopIteration:
pass
def forward(
self,
audio: torch.Tensor,
*,
use_streaming: bool = True,
segment_duration_s: Optional[float] = None,
use_sample: Optional[bool] = None,
) -> torch.Tensor:
"""Encode audio with optional streaming for long clips.
Args:
audio: Input audio tensor [B, T] or [T]
use_streaming: Whether to enable segmented encoding for long audio
segment_duration_s: Segment length in seconds (defaults to 60s)
use_sample: If True, use sampling for acoustic tokens; if False, use mean
Defaults to self.use_sample (controlled by VIBEVOICE_USE_MEAN env var)
Returns:
Audio embeddings tensor compatible with the language model
"""
# Ensure audio encoder components use correct dtype
self._ensure_audio_encoder_dtype()
# Audio input should match the audio encoder dtype
audio = audio.to(dtype=self._audio_encoder_dtype)
if audio.ndim == 1:
audio = audio.unsqueeze(0)
# Resolve streaming options
segment_duration = segment_duration_s or self.streaming_segment_duration
sample_rate = self.sample_rate
total_samples = audio.shape[-1]
segment_samples = int(segment_duration * sample_rate)
use_streaming = use_streaming and self.enable_streaming and total_samples > segment_samples
# Resolve use_sample flag
if use_sample is None:
use_sample = self.use_sample
# Keep encoding in inference mode to avoid autograd build-up
with torch.no_grad():
if not use_streaming:
acoustic_input = audio.unsqueeze(1)
acoustic_out = self.acoustic_tokenizer.encode(acoustic_input)
# Use sample() or .mean based on use_sample flag
if use_sample:
acoustic_tokens = acoustic_out.sample(
dist_type=self.acoustic_tokenizer.std_dist_type
)[0]
else:
acoustic_tokens = acoustic_out.mean
# Connector is now float32, no conversion needed
acoustic_embeds = self.acoustic_connector(acoustic_tokens)
semantic_out = self.semantic_tokenizer.encode(acoustic_input)
# Semantic always uses .mean for consistency
semantic_tokens = semantic_out.mean
# Connector is now float32, no conversion needed
semantic_embeds = self.semantic_connector(semantic_tokens)
else:
# ==========================================
# Streaming path (Retained for future use)
# ==========================================
acoustic_cache = VibeVoiceTokenizerStreamingCache()
semantic_cache = VibeVoiceTokenizerStreamingCache()
acoustic_mean_segments = []
semantic_mean_segments = []
batch_size = audio.shape[0]
sample_indices = torch.arange(batch_size, device=audio.device)
def _iter_segments(total_length: int, segment_length: int):
for start in range(0, total_length, segment_length):
end = min(start + segment_length, total_length)
if end > start:
yield start, end
segments = list(_iter_segments(total_samples, segment_samples))
num_segments = len(segments)
for seg_idx, (start, end) in enumerate(segments):
chunk = audio[:, start:end].contiguous()
if chunk.numel() == 0:
continue
# Check if this is the final segment
is_final = (seg_idx == num_segments - 1)
# --- Acoustic Encode ---
acoustic_enc_out = self.acoustic_tokenizer.encode(
chunk.unsqueeze(1),
cache=acoustic_cache,
sample_indices=sample_indices,
use_cache=True,
is_final_chunk=is_final,
)
acoustic_mean_segments.append(acoustic_enc_out.mean)
# --- Semantic Encode ---
semantic_enc_out = self.semantic_tokenizer.encode(
chunk.unsqueeze(1),
cache=semantic_cache,
sample_indices=sample_indices,
use_cache=True,
is_final_chunk=is_final,
)
semantic_mean_segments.append(semantic_enc_out.mean)
# Concatenate sequence outputs (Acoustic)
if len(acoustic_mean_segments) == 0:
acoustic_mean_full = torch.zeros(
(batch_size, 0, self.acoustic_vae_dim),
device=audio.device,
dtype=self._audio_encoder_dtype # Use config dtype
)
else:
acoustic_mean_full = torch.cat(acoustic_mean_segments, dim=1).contiguous()
# Get acoustic tokens based on use_sample flag
acoustic_enc_full = VibeVoiceTokenizerEncoderOutput(
mean=acoustic_mean_full,
std=self.acoustic_tokenizer.fix_std,
)
if use_sample:
acoustic_tokens = acoustic_enc_full.sample(
dist_type=self.acoustic_tokenizer.std_dist_type
)[0]
else:
acoustic_tokens = acoustic_enc_full.mean
# Connector uses same dtype as tokenizer
acoustic_embeds = self.acoustic_connector(acoustic_tokens)
# Concatenate sequence outputs (Semantic)
if len(semantic_mean_segments) == 0:
semantic_tokens = torch.zeros(
(batch_size, 0, self.semantic_vae_dim),
device=audio.device,
dtype=self._audio_encoder_dtype # Use config dtype
)
else:
semantic_tokens = torch.cat(semantic_mean_segments, dim=1).contiguous()
# Connector uses same dtype as tokenizer
semantic_embeds = self.semantic_connector(semantic_tokens)
# Combine acoustic and semantic embeddings
combined_embeds = acoustic_embeds + semantic_embeds
# Convert to language model dtype for compatibility
# Audio encoder uses config.torch_dtype (typically float32) for numerical precision,
# but LM expects the dtype specified by vLLM's --dtype flag (e.g., bfloat16, float16)
combined_embeds = combined_embeds.to(dtype=self._lm_dtype)
return combined_embeds
# ============================================================================
# vLLM Multimodal Processing Infrastructure
# ============================================================================
class VibeVoiceProcessingInfo(BaseProcessingInfo):
"""Processing info for VibeVoice multimodal model."""
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_feature_extractor(self, **kwargs) -> WhisperFeatureExtractor:
"""
Get a WhisperFeatureExtractor for vLLM profiling compatibility.
IMPORTANT: This is NOT used in actual inference!
VibeVoice uses its own acoustic/semantic VAE tokenizers operating on
raw 24kHz waveforms, NOT Whisper mel spectrograms.
This feature extractor exists only to satisfy vLLM's multimodal
profiling infrastructure which may query audio parameters like
sampling_rate and chunk_length for memory estimation.
"""
# Read config from preprocessor_config.json if available
import json
import os
model_path = self.ctx.model_config.model
preprocessor_path = os.path.join(model_path, "preprocessor_config.json")
# Default values: keep a coherent (sr, hop) pair.
# VibeVoice runs at 24kHz in this repo (see demo/asr_transcribe_file.py).
config = {
"sampling_rate": 24000,
"feature_size": 128,
# 10ms hop at 24kHz
"hop_length": 240,
"chunk_length": 30,
"n_fft": 400,
"padding_value": 0.0,
}
# Try to load from config file
if os.path.exists(preprocessor_path):
try:
with open(preprocessor_path, "r") as f:
file_config = json.load(f)
config.update({k: file_config[k] for k in config.keys() if k in file_config})
except Exception:
pass # Use defaults
return WhisperFeatureExtractor(
feature_size=config["feature_size"],
sampling_rate=config["sampling_rate"],
hop_length=config["hop_length"],
chunk_length=config["chunk_length"],
n_fft=config["n_fft"],
padding_value=config["padding_value"],
)
def get_audio_token_info(self) -> dict:
"""
Get audio special tokens and their IDs.
Returns dict with:
audio_token, audio_bos_token, audio_eos_token,
audio_token_id, audio_bos_id, audio_eos_id
"""
tokenizer = self.get_tokenizer()
vocab = tokenizer.get_vocab()
# VibeVoice special tokens
tokens = {
"audio_token": "<|AUDIO|>",
"audio_bos_token": "<|audio_bos|>",
"audio_eos_token": "<|audio_eos|>",
}
# Get IDs
tokens["audio_token_id"] = vocab.get(tokens["audio_token"])
tokens["audio_bos_id"] = vocab.get(tokens["audio_bos_token"])
tokens["audio_eos_id"] = vocab.get(tokens["audio_eos_token"])
return tokens
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
"""Return the maximum number of audio tokens per item.
This tells vLLM's scheduler the upper bound so that
``encoder_compute_budget`` is large enough for any audio length
the model can handle, preventing the silent scheduling deadlock
described in docs/max_num_batched_tokens_issue.md.
Formula: audio_tokens = ceil(audio_samples / compress_ratio) + 3
where +3 accounts for speech_start, speech_end, and newline tokens.
The max audio samples is bounded by seq_len (the model's context
window cannot hold more tokens than that).
"""
hf_config = self.get_hf_config()
def _cfg(key: str, default):
if isinstance(hf_config, dict):
return hf_config.get(key, default)
return getattr(hf_config, key, default)
compress_ratio = int(_cfg("speech_tok_compress_ratio", 3200))
sample_rate = int(_cfg("target_sample_rate", 24000))
# Upper bound: 61-minute audio at 24 kHz
max_audio_samples = 61 * 60 * sample_rate # 88,464,000
max_audio_tokens = int(np.ceil(max_audio_samples / compress_ratio)) + 3
# Cannot exceed the model's context window
max_audio_tokens = min(max_audio_tokens, seq_len)
return {"audio": max_audio_tokens}
class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo]):
"""
Build dummy inputs for multimodal profiling.
vLLM uses dummy inputs to:
1. Measure peak GPU activation memory → determine KV cache capacity
2. Warm up CUDA graphs
The dummy audio length must be consistent with ``get_mm_max_tokens_per_item``
so that the memory estimate covers the worst-case (longest audio) scenario.
"""
def _get_max_audio_samples(self, seq_len: int) -> int:
"""Compute maximum audio samples consistent with ``get_mm_max_tokens_per_item``.
Uses the same formula: max_tokens = min(ceil(61min * sr / ratio) + 3, seq_len),
then converts back to samples.
"""
hf_config = self.info.get_hf_config()
def _cfg(key: str, default):
if isinstance(hf_config, dict):
return hf_config.get(key, default)
return getattr(hf_config, key, default)
compress_ratio = int(_cfg("speech_tok_compress_ratio", 3200))
sample_rate = int(_cfg("target_sample_rate", 24000))
# Upper bound: 61-minute audio at 24 kHz
max_hour_samples = 61 * 60 * sample_rate # 88,464,000
max_tokens_from_audio = int(np.ceil(max_hour_samples / compress_ratio)) + 3
# Cannot exceed model context window
max_tokens = min(max_tokens_from_audio, seq_len)
# Convert tokens back to samples
return max_tokens * compress_ratio
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
if num_audios <= 0:
return ""
token_info = self.info.get_audio_token_info()
audio_token = token_info["audio_token"]
return audio_token * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, Any] | None = None,
) -> Dict[str, Any]:
"""Generate dummy audio data for profiling.
The audio length is derived from ``seq_len`` so that profiling
accurately measures memory for the longest audio the model can handle.
Supports ``AudioDummyOptions.length`` override for faster startup.
"""
num_audios = mm_counts.get("audio", 0)
max_audio_len = self._get_max_audio_samples(seq_len)
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio": self._get_dummy_audios(
length=max_audio_len,
num_audios=num_audios,
overrides=audio_overrides,
)
}
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, Any] | None = None,
) -> ProcessorInputs:
"""Build ProcessorInputs for dummy profiling."""
return ProcessorInputs(
prompt=self.get_dummy_text(mm_counts),
mm_data=self.get_dummy_mm_data(seq_len, mm_counts, mm_options),
)
def _vibevoice_field_config(hf_inputs: Mapping[str, torch.Tensor]):
"""Map HF processor output keys to audio modality.
Returns a config dict that tells vLLM how to batch multimodal data.
"""
# Always define the config for all fields we use
# Even if the field isn't in hf_inputs, vLLM needs to know how to batch it
config = {
# These are our custom fields for VibeVoice
"raw_audio": MultiModalFieldConfig.batched("audio"),
"raw_audio_lengths": MultiModalFieldConfig.batched("audio"),
"salt": MultiModalFieldConfig.batched("audio"),
}
# Add optional Whisper features if present
if "input_features" in hf_inputs:
config["input_features"] = MultiModalFieldConfig.batched("audio")
if "feature_attention_mask" in hf_inputs:
config["feature_attention_mask"] = MultiModalFieldConfig.batched("audio")
return config
class VibeVoiceMultiModalProcessor(BaseMultiModalProcessor[VibeVoiceProcessingInfo]):
"""
Multimodal processor for VibeVoice.
Handles the conversion of raw audio inputs to model-ready features,
and manages the prompt token replacement for audio placeholders.
"""
def _get_data_parser(self) -> MultiModalDataParser:
"""Create a data parser with the correct target sample rate (24kHz)."""
# VibeVoice requires 24kHz, not 16kHz (Whisper default)
target_sr = 24000
return MultiModalDataParser(target_sr=target_sr)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
"""
Process prompt and audio for vLLM multimodal pipeline.
We intentionally do NOT run a HF processor that would pre-expand
`<|AUDIO|>` inside this method. Instead we:
1) Tokenize the prompt as-is (so `<|AUDIO|>` stays a single token)
2) Store raw audio tensors for `embed_multimodal` to encode later
3) Let vLLM call `_get_prompt_updates` to expand the single `<|AUDIO|>`
into the full ASR format: [speech_start] + N*[speech_pad] + [speech_end] + [\\n]
"""
# Handle the case where 'audios' key is used (transformers deprecation)
mm_data = dict(mm_data) # Make a mutable copy
audios = mm_data.pop("audios", None)
if audios is not None and "audio" not in mm_data:
mm_data["audio"] = audios
# Text-only input handling
if not mm_data.get("audio"):
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
# Get raw audio data
raw_audio_list = mm_data.get("audio")
if isinstance(raw_audio_list, np.ndarray):
raw_audio_list = [raw_audio_list]
elif not isinstance(raw_audio_list, list):
raw_audio_list = list(raw_audio_list)
# Tokenize prompt directly to preserve <|AUDIO|> as a single token
# vLLM will expand it via _get_prompt_updates
tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
# Create result with input_ids
result = BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
# Add raw audio tensors for VibeVoice encoder
# Stack into a single tensor for vLLM's batched field config
max_len = max(len(a) for a in raw_audio_list)
raw_audio_tensors = []
audio_lengths = []
for audio in raw_audio_list:
audio_len = len(audio)
audio_lengths.append(audio_len)
if audio_len < max_len:
audio = np.pad(audio, (0, max_len - audio_len), mode='constant')
raw_audio_tensors.append(torch.from_numpy(audio).float())
# Stack into [num_audios, max_len] tensor
stacked_audio = torch.stack(raw_audio_tensors, dim=0) # Shape: [num_audios, max_len]
result["raw_audio"] = stacked_audio
# Convert lengths to tensor as well
result["raw_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long)
# Add a random salt to ensure unique hash and bypass cache
import uuid
# Use a random integer for salt
salt_val = hash(str(uuid.uuid4())) % 100000
result["salt"] = torch.tensor([salt_val], dtype=torch.long).expand(len(raw_audio_list))
return result
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
"""Return whether the HF processor applies prompt updates.
Returns False because we handle token expansion via _get_prompt_updates.
"""
return False
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
"""Configure which HF output fields map to which modality."""
return _vibevoice_field_config(hf_inputs)
def _get_prompt_updates(
self,
mm_items,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
"""
Define how to replace the audio placeholder in the prompt.
vLLM's OpenAI multimodal parsing inserts the model placeholder string
from `get_placeholder_str` (here: `<|AUDIO|>`) into the conversation.
We expand that single token into N repeated `<|AUDIO|>` tokens, where N
is derived from waveform length and `speech_tok_compress_ratio`.
"""
token_info = self.info.get_audio_token_info()
audio_token = token_info["audio_token"]
audio_token_id = token_info["audio_token_id"]
audio_bos_id = token_info.get("audio_bos_id")
audio_eos_id = token_info.get("audio_eos_id")
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
def _tok_id(name: str) -> int | None:
return vocab.get(name)
# Look up speech token IDs from vocabulary
# These tokens mark the start/end of audio embeddings in the prompt
speech_start_id = (
_tok_id("<|object_ref_start|>")
or getattr(tokenizer, "speech_start_id", None)
or _tok_id("<|speech_start|>")
)
speech_end_id = (
_tok_id("<|object_ref_end|>")
or getattr(tokenizer, "speech_end_id", None)
or _tok_id("<|speech_end|>")
)
speech_pad_id = (
_tok_id("<|box_start|>")
or getattr(tokenizer, "speech_pad_id", None)
or _tok_id("<|speech_pad|>")
)
if audio_token_id is None:
return []
# Get raw audio lengths (in samples, after resampling to 24kHz) from our stored data
out_mm_data = out_mm_kwargs.get_data()
raw_audio_lengths = out_mm_data.get("raw_audio_lengths", [])
# Fetch defaults from model config when available (falls back to 3200)
hf_config = self.info.get_hf_config()
if isinstance(hf_config, dict):
compress_ratio = int(hf_config.get("speech_tok_compress_ratio", 3200))
else:
compress_ratio = int(getattr(hf_config, "speech_tok_compress_ratio", 3200))
def _to_int_len(x) -> int:
if x is None:
return 0
if isinstance(x, torch.Tensor):
# Accept 0-dim or 1-dim scalar-like tensors
if x.numel() == 1:
return int(x.item())
# If a full tensor is passed accidentally, fall back to its length
return int(x.shape[0])
return int(x)
def get_replacement(item_idx: int):
if raw_audio_lengths and item_idx < len(raw_audio_lengths):
audio_len = _to_int_len(raw_audio_lengths[item_idx])
num_features = max(1, int(np.ceil(audio_len / compress_ratio)))
else:
# Fallback: estimate for 30 second audio at 24kHz
num_features = int(np.ceil(30 * 24000 / compress_ratio))
if num_features == 0:
raise ValueError(
f"Audio at index {item_idx} is too short to be represented"
)
# Build replacement token sequence:
# <|speech_start|> + N * <|speech_pad|> + <|speech_end|> + \n
# The newline is important for correct prompt structure.
newline_id = 198 # '\n' token
if speech_start_id is not None and speech_pad_id is not None and speech_end_id is not None:
embed_id = int(speech_pad_id)
replacement_ids = [int(speech_start_id)] + [embed_id] * num_features + [int(speech_end_id), newline_id]
# Fallback: add audio BOS/EOS boundaries around repeated <|AUDIO|>.
elif audio_bos_id is not None and audio_eos_id is not None:
embed_id = int(audio_token_id)
replacement_ids = [int(audio_bos_id)] + [embed_id] * num_features + [int(audio_eos_id)]
else:
embed_id = int(audio_token_id)
replacement_ids = [embed_id] * num_features
return PromptUpdateDetails.select_token_id(
replacement_ids,
embed_token_id=int(embed_id),
)
return [
PromptReplacement(
modality="audio",
# Keep string placeholder matching for maximum vLLM compatibility.
target=audio_token,
replacement=get_replacement,
)
]
# ============================================================================
# Main Model Class
# ============================================================================
@MULTIMODAL_REGISTRY.register_processor(
VibeVoiceMultiModalProcessor,
info=VibeVoiceProcessingInfo,
dummy_inputs=VibeVoiceDummyInputsBuilder,
)
class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
"""
VibeVoice ASR model with native vLLM multimodal integration.
This model combines VibeVoice acoustic/semantic tokenizers for audio encoding
with a causal language model for text generation.
"""
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
"""Return the placeholder string format for a given modality.
Returns "<|AUDIO|>" which vLLM inserts into the conversation prompt.
This single placeholder is later expanded by `_get_prompt_updates` into:
[speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
where N = ceil(audio_samples / compress_ratio).
"""
if modality.startswith("audio"):
return "<|AUDIO|>"
raise ValueError(f"Unsupported modality: {modality}")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.audio_encoder = VibeVoiceAudioEncoder(config)
# Pass decoder_config to the language model initialization
decoder_config = getattr(config, "decoder_config", config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=decoder_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
# Set the language model dtype for audio encoder output conversion
# This should match vLLM's --dtype flag (e.g., bfloat16, float16, float32)
# Audio encoder internal computation stays in fp32 for numerical precision,
# but output is converted to LM dtype for compatibility
lm_dtype = vllm_config.model_config.dtype
if lm_dtype is not None:
self.audio_encoder._lm_dtype = lm_dtype
# Ensure audio encoder uses correct dtype (typically fp32 for precision)
try:
self.audio_encoder._ensure_audio_encoder_dtype()
except Exception:
pass
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
"""
Extract audio embeddings using VibeVoice's acoustic/semantic tokenizers.
Called by vLLM to get audio embeddings that replace audio placeholder tokens.
Returns:
Tuple of embedding tensors, one per audio input.
"""
# Get raw audio data (stored by our processor)
raw_audio = kwargs.get("raw_audio")
raw_audio_lengths = kwargs.get("raw_audio_lengths")
# Handle no audio input - this happens during memory profiling
if raw_audio is None:
return []
# Handle empty audio list
if isinstance(raw_audio, (list, tuple)) and len(raw_audio) == 0:
return []
# Flatten raw_audio_lengths if it's nested
def flatten_lengths(lengths):
"""Flatten nested lists/tensors of lengths to a single list."""
if lengths is None:
return []
result = []
if isinstance(lengths, torch.Tensor):
lengths = lengths.tolist()
if isinstance(lengths, (list, tuple)):
for item in lengths:
if isinstance(item, (list, tuple)):
result.extend(flatten_lengths(item))
elif isinstance(item, torch.Tensor):
if item.dim() == 0:
result.append(item.item())
else:
result.extend(item.tolist())
else:
result.append(item)
else:
result.append(lengths)
return result
raw_audio_lengths = flatten_lengths(raw_audio_lengths)
# Streaming controls. Enabled by default; can be overridden per-call.
use_streaming_flag = bool(
kwargs.get(
"use_streaming",
getattr(self.audio_encoder, "enable_streaming", True),
)
)
streaming_segment_duration = kwargs.get(
"streaming_segment_duration",
getattr(self.audio_encoder, "streaming_segment_duration", 60.0),
)
# Process each audio through the VibeVoice encoder
embeddings = []
# Get model device for tensor placement.
# dtype is NOT set here — audio_encoder.forward() handles it internally:
# input: converted to fp32 (self._audio_encoder_dtype)
# output: converted to bfloat16 (self._lm_dtype)
try:
device = next(self.audio_encoder.parameters()).device
except StopIteration:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Handle both stacked tensor and list of tensors
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
if isinstance(raw_audio, torch.Tensor):
if raw_audio.dim() == 3:
# Shape: [batch_size, 1, seq_len] - squeeze the middle dimension
num_audios = raw_audio.shape[0]
audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)]
elif raw_audio.dim() == 2:
# Shape: [batch_size, seq_len]
num_audios = raw_audio.shape[0]
audio_list = [raw_audio[i] for i in range(num_audios)]
else:
# Single 1D tensor
audio_list = [raw_audio]
elif isinstance(raw_audio, (list, tuple)):
audio_list = list(raw_audio)
else:
# Single tensor
audio_list = [raw_audio]
for i, audio_tensor in enumerate(audio_list):
try:
if isinstance(audio_tensor, list):
audio_tensor = torch.stack(audio_tensor)
# Ensure tensor
if not isinstance(audio_tensor, torch.Tensor):
audio_tensor = torch.tensor(audio_tensor)
# Only place on correct device; audio_encoder.forward() handles dtype
audio_tensor = audio_tensor.to(device=device)
# Get actual length if available, otherwise use full length
if raw_audio_lengths and i < len(raw_audio_lengths):
actual_len = int(raw_audio_lengths[i])
if actual_len > 0 and actual_len <= audio_tensor.shape[-1]:
# Truncate from the last dimension (sequence length)
audio_tensor = audio_tensor[..., :actual_len]
# Skip if audio is too short (< 1 frame)
if audio_tensor.numel() < 160: # Minimum ~1ms at 24kHz
continue
# Encode audio through VibeVoice encoder
audio_embeds = self.audio_encoder(
audio_tensor,
use_streaming=use_streaming_flag,
segment_duration_s=streaming_segment_duration,
)
# audio_embeds shape: [1, seq_len, hidden_size]
# We need to return it as a single embedding tensor per audio
final_embed = audio_embeds.squeeze(0)
embeddings.append(final_embed)
except Exception as e:
# Log error but don't crash - this helps debug profiling issues
print(f"[VibeVoice] Error encoding audio {i}: {e}")
import traceback
traceback.print_exc()
# Return empty embedding to avoid crash
continue
return tuple(embeddings)
def get_input_embeddings(self) -> torch.nn.Module:
"""Return the text embedding layer (embed_tokens).
vLLM uses this to get the embedding module for converting token IDs
to embeddings during decode phase.
Returns:
The embed_tokens module from the language model
"""
# Get embed_tokens from the language model
if hasattr(self.language_model, 'model') and hasattr(self.language_model.model, 'embed_tokens'):
return self.language_model.model.embed_tokens
elif hasattr(self.language_model, 'embed_tokens'):
return self.language_model.embed_tokens
else:
# Try to get from inner model
inner = self.language_model
if hasattr(inner, 'language_model'):
inner = inner.language_model
if hasattr(inner, 'model') and hasattr(inner.model, 'embed_tokens'):
return inner.model.embed_tokens
raise AttributeError("Cannot find embed_tokens layer")
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
is_multimodal: Optional[torch.Tensor] = None,
**kwargs, # Accept any additional kwargs for compatibility
) -> torch.Tensor:
"""Apply token embeddings to input_ids and merge with multimodal embeddings.
This is the preferred method in vLLM V1 for converting token IDs
to embeddings and merging multimodal (audio) embeddings.
Args:
input_ids: Tensor of token IDs to embed
multimodal_embeddings: Pre-computed multimodal embeddings (audio).
Can be a Tensor or a List of Tensors (vLLM standard).
is_multimodal: Boolean mask indicating which positions are multimodal
**kwargs: Additional arguments for compatibility
Returns:
Tensor of embeddings with multimodal content merged in
"""
from vllm.model_executor.models.utils import _merge_multimodal_embeddings
# Get text embeddings
embed_tokens = self.get_input_embeddings()
inputs_embeds = embed_tokens(input_ids)
# Merge multimodal embeddings if provided
if multimodal_embeddings is not None and is_multimodal is not None:
# Use vLLM's standard merge function which handles List[Tensor] correctly
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds,
multimodal_embeddings,
is_multimodal,
)
return inputs_embeds
def get_language_model(self) -> torch.nn.Module:
"""Return the language model backbone."""
return self.language_model
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> set[str]:
"""Load model weights from checkpoint.
The checkpoint has weights named like:
- lm_head.weight -> language_model.lm_head.weight
- model.language_model.layers.X... -> language_model.model.layers.X...
- model.acoustic_tokenizer... -> audio_encoder.acoustic_tokenizer...
- model.semantic_tokenizer... -> audio_encoder.semantic_tokenizer...
- model.acoustic_connector... -> audio_encoder.acoustic_connector...
- model.semantic_connector... -> audio_encoder.semantic_connector...
Let vLLM handle all dtype conversions according to --dtype flag.
"""
# Map weight prefixes for VibeVoice
# The checkpoint uses "model." prefix, we need to remap it
mapper = WeightsMapper(
orig_to_new_prefix={
# Audio encoder components: model.X -> audio_encoder.X
"model.acoustic_tokenizer.": "audio_encoder.acoustic_tokenizer.",
"model.semantic_tokenizer.": "audio_encoder.semantic_tokenizer.",
"model.acoustic_connector.": "audio_encoder.acoustic_connector.",
"model.semantic_connector.": "audio_encoder.semantic_connector.",
# Language model: model.language_model.X -> language_model.model.X
"model.language_model.": "language_model.model.",
# LM head: lm_head.X -> language_model.lm_head.X
"lm_head.": "language_model.lm_head.",
}
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=mapper)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""
Forward pass for VibeVoice ASR model.
Handles embedding computation and language model forward pass.
Uses inputs_embeds if provided (from vLLM multimodal merge),
otherwise computes embeddings from input_ids.
Args:
input_ids: Token IDs. May be None when inputs_embeds is provided.
positions: Position indices for the input tokens.
intermediate_tensors: Intermediate tensors for pipeline parallelism.
inputs_embeds: Pre-computed embeddings (from multimodal merge or decode).
"""
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
# Only compute from input_ids if inputs_embeds is not available
if inputs_embeds is None and input_ids is not None:
# Compute embeddings from input_ids
inputs_embeds = self.get_input_embeddings()(input_ids)
# If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds
if intermediate_tensors is not None:
inputs_embeds = None
# Get the inner model - handle both wrapped and direct language models
language_model = self.language_model
if hasattr(language_model, "language_model"):
language_model = language_model.language_model
# Call the language model's model (Qwen2Model)
# vLLM V1 passes kv_caches and attn_metadata via context, not arguments
# IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding
hidden_states = language_model.model(
input_ids=None, # Always None when we have inputs_embeds
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds
)
return hidden_states
# Alias for training checkpoint compatibility
VibeVoiceForASRTraining = VibeVoiceForCausalLM