monkey_01 / miner.py
Gem1832's picture
Upload folder using huggingface_hub
543e56d verified
from __future__ import annotations
import io
import json
import os
import re
import sys
import threading
import traceback
from functools import cached_property
from pathlib import Path
from types import SimpleNamespace
from typing import AbstractSet, Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from transformers.utils import logging as hf_logging
import math
import random
import warnings
from dataclasses import dataclass
try:
import librosa
except Exception:
librosa = None
try:
import resampy
except Exception:
resampy = None
def _resample_if_needed(wav: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
if orig_sr == target_sr:
return wav.astype(np.float32, copy=False)
if resampy is not None:
return resampy.resample(wav.astype(np.float32), orig_sr, target_sr)
if librosa is not None:
return librosa.resample(
y=wav.astype(np.float32), orig_sr=orig_sr, target_sr=target_sr
)
warnings.warn(
"No resampler available; treating audio as target_sr without resampling. Install resampy or librosa.",
RuntimeWarning,
)
return wav.astype(np.float32, copy=False)
class QWEN3VoxDataset:
def __init__(
self,
dataset: Any,
text_column: str = "text",
audio_column: str = "audio",
voice_prompts_column: Optional[str] = "voice_prompts",
) -> None:
self.dataset = dataset
self.text_column = text_column
self.audio_column = audio_column
self.voice_prompts_column = voice_prompts_column
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> Dict[str, Any]:
item = self.dataset[idx]
data: Dict[str, Any] = {}
data["text"] = item[self.text_column]
data["audio"] = item[self.audio_column]
user_provided_prompt = None
if self.voice_prompts_column and self.voice_prompts_column in item:
user_provided_prompt = item[self.voice_prompts_column]
if user_provided_prompt:
if not isinstance(user_provided_prompt, list):
data["voice_prompts"] = [user_provided_prompt]
else:
data["voice_prompts"] = user_provided_prompt
else:
try:
target_sr = 22050
wav_array = _load_audio_to_24k(
item[self.audio_column], target_sr=target_sr
)
audio_len_seconds = len(wav_array) / target_sr
min_len_sec = min(5.0, audio_len_seconds / 4.0)
max_len_sec = min(15.0, audio_len_seconds / 2.0)
if min_len_sec > max_len_sec:
min_len_sec = max_len_sec
max_len_sec = min(max_len_sec, audio_len_seconds)
if max_len_sec > 0.1:
prompt_len_sec = random.uniform(min_len_sec, max_len_sec)
prompt_len_samples = int(prompt_len_sec * target_sr)
max_start_sample = len(wav_array) - prompt_len_samples
start_sample = random.randint(0, max_start_sample)
prompt_crop = wav_array[
start_sample : start_sample + prompt_len_samples
]
data["voice_prompts"] = [prompt_crop]
else:
data["voice_prompts"] = None
except Exception as e:
warnings.warn(f"Could not create voice prompt for item {idx }: {e }")
data["voice_prompts"] = None
return data
def _apply_silence_with_crossfade(
wav: np.ndarray,
*,
sample_rate: int,
pre_silence_sec: float = 0.25,
pre_crossfade_sec: float = 0.25,
post_crossfade_sec: float = 0.25,
post_silence_sec: float = 0.75,
) -> np.ndarray:
wav = np.asarray(wav, dtype=np.float32).reshape(-1)
start_sil_samples = int(round(pre_silence_sec * sample_rate))
end_sil_samples = int(round(post_silence_sec * sample_rate))
pre_crossfade_samples = int(round(pre_crossfade_sec * sample_rate))
post_crossfade_samples = int(round(post_crossfade_sec * sample_rate))
total_len = wav.shape[0]
if total_len == 0:
pieces: List[np.ndarray] = []
if start_sil_samples > 0:
pieces.append(np.zeros(start_sil_samples, dtype=np.float32))
if end_sil_samples > 0:
pieces.append(np.zeros(end_sil_samples, dtype=np.float32))
return np.concatenate(pieces) if pieces else wav
start_len = min(pre_crossfade_samples, total_len)
remaining_after_start = max(total_len - start_len, 0)
end_len = min(post_crossfade_samples, remaining_after_start)
middle_end_idx = total_len - end_len
start_segment = wav[:start_len]
middle_segment = wav[start_len:middle_end_idx]
end_segment = wav[middle_end_idx:]
def _linear_fade(num_samples: int, start: float, end: float) -> np.ndarray:
if num_samples <= 0:
return np.zeros((0,), dtype=np.float32)
return np.linspace(start, end, num_samples, endpoint=True, dtype=np.float32)
start_crossfade = start_segment * _linear_fade(start_len, 0.0, 1.0)
end_crossfade = end_segment * _linear_fade(end_segment.shape[0], 1.0, 0.0)
pieces: List[np.ndarray] = []
if start_sil_samples > 0:
pieces.append(np.zeros(start_sil_samples, dtype=np.float32))
if start_crossfade.size > 0:
pieces.append(start_crossfade.astype(np.float32, copy=False))
if middle_segment.size > 0:
pieces.append(middle_segment.astype(np.float32, copy=False))
if end_crossfade.size > 0:
pieces.append(end_crossfade.astype(np.float32, copy=False))
if end_sil_samples > 0:
pieces.append(np.zeros(end_sil_samples, dtype=np.float32))
return np.concatenate(pieces)
def _load_audio_to_24k(
audio: Union[str, np.ndarray, torch.Tensor, Dict[str, Any]],
*,
target_sr: int = 22050,
augment_with_silence: bool = False,
) -> np.ndarray:
if isinstance(audio, np.ndarray):
wav_out = audio.astype(np.float32)
elif isinstance(audio, torch.Tensor):
wav_out = audio.detach().cpu().float().numpy()
elif isinstance(audio, str):
if librosa is None:
raise RuntimeError(
"librosa is required to load audio file paths. Please pip install librosa."
)
wav, sr = librosa.load(audio, sr=None, mono=True)
wav_out = _resample_if_needed(wav, int(sr), target_sr)
elif isinstance(audio, dict) and "array" in audio and ("sampling_rate" in audio):
arr = np.asarray(audio["array"], dtype=np.float32)
sr = int(audio["sampling_rate"])
wav_out = _resample_if_needed(arr, sr, target_sr)
else:
raise ValueError(f"Unsupported audio type: {type (audio )}")
wav_out = np.asarray(wav_out, dtype=np.float32)
if augment_with_silence:
wav_out = _apply_silence_with_crossfade(wav_out, sample_rate=target_sr)
return wav_out
@dataclass
class QWEN3VoxCollator:
processor: Any
max_length: Optional[int] = None
speech_compress_ratio: int = 3200
semantic_vae_dim: int = 128
compute_semantics: bool = False
debug_checks: bool = False
text_field: str = "text"
audio_field: str = "audio"
voice_prompts_field: str = "voice_prompts"
voice_prompt_drop_rate: float = 0.0
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
batch_size = len(features)
sample_input_ids: List[List[int]] = []
sample_attention_masks: List[List[int]] = []
sample_acoustic_input_masks: List[List[bool]] = []
sample_acoustic_loss_masks: List[List[bool]] = []
all_speech_waveforms: List[np.ndarray] = []
all_speech_latent_lengths: List[int] = []
per_segment_is_target: List[bool] = []
for ex in features:
text: str = ex.get(self.text_field, "")
voice_prompts: Optional[List[Union[str, np.ndarray, torch.Tensor]]] = (
ex.get(self.voice_prompts_field)
)
target_audio: Union[str, np.ndarray, torch.Tensor, Dict[str, Any]] = ex.get(
self.audio_field
)
_drop_rate = self.voice_prompt_drop_rate
if _drop_rate < 0.0:
_drop_rate = 0.0
elif _drop_rate > 1.0:
_drop_rate = 1.0
proc = self.processor(
text=[text],
voice_samples=(
[voice_prompts]
if voice_prompts is not None and random.random() >= _drop_rate
else None
),
padding=False,
truncation=False,
max_length=self.max_length,
return_tensors="pt",
)
ids = proc["input_ids"][0].tolist()
attn = proc.get("attention_mask", torch.ones_like(proc["input_ids"]))[
0
].tolist()
speech_input_mask = proc.get("speech_input_mask")
if speech_input_mask is None:
speech_input_mask = torch.zeros_like(
proc["input_ids"], dtype=torch.bool
)
speech_input_mask_list = speech_input_mask[0].tolist()
wav_target = _load_audio_to_24k(
target_audio, target_sr=22050, augment_with_silence=True
)
target_latent_len = None
try:
acoustic_tok = getattr(self.processor, "acoustic_tokenizer", None)
if acoustic_tok is not None and hasattr(acoustic_tok, "encode"):
enc_out = acoustic_tok.encode(wav_target)
T = None
try:
if (
hasattr(enc_out, "shape")
and len(getattr(enc_out, "shape", [])) >= 1
):
T = int(enc_out.shape[0])
else:
cand = enc_out
for _ in range(2):
if isinstance(cand, (list, tuple)) and len(cand) > 0:
cand = cand[0]
if (
hasattr(cand, "shape")
and len(getattr(cand, "shape", [])) >= 1
):
T = int(cand.shape[0])
except Exception:
T = None
if T is not None and T > 0:
target_latent_len = T
except Exception:
target_latent_len = None
if target_latent_len is None:
target_latent_len = max(
1,
int(math.ceil(len(wav_target) / float(self.speech_compress_ratio))),
)
speech_diff_id = self.processor.tokenizer.speech_diffusion_id
target_placeholders = [speech_diff_id] * target_latent_len
ids_extended = ids + target_placeholders
attn_extended = attn + [1] * target_latent_len
acoustic_input_mask = speech_input_mask_list + [True] * target_latent_len
acoustic_loss_mask = [False] * len(speech_input_mask_list) + [
True
] * target_latent_len
speech_end_id = self.processor.tokenizer.speech_end_id
ids_extended.append(speech_end_id)
attn_extended.append(1)
acoustic_input_mask.append(False)
acoustic_loss_mask.append(False)
eos_token_id = getattr(self.processor.tokenizer, "eos_id", None)
if eos_token_id is None:
eos_token_id = getattr(self.processor.tokenizer, "eos_token_id", None)
if eos_token_id is not None and eos_token_id >= 0:
ids_extended.append(eos_token_id)
attn_extended.append(1)
acoustic_input_mask.append(False)
acoustic_loss_mask.append(False)
if self.max_length is not None and len(ids_extended) > self.max_length:
cut = len(ids_extended) - int(self.max_length)
leading_non_acoustic = 0
for v in acoustic_input_mask:
if v:
break
leading_non_acoustic += 1
if cut > leading_non_acoustic:
raise ValueError(
f"--max_length={self .max_length } would truncate into acoustic tokens. Needed cut={cut }, but only {leading_non_acoustic } leading non-acoustic tokens available. Increase max_length or shorten text/voice-prompt preamble."
)
ids_extended = ids_extended[cut:]
attn_extended = attn_extended[cut:]
acoustic_input_mask = acoustic_input_mask[cut:]
acoustic_loss_mask = acoustic_loss_mask[cut:]
sample_input_ids.append(ids_extended)
sample_attention_masks.append(attn_extended)
sample_acoustic_input_masks.append(acoustic_input_mask)
sample_acoustic_loss_masks.append(acoustic_loss_mask)
voice_speeches = []
voice_latent_lengths = []
if proc.get("speech_tensors") is not None:
voice_np = proc["speech_tensors"].cpu().numpy()
voice_masks = proc["speech_masks"].cpu().numpy().astype(bool)
for seg_idx in range(voice_np.shape[0]):
voice_speeches.append(voice_np[seg_idx])
voice_latent_lengths.append(int(voice_masks[seg_idx].sum()))
all_speech_waveforms.extend(voice_speeches)
all_speech_latent_lengths.extend(voice_latent_lengths)
per_segment_is_target.extend([False] * len(voice_speeches))
all_speech_waveforms.append(wav_target)
all_speech_latent_lengths.append(target_latent_len)
per_segment_is_target.append(True)
max_seq_len = max((len(x) for x in sample_input_ids))
padded_input_ids = []
padded_attention_masks = []
padded_acoustic_input_masks = []
padded_acoustic_loss_masks = []
tok = self.processor.tokenizer
pad_token_id = getattr(tok, "pad_token_id", None)
if pad_token_id is None or pad_token_id < 0:
pad_token_id = getattr(tok, "eos_token_id", None)
if pad_token_id is None or pad_token_id < 0:
raise ValueError(
"Tokenizer has no pad_token_id or eos_token_id; please set one or pass a valid pad id."
)
for ids, attn, ain_mask, aloss_mask in zip(
sample_input_ids,
sample_attention_masks,
sample_acoustic_input_masks,
sample_acoustic_loss_masks,
):
pad_len = max_seq_len - len(ids)
padded_input_ids.append(ids + [pad_token_id] * pad_len)
padded_attention_masks.append(attn + [0] * pad_len)
padded_acoustic_input_masks.append(ain_mask + [False] * pad_len)
padded_acoustic_loss_masks.append(aloss_mask + [False] * pad_len)
input_ids_tensor = torch.tensor(padded_input_ids, dtype=torch.long)
attention_mask_tensor = torch.tensor(padded_attention_masks, dtype=torch.long)
acoustic_input_mask_tensor = torch.tensor(
padded_acoustic_input_masks, dtype=torch.bool
)
acoustic_loss_mask_tensor = torch.tensor(
padded_acoustic_loss_masks, dtype=torch.bool
)
if all_speech_waveforms:
max_wave_len = max((w.shape[0] for w in all_speech_waveforms))
padded_speeches = np.zeros(
(len(all_speech_waveforms), max_wave_len), dtype=np.float32
)
for i, w in enumerate(all_speech_waveforms):
L = w.shape[0]
padded_speeches[i, :L] = w
max_latent_len = (
max(all_speech_latent_lengths) if all_speech_latent_lengths else 1
)
speech_masks_np = np.zeros(
(len(all_speech_waveforms), max_latent_len), dtype=np.bool_
)
for i, L_lat in enumerate(all_speech_latent_lengths):
speech_masks_np[i, :L_lat] = True
speech_tensors_tensor = torch.tensor(padded_speeches, dtype=torch.float32)
speech_masks_tensor = torch.tensor(speech_masks_np, dtype=torch.bool)
speeches_loss_input_np = np.zeros_like(speech_masks_np, dtype=np.bool_)
for i, is_target in enumerate(per_segment_is_target):
if is_target:
speeches_loss_input_np[i] = speech_masks_np[i]
speeches_loss_input_tensor = torch.tensor(
speeches_loss_input_np, dtype=torch.bool
)
if (
self.compute_semantics
and hasattr(self.processor, "semantic_tokenizer")
and (self.processor.semantic_tokenizer is not None)
):
sem_feats: List[np.ndarray] = []
for w in all_speech_waveforms:
try:
sem = self.processor.semantic_tokenizer.encode(w)
sem = np.asarray(sem, dtype=np.float32)
except Exception:
sem = np.zeros((0, self.semantic_vae_dim), dtype=np.float32)
if sem.ndim != 2:
raise RuntimeError(
f"Semantic tokenizer returned unexpected shape {sem .shape }. Expect [T, D]."
)
L = sem.shape[0]
D = sem.shape[1]
if D != self.semantic_vae_dim:
if D < self.semantic_vae_dim:
pad_d = np.zeros(
(L, self.semantic_vae_dim - D), dtype=np.float32
)
sem = np.concatenate([sem, pad_d], axis=1)
else:
sem = sem[:, : self.semantic_vae_dim]
if L < max_latent_len:
pad = np.zeros(
(max_latent_len - L, self.semantic_vae_dim),
dtype=np.float32,
)
sem = np.concatenate([sem, pad], axis=0)
elif L > max_latent_len:
sem = sem[:max_latent_len]
sem_feats.append(sem.astype(np.float32))
speech_semantic_tensors = torch.tensor(
np.stack(sem_feats, axis=0), dtype=torch.float32
)
else:
raise RuntimeError(
"Semantic features are required but could not be computed. Ensure processor.semantic_tokenizer is available or precompute and provide features."
)
else:
speech_tensors_tensor = None
speech_masks_tensor = None
speeches_loss_input_tensor = None
speech_semantic_tensors = None
if self.debug_checks:
assert (input_ids_tensor >= 0).all(), "input_ids contains negative indices"
if speech_tensors_tensor is not None:
assert (
speech_tensors_tensor.dim() == 2
), "Expected speech_tensors 2D [segments, samples]"
return {
"input_ids": input_ids_tensor,
"attention_mask": attention_mask_tensor,
"speech_tensors": speech_tensors_tensor,
"speech_masks": speech_masks_tensor,
"speech_semantic_tensors": speech_semantic_tensors,
"acoustic_input_mask": acoustic_input_mask_tensor,
"acoustic_loss_mask": acoustic_loss_mask_tensor,
"speeches_loss_input": speeches_loss_input_tensor,
}
' QWEN3Vox_AcousticTokenizer model configuration'
from typing import Dict, List, Optional, Tuple
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
logger = logging.get_logger(__name__)
class QWEN3VoxAcousticTokenizerConfig(PretrainedConfig):
model_type = 'vibevoice_acoustic_tokenizer'
def __init__(
self,
channels: int = 1,
corpus_normalize: float = 0.0,
causal: bool = True,
vae_dim: int = 64,
fix_std: float = 0.5,
std_dist_type: str = "gaussian",
mixer_layer: str = "depthwise_conv",
conv_norm: str = "none",
pad_mode: str = "constant",
disable_last_norm: bool = True,
layernorm: str = "RMSNorm",
layernorm_eps: float = 1e-05,
layernorm_elementwise_affine: bool = True,
conv_bias: bool = True,
layer_scale_init_value: float = 1e-06,
weight_init_value: float = 0.01,
encoder_n_filters: int = 32,
encoder_ratios: Optional[List[int]] = [8, 5, 5, 4, 2, 2],
encoder_depths: str = "3-3-3-3-3-3-8",
decoder_n_filters: int = 32,
decoder_ratios: Optional[List[int]] = None,
decoder_depths: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.channels = channels
self.corpus_normalize = corpus_normalize
self.causal = causal
self.vae_dim = vae_dim
self.fix_std = fix_std
self.std_dist_type = std_dist_type
self.conv_norm = conv_norm
self.pad_mode = pad_mode
self.layernorm_eps = layernorm_eps
self.disable_last_norm = disable_last_norm
self.layernorm = layernorm
self.layernorm_elementwise_affine = layernorm_elementwise_affine
self.conv_bias = conv_bias
self.layer_scale_init_value = layer_scale_init_value
self.weight_init_value = weight_init_value
self.mixer_layer = mixer_layer
self.encoder_n_filters = encoder_n_filters
self.encoder_ratios = encoder_ratios
self.encoder_depths = encoder_depths
self.decoder_ratios = (
decoder_ratios if decoder_ratios is not None else encoder_ratios
)
self.decoder_n_filters = decoder_n_filters
self.decoder_depths = decoder_depths
class QWEN3VoxSemanticTokenizerConfig(PretrainedConfig):
model_type = 'vibevoice_semantic_tokenizer'
def __init__(
self,
channels: int = 1,
corpus_normalize: float = 0.0,
causal: bool = True,
vae_dim: int = 64,
fix_std: float = 0,
std_dist_type: str = "none",
mixer_layer: str = "depthwise_conv",
conv_norm: str = "none",
pad_mode: str = "constant",
disable_last_norm: bool = True,
layernorm: str = "RMSNorm",
layernorm_eps: float = 1e-05,
layernorm_elementwise_affine: bool = True,
conv_bias: bool = True,
layer_scale_init_value: float = 1e-06,
weight_init_value: float = 0.01,
encoder_n_filters: int = 32,
encoder_ratios: Optional[List[int]] = [8, 5, 5, 4, 2, 2],
encoder_depths: str = "3-3-3-3-3-3-8",
**kwargs,
):
super().__init__(**kwargs)
self.channels = channels
self.corpus_normalize = corpus_normalize
self.causal = causal
self.vae_dim = vae_dim
self.fix_std = fix_std
self.std_dist_type = std_dist_type
self.conv_norm = conv_norm
self.pad_mode = pad_mode
self.layernorm_eps = layernorm_eps
self.disable_last_norm = disable_last_norm
self.layernorm = layernorm
self.layernorm_elementwise_affine = layernorm_elementwise_affine
self.conv_bias = conv_bias
self.layer_scale_init_value = layer_scale_init_value
self.weight_init_value = weight_init_value
self.mixer_layer = mixer_layer
self.encoder_n_filters = encoder_n_filters
self.encoder_ratios = encoder_ratios
self.encoder_depths = encoder_depths
class QWEN3VoxDiffusionHeadConfig(PretrainedConfig):
model_type = 'vibevoice_diffusion_head'
def __init__(
self,
hidden_size=768,
head_layers=4,
head_ffn_ratio=3.0,
rms_norm_eps=1e-05,
latent_size=64,
speech_vae_dim=None,
prediction_type="v_prediction",
diffusion_type="ddpm",
ddpm_num_steps=1000,
ddpm_num_inference_steps=30,
ddpm_beta_schedule="cosine",
ddpm_batch_mul=4,
**kwargs,
):
self.hidden_size = hidden_size
self.head_layers = head_layers
self.head_ffn_ratio = head_ffn_ratio
self.rms_norm_eps = rms_norm_eps
self.latent_size = latent_size
self.speech_vae_dim = speech_vae_dim
self.prediction_type = prediction_type
self.diffusion_type = diffusion_type
self.ddpm_num_steps = ddpm_num_steps
self.ddpm_num_inference_steps = ddpm_num_inference_steps
self.ddpm_beta_schedule = ddpm_beta_schedule
self.ddpm_batch_mul = ddpm_batch_mul
super().__init__(**kwargs)
class QWEN3VoxConfig(PretrainedConfig):
model_type = 'vibevoice'
is_composition = True
sub_configs = {
"acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig,
"semantic_tokenizer_config": QWEN3VoxSemanticTokenizerConfig,
"decoder_config": Qwen2Config,
"diffusion_head_config": QWEN3VoxDiffusionHeadConfig,
}
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
acoustic_tokenizer_config=None,
semantic_tokenizer_config=None,
decoder_config=None,
diffusion_head_config=None,
**kwargs,
):
kwargs["_attn_implementation_autoset"] = False
if acoustic_tokenizer_config is None:
self.acoustic_tokenizer_config = self.sub_configs[
"acoustic_tokenizer_config"
]()
elif isinstance(acoustic_tokenizer_config, dict):
acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer'
self.acoustic_tokenizer_config = self.sub_configs[
"acoustic_tokenizer_config"
](**acoustic_tokenizer_config)
elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig):
self.acoustic_tokenizer_config = acoustic_tokenizer_config
if semantic_tokenizer_config is None:
self.semantic_tokenizer_config = self.sub_configs[
"semantic_tokenizer_config"
]()
elif isinstance(semantic_tokenizer_config, dict):
semantic_tokenizer_config["model_type"] = 'vibevoice_semantic_tokenizer'
self.semantic_tokenizer_config = self.sub_configs[
"semantic_tokenizer_config"
](**semantic_tokenizer_config)
elif isinstance(semantic_tokenizer_config, QWEN3VoxSemanticTokenizerConfig):
self.semantic_tokenizer_config = semantic_tokenizer_config
if decoder_config is None:
self.decoder_config = self.sub_configs["decoder_config"]()
elif isinstance(decoder_config, dict):
if decoder_config.get("model_type", "") == "qwen2":
self.decoder_config = Qwen2Config(**decoder_config)
else:
raise ValueError(
f"Unsupported decoder model type: {decoder_config .get ('model_type','')}"
)
elif isinstance(decoder_config, (Qwen2Config,)):
self.decoder_config = decoder_config
if diffusion_head_config is None:
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
elif isinstance(diffusion_head_config, dict):
diffusion_head_config["model_type"] = 'vibevoice_diffusion_head'
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](
**diffusion_head_config
)
elif isinstance(diffusion_head_config, QWEN3VoxDiffusionHeadConfig):
self.diffusion_head_config = diffusion_head_config
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64)
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, "vae_dim", 128)
super().__init__(**kwargs)
class QWEN3VoxASRConfig(PretrainedConfig):
model_type = 'vibevoice'
is_composition = True
sub_configs = {
"acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig,
"semantic_tokenizer_config": QWEN3VoxSemanticTokenizerConfig,
"decoder_config": Qwen2Config,
}
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
acoustic_tokenizer_config=None,
semantic_tokenizer_config=None,
decoder_config=None,
**kwargs,
):
kwargs["_attn_implementation_autoset"] = False
if acoustic_tokenizer_config is None:
self.acoustic_tokenizer_config = self.sub_configs[
"acoustic_tokenizer_config"
]()
elif isinstance(acoustic_tokenizer_config, dict):
acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer'
self.acoustic_tokenizer_config = self.sub_configs[
"acoustic_tokenizer_config"
](**acoustic_tokenizer_config)
elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig):
self.acoustic_tokenizer_config = acoustic_tokenizer_config
if semantic_tokenizer_config is None:
self.semantic_tokenizer_config = self.sub_configs[
"semantic_tokenizer_config"
]()
elif isinstance(semantic_tokenizer_config, dict):
semantic_tokenizer_config["model_type"] = 'vibevoice_semantic_tokenizer'
self.semantic_tokenizer_config = self.sub_configs[
"semantic_tokenizer_config"
](**semantic_tokenizer_config)
elif isinstance(semantic_tokenizer_config, QWEN3VoxSemanticTokenizerConfig):
self.semantic_tokenizer_config = semantic_tokenizer_config
if decoder_config is None:
self.decoder_config = self.sub_configs["decoder_config"]()
elif isinstance(decoder_config, dict):
if decoder_config.get("model_type", "") == "qwen2":
self.decoder_config = Qwen2Config(**decoder_config)
else:
raise ValueError(
f"Unsupported decoder model type: {decoder_config .get ('model_type','')}"
)
elif isinstance(decoder_config, Qwen2Config):
self.decoder_config = decoder_config
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64)
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, "vae_dim", 128)
super().__init__(**kwargs)
def get_text_config(self, decoder: bool = False):
return self.decoder_config
@property
def vocab_size(self):
return self.decoder_config.vocab_size
@property
def num_attention_heads(self):
return self.decoder_config.num_attention_heads
@property
def num_key_value_heads(self):
return self.decoder_config.num_key_value_heads
@property
def hidden_size(self):
return self.decoder_config.hidden_size
@property
def num_hidden_layers(self):
return self.decoder_config.num_hidden_layers
@property
def head_dim(self):
return getattr(
self.decoder_config,
"head_dim",
self.hidden_size // self.num_attention_heads,
)
__all__ = [
'QWEN3VoxAcousticTokenizerConfig',
'QWEN3VoxSemanticTokenizerConfig',
'QWEN3VoxDiffusionHeadConfig',
'QWEN3VoxConfig',
'QWEN3VoxASRConfig',
]
import torch
import asyncio
from queue import Queue
from typing import TYPE_CHECKING, Optional
from transformers.generation import BaseStreamer
class AudioStreamer(BaseStreamer):
def __init__(
self,
batch_size: int,
stop_signal: Optional[any] = None,
timeout: Optional[float] = None,
):
self.batch_size = batch_size
self.stop_signal = stop_signal
self.timeout = timeout
self.audio_queues = [Queue() for _ in range(batch_size)]
self.finished_flags = [False for _ in range(batch_size)]
self.sample_indices_map = {}
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and (not self.finished_flags[idx]):
audio_chunk = audio_chunks[i].detach().cpu()
self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
def end(self, sample_indices: Optional[torch.Tensor] = None):
if sample_indices is None:
for idx in range(self.batch_size):
if not self.finished_flags[idx]:
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
self.finished_flags[idx] = True
else:
for sample_idx in sample_indices:
idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
if idx < self.batch_size and (not self.finished_flags[idx]):
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
self.finished_flags[idx] = True
def __iter__(self):
return AudioBatchIterator(self)
def get_stream(self, sample_idx: int):
if sample_idx >= self.batch_size:
raise ValueError(
f"Sample index {sample_idx } exceeds batch size {self .batch_size }"
)
return AudioSampleIterator(self, sample_idx)
class AudioSampleIterator:
def __init__(self, streamer: AudioStreamer, sample_idx: int):
self.streamer = streamer
self.sample_idx = sample_idx
def __iter__(self):
return self
def __next__(self):
value = self.streamer.audio_queues[self.sample_idx].get(
timeout=self.streamer.timeout
)
if value == self.streamer.stop_signal:
raise StopIteration()
return value
class AudioBatchIterator:
def __init__(self, streamer: AudioStreamer):
self.streamer = streamer
self.active_samples = set(range(streamer.batch_size))
def __iter__(self):
return self
def __next__(self):
if not self.active_samples:
raise StopIteration()
batch_chunks = {}
samples_to_remove = set()
for idx in self.active_samples:
try:
value = self.streamer.audio_queues[idx].get(block=False)
if value == self.streamer.stop_signal:
samples_to_remove.add(idx)
else:
batch_chunks[idx] = value
except:
pass
self.active_samples -= samples_to_remove
if batch_chunks:
return batch_chunks
elif self.active_samples:
import time
time.sleep(0.01)
return self.__next__()
else:
raise StopIteration()
class AsyncAudioStreamer(AudioStreamer):
def __init__(
self,
batch_size: int,
stop_signal: Optional[any] = None,
timeout: Optional[float] = None,
):
super().__init__(batch_size, stop_signal, timeout)
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
self.loop = asyncio.get_running_loop()
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and (not self.finished_flags[idx]):
audio_chunk = audio_chunks[i].detach().cpu()
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, audio_chunk
)
def end(self, sample_indices: Optional[torch.Tensor] = None):
if sample_indices is None:
indices_to_end = range(self.batch_size)
else:
indices_to_end = [
s.item() if torch.is_tensor(s) else s for s in sample_indices
]
for idx in indices_to_end:
if idx < self.batch_size and (not self.finished_flags[idx]):
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, self.stop_signal
)
self.finished_flags[idx] = True
async def get_stream(self, sample_idx: int):
if sample_idx >= self.batch_size:
raise ValueError(
f"Sample index {sample_idx } exceeds batch size {self .batch_size }"
)
while True:
value = await self.audio_queues[sample_idx].get()
if value == self.stop_signal:
break
yield value
def __aiter__(self):
return AsyncAudioBatchIterator(self)
class AsyncAudioBatchIterator:
def __init__(self, streamer: AsyncAudioStreamer):
self.streamer = streamer
self.active_samples = set(range(streamer.batch_size))
def __aiter__(self):
return self
async def __anext__(self):
if not self.active_samples:
raise StopAsyncIteration()
batch_chunks = {}
samples_to_remove = set()
tasks = {
idx: asyncio.create_task(self._get_chunk(idx))
for idx in self.active_samples
}
done, pending = await asyncio.wait(
tasks.values(),
return_when=asyncio.FIRST_COMPLETED,
timeout=self.streamer.timeout,
)
for task in pending:
task.cancel()
for idx, task in tasks.items():
if task in done:
try:
value = await task
if value == self.streamer.stop_signal:
samples_to_remove.add(idx)
else:
batch_chunks[idx] = value
except asyncio.CancelledError:
pass
self.active_samples -= samples_to_remove
if batch_chunks:
return batch_chunks
elif self.active_samples:
return await self.__anext__()
else:
raise StopAsyncIteration()
async def _get_chunk(self, idx):
return await self.streamer.audio_queues[idx].get()
'Tokenization classes for QWEN3Vox.'
from typing import List, Optional, Union
from transformers.utils import logging
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
logger = logging.get_logger(__name__)
class QWEN3VoxTextTokenizer(Qwen2Tokenizer):
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
add_prefix_space=False,
add_special_tokens=True,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
add_special_tokens=add_special_tokens,
**kwargs,
)
self._add_q3_sp_tok()
def _add_q3_sp_tok(self):
special_tokens = {
"additional_special_tokens": [
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
]
}
num_added = self.add_special_tokens(special_tokens)
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
self._eos_id = self.convert_tokens_to_ids("<|endoftext|>")
return num_added
@property
def eos_id(self) -> int:
return self._eos_id
@property
def speech_start_id(self) -> int:
return self._speech_start_id
@property
def speech_end_id(self) -> int:
return self._speech_end_id
@property
def speech_diffusion_id(self) -> int:
return self._speech_diffusion_id
@property
def pad_id(self) -> int:
return -100
class QWEN3VoxTextTokenizerFast(Qwen2TokenizerFast):
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
add_prefix_space=False,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
self._add_q3_sp_tok()
def _add_q3_sp_tok(self):
special_tokens = {
"additional_special_tokens": [
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
]
}
num_added = self.add_special_tokens(special_tokens)
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
self._eos_id = self.eos_token_id
self._pad_id = self.convert_tokens_to_ids("<|image_pad|>")
return num_added
@property
def eos_id(self) -> int:
return self._eos_id
@property
def speech_start_id(self) -> int:
return self._speech_start_id
@property
def speech_end_id(self) -> int:
return self._speech_end_id
@property
def speech_diffusion_id(self) -> int:
return self._speech_diffusion_id
@property
def pad_id(self) -> int:
return self._pad_id
QWEN3VoxASRTextTokenizerFast = QWEN3VoxTextTokenizerFast
__all__ = [
'QWEN3VoxTextTokenizer',
'QWEN3VoxTextTokenizerFast',
]
"Utilities for loading fine-tuned LoRA adapters and connector weights."
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
from transformers.utils import logging
logger = logging.get_logger(__name__)
@dataclass
class _LoadReport:
language_model: bool = False
diffusion_head_lora: bool = False
diffusion_head_full: bool = False
acoustic_connector: bool = False
semantic_connector: bool = False
adapter_root: Optional[Path] = None
class _DiffusionHeadForwardShim(nn.Module):
def __init__(self, base: nn.Module):
super().__init__()
self.base = base
def forward(self, *args, **kwargs):
if len(args) >= 3:
noisy_images, timesteps, condition = args[:3]
else:
noisy_images = kwargs.get("noisy_images")
timesteps = kwargs.get("timesteps")
condition = kwargs.get("condition")
return self.base(noisy_images, timesteps, condition)
def _resolve_adapter_root(checkpoint_path: Path) -> Path:
if checkpoint_path.is_file():
checkpoint_path = checkpoint_path.parent
if (checkpoint_path / "lora").exists():
return checkpoint_path / "lora"
return checkpoint_path
def _load_connector(
module: Optional[nn.Module], path: Path, device: torch.device
) -> bool:
if module is None or not path.exists():
return False
state_dict = torch.load(path, map_location=device)
missing, unexpected = module.load_state_dict(state_dict, strict=False)
if missing:
logger.warning(f"Connector load missing keys: {missing }")
if unexpected:
logger.warning(f"Connector load unexpected keys: {unexpected }")
module.to(device)
return True
def _load_diffusion_head(
model, adapter_root: Path, device: torch.device, report: _LoadReport
) -> None:
diff_dir = adapter_root / "diffusion_head"
adapter_config = diff_dir / "adapter_config.json"
adapter_model = diff_dir / "adapter_model.bin"
adapter_model_safetensors = diff_dir / "adapter_model.safetensors"
try:
from peft import PeftModel
except ImportError as exc:
raise RuntimeError(
"peft is required to load diffusion head adapters but is not installed"
) from exc
if adapter_config.exists() and (
adapter_model.exists() or adapter_model_safetensors.exists()
):
logger.info(f"Loading diffusion head LoRA from {diff_dir }")
shim = _DiffusionHeadForwardShim(model.model.prediction_head)
_peft_load = getattr(PeftModel, "from_pretrained")
peft_head = _peft_load(shim, diff_dir)
peft_head.to(device)
model.model.prediction_head = peft_head
report.diffusion_head_lora = True
return
full_path = diff_dir / "diffusion_head_full.bin"
if not full_path.exists():
full_path = adapter_root / "diffusion_head_full.bin"
if full_path.exists():
logger.info(f"Loading full diffusion head weights from {full_path }")
state_dict = torch.load(full_path, map_location=device)
missing, unexpected = model.model.prediction_head.load_state_dict(
state_dict, strict=False
)
if missing:
logger.warning(f"Diffusion head load missing keys: {missing }")
if unexpected:
logger.warning(f"Diffusion head load unexpected keys: {unexpected }")
model.model.prediction_head.to(device)
report.diffusion_head_full = True
def _load_language_model(
model, adapter_root: Path, device: torch.device, report: _LoadReport
) -> None:
config_file = adapter_root / "adapter_config.json"
bin_file = adapter_root / "adapter_model.bin"
safe_tensors_file = adapter_root / "adapter_model.safetensors"
if not (config_file.exists() and (bin_file.exists() or safe_tensors_file.exists())):
return
try:
from peft import PeftModel
except ImportError as exc:
raise RuntimeError(
"peft is required to load language model adapters but is not installed"
) from exc
logger.info(f"Loading language model LoRA from {adapter_root }")
_peft_load = getattr(PeftModel, "from_pretrained")
peft_lm = _peft_load(model.model.language_model, adapter_root)
peft_lm.to(device)
model.model.language_model = peft_lm
if hasattr(model, "tie_weights"):
try:
model.tie_weights()
except Exception as exc:
logger.warning(
f"Failed to retie weights after loading language LoRA: {exc }"
)
report.language_model = True
def load_lora_assets(
model, checkpoint_dir: str, device: Optional[torch.device] = None
) -> _LoadReport:
adapter_root = _resolve_adapter_root(Path(checkpoint_dir))
if not adapter_root.exists():
raise FileNotFoundError(f"Adapter directory not found: {adapter_root }")
inferred_device = device or next(model.parameters()).device
report = _LoadReport(adapter_root=adapter_root)
_load_language_model(model, adapter_root, inferred_device, report)
_load_diffusion_head(model, adapter_root, inferred_device, report)
ac_path = adapter_root / "acoustic_connector" / "pytorch_model.bin"
if _load_connector(
getattr(model.model, "acoustic_connector", None), ac_path, inferred_device
):
report.acoustic_connector = True
se_path = adapter_root / "semantic_connector" / "pytorch_model.bin"
if _load_connector(
getattr(model.model, "semantic_connector", None), se_path, inferred_device
):
report.semantic_connector = True
if not any(report.__dict__.values()):
logger.warning(
"No adapter assets were loaded. Ensure the checkpoint directory is correct and contains LoRA weights."
)
return report
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import deprecate
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput,
)
def betas_for_alpha_bar(
num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine"
):
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
elif alpha_transform_type == "cauchy":
def alpha_bar_fn(t, gamma=1, mu=3):
snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9)
return 1 - 1 / (math.exp(snr) + 1.1)
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t, mu=0, b=1):
snr = mu - b * math.copysign(1, 0.5 - t) * math.log(
1 - 2 * abs(t - 0.5) * 0.98
)
return 1 - 1 / (math.exp(snr) + 1.02)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type }")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
def rescale_zero_terminal_snr(betas):
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
alphas_bar_sqrt -= alphas_bar_sqrt_T
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero",
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
):
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type } is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate(
"algorithm_types dpmsolver and sde-dpmsolver",
"1.0.0",
deprecation_message,
)
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(
beta_start, beta_end, num_train_timesteps, dtype=torch.float32
)
elif beta_schedule == "scaled_linear":
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine":
self.betas = betas_for_alpha_bar(
num_train_timesteps, alpha_transform_type="cosine"
)
elif beta_schedule == "cauchy":
self.betas = betas_for_alpha_bar(
num_train_timesteps, alpha_transform_type="cauchy"
)
elif beta_schedule == "laplace":
self.betas = betas_for_alpha_bar(
num_train_timesteps, alpha_transform_type="laplace"
)
else:
raise NotImplementedError(
f"{beta_schedule } is not implemented for {self .__class__ }"
)
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
if rescale_betas_zero_snr:
self.alphas_cumprod[-1] = 2 ** (-24)
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
self.init_noise_sigma = 1.0
if algorithm_type not in [
"dpmsolver",
"dpmsolver++",
"sde-dpmsolver",
"sde-dpmsolver++",
]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
raise NotImplementedError(
f"{algorithm_type } is not implemented for {self .__class__ }"
)
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.register_to_config(solver_type="midpoint")
else:
raise NotImplementedError(
f"{solver_type } is not implemented for {self .__class__ }"
)
if (
algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]
and final_sigmas_type == "zero"
):
raise ValueError(
f"`final_sigmas_type` {final_sigmas_type } is not supported for `algorithm_type` {algorithm_type }. Please choose `sigma_min` instead."
)
self.num_inference_steps = None
timesteps = np.linspace(
0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32
)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu")
@property
def step_index(self):
return self._step_index
@property
def begin_index(self):
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
self._begin_index = begin_index
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
):
if num_inference_steps is None and timesteps is None:
raise ValueError(
"Must pass exactly one of `num_inference_steps` or `timesteps`."
)
if num_inference_steps is not None and timesteps is not None:
raise ValueError(
"Can only pass one of `num_inference_steps` or `custom_timesteps`."
)
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError(
"Cannot use `timesteps` with `config.use_karras_sigmas = True`"
)
if timesteps is not None and self.config.use_lu_lambdas:
raise ValueError(
"Cannot use `timesteps` with `config.use_lu_lambdas = True`"
)
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.int64)
else:
clipped_idx = torch.searchsorted(
torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped
)
last_timestep = (
(self.config.num_train_timesteps - clipped_idx).numpy().item()
)
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, last_timestep - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = last_timestep // (num_inference_steps + 1)
timesteps = (
(np.arange(0, num_inference_steps + 1) * step_ratio)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
timesteps = (
np.arange(last_timestep, 0, -step_ratio)
.round()
.copy()
.astype(np.int64)
)
timesteps -= 1
else:
raise ValueError(
f"{self .config .timestep_spacing } is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
if self.config.use_karras_sigmas:
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(
in_sigmas=sigmas, num_inference_steps=num_inference_steps
)
timesteps = np.array(
[self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
).round()
elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(
in_lambdas=lambdas, num_inference_steps=num_inference_steps
)
sigmas = np.exp(lambdas)
timesteps = np.array(
[self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
).round()
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self .config .final_sigmas_type }"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(
device=device, dtype=torch.int64
)
self.num_inference_steps = len(timesteps)
self.model_outputs = [None] * self.config.solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu")
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float()
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs()
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(s, min=1, max=self.config.sample_max_value)
s = s.unsqueeze(1)
sample = torch.clamp(sample, -s, s) / s
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
def _sigma_to_t(self, sigma, log_sigmas):
log_sigma = np.log(np.maximum(sigma, 1e-10))
dists = log_sigma - log_sigmas[:, np.newaxis]
low_idx = (
np.cumsum(dists >= 0, axis=0)
.argmax(axis=0)
.clip(max=log_sigmas.shape[0] - 2)
)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / (sigma**2 + 1) ** 0.5
sigma_t = sigma * alpha_t
return (alpha_t, sigma_t)
def _convert_to_karras(
self, in_sigmas: torch.Tensor, num_inference_steps
) -> torch.Tensor:
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def _convert_to_lu(
self, in_lambdas: torch.Tensor, num_inference_steps
) -> torch.Tensor:
lambda_min: float = in_lambdas[-1].item()
lambda_max: float = in_lambdas[0].item()
rho = 1.0
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = lambda_min ** (1 / rho)
max_inv_rho = lambda_max ** (1 / rho)
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return lambdas
def convert_model_output(
self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs
) -> torch.Tensor:
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config.prediction_type == "epsilon":
if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3]
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self .config .prediction_type } must be one of `epsilon`, `sample`, or `v_prediction` for the DPMSolverMultistepScheduler."
)
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
if self.config.prediction_type == "epsilon":
if self.config.variance_type in ["learned", "learned_range"]:
epsilon = model_output[:, :3]
else:
epsilon = model_output
elif self.config.prediction_type == "sample":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = (sample - alpha_t * model_output) / sigma_t
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = alpha_t * model_output + sigma_t * sample
else:
raise ValueError(
f"prediction_type given as {self .config .prediction_type } must be one of `epsilon`, `sample`, or `v_prediction` for the DPMSolverMultistepScheduler."
)
if self.config.thresholding:
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t
return epsilon
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (
sigma_t / sigma_s * sample
- alpha_t * (torch.exp(-h) - 1.0) * model_output
)
elif self.config.algorithm_type == "dpmsolver":
x_t = (
alpha_t / alpha_s * sample
- sigma_t * (torch.exp(h) - 1.0) * model_output
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
sigma_t / sigma_s * torch.exp(-h) * sample
+ alpha_t * (1 - torch.exp(-2.0 * h)) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
x_t = (
alpha_t / alpha_s * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
m0, m1 = (model_output_list[-1], model_output_list[-2])
h, h_0 = (lambda_t - lambda_s0, lambda_s0 - lambda_s1)
r0 = h_0 / h
D0, D1 = (m0, 1.0 / r0 * (m0 - m1))
if self.config.algorithm_type == "dpmsolver++":
if self.config.solver_type == "midpoint":
x_t = (
sigma_t / sigma_s0 * sample
- alpha_t * (torch.exp(-h) - 1.0) * D0
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
x_t = (
sigma_t / sigma_s0 * sample
- alpha_t * (torch.exp(-h) - 1.0) * D0
+ alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0) * D1
)
elif self.config.algorithm_type == "dpmsolver":
if self.config.solver_type == "midpoint":
x_t = (
alpha_t / alpha_s0 * sample
- sigma_t * (torch.exp(h) - 1.0) * D0
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
x_t = (
alpha_t / alpha_s0 * sample
- sigma_t * (torch.exp(h) - 1.0) * D0
- sigma_t * ((torch.exp(h) - 1.0) / h - 1.0) * D1
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
sigma_t / sigma_s0 * torch.exp(-h) * sample
+ alpha_t * (1 - torch.exp(-2.0 * h)) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
sigma_t / sigma_s0 * torch.exp(-h) * sample
+ alpha_t * (1 - torch.exp(-2.0 * h)) * D0
+ alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
alpha_t / alpha_s0 * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- sigma_t * (torch.exp(h) - 1.0) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
elif self.config.solver_type == "heun":
x_t = (
alpha_t / alpha_s0 * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing`sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
m0, m1, m2 = (
model_output_list[-1],
model_output_list[-2],
model_output_list[-3],
)
h, h_0, h_1 = (
lambda_t - lambda_s0,
lambda_s0 - lambda_s1,
lambda_s1 - lambda_s2,
)
r0, r1 = (h_0 / h, h_1 / h)
D0 = m0
D1_0, D1_1 = (1.0 / r0 * (m0 - m1), 1.0 / r1 * (m1 - m2))
D1 = D1_0 + r0 / (r0 + r1) * (D1_0 - D1_1)
D2 = 1.0 / (r0 + r1) * (D1_0 - D1_1)
if self.config.algorithm_type == "dpmsolver++":
x_t = (
sigma_t / sigma_s0 * sample
- alpha_t * (torch.exp(-h) - 1.0) * D0
+ alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0) * D1
- alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5) * D2
)
elif self.config.algorithm_type == "dpmsolver":
x_t = (
alpha_t / alpha_s0 * sample
- sigma_t * (torch.exp(h) - 1.0) * D0
- sigma_t * ((torch.exp(h) - 1.0) / h - 1.0) * D1
- sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5) * D2
)
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
return step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
lower_order_final = self.step_index == len(self.timesteps) - 1 and (
self.config.euler_at_final
or (self.config.lower_order_final and len(self.timesteps) < 15)
or self.config.final_sigmas_type == "zero"
)
lower_order_second = (
self.step_index == len(self.timesteps) - 2
and self.config.lower_order_final
and (len(self.timesteps) < 15)
)
model_output = self.convert_model_output(model_output, sample=sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
sample = sample.to(torch.float32)
if (
self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]
and variance_noise is None
):
noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=torch.float32,
)
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
else:
noise = None
if (
self.config.solver_order == 1
or self.lower_order_nums < 1
or lower_order_final
):
prev_sample = self.dpm_solver_first_order_update(
model_output, sample=sample, noise=noise
)
elif (
self.config.solver_order == 2
or self.lower_order_nums < 2
or lower_order_second
):
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, sample=sample, noise=noise
)
else:
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, sample=sample
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
prev_sample = prev_sample.to(model_output.dtype)
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
alpha_t = alpha_t[timesteps].flatten()
while len(alpha_t.shape) < len(original_samples.shape):
alpha_t = alpha_t.unsqueeze(-1)
sigma_t = sigma_t[timesteps].flatten()
while len(sigma_t.shape) < len(original_samples.shape):
sigma_t = sigma_t.unsqueeze(-1)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def get_velocity(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
alpha_t = alpha_t[timesteps].flatten()
while len(alpha_t.shape) < len(original_samples.shape):
alpha_t = alpha_t.unsqueeze(-1)
sigma_t = sigma_t[timesteps].flatten()
while len(sigma_t.shape) < len(original_samples.shape):
sigma_t = sigma_t.unsqueeze(-1)
velocity = alpha_t * noise - sigma_t * original_samples
return velocity
def __len__(self):
return self.config.num_train_timesteps
'\nProcessor class for QWEN3Vox models.\n'
import os
import json
import warnings
from typing import List, Optional, Union, Dict, Any
import numpy as np
import torch
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.utils import logging
logger = logging.get_logger(__name__)
class AudioNormalizer:
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-06):
self.target_dB_FS = target_dB_FS
self.eps = eps
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
rms = np.sqrt(np.mean(audio**2))
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
normalized_audio = audio * scalar
return (normalized_audio, rms, scalar)
def avoid_clipping(
self, audio: np.ndarray, scalar: Optional[float] = None
) -> tuple:
if scalar is None:
max_val = np.max(np.abs(audio))
if max_val > 1.0:
scalar = max_val + self.eps
else:
scalar = 1.0
return (audio / scalar, scalar)
def __call__(self, audio: np.ndarray) -> np.ndarray:
audio, _, _ = self.tailor_dB_FS(audio)
audio, _ = self.avoid_clipping(audio)
return audio
class QWEN3VoxTokenizerProcessor(FeatureExtractionMixin):
model_input_names = ["input_features"]
def __init__(
self,
sampling_rate: int = 22050,
normalize_audio: bool = True,
target_dB_FS: float = -25,
eps: float = 1e-06,
**kwargs,
):
super().__init__(**kwargs)
self.sampling_rate = sampling_rate
self.normalize_audio = normalize_audio
if self.normalize_audio:
self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps)
else:
self.normalizer = None
self.feature_extractor_dict = {
"sampling_rate": sampling_rate,
"normalize_audio": normalize_audio,
"target_dB_FS": target_dB_FS,
"eps": eps,
}
def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
if len(audio.shape) == 1:
return audio
elif len(audio.shape) == 2:
if audio.shape[0] == 2:
return np.mean(audio, axis=0)
elif audio.shape[1] == 2:
return np.mean(audio, axis=1)
elif audio.shape[0] == 1:
return audio.squeeze(0)
elif audio.shape[1] == 1:
return audio.squeeze(1)
else:
raise ValueError(f"Unexpected audio shape: {audio .shape }")
else:
raise ValueError(f"Audio should be 1D or 2D, got shape: {audio .shape }")
def _process_single_audio(
self, audio: Union[np.ndarray, List[float]]
) -> np.ndarray:
if not isinstance(audio, np.ndarray):
audio = np.array(audio, dtype=np.float32)
else:
audio = audio.astype(np.float32)
audio = self._ensure_mono(audio)
if self.normalize_audio and self.normalizer is not None:
audio = self.normalizer(audio)
return audio
def __call__(
self,
audio: Union[
str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str]
] = None,
sampling_rate: Optional[int] = None,
return_tensors: Optional[str] = None,
**kwargs,
):
if audio is None:
raise ValueError("Audio input is required")
if sampling_rate is not None and sampling_rate != self.sampling_rate:
logger.warning(
f"Input sampling rate ({sampling_rate }) differs from expected sampling rate ({self .sampling_rate }). Please resample your audio."
)
if isinstance(audio, str):
audio = self._load_audio_from_path(audio)
is_batched = False
elif isinstance(audio, list):
if len(audio) == 0:
raise ValueError("Empty audio list provided")
if all((isinstance(item, str) for item in audio)):
audio = [self._load_audio_from_path(path) for path in audio]
is_batched = True
else:
is_batched = isinstance(audio[0], (np.ndarray, list))
else:
is_batched = False
if is_batched:
processed_audio = [self._process_single_audio(a) for a in audio]
else:
processed_audio = [self._process_single_audio(audio)]
if return_tensors == "pt":
if len(processed_audio) == 1:
input_features = (
torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1)
)
else:
input_features = torch.stack(
[torch.from_numpy(a) for a in processed_audio]
).unsqueeze(1)
elif return_tensors == "np":
if len(processed_audio) == 1:
input_features = processed_audio[0][np.newaxis, np.newaxis, :]
else:
input_features = np.stack(processed_audio)[:, np.newaxis, :]
else:
input_features = (
processed_audio[0] if len(processed_audio) == 1 else processed_audio
)
outputs = {"audio": input_features}
return outputs
def _load_audio_from_path(self, audio_path: str) -> np.ndarray:
file_ext = os.path.splitext(audio_path)[1].lower()
if file_ext in [".wav", ".mp3", ".flac", ".m4a", ".ogg"]:
import librosa
audio_array, sr = librosa.load(audio_path, sr=self.sampling_rate, mono=True)
return audio_array
elif file_ext == ".pt":
audio_tensor = torch.load(audio_path, map_location="cpu").squeeze()
if isinstance(audio_tensor, torch.Tensor):
audio_array = audio_tensor.numpy()
else:
audio_array = np.array(audio_tensor)
return audio_array.astype(np.float32)
elif file_ext == ".npy":
audio_array = np.load(audio_path)
return audio_array.astype(np.float32)
else:
raise ValueError(
f"Unsupported file format: {file_ext }. Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz"
)
def preprocess_audio(
self,
audio_path_or_array: Union[str, np.ndarray],
normalize: Optional[bool] = None,
) -> np.ndarray:
if isinstance(audio_path_or_array, str):
audio_array = self._load_audio_from_path(audio_path_or_array)
else:
audio_array = np.array(audio_path_or_array, dtype=np.float32)
original_normalize = self.normalize_audio
if normalize is not None:
self.normalize_audio = normalize
try:
processed = self._process_single_audio(audio_array)
finally:
self.normalize_audio = original_normalize
return processed
def to_dict(self) -> Dict[str, Any]:
return self.feature_extractor_dict
def save_audio(
self,
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
):
if sampling_rate is None:
sampling_rate = self.sampling_rate
try:
import soundfile as sf
except ImportError:
raise ImportError(
"soundfile is required to save audio files. Install it with: pip install soundfile"
)
if isinstance(audio, torch.Tensor):
audio_np = audio.float().detach().cpu().numpy()
elif isinstance(audio, np.ndarray):
audio_np = audio
elif isinstance(audio, list):
if all((isinstance(a, torch.Tensor) for a in audio)):
audio_np = [a.float().detach().cpu().numpy() for a in audio]
else:
audio_np = audio
else:
raise ValueError(f"Unsupported audio type: {type (audio )}")
saved_paths = []
if isinstance(audio_np, list):
output_dir = output_path
os.makedirs(output_dir, exist_ok=True)
for i, audio_item in enumerate(audio_np):
audio_item = self._prepare_audio_for_save(audio_item, normalize)
file_path = os.path.join(output_dir, f"{batch_prefix }{i }.wav")
sf.write(file_path, audio_item, sampling_rate)
saved_paths.append(file_path)
elif len(audio_np.shape) >= 3:
batch_size = audio_np.shape[0]
if batch_size > 1:
output_dir = output_path
os.makedirs(output_dir, exist_ok=True)
for i in range(batch_size):
single_audio = audio_np[i]
if len(single_audio.shape) > 1:
if single_audio.shape[0] == 1:
single_audio = single_audio.squeeze(0)
single_audio = self._prepare_audio_for_save(single_audio, normalize)
file_path = os.path.join(output_dir, f"{batch_prefix }{i }.wav")
sf.write(file_path, single_audio, sampling_rate)
saved_paths.append(file_path)
else:
audio_item = audio_np.squeeze()
audio_item = self._prepare_audio_for_save(audio_item, normalize)
sf.write(output_path, audio_item, sampling_rate)
saved_paths.append(output_path)
else:
audio_item = self._prepare_audio_for_save(audio_np, normalize)
sf.write(output_path, audio_item, sampling_rate)
saved_paths.append(output_path)
return saved_paths
def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
if len(audio.shape) > 1 and audio.shape[0] == 1:
audio = audio.squeeze(0)
if normalize:
max_val = np.abs(audio).max()
if max_val > 0:
audio = audio / max_val
return audio
__all__ = [
'QWEN3VoxTokenizerProcessor',
"AudioNormalizer",
]
import math
import torch
class UniformSampler:
def __init__(self, timesteps=1000):
self.timesteps = timesteps
def sample(self, batch_size, device):
return torch.randint(0, self.timesteps, (batch_size,), device=device)
class LogitNormalSampler:
def __init__(self, timesteps=1000, m=0, s=1):
self.timesteps = timesteps
timesteps = torch.linspace(0, 1, timesteps)
logit = torch.log(timesteps / (1 - timesteps))
self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s**2) / (
s * math.sqrt(2 * math.pi)
)
def sample(self, batch_size, device):
return torch.multinomial(self.prob, batch_size, replacement=True).to(device)
' QWEN3Vox Streaming model configuration'
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
logger = logging.get_logger(__name__)
class QWEN3VoxStreamingConfig(PretrainedConfig):
model_type = 'vibevoice_streaming'
is_composition = True
sub_configs = {
"acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig,
"decoder_config": Qwen2Config,
"diffusion_head_config": QWEN3VoxDiffusionHeadConfig,
}
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
acoustic_tokenizer_config=None,
decoder_config=None,
diffusion_head_config=None,
tts_backbone_num_hidden_layers=20,
**kwargs,
):
kwargs["_attn_implementation_autoset"] = False
if acoustic_tokenizer_config is None:
self.acoustic_tokenizer_config = self.sub_configs[
"acoustic_tokenizer_config"
]()
elif isinstance(acoustic_tokenizer_config, dict):
acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer'
self.acoustic_tokenizer_config = self.sub_configs[
"acoustic_tokenizer_config"
](**acoustic_tokenizer_config)
elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig):
self.acoustic_tokenizer_config = acoustic_tokenizer_config
if decoder_config is None:
self.decoder_config = self.sub_configs["decoder_config"]()
elif isinstance(decoder_config, dict):
if decoder_config.get("model_type", "") == "qwen2":
self.decoder_config = Qwen2Config(**decoder_config)
else:
raise ValueError(
f"Unsupported decoder model type: {decoder_config .get ('model_type','')}"
)
elif isinstance(decoder_config, (Qwen2Config,)):
self.decoder_config = decoder_config
if diffusion_head_config is None:
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
elif isinstance(diffusion_head_config, dict):
diffusion_head_config["model_type"] = 'vibevoice_diffusion_head'
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](
**diffusion_head_config
)
elif isinstance(diffusion_head_config, QWEN3VoxDiffusionHeadConfig):
self.diffusion_head_config = diffusion_head_config
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64)
self.tts_backbone_num_hidden_layers = tts_backbone_num_hidden_layers
super().__init__(**kwargs)
__all__ = [
'QWEN3VoxStreamingConfig'
]
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.auto import AutoModel
from transformers.modeling_utils import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.utils import logging
logger = logging.get_logger(__name__)
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
eps: float = 1e-06,
elementwise_affine=True,
memory_efficient=False,
):
super().__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter("weight", None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output
def extra_repr(self) -> str:
return f"dim={self .dim }, eps={self .eps }, elementwise_affine={self .elementwise_affine }"
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
ACT2FN["silu"],
nn.Linear(hidden_size, hidden_size, bias=False),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding.to(t.dtype)
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class FeedForwardNetwork(nn.Module):
def __init__(self, embed_dim, ffn_dim):
super().__init__()
self.embed_dim = embed_dim
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
self.act_fn = ACT2FN["silu"]
def forward(self, x):
gate = self.gate_proj(x)
up = self.up_proj(x)
gate = self.act_fn(gate)
return self.down_proj(gate * up)
class HeadLayer(nn.Module):
def __init__(self, embed_dim, ffn_dim, cond_dim, norm_eps=1e-05):
super().__init__()
self.embed_dim = embed_dim
self.cond_dim = cond_dim
self.ffn_dim = ffn_dim
self.ffn = FeedForwardNetwork(self.embed_dim, self.ffn_dim)
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
self.adaLN_modulation = nn.Sequential(
ACT2FN["silu"], nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
)
def forward(self, x, c):
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
return x
class FinalLayer(nn.Module):
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-05):
super().__init__()
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
self.linear = nn.Linear(hidden_size, output_size, bias=False)
self.adaLN_modulation = nn.Sequential(
ACT2FN["silu"], nn.Linear(cond_size, 2 * hidden_size, bias=False)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class QWEN3VoxDiffusionHead(PreTrainedModel):
config_class = QWEN3VoxDiffusionHeadConfig
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config):
super().__init__(config)
self.config = config
self.cond_dim = config.hidden_size
latent_size = config.latent_size
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
self.t_embedder = TimestepEmbedder(self.cond_dim)
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
self.layers = nn.ModuleList(
[
HeadLayer(
embed_dim=config.hidden_size,
ffn_dim=ffn_dim,
cond_dim=self.cond_dim,
norm_eps=config.rms_norm_eps,
)
for _ in range(config.head_layers)
]
)
self.final_layer = FinalLayer(
hidden_size=config.hidden_size,
output_size=latent_size,
cond_size=self.cond_dim,
norm_eps=config.rms_norm_eps,
)
self.initialize_weights()
def initialize_weights(self):
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
for layer in self.layers:
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
def forward(self, noisy_images, timesteps, condition):
x = self.noisy_images_proj(noisy_images)
t = self.t_embedder(timesteps)
condition = self.cond_proj(condition)
c = condition + t
for layer in self.layers:
x = layer(x, c)
x = self.final_layer(x, c)
return x
AutoModel.register(QWEN3VoxDiffusionHeadConfig, QWEN3VoxDiffusionHead)
__all__ = [
'QWEN3VoxDiffusionHead'
]
import math
import typing as tp
from functools import partial
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.auto import AutoModel
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.modeling_utils import PreTrainedModel
from transformers.activations import ACT2FN
logger = logging.get_logger(__name__)
import os
try:
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
APEX_AVAILABLE = True
logger.info("APEX FusedRMSNorm is available and will be used for optimization")
if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
APEX_AVAILABLE = False
logger.warning(
"APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0"
)
except ImportError:
APEX_AVAILABLE = False
logger.warning("APEX FusedRMSNorm not available, using native implementation")
class ConvLayerNorm(nn.LayerNorm):
def __init__(
self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
):
super().__init__(normalized_shape, **kwargs)
def forward(self, x):
x = x.transpose(1, 2)
x = nn.functional.layer_norm(
x.float(),
self.normalized_shape,
self.weight.float(),
self.bias.float(),
self.eps,
).type_as(x)
x = x.transpose(1, 2)
return x
class RMSNorm(nn.Module):
def __init__(
self, dim: int, eps: float = 1e-05, elementwise_affine=True, weight_shape=None
):
super().__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
weight_shape = (dim,) if weight_shape is None else weight_shape
self.weight = nn.Parameter(torch.ones(weight_shape))
else:
self.register_parameter("weight", None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output
def extra_repr(self) -> str:
return f"dim={self .dim }, eps={self .eps }, elementwise_affine={self .elementwise_affine }"
class ConvRMSNorm(RMSNorm):
def __init__(
self, dim: int, eps: float = 1e-05, elementwise_affine=True, weight_shape=None
):
super().__init__(dim, eps, elementwise_affine, weight_shape)
def forward(self, x):
x = x.transpose(1, 2)
if not APEX_AVAILABLE or not self.elementwise_affine:
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
else:
output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps)
output = output.transpose(1, 2)
return output
CONV_NORMALIZATIONS = frozenset(
[
"none",
"weight_norm",
"spectral_norm",
"time_layer_norm",
"layer_norm",
"time_group_norm",
]
)
def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == "weight_norm":
return nn.utils.weight_norm(module)
elif norm == "spectral_norm":
return nn.utils.spectral_norm(module)
else:
return module
def get_norm_module(
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
) -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == "layer_norm":
assert isinstance(module, nn.modules.conv._ConvNd)
return ConvLayerNorm(module.out_channels, **norm_kwargs)
elif norm == "time_group_norm":
if causal:
raise ValueError("GroupNorm doesn't support causal evaluation.")
assert isinstance(module, nn.modules.conv._ConvNd)
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
else:
return nn.Identity()
def get_extra_padding_for_conv1d(
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
) -> int:
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad1d(
x: torch.Tensor,
paddings: tp.Tuple[int, int],
mode: str = "zero",
value: float = 0.0,
):
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert padding_left + padding_right <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left:end]
class NormConv1d(nn.Module):
def __init__(
self,
*args,
causal: bool = False,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
**kwargs,
):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConvTranspose1d(nn.Module):
def __init__(
self,
*args,
causal: bool = False,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
**kwargs,
):
super().__init__()
self.convtr = apply_parametrization_norm(
nn.ConvTranspose1d(*args, **kwargs), norm
)
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class QWEN3VoxTokenizerStreamingCache:
def __init__(self):
self.cache = {}
def get(
self, layer_id: str, sample_indices: torch.Tensor
) -> Optional[torch.Tensor]:
states = []
max_length = 0
for idx in sample_indices.tolist():
key = (layer_id, idx)
if key not in self.cache:
return None
state = self.cache[key]
states.append(state)
max_length = max(max_length, state.shape[-1])
if len(states) > 0 and states[0].dim() >= 2:
padded_states = []
for state in states:
if state.shape[-1] < max_length:
pad_size = max_length - state.shape[-1]
padded_state = F.pad(state, (pad_size, 0), mode="constant", value=0)
padded_states.append(padded_state)
else:
padded_states.append(state)
return torch.stack(padded_states, dim=0)
else:
return torch.stack(states, dim=0)
def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor):
for i, idx in enumerate(sample_indices.tolist()):
key = (layer_id, idx)
self.cache[key] = states[i].detach()
def set_to_zero(self, sample_indices: torch.Tensor):
for key in list(self.cache.keys()):
layer_id, sample_idx = key
if sample_idx in sample_indices.tolist():
cached_tensor = self.cache[key]
self.cache[key] = torch.zeros_like(cached_tensor)
def clear(
self,
layer_id: Optional[str] = None,
sample_indices: Optional[torch.Tensor] = None,
):
if layer_id is None and sample_indices is None:
self.cache.clear()
elif layer_id is not None and sample_indices is None:
keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id]
for k in keys_to_remove:
del self.cache[k]
elif layer_id is not None and sample_indices is not None:
for idx in sample_indices.tolist():
key = (layer_id, idx)
self.cache.pop(key, None)
class SConv1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
causal: bool = False,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
pad_mode: str = "reflect",
):
super().__init__()
self.conv = NormConv1d(
in_channels,
out_channels,
kernel_size,
stride,
dilation=dilation,
groups=groups,
bias=bias,
causal=causal,
norm=norm,
norm_kwargs=norm_kwargs,
)
self.causal = causal
self.pad_mode = pad_mode
self.kernel_size = kernel_size
self.dilation = dilation
self.stride = stride
self.in_channels = in_channels
self.out_channels = out_channels
self.context_size = (kernel_size - 1) * dilation - (stride - 1)
self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
self._layer_id = None
@property
def layer_id(self):
if self._layer_id is None:
self._layer_id = f"sconv1d_{id (self )}"
return self._layer_id
def forward(
self,
x: torch.Tensor,
cache: Optional[QWEN3VoxTokenizerStreamingCache] = None,
sample_indices: Optional[torch.Tensor] = None,
use_cache: bool = False,
debug: bool = False,
) -> torch.Tensor:
B, C, T = x.shape
if not use_cache or cache is None:
return self._forward_non_streaming(x, debug=debug)
assert self.causal, "Streaming mode is only supported for causal convolutions"
assert (
sample_indices is not None
), "sample_indices must be provided for streaming mode"
assert len(sample_indices) == B, "sample_indices must match batch size"
return self._forward_streaming(x, cache, sample_indices, debug)
def _forward_streaming(
self,
x: torch.Tensor,
cache: QWEN3VoxTokenizerStreamingCache,
sample_indices: torch.Tensor,
debug: bool = False,
) -> torch.Tensor:
B, C, T = x.shape
cached_states = cache.get(self.layer_id, sample_indices)
if cached_states is None:
if self.context_size > 0:
cached_states = torch.zeros(
B, C, self.context_size, device=x.device, dtype=x.dtype
)
if debug:
print(
f"[DEBUG] Initialized cache with shape: {cached_states .shape }, context_size={self .context_size }"
)
else:
cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
if debug:
print(f"[DEBUG] No context needed (kernel_size=stride)")
if cached_states.shape[2] > 0:
input_with_context = torch.cat([cached_states, x], dim=2)
else:
input_with_context = x
if debug:
print(
f"[DEBUG] Input shape: {x .shape }, Cache shape: {cached_states .shape }, Combined: {input_with_context .shape }"
)
output = self.conv(input_with_context)
if debug:
print(f"[DEBUG] Output shape: {output .shape }")
if self.context_size > 0:
total_input_length = input_with_context.shape[2]
if total_input_length >= self.context_size:
new_cache_start = total_input_length - self.context_size
new_cache = input_with_context[:, :, new_cache_start:]
else:
new_cache = input_with_context
if debug:
print(f"[DEBUG] New cache shape: {new_cache .shape }")
cache.set(self.layer_id, sample_indices, new_cache)
return output
def _forward_non_streaming(
self, x: torch.Tensor, debug: bool = False
) -> torch.Tensor:
B, C, T = x.shape
kernel_size = self.kernel_size
stride = self.stride
dilation = self.dilation
padding_total = self.padding_total
extra_padding = get_extra_padding_for_conv1d(
x, kernel_size, stride, padding_total
)
if debug:
print(
f"[DEBUG NON-STREAMING] Input shape: {x .shape }, padding_total={padding_total }, extra_padding={extra_padding }"
)
if self.causal:
if self.pad_mode == "constant":
x = pad1d(
x, (padding_total, extra_padding), mode=self.pad_mode, value=0
)
else:
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
else:
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
)
if debug:
print(f"[DEBUG NON-STREAMING] After padding: {x .shape }")
output = self.conv(x)
if debug:
print(f"[DEBUG NON-STREAMING] Output shape: {output .shape }")
return output
class SConvTranspose1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
causal: bool = False,
norm: str = "none",
trim_right_ratio: float = 1.0,
norm_kwargs: tp.Dict[str, tp.Any] = {},
bias: bool = True,
):
super().__init__()
self.convtr = NormConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
causal=causal,
norm=norm,
norm_kwargs=norm_kwargs,
bias=bias,
)
self.causal = causal
self.trim_right_ratio = trim_right_ratio
assert (
self.causal or self.trim_right_ratio == 1.0
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
self.kernel_size = kernel_size
self.stride = stride
self.in_channels = in_channels
self.out_channels = out_channels
self.padding_total = kernel_size - stride
self.context_size = kernel_size - 1
self._layer_id = None
@property
def layer_id(self):
if self._layer_id is None:
self._layer_id = f"sconvtr1d_{id (self )}"
return self._layer_id
def forward(
self,
x: torch.Tensor,
cache: Optional[QWEN3VoxTokenizerStreamingCache] = None,
sample_indices: Optional[torch.Tensor] = None,
use_cache: bool = False,
debug: bool = False,
) -> torch.Tensor:
B, C, T = x.shape
if not use_cache or cache is None:
return self._forward_non_streaming(x, debug=debug)
assert (
sample_indices is not None
), "sample_indices must be provided for streaming mode"
assert len(sample_indices) == B, "sample_indices must match batch size"
return self._forward_streaming(x, cache, sample_indices, debug)
def _forward_streaming(
self,
x: torch.Tensor,
cache: QWEN3VoxTokenizerStreamingCache,
sample_indices: torch.Tensor,
debug: bool = False,
) -> torch.Tensor:
B, C, T = x.shape
cached_input = cache.get(self.layer_id, sample_indices)
if cached_input is None:
cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
if debug:
print(f"[DEBUG] Initialized empty cache for transposed conv")
full_input = torch.cat([cached_input, x], dim=2)
if debug:
print(
f"[DEBUG] Input shape: {x .shape }, Cache shape: {cached_input .shape }, Combined: {full_input .shape }"
)
full_output = self.convtr(full_input)
if debug:
print(f"[DEBUG] Full transposed conv output shape: {full_output .shape }")
if self.causal:
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
padding_left = self.padding_total - padding_right
else:
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
if padding_left + padding_right > 0:
full_output = unpad1d(full_output, (padding_left, padding_right))
if debug:
print(f"[DEBUG] After unpadding: {full_output .shape }")
if cached_input.shape[2] == 0:
output = full_output
else:
expected_new_output = T * self.stride
if full_output.shape[2] >= expected_new_output:
output = full_output[:, :, -expected_new_output:]
else:
output = full_output
if debug:
print(f"[DEBUG] Final streaming output shape: {output .shape }")
if full_input.shape[2] > self.context_size:
new_cache = full_input[:, :, -self.context_size :]
else:
new_cache = full_input
if debug:
print(f"[DEBUG] New cache shape: {new_cache .shape }")
cache.set(self.layer_id, sample_indices, new_cache)
return output
def _forward_non_streaming(
self, x: torch.Tensor, debug: bool = False
) -> torch.Tensor:
if debug:
print(f"[DEBUG NON-STREAMING] Input shape: {x .shape }")
y = self.convtr(x)
if debug:
print(f"[DEBUG NON-STREAMING] After transposed conv: {y .shape }")
if self.causal:
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
padding_left = self.padding_total - padding_right
else:
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
if padding_left + padding_right > 0:
y = unpad1d(y, (padding_left, padding_right))
if debug:
print(f"[DEBUG NON-STREAMING] Final output shape: {y .shape }")
return y
class FFN(nn.Module):
def __init__(self, embed_dim, ffn_dim, bias=False):
super().__init__()
self.embed_dim = embed_dim
self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias)
self.gelu = ACT2FN["gelu"]
self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias)
def forward(self, x):
x = self.linear1(x)
x = self.gelu(x)
x = self.linear2(x)
return x
class Convlayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias=True,
pad_mode="zeros",
norm="weight_norm",
causal=True,
):
super().__init__()
self.conv = SConv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
pad_mode=pad_mode,
norm=norm,
causal=causal,
)
def forward(self, x):
return self.conv(x)
class Block1D(nn.Module):
def __init__(
self,
dim,
kernel_size=7,
drop_path=0.0,
mixer_layer="conv",
layer_scale_init_value=1e-06,
**kwargs,
):
super().__init__()
if kwargs.get("layernorm", "LN") == "LN":
self.norm = ConvLayerNorm(dim, eps=kwargs.get("eps", 1e-06))
self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get("eps", 1e-06))
elif kwargs.get("layernorm", "RMSNorm") == "RMSNorm":
self.norm = ConvRMSNorm(dim, eps=kwargs.get("eps", 1e-06))
self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get("eps", 1e-06))
if mixer_layer == "conv":
self.mixer = Convlayer(
dim,
dim,
groups=kwargs.get("groups", 1),
kernel_size=kernel_size,
pad_mode=kwargs.get("pad_mode", "reflect"),
norm=kwargs.get("norm", "none"),
causal=kwargs.get("causal", True),
bias=kwargs.get("bias", True),
)
elif mixer_layer == "depthwise_conv":
self.mixer = Convlayer(
dim,
dim,
groups=dim,
kernel_size=kernel_size,
pad_mode=kwargs.get("pad_mode", "reflect"),
norm=kwargs.get("norm", "none"),
causal=kwargs.get("causal", True),
bias=kwargs.get("bias", True),
)
else:
raise ValueError(f"Unsupported mixer layer: {mixer_layer }")
self.ffn = FFN(
dim, kwargs.get("ffn_expansion", 4) * dim, bias=kwargs.get("bias", False)
)
self.drop_path = (
nn.Identity() if drop_path <= 0.0 else nn.modules.DropPath(drop_path)
)
if layer_scale_init_value > 0:
self.gamma = nn.Parameter(
layer_scale_init_value * torch.ones(dim), requires_grad=True
)
self.ffn_gamma = nn.Parameter(
layer_scale_init_value * torch.ones(dim), requires_grad=True
)
else:
self.gamma = None
self.ffn_gamma = None
def forward(self, x):
residual = x
x = self.norm(x)
x = self.mixer(x)
if self.gamma is not None:
x = x * self.gamma.unsqueeze(-1)
x = residual + self.drop_path(x)
residual = x
x = self.ffn_norm(x)
x = x.permute(0, 2, 1)
x = self.ffn(x)
x = x.permute(0, 2, 1)
if self.ffn_gamma is not None:
x = x * self.ffn_gamma.unsqueeze(-1)
x = residual + self.drop_path(x)
return x
class TokenizerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.channels = config.channels
self.dimension = config.dimension
self.n_filters = config.n_filters
self.ratios = list(reversed(config.ratios))
self.depths = config.depths
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
self.hop_length = np.prod(self.ratios)
self.causal = config.causal
kernel_size = getattr(config, "kernel_size", 7)
last_kernel_size = getattr(config, "last_kernel_size", 7)
norm = getattr(config, "norm", "none")
norm_params = getattr(config, "norm_params", {})
pad_mode = getattr(config, "pad_mode", "reflect")
bias = getattr(config, "bias", True)
layernorm = getattr(config, "layernorm", "LN")
layernorm_eps = getattr(config, "layernorm_eps", 1e-06)
layernorm_elementwise_affine = getattr(
config, "layernorm_elementwise_affine", True
)
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
mixer_layer = getattr(config, "mixer_layer", "conv")
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
disable_last_norm = getattr(config, "disable_last_norm", False)
if layernorm == "LN":
norm_type = ConvLayerNorm
elif layernorm == "RMSNorm":
norm_type = partial(
ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine
)
else:
raise ValueError(f"Unsupported norm type: {layernorm }")
stem = nn.Sequential(
SConv1d(
self.channels,
self.n_filters,
kernel_size,
norm=norm,
norm_kwargs=norm_params,
causal=self.causal,
pad_mode=pad_mode,
bias=bias,
)
)
self.downsample_layers = nn.ModuleList()
self.downsample_layers.append(stem)
for i in range(len(self.ratios)):
in_ch = self.n_filters * 2**i
out_ch = self.n_filters * 2 ** (i + 1)
downsample_layer = nn.Sequential(
SConv1d(
in_ch,
out_ch,
kernel_size=self.ratios[i] * 2,
stride=self.ratios[i],
causal=self.causal,
pad_mode=pad_mode,
norm=norm,
bias=bias,
)
)
self.downsample_layers.append(downsample_layer)
layer_type = partial(
Block1D,
mixer_layer=mixer_layer,
layernorm=layernorm,
eps=layernorm_eps,
causal=self.causal,
pad_mode=pad_mode,
norm=norm,
bias=bias,
layer_scale_init_value=layer_scale_init_value,
)
self.stages = nn.ModuleList()
dp_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))
]
cur = 0
for i in range(len(self.depths)):
in_ch = self.n_filters * 2**i
stage = nn.Sequential(
*[
layer_type(dim=in_ch, drop_path=dp_rates[cur + j])
for j in range(self.depths[i])
]
)
self.stages.append(stage)
cur += self.depths[i]
if not disable_last_norm:
self.norm = norm_type(in_ch, eps=layernorm_eps)
else:
self.norm = nn.Identity()
self.head = SConv1d(
in_ch,
self.dimension,
kernel_size=last_kernel_size,
causal=self.causal,
pad_mode=pad_mode,
norm=norm,
bias=bias,
)
def forward_features(
self, x, cache=None, sample_indices=None, use_cache=False, debug=False
):
for i in range(len(self.depths)):
for layer in self.downsample_layers[i]:
if isinstance(layer, SConv1d):
x = layer(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
else:
x = layer(x)
for block in self.stages[i]:
if (
hasattr(block, "mixer")
and hasattr(block.mixer, "conv")
and isinstance(block.mixer.conv, SConv1d)
):
residual = x
x = block.norm(x)
x = block.mixer.conv(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
if block.gamma is not None:
x = x * block.gamma.unsqueeze(-1)
x = residual + x
residual = x
x = block.ffn_norm(x)
x = x.permute(0, 2, 1)
x = block.ffn(x)
x = x.permute(0, 2, 1)
if block.ffn_gamma is not None:
x = x * block.ffn_gamma.unsqueeze(-1)
x = residual + x
else:
x = block(x)
return self.norm(x)
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
x = self.forward_features(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
x = self.head(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
return x
class TokenizerDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.dimension = config.dimension
self.channels = config.channels
self.n_filters = config.n_filters
self.ratios = config.ratios
self.depths = config.depths
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
self.hop_length = np.prod(self.ratios)
self.causal = config.causal
kernel_size = getattr(config, "kernel_size", 7)
last_kernel_size = getattr(config, "last_kernel_size", 7)
norm = getattr(config, "norm", "none")
norm_params = getattr(config, "norm_params", {})
pad_mode = getattr(config, "pad_mode", "reflect")
bias = getattr(config, "bias", True)
layernorm = getattr(config, "layernorm", "LN")
layernorm_eps = getattr(config, "layernorm_eps", 1e-06)
trim_right_ratio = getattr(config, "trim_right_ratio", 1.0)
layernorm_elementwise_affine = getattr(
config, "layernorm_elementwise_affine", True
)
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
mixer_layer = getattr(config, "mixer_layer", "conv")
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
disable_last_norm = getattr(config, "disable_last_norm", False)
if layernorm == "LN":
norm_type = ConvLayerNorm
elif layernorm == "RMSNorm":
norm_type = partial(
ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine
)
else:
raise ValueError(f"Unsupported norm type: {layernorm }")
stem = nn.Sequential(
SConv1d(
self.dimension,
self.n_filters * 2 ** (len(self.depths) - 1),
kernel_size,
norm=norm,
norm_kwargs=norm_params,
causal=self.causal,
pad_mode=pad_mode,
bias=bias,
)
)
self.upsample_layers = nn.ModuleList()
self.upsample_layers.append(stem)
for i in range(len(self.ratios)):
in_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i)
out_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i - 1)
upsample_layer = nn.Sequential(
SConvTranspose1d(
in_ch,
out_ch,
kernel_size=self.ratios[i] * 2,
stride=self.ratios[i],
norm=norm,
norm_kwargs=norm_params,
bias=bias,
causal=self.causal,
trim_right_ratio=trim_right_ratio,
)
)
self.upsample_layers.append(upsample_layer)
layer_type = partial(
Block1D,
mixer_layer=mixer_layer,
layernorm=layernorm,
eps=layernorm_eps,
causal=self.causal,
pad_mode=pad_mode,
norm=norm,
bias=bias,
layer_scale_init_value=layer_scale_init_value,
)
self.stages = nn.ModuleList()
dp_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))
]
cur = 0
for i in range(len(self.depths)):
in_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i)
stage = nn.Sequential(
*[
layer_type(dim=in_ch, drop_path=dp_rates[cur + j])
for j in range(self.depths[i])
]
)
self.stages.append(stage)
cur += self.depths[i]
if not disable_last_norm:
self.norm = norm_type(in_ch, eps=layernorm_eps)
else:
self.norm = nn.Identity()
self.head = SConv1d(
in_ch,
self.channels,
kernel_size=last_kernel_size,
causal=self.causal,
pad_mode=pad_mode,
norm=norm,
bias=bias,
)
def forward_features(
self, x, cache=None, sample_indices=None, use_cache=False, debug=False
):
for i in range(len(self.depths)):
for layer in self.upsample_layers[i]:
if isinstance(layer, (SConv1d, SConvTranspose1d)):
x = layer(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
else:
x = layer(x)
for block in self.stages[i]:
if (
hasattr(block, "mixer")
and hasattr(block.mixer, "conv")
and isinstance(block.mixer.conv, SConv1d)
):
residual = x
x = block.norm(x)
x = block.mixer.conv(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
if block.gamma is not None:
x = x * block.gamma.unsqueeze(-1)
x = residual + x
residual = x
x = block.ffn_norm(x)
x = x.permute(0, 2, 1)
x = block.ffn(x)
x = x.permute(0, 2, 1)
if block.ffn_gamma is not None:
x = x * block.ffn_gamma.unsqueeze(-1)
x = residual + x
else:
x = block(x)
return self.norm(x)
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
x = self.forward_features(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
x = self.head(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
return x
@dataclass
class QWEN3VoxTokenizerEncoderOutput:
mean: torch.Tensor
std: Optional[Union[float, torch.Tensor]] = None
def sample(self, dist_type="fix"):
if dist_type == "fix":
x = self.mean + self.std * torch.randn_like(self.mean)
return (x, self.std)
elif dist_type == "gaussian":
batch_size = self.mean.size(0)
value = self.std / 0.8
std = (
torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype)
* value
)
while std.dim() < self.mean.dim():
std = std.unsqueeze(-1)
x = self.mean + std * torch.randn_like(self.mean)
return (x, std)
else:
return (self.mean, self.std)
def kl(self):
target = torch.zeros_like(self.mean)
return F.mse_loss(self.mean, target, reduction="none")
def mode(self):
return self.mean
class QWEN3VoxAcousticTokenizerModel(PreTrainedModel):
config_class = QWEN3VoxAcousticTokenizerConfig
base_model_prefix = 'vibevoice_acoustic_tokenizer'
_supports_flash_attn_2 = True
_supports_sdpa = True
_no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
def __init__(self, config):
super().__init__(config)
self.register_buffer("fix_std", torch.tensor(config.fix_std), persistent=False)
self.std_dist_type = getattr(config, "std_dist_type", "fix")
if isinstance(config.encoder_depths, str):
encoder_depths = [int(d) for d in config.encoder_depths.split("-")]
else:
encoder_depths = config.encoder_depths
if config.decoder_depths is not None and isinstance(config.decoder_depths, str):
decoder_depths = [int(d) for d in config.decoder_depths.split("-")]
else:
decoder_depths = list(reversed(encoder_depths))
encoder_config = copy.deepcopy(config)
encoder_config.dimension = config.vae_dim
encoder_config.n_filters = config.encoder_n_filters
encoder_config.ratios = config.encoder_ratios
encoder_config.depths = encoder_depths
encoder_config.norm = config.conv_norm
encoder_config.pad_mode = config.pad_mode
encoder_config.bias = config.conv_bias
encoder_config.layernorm_eps = config.layernorm_eps
encoder_config.layernorm_elementwise_affine = (
config.layernorm_elementwise_affine
)
encoder_config.mixer_layer = config.mixer_layer
encoder_config.layer_scale_init_value = config.layer_scale_init_value
encoder_config.disable_last_norm = config.disable_last_norm
decoder_config = copy.deepcopy(config)
decoder_config.dimension = config.vae_dim
decoder_config.n_filters = config.decoder_n_filters
decoder_config.ratios = config.decoder_ratios
decoder_config.depths = decoder_depths
decoder_config.norm = config.conv_norm
decoder_config.pad_mode = config.pad_mode
decoder_config.bias = config.conv_bias
decoder_config.layernorm_eps = config.layernorm_eps
decoder_config.layernorm_elementwise_affine = (
config.layernorm_elementwise_affine
)
decoder_config.mixer_layer = config.mixer_layer
decoder_config.layer_scale_init_value = config.layer_scale_init_value
decoder_config.disable_last_norm = config.disable_last_norm
self.encoder = TokenizerEncoder(encoder_config)
self.decoder = TokenizerDecoder(decoder_config)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=self.config.weight_init_value)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv1d):
nn.init.normal_(module.weight, std=self.config.weight_init_value)
if module.bias is not None:
nn.init.zeros_(module.bias)
@torch.no_grad()
def encode(
self, audio, cache=None, sample_indices=None, use_cache=False, debug=False
):
latents = self.encoder(
audio,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
return QWEN3VoxTokenizerEncoderOutput(
mean=latents.permute(0, 2, 1), std=self.fix_std
)
@torch.no_grad()
def sampling(self, encoder_output, dist_type=None):
dist_type = dist_type or self.std_dist_type
if dist_type == "fix":
return encoder_output.sample(dist_type="fix")
elif dist_type == "gaussian":
return encoder_output.sample(dist_type="gaussian")
else:
raise ValueError(
f"Unsupported dist_type: {dist_type }, expected 'fix' or 'gaussian'"
)
@torch.no_grad()
def decode(
self, latents, cache=None, sample_indices=None, use_cache=False, debug=False
):
if latents.shape[1] == self.config.vae_dim:
pass
else:
latents = latents.permute(0, 2, 1)
audio = self.decoder(
latents,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
return audio
def forward(
self, audio, cache=None, sample_indices=None, use_cache=False, debug=False
):
encoder_output = self.encode(
audio,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
sampled_latents, _ = self.sampling(encoder_output)
reconstructed = self.decode(
sampled_latents,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
return (reconstructed, sampled_latents)
class QWEN3VoxSemanticTokenizerModel(PreTrainedModel):
config_class = QWEN3VoxSemanticTokenizerConfig
base_model_prefix = 'vibevoice_semantic_tokenizer'
_supports_flash_attn_2 = True
_supports_sdpa = True
_no_split_modules = ["TokenizerEncoder"]
def __init__(self, config):
super().__init__(config)
if isinstance(config.encoder_depths, str):
encoder_depths = [int(d) for d in config.encoder_depths.split("-")]
else:
encoder_depths = config.encoder_depths
encoder_config = copy.deepcopy(config)
encoder_config.dimension = config.vae_dim
encoder_config.n_filters = config.encoder_n_filters
encoder_config.ratios = config.encoder_ratios
encoder_config.depths = encoder_depths
encoder_config.norm = config.conv_norm
encoder_config.pad_mode = config.pad_mode
encoder_config.bias = config.conv_bias
encoder_config.layernorm_eps = config.layernorm_eps
encoder_config.layernorm_elementwise_affine = (
config.layernorm_elementwise_affine
)
encoder_config.mixer_layer = config.mixer_layer
encoder_config.layer_scale_init_value = config.layer_scale_init_value
encoder_config.disable_last_norm = config.disable_last_norm
self.encoder = TokenizerEncoder(encoder_config)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=self.config.weight_init_value)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv1d):
nn.init.normal_(module.weight, std=self.config.weight_init_value)
if module.bias is not None:
nn.init.zeros_(module.bias)
@torch.no_grad()
def encode(
self, audio, cache=None, sample_indices=None, use_cache=False, debug=False
):
latents = self.encoder(
audio,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
return QWEN3VoxTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
@torch.no_grad()
def sampling(self, encoder_output, dist_type=None):
return encoder_output.sample(dist_type="none")
def forward(
self, audio, cache=None, sample_indices=None, use_cache=False, debug=False
):
encoder_output = self.encode(
audio,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
sampled_latents, _ = self.sampling(encoder_output, dist_type="none")
return (None, sampled_latents)
AutoModel.register(QWEN3VoxAcousticTokenizerConfig, QWEN3VoxAcousticTokenizerModel)
AutoModel.register(QWEN3VoxSemanticTokenizerConfig, QWEN3VoxSemanticTokenizerModel)
__all__ = [
'QWEN3VoxTokenizerStreamingCache',
'QWEN3VoxAcousticTokenizerModel',
'QWEN3VoxSemanticTokenizerModel',
]
'\nProcessor class for QWEN3Vox ASR models.\n'
import os
import json
import math
import warnings
from typing import List, Optional, Union, Dict, Any, Tuple
import numpy as np
import torch
from transformers.tokenization_utils_base import BatchEncoding
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
SYSTEM_PROMPT = "You are a helpful assistant that transcribes audio input into text output in JSON format."
class QWEN3VoxASRProcessor:
def __init__(
self,
tokenizer=None,
audio_processor=None,
speech_tok_compress_ratio=320,
target_sample_rate=22050,
normalize_audio=True,
**kwargs,
):
self.tokenizer = tokenizer
self.audio_processor = audio_processor or QWEN3VoxTokenizerProcessor(
sampling_rate=target_sample_rate, normalize_audio=normalize_audio
)
self.speech_tok_compress_ratio = speech_tok_compress_ratio
self.target_sample_rate = target_sample_rate
self.normalize_audio = normalize_audio
if normalize_audio:
self.audio_normalizer = AudioNormalizer()
else:
self.audio_normalizer = None
self._cache_special_tokens()
def _cache_special_tokens(self):
if hasattr(self.tokenizer, "speech_start_id"):
self.speech_start_id = self.tokenizer.speech_start_id
else:
self.speech_start_id = self.tokenizer.convert_tokens_to_ids(
"<|speech_start|>"
)
if hasattr(self.tokenizer, "speech_end_id"):
self.speech_end_id = self.tokenizer.speech_end_id
else:
self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|speech_end|>")
if hasattr(self.tokenizer, "speech_pad_id"):
self.speech_pad_id = self.tokenizer.speech_pad_id
else:
self.speech_pad_id = self.tokenizer.convert_tokens_to_ids("<|speech_pad|>")
if hasattr(self.tokenizer, "pad_id"):
self.pad_id = self.tokenizer.pad_id
elif hasattr(self.tokenizer, "pad_token_id"):
self.pad_id = self.tokenizer.pad_token_id
else:
self.pad_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
import json
from transformers.utils import cached_file
config_path = os.path.join(
pretrained_model_name_or_path, "preprocessor_config.json"
)
config = {}
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
else:
try:
config_file = cached_file(
pretrained_model_name_or_path, "preprocessor_config.json", **kwargs
)
with open(config_file, "r") as f:
config = json.load(f)
except Exception as e:
logger.warning(f"Could not load preprocessor_config.json: {e }")
logger.warning("Using default configuration")
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
target_sample_rate = config.get("target_sample_rate", 22050)
normalize_audio = config.get("normalize_audio", True)
model_name = pretrained_model_name_or_path
logger.info(f"Loading tokenizer from {model_name }")
tokenizer = QWEN3VoxASRTextTokenizerFast.from_pretrained(
model_name, **kwargs
)
audio_processor = QWEN3VoxTokenizerProcessor(
sampling_rate=target_sample_rate,
normalize_audio=normalize_audio,
target_dB_FS=config.get("target_dB_FS", -25),
eps=config.get("eps", 1e-06),
)
return cls(
tokenizer=tokenizer,
audio_processor=audio_processor,
speech_tok_compress_ratio=speech_tok_compress_ratio,
target_sample_rate=target_sample_rate,
normalize_audio=normalize_audio,
)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
import json
os.makedirs(save_directory, exist_ok=True)
processor_config = {
"processor_class": "QWEN3VoxASRProcessor",
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
"target_sample_rate": self.target_sample_rate,
"normalize_audio": self.normalize_audio,
"target_dB_FS": -25,
"eps": 1e-06,
}
config_path = os.path.join(save_directory, "preprocessor_config.json")
with open(config_path, "w") as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Processor configuration saved in {config_path }")
def __call__(
self,
audio: Optional[
Union[
str,
np.ndarray,
torch.Tensor,
List[Union[str, np.ndarray, torch.Tensor]],
]
] = None,
sampling_rate: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
padding: bool = True,
max_length: Optional[int] = None,
truncation: bool = False,
add_generation_prompt: bool = True,
use_streaming: bool = True,
context_info: Optional[str] = None,
**kwargs,
) -> BatchEncoding:
if audio is None:
raise ValueError("Audio input is required for ASR processing")
if isinstance(audio, list):
is_batched = True
audio_list = audio
else:
is_batched = False
audio_list = [audio]
all_encodings = []
for audio_input in audio_list:
encoding = self._process_single_audio(
audio_input,
sampling_rate=sampling_rate,
add_generation_prompt=add_generation_prompt,
use_streaming=use_streaming,
context_info=context_info,
)
all_encodings.append(encoding)
batch_encoding = self._batch_encode(
all_encodings,
padding=padding,
max_length=max_length,
truncation=truncation,
return_tensors=return_tensors,
)
return batch_encoding
def _process_single_audio(
self,
audio: Union[str, np.ndarray, torch.Tensor],
sampling_rate: Optional[int] = None,
add_generation_prompt: bool = True,
use_streaming: bool = True,
context_info: Optional[str] = None,
) -> Dict[str, Any]:
if isinstance(audio, str):
import soundfile as sf
audio_array, file_sr = sf.read(audio)
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
if file_sr != self.target_sample_rate:
import librosa
audio_array = librosa.resample(
audio_array, orig_sr=file_sr, target_sr=self.target_sample_rate
)
elif isinstance(audio, torch.Tensor):
audio_array = audio.cpu().numpy()
if audio_array.ndim > 1:
audio_array = audio_array.squeeze()
else:
audio_array = np.array(audio, dtype=np.float32)
if audio_array.ndim > 1:
audio_array = audio_array.squeeze()
audio_array = audio_array.astype(np.float32)
if self.normalize_audio and self.audio_normalizer:
audio_array = self.audio_normalizer(audio_array)
audio_duration = len(audio_array) / self.target_sample_rate
if use_streaming and audio_duration < 60.0:
use_streaming = False
vae_tok_len = math.ceil(len(audio_array) / self.speech_tok_compress_ratio)
system_prompt_text = self.tokenizer.apply_chat_template(
[{"role": "system", "content": SYSTEM_PROMPT}], tokenize=False
)
system_tokens = self.tokenizer.encode(system_prompt_text)
sp_start_token = self.tokenizer.convert_ids_to_tokens(self.speech_start_id)
sp_pad_token = self.tokenizer.convert_ids_to_tokens(self.speech_pad_id)
sp_end_token = self.tokenizer.convert_ids_to_tokens(self.speech_end_id)
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
if context_info and context_info.strip():
user_suffix = (
f"This is a {audio_duration :.2f} seconds audio, with extra info: {context_info .strip ()}\n\nPlease transcribe it with these keys: "
+ ", ".join(show_keys)
)
else:
user_suffix = (
f"This is a {audio_duration :.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys)
)
user_input_string = (
"".join([sp_start_token] + [sp_pad_token] * vae_tok_len + [sp_end_token])
+ "\n"
+ user_suffix
)
user_tokens = self.tokenizer.apply_chat_template(
[{"role": "user", "content": user_input_string}], tokenize=True
)
full_tokens = system_tokens + user_tokens
acoustic_input_mask = [
1 if token == self.speech_pad_id else 0 for token in full_tokens
]
return {
"input_ids": full_tokens,
"acoustic_input_mask": acoustic_input_mask,
"speech": audio_array,
"vae_tok_len": vae_tok_len,
}
def _batch_encode(
self,
encodings: List[Dict[str, Any]],
padding: bool = True,
max_length: Optional[int] = None,
truncation: bool = False,
return_tensors: Optional[str] = None,
) -> BatchEncoding:
input_ids_list = [enc["input_ids"] for enc in encodings]
acoustic_masks_list = [enc["acoustic_input_mask"] for enc in encodings]
speech_list = [enc["speech"] for enc in encodings]
vae_tok_lens = [enc["vae_tok_len"] for enc in encodings]
if padding:
if max_length is not None:
target_length = max_length
else:
target_length = max((len(ids) for ids in input_ids_list))
padded_input_ids = []
padded_acoustic_masks = []
attention_masks = []
for input_ids, acoustic_mask in zip(input_ids_list, acoustic_masks_list):
if truncation and len(input_ids) > target_length:
input_ids = input_ids[:target_length]
acoustic_mask = acoustic_mask[:target_length]
padding_length = target_length - len(input_ids)
padded_ids = [self.pad_id] * padding_length + input_ids
padded_acoustic = [0] * padding_length + acoustic_mask
attention_mask = [0] * padding_length + [1] * len(input_ids)
padded_input_ids.append(padded_ids)
padded_acoustic_masks.append(padded_acoustic)
attention_masks.append(attention_mask)
input_ids_list = padded_input_ids
acoustic_masks_list = padded_acoustic_masks
else:
attention_masks = [[1] * len(ids) for ids in input_ids_list]
max_speech_length = max((len(s) for s in speech_list))
padded_speeches = np.zeros(
(len(speech_list), max_speech_length), dtype=np.float32
)
speech_masks = np.zeros((len(speech_list), max(vae_tok_lens)), dtype=bool)
for i, (speech, vae_len) in enumerate(zip(speech_list, vae_tok_lens)):
padded_speeches[i, : len(speech)] = speech
speech_masks[i, :vae_len] = True
batch_encoding = BatchEncoding()
if return_tensors == "pt":
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
batch_encoding["attention_mask"] = torch.tensor(
attention_masks, dtype=torch.long
)
batch_encoding["acoustic_input_mask"] = torch.tensor(
acoustic_masks_list, dtype=torch.bool
)
batch_encoding["speech_tensors"] = torch.tensor(
padded_speeches, dtype=torch.float32
)
batch_encoding["speech_masks"] = torch.tensor(
speech_masks, dtype=torch.bool
)
else:
batch_encoding["input_ids"] = (
input_ids_list if len(input_ids_list) > 1 else input_ids_list[0]
)
batch_encoding["attention_mask"] = (
attention_masks if len(attention_masks) > 1 else attention_masks[0]
)
batch_encoding["acoustic_input_mask"] = (
acoustic_masks_list
if len(acoustic_masks_list) > 1
else acoustic_masks_list[0]
)
batch_encoding["speech_tensors"] = (
padded_speeches if len(padded_speeches) > 1 else padded_speeches[0]
)
batch_encoding["speech_masks"] = (
speech_masks if len(speech_masks) > 1 else speech_masks[0]
)
return batch_encoding
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
def post_process_transcription(self, text: str) -> List[Dict[str, Any]]:
try:
if "```json" in text:
json_start = text.find("```json") + 7
json_end = text.find("```", json_start)
json_str = text[json_start:json_end].strip()
else:
json_start = text.find("[")
if json_start == -1:
json_start = text.find("{")
if json_start != -1:
bracket_count = 0
json_end = json_start
for i in range(json_start, len(text)):
if text[i] in "[{":
bracket_count += 1
elif text[i] in "]}":
bracket_count -= 1
if bracket_count == 0:
json_end = i + 1
break
json_str = text[json_start:json_end]
else:
json_str = text
result = json.loads(json_str)
if isinstance(result, dict):
result = [result]
cleaned_result = []
for item in result:
if isinstance(item, dict):
cleaned_item = {}
key_mapping = {
"Start time": "start_time",
"Start": "start_time",
"End time": "end_time",
"End": "end_time",
"Speaker ID": "speaker_id",
"Speaker": "speaker_id",
"Content": "text",
}
for key, mapped_key in key_mapping.items():
if key in item:
cleaned_item[mapped_key] = item[key]
if cleaned_item:
cleaned_result.append(cleaned_item)
return cleaned_result
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON from transcription: {e }")
logger.debug(f"Raw text: {text }")
return []
except Exception as e:
logger.warning(f"Error post-processing transcription: {e }")
return []
@property
def model_input_names(self):
return [
"input_ids",
"attention_mask",
"acoustic_input_mask",
"speech_tensors",
"speech_masks",
]
__all__ = [
'QWEN3VoxASRProcessor'
]
import math
import warnings
from typing import List, Optional, Union, Dict, Any, Tuple
import os
import re
import numpy as np
import torch
from transformers.tokenization_utils_base import (
BatchEncoding,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
class QWEN3VoxProcessor:
def __init__(
self,
tokenizer=None,
audio_processor=None,
speech_tok_compress_ratio=3200,
db_normalize=True,
**kwargs,
):
self.tokenizer = tokenizer
self.audio_processor = audio_processor
self.speech_tok_compress_ratio = speech_tok_compress_ratio
self.db_normalize = db_normalize
self.audio_normalizer = AudioNormalizer() if db_normalize else None
self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
import os
import json
from transformers.utils import cached_file
config_path = os.path.join(
pretrained_model_name_or_path, "preprocessor_config.json"
)
config = None
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
else:
try:
config_file = cached_file(
pretrained_model_name_or_path, "preprocessor_config.json", **kwargs
)
with open(config_file, "r") as f:
config = json.load(f)
except Exception as e:
logger.warning(
f"Could not load preprocessor_config.json from {pretrained_model_name_or_path }: {e }"
)
logger.warning("Using default configuration")
config = {"speech_tok_compress_ratio": 3200, "db_normalize": True}
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
db_normalize = config.get("db_normalize", True)
model_name = pretrained_model_name_or_path
logger.info(f"Loading tokenizer from {model_name }")
tokenizer = QWEN3VoxTextTokenizerFast.from_pretrained(
model_name, **kwargs
)
if "audio_processor" in config:
audio_config = config["audio_processor"]
audio_processor = QWEN3VoxTokenizerProcessor(
sampling_rate=audio_config.get("sampling_rate", 22050),
normalize_audio=audio_config.get("normalize_audio", True),
target_dB_FS=audio_config.get("target_dB_FS", -25),
eps=audio_config.get("eps", 1e-06),
)
else:
audio_processor = QWEN3VoxTokenizerProcessor()
return cls(
tokenizer=tokenizer,
audio_processor=audio_processor,
speech_tok_compress_ratio=speech_tok_compress_ratio,
db_normalize=db_normalize,
)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
import os
import json
os.makedirs(save_directory, exist_ok=True)
processor_config = {
"processor_class": "QWEN3VoxProcessor",
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
"db_normalize": self.db_normalize,
"audio_processor": {
"feature_extractor_type": "QWEN3VoxTokenizerProcessor",
"sampling_rate": getattr(self.audio_processor, "sampling_rate", 22050),
"normalize_audio": getattr(
self.audio_processor, "normalize_audio", True
),
"target_dB_FS": getattr(self.audio_processor, "target_dB_FS", -25),
"eps": getattr(self.audio_processor, "eps", 1e-06),
},
}
config_path = os.path.join(save_directory, "preprocessor_config.json")
with open(config_path, "w") as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Processor configuration saved in {config_path }")
def __call__(
self,
text: Optional[
Union[
str,
List[str],
TextInput,
PreTokenizedInput,
List[TextInput],
List[PreTokenizedInput],
]
] = None,
voice_samples: Optional[
Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]
] = None,
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
**kwargs,
) -> BatchEncoding:
if isinstance(text, str) or (
isinstance(text, list) and len(text) > 0 and (not isinstance(text[0], str))
):
texts = [text]
is_batched = False
else:
texts = text
is_batched = True
if voice_samples is not None:
if not is_batched or isinstance(voice_samples[0], (str, np.ndarray)):
voice_samples_list = [voice_samples]
else:
voice_samples_list = voice_samples
else:
voice_samples_list = [None] * len(texts)
all_encodings = []
for text_input, voice_input in zip(texts, voice_samples_list):
encoding = self._process_single(text_input, voice_input)
all_encodings.append(encoding)
batch_encoding = self._batch_encode(
all_encodings,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_attention_mask=return_attention_mask,
)
return batch_encoding
def _process_single(
self,
text: Union[str, TextInput],
voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
) -> Dict[str, Any]:
script = None
if isinstance(text, str):
if text.endswith(".json") and os.path.exists(text):
script = self._convert_json_to_script(text)
elif text.endswith(".txt") and os.path.exists(text):
script = self._convert_text_to_script(text)
else:
script = text
if script is None:
raise ValueError(f"Could not process input text: {text }")
parsed_lines = self._parse_script(script)
all_speakers = list(set((speaker_id for speaker_id, _ in parsed_lines)))
system_tokens = self.tokenizer.encode(self.system_prompt)
if voice_samples:
voice_tokens, voice_speech_inputs, voice_speech_masks = (
self._create_voice_prompt(voice_samples[: len(all_speakers)])
)
else:
voice_tokens, voice_speech_inputs, voice_speech_masks = ([], [], [])
full_tokens = system_tokens + voice_tokens
speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
full_tokens += self.tokenizer.encode(" Text input:\n", add_special_tokens=False)
speech_input_mask += [False] * len(
self.tokenizer.encode(" Text input:\n", add_special_tokens=False)
)
for speaker_id, speaker_text in parsed_lines:
speaker_text_tokens = self.tokenizer.encode(
f" Speaker {speaker_id }:{speaker_text }\n", add_special_tokens=False
)
full_tokens += speaker_text_tokens
speech_input_mask += [False] * len(speaker_text_tokens)
full_tokens += self.tokenizer.encode(
" Speech output:\n", add_special_tokens=False
) + [self.tokenizer.speech_start_id]
speech_input_mask += [False] * (
len(self.tokenizer.encode(" Speech output:\n", add_special_tokens=False))
+ 1
)
return {
"input_ids": full_tokens,
"speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
"speech_input_mask": speech_input_mask,
"parsed_script": parsed_lines,
"all_speakers": all_speakers,
}
def _batch_encode(
self,
encodings: List[Dict[str, Any]],
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
) -> BatchEncoding:
input_ids_list = [enc["input_ids"] for enc in encodings]
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
if isinstance(padding, bool):
padding_strategy = (
PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
)
elif isinstance(padding, str):
padding_strategy = PaddingStrategy(padding)
else:
padding_strategy = padding
if padding_strategy != PaddingStrategy.DO_NOT_PAD:
if padding_strategy == PaddingStrategy.LONGEST:
max_len = max((len(ids) for ids in input_ids_list))
elif (
padding_strategy == PaddingStrategy.MAX_LENGTH
and max_length is not None
):
max_len = max_length
else:
max_len = max((len(ids) for ids in input_ids_list))
padded_input_ids = []
attention_masks = []
padded_speech_input_masks = []
for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
if truncation and len(input_ids) > max_len:
input_ids = input_ids[:max_len]
speech_mask = speech_mask[:max_len]
padding_length = max_len - len(input_ids)
padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
attention_mask = [0] * padding_length + [1] * len(input_ids)
padded_speech_mask = [False] * padding_length + speech_mask
padded_input_ids.append(padded_ids)
attention_masks.append(attention_mask)
padded_speech_input_masks.append(padded_speech_mask)
input_ids_list = padded_input_ids
speech_input_masks_list = padded_speech_input_masks
else:
attention_masks = (
[[1] * len(ids) for ids in input_ids_list]
if return_attention_mask
else None
)
all_speech_inputs = []
has_speech = False
for enc in encodings:
if enc["speech_inputs"] is not None:
all_speech_inputs.extend(enc["speech_inputs"])
has_speech = True
batch_encoding = BatchEncoding()
if return_tensors is not None:
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = torch.tensor(
attention_masks, dtype=torch.long
)
batch_encoding["speech_input_mask"] = torch.tensor(
speech_input_masks_list, dtype=torch.bool
)
else:
batch_encoding["input_ids"] = input_ids_list
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = attention_masks
batch_encoding["speech_input_mask"] = speech_input_masks_list
if has_speech:
speech_dict = self.prepare_speech_inputs(
all_speech_inputs, return_tensors=return_tensors
)
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
else:
batch_encoding["speech_tensors"] = None
batch_encoding["speech_masks"] = None
batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
return batch_encoding
def _create_voice_prompt(
self, speaker_samples: List[Union[str, np.ndarray]]
) -> Tuple[List[int], List[np.ndarray], List[bool]]:
vae_token_id = self.tokenizer.speech_diffusion_id
voice_full_tokens = self.tokenizer.encode(
" Voice input:\n", add_special_tokens=False
)
voice_speech_inputs = []
voice_speech_masks = [False] * len(voice_full_tokens)
for speaker_id, speaker_audio in enumerate(speaker_samples):
prefix_tokens = self.tokenizer.encode(
f" Speaker {speaker_id }:", add_special_tokens=False
)
if isinstance(speaker_audio, str):
wav = self.audio_processor._load_audio_from_path(speaker_audio)
elif isinstance(speaker_audio, dict):
if "array" in speaker_audio:
wav = np.array(speaker_audio["array"], dtype=np.float32)
elif "audio" in speaker_audio:
wav = np.array(speaker_audio["audio"], dtype=np.float32)
else:
raise ValueError(
f"Dictionary audio input must have 'array' or 'audio' key, got: {speaker_audio .keys ()}"
)
else:
wav = np.array(speaker_audio, dtype=np.float32)
if self.db_normalize and self.audio_normalizer:
wav = self.audio_normalizer(wav)
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
speaker_tokens = (
prefix_tokens
+ [self.tokenizer.speech_start_id]
+ [vae_token_id] * vae_tok_len
+ [self.tokenizer.speech_end_id]
+ self.tokenizer.encode("\n", add_special_tokens=False)
)
vae_input_mask = (
[False] * len(prefix_tokens)
+ [False]
+ [True] * vae_tok_len
+ [False]
+ [False]
)
voice_full_tokens.extend(speaker_tokens)
voice_speech_masks.extend(vae_input_mask)
voice_speech_inputs.append(wav)
return (voice_full_tokens, voice_speech_inputs, voice_speech_masks)
def prepare_speech_inputs(
self,
speech_inputs: List[np.ndarray],
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> Dict[str, Any]:
if not speech_inputs:
return {"padded_speeches": None, "speech_masks": None}
vae_tok_seqlens = [
math.ceil(s.shape[0] / self.speech_tok_compress_ratio)
for s in speech_inputs
]
max_speech_length = max((s.shape[0] for s in speech_inputs))
if speech_inputs[0].ndim == 1:
padded_speeches = np.full(
(len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32
)
else:
padded_speeches = np.full(
(len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]),
fill_value=0,
dtype=np.float32,
)
speech_masks = np.zeros(
(len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_
)
for i, (speech, vae_tok_length) in enumerate(
zip(speech_inputs, vae_tok_seqlens)
):
padded_speeches[i, : len(speech)] = speech
speech_masks[i, :vae_tok_length] = True
result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks}
if return_tensors == "pt":
result["padded_speeches"] = torch.tensor(
padded_speeches, device=device, dtype=dtype or torch.float32
)
result["speech_masks"] = torch.tensor(
speech_masks, device=device, dtype=torch.bool
)
return result
def _convert_json_to_script(self, json_file: str) -> str:
import json
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("JSON file must contain a list of speaker entries")
script_lines = []
for item in data:
if not isinstance(item, dict):
logger.warning(f"Skipping non-dict entry: {item }")
continue
speaker = item.get("speaker")
text = item.get("text")
if speaker is None or text is None:
logger.warning(f"Skipping entry missing speaker or text: {item }")
continue
try:
speaker_id = int(speaker)
except (ValueError, TypeError):
logger.warning(f"Invalid speaker ID: {speaker }, skipping entry")
continue
text = text.strip()
if text:
script_lines.append(f"Speaker {speaker_id }: {text }")
if not script_lines:
raise ValueError("No valid entries found in JSON file")
return "\n".join(script_lines)
def _convert_text_to_script(self, text_file: str) -> str:
with open(text_file, "r", encoding="utf-8") as f:
lines = f.readlines()
script_lines = []
current_speaker = 1
for line in lines:
line = line.strip()
if not line:
continue
speaker_match = re.match(
"^Speaker\\s+(\\d+)\\s*:\\s*(.*)$", line, re.IGNORECASE
)
if speaker_match:
speaker_id = int(speaker_match.group(1))
text = speaker_match.group(2).strip()
if text:
script_lines.append(f"Speaker {speaker_id }: {text }")
else:
script_lines.append(f"Speaker {current_speaker }: {line }")
if not script_lines:
raise ValueError("No valid content found in text file")
return "\n".join(script_lines)
def _parse_script(self, script: str) -> List[Tuple[int, str]]:
lines = script.strip().split("\n")
parsed_lines = []
speaker_ids = []
for line in lines:
if not line.strip():
continue
match = re.match(
"^Speaker\\s+(\\d+)\\s*:\\s*(.*)$", line.strip(), re.IGNORECASE
)
if match:
speaker_id = int(match.group(1))
text = " " + match.group(2).strip()
parsed_lines.append((speaker_id, text))
speaker_ids.append(speaker_id)
elif line.strip():
# Vocence validators send plain transcription (no "Speaker N:" prefix).
parsed_lines.append((1, " " + line.strip()))
speaker_ids.append(1)
if not parsed_lines:
raise ValueError("No valid speaker lines found in script")
min_speaker_id = min(speaker_ids)
if min_speaker_id > 0:
normalized_lines = []
for speaker_id, text in parsed_lines:
normalized_lines.append((speaker_id - 1, text))
return normalized_lines
else:
return parsed_lines
def _merge_inputs(
self, text_inputs: BatchEncoding, audio_inputs: Dict
) -> BatchEncoding:
merged = BatchEncoding(text_inputs)
if "audio" in audio_inputs:
merged["speech_inputs"] = audio_inputs["audio"]
if "streaming" in audio_inputs:
merged["streaming"] = audio_inputs["streaming"]
return merged
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
audio_processor_input_names = self.audio_processor.model_input_names
return list(
dict.fromkeys(
tokenizer_input_names
+ audio_processor_input_names
+ ["speech_inputs", "speech_input_mask"]
)
)
def save_audio(
self,
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
) -> str:
return self.audio_processor.save_audio(
audio,
output_path=output_path,
sampling_rate=sampling_rate,
normalize=normalize,
batch_prefix=batch_prefix,
)
__all__ = [
'QWEN3VoxProcessor'
]
'\nQWEN3Vox Streaming Processor\n\nThis processor handles input preparation for the streaming 0.5B model,\nincluding text tokenization and cached voice prompt handling.\n'
import math
import warnings
from typing import List, Optional, Union, Dict, Any, Tuple
import os
import re
import numpy as np
import torch
from transformers.tokenization_utils_base import (
BatchEncoding,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
class QWEN3VoxStreamingProcessor:
def __init__(
self,
tokenizer=None,
audio_processor=None,
speech_tok_compress_ratio=3200,
db_normalize=True,
**kwargs,
):
self.tokenizer = tokenizer
self.audio_processor = audio_processor
self.speech_tok_compress_ratio = speech_tok_compress_ratio
self.db_normalize = db_normalize
self.audio_normalizer = AudioNormalizer() if db_normalize else None
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
import os
import json
from transformers.utils import cached_file
config_path = os.path.join(
pretrained_model_name_or_path, "preprocessor_config.json"
)
config = None
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
else:
try:
config_file = cached_file(
pretrained_model_name_or_path, "preprocessor_config.json", **kwargs
)
with open(config_file, "r") as f:
config = json.load(f)
except Exception as e:
logger.warning(
f"Could not load preprocessor_config.json from {pretrained_model_name_or_path }: {e }"
)
logger.warning("Using default configuration")
config = {"speech_tok_compress_ratio": 3200, "db_normalize": True}
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
db_normalize = config.get("db_normalize", True)
model_name = pretrained_model_name_or_path
logger.info(f"Loading tokenizer from {model_name }")
tokenizer = QWEN3VoxTextTokenizerFast.from_pretrained(
model_name, **kwargs
)
if "audio_processor" in config:
audio_config = config["audio_processor"]
audio_processor = QWEN3VoxTokenizerProcessor(
sampling_rate=audio_config.get("sampling_rate", 22050),
normalize_audio=audio_config.get("normalize_audio", True),
target_dB_FS=audio_config.get("target_dB_FS", -25),
eps=audio_config.get("eps", 1e-06),
)
else:
audio_processor = QWEN3VoxTokenizerProcessor()
return cls(
tokenizer=tokenizer,
audio_processor=audio_processor,
speech_tok_compress_ratio=speech_tok_compress_ratio,
db_normalize=db_normalize,
)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
import os
import json
os.makedirs(save_directory, exist_ok=True)
processor_config = {
"processor_class": "QWEN3VoxStreamingProcessor",
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
"db_normalize": self.db_normalize,
"audio_processor": {
"feature_extractor_type": "QWEN3VoxTokenizerProcessor",
"sampling_rate": getattr(self.audio_processor, "sampling_rate", 22050),
"normalize_audio": getattr(
self.audio_processor, "normalize_audio", True
),
"target_dB_FS": getattr(self.audio_processor, "target_dB_FS", -25),
"eps": getattr(self.audio_processor, "eps", 1e-06),
},
}
config_path = os.path.join(save_directory, "preprocessor_config.json")
with open(config_path, "w") as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Processor configuration saved in {config_path }")
def __call__(self) -> BatchEncoding:
raise NotImplementedError(
'QWEN3VoxStreamingProcessor.__call__ is not implemented. Use process_input_with_cached_prompt for streaming inputs.'
)
def process_input_with_cached_prompt(
self,
text: Optional[str] = None,
cached_prompt: Optional[Dict[str, Any]] = None,
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
**kwargs,
) -> BatchEncoding:
texts = [text]
cached_prompts = [cached_prompt]
is_batched = False
all_encodings = []
for text_input, cached_prompt_input in zip(texts, cached_prompts):
script_tokens = self.tokenizer.encode(
text_input.strip() + "\n", add_special_tokens=False
)
input_id_length = cached_prompt_input["lm"]["last_hidden_state"].size(1)
tts_lm_input_id_length = cached_prompt_input["tts_lm"][
"last_hidden_state"
].size(1)
input_ids = [self.tokenizer.pad_id] * input_id_length
tts_lm_input_ids = [self.tokenizer.pad_id] * tts_lm_input_id_length
speech_input_mask = [False] * tts_lm_input_id_length
encoding = {
"input_ids": input_ids,
"tts_lm_input_ids": tts_lm_input_ids,
"tts_text_ids": script_tokens,
"speech_inputs": None,
"speech_input_mask": speech_input_mask,
}
all_encodings.append(encoding)
batch_encoding = self._batch_encode(
all_encodings,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_attention_mask=return_attention_mask,
)
return batch_encoding
def _batch_encode(
self,
encodings: List[Dict[str, Any]],
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
) -> BatchEncoding:
input_ids_list = [enc["input_ids"] for enc in encodings]
tts_lm_input_ids_list = [enc["tts_lm_input_ids"] for enc in encodings]
tts_text_ids_list = [enc["tts_text_ids"] for enc in encodings]
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
attention_masks = (
[[1] * len(ids) for ids in input_ids_list]
if return_attention_mask
else None
)
tts_lm_attention_masks = (
[[1] * len(ids) for ids in tts_lm_input_ids_list]
if return_attention_mask
else None
)
all_speech_inputs = []
has_speech = False
for enc in encodings:
if enc["speech_inputs"] is not None:
all_speech_inputs.extend(enc["speech_inputs"])
has_speech = True
batch_encoding = BatchEncoding()
if return_tensors is not None:
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
batch_encoding["tts_lm_input_ids"] = torch.tensor(
tts_lm_input_ids_list, dtype=torch.long
)
batch_encoding["tts_text_ids"] = torch.tensor(
tts_text_ids_list, dtype=torch.long
)
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = torch.tensor(
attention_masks, dtype=torch.long
)
batch_encoding["tts_lm_attention_mask"] = torch.tensor(
tts_lm_attention_masks, dtype=torch.long
)
batch_encoding["speech_input_mask"] = torch.tensor(
speech_input_masks_list, dtype=torch.bool
)
else:
batch_encoding["input_ids"] = input_ids_list
batch_encoding["tts_lm_input_ids"] = tts_lm_input_ids_list
batch_encoding["tts_text_ids"] = tts_text_ids_list
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = attention_masks
batch_encoding["tts_lm_attention_mask"] = tts_lm_attention_masks
batch_encoding["speech_input_mask"] = speech_input_masks_list
if has_speech:
speech_dict = self.prepare_speech_inputs(
all_speech_inputs, return_tensors=return_tensors
)
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
else:
batch_encoding["speech_tensors"] = None
batch_encoding["speech_masks"] = None
return batch_encoding
def prepare_speech_inputs(
self,
speech_inputs: List[np.ndarray],
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> Dict[str, Any]:
if not speech_inputs:
return {"padded_speeches": None, "speech_masks": None}
vae_tok_seqlens = [
math.ceil(s.shape[0] / self.speech_tok_compress_ratio)
for s in speech_inputs
]
max_speech_length = max((s.shape[0] for s in speech_inputs))
if speech_inputs[0].ndim == 1:
padded_speeches = np.full(
(len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32
)
else:
padded_speeches = np.full(
(len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]),
fill_value=0,
dtype=np.float32,
)
speech_masks = np.zeros(
(len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_
)
for i, (speech, vae_tok_length) in enumerate(
zip(speech_inputs, vae_tok_seqlens)
):
padded_speeches[i, : len(speech)] = speech
speech_masks[i, :vae_tok_length] = True
result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks}
if return_tensors == "pt":
result["padded_speeches"] = torch.tensor(
padded_speeches, device=device, dtype=dtype or torch.float32
)
result["speech_masks"] = torch.tensor(
speech_masks, device=device, dtype=torch.bool
)
return result
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
audio_processor_input_names = self.audio_processor.model_input_names
return list(
dict.fromkeys(
tokenizer_input_names
+ audio_processor_input_names
+ ["speech_inputs", "speech_input_mask"]
)
)
def save_audio(
self,
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
) -> str:
return self.audio_processor.save_audio(
audio,
output_path=output_path,
sampling_rate=sampling_rate,
normalize=normalize,
batch_prefix=batch_prefix,
)
__all__ = [
'QWEN3VoxStreamingProcessor'
]
'\nQWEN3Vox Streaming Model Architecture (0.5B)\n\nThis module implements the streaming-optimized version of QWEN3Vox for real-time TTS.\nKey differences from the multi-speaker model:\n- No semantic tokenizer (only acoustic)\n- Split language model architecture: lower layers for text, upper layers for TTS\n- Optimized for low-latency generation\n'
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
CausalLMOutput,
BaseModelOutputWithPast,
ModelOutput,
)
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
logger = logging.get_logger(__name__)
if (
not hasattr(modeling_utils, "ALL_PARALLEL_STYLES")
or modeling_utils.ALL_PARALLEL_STYLES is None
):
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
class BinaryClassifier(nn.Module):
def __init__(self, hidden_size):
super(BinaryClassifier, self).__init__()
self.fc1 = nn.Linear(hidden_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
class SpeechConnector(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
self.norm = LlamaRMSNorm(output_dim, eps=1e-06)
self.fc2 = nn.Linear(output_dim, output_dim)
def forward(self, features, **kwargs):
x = self.fc1(features)
x = self.norm(x)
x = self.fc2(x)
return x
class QWEN3VoxStreamingPreTrainedModel(PreTrainedModel):
config_class = QWEN3VoxStreamingConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if isinstance(module, QWEN3VoxDiffusionHead):
module.initialize_weights()
return
if hasattr(self.config, "language_model_config") and hasattr(
self.config.language_model_config, "initializer_range"
):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, "decoder_config") and hasattr(
self.config.decoder_config, "initializer_range"
):
std = self.config.decoder_config.initializer_range
else:
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
class QWEN3VoxStreamingModel(QWEN3VoxStreamingPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
lm_config = copy.deepcopy(config.decoder_config)
lm_backbone_num_hidden_layers = (
getattr(lm_config, "num_hidden_layers", 24)
- config.tts_backbone_num_hidden_layers
)
lm_config.num_hidden_layers = lm_backbone_num_hidden_layers
self.language_model = AutoModel.from_config(lm_config)
self.language_model.norm = nn.Identity()
tts_lm_config = copy.deepcopy(lm_config)
tts_lm_config.num_hidden_layers = config.tts_backbone_num_hidden_layers
self.tts_language_model = AutoModel.from_config(tts_lm_config)
self.tts_input_types = nn.Embedding(
num_embeddings=2, embedding_dim=config.decoder_config.hidden_size
)
self.acoustic_tokenizer = AutoModel.from_config(
config.acoustic_tokenizer_config
).to(dtype)
self.acoustic_connector = SpeechConnector(
config.acoustic_vae_dim, lm_config.hidden_size
).to(dtype)
self.register_buffer("speech_scaling_factor", torch.tensor(float("nan")))
self.register_buffer("speech_bias_factor", torch.tensor(float("nan")))
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(
dtype
)
self.noise_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
prediction_type=config.diffusion_head_config.prediction_type,
)
def get_input_embeddings(self):
if hasattr(self.language_model, "embed_tokens"):
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items():
if attr.orig_name == "embed_tokens.weight":
return getattr(self.language_model, name)
assert False, "should not arrive here"
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None):
self.acoustic_tokenizer = acoustic_tokenizer
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.train(False)
def forward(self, *args, **kwargs):
raise RuntimeError(
'QWEN3VoxStreamingModel.forward is intentionally disabled. Use `model.language_model(...)` or `model.tts_language_model(...)` instead.'
)
AutoModel.register(QWEN3VoxStreamingConfig, QWEN3VoxStreamingModel)
__all__ = [
'QWEN3VoxStreamingPreTrainedModel',
'QWEN3VoxStreamingModel',
"BinaryClassifier",
"SpeechConnector",
]
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
CausalLMOutput,
BaseModelOutputWithPast,
ModelOutput,
)
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
logger = logging.get_logger(__name__)
if (
not hasattr(modeling_utils, "ALL_PARALLEL_STYLES")
or modeling_utils.ALL_PARALLEL_STYLES is None
):
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
@dataclass
class QWEN3VoxCausalLMOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
diffusion_loss: Optional[torch.FloatTensor] = None
speech_token_num: Optional[int] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class QWEN3VoxGenerationOutput(ModelOutput):
sequences: torch.LongTensor = None
speech_outputs: Optional[List[torch.FloatTensor]] = None
class SpeechConnector(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
self.norm = LlamaRMSNorm(output_dim, eps=1e-06)
self.fc2 = nn.Linear(output_dim, output_dim)
def forward(self, features, **kwargs):
x = self.fc1(features)
x = self.norm(x)
x = self.fc2(x)
return x
class QWEN3VoxPreTrainedModel(PreTrainedModel):
config_class = QWEN3VoxConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if isinstance(module, QWEN3VoxDiffusionHead):
module.initialize_weights()
return
if hasattr(self.config, "language_model_config") and hasattr(
self.config.language_model_config, "initializer_range"
):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, "decoder_config") and hasattr(
self.config.decoder_config, "initializer_range"
):
std = self.config.decoder_config.initializer_range
else:
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
class QWEN3VoxModel(QWEN3VoxPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
lm_config = config.decoder_config
self.language_model = AutoModel.from_config(lm_config)
self.acoustic_tokenizer = AutoModel.from_config(
config.acoustic_tokenizer_config
).to(dtype)
self.semantic_tokenizer = AutoModel.from_config(
config.semantic_tokenizer_config
).to(dtype)
self.acoustic_connector = SpeechConnector(
config.acoustic_vae_dim, lm_config.hidden_size
).to(dtype)
self.semantic_connector = SpeechConnector(
config.semantic_vae_dim, lm_config.hidden_size
).to(dtype)
self.register_buffer("speech_scaling_factor", torch.tensor(float("nan")))
self.register_buffer("speech_bias_factor", torch.tensor(float("nan")))
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(
dtype
)
self.noise_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
prediction_type=config.diffusion_head_config.prediction_type,
)
def get_input_embeddings(self):
if hasattr(self.language_model, "embed_tokens"):
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items():
if attr.orig_name == "embed_tokens.weight":
return getattr(self.language_model, name)
assert False, "should not arrive here"
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
self.acoustic_tokenizer = acoustic_tokenizer
self.semantic_tokenizer = semantic_tokenizer
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.train(False)
if self.semantic_tokenizer is not None:
self.semantic_tokenizer.train(False)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
if not return_dict:
return outputs
return BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class QWEN3VoxForConditionalGeneration(QWEN3VoxPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = QWEN3VoxModel(config)
self.vocab_size = config.decoder_config.vocab_size
self.lm_head = nn.Linear(
config.decoder_config.hidden_size, self.vocab_size, bias=False
)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_decoder(self, decoder):
self.model.language_model = decoder
def get_decoder(self):
return self.model.language_model
def tie_weights(self):
if getattr(self.config.decoder_config, "tie_word_embeddings", False):
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
if hasattr(input_embeddings, "weight"):
output_embeddings.weight = input_embeddings.weight
else:
output_embeddings.weight = input_embeddings
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = nn.functional.pad(
output_embeddings.bias.data,
(
0,
output_embeddings.weight.shape[0]
- output_embeddings.bias.shape[0],
),
"constant",
0,
)
print("Tied input and output embeddings using standard assignment.")
else:
print("tie_word_embeddings is False, not tying weights.")
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward_speech_features(
self,
speech_tensors=None,
speech_masks=None,
speech_type="audio",
return_unmask=False,
):
if speech_tensors is None:
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
audio_features = torch.zeros(1, 1, vae_dim).to(
self.get_input_embeddings().weight
)
connect_features = self.model.acoustic_connector(audio_features)
return (audio_features, connect_features)
else:
with torch.no_grad():
if speech_type == "audio":
with torch.no_grad():
frames = self.model.acoustic_tokenizer.encode(
speech_tensors.unsqueeze(1)
)[0][0]
audio_tokens = frames.sample(
self.model.acoustic_tokenizer.std_dist_type
)[0]
elif speech_type == "vae":
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
speech_mode = speech_tensors.reshape(
speech_tensors.size(0), -1, vae_dim
)
batch_size = speech_mode.size(0)
value = self.model.acoustic_tokenizer.fix_std / 0.8
std = (
torch.randn(
batch_size,
dtype=speech_mode.dtype,
device=speech_mode.device,
)
* value
)
std = std.view(-1, *[1] * (speech_mode.dim() - 1))
audio_tokens = speech_mode + std * torch.randn(
speech_mode.shape
).to(speech_mode)
else:
raise NotImplementedError(
f"Speech type {speech_type } not implemented"
)
if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(
self.model.speech_bias_factor
):
scaling_factor = 1.0 / audio_tokens[speech_masks].flatten().std()
bias_factor = -audio_tokens[speech_masks].flatten().mean()
if dist.is_available() and dist.is_initialized():
dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
world_size = dist.get_world_size()
self.model.speech_scaling_factor.copy_(
scaling_factor / world_size
)
self.model.speech_bias_factor.copy_(bias_factor / world_size)
print(
f"Speech scaling factor (distributed): {self .model .speech_scaling_factor }, bias factor: {self .model .speech_bias_factor }",
flush=True,
)
else:
self.model.speech_scaling_factor.copy_(scaling_factor)
self.model.speech_bias_factor.copy_(bias_factor)
print(
f"Speech scaling factor (single process): {self .model .speech_scaling_factor }, bias factor: {self .model .speech_bias_factor }",
flush=True,
)
audio_features = (
audio_tokens + self.model.speech_bias_factor
) * self.model.speech_scaling_factor
connect_features = self.model.acoustic_connector(audio_features)
if return_unmask:
return (audio_features, connect_features)
return (audio_features[speech_masks], connect_features[speech_masks])
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speeches_loss_input: Optional[torch.FloatTensor] = None,
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
acoustic_input_mask: Optional[torch.BoolTensor] = None,
acoustic_loss_mask: Optional[torch.BoolTensor] = None,
ddpm_batch_mul: int = 1,
**kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
) -> Union[Tuple, QWEN3VoxCausalLMOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
x = self.get_input_embeddings()(input_ids)
semantic_speech_all_connect_features = self.model.semantic_connector(
speech_semantic_tensors
)
if speeches_loss_input is not None:
speech_all_features, speech_all_connect_features = (
self.forward_speech_features(
speech_tensors=(
speech_tensors.type_as(x)
if speech_tensors is not None
else None
),
speech_masks=speech_masks,
speech_type=kwargs.get("speech_type", "audio"),
return_unmask=True,
)
)
if speech_tensors is not None:
if semantic_speech_all_connect_features is not None:
x[acoustic_input_mask] = (
speech_all_connect_features[speech_masks]
+ semantic_speech_all_connect_features[speech_masks]
)
else:
x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
target_latent_mask = speeches_loss_input & speech_masks
speech_features = speech_all_features[target_latent_mask]
speech_connect_features = speech_all_connect_features[
target_latent_mask
]
else:
speech_features, speech_connect_features = self.forward_speech_features(
speech_tensors=(
speech_tensors.type_as(x) if speech_tensors is not None else None
),
speech_masks=speech_masks,
speech_type=kwargs.get("speech_type", "audio"),
)
if speech_tensors is not None:
x[acoustic_input_mask] = speech_connect_features
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=x,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=False,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
pass
diffusion_loss = None
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
condition_features = hidden_states[acoustic_loss_mask]
speech_len, latent_size = speech_features.shape
noise = torch.randn(
(speech_len * ddpm_batch_mul, latent_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
timesteps = torch.multinomial(
torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
speech_len * ddpm_batch_mul,
replacement=True,
).to(hidden_states.device)
speech_features_repeated = speech_features.repeat_interleave(
ddpm_batch_mul, dim=0
)
condition_features_repeated = condition_features.repeat_interleave(
ddpm_batch_mul, dim=0
)
noisy_speech_features = self.model.noise_scheduler.add_noise(
speech_features_repeated, noise, timesteps
)
model_output = self.model.prediction_head(
noisy_speech_features, timesteps.type_as(x), condition_features_repeated
)
prediction_type = self.config.diffusion_head_config.prediction_type
if prediction_type == "epsilon":
target_for_loss = noise
elif prediction_type == "v_prediction":
target_for_loss = self.model.noise_scheduler.get_velocity(
speech_features_repeated, noise, timesteps
)
else:
raise NotImplementedError(
f"Prediction type {prediction_type } not implemented"
)
diffusion_loss = F.mse_loss(
model_output.float(), target_for_loss.float(), reduction="sum"
)
if latent_size > 0 and ddpm_batch_mul > 0:
diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
else:
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
else:
diffusion_loss = (
sum((p.sum() for p in self.model.prediction_head.parameters())) * 0.0
)
diffusion_loss += (
sum((p.sum() for p in self.model.acoustic_connector.parameters())) * 0.0
)
diffusion_loss += (
sum((p.sum() for p in self.model.semantic_connector.parameters())) * 0.0
)
if not return_dict:
output = (logits, speech_len) + outputs.to_tuple()[1:]
return (loss, diffusion_loss) + output
return QWEN3VoxCausalLMOutputWithPast(
loss=loss,
diffusion_loss=diffusion_loss,
speech_token_num=speech_len if speech_tensors is not None else 0,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
AutoModel.register(QWEN3VoxConfig, QWEN3VoxModel)
AutoModelForCausalLM.register(QWEN3VoxConfig, QWEN3VoxForConditionalGeneration)
__all__ = [
'QWEN3VoxModel',
'QWEN3VoxPreTrainedModel',
'QWEN3VoxForConditionalGeneration',
'QWEN3VoxCausalLMOutputWithPast',
'QWEN3VoxGenerationOutput',
]
'\nQWEN3Vox Processors\n\nThis module provides processors for preparing inputs for QWEN3Vox models:\n- QWEN3VoxProcessor: For multi-speaker models (1.5B, 7B)\n- QWEN3VoxStreamingProcessor: For streaming model (0.5B)\n'
__all__ = [
'QWEN3VoxProcessor',
'QWEN3VoxStreamingProcessor',
'QWEN3VoxTokenizerProcessor',
"AudioNormalizer",
'QWEN3VoxASRProcessor',
]
'\nQWEN3Vox Streaming Inference Model (0.5B)\n\nThis module implements the inference engine for real-time streaming TTS.\nKey features:\n- Window-based text/speech interleaving for streaming\n- Binary EOS classifier for end-of-speech detection\n- Classifier-free guidance for speech quality\n- Audio streaming support\n'
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.generation import (
GenerationMixin,
GenerationConfig,
LogitsProcessor,
LogitsProcessorList,
StoppingCriteriaList,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
logger = logging.get_logger(__name__)
if (
not hasattr(modeling_utils, "ALL_PARALLEL_STYLES")
or modeling_utils.ALL_PARALLEL_STYLES is None
):
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
TTS_TEXT_WINDOW_SIZE = 5
TTS_SPEECH_WINDOW_SIZE = 6
def _update_model_kwargs_for_generation(
outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1
) -> Dict[str, Any]:
model_kwargs["past_key_values"] = getattr(outputs, "past_key_values")
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], num_new_tokens)),
],
dim=-1,
)
model_kwargs["cache_position"] = torch.arange(
model_kwargs["cache_position"][-1] + 1,
model_kwargs["cache_position"][-1] + num_new_tokens + 1,
).to(model_kwargs["cache_position"].device)
return model_kwargs
@dataclass
class QWEN3VoxCausalLMOutputWithPast(BaseModelOutputWithPast):
logits: Optional[torch.FloatTensor] = None
@dataclass
class QWEN3VoxGenerationOutput(ModelOutput):
sequences: torch.LongTensor = None
speech_outputs: Optional[List[torch.FloatTensor]] = None
reach_max_step_sample: Optional[torch.BoolTensor] = None
class QWEN3VoxStreamingForConditionalGenerationInference(
QWEN3VoxStreamingPreTrainedModel, GenerationMixin
):
def __init__(self, config):
super().__init__(config)
self.model = QWEN3VoxStreamingModel(config)
self.tts_eos_classifier = BinaryClassifier(config.decoder_config.hidden_size)
self.ddpm_inference_steps = (
config.diffusion_head_config.ddpm_num_inference_steps
)
self.post_init()
@property
def noise_scheduler(self):
return self.model.noise_scheduler
@property
def prediction_head(self):
return self.model.prediction_head
@property
def speech_scaling_factor(self):
return self.model.speech_scaling_factor
@property
def speech_bias_factor(self):
return self.model.speech_bias_factor
@property
def acoustic_tokenizer(self):
return self.model.acoustic_tokenizer
@property
def acoustic_connector(self):
return self.model.acoustic_connector
def tie_weights(self):
if not getattr(self.config, "tie_word_embeddings", False):
return
if hasattr(self, "lm_head") and hasattr(
self.model.language_model, "embed_tokens"
):
self.lm_head.weight = self.model.language_model.embed_tokens.weight
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return None
def set_output_embeddings(self, new_embeddings):
raise RuntimeError(
"Output embeddings (lm_head) are not defined for this model. Create one before calling set_output_embeddings if needed."
)
def set_speech_tokenizers(self, acoustic_tokenizer=None):
self.model.set_speech_tokenizers(acoustic_tokenizer)
def set_ddpm_inference_steps(self, num_steps=None):
self.ddpm_inference_steps = (
num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
)
def forward_lm(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
if labels is not None:
raise NotImplementedError(
"Loss computation is not implemented in this version."
)
return BaseModelOutputWithPast(
past_key_values=outputs.past_key_values,
last_hidden_state=hidden_states,
attentions=outputs.attentions,
)
def forward_tts_lm(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
lm_last_hidden_state: Optional[torch.FloatTensor] = None,
tts_text_masks: Optional[torch.BoolTensor] = None,
**kwargs,
) -> Union[Tuple, QWEN3VoxCausalLMOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
start_idx = inputs_embeds.shape[1] - lm_last_hidden_state.shape[1]
inputs_embeds[:, start_idx:, :] = lm_last_hidden_state
inputs_embeds = inputs_embeds + self.model.tts_input_types(
tts_text_masks.long()
)
outputs = self.model.tts_language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
logits = self.tts_eos_classifier(hidden_states[:, -1, :])
if labels is not None:
raise NotImplementedError(
"Loss computation is not implemented in this version."
)
return QWEN3VoxCausalLMOutputWithPast(
logits=logits,
past_key_values=outputs.past_key_values,
last_hidden_state=hidden_states,
attentions=outputs.attentions,
)
def forward(self, *args, **kwargs):
raise RuntimeError(
"Unified forward is disabled. Use `forward_lm`, `forward_tts_lm`, or `generate` instead."
)
def _build_generate_config_model_kwargs(
self, generation_config, inputs, tokenizer, return_processors=False, **kwargs
):
if generation_config is None:
generation_config = GenerationConfig(
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
else:
generation_config = GenerationConfig(
**generation_config,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
generation_config, model_kwargs = self._prepare_generation_config(
generation_config,
True,
speech_start_id=tokenizer.speech_start_id,
speech_end_id=tokenizer.speech_end_id,
speech_diffusion_id=tokenizer.speech_diffusion_id,
**kwargs,
)
generation_config.speech_start_id = tokenizer.speech_start_id
generation_config.speech_end_id = tokenizer.speech_end_id
generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = self.device
self._prepare_special_tokens(generation_config, True, device=device)
generation_config.use_cache = True
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = inputs_tensor.to(self.device)
input_ids_length = input_ids.shape[1]
has_default_max_length = (
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
has_default_min_length = (
kwargs.get("min_length") is None
and generation_config.min_length is not None
)
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
max_cache_length = generation_config.max_length - 1
self._prepare_cache_for_generation(
generation_config, model_kwargs, None, batch_size, max_cache_length, device
)
model_kwargs["cache_position"] = torch.arange(
input_ids_length, device=device, dtype=torch.long
)
for k, v in model_kwargs.items():
if isinstance(v, torch.Tensor):
model_kwargs[k] = v.to(device=device)
if return_processors:
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=LogitsProcessorList(),
device=inputs_tensor.device,
model_kwargs=model_kwargs,
)
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=StoppingCriteriaList(),
)
return (
generation_config,
model_kwargs,
input_ids,
logits_processor,
stopping_criteria,
)
else:
return (generation_config, model_kwargs, input_ids)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
tts_text_ids: Optional[torch.LongTensor] = None,
return_speech: bool = True,
cfg_scale: float = 1.0,
stop_check_fn: Optional[Callable[[], bool]] = None,
**kwargs,
) -> Union[torch.LongTensor, QWEN3VoxGenerationOutput]:
tokenizer = kwargs.pop("tokenizer", None)
neg_text_input_id = tokenizer.convert_tokens_to_ids("<|image_pad|>")
tts_lm_input_ids = kwargs.pop("tts_lm_input_ids", None)
tts_lm_attention_mask = kwargs.pop("tts_lm_attention_mask", None)
all_prefilled_outputs = kwargs.pop("all_prefilled_outputs", None)
tts_text_ids = tts_text_ids.to(self.device)
if kwargs.get("max_new_tokens", None) is None:
kwargs["max_new_tokens"] = (
self.config.decoder_config.max_position_embeddings
- tts_lm_input_ids.shape[-1]
)
(
generation_config,
model_kwargs,
input_ids,
logits_processor,
stopping_criteria,
) = self._build_generate_config_model_kwargs(
generation_config, inputs, tokenizer, return_processors=True, **kwargs
)
negative_kwargs = {
"input_ids": torch.full(
(kwargs["input_ids"].shape[0], 1),
neg_text_input_id,
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"attention_mask": torch.ones(
(kwargs["input_ids"].shape[0], 1),
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
negative_generation_config, negative_model_kwargs, negative_input_ids = (
self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **negative_kwargs
)
)
tts_lm_kwargs = {
"input_ids": tts_lm_input_ids,
"attention_mask": tts_lm_attention_mask,
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
tts_lm_generation_config, tts_lm_model_kwargs, tts_lm_input_ids = (
self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **tts_lm_kwargs
)
)
tts_lm_negative_kwargs = {
"input_ids": torch.full(
(kwargs["input_ids"].shape[0], 1),
neg_text_input_id,
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"attention_mask": torch.ones(
(kwargs["input_ids"].shape[0], 1),
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
(
tts_lm_negative_generation_config,
tts_lm_negative_model_kwargs,
tts_lm_negative_input_ids,
) = self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **tts_lm_negative_kwargs
)
acoustic_cache = QWEN3VoxTokenizerStreamingCache()
batch_size = input_ids.shape[0]
assert batch_size == 1, "Currently only supports batch size == 1"
device = input_ids.device
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
verbose = kwargs.get("verbose", False)
audio_chunks = [[] for _ in range(batch_size)]
tts_text_window_index = 0
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
first_text_window_size = (
TTS_TEXT_WINDOW_SIZE
if tts_text_ids.shape[1] >= TTS_TEXT_WINDOW_SIZE
else tts_text_ids.shape[1]
)
outputs = all_prefilled_outputs["lm"]
tts_lm_outputs = all_prefilled_outputs["tts_lm"]
negative_outputs = all_prefilled_outputs["neg_lm"]
tts_lm_negative_outputs = all_prefilled_outputs["neg_tts_lm"]
model_kwargs = _update_model_kwargs_for_generation(
outputs, model_kwargs, num_new_tokens=first_text_window_size
)
tts_lm_model_kwargs = _update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=first_text_window_size
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs, negative_model_kwargs, is_encoder_decoder=False
)
tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_negative_outputs,
tts_lm_negative_model_kwargs,
is_encoder_decoder=False,
)
step = tts_lm_input_ids.shape[1]
total_generated_speech_tokens = 0
total_prefilled_text_tokens = 0
if kwargs.get("show_progress_bar", True):
progress_bar = tqdm(
total=tts_lm_generation_config.max_length,
desc=f"Prefilled {step } tokens, current step ({step } / {tts_lm_generation_config .max_length })",
initial=step,
leave=False,
)
else:
progress_bar = None
while True:
if stop_check_fn is not None and stop_check_fn():
if verbose:
print(f"Generation stopped externally at step {step +1 }")
if audio_streamer is not None:
audio_streamer.end()
break
if finished_tags.all():
if hasattr(progress_bar, "set_description"):
progress_bar.set_description("Generation complete")
break
cur_input_tts_text_ids = tts_text_ids[
:,
tts_text_window_index
* TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 1)
* TTS_TEXT_WINDOW_SIZE,
]
next_text_window_size = tts_text_ids[
:,
(tts_text_window_index + 1)
* TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 2)
* TTS_TEXT_WINDOW_SIZE,
].shape[1]
tts_text_window_index += 1
if cur_input_tts_text_ids.shape[1] > 0:
input_ids = torch.cat([input_ids, cur_input_tts_text_ids], dim=-1)
tts_lm_input_ids = torch.cat(
[tts_lm_input_ids, cur_input_tts_text_ids], dim=-1
)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
if verbose:
print(
f"Reached maximum generation length {generation_config .max_length }, stopped it."
)
reached_samples = torch.arange(batch_size, device=device)[
~finished_tags
]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
step += cur_input_tts_text_ids.shape[1]
total_prefilled_text_tokens += cur_input_tts_text_ids.shape[1]
if progress_bar is not None:
progress_bar.update(cur_input_tts_text_ids.shape[1])
progress_bar.set_description(
f"Prefilled {total_prefilled_text_tokens } text tokens, generated {total_generated_speech_tokens } speech tokens, current step ({step } / {tts_lm_generation_config .max_length })"
)
model_inputs = self.prepare_inputs_for_generation(
input_ids, **model_kwargs
)
outputs = self.forward_lm(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
model_kwargs = _update_model_kwargs_for_generation(
outputs, model_kwargs, num_new_tokens=next_text_window_size
)
tts_lm_model_inputs = self.prepare_inputs_for_generation(
tts_lm_input_ids, **tts_lm_model_kwargs
)
tts_lm_additional_inputs = {
"tts_text_masks": torch.ones_like(tts_lm_input_ids[:, -1:]),
"lm_last_hidden_state": outputs.last_hidden_state,
}
tts_lm_outputs = self.forward_tts_lm(
**tts_lm_model_inputs,
**tts_lm_additional_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
tts_lm_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False
)
diffusion_indices = torch.LongTensor([0])
for cur_speech_index in range(TTS_SPEECH_WINDOW_SIZE):
positive_condition = tts_lm_outputs.last_hidden_state[
diffusion_indices, -1, :
]
negative_condition = tts_lm_negative_outputs.last_hidden_state[
diffusion_indices, -1, :
]
speech_latent = self.sample_speech_tokens(
positive_condition, negative_condition, cfg_scale=cfg_scale
).unsqueeze(1)
scaled_latent = speech_latent / self.model.speech_scaling_factor.to(
speech_latent.device
) - self.model.speech_bias_factor.to(speech_latent.device)
audio_chunk = self.model.acoustic_tokenizer.decode(
scaled_latent.to(self.model.acoustic_tokenizer.device),
cache=acoustic_cache,
sample_indices=diffusion_indices.to(
self.model.acoustic_tokenizer.device
),
use_cache=True,
debug=False,
)
for i, sample_idx in enumerate(diffusion_indices):
idx = sample_idx.item()
if not finished_tags[idx]:
audio_chunks[idx].append(audio_chunk[i])
if audio_streamer is not None:
audio_streamer.put(audio_chunk, diffusion_indices)
acoustic_embed = self.model.acoustic_connector(speech_latent)
tts_lm_input_ids = torch.cat(
[tts_lm_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])],
dim=-1,
)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
break
step += 1
total_generated_speech_tokens += 1
if progress_bar is not None:
progress_bar.update(1)
progress_bar.set_description(
f"Prefilled {total_prefilled_text_tokens } text tokens, generated {total_generated_speech_tokens } speech tokens, current step ({step } / {tts_lm_generation_config .max_length })"
)
tts_lm_model_inputs = self.prepare_inputs_for_generation(
tts_lm_input_ids, **tts_lm_model_kwargs
)
tts_lm_additional_inputs = {
"tts_text_masks": torch.zeros_like(tts_lm_input_ids[:, -1:]),
"lm_last_hidden_state": acoustic_embed,
}
tts_lm_outputs = self.forward_tts_lm(
**tts_lm_model_inputs,
**tts_lm_additional_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
if (
cur_speech_index == TTS_SPEECH_WINDOW_SIZE - 1
and next_text_window_size > 0
):
tts_lm_model_kwargs = _update_model_kwargs_for_generation(
tts_lm_outputs,
tts_lm_model_kwargs,
num_new_tokens=next_text_window_size,
)
else:
tts_lm_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False
)
tts_lm_negative_input_ids = torch.cat(
[
tts_lm_negative_input_ids,
torch.ones_like(tts_lm_input_ids[:, -1:]),
],
dim=-1,
)
tts_lm_negative_model_inputs = self.prepare_inputs_for_generation(
tts_lm_negative_input_ids, **tts_lm_negative_model_kwargs
)
tts_lm_negative_additional_inputs = {
"tts_text_masks": torch.zeros_like(
tts_lm_negative_input_ids[:, -1:]
),
"lm_last_hidden_state": acoustic_embed,
}
tts_lm_negative_outputs = self.forward_tts_lm(
**tts_lm_negative_model_inputs,
**tts_lm_negative_additional_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_negative_outputs,
tts_lm_negative_model_kwargs,
is_encoder_decoder=False,
)
tts_eos_logits = torch.sigmoid(
self.tts_eos_classifier(
tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :]
)
)
if tts_eos_logits[0].item() > 0.5:
finished_tags[diffusion_indices] = True
if audio_streamer is not None:
audio_streamer.end(diffusion_indices)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
if verbose:
print(
f"Reached maximum generation length {tts_lm_generation_config .max_length }, stopped it."
)
reached_samples = torch.arange(batch_size, device=device)[
~finished_tags
]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
if audio_streamer is not None:
audio_streamer.end()
final_audio_outputs = []
for sample_chunks in audio_chunks:
if sample_chunks:
concatenated_audio = torch.cat(sample_chunks, dim=-1)
final_audio_outputs.append(concatenated_audio)
else:
final_audio_outputs.append(None)
if reach_max_step_sample is not None and reach_max_step_sample.any():
print(
f"Reached maximum generation length {tts_lm_generation_config .max_length }, stopped it."
)
return QWEN3VoxGenerationOutput(
sequences=tts_lm_input_ids,
speech_outputs=final_audio_outputs if return_speech else None,
reach_max_step_sample=reach_max_step_sample,
)
@torch.no_grad()
def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
condition = torch.cat([condition, neg_condition], dim=0).to(
self.model.prediction_head.device
)
speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(
condition
)
for t in self.model.noise_scheduler.timesteps:
half = speech[: len(speech) // 2]
combined = torch.cat([half, half], dim=0)
eps = self.model.prediction_head(
combined, t.repeat(combined.shape[0]).to(combined), condition=condition
)
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
return speech[: len(speech) // 2]
AutoModelForCausalLM.register(
QWEN3VoxStreamingConfig, QWEN3VoxStreamingForConditionalGenerationInference
)
__all__ = [
'QWEN3VoxStreamingForConditionalGenerationInference',
'QWEN3VoxGenerationOutput',
'QWEN3VoxCausalLMOutputWithPast',
"TTS_TEXT_WINDOW_SIZE",
"TTS_SPEECH_WINDOW_SIZE",
]
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, DatasetDict, VerificationMode
from transformers import HfArgumentParser, Trainer, set_seed, TrainerCallback
from transformers import TrainingArguments as HfTrainingArguments
from peft import LoraConfig, get_peft_model, TaskType
logger = logging.getLogger(__name__)
import copy
import torch
from transformers import TrainerCallback
class EmaCallback(TrainerCallback):
def __init__(self, attr_path="model.prediction_head", decay=0.999, device="cpu"):
self.attr_path = attr_path
self.decay = float(decay)
self.device = torch.device(device)
self.shadow = None
self._orig = None
def _get_module(self, model):
mod = model
for name in self.attr_path.split("."):
mod = getattr(mod, name)
return mod
def on_train_begin(self, args, state, control, model=None, **kwargs):
head = self._get_module(model)
self.shadow = {
k: p.detach().to(self.device).clone() for k, p in head.state_dict().items()
}
def on_step_end(self, args, state, control, model=None, **kwargs):
if self.shadow is None:
return
head = self._get_module(model)
with torch.no_grad():
for k, v in head.state_dict().items():
self.shadow[k].mul_(self.decay).add_(
v.detach().to(self.device), alpha=1.0 - self.decay
)
def _swap_in_ema(self, model):
head = self._get_module(model)
self._orig = copy.deepcopy(head.state_dict())
head.load_state_dict(self.shadow, strict=False)
def _swap_back(self, model):
if self._orig is None:
return
head = self._get_module(model)
head.load_state_dict(self._orig, strict=False)
self._orig = None
def on_evaluate(self, args, state, control, model=None, **kwargs):
self._swap_in_ema(model)
def on_evaluate_end(self, args, state, control, model=None, **kwargs):
self._swap_back(model)
def on_save(self, args, state, control, model=None, **kwargs):
self._swap_in_ema(model)
def on_save_end(self, args, state, control, model=None, **kwargs):
self._swap_back(model)
def on_train_end(self, args, state, control, model=None, **kwargs):
self._swap_in_ema(model)
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": 'Path to QWEN3Vox base model with config.json'
},
)
processor_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "Path to processor dir (preprocessor_config.json). Defaults to model path."
},
)
cache_dir: Optional[str] = field(default=None)
freeze_acoustic_tokenizer: bool = field(default=True)
freeze_semantic_tokenizer: bool = field(default=True)
lora_r: int = field(default=8)
lora_alpha: int = field(default=32)
lora_dropout: float = field(default=0.05)
lora_target_modules: str = field(
default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
metadata={
"help": "Comma-separated list of target module names in the LLM blocks"
},
)
lora_wrap_diffusion_head: bool = field(
default=False, metadata={"help": "Wrap diffusion head with PEFT LoRA"}
)
train_diffusion_head: bool = field(
default=False,
metadata={"help": "Train diffusion prediction head (full fine-tune)"},
)
train_connectors: bool = field(
default=False,
metadata={"help": "Train acoustic/semantic connectors (full fine-tune)"},
)
layers_to_freeze: Optional[str] = field(
default=None,
metadata={
"help": "Comma-separated indices of diffusion head layers to freeze (e.g., '0,1,5,7,8')."
},
)
@dataclass
class DataArguments:
dataset_name: Optional[str] = field(
default=None,
metadata={
"help": "HF dataset name or 'json' with --train_jsonl for local files"
},
)
dataset_config_name: Optional[str] = field(default=None)
train_split_name: str = field(default="train")
eval_split_name: Optional[str] = field(default="validation")
text_column_name: str = field(default="text")
audio_column_name: str = field(default="audio")
voice_prompts_column_name: Optional[str] = field(default="voice_prompts")
eval_split_size: float = field(default=0.0)
ignore_verifications: bool = field(default=False)
max_length: Optional[int] = field(default=None)
train_jsonl: Optional[str] = field(
default=None,
metadata={
"help": "Path to local train JSONL with {text, audio, [voice_prompts]}"
},
)
validation_jsonl: Optional[str] = field(
default=None, metadata={"help": "Optional path to local validation JSONL"}
)
voice_prompt_drop_rate: float = field(
default=0.0,
metadata={
"help": "Probability to drop conditioning voice prompt during training (0.0 keep always, 1.0 drop always)."
},
)
@dataclass
class CustomTrainingArguments(HfTrainingArguments):
ddpm_batch_mul: int = field(default=1)
ce_loss_weight: float = field(default=1.0)
diffusion_loss_weight: float = field(default=1.0)
debug_ce_details: bool = field(default=False)
debug_ce_topk: int = field(default=5)
debug_ce_max_examples: int = field(default=1)
debug_ce_every_n_steps: int = field(default=200)
gradient_clipping: bool = field(
default=False,
metadata={
"help": "Enable gradient clipping using max_grad_norm (set via --max_grad_norm, default 1.0). When False, disables clipping by forcing max_grad_norm=0.0."
},
)
debug_save: bool = field(
default=False,
metadata={
"help": "If set, saves model components BEFORE training starts, into output_dir/debug_initial."
},
)
def build_lora_config(args: ModelArguments) -> LoraConfig:
target_modules = [
s.strip() for s in args.lora_target_modules.split(",") if s.strip()
]
return LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
)
def build_head_lora_config(args: ModelArguments) -> LoraConfig:
target_modules = [
"noisy_images_proj",
"cond_proj",
"gate_proj",
"up_proj",
"down_proj",
"linear",
]
return LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type=TaskType.FEATURE_EXTRACTION,
target_modules=target_modules,
)
def mask_for_ce(
labels: torch.Tensor,
attention_mask: torch.Tensor,
acoustic_input_mask: torch.Tensor,
pad_id: int = -100,
) -> torch.Tensor:
shifted = labels[:, 1:].contiguous()
base_mask = (
attention_mask[:, 1:].contiguous().eq(1)
if attention_mask is not None and attention_mask.numel() > 0
else torch.ones_like(shifted, dtype=torch.bool)
)
label_is_acoustic = acoustic_input_mask[:, 1:].contiguous()
final_mask = base_mask & ~label_is_acoustic
out = shifted.clone()
out[~final_mask] = pad_id
return out
def _patch_acoustic_encode_for_legacy_indexing(model_obj, logger_):
try:
acoustic = getattr(
getattr(model_obj, "model", model_obj), "acoustic_tokenizer", None
)
if acoustic is None or not hasattr(acoustic, "encode"):
logger_.warning("No acoustic_tokenizer.encode() found to patch.")
return
base_encode = acoustic.encode
def encode_wrapped(*args, **kwargs):
out = base_encode(*args, **kwargs)
try:
_ = out[0][0]
return out
except Exception:
pass
if isinstance(out, dict):
for k in ("frames", "codes", "tokens", "latents", "hidden_states"):
if k in out:
return [[out[k]]]
if len(out) > 0:
return [[next(iter(out.values()))]]
for attr in ("frames", "codes", "tokens", "latents", "hidden_states"):
if hasattr(out, attr):
return [[getattr(out, attr)]]
try:
if isinstance(out, torch.Tensor):
return [[out]]
except Exception:
pass
return [[out]]
acoustic.encode = encode_wrapped
logger_.info(
"Patched acoustic_tokenizer.encode() to return [[...]] for legacy indexing."
)
except Exception as e:
logger_.warning(f"Failed to patch acoustic_tokenizer.encode(): {e }")
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.generation import GenerationMixin
logger = logging.get_logger(__name__)
if (
not hasattr(modeling_utils, "ALL_PARALLEL_STYLES")
or modeling_utils.ALL_PARALLEL_STYLES is None
):
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
class QWEN3VoxASRPreTrainedModel(PreTrainedModel):
config_class = QWEN3VoxASRConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if hasattr(self.config, "language_model_config") and hasattr(
self.config.language_model_config, "initializer_range"
):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, "decoder_config") and hasattr(
self.config.decoder_config, "initializer_range"
):
std = self.config.decoder_config.initializer_range
else:
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
class QWEN3VoxASRModel(QWEN3VoxASRPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
lm_config = config.decoder_config
self.language_model = AutoModel.from_config(lm_config)
self.acoustic_tokenizer = AutoModel.from_config(
config.acoustic_tokenizer_config
).to(dtype)
self.semantic_tokenizer = AutoModel.from_config(
config.semantic_tokenizer_config
).to(dtype)
self.acoustic_connector = SpeechConnector(
config.acoustic_vae_dim, lm_config.hidden_size
).to(dtype)
self.semantic_connector = SpeechConnector(
config.semantic_vae_dim, lm_config.hidden_size
).to(dtype)
def get_input_embeddings(self):
if hasattr(self.language_model, "embed_tokens"):
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items():
if attr.orig_name == "embed_tokens.weight":
return getattr(self.language_model, name)
assert False, "should not arrive here"
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
self.acoustic_tokenizer = acoustic_tokenizer
self.semantic_tokenizer = semantic_tokenizer
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.train(False)
if self.semantic_tokenizer is not None:
self.semantic_tokenizer.train(False)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
if not return_dict:
return outputs
return BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class QWEN3VoxASRForConditionalGeneration(QWEN3VoxASRPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = QWEN3VoxASRModel(config)
self.vocab_size = config.decoder_config.vocab_size
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
self.lm_head = nn.Linear(
config.decoder_config.hidden_size, self.vocab_size, bias=False
).to(dtype)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.language_model = decoder
def get_decoder(self):
return self.model.language_model
def tie_weights(self):
if getattr(self.config.decoder_config, "tie_word_embeddings", False):
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
if hasattr(input_embeddings, "weight"):
output_embeddings.weight = input_embeddings.weight
else:
output_embeddings.weight = input_embeddings
def encode_speech(
self,
speech_tensors: torch.FloatTensor,
speech_masks: Optional[torch.BoolTensor] = None,
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
streaming_segment_duration: float = 60.0,
):
if hasattr(self.config, "torch_dtype") and self.config.torch_dtype is not None:
if isinstance(self.config.torch_dtype, str):
dtype = getattr(torch, self.config.torch_dtype)
else:
dtype = self.config.torch_dtype
else:
dtype = torch.float32
speech_tensors = speech_tensors.to(dtype)
if speech_tensors.ndim == 1:
speech_tensors = speech_tensors.unsqueeze(0)
batch_size, total_samples = speech_tensors.shape
sample_rate = 22050
segment_samples = int(streaming_segment_duration * sample_rate)
use_streaming = total_samples > segment_samples
with torch.no_grad():
if not use_streaming:
encoder_output = self.model.acoustic_tokenizer.encode(
speech_tensors.unsqueeze(1)
)
audio_tokens = encoder_output.sample(
dist_type=self.model.acoustic_tokenizer.std_dist_type
)[0]
acoustic_features = self.model.acoustic_connector(audio_tokens)
if speech_semantic_tensors is not None:
semantic_features = self.model.semantic_connector(
speech_semantic_tensors
)
else:
semantic_tokens = self.model.semantic_tokenizer.encode(
speech_tensors.unsqueeze(1)
).mean
semantic_features = self.model.semantic_connector(semantic_tokens)
else:
acoustic_encoder_cache = QWEN3VoxTokenizerStreamingCache()
semantic_encoder_cache = QWEN3VoxTokenizerStreamingCache()
acoustic_mean_segments = []
semantic_mean_segments = []
sample_indices = torch.arange(batch_size, device=speech_tensors.device)
def _iter_segments(total_length: int, segment_length: int):
if segment_length <= 0:
raise ValueError("segment_length must be positive")
for start in range(0, total_length, segment_length):
end = min(start + segment_length, total_length)
if end > start:
yield (start, end)
segments = list(_iter_segments(total_samples, segment_samples))
num_segments = len(segments)
for seg_idx, (start, end) in enumerate(segments):
chunk = speech_tensors[:, start:end].contiguous()
if chunk.numel() == 0:
continue
is_final = seg_idx == num_segments - 1
acoustic_encoder_output = self.model.acoustic_tokenizer.encode(
chunk.unsqueeze(1),
cache=acoustic_encoder_cache,
sample_indices=sample_indices,
use_cache=True,
is_final_chunk=is_final,
)
acoustic_mean_segments.append(acoustic_encoder_output.mean)
semantic_encoder_output = self.model.semantic_tokenizer.encode(
chunk.unsqueeze(1),
cache=semantic_encoder_cache,
sample_indices=sample_indices,
use_cache=True,
is_final_chunk=is_final,
)
semantic_mean_segments.append(semantic_encoder_output.mean)
acoustic_mean_full = torch.cat(
acoustic_mean_segments, dim=1
).contiguous()
acoustic_encoder_output = QWEN3VoxTokenizerEncoderOutput(
mean=acoustic_mean_full, std=self.model.acoustic_tokenizer.fix_std
)
audio_tokens = acoustic_encoder_output.sample(
dist_type=self.model.acoustic_tokenizer.std_dist_type
)[0]
acoustic_features = self.model.acoustic_connector(audio_tokens)
semantic_tokens = torch.cat(semantic_mean_segments, dim=1).contiguous()
semantic_features = self.model.semantic_connector(semantic_tokens)
if speech_masks is not None:
combined_features = (
acoustic_features[speech_masks] + semantic_features[speech_masks]
)
else:
combined_features = acoustic_features + semantic_features
return combined_features
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
acoustic_input_mask: Optional[torch.BoolTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutput]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if inputs_embeds is None and input_ids is not None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if speech_tensors is not None and acoustic_input_mask is not None:
speech_features = self.encode_speech(
speech_tensors=speech_tensors,
speech_masks=speech_masks,
speech_semantic_tensors=speech_semantic_tensors,
)
inputs_embeds[acoustic_input_mask] = speech_features
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return QWEN3VoxCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
speech_tensors=None,
speech_masks=None,
speech_semantic_tensors=None,
acoustic_input_mask=None,
**kwargs,
):
if past_key_values is not None:
if isinstance(past_key_values, tuple):
past_length = past_key_values[0][0].shape[2]
else:
past_length = past_key_values.get_seq_length()
if input_ids is not None and input_ids.shape[1] > past_length:
input_ids = input_ids[:, past_length:]
if position_ids is None and attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values is not None and input_ids is not None:
position_ids = position_ids[:, -input_ids.shape[1] :]
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens
+ (
input_ids.shape[1]
if input_ids is not None
else inputs_embeds.shape[1]
),
device=(
input_ids.device if input_ids is not None else inputs_embeds.device
),
)
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
if (
cache_position is not None
and len(cache_position) > 0
and (cache_position[0] == 0)
):
model_inputs.update(
{
"speech_tensors": speech_tensors,
"speech_masks": speech_masks,
"speech_semantic_tensors": speech_semantic_tensors,
"acoustic_input_mask": acoustic_input_mask,
}
)
else:
model_inputs.update(
{
"speech_tensors": None,
"speech_masks": None,
"speech_semantic_tensors": None,
"acoustic_input_mask": None,
}
)
model_inputs.update(kwargs)
return model_inputs
AutoModel.register(QWEN3VoxASRConfig, QWEN3VoxASRModel)
AutoModelForCausalLM.register(QWEN3VoxASRConfig, QWEN3VoxASRForConditionalGeneration)
__all__ = [
'QWEN3VoxASRPreTrainedModel',
'QWEN3VoxASRModel',
'QWEN3VoxASRForConditionalGeneration',
]
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.generation import (
GenerationMixin,
GenerationConfig,
LogitsProcessor,
LogitsProcessorList,
StoppingCriteriaList,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
logger = logging.get_logger(__name__)
if (
not hasattr(modeling_utils, "ALL_PARALLEL_STYLES")
or modeling_utils.ALL_PARALLEL_STYLES is None
):
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
@dataclass
class QWEN3VoxCausalLMOutputWithPast(BaseModelOutputWithPast):
logits: Optional[torch.FloatTensor] = None
@dataclass
class QWEN3VoxGenerationOutput(ModelOutput):
sequences: torch.LongTensor = None
speech_outputs: Optional[List[torch.FloatTensor]] = None
reach_max_step_sample: Optional[torch.BoolTensor] = None
class QWEN3VoxTokenConstraintProcessor(LogitsProcessor):
def __init__(self, valid_token_ids: List[int], device: torch.device = None):
self.valid_token_ids = torch.tensor(
valid_token_ids, dtype=torch.long, device=device
)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
mask = torch.full_like(scores, float("-inf"))
mask[:, self.valid_token_ids] = 0
scores = scores + mask
return scores
class QWEN3VoxForConditionalGenerationInference(QWEN3VoxPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = QWEN3VoxModel(config)
self.lm_head = nn.Linear(
config.decoder_config.hidden_size,
config.decoder_config.vocab_size,
bias=False,
)
self.ddpm_inference_steps = (
config.diffusion_head_config.ddpm_num_inference_steps
)
self.post_init()
@property
def noise_scheduler(self):
return self.model.noise_scheduler
@property
def prediction_head(self):
return self.model.prediction_head
@property
def speech_scaling_factor(self):
return self.model.speech_scaling_factor
@property
def speech_bias_factor(self):
return self.model.speech_bias_factor
@property
def acoustic_tokenizer(self):
return self.model.acoustic_tokenizer
@property
def semantic_tokenizer(self):
return self.model.semantic_tokenizer
@property
def acoustic_connector(self):
return self.model.acoustic_connector
@property
def semantic_connector(self):
return self.model.semantic_connector
def tie_weights(self):
if not getattr(self.config, "tie_word_embeddings", False):
return
if hasattr(self, "lm_head") and hasattr(
self.model.language_model, "embed_tokens"
):
self.lm_head.weight = self.model.language_model.embed_tokens.weight
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer)
def set_ddpm_inference_steps(self, num_steps=None):
self.ddpm_inference_steps = (
num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
)
def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"):
with torch.no_grad():
if speech_type == "audio":
encoder_output = self.model.acoustic_tokenizer.encode(
speech_tensors.unsqueeze(1)
)
acoustic_latents = encoder_output.sample(
dist_type=self.model.acoustic_tokenizer.std_dist_type
)[0]
acoustic_features = (
acoustic_latents
+ self.model.speech_bias_factor.to(acoustic_latents.device)
) * self.model.speech_scaling_factor.to(acoustic_latents.device)
acoustic_connected = self.model.acoustic_connector(acoustic_features)[
speech_masks.cpu()
]
return (acoustic_features, acoustic_connected)
elif speech_type == "pt":
encoder_output = QWEN3VoxTokenizerEncoderOutput(
mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std
)
acoustic_latents = encoder_output.sample(
dist_type=self.model.acoustic_tokenizer.std_dist_type
)[0]
acoustic_features = (
acoustic_latents
+ self.model.speech_bias_factor.to(acoustic_latents.device)
) * self.model.speech_scaling_factor.to(acoustic_latents.device)
acoustic_connected = self.model.acoustic_connector(acoustic_features)[
speech_masks.cpu()
]
return (acoustic_features, acoustic_connected)
else:
raise NotImplementedError(f"Speech type {speech_type } not implemented")
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
logits_to_keep: Union[int, slice] = 0,
**kwargs,
) -> Union[Tuple, QWEN3VoxCausalLMOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if speech_tensors is not None and speech_masks is not None:
acoustic_features, speech_embeds = self._process_speech_inputs(
speech_tensors.to(self.dtype), speech_masks
)
if speech_input_mask is not None:
inputs_embeds[speech_input_mask] = speech_embeds
outputs = self.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
raise NotImplementedError(
"Loss computation is not implemented in this version."
)
return QWEN3VoxCausalLMOutputWithPast(
logits=logits,
past_key_values=outputs.past_key_values,
last_hidden_state=hidden_states,
attentions=outputs.attentions,
)
def _build_generate_config_model_kwargs(
self, generation_config, inputs, tokenizer, return_processors=False, **kwargs
):
if generation_config is None:
generation_config = GenerationConfig(
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
else:
generation_config = GenerationConfig(
**generation_config,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
generation_config, model_kwargs = self._prepare_generation_config(
generation_config,
True,
speech_start_id=tokenizer.speech_start_id,
speech_end_id=tokenizer.speech_end_id,
speech_diffusion_id=tokenizer.speech_diffusion_id,
**kwargs,
)
generation_config.speech_start_id = tokenizer.speech_start_id
generation_config.speech_end_id = tokenizer.speech_end_id
generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = self.device
self._prepare_special_tokens(generation_config, True, device=device)
generation_config.use_cache = True
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = inputs_tensor.to(self.device)
input_ids_length = input_ids.shape[1]
has_default_max_length = (
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
has_default_min_length = (
kwargs.get("min_length") is None
and generation_config.min_length is not None
)
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
max_cache_length = generation_config.max_length - 1
self._prepare_cache_for_generation(
generation_config, model_kwargs, None, batch_size, max_cache_length, device
)
model_kwargs["cache_position"] = torch.arange(
input_ids_length, device=device, dtype=torch.long
)
for k, v in model_kwargs.items():
if isinstance(v, torch.Tensor):
model_kwargs[k] = v.to(device=device)
if return_processors:
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=LogitsProcessorList(),
device=inputs_tensor.device,
model_kwargs=model_kwargs,
)
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=StoppingCriteriaList(),
)
return (
generation_config,
model_kwargs,
input_ids,
logits_processor,
stopping_criteria,
)
else:
return (generation_config, model_kwargs, input_ids)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
is_prefill: bool = True,
return_speech: bool = True,
cfg_scale: float = 1.0,
stop_check_fn: Optional[Callable[[], bool]] = None,
tqdm_class: Optional[type] = None,
**kwargs,
) -> Union[torch.LongTensor, QWEN3VoxGenerationOutput]:
tokenizer = kwargs.pop("tokenizer", None)
parsed_scripts = kwargs.pop("parsed_scripts", None)
all_speakers_list = kwargs.pop("all_speakers_list", None)
max_length_times = kwargs.pop("max_length_times", 2)
if kwargs.get("max_new_tokens", None) is None:
kwargs["max_new_tokens"] = (
self.config.decoder_config.max_position_embeddings
- kwargs["input_ids"].shape[-1]
)
(
generation_config,
model_kwargs,
input_ids,
logits_processor,
stopping_criteria,
) = self._build_generate_config_model_kwargs(
generation_config, inputs, tokenizer, return_processors=True, **kwargs
)
negative_kwargs = {
"input_ids": torch.full(
(kwargs["input_ids"].shape[0], 1),
tokenizer.speech_start_id,
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"attention_mask": torch.ones(
(kwargs["input_ids"].shape[0], 1),
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
negative_generation_config, negative_model_kwargs, negative_input_ids = (
self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **negative_kwargs
)
)
acoustic_cache = QWEN3VoxTokenizerStreamingCache()
semantic_cache = QWEN3VoxTokenizerStreamingCache()
batch_size = input_ids.shape[0]
device = input_ids.device
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
inputs_embeds = None
verbose = kwargs.get("verbose", False)
audio_chunks = [[] for _ in range(batch_size)]
initial_length = input_ids.shape[-1]
initial_length_per_sample = model_kwargs["attention_mask"].sum(dim=-1)
valid_tokens = [
generation_config.speech_start_id,
generation_config.speech_end_id,
generation_config.speech_diffusion_id,
generation_config.eos_token_id,
]
if (
hasattr(generation_config, "bos_token_id")
and generation_config.bos_token_id is not None
):
valid_tokens.append(generation_config.bos_token_id)
token_constraint_processor = QWEN3VoxTokenConstraintProcessor(
valid_tokens, device=device
)
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(token_constraint_processor)
max_steps = min(
generation_config.max_length - initial_length,
int(max_length_times * initial_length),
)
max_step_per_sample = torch.min(
generation_config.max_length - initial_length_per_sample,
(max_length_times * initial_length_per_sample).long(),
)
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
if kwargs.get("show_progress_bar", True):
tqdm_fn = tqdm_class if tqdm_class is not None else tqdm
progress_bar = tqdm_fn(range(max_steps), desc="Generating", leave=False)
else:
progress_bar = range(max_steps)
for step in progress_bar:
if stop_check_fn is not None and stop_check_fn():
if verbose:
print(f"Generation stopped externally at step {step +1 }")
if audio_streamer is not None:
audio_streamer.end()
break
if audio_streamer is not None and hasattr(audio_streamer, "finished_flags"):
if any(audio_streamer.finished_flags):
if verbose:
print(f"Audio generation stopped externally at step {step +1 }")
break
if finished_tags.all():
if hasattr(progress_bar, "set_description"):
progress_bar.set_description("Generation complete")
break
if input_ids.shape[-1] >= generation_config.max_length:
print(
f"Reached maximum generation length {generation_config .max_length }, stopped it."
)
reached_samples = torch.arange(batch_size, device=device)[
~finished_tags
]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
if hasattr(progress_bar, "set_description"):
active_samples = (~finished_tags).sum().item()
progress_bar.set_description(
f"Generating (active: {active_samples }/{batch_size })"
)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
if is_prefill:
prefill_inputs = {}
if speech_tensors is not None:
prefill_inputs["speech_tensors"] = speech_tensors.to(device=device)
if speech_masks is not None:
prefill_inputs["speech_masks"] = speech_masks.to(device)
if speech_input_mask is not None:
prefill_inputs["speech_input_mask"] = speech_input_mask.to(device)
is_prefill = False
else:
_ = model_inputs.pop("inputs_embeds", None)
prefill_inputs = {"inputs_embeds": inputs_embeds}
outputs = self(
**model_inputs,
**prefill_inputs,
logits_to_keep=1,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False
)
next_token_logits = outputs.logits[:, -1, :].to(
copy=True, dtype=torch.float32, device=input_ids.device
)
next_token_scores = logits_processor(input_ids, next_token_logits)
if generation_config.do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
next_tokens[finished_tags] = generation_config.eos_token_id
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if not kwargs.get("refresh_negative", True):
negative_model_inputs = self.prepare_inputs_for_generation(
negative_input_ids, **negative_model_kwargs
)
if (
negative_model_inputs["inputs_embeds"] is None
and inputs_embeds is not None
):
negative_model_inputs["inputs_embeds"] = inputs_embeds
negative_model_inputs["input_ids"] = None
negative_outputs = self(
**negative_model_inputs,
logits_to_keep=0,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs, negative_model_kwargs, is_encoder_decoder=False
)
negative_input_ids = torch.cat(
[negative_input_ids, next_tokens[:, None]], dim=-1
)
if (next_tokens == generation_config.eos_token_id).any():
eos_indices = (
(next_tokens == generation_config.eos_token_id)
.nonzero(as_tuple=False)
.squeeze(1)
)
new_eos_indices = eos_indices[~finished_tags[eos_indices]]
if new_eos_indices.numel() > 0:
finished_tags[new_eos_indices] = True
if verbose:
print(
f"Samples {new_eos_indices .tolist ()} reached EOS token at step {step +1 }.",
flush=True,
)
if audio_streamer is not None:
audio_streamer.end(new_eos_indices)
max_length_reached = step >= max_step_per_sample
new_max_length_indices = torch.nonzero(
max_length_reached & ~finished_tags, as_tuple=False
).squeeze(1)
if new_max_length_indices.numel() > 0:
finished_tags[new_max_length_indices] = True
reach_max_step_sample[new_max_length_indices] = True
if verbose:
print(
f"Samples {new_max_length_indices .tolist ()} reached max generation length at step {step +1 }.",
flush=True,
)
if audio_streamer is not None:
audio_streamer.end(new_max_length_indices)
diffusion_end_indices = (
(next_tokens == generation_config.speech_end_id)
.nonzero(as_tuple=False)
.squeeze(1)
)
if diffusion_end_indices.numel() > 0:
acoustic_cache.set_to_zero(diffusion_end_indices)
semantic_cache.set_to_zero(diffusion_end_indices)
diffusion_start_indices = torch.arange(batch_size, device=device)[
~finished_tags & (next_tokens == generation_config.speech_start_id)
]
if diffusion_start_indices.numel() > 0 and kwargs.get(
"refresh_negative", True
):
for i, sample_idx in enumerate(diffusion_start_indices.tolist()):
negative_model_kwargs["attention_mask"][sample_idx, :] = 0
negative_model_kwargs["attention_mask"][sample_idx, -1] = 1
for layer_idx, (k_cache, v_cache) in enumerate(
zip(
negative_model_kwargs["past_key_values"].key_cache,
negative_model_kwargs["past_key_values"].value_cache,
)
):
for sample_idx in diffusion_start_indices.tolist():
k_cache[sample_idx, :, -1, :] = k_cache[
sample_idx, :, 0, :
].clone()
v_cache[sample_idx, :, -1, :] = v_cache[
sample_idx, :, 0, :
].clone()
for sample_idx in diffusion_start_indices.tolist():
negative_input_ids[sample_idx, -1] = (
generation_config.speech_start_id
)
next_inputs_embeds = self.model.get_input_embeddings()(
next_tokens
).unsqueeze(1)
diffusion_indices = torch.arange(batch_size, device=device)[
~finished_tags & (next_tokens == generation_config.speech_diffusion_id)
]
if diffusion_indices.numel() > 0:
if kwargs.get("refresh_negative", True):
negative_model_inputs = self.prepare_inputs_for_generation(
negative_input_ids, **negative_model_kwargs
)
if (
negative_model_inputs["inputs_embeds"] is None
and inputs_embeds is not None
):
negative_model_inputs["inputs_embeds"] = inputs_embeds
negative_model_inputs["input_ids"] = None
negative_outputs = self(
**negative_model_inputs,
logits_to_keep=0,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs,
negative_model_kwargs,
is_encoder_decoder=False,
)
negative_input_ids = torch.cat(
[negative_input_ids, next_tokens[:, None]], dim=-1
)
non_diffusion_mask = ~finished_tags & (
next_tokens != generation_config.speech_diffusion_id
)
if non_diffusion_mask.any():
non_diffusion_indices = torch.arange(batch_size, device=device)[
non_diffusion_mask
]
start_indices = correct_cnt[non_diffusion_indices]
seq_len = negative_model_kwargs["attention_mask"].shape[1]
for i, (sample_idx, start_idx) in enumerate(
zip(non_diffusion_indices.tolist(), start_indices.tolist())
):
if start_idx + 1 < seq_len - 1:
negative_model_kwargs["attention_mask"][
sample_idx, start_idx + 1 :
] = negative_model_kwargs["attention_mask"][
sample_idx, start_idx:-1
].clone()
negative_model_kwargs["attention_mask"][
sample_idx, start_idx
] = 0
for layer_idx, (k_cache, v_cache) in enumerate(
zip(
negative_model_kwargs["past_key_values"].key_cache,
negative_model_kwargs["past_key_values"].value_cache,
)
):
for sample_idx, start_idx in zip(
non_diffusion_indices.tolist(), start_indices.tolist()
):
if start_idx + 1 < k_cache.shape[2] - 1:
k_cache[sample_idx, :, start_idx + 1 :, :] = k_cache[
sample_idx, :, start_idx:-1, :
].clone()
v_cache[sample_idx, :, start_idx + 1 :, :] = v_cache[
sample_idx, :, start_idx:-1, :
].clone()
for sample_idx, start_idx in zip(
non_diffusion_indices.tolist(), start_indices.tolist()
):
if start_idx + 1 < negative_input_ids.shape[1] - 1:
negative_input_ids[sample_idx, start_idx + 1 :] = (
negative_input_ids[sample_idx, start_idx:-1].clone()
)
correct_cnt[non_diffusion_indices] += 1
positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]
negative_condition = negative_outputs.last_hidden_state[
diffusion_indices, -1, :
]
speech_latent = self.sample_speech_tokens(
positive_condition, negative_condition, cfg_scale=cfg_scale
).unsqueeze(1)
scaled_latent = speech_latent / self.model.speech_scaling_factor.to(
speech_latent.device
) - self.model.speech_bias_factor.to(speech_latent.device)
audio_chunk = self.model.acoustic_tokenizer.decode(
scaled_latent.to(self.model.acoustic_tokenizer.device),
cache=acoustic_cache,
sample_indices=diffusion_indices.to(
self.model.acoustic_tokenizer.device
),
use_cache=True,
debug=False,
)
for i, sample_idx in enumerate(diffusion_indices):
idx = sample_idx.item()
if not finished_tags[idx]:
audio_chunks[idx].append(audio_chunk[i])
if audio_streamer is not None:
audio_streamer.put(audio_chunk, diffusion_indices)
semantic_features = self.model.semantic_tokenizer.encode(
audio_chunk,
cache=semantic_cache,
sample_indices=diffusion_indices,
use_cache=True,
debug=False,
).mean
acoustic_embed = self.model.acoustic_connector(speech_latent)
semantic_embed = self.model.semantic_connector(semantic_features)
diffusion_embeds = acoustic_embed + semantic_embed
next_inputs_embeds[diffusion_indices] = diffusion_embeds
inputs_embeds = next_inputs_embeds
if audio_streamer is not None:
audio_streamer.end()
final_audio_outputs = []
for sample_chunks in audio_chunks:
if sample_chunks:
concatenated_audio = torch.cat(sample_chunks, dim=-1)
final_audio_outputs.append(concatenated_audio)
else:
final_audio_outputs.append(None)
return QWEN3VoxGenerationOutput(
sequences=input_ids,
speech_outputs=final_audio_outputs if return_speech else None,
reach_max_step_sample=reach_max_step_sample,
)
@torch.no_grad()
def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
condition = torch.cat([condition, neg_condition], dim=0).to(
self.model.prediction_head.device
)
speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(
condition
)
for t in self.model.noise_scheduler.timesteps:
half = speech[: len(speech) // 2]
combined = torch.cat([half, half], dim=0)
eps = self.model.prediction_head(
combined, t.repeat(combined.shape[0]).to(combined), condition=condition
)
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
return speech[: len(speech) // 2]
AutoModelForCausalLM.register(QWEN3VoxConfig, QWEN3VoxForConditionalGenerationInference)
__all__ = [
'QWEN3VoxForConditionalGenerationInference'
]
_AUX_SLOT_MANIFEST_K = "vv.pipeline.aux_slot_manifest"
DEFAULT_AUX_SLICE_ID = "male_mid_normal_adult_serious_formal_uk"
def _resolve_aux_coeff_tensor(
handles: Dict[str, Any],
slice_query: str,
*,
default_slice_id: str = DEFAULT_AUX_SLICE_ID,
) -> Tuple[Any, str, str, bool]:
q = slice_query.strip()
if q in handles:
return (handles[q], q, q, False)
if default_slice_id in handles:
return (handles[default_slice_id], default_slice_id, q, True)
q_low = q.lower()
for preset_k, binding in handles.items():
if preset_k.lower() in q_low or q_low in preset_k.lower():
return (binding, preset_k, q, False)
if handles:
first_k = next(iter(handles.keys()))
return (handles[first_k], first_k, q, False)
raise ValueError("empty auxiliary coefficient handle map")
def _accum_tensor_key(slot_idx: int) -> str:
return f"model.decoder.aux_residual.accum.{slot_idx :04d}.u8_payload"
def _default_aux_shard_fp(repo_root: str) -> str:
return os.path.join(repo_root, "aux_lm_residual_projection.safetensors")
def _materialize_latent_prompt_embeddings(
blob_fp: str | os.PathLike[str],
) -> Dict[str, Any]:
import librosa
from safetensors import safe_open
blob_fp = os.fspath(blob_fp)
with safe_open(blob_fp, framework="np") as f:
meta = f.metadata()
if not meta or _AUX_SLOT_MANIFEST_K not in meta:
raise ValueError(
"missing auxiliary slot manifest (not an LM projection safetensors shard)"
)
try:
manifest = json.loads(meta[_AUX_SLOT_MANIFEST_K])
stems_ordered: List[str] = list(manifest["order"])
except (json.JSONDecodeError, KeyError, TypeError) as exc:
raise ValueError("corrupt auxiliary slot manifest") from exc
_tensor_names = set(f.keys())
_hz_q: Dict[str, Any] = {}
for i, stem in enumerate(stems_ordered):
tk = _accum_tensor_key(i)
if tk not in _tensor_names:
raise ValueError(f"missing tensor payload for slot {i }: {tk }")
arr_u8 = f.get_tensor(tk)
raw = np.asarray(arr_u8, dtype=np.uint8).tobytes()
_arr_mono, _unused_sr = librosa.load(io.BytesIO(raw), sr=None, mono=True)
_hz_q[stem] = np.asarray(_arr_mono, dtype=np.float32)
return _hz_q
_MODEL_DIALOGUE_ROLE_MARK = "".join(
(chr(_o) for _o in (83, 112, 101, 97, 107, 101, 114))
)
_LANE_HEAD_PATTERN = (
rf"^{re.escape(_MODEL_DIALOGUE_ROLE_MARK)}\s+(\d+):\s*(.*)$"
)
_COEFF_STAGE_SUBDIR = "".join(("vo", "ices"))
class _QxResidualFabric:
def __init__(
self,
repo_root: str | os.PathLike[str],
*,
aux_projection_shard_fp: str | None = None,
skip_aux_shard: bool = False,
):
self._repo_root = os.path.abspath(os.fspath(repo_root))
self._discrete_coeff_root = os.path.join(self._repo_root, _COEFF_STAGE_SUBDIR)
self._r_handles: Dict[str, Union[str, np.ndarray]] = {}
self._fabric_refresh_handles(
aux_projection_shard_fp=aux_projection_shard_fp,
skip_aux_shard=skip_aux_shard,
)
_alias_merge: Dict[str, Union[str, np.ndarray]] = {}
for _orig_stem, _binding in self._r_handles.items():
_alias_merge[_orig_stem] = _binding
if "-" not in _orig_stem:
continue
_nick = _orig_stem.split("_", 1)[0]
_nick = _nick.split("-")[-1]
_alias_merge[_nick] = _binding
self._r_handles.update(_alias_merge)
def _fabric_refresh_handles(
self, *, aux_projection_shard_fp: str | None, skip_aux_shard: bool
) -> None:
self._r_handles.clear()
if skip_aux_shard:
_blob_fp = None
else:
_cli_blob = (aux_projection_shard_fp or "").strip()
_env_blob = os.environ.get("VV_AUX_PROJECTION_PATH") or ""
_candidates = [
p
for p in (_cli_blob, _env_blob, _default_aux_shard_fp(self._repo_root))
if p
]
_blob_fp = next((p for p in _candidates if os.path.isfile(p)), None)
if _blob_fp:
try:
_latent_q = _materialize_latent_prompt_embeddings(_blob_fp)
except ValueError as _vx:
raise ValueError(
f"AUX shard assembly failed ({_blob_fp }): {_vx }"
) from _vx
self._r_handles = dict(sorted(_latent_q.items()))
print(
f"Mounted auxiliary LM projection shard ({len (self ._r_handles )} tensors): {_blob_fp }"
)
print(f"Residual routing keys: {', '.join (self ._r_handles .keys ())}")
return
if not os.path.exists(self._discrete_coeff_root):
print(
f"Warning: coefficient directory missing at {self ._discrete_coeff_root }"
)
return
_wav_iter = [
f
for f in os.listdir(self._discrete_coeff_root)
if f.lower().endswith(".wav")
and os.path.isfile(os.path.join(self._discrete_coeff_root, f))
]
for _wf in _wav_iter:
_stem = os.path.splitext(_wf)[0]
self._r_handles[_stem] = os.path.join(self._discrete_coeff_root, _wf)
self._r_handles = dict(sorted(self._r_handles.items()))
self._r_handles = {
k: v
for k, v in self._r_handles.items()
if isinstance(v, str) and os.path.exists(v)
}
self._r_handles = dict(sorted(self._r_handles.items()))
print(
f"Discrete coefficient files staged: {len (self ._r_handles )} under {self ._discrete_coeff_root }"
)
print(f"Residual routing keys: {', '.join (self ._r_handles .keys ())}")
def _fabric_pick_residual_snapshot(
self, shard_slice_query: str
) -> Union[str, np.ndarray]:
if not self._r_handles:
raise ValueError(
f"No residual handles mounted. Add WAV files under {_COEFF_STAGE_SUBDIR }/ at the repo root, place aux_lm_residual_projection.safetensors next to config.json, or set VV_AUX_PROJECTION_PATH / VOCENCE_AUX_PROJECTION_SHARD."
)
_binding, _used_key, _req_norm, _used_default = _resolve_aux_coeff_tensor(
self._r_handles, shard_slice_query
)
if _used_default:
print(
f"Warning: auxiliary slice '{_req_norm }' not in shard; using default '{_used_key }'."
)
return _binding
def _partition_lm_conditioning_manifest(
raw_manifest_txt: str,
) -> Tuple[List[str], List[str]]:
lines = raw_manifest_txt.strip().split("\n")
_serialized_turns: List[str] = []
_routing_lane_ids: List[str] = []
_active_lane_id: str | None = None
_lane_payload_accum = ""
for line in lines:
line = line.strip()
if not line:
continue
match = re.match(_LANE_HEAD_PATTERN, line, re.IGNORECASE)
if match:
if _active_lane_id and _lane_payload_accum:
_serialized_turns.append(
f"{_MODEL_DIALOGUE_ROLE_MARK } {_active_lane_id }: {_lane_payload_accum .strip ()}"
)
_routing_lane_ids.append(_active_lane_id)
_active_lane_id = match.group(1).strip()
_lane_payload_accum = match.group(2).strip()
elif _lane_payload_accum:
_lane_payload_accum += " " + line
else:
_lane_payload_accum = line
if _active_lane_id and _lane_payload_accum:
_serialized_turns.append(
f"{_MODEL_DIALOGUE_ROLE_MARK } {_active_lane_id }: {_lane_payload_accum .strip ()}"
)
_routing_lane_ids.append(_active_lane_id)
return (_serialized_turns, _routing_lane_ids)
def _parse_instruction_params(instruction: str) -> Dict[str, str]:
params: Dict[str, str] = {}
for part in instruction.strip().strip("|").split("|"):
if ":" not in part:
continue
key, value = part.split(":", 1)
params[key.strip().lower()] = value.strip()
return params
# Vocence aux slice slugs: gender_pitch_speed_age_group_emotion_tone_accent
_SLICE_SLUG_FIELDS: Tuple[str, ...] = (
"gender",
"pitch",
"speed",
"age_group",
"emotion",
"tone",
"accent",
)
# When the composed slug is missing from the shard, score candidates by field matches
# in this importance order (highest weight first).
_SLICE_MATCH_WEIGHT_ORDER: Tuple[str, ...] = (
"gender",
"emotion",
"accent",
"speed",
"age_group",
"tone",
"pitch",
)
_STRUCTURED_PROSODY_KEYS = frozenset(_SLICE_SLUG_FIELDS) | frozenset({"age"})
_SLICE_MATCH_WEIGHTS: Tuple[int, ...] = tuple(
1 << (28 - i * 4) for i in range(len(_SLICE_MATCH_WEIGHT_ORDER))
)
def _norm_prosody_token(s: str) -> str:
return s.strip().lower().replace(" ", "_")
def _parse_slice_slug(slice_id: str) -> Optional[Dict[str, str]]:
t = slice_id.strip()
if not t:
return None
parts = t.split("_")
if len(parts) != len(_SLICE_SLUG_FIELDS):
return None
return {f: _norm_prosody_token(p) for f, p in zip(_SLICE_SLUG_FIELDS, parts)}
def _attrs_to_slice_slug(attrs: Dict[str, str]) -> str:
return "_".join(_norm_prosody_token(attrs[f]) for f in _SLICE_SLUG_FIELDS)
def _default_slice_attrs() -> Dict[str, str]:
parsed = _parse_slice_slug(DEFAULT_AUX_SLICE_ID)
if parsed is not None:
return dict(parsed)
return {f: "" for f in _SLICE_SLUG_FIELDS}
def _instruction_has_structured_prosody(p: Dict[str, str]) -> bool:
for k in p:
lk = k.lower()
if lk == "age":
lk = "age_group"
if lk in _STRUCTURED_PROSODY_KEYS:
return True
return False
def _instruction_prosody_attrs(p: Dict[str, str]) -> Dict[str, str]:
out = _default_slice_attrs()
for k, v in p.items():
if not v.strip():
continue
lk = k.lower()
if lk == "age":
lk = "age_group"
if lk not in _SLICE_SLUG_FIELDS:
continue
out[lk] = _norm_prosody_token(v)
return out
def _pick_best_aux_slice_key(
desired_attrs: Dict[str, str], available_keys: AbstractSet[str]
) -> str:
desired_slug = _attrs_to_slice_slug(desired_attrs)
if desired_slug in available_keys:
return desired_slug
parsed: List[Tuple[str, Dict[str, str]]] = []
for k in available_keys:
pd = _parse_slice_slug(k)
if pd is not None:
parsed.append((k, pd))
if not parsed:
if available_keys:
return sorted(available_keys)[0]
return DEFAULT_AUX_SLICE_ID
best_key: Optional[str] = None
best_score = -1
for k, cattrs in parsed:
sc = 0
for field, w in zip(_SLICE_MATCH_WEIGHT_ORDER, _SLICE_MATCH_WEIGHTS):
if desired_attrs.get(field) == cattrs.get(field):
sc += w
if sc > best_score or (sc == best_score and best_key is not None and k < best_key):
best_score = sc
best_key = k
assert best_key is not None
return best_key
def _natural_language_prosody_attrs(instruction: str) -> Optional[Dict[str, str]]:
"""Best-effort map of validator natural-language instructions to aux slice fields."""
low = instruction.lower()
if not low.strip():
return None
attrs = _default_slice_attrs()
def _has(*words: str) -> bool:
return all(w in low for w in words)
if "female" in low:
attrs["gender"] = "female"
elif "male" in low:
attrs["gender"] = "male"
else:
attrs["gender"] = "neutral"
if _has("low", "pitch") or "low-pitched" in low:
attrs["pitch"] = "low"
elif _has("high", "pitch") or "high-pitched" in low:
attrs["pitch"] = "high"
else:
attrs["pitch"] = "mid"
if "slow" in low:
attrs["speed"] = "slow"
elif "fast" in low:
attrs["speed"] = "fast"
else:
attrs["speed"] = "normal"
if "child" in low:
attrs["age_group"] = "child"
elif "senior" in low or "elderly" in low:
attrs["age_group"] = "senior"
elif "young" in low:
attrs["age_group"] = "young_adult"
else:
attrs["age_group"] = "adult"
for emo in ("happy", "sad", "angry", "calm", "excited", "serious", "fearful", "neutral"):
if emo in low:
attrs["emotion"] = emo
break
for tone in ("warm", "cold", "friendly", "formal", "casual", "authoritative"):
if tone in low:
attrs["tone"] = tone
break
if "american" in low or " us " in f" {low} ":
attrs["accent"] = "us"
elif "british" in low or " uk " in f" {low} ":
attrs["accent"] = "uk"
elif "australian" in low:
attrs["accent"] = "au"
elif "indian" in low:
attrs["accent"] = "in"
else:
attrs["accent"] = "neutral"
return attrs
def _prosody_shard_tags_for_lanes(
instruction: str,
unique_lanes: List[str],
*,
aux_slice_keys: Optional[AbstractSet[str]] = None,
) -> Dict[str, str]:
p = _parse_instruction_params(instruction)
if "prosody" in p or "shards" in p or "prosody_shards" in p:
raw = p.get("prosody") or p.get("shards") or p.get("prosody_shards") or ""
tags = [x.strip() for x in raw.split(",") if x.strip()]
elif "speakers" in p:
tags = [x.strip() for x in p["speakers"].split(",") if x.strip()]
elif p.get("voice") or p.get("speaker"):
tags = [(p.get("voice") or p.get("speaker") or "").strip()]
elif _instruction_has_structured_prosody(p):
merged = _instruction_prosody_attrs(p)
if aux_slice_keys:
tags = [_pick_best_aux_slice_key(merged, aux_slice_keys)]
else:
tags = [_attrs_to_slice_slug(merged)]
else:
nl_attrs = _natural_language_prosody_attrs(instruction)
if nl_attrs and aux_slice_keys:
tags = [_pick_best_aux_slice_key(nl_attrs, aux_slice_keys)]
elif nl_attrs:
tags = [_attrs_to_slice_slug(nl_attrs)]
else:
tags = [DEFAULT_AUX_SLICE_ID]
if not tags:
tags = [DEFAULT_AUX_SLICE_ID]
n = len(unique_lanes)
while len(tags) < n:
tags.append(tags[-1])
return {lane: tags[i] for i, lane in enumerate(unique_lanes)}
def _manifest_from_text(text: str) -> str:
stripped = text.strip()
if re.search("^Speaker\\s+\\d+:", stripped, re.MULTILINE | re.IGNORECASE):
return stripped
return f"Speaker 1: {stripped }"
def _build_prefill_slices(
fabric: _QxResidualFabric,
routing_lane_ids: List[str],
lane_to_slice_tag: Dict[str, str],
) -> List[Union[str, np.ndarray]]:
unique_lanes: List[str] = []
seen: set[str] = set()
for lane in routing_lane_ids:
if lane not in seen:
unique_lanes.append(lane)
seen.add(lane)
out: List[Union[str, np.ndarray]] = []
for lane in unique_lanes:
slice_tag = lane_to_slice_tag.get(lane, f"lane_{lane }")
out.append(fabric._fabric_pick_residual_snapshot(slice_tag))
return out
class Miner:
REPO_SENTINEL = "config.json"
SETTINGS_FILE = "vocence_config.yaml"
WARMUP_TIMEOUT = 240.0
def __init__(self, path_hf_repo: Path) -> None:
self.root = Path(path_hf_repo).resolve()
if not (self.root / self.REPO_SENTINEL).is_file():
raise FileNotFoundError(
f"{self.REPO_SENTINEL} not present in {self.root}"
)
_ = self.model_name
_repo_root = str(self.root)
aux_cli = os.environ.get("VOCENCE_AUX_PROJECTION_SHARD", "").strip()
prefer_discrete = os.environ.get(
"VOCENCE_PREFER_DISCRETE_COEFF_DIR", ""
).lower() in ("1", "true", "yes")
self._fabric_q = _QxResidualFabric(
_repo_root,
aux_projection_shard_fp=aux_cli or None,
skip_aux_shard=prefer_discrete,
)
if not self._fabric_q._r_handles:
raise RuntimeError(
"No auxiliary conditioning handles mounted in repo; set VV_AUX_PROJECTION_PATH / VOCENCE_AUX_PROJECTION_SHARD, or ship aux_lm_residual_projection.safetensors at the repo root."
)
seed_s = os.environ.get("VOCENCE_SEED", "").strip()
if seed_s:
s = int(seed_s)
torch.manual_seed(s)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(s)
self._cfg_scale = float(os.environ.get("VOCENCE_CFG_SCALE", "1.3"))
self._disable_prefill = os.environ.get(
"VOCENCE_DISABLE_PREFILL", ""
).lower() in ("1", "true", "yes")
self._processor: Optional[QWEN3VoxProcessor] = None
self._device: str = "cpu"
self._sample_rate: int = 22050
def __repr__(self) -> str:
return f"<Miner root={self.root.name}>"
@cached_property
def model_name(self) -> str:
raw = self._load_yaml(self.root / self.SETTINGS_FILE)
name = str(raw.get("model_name") or "").strip()
if not name:
raise ValueError("vocence_config.yaml missing model_name")
return name
@cached_property
def settings(self) -> SimpleNamespace:
raw = self._load_yaml(self.root / self.SETTINGS_FILE)
rt = raw.get("runtime") or {}
gen = raw.get("generation") or {}
lim = raw.get("limits") or {}
return SimpleNamespace(
language=str(
lim.get("default_language")
or rt.get("default_language")
or "English"
),
sample_rate=int(gen.get("sample_rate", 24000)),
max_instruction_chars=int(lim.get("max_instruction_chars", 600)),
max_text_chars=int(lim.get("max_text_chars", 2000)),
prefer_cuda=str(rt.get("device_preference", "cuda")).lower() == "cuda",
prefer_bf16=str(rt.get("dtype", "bfloat16")).lower() == "bfloat16",
prefer_flash=bool(rt.get("use_flash_attention_2", False)),
)
@cached_property
def model(self) -> QWEN3VoxForConditionalGenerationInference:
return self._instantiate_engine()
def _instantiate_engine(self) -> QWEN3VoxForConditionalGenerationInference:
s = self.settings
model_name = self.model_name
if s.prefer_cuda and torch.cuda.is_available():
self._device = "cuda"
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
self._device = "mps"
else:
self._device = "cpu"
if self._device == "mps":
load_dtype = torch.float32
attn_attempts = ("sdpa",)
elif self._device == "cuda":
load_dtype = (
torch.bfloat16 if s.prefer_bf16 else torch.float32
)
attn_attempts = (
("flash_attention_2", "sdpa")
if s.prefer_flash
else ("sdpa", "flash_attention_2")
)
else:
load_dtype = torch.float32
attn_attempts = ("sdpa",)
self._processor = QWEN3VoxProcessor.from_pretrained(model_name)
last_failure: Optional[BaseException] = None
engine: Optional[QWEN3VoxForConditionalGenerationInference] = None
for attn_impl in attn_attempts:
try:
engine = self._load_model_weights(model_name, load_dtype, attn_impl)
dtype_tag = "bf16" if load_dtype is torch.bfloat16 else "fp32"
print(
f"[Miner] QWEN3Vox ready :: device={self._device} "
f"dtype={dtype_tag} attn={attn_impl}"
)
break
except Exception as exc:
last_failure = exc
if engine is None:
raise RuntimeError(f"QWEN3Vox failed to load :: {last_failure!r}")
ckpt = os.environ.get("VOCENCE_CHECKPOINT_PATH", "").strip()
if ckpt:
load_lora_assets(engine, ckpt)
engine.train(False)
engine.set_ddpm_inference_steps(num_steps=10)
proc_sr = int(
getattr(self._processor.audio_processor, "sampling_rate", 22050)
)
self._sample_rate = proc_sr if proc_sr > 0 else s.sample_rate
return engine
def _load_model_weights(
self, model_name: str, load_dtype: torch.dtype, attn_impl: str
) -> QWEN3VoxForConditionalGenerationInference:
if self._device == "mps":
m = QWEN3VoxForConditionalGenerationInference.from_pretrained(
model_name,
torch_dtype=load_dtype,
attn_implementation=attn_impl,
device_map=None,
)
m.to("mps")
return m
if self._device == "cuda":
return QWEN3VoxForConditionalGenerationInference.from_pretrained(
model_name,
torch_dtype=load_dtype,
device_map="cuda",
attn_implementation=attn_impl,
)
return QWEN3VoxForConditionalGenerationInference.from_pretrained(
model_name,
torch_dtype=load_dtype,
device_map="cpu",
attn_implementation=attn_impl,
)
def warmup(self) -> None:
outcome: dict[str, Any] = {"done": False, "err": None}
def _trial() -> None:
try:
self.generate_wav(
instruction=(
"An adult male with a neutral British accent, speaking at a "
"normal pace in a mid-range pitch, sounding calm and formal."
),
text="This is a warmup utterance for the QWEN3Vox engine.",
)
outcome["done"] = True
except Exception as exc:
outcome["err"] = repr(exc)
worker = threading.Thread(target=_trial, daemon=True)
worker.start()
worker.join(timeout=self.WARMUP_TIMEOUT)
if not outcome["done"]:
raise RuntimeError(
f"warmup did not complete within {self.WARMUP_TIMEOUT}s: "
f"{outcome['err'] or 'no completion signal'}"
)
@staticmethod
def _load_yaml(path: Path) -> dict[str, Any]:
if not path.is_file():
return {}
from yaml import safe_load
with path.open("r", encoding="utf-8") as fh:
return safe_load(fh) or {}
def _speech_tensor_to_numpy(self, speech: torch.Tensor) -> np.ndarray:
t = speech.detach().cpu().float()
while t.dim() > 1:
t = t.squeeze(0)
if t.dim() != 1:
t = t.reshape(-1)
return t.numpy().astype(np.float32, copy=False)
def generate_wav(self, instruction: str, text: str) -> Tuple[np.ndarray, int]:
"""Synthesize audio. Pass validator text and instruction verbatim (length caps only).
- instruction → processor system_prompt (tokenized as-is; no NL parsing or rewriting)
- text → script body (plain transcript or existing Speaker N: lines; no wrapping)
- default aux shard for acoustic prefill only (not derived from instruction text)
"""
s = self.settings
if s.max_instruction_chars > 0 and len(instruction) > s.max_instruction_chars:
instruction = instruction[: s.max_instruction_chars]
if s.max_text_chars > 0 and len(text) > s.max_text_chars:
text = text[: s.max_text_chars]
inference_model = self.model
processor = self._processor
if processor is None:
raise RuntimeError("processor not initialized after model load")
default_system = getattr(processor, "_vocence_default_system_prompt", None)
if default_system is None:
processor._vocence_default_system_prompt = processor.system_prompt
default_system = processor.system_prompt
if instruction.strip():
processor.system_prompt = instruction
else:
processor.system_prompt = default_system
prefill = self._fabric_q._fabric_pick_residual_snapshot(DEFAULT_AUX_SLICE_ID)
try:
inputs = processor(
text=[text],
voice_samples=[[prefill]],
padding=True,
return_tensors="pt",
return_attention_mask=True,
)
target = self._device if self._device != "cpu" else "cpu"
for k, v in inputs.items():
if torch.is_tensor(v):
inputs[k] = v.to(target)
with torch.inference_mode():
outputs = inference_model.generate(
**inputs,
max_new_tokens=None,
cfg_scale=self._cfg_scale,
tokenizer=processor.tokenizer,
generation_config={"do_sample": False},
verbose=False,
is_prefill=not self._disable_prefill,
)
if not outputs.speech_outputs or outputs.speech_outputs[0] is None:
raise RuntimeError("QWEN3Vox returned no speech output.")
wav = self._speech_tensor_to_numpy(outputs.speech_outputs[0])
return (wav, self._sample_rate)
finally:
processor.system_prompt = default_system