trainer-13 / miner.py
might2901's picture
Upload folder using huggingface_hub
c6a64b8 verified
from __future__ import annotations
import io
import json
import os
import re
import sys
import threading
import traceback
from pathlib import Path
from typing import AbstractSet, Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from transformers.utils import logging as hf_logging
import math
import random
import warnings
from dataclasses import dataclass
try:
import librosa
except Exception:
librosa = None
try:
import resampy
except Exception:
resampy = None
def _resample_if_needed(wav: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
if orig_sr == target_sr:
return wav.astype(np.float32, copy=False)
if resampy is not None:
return resampy.resample(wav.astype(np.float32), orig_sr, target_sr)
if librosa is not None:
return librosa.resample(
y=wav.astype(np.float32), orig_sr=orig_sr, target_sr=target_sr
)
warnings.warn(
"No resampler available; treating audio as target_sr without resampling. Install resampy or librosa.",
RuntimeWarning,
)
return wav.astype(np.float32, copy=False)
class QWEN3VoxDataset:
def __init__(
self,
dataset: Any,
text_column: str = "text",
audio_column: str = "audio",
voice_prompts_column: Optional[str] = "voice_prompts",
) -> None:
self.dataset = dataset
self.text_column = text_column
self.audio_column = audio_column
self.voice_prompts_column = voice_prompts_column
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> Dict[str, Any]:
item = self.dataset[idx]
data: Dict[str, Any] = {}
data["text"] = item[self.text_column]
data["audio"] = item[self.audio_column]
user_provided_prompt = None
if self.voice_prompts_column and self.voice_prompts_column in item:
user_provided_prompt = item[self.voice_prompts_column]
if user_provided_prompt:
if not isinstance(user_provided_prompt, list):
data["voice_prompts"] = [user_provided_prompt]
else:
data["voice_prompts"] = user_provided_prompt
else:
try:
target_sr = 22050
wav_array = _load_audio_to_24k(
item[self.audio_column], target_sr=target_sr
)
audio_len_seconds = len(wav_array) / target_sr
min_len_sec = min(5.0, audio_len_seconds / 4.0)
max_len_sec = min(15.0, audio_len_seconds / 2.0)
if min_len_sec > max_len_sec:
min_len_sec = max_len_sec
max_len_sec = min(max_len_sec, audio_len_seconds)
if max_len_sec > 0.1:
prompt_len_sec = random.uniform(min_len_sec, max_len_sec)
prompt_len_samples = int(prompt_len_sec * target_sr)
max_start_sample = len(wav_array) - prompt_len_samples
start_sample = random.randint(0, max_start_sample)
prompt_crop = wav_array[
start_sample : start_sample + prompt_len_samples
]
data["voice_prompts"] = [prompt_crop]
else:
data["voice_prompts"] = None
except Exception as e:
warnings.warn(f"Could not create voice prompt for item {idx }: {e }")
data["voice_prompts"] = None
return data
def _apply_silence_with_crossfade(
wav: np.ndarray,
*,
sample_rate: int,
pre_silence_sec: float = 0.25,
pre_crossfade_sec: float = 0.25,
post_crossfade_sec: float = 0.25,
post_silence_sec: float = 0.75,
) -> np.ndarray:
wav = np.asarray(wav, dtype=np.float32).reshape(-1)
start_sil_samples = int(round(pre_silence_sec * sample_rate))
end_sil_samples = int(round(post_silence_sec * sample_rate))
pre_crossfade_samples = int(round(pre_crossfade_sec * sample_rate))
post_crossfade_samples = int(round(post_crossfade_sec * sample_rate))
total_len = wav.shape[0]
if total_len == 0:
pieces: List[np.ndarray] = []
if start_sil_samples > 0:
pieces.append(np.zeros(start_sil_samples, dtype=np.float32))
if end_sil_samples > 0:
pieces.append(np.zeros(end_sil_samples, dtype=np.float32))
return np.concatenate(pieces) if pieces else wav
start_len = min(pre_crossfade_samples, total_len)
remaining_after_start = max(total_len - start_len, 0)
end_len = min(post_crossfade_samples, remaining_after_start)
middle_end_idx = total_len - end_len
start_segment = wav[:start_len]
middle_segment = wav[start_len:middle_end_idx]
end_segment = wav[middle_end_idx:]
def _linear_fade(num_samples: int, start: float, end: float) -> np.ndarray:
if num_samples <= 0:
return np.zeros((0,), dtype=np.float32)
return np.linspace(start, end, num_samples, endpoint=True, dtype=np.float32)
start_crossfade = start_segment * _linear_fade(start_len, 0.0, 1.0)
end_crossfade = end_segment * _linear_fade(end_segment.shape[0], 1.0, 0.0)
pieces: List[np.ndarray] = []
if start_sil_samples > 0:
pieces.append(np.zeros(start_sil_samples, dtype=np.float32))
if start_crossfade.size > 0:
pieces.append(start_crossfade.astype(np.float32, copy=False))
if middle_segment.size > 0:
pieces.append(middle_segment.astype(np.float32, copy=False))
if end_crossfade.size > 0:
pieces.append(end_crossfade.astype(np.float32, copy=False))
if end_sil_samples > 0:
pieces.append(np.zeros(end_sil_samples, dtype=np.float32))
return np.concatenate(pieces)
def _load_audio_to_24k(
audio: Union[str, np.ndarray, torch.Tensor, Dict[str, Any]],
*,
target_sr: int = 22050,
augment_with_silence: bool = False,
) -> np.ndarray:
if isinstance(audio, np.ndarray):
wav_out = audio.astype(np.float32)
elif isinstance(audio, torch.Tensor):
wav_out = audio.detach().cpu().float().numpy()
elif isinstance(audio, str):
if librosa is None:
raise RuntimeError(
"librosa is required to load audio file paths. Please pip install librosa."
)
wav, sr = librosa.load(audio, sr=None, mono=True)
wav_out = _resample_if_needed(wav, int(sr), target_sr)
elif isinstance(audio, dict) and "array" in audio and ("sampling_rate" in audio):
arr = np.asarray(audio["array"], dtype=np.float32)
sr = int(audio["sampling_rate"])
wav_out = _resample_if_needed(arr, sr, target_sr)
else:
raise ValueError(f"Unsupported audio type: {type (audio )}")
wav_out = np.asarray(wav_out, dtype=np.float32)
if augment_with_silence:
wav_out = _apply_silence_with_crossfade(wav_out, sample_rate=target_sr)
return wav_out
@dataclass
class QWEN3VoxCollator:
processor: Any
max_length: Optional[int] = None
speech_compress_ratio: int = 3200
semantic_vae_dim: int = 128
compute_semantics: bool = False
debug_checks: bool = False
text_field: str = "text"
audio_field: str = "audio"
voice_prompts_field: str = "voice_prompts"
voice_prompt_drop_rate: float = 0.0
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
batch_size = len(features)
sample_input_ids: List[List[int]] = []
sample_attention_masks: List[List[int]] = []
sample_acoustic_input_masks: List[List[bool]] = []
sample_acoustic_loss_masks: List[List[bool]] = []
all_speech_waveforms: List[np.ndarray] = []
all_speech_latent_lengths: List[int] = []
per_segment_is_target: List[bool] = []
for ex in features:
text: str = ex.get(self.text_field, "")
voice_prompts: Optional[List[Union[str, np.ndarray, torch.Tensor]]] = (
ex.get(self.voice_prompts_field)
)
target_audio: Union[str, np.ndarray, torch.Tensor, Dict[str, Any]] = ex.get(
self.audio_field
)
_drop_rate = self.voice_prompt_drop_rate
if _drop_rate < 0.0:
_drop_rate = 0.0
elif _drop_rate > 1.0:
_drop_rate = 1.0
proc = self.processor(
text=[text],
voice_samples=(
[voice_prompts]
if voice_prompts is not None and random.random() >= _drop_rate
else None
),
padding=False,
truncation=False,
max_length=self.max_length,
return_tensors="pt",
)
ids = proc["input_ids"][0].tolist()
attn = proc.get("attention_mask", torch.ones_like(proc["input_ids"]))[
0
].tolist()
speech_input_mask = proc.get("speech_input_mask")
if speech_input_mask is None:
speech_input_mask = torch.zeros_like(
proc["input_ids"], dtype=torch.bool
)
speech_input_mask_list = speech_input_mask[0].tolist()
wav_target = _load_audio_to_24k(
target_audio, target_sr=22050, augment_with_silence=True
)
target_latent_len = None
try:
acoustic_tok = getattr(self.processor, "acoustic_tokenizer", None)
if acoustic_tok is not None and hasattr(acoustic_tok, "encode"):
enc_out = acoustic_tok.encode(wav_target)
T = None
try:
if (
hasattr(enc_out, "shape")
and len(getattr(enc_out, "shape", [])) >= 1
):
T = int(enc_out.shape[0])
else:
cand = enc_out
for _ in range(2):
if isinstance(cand, (list, tuple)) and len(cand) > 0:
cand = cand[0]
if (
hasattr(cand, "shape")
and len(getattr(cand, "shape", [])) >= 1
):
T = int(cand.shape[0])
except Exception:
T = None
if T is not None and T > 0:
target_latent_len = T
except Exception:
target_latent_len = None
if target_latent_len is None:
target_latent_len = max(
1,
int(math.ceil(len(wav_target) / float(self.speech_compress_ratio))),
)
speech_diff_id = self.processor.tokenizer.speech_diffusion_id
target_placeholders = [speech_diff_id] * target_latent_len
ids_extended = ids + target_placeholders
attn_extended = attn + [1] * target_latent_len
acoustic_input_mask = speech_input_mask_list + [True] * target_latent_len
acoustic_loss_mask = [False] * len(speech_input_mask_list) + [
True
] * target_latent_len
speech_end_id = self.processor.tokenizer.speech_end_id
ids_extended.append(speech_end_id)
attn_extended.append(1)
acoustic_input_mask.append(False)
acoustic_loss_mask.append(False)
eos_token_id = getattr(self.processor.tokenizer, "eos_id", None)
if eos_token_id is None:
eos_token_id = getattr(self.processor.tokenizer, "eos_token_id", None)
if eos_token_id is not None and eos_token_id >= 0:
ids_extended.append(eos_token_id)
attn_extended.append(1)
acoustic_input_mask.append(False)
acoustic_loss_mask.append(False)
if self.max_length is not None and len(ids_extended) > self.max_length:
cut = len(ids_extended) - int(self.max_length)
leading_non_acoustic = 0
for v in acoustic_input_mask:
if v:
break
leading_non_acoustic += 1
if cut > leading_non_acoustic:
raise ValueError(
f"--max_length={self .max_length } would truncate into acoustic tokens. Needed cut={cut }, but only {leading_non_acoustic } leading non-acoustic tokens available. Increase max_length or shorten text/voice-prompt preamble."
)
ids_extended = ids_extended[cut:]
attn_extended = attn_extended[cut:]
acoustic_input_mask = acoustic_input_mask[cut:]
acoustic_loss_mask = acoustic_loss_mask[cut:]
sample_input_ids.append(ids_extended)
sample_attention_masks.append(attn_extended)
sample_acoustic_input_masks.append(acoustic_input_mask)
sample_acoustic_loss_masks.append(acoustic_loss_mask)
voice_speeches = []
voice_latent_lengths = []
if proc.get("speech_tensors") is not None:
voice_np = proc["speech_tensors"].cpu().numpy()
voice_masks = proc["speech_masks"].cpu().numpy().astype(bool)
for seg_idx in range(voice_np.shape[0]):
voice_speeches.append(voice_np[seg_idx])
voice_latent_lengths.append(int(voice_masks[seg_idx].sum()))
all_speech_waveforms.extend(voice_speeches)
all_speech_latent_lengths.extend(voice_latent_lengths)
per_segment_is_target.extend([False] * len(voice_speeches))
all_speech_waveforms.append(wav_target)
all_speech_latent_lengths.append(target_latent_len)
per_segment_is_target.append(True)
max_seq_len = max((len(x) for x in sample_input_ids))
padded_input_ids = []
padded_attention_masks = []
padded_acoustic_input_masks = []
padded_acoustic_loss_masks = []
tok = self.processor.tokenizer
pad_token_id = getattr(tok, "pad_token_id", None)
if pad_token_id is None or pad_token_id < 0:
pad_token_id = getattr(tok, "eos_token_id", None)
if pad_token_id is None or pad_token_id < 0:
raise ValueError(
"Tokenizer has no pad_token_id or eos_token_id; please set one or pass a valid pad id."
)
for ids, attn, ain_mask, aloss_mask in zip(
sample_input_ids,
sample_attention_masks,
sample_acoustic_input_masks,
sample_acoustic_loss_masks,
):
pad_len = max_seq_len - len(ids)
padded_input_ids.append(ids + [pad_token_id] * pad_len)
padded_attention_masks.append(attn + [0] * pad_len)
padded_acoustic_input_masks.append(ain_mask + [False] * pad_len)
padded_acoustic_loss_masks.append(aloss_mask + [False] * pad_len)
input_ids_tensor = torch.tensor(padded_input_ids, dtype=torch.long)
attention_mask_tensor = torch.tensor(padded_attention_masks, dtype=torch.long)
acoustic_input_mask_tensor = torch.tensor(
padded_acoustic_input_masks, dtype=torch.bool
)
acoustic_loss_mask_tensor = torch.tensor(
padded_acoustic_loss_masks, dtype=torch.bool
)
if all_speech_waveforms:
max_wave_len = max((w.shape[0] for w in all_speech_waveforms))
padded_speeches = np.zeros(
(len(all_speech_waveforms), max_wave_len), dtype=np.float32
)
for i, w in enumerate(all_speech_waveforms):
L = w.shape[0]
padded_speeches[i, :L] = w
max_latent_len = (
max(all_speech_latent_lengths) if all_speech_latent_lengths else 1
)
speech_masks_np = np.zeros(
(len(all_speech_waveforms), max_latent_len), dtype=np.bool_
)
for i, L_lat in enumerate(all_speech_latent_lengths):
speech_masks_np[i, :L_lat] = True
speech_tensors_tensor = torch.tensor(padded_speeches, dtype=torch.float32)
speech_masks_tensor = torch.tensor(speech_masks_np, dtype=torch.bool)
speeches_loss_input_np = np.zeros_like(speech_masks_np, dtype=np.bool_)
for i, is_target in enumerate(per_segment_is_target):
if is_target:
speeches_loss_input_np[i] = speech_masks_np[i]
speeches_loss_input_tensor = torch.tensor(
speeches_loss_input_np, dtype=torch.bool
)
if (
self.compute_semantics
and hasattr(self.processor, "semantic_tokenizer")
and (self.processor.semantic_tokenizer is not None)
):
sem_feats: List[np.ndarray] = []
for w in all_speech_waveforms:
try:
sem = self.processor.semantic_tokenizer.encode(w)
sem = np.asarray(sem, dtype=np.float32)
except Exception:
sem = np.zeros((0, self.semantic_vae_dim), dtype=np.float32)
if sem.ndim != 2:
raise RuntimeError(
f"Semantic tokenizer returned unexpected shape {sem .shape }. Expect [T, D]."
)
L = sem.shape[0]
D = sem.shape[1]
if D != self.semantic_vae_dim:
if D < self.semantic_vae_dim:
pad_d = np.zeros(
(L, self.semantic_vae_dim - D), dtype=np.float32
)
sem = np.concatenate([sem, pad_d], axis=1)
else:
sem = sem[:, : self.semantic_vae_dim]
if L < max_latent_len:
pad = np.zeros(
(max_latent_len - L, self.semantic_vae_dim),
dtype=np.float32,
)
sem = np.concatenate([sem, pad], axis=0)
elif L > max_latent_len:
sem = sem[:max_latent_len]
sem_feats.append(sem.astype(np.float32))
speech_semantic_tensors = torch.tensor(
np.stack(sem_feats, axis=0), dtype=torch.float32
)
else:
raise RuntimeError(
"Semantic features are required but could not be computed. Ensure processor.semantic_tokenizer is available or precompute and provide features."
)
else:
speech_tensors_tensor = None
speech_masks_tensor = None
speeches_loss_input_tensor = None
speech_semantic_tensors = None
if self.debug_checks:
assert (input_ids_tensor >= 0).all(), "input_ids contains negative indices"
if speech_tensors_tensor is not None:
assert (
speech_tensors_tensor.dim() == 2
), "Expected speech_tensors 2D [segments, samples]"
return {
"input_ids": input_ids_tensor,
"attention_mask": attention_mask_tensor,
"speech_tensors": speech_tensors_tensor,
"speech_masks": speech_masks_tensor,
"speech_semantic_tensors": speech_semantic_tensors,
"acoustic_input_mask": acoustic_input_mask_tensor,
"acoustic_loss_mask": acoustic_loss_mask_tensor,
"speeches_loss_input": speeches_loss_input_tensor,
}
' QWEN3Vox_AcousticTokenizer model configuration'
from typing import Dict, List, Optional, Tuple
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
logger = logging.get_logger(__name__)
class QWEN3VoxAcousticTokenizerConfig(PretrainedConfig):
model_type = 'vibevoice_acoustic_tokenizer'
def __init__(
self,
channels: int = 1,
corpus_normalize: float = 0.0,
causal: bool = True,
vae_dim: int = 64,
fix_std: float = 0.5,
std_dist_type: str = "gaussian",
mixer_layer: str = "depthwise_conv",
conv_norm: str = "none",
pad_mode: str = "constant",
disable_last_norm: bool = True,
layernorm: str = "RMSNorm",
layernorm_eps: float = 1e-05,
layernorm_elementwise_affine: bool = True,
conv_bias: bool = True,
layer_scale_init_value: float = 1e-06,
weight_init_value: float = 0.01,
encoder_n_filters: int = 32,
encoder_ratios: Optional[List[int]] = [8, 5, 5, 4, 2, 2],
encoder_depths: str = "3-3-3-3-3-3-8",
decoder_n_filters: int = 32,
decoder_ratios: Optional[List[int]] = None,
decoder_depths: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.channels = channels
self.corpus_normalize = corpus_normalize
self.causal = causal
self.vae_dim = vae_dim
self.fix_std = fix_std
self.std_dist_type = std_dist_type
self.conv_norm = conv_norm
self.pad_mode = pad_mode
self.layernorm_eps = layernorm_eps
self.disable_last_norm = disable_last_norm
self.layernorm = layernorm
self.layernorm_elementwise_affine = layernorm_elementwise_affine
self.conv_bias = conv_bias
self.layer_scale_init_value = layer_scale_init_value
self.weight_init_value = weight_init_value
self.mixer_layer = mixer_layer
self.encoder_n_filters = encoder_n_filters
self.encoder_ratios = encoder_ratios
self.encoder_depths = encoder_depths
self.decoder_ratios = (
decoder_ratios if decoder_ratios is not None else encoder_ratios
)
self.decoder_n_filters = decoder_n_filters
self.decoder_depths = decoder_depths
class QWEN3VoxSemanticTokenizerConfig(PretrainedConfig):
model_type = 'vibevoice_semantic_tokenizer'
def __init__(
self,
channels: int = 1,
corpus_normalize: float = 0.0,
causal: bool = True,
vae_dim: int = 64,
fix_std: float = 0,
std_dist_type: str = "none",
mixer_layer: str = "depthwise_conv",
conv_norm: str = "none",
pad_mode: str = "constant",
disable_last_norm: bool = True,
layernorm: str = "RMSNorm",
layernorm_eps: float = 1e-05,
layernorm_elementwise_affine: bool = True,
conv_bias: bool = True,
layer_scale_init_value: float = 1e-06,
weight_init_value: float = 0.01,
encoder_n_filters: int = 32,
encoder_ratios: Optional[List[int]] = [8, 5, 5, 4, 2, 2],
encoder_depths: str = "3-3-3-3-3-3-8",
**kwargs,
):
super().__init__(**kwargs)
self.channels = channels
self.corpus_normalize = corpus_normalize
self.causal = causal
self.vae_dim = vae_dim
self.fix_std = fix_std
self.std_dist_type = std_dist_type
self.conv_norm = conv_norm
self.pad_mode = pad_mode
self.layernorm_eps = layernorm_eps
self.disable_last_norm = disable_last_norm
self.layernorm = layernorm
self.layernorm_elementwise_affine = layernorm_elementwise_affine
self.conv_bias = conv_bias
self.layer_scale_init_value = layer_scale_init_value
self.weight_init_value = weight_init_value
self.mixer_layer = mixer_layer
self.encoder_n_filters = encoder_n_filters
self.encoder_ratios = encoder_ratios
self.encoder_depths = encoder_depths
class QWEN3VoxDiffusionHeadConfig(PretrainedConfig):
model_type = 'vibevoice_diffusion_head'
def __init__(
self,
hidden_size=768,
head_layers=4,
head_ffn_ratio=3.0,
rms_norm_eps=1e-05,
latent_size=64,
speech_vae_dim=None,
prediction_type="v_prediction",
diffusion_type="ddpm",
ddpm_num_steps=1000,
ddpm_num_inference_steps=30,
ddpm_beta_schedule="cosine",
ddpm_batch_mul=4,
**kwargs,
):
self.hidden_size = hidden_size
self.head_layers = head_layers
self.head_ffn_ratio = head_ffn_ratio
self.rms_norm_eps = rms_norm_eps
self.latent_size = latent_size
self.speech_vae_dim = speech_vae_dim
self.prediction_type = prediction_type
self.diffusion_type = diffusion_type
self.ddpm_num_steps = ddpm_num_steps
self.ddpm_num_inference_steps = ddpm_num_inference_steps
self.ddpm_beta_schedule = ddpm_beta_schedule
self.ddpm_batch_mul = ddpm_batch_mul
super().__init__(**kwargs)
class QWEN3VoxConfig(PretrainedConfig):
model_type = 'vibevoice'
is_composition = True
sub_configs = {
"acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig,
"semantic_tokenizer_config": QWEN3VoxSemanticTokenizerConfig,
"decoder_config": Qwen2Config,
"diffusion_head_config": QWEN3VoxDiffusionHeadConfig,
}
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
acoustic_tokenizer_config=None,
semantic_tokenizer_config=None,
decoder_config=None,
diffusion_head_config=None,
**kwargs,
):
kwargs["_attn_implementation_autoset"] = False
if acoustic_tokenizer_config is None:
self.acoustic_tokenizer_config = self.sub_configs[
"acoustic_tokenizer_config"
]()
elif isinstance(acoustic_tokenizer_config, dict):
acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer'
self.acoustic_tokenizer_config = self.sub_configs[
"acoustic_tokenizer_config"
](**acoustic_tokenizer_config)
elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig):
self.acoustic_tokenizer_config = acoustic_tokenizer_config
if semantic_tokenizer_config is None:
self.semantic_tokenizer_config = self.sub_configs[
"semantic_tokenizer_config"
]()
elif isinstance(semantic_tokenizer_config, dict):
semantic_tokenizer_config["model_type"] = 'vibevoice_semantic_tokenizer'
self.semantic_tokenizer_config = self.sub_configs[
"semantic_tokenizer_config"
](**semantic_tokenizer_config)
elif isinstance(semantic_tokenizer_config, QWEN3VoxSemanticTokenizerConfig):
self.semantic_tokenizer_config = semantic_tokenizer_config
if decoder_config is None:
self.decoder_config = self.sub_configs["decoder_config"]()
elif isinstance(decoder_config, dict):
if decoder_config.get("model_type", "") == "qwen2":
self.decoder_config = Qwen2Config(**decoder_config)
else:
raise ValueError(
f"Unsupported decoder model type: {decoder_config .get ('model_type','')}"
)
elif isinstance(decoder_config, (Qwen2Config,)):
self.decoder_config = decoder_config
if diffusion_head_config is None:
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
elif isinstance(diffusion_head_config, dict):
diffusion_head_config["model_type"] = 'vibevoice_diffusion_head'
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](
**diffusion_head_config
)
elif isinstance(diffusion_head_config, QWEN3VoxDiffusionHeadConfig):
self.diffusion_head_config = diffusion_head_config
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64)
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, "vae_dim", 128)
super().__init__(**kwargs)
class QWEN3VoxASRConfig(PretrainedConfig):
model_type = 'vibevoice'
is_composition = True
sub_configs = {
"acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig,
"semantic_tokenizer_config": QWEN3VoxSemanticTokenizerConfig,
"decoder_config": Qwen2Config,
}
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
acoustic_tokenizer_config=None,
semantic_tokenizer_config=None,
decoder_config=None,
**kwargs,
):
kwargs["_attn_implementation_autoset"] = False
if acoustic_tokenizer_config is None:
self.acoustic_tokenizer_config = self.sub_configs[
"acoustic_tokenizer_config"
]()
elif isinstance(acoustic_tokenizer_config, dict):
acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer'
self.acoustic_tokenizer_config = self.sub_configs[
"acoustic_tokenizer_config"
](**acoustic_tokenizer_config)
elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig):
self.acoustic_tokenizer_config = acoustic_tokenizer_config
if semantic_tokenizer_config is None:
self.semantic_tokenizer_config = self.sub_configs[
"semantic_tokenizer_config"
]()
elif isinstance(semantic_tokenizer_config, dict):
semantic_tokenizer_config["model_type"] = 'vibevoice_semantic_tokenizer'
self.semantic_tokenizer_config = self.sub_configs[
"semantic_tokenizer_config"
](**semantic_tokenizer_config)
elif isinstance(semantic_tokenizer_config, QWEN3VoxSemanticTokenizerConfig):
self.semantic_tokenizer_config = semantic_tokenizer_config
if decoder_config is None:
self.decoder_config = self.sub_configs["decoder_config"]()
elif isinstance(decoder_config, dict):
if decoder_config.get("model_type", "") == "qwen2":
self.decoder_config = Qwen2Config(**decoder_config)
else:
raise ValueError(
f"Unsupported decoder model type: {decoder_config .get ('model_type','')}"
)
elif isinstance(decoder_config, Qwen2Config):
self.decoder_config = decoder_config
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64)
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, "vae_dim", 128)
super().__init__(**kwargs)
def get_text_config(self, decoder: bool = False):
return self.decoder_config
@property
def vocab_size(self):
return self.decoder_config.vocab_size
@property
def num_attention_heads(self):
return self.decoder_config.num_attention_heads
@property
def num_key_value_heads(self):
return self.decoder_config.num_key_value_heads
@property
def hidden_size(self):
return self.decoder_config.hidden_size
@property
def num_hidden_layers(self):
return self.decoder_config.num_hidden_layers
@property
def head_dim(self):
return getattr(
self.decoder_config,
"head_dim",
self.hidden_size // self.num_attention_heads,
)
__all__ = [
'QWEN3VoxAcousticTokenizerConfig',
'QWEN3VoxSemanticTokenizerConfig',
'QWEN3VoxDiffusionHeadConfig',
'QWEN3VoxConfig',
'QWEN3VoxASRConfig',
]
import torch
import asyncio
from queue import Queue
from typing import TYPE_CHECKING, Optional
from transformers.generation import BaseStreamer
class AudioStreamer(BaseStreamer):
def __init__(
self,
batch_size: int,
stop_signal: Optional[any] = None,
timeout: Optional[float] = None,
):
self.batch_size = batch_size
self.stop_signal = stop_signal
self.timeout = timeout
self.audio_queues = [Queue() for _ in range(batch_size)]
self.finished_flags = [False for _ in range(batch_size)]
self.sample_indices_map = {}
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and (not self.finished_flags[idx]):
audio_chunk = audio_chunks[i].detach().cpu()
self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
def end(self, sample_indices: Optional[torch.Tensor] = None):
if sample_indices is None:
for idx in range(self.batch_size):
if not self.finished_flags[idx]:
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
self.finished_flags[idx] = True
else:
for sample_idx in sample_indices:
idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
if idx < self.batch_size and (not self.finished_flags[idx]):
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
self.finished_flags[idx] = True
def __iter__(self):
return AudioBatchIterator(self)
def get_stream(self, sample_idx: int):
if sample_idx >= self.batch_size:
raise ValueError(
f"Sample index {sample_idx } exceeds batch size {self .batch_size }"
)
return AudioSampleIterator(self, sample_idx)
class AudioSampleIterator:
def __init__(self, streamer: AudioStreamer, sample_idx: int):
self.streamer = streamer
self.sample_idx = sample_idx
def __iter__(self):
return self
def __next__(self):
value = self.streamer.audio_queues[self.sample_idx].get(
timeout=self.streamer.timeout
)
if value == self.streamer.stop_signal:
raise StopIteration()
return value
class AudioBatchIterator:
def __init__(self, streamer: AudioStreamer):
self.streamer = streamer
self.active_samples = set(range(streamer.batch_size))
def __iter__(self):
return self
def __next__(self):
if not self.active_samples:
raise StopIteration()
batch_chunks = {}
samples_to_remove = set()
for idx in self.active_samples:
try:
value = self.streamer.audio_queues[idx].get(block=False)
if value == self.streamer.stop_signal:
samples_to_remove.add(idx)
else:
batch_chunks[idx] = value
except:
pass
self.active_samples -= samples_to_remove
if batch_chunks:
return batch_chunks
elif self.active_samples:
import time
time.sleep(0.01)
return self.__next__()
else:
raise StopIteration()
class AsyncAudioStreamer(AudioStreamer):
def __init__(
self,
batch_size: int,
stop_signal: Optional[any] = None,
timeout: Optional[float] = None,
):
super().__init__(batch_size, stop_signal, timeout)
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
self.loop = asyncio.get_running_loop()
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and (not self.finished_flags[idx]):
audio_chunk = audio_chunks[i].detach().cpu()
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, audio_chunk
)
def end(self, sample_indices: Optional[torch.Tensor] = None):
if sample_indices is None:
indices_to_end = range(self.batch_size)
else:
indices_to_end = [
s.item() if torch.is_tensor(s) else s for s in sample_indices
]
for idx in indices_to_end:
if idx < self.batch_size and (not self.finished_flags[idx]):
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, self.stop_signal
)
self.finished_flags[idx] = True
async def get_stream(self, sample_idx: int):
if sample_idx >= self.batch_size:
raise ValueError(
f"Sample index {sample_idx } exceeds batch size {self .batch_size }"
)
while True:
value = await self.audio_queues[sample_idx].get()
if value == self.stop_signal:
break
yield value
def __aiter__(self):
return AsyncAudioBatchIterator(self)
class AsyncAudioBatchIterator:
def __init__(self, streamer: AsyncAudioStreamer):
self.streamer = streamer
self.active_samples = set(range(streamer.batch_size))
def __aiter__(self):
return self
async def __anext__(self):
if not self.active_samples:
raise StopAsyncIteration()
batch_chunks = {}
samples_to_remove = set()
tasks = {
idx: asyncio.create_task(self._get_chunk(idx))
for idx in self.active_samples
}
done, pending = await asyncio.wait(
tasks.values(),
return_when=asyncio.FIRST_COMPLETED,
timeout=self.streamer.timeout,
)
for task in pending:
task.cancel()
for idx, task in tasks.items():
if task in done:
try:
value = await task
if value == self.streamer.stop_signal:
samples_to_remove.add(idx)
else:
batch_chunks[idx] = value
except asyncio.CancelledError:
pass
self.active_samples -= samples_to_remove
if batch_chunks:
return batch_chunks
elif self.active_samples:
return await self.__anext__()
else:
raise StopAsyncIteration()
async def _get_chunk(self, idx):
return await self.streamer.audio_queues[idx].get()
'Tokenization classes for QWEN3Vox.'
from typing import List, Optional, Union
from transformers.utils import logging
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
logger = logging.get_logger(__name__)
class QWEN3VoxTextTokenizer(Qwen2Tokenizer):
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
add_prefix_space=False,
add_special_tokens=True,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
add_special_tokens=add_special_tokens,
**kwargs,
)
self._add_q3_sp_tok()
def _add_q3_sp_tok(self):
special_tokens = {
"additional_special_tokens": [
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
]
}
num_added = self.add_special_tokens(special_tokens)
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
self._eos_id = self.convert_tokens_to_ids("<|endoftext|>")
return num_added
@property
def eos_id(self) -> int:
return self._eos_id
@property
def speech_start_id(self) -> int:
return self._speech_start_id
@property
def speech_end_id(self) -> int:
return self._speech_end_id
@property
def speech_diffusion_id(self) -> int:
return self._speech_diffusion_id
@property
def pad_id(self) -> int:
return -100
class QWEN3VoxTextTokenizerFast(Qwen2TokenizerFast):
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
add_prefix_space=False,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
self._add_q3_sp_tok()
def _add_q3_sp_tok(self):
special_tokens = {
"additional_special_tokens": [
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
]
}
num_added = self.add_special_tokens(special_tokens)
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
self._eos_id = self.eos_token_id
self._pad_id = self.convert_tokens_to_ids("<|image_pad|>")
return num_added
@property
def eos_id(self) -> int:
return self._eos_id
@property
def speech_start_id(self) -> int:
return self._speech_start_id
@property
def speech_end_id(self) -> int:
return self._speech_end_id
@property
def speech_diffusion_id(self) -> int:
return self._speech_diffusion_id
@property
def pad_id(self) -> int:
return self._pad_id
QWEN3VoxASRTextTokenizerFast = QWEN3VoxTextTokenizerFast
__all__ = [
'QWEN3VoxTextTokenizer',
'QWEN3VoxTextTokenizerFast',
]
"Utilities for loading fine-tuned LoRA adapters and connector weights."
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
from transformers.utils import logging
logger = logging.get_logger(__name__)
@dataclass
class _LoadReport:
language_model: bool = False
diffusion_head_lora: bool = False
diffusion_head_full: bool = False
acoustic_connector: bool = False
semantic_connector: bool = False
adapter_root: Optional[Path] = None
class _DiffusionHeadForwardShim(nn.Module):
def __init__(self, base: nn.Module):
super().__init__()
self.base = base
def forward(self, *args, **kwargs):
if len(args) >= 3:
noisy_images, timesteps, condition = args[:3]
else:
noisy_images = kwargs.get("noisy_images")
timesteps = kwargs.get("timesteps")
condition = kwargs.get("condition")
return self.base(noisy_images, timesteps, condition)
def _resolve_adapter_root(checkpoint_path: Path) -> Path:
if checkpoint_path.is_file():
checkpoint_path = checkpoint_path.parent
if (checkpoint_path / "lora").exists():
return checkpoint_path / "lora"
return checkpoint_path
def _load_connector(
module: Optional[nn.Module], path: Path, device: torch.device
) -> bool:
if module is None or not path.exists():
return False
state_dict = torch.load(path, map_location=device)
missing, unexpected = module.load_state_dict(state_dict, strict=False)
if missing:
logger.warning(f"Connector load missing keys: {missing }")
if unexpected:
logger.warning(f"Connector load unexpected keys: {unexpected }")
module.to(device)
return True
def _load_diffusion_head(
model, adapter_root: Path, device: torch.device, report: _LoadReport
) -> None:
diff_dir = adapter_root / "diffusion_head"
adapter_config = diff_dir / "adapter_config.json"
adapter_model = diff_dir / "adapter_model.bin"
adapter_model_safetensors = diff_dir / "adapter_model.safetensors"
try:
from peft import PeftModel
except ImportError as exc:
raise RuntimeError(
"peft is required to load diffusion head adapters but is not installed"
) from exc
if adapter_config.exists() and (
adapter_model.exists() or adapter_model_safetensors.exists()
):
logger.warning(
f"Skipping diffusion-head LoRA at {diff_dir }; "
"PeftModel.from_pretrained is not allowed in miner.py (use full weights .bin)."
)
return
full_path = diff_dir / "diffusion_head_full.bin"
if not full_path.exists():
full_path = adapter_root / "diffusion_head_full.bin"
if full_path.exists():
logger.info(f"Loading full diffusion head weights from {full_path }")
state_dict = torch.load(full_path, map_location=device)
missing, unexpected = model.model.prediction_head.load_state_dict(
state_dict, strict=False
)
if missing:
logger.warning(f"Diffusion head load missing keys: {missing }")
if unexpected:
logger.warning(f"Diffusion head load unexpected keys: {unexpected }")
model.model.prediction_head.to(device)
report.diffusion_head_full = True
def _load_language_model(
model, adapter_root: Path, device: torch.device, report: _LoadReport
) -> None:
config_file = adapter_root / "adapter_config.json"
bin_file = adapter_root / "adapter_model.bin"
safe_tensors_file = adapter_root / "adapter_model.safetensors"
if not (config_file.exists() and (bin_file.exists() or safe_tensors_file.exists())):
return
try:
from peft import PeftConfig, PeftModel, TaskType
except ImportError as exc:
raise RuntimeError(
"peft is required to load language model adapters but is not installed"
) from exc
logger.warning(
f"Skipping language-model LoRA at {adapter_root }; "
"PeftModel.from_pretrained is not allowed in miner.py (use full weights .bin)."
)
def load_lora_assets(
model, checkpoint_dir: str, device: Optional[torch.device] = None
) -> _LoadReport:
adapter_root = _resolve_adapter_root(Path(checkpoint_dir))
if not adapter_root.exists():
raise FileNotFoundError(f"Adapter directory not found: {adapter_root }")
inferred_device = device or next(model.parameters()).device
report = _LoadReport(adapter_root=adapter_root)
_load_language_model(model, adapter_root, inferred_device, report)
_load_diffusion_head(model, adapter_root, inferred_device, report)
ac_path = adapter_root / "acoustic_connector" / "pytorch_model.bin"
if _load_connector(
getattr(model.model, "acoustic_connector", None), ac_path, inferred_device
):
report.acoustic_connector = True
se_path = adapter_root / "semantic_connector" / "pytorch_model.bin"
if _load_connector(
getattr(model.model, "semantic_connector", None), se_path, inferred_device
):
report.semantic_connector = True
if not any(report.__dict__.values()):
logger.warning(
"No adapter assets were loaded. Ensure the checkpoint directory is correct and contains LoRA weights."
)
return report
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import deprecate
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput,
)
def betas_for_alpha_bar(
num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine"
):
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
elif alpha_transform_type == "cauchy":
def alpha_bar_fn(t, gamma=1, mu=3):
snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9)
return 1 - 1 / (math.exp(snr) + 1.1)
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t, mu=0, b=1):
snr = mu - b * math.copysign(1, 0.5 - t) * math.log(
1 - 2 * abs(t - 0.5) * 0.98
)
return 1 - 1 / (math.exp(snr) + 1.02)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type }")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
def rescale_zero_terminal_snr(betas):
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
alphas_bar_sqrt -= alphas_bar_sqrt_T
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero",
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
):
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type } is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate(
"algorithm_types dpmsolver and sde-dpmsolver",
"1.0.0",
deprecation_message,
)
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(
beta_start, beta_end, num_train_timesteps, dtype=torch.float32
)
elif beta_schedule == "scaled_linear":
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine":
self.betas = betas_for_alpha_bar(
num_train_timesteps, alpha_transform_type="cosine"
)
elif beta_schedule == "cauchy":
self.betas = betas_for_alpha_bar(
num_train_timesteps, alpha_transform_type="cauchy"
)
elif beta_schedule == "laplace":
self.betas = betas_for_alpha_bar(
num_train_timesteps, alpha_transform_type="laplace"
)
else:
raise NotImplementedError(
f"{beta_schedule } is not implemented for {self .__class__ }"
)
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
if rescale_betas_zero_snr:
self.alphas_cumprod[-1] = 2 ** (-24)
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
self.init_noise_sigma = 1.0
if algorithm_type not in [
"dpmsolver",
"dpmsolver++",
"sde-dpmsolver",
"sde-dpmsolver++",
]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
raise NotImplementedError(
f"{algorithm_type } is not implemented for {self .__class__ }"
)
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.register_to_config(solver_type="midpoint")
else:
raise NotImplementedError(
f"{solver_type } is not implemented for {self .__class__ }"
)
if (
algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]
and final_sigmas_type == "zero"
):
raise ValueError(
f"`final_sigmas_type` {final_sigmas_type } is not supported for `algorithm_type` {algorithm_type }. Please choose `sigma_min` instead."
)
self.num_inference_steps = None
timesteps = np.linspace(
0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32
)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu")
@property
def step_index(self):
return self._step_index
@property
def begin_index(self):
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
self._begin_index = begin_index
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
):
if num_inference_steps is None and timesteps is None:
raise ValueError(
"Must pass exactly one of `num_inference_steps` or `timesteps`."
)
if num_inference_steps is not None and timesteps is not None:
raise ValueError(
"Can only pass one of `num_inference_steps` or `custom_timesteps`."
)
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError(
"Cannot use `timesteps` with `config.use_karras_sigmas = True`"
)
if timesteps is not None and self.config.use_lu_lambdas:
raise ValueError(
"Cannot use `timesteps` with `config.use_lu_lambdas = True`"
)
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.int64)
else:
clipped_idx = torch.searchsorted(
torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped
)
last_timestep = (
(self.config.num_train_timesteps - clipped_idx).numpy().item()
)
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, last_timestep - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = last_timestep // (num_inference_steps + 1)
timesteps = (
(np.arange(0, num_inference_steps + 1) * step_ratio)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
timesteps = (
np.arange(last_timestep, 0, -step_ratio)
.round()
.copy()
.astype(np.int64)
)
timesteps -= 1
else:
raise ValueError(
f"{self .config .timestep_spacing } is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
if self.config.use_karras_sigmas:
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(
in_sigmas=sigmas, num_inference_steps=num_inference_steps
)
timesteps = np.array(
[self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
).round()
elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(
in_lambdas=lambdas, num_inference_steps=num_inference_steps
)
sigmas = np.exp(lambdas)
timesteps = np.array(
[self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
).round()
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self .config .final_sigmas_type }"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(
device=device, dtype=torch.int64
)
self.num_inference_steps = len(timesteps)
self.model_outputs = [None] * self.config.solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu")
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float()
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs()
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(s, min=1, max=self.config.sample_max_value)
s = s.unsqueeze(1)
sample = torch.clamp(sample, -s, s) / s
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
def _sigma_to_t(self, sigma, log_sigmas):
log_sigma = np.log(np.maximum(sigma, 1e-10))
dists = log_sigma - log_sigmas[:, np.newaxis]
low_idx = (
np.cumsum(dists >= 0, axis=0)
.argmax(axis=0)
.clip(max=log_sigmas.shape[0] - 2)
)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / (sigma**2 + 1) ** 0.5
sigma_t = sigma * alpha_t
return (alpha_t, sigma_t)
def _convert_to_karras(
self, in_sigmas: torch.Tensor, num_inference_steps
) -> torch.Tensor:
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def _convert_to_lu(
self, in_lambdas: torch.Tensor, num_inference_steps
) -> torch.Tensor:
lambda_min: float = in_lambdas[-1].item()
lambda_max: float = in_lambdas[0].item()
rho = 1.0
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = lambda_min ** (1 / rho)
max_inv_rho = lambda_max ** (1 / rho)
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return lambdas
def convert_model_output(
self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs
) -> torch.Tensor:
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config.prediction_type == "epsilon":
if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3]
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self .config .prediction_type } must be one of `epsilon`, `sample`, or `v_prediction` for the DPMSolverMultistepScheduler."
)
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
if self.config.prediction_type == "epsilon":
if self.config.variance_type in ["learned", "learned_range"]:
epsilon = model_output[:, :3]
else:
epsilon = model_output
elif self.config.prediction_type == "sample":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = (sample - alpha_t * model_output) / sigma_t
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = alpha_t * model_output + sigma_t * sample
else:
raise ValueError(
f"prediction_type given as {self .config .prediction_type } must be one of `epsilon`, `sample`, or `v_prediction` for the DPMSolverMultistepScheduler."
)
if self.config.thresholding:
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t
return epsilon
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (
sigma_t / sigma_s * sample
- alpha_t * (torch.exp(-h) - 1.0) * model_output
)
elif self.config.algorithm_type == "dpmsolver":
x_t = (
alpha_t / alpha_s * sample
- sigma_t * (torch.exp(h) - 1.0) * model_output
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
sigma_t / sigma_s * torch.exp(-h) * sample
+ alpha_t * (1 - torch.exp(-2.0 * h)) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
x_t = (
alpha_t / alpha_s * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
m0, m1 = (model_output_list[-1], model_output_list[-2])
h, h_0 = (lambda_t - lambda_s0, lambda_s0 - lambda_s1)
r0 = h_0 / h
D0, D1 = (m0, 1.0 / r0 * (m0 - m1))
if self.config.algorithm_type == "dpmsolver++":
if self.config.solver_type == "midpoint":
x_t = (
sigma_t / sigma_s0 * sample
- alpha_t * (torch.exp(-h) - 1.0) * D0
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
x_t = (
sigma_t / sigma_s0 * sample
- alpha_t * (torch.exp(-h) - 1.0) * D0
+ alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0) * D1
)
elif self.config.algorithm_type == "dpmsolver":
if self.config.solver_type == "midpoint":
x_t = (
alpha_t / alpha_s0 * sample
- sigma_t * (torch.exp(h) - 1.0) * D0
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
x_t = (
alpha_t / alpha_s0 * sample
- sigma_t * (torch.exp(h) - 1.0) * D0
- sigma_t * ((torch.exp(h) - 1.0) / h - 1.0) * D1
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
sigma_t / sigma_s0 * torch.exp(-h) * sample
+ alpha_t * (1 - torch.exp(-2.0 * h)) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
sigma_t / sigma_s0 * torch.exp(-h) * sample
+ alpha_t * (1 - torch.exp(-2.0 * h)) * D0
+ alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
alpha_t / alpha_s0 * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- sigma_t * (torch.exp(h) - 1.0) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
elif self.config.solver_type == "heun":
x_t = (
alpha_t / alpha_s0 * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing`sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
m0, m1, m2 = (
model_output_list[-1],
model_output_list[-2],
model_output_list[-3],
)
h, h_0, h_1 = (
lambda_t - lambda_s0,
lambda_s0 - lambda_s1,
lambda_s1 - lambda_s2,
)
r0, r1 = (h_0 / h, h_1 / h)
D0 = m0
D1_0, D1_1 = (1.0 / r0 * (m0 - m1), 1.0 / r1 * (m1 - m2))
D1 = D1_0 + r0 / (r0 + r1) * (D1_0 - D1_1)
D2 = 1.0 / (r0 + r1) * (D1_0 - D1_1)
if self.config.algorithm_type == "dpmsolver++":
x_t = (
sigma_t / sigma_s0 * sample
- alpha_t * (torch.exp(-h) - 1.0) * D0
+ alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0) * D1
- alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5) * D2
)
elif self.config.algorithm_type == "dpmsolver":
x_t = (
alpha_t / alpha_s0 * sample
- sigma_t * (torch.exp(h) - 1.0) * D0
- sigma_t * ((torch.exp(h) - 1.0) / h - 1.0) * D1
- sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5) * D2
)
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
return step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
lower_order_final = self.step_index == len(self.timesteps) - 1 and (
self.config.euler_at_final
or (self.config.lower_order_final and len(self.timesteps) < 15)
or self.config.final_sigmas_type == "zero"
)
lower_order_second = (
self.step_index == len(self.timesteps) - 2
and self.config.lower_order_final
and (len(self.timesteps) < 15)
)
model_output = self.convert_model_output(model_output, sample=sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
sample = sample.to(torch.float32)
if (
self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]
and variance_noise is None
):
noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=torch.float32,
)
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
else:
noise = None
if (
self.config.solver_order == 1
or self.lower_order_nums < 1
or lower_order_final
):
prev_sample = self.dpm_solver_first_order_update(
model_output, sample=sample, noise=noise
)
elif (
self.config.solver_order == 2
or self.lower_order_nums < 2
or lower_order_second
):
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, sample=sample, noise=noise
)
else:
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, sample=sample
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
prev_sample = prev_sample.to(model_output.dtype)
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
alpha_t = alpha_t[timesteps].flatten()
while len(alpha_t.shape) < len(original_samples.shape):
alpha_t = alpha_t.unsqueeze(-1)
sigma_t = sigma_t[timesteps].flatten()
while len(sigma_t.shape) < len(original_samples.shape):
sigma_t = sigma_t.unsqueeze(-1)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def get_velocity(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
alpha_t = alpha_t[timesteps].flatten()
while len(alpha_t.shape) < len(original_samples.shape):
alpha_t = alpha_t.unsqueeze(-1)
sigma_t = sigma_t[timesteps].flatten()
while len(sigma_t.shape) < len(original_samples.shape):
sigma_t = sigma_t.unsqueeze(-1)
velocity = alpha_t * noise - sigma_t * original_samples
return velocity
def __len__(self):
return self.config.num_train_timesteps
'\nProcessor class for QWEN3Vox models.\n'
import os
import json
import warnings
from typing import List, Optional, Union, Dict, Any
import numpy as np
import torch
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.utils import logging
logger = logging.get_logger(__name__)
class AudioNormalizer:
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-06):
self.target_dB_FS = target_dB_FS
self.eps = eps
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
rms = np.sqrt(np.mean(audio**2))
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
normalized_audio = audio * scalar
return (normalized_audio, rms, scalar)
def avoid_clipping(
self, audio: np.ndarray, scalar: Optional[float] = None
) -> tuple:
if scalar is None:
max_val = np.max(np.abs(audio))
if max_val > 1.0:
scalar = max_val + self.eps
else:
scalar = 1.0
return (audio / scalar, scalar)
def __call__(self, audio: np.ndarray) -> np.ndarray:
audio, _, _ = self.tailor_dB_FS(audio)
audio, _ = self.avoid_clipping(audio)
return audio
class QWEN3VoxTokenizerProcessor(FeatureExtractionMixin):
model_input_names = ["input_features"]
def __init__(
self,
sampling_rate: int = 22050,
normalize_audio: bool = True,
target_dB_FS: float = -25,
eps: float = 1e-06,
**kwargs,
):
super().__init__(**kwargs)
self.sampling_rate = sampling_rate
self.normalize_audio = normalize_audio
if self.normalize_audio:
self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps)
else:
self.normalizer = None
self.feature_extractor_dict = {
"sampling_rate": sampling_rate,
"normalize_audio": normalize_audio,
"target_dB_FS": target_dB_FS,
"eps": eps,
}
def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
if len(audio.shape) == 1:
return audio
elif len(audio.shape) == 2:
if audio.shape[0] == 2:
return np.mean(audio, axis=0)
elif audio.shape[1] == 2:
return np.mean(audio, axis=1)
elif audio.shape[0] == 1:
return audio.squeeze(0)
elif audio.shape[1] == 1:
return audio.squeeze(1)
else:
raise ValueError(f"Unexpected audio shape: {audio .shape }")
else:
raise ValueError(f"Audio should be 1D or 2D, got shape: {audio .shape }")
def _process_single_audio(
self, audio: Union[np.ndarray, List[float]]
) -> np.ndarray:
if not isinstance(audio, np.ndarray):
audio = np.array(audio, dtype=np.float32)
else:
audio = audio.astype(np.float32)
audio = self._ensure_mono(audio)
if self.normalize_audio and self.normalizer is not None:
audio = self.normalizer(audio)
return audio
def __call__(
self,
audio: Union[
str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str]
] = None,
sampling_rate: Optional[int] = None,
return_tensors: Optional[str] = None,
**kwargs,
):
if audio is None:
raise ValueError("Audio input is required")
if sampling_rate is not None and sampling_rate != self.sampling_rate:
logger.warning(
f"Input sampling rate ({sampling_rate }) differs from expected sampling rate ({self .sampling_rate }). Please resample your audio."
)
if isinstance(audio, str):
audio = self._load_audio_from_path(audio)
is_batched = False
elif isinstance(audio, list):
if len(audio) == 0:
raise ValueError("Empty audio list provided")
if all((isinstance(item, str) for item in audio)):
audio = [self._load_audio_from_path(path) for path in audio]
is_batched = True
else:
is_batched = isinstance(audio[0], (np.ndarray, list))
else:
is_batched = False
if is_batched:
processed_audio = [self._process_single_audio(a) for a in audio]
else:
processed_audio = [self._process_single_audio(audio)]
if return_tensors == "pt":
if len(processed_audio) == 1:
input_features = (
torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1)
)
else:
input_features = torch.stack(
[torch.from_numpy(a) for a in processed_audio]
).unsqueeze(1)
elif return_tensors == "np":
if len(processed_audio) == 1:
input_features = processed_audio[0][np.newaxis, np.newaxis, :]
else:
input_features = np.stack(processed_audio)[:, np.newaxis, :]
else:
input_features = (
processed_audio[0] if len(processed_audio) == 1 else processed_audio
)
outputs = {"audio": input_features}
return outputs
def _load_audio_from_path(self, audio_path: str) -> np.ndarray:
file_ext = os.path.splitext(audio_path)[1].lower()
if file_ext in [".wav", ".mp3", ".flac", ".m4a", ".ogg"]:
import librosa
audio_array, sr = librosa.load(audio_path, sr=self.sampling_rate, mono=True)
return audio_array
elif file_ext == ".pt":
audio_tensor = torch.load(audio_path, map_location="cpu").squeeze()
if isinstance(audio_tensor, torch.Tensor):
audio_array = audio_tensor.numpy()
else:
audio_array = np.array(audio_tensor)
return audio_array.astype(np.float32)
elif file_ext == ".npy":
audio_array = np.load(audio_path)
return audio_array.astype(np.float32)
else:
raise ValueError(
f"Unsupported file format: {file_ext }. Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz"
)
def preprocess_audio(
self,
audio_path_or_array: Union[str, np.ndarray],
normalize: Optional[bool] = None,
) -> np.ndarray:
if isinstance(audio_path_or_array, str):
audio_array = self._load_audio_from_path(audio_path_or_array)
else:
audio_array = np.array(audio_path_or_array, dtype=np.float32)
original_normalize = self.normalize_audio
if normalize is not None:
self.normalize_audio = normalize
try:
processed = self._process_single_audio(audio_array)
finally:
self.normalize_audio = original_normalize
return processed
def to_dict(self) -> Dict[str, Any]:
return self.feature_extractor_dict
def save_audio(
self,
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
):
if sampling_rate is None:
sampling_rate = self.sampling_rate
try:
import soundfile as sf
except ImportError:
raise ImportError(
"soundfile is required to save audio files. Install it with: pip install soundfile"
)
if isinstance(audio, torch.Tensor):
audio_np = audio.float().detach().cpu().numpy()
elif isinstance(audio, np.ndarray):
audio_np = audio
elif isinstance(audio, list):
if all((isinstance(a, torch.Tensor) for a in audio)):
audio_np = [a.float().detach().cpu().numpy() for a in audio]
else:
audio_np = audio
else:
raise ValueError(f"Unsupported audio type: {type (audio )}")
saved_paths = []
if isinstance(audio_np, list):
output_dir = output_path
os.makedirs(output_dir, exist_ok=True)
for i, audio_item in enumerate(audio_np):
audio_item = self._prepare_audio_for_save(audio_item, normalize)
file_path = os.path.join(output_dir, f"{batch_prefix }{i }.wav")
sf.write(file_path, audio_item, sampling_rate)
saved_paths.append(file_path)
elif len(audio_np.shape) >= 3:
batch_size = audio_np.shape[0]
if batch_size > 1:
output_dir = output_path
os.makedirs(output_dir, exist_ok=True)
for i in range(batch_size):
single_audio = audio_np[i]
if len(single_audio.shape) > 1:
if single_audio.shape[0] == 1:
single_audio = single_audio.squeeze(0)
single_audio = self._prepare_audio_for_save(single_audio, normalize)
file_path = os.path.join(output_dir, f"{batch_prefix }{i }.wav")
sf.write(file_path, single_audio, sampling_rate)
saved_paths.append(file_path)
else:
audio_item = audio_np.squeeze()
audio_item = self._prepare_audio_for_save(audio_item, normalize)
sf.write(output_path, audio_item, sampling_rate)
saved_paths.append(output_path)
else:
audio_item = self._prepare_audio_for_save(audio_np, normalize)
sf.write(output_path, audio_item, sampling_rate)
saved_paths.append(output_path)
return saved_paths
def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
if len(audio.shape) > 1 and audio.shape[0] == 1:
audio = audio.squeeze(0)
if normalize:
max_val = np.abs(audio).max()
if max_val > 0:
audio = audio / max_val
return audio
__all__ = [
'QWEN3VoxTokenizerProcessor',
"AudioNormalizer",
]
import math
import torch
class UniformSampler:
def __init__(self, timesteps=1000):
self.timesteps = timesteps
def sample(self, batch_size, device):
return torch.randint(0, self.timesteps, (batch_size,), device=device)
class LogitNormalSampler:
def __init__(self, timesteps=1000, m=0, s=1):
self.timesteps = timesteps
timesteps = torch.linspace(0, 1, timesteps)
logit = torch.log(timesteps / (1 - timesteps))
self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s**2) / (
s * math.sqrt(2 * math.pi)
)
def sample(self, batch_size, device):
return torch.multinomial(self.prob, batch_size, replacement=True).to(device)
' QWEN3Vox Streaming model configuration'
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
logger = logging.get_logger(__name__)
class QWEN3VoxStreamingConfig(PretrainedConfig):
model_type = 'vibevoice_streaming'
is_composition = True
sub_configs = {
"acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig,
"decoder_config": Qwen2Config,
"diffusion_head_config": QWEN3VoxDiffusionHeadConfig,
}
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
acoustic_tokenizer_config=None,
decoder_config=None,
diffusion_head_config=None,
tts_backbone_num_hidden_layers=20,
**kwargs,
):
kwargs["_attn_implementation_autoset"] = False
if acoustic_tokenizer_config is None:
self.acoustic_tokenizer_config = self.sub_configs[
"acoustic_tokenizer_config"
]()
elif isinstance(acoustic_tokenizer_config, dict):
acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer'
self.acoustic_tokenizer_config = self.sub_configs[
"acoustic_tokenizer_config"
](**acoustic_tokenizer_config)
elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig):
self.acoustic_tokenizer_config = acoustic_tokenizer_config
if decoder_config is None:
self.decoder_config = self.sub_configs["decoder_config"]()
elif isinstance(decoder_config, dict):
if decoder_config.get("model_type", "") == "qwen2":
self.decoder_config = Qwen2Config(**decoder_config)
else:
raise ValueError(
f"Unsupported decoder model type: {decoder_config .get ('model_type','')}"
)
elif isinstance(decoder_config, (Qwen2Config,)):
self.decoder_config = decoder_config
if diffusion_head_config is None:
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
elif isinstance(diffusion_head_config, dict):
diffusion_head_config["model_type"] = 'vibevoice_diffusion_head'
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](
**diffusion_head_config
)
elif isinstance(diffusion_head_config, QWEN3VoxDiffusionHeadConfig):
self.diffusion_head_config = diffusion_head_config
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64)
self.tts_backbone_num_hidden_layers = tts_backbone_num_hidden_layers
super().__init__(**kwargs)
__all__ = [
'QWEN3VoxStreamingConfig'
]
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.auto import AutoModel
from transformers.modeling_utils import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.utils import logging
logger = logging.get_logger(__name__)
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
eps: float = 1e-06,
elementwise_affine=True,
memory_efficient=False,
):
super().__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter("weight", None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output
def extra_repr(self) -> str:
return f"dim={self .dim }, eps={self .eps }, elementwise_affine={self .elementwise_affine }"
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
ACT2FN["silu"],
nn.Linear(hidden_size, hidden_size, bias=False),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding.to(t.dtype)
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class FeedForwardNetwork(nn.Module):
def __init__(self, embed_dim, ffn_dim):
super().__init__()
self.embed_dim = embed_dim
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
self.act_fn = ACT2FN["silu"]
def forward(self, x):
gate = self.gate_proj(x)
up = self.up_proj(x)
gate = self.act_fn(gate)
return self.down_proj(gate * up)
class HeadLayer(nn.Module):
def __init__(self, embed_dim, ffn_dim, cond_dim, norm_eps=1e-05):
super().__init__()
self.embed_dim = embed_dim
self.cond_dim = cond_dim
self.ffn_dim = ffn_dim
self.ffn = FeedForwardNetwork(self.embed_dim, self.ffn_dim)
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
self.adaLN_modulation = nn.Sequential(
ACT2FN["silu"], nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
)
def forward(self, x, c):
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
return x
class FinalLayer(nn.Module):
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-05):
super().__init__()
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
self.linear = nn.Linear(hidden_size, output_size, bias=False)
self.adaLN_modulation = nn.Sequential(
ACT2FN["silu"], nn.Linear(cond_size, 2 * hidden_size, bias=False)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class QWEN3VoxDiffusionHead(PreTrainedModel):
config_class = QWEN3VoxDiffusionHeadConfig
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config):
super().__init__(config)
self.config = config
self.cond_dim = config.hidden_size
latent_size = config.latent_size
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
self.t_embedder = TimestepEmbedder(self.cond_dim)
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
self.layers = nn.ModuleList(
[
HeadLayer(
embed_dim=config.hidden_size,
ffn_dim=ffn_dim,
cond_dim=self.cond_dim,
norm_eps=config.rms_norm_eps,
)
for _ in range(config.head_layers)
]
)
self.final_layer = FinalLayer(
hidden_size=config.hidden_size,
output_size=latent_size,
cond_size=self.cond_dim,
norm_eps=config.rms_norm_eps,
)
self.initialize_weights()
def initialize_weights(self):
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
for layer in self.layers:
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
def forward(self, noisy_images, timesteps, condition):
x = self.noisy_images_proj(noisy_images)
t = self.t_embedder(timesteps)
condition = self.cond_proj(condition)
c = condition + t
for layer in self.layers:
x = layer(x, c)
x = self.final_layer(x, c)
return x
AutoModel.register(QWEN3VoxDiffusionHeadConfig, QWEN3VoxDiffusionHead)
__all__ = [
'QWEN3VoxDiffusionHead'
]
import math
import typing as tp
from functools import partial
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.auto import AutoModel
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.modeling_utils import PreTrainedModel
from transformers.activations import ACT2FN
logger = logging.get_logger(__name__)
import os
try:
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
APEX_AVAILABLE = True
logger.info("APEX FusedRMSNorm is available and will be used for optimization")
if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
APEX_AVAILABLE = False
logger.warning(
"APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0"
)
except ImportError:
APEX_AVAILABLE = False
logger.warning("APEX FusedRMSNorm not available, using native implementation")
class ConvLayerNorm(nn.LayerNorm):
def __init__(
self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
):
super().__init__(normalized_shape, **kwargs)
def forward(self, x):
x = x.transpose(1, 2)
x = nn.functional.layer_norm(
x.float(),
self.normalized_shape,
self.weight.float(),
self.bias.float(),
self.eps,
).type_as(x)
x = x.transpose(1, 2)
return x
class RMSNorm(nn.Module):
def __init__(
self, dim: int, eps: float = 1e-05, elementwise_affine=True, weight_shape=None
):
super().__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
weight_shape = (dim,) if weight_shape is None else weight_shape
self.weight = nn.Parameter(torch.ones(weight_shape))
else:
self.register_parameter("weight", None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output
def extra_repr(self) -> str:
return f"dim={self .dim }, eps={self .eps }, elementwise_affine={self .elementwise_affine }"
class ConvRMSNorm(RMSNorm):
def __init__(
self, dim: int, eps: float = 1e-05, elementwise_affine=True, weight_shape=None
):
super().__init__(dim, eps, elementwise_affine, weight_shape)
def forward(self, x):
x = x.transpose(1, 2)
if not APEX_AVAILABLE or not self.elementwise_affine:
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
else:
output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps)
output = output.transpose(1, 2)
return output
CONV_NORMALIZATIONS = frozenset(
[
"none",
"weight_norm",
"spectral_norm",
"time_layer_norm",
"layer_norm",
"time_group_norm",
]
)
def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == "weight_norm":
return nn.utils.weight_norm(module)
elif norm == "spectral_norm":
return nn.utils.spectral_norm(module)
else:
return module
def get_norm_module(
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
) -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == "layer_norm":
assert isinstance(module, nn.modules.conv._ConvNd)
return ConvLayerNorm(module.out_channels, **norm_kwargs)
elif norm == "time_group_norm":
if causal:
raise ValueError("GroupNorm doesn't support causal evaluation.")
assert isinstance(module, nn.modules.conv._ConvNd)
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
else:
return nn.Identity()
def get_extra_padding_for_conv1d(
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
) -> int:
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad1d(
x: torch.Tensor,
paddings: tp.Tuple[int, int],
mode: str = "zero",
value: float = 0.0,
):
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert padding_left + padding_right <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left:end]
class NormConv1d(nn.Module):
def __init__(
self,
*args,
causal: bool = False,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
**kwargs,
):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConvTranspose1d(nn.Module):
def __init__(
self,
*args,
causal: bool = False,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
**kwargs,
):
super().__init__()
self.convtr = apply_parametrization_norm(
nn.ConvTranspose1d(*args, **kwargs), norm
)
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class QWEN3VoxTokenizerStreamingCache:
def __init__(self):
self.cache = {}
def get(
self, layer_id: str, sample_indices: torch.Tensor
) -> Optional[torch.Tensor]:
states = []
max_length = 0
for idx in sample_indices.tolist():
key = (layer_id, idx)
if key not in self.cache:
return None
state = self.cache[key]
states.append(state)
max_length = max(max_length, state.shape[-1])
if len(states) > 0 and states[0].dim() >= 2:
padded_states = []
for state in states:
if state.shape[-1] < max_length:
pad_size = max_length - state.shape[-1]
padded_state = F.pad(state, (pad_size, 0), mode="constant", value=0)
padded_states.append(padded_state)
else:
padded_states.append(state)
return torch.stack(padded_states, dim=0)
else:
return torch.stack(states, dim=0)
def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor):
for i, idx in enumerate(sample_indices.tolist()):
key = (layer_id, idx)
self.cache[key] = states[i].detach()
def set_to_zero(self, sample_indices: torch.Tensor):
for key in list(self.cache.keys()):
layer_id, sample_idx = key
if sample_idx in sample_indices.tolist():
cached_tensor = self.cache[key]
self.cache[key] = torch.zeros_like(cached_tensor)
def clear(
self,
layer_id: Optional[str] = None,
sample_indices: Optional[torch.Tensor] = None,
):
if layer_id is None and sample_indices is None:
self.cache.clear()
elif layer_id is not None and sample_indices is None:
keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id]
for k in keys_to_remove:
del self.cache[k]
elif layer_id is not None and sample_indices is not None:
for idx in sample_indices.tolist():
key = (layer_id, idx)
self.cache.pop(key, None)
class SConv1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
causal: bool = False,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
pad_mode: str = "reflect",
):
super().__init__()
self.conv = NormConv1d(
in_channels,
out_channels,
kernel_size,
stride,
dilation=dilation,
groups=groups,
bias=bias,
causal=causal,
norm=norm,
norm_kwargs=norm_kwargs,
)
self.causal = causal
self.pad_mode = pad_mode
self.kernel_size = kernel_size
self.dilation = dilation
self.stride = stride
self.in_channels = in_channels
self.out_channels = out_channels
self.context_size = (kernel_size - 1) * dilation - (stride - 1)
self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
self._layer_id = None
@property
def layer_id(self):
if self._layer_id is None:
self._layer_id = f"sconv1d_{id (self )}"
return self._layer_id
def forward(
self,
x: torch.Tensor,
cache: Optional[QWEN3VoxTokenizerStreamingCache] = None,
sample_indices: Optional[torch.Tensor] = None,
use_cache: bool = False,
debug: bool = False,
) -> torch.Tensor:
B, C, T = x.shape
if not use_cache or cache is None:
return self._forward_non_streaming(x, debug=debug)
assert self.causal, "Streaming mode is only supported for causal convolutions"
assert (
sample_indices is not None
), "sample_indices must be provided for streaming mode"
assert len(sample_indices) == B, "sample_indices must match batch size"
return self._forward_streaming(x, cache, sample_indices, debug)
def _forward_streaming(
self,
x: torch.Tensor,
cache: QWEN3VoxTokenizerStreamingCache,
sample_indices: torch.Tensor,
debug: bool = False,
) -> torch.Tensor:
B, C, T = x.shape
cached_states = cache.get(self.layer_id, sample_indices)
if cached_states is None:
if self.context_size > 0:
cached_states = torch.zeros(
B, C, self.context_size, device=x.device, dtype=x.dtype
)
if debug:
print(
f"[DEBUG] Initialized cache with shape: {cached_states .shape }, context_size={self .context_size }"
)
else:
cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
if debug:
print(f"[DEBUG] No context needed (kernel_size=stride)")
if cached_states.shape[2] > 0:
input_with_context = torch.cat([cached_states, x], dim=2)
else:
input_with_context = x
if debug:
print(
f"[DEBUG] Input shape: {x .shape }, Cache shape: {cached_states .shape }, Combined: {input_with_context .shape }"
)
output = self.conv(input_with_context)
if debug:
print(f"[DEBUG] Output shape: {output .shape }")
if self.context_size > 0:
total_input_length = input_with_context.shape[2]
if total_input_length >= self.context_size:
new_cache_start = total_input_length - self.context_size
new_cache = input_with_context[:, :, new_cache_start:]
else:
new_cache = input_with_context
if debug:
print(f"[DEBUG] New cache shape: {new_cache .shape }")
cache.set(self.layer_id, sample_indices, new_cache)
return output
def _forward_non_streaming(
self, x: torch.Tensor, debug: bool = False
) -> torch.Tensor:
B, C, T = x.shape
kernel_size = self.kernel_size
stride = self.stride
dilation = self.dilation
padding_total = self.padding_total
extra_padding = get_extra_padding_for_conv1d(
x, kernel_size, stride, padding_total
)
if debug:
print(
f"[DEBUG NON-STREAMING] Input shape: {x .shape }, padding_total={padding_total }, extra_padding={extra_padding }"
)
if self.causal:
if self.pad_mode == "constant":
x = pad1d(
x, (padding_total, extra_padding), mode=self.pad_mode, value=0
)
else:
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
else:
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
)
if debug:
print(f"[DEBUG NON-STREAMING] After padding: {x .shape }")
output = self.conv(x)
if debug:
print(f"[DEBUG NON-STREAMING] Output shape: {output .shape }")
return output
class SConvTranspose1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
causal: bool = False,
norm: str = "none",
trim_right_ratio: float = 1.0,
norm_kwargs: tp.Dict[str, tp.Any] = {},
bias: bool = True,
):
super().__init__()
self.convtr = NormConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
causal=causal,
norm=norm,
norm_kwargs=norm_kwargs,
bias=bias,
)
self.causal = causal
self.trim_right_ratio = trim_right_ratio
assert (
self.causal or self.trim_right_ratio == 1.0
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
self.kernel_size = kernel_size
self.stride = stride
self.in_channels = in_channels
self.out_channels = out_channels
self.padding_total = kernel_size - stride
self.context_size = kernel_size - 1
self._layer_id = None
@property
def layer_id(self):
if self._layer_id is None:
self._layer_id = f"sconvtr1d_{id (self )}"
return self._layer_id
def forward(
self,
x: torch.Tensor,
cache: Optional[QWEN3VoxTokenizerStreamingCache] = None,
sample_indices: Optional[torch.Tensor] = None,
use_cache: bool = False,
debug: bool = False,
) -> torch.Tensor:
B, C, T = x.shape
if not use_cache or cache is None:
return self._forward_non_streaming(x, debug=debug)
assert (
sample_indices is not None
), "sample_indices must be provided for streaming mode"
assert len(sample_indices) == B, "sample_indices must match batch size"
return self._forward_streaming(x, cache, sample_indices, debug)
def _forward_streaming(
self,
x: torch.Tensor,
cache: QWEN3VoxTokenizerStreamingCache,
sample_indices: torch.Tensor,
debug: bool = False,
) -> torch.Tensor:
B, C, T = x.shape
cached_input = cache.get(self.layer_id, sample_indices)
if cached_input is None:
cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
if debug:
print(f"[DEBUG] Initialized empty cache for transposed conv")
full_input = torch.cat([cached_input, x], dim=2)
if debug:
print(
f"[DEBUG] Input shape: {x .shape }, Cache shape: {cached_input .shape }, Combined: {full_input .shape }"
)
full_output = self.convtr(full_input)
if debug:
print(f"[DEBUG] Full transposed conv output shape: {full_output .shape }")
if self.causal:
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
padding_left = self.padding_total - padding_right
else:
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
if padding_left + padding_right > 0:
full_output = unpad1d(full_output, (padding_left, padding_right))
if debug:
print(f"[DEBUG] After unpadding: {full_output .shape }")
if cached_input.shape[2] == 0:
output = full_output
else:
expected_new_output = T * self.stride
if full_output.shape[2] >= expected_new_output:
output = full_output[:, :, -expected_new_output:]
else:
output = full_output
if debug:
print(f"[DEBUG] Final streaming output shape: {output .shape }")
if full_input.shape[2] > self.context_size:
new_cache = full_input[:, :, -self.context_size :]
else:
new_cache = full_input
if debug:
print(f"[DEBUG] New cache shape: {new_cache .shape }")
cache.set(self.layer_id, sample_indices, new_cache)
return output
def _forward_non_streaming(
self, x: torch.Tensor, debug: bool = False
) -> torch.Tensor:
if debug:
print(f"[DEBUG NON-STREAMING] Input shape: {x .shape }")
y = self.convtr(x)
if debug:
print(f"[DEBUG NON-STREAMING] After transposed conv: {y .shape }")
if self.causal:
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
padding_left = self.padding_total - padding_right
else:
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
if padding_left + padding_right > 0:
y = unpad1d(y, (padding_left, padding_right))
if debug:
print(f"[DEBUG NON-STREAMING] Final output shape: {y .shape }")
return y
class FFN(nn.Module):
def __init__(self, embed_dim, ffn_dim, bias=False):
super().__init__()
self.embed_dim = embed_dim
self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias)
self.gelu = ACT2FN["gelu"]
self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias)
def forward(self, x):
x = self.linear1(x)
x = self.gelu(x)
x = self.linear2(x)
return x
class Convlayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias=True,
pad_mode="zeros",
norm="weight_norm",
causal=True,
):
super().__init__()
self.conv = SConv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
pad_mode=pad_mode,
norm=norm,
causal=causal,
)
def forward(self, x):
return self.conv(x)
class Block1D(nn.Module):
def __init__(
self,
dim,
kernel_size=7,
drop_path=0.0,
mixer_layer="conv",
layer_scale_init_value=1e-06,
**kwargs,
):
super().__init__()
if kwargs.get("layernorm", "LN") == "LN":
self.norm = ConvLayerNorm(dim, eps=kwargs.get("eps", 1e-06))
self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get("eps", 1e-06))
elif kwargs.get("layernorm", "RMSNorm") == "RMSNorm":
self.norm = ConvRMSNorm(dim, eps=kwargs.get("eps", 1e-06))
self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get("eps", 1e-06))
if mixer_layer == "conv":
self.mixer = Convlayer(
dim,
dim,
groups=kwargs.get("groups", 1),
kernel_size=kernel_size,
pad_mode=kwargs.get("pad_mode", "reflect"),
norm=kwargs.get("norm", "none"),
causal=kwargs.get("causal", True),
bias=kwargs.get("bias", True),
)
elif mixer_layer == "depthwise_conv":
self.mixer = Convlayer(
dim,
dim,
groups=dim,
kernel_size=kernel_size,
pad_mode=kwargs.get("pad_mode", "reflect"),
norm=kwargs.get("norm", "none"),
causal=kwargs.get("causal", True),
bias=kwargs.get("bias", True),
)
else:
raise ValueError(f"Unsupported mixer layer: {mixer_layer }")
self.ffn = FFN(
dim, kwargs.get("ffn_expansion", 4) * dim, bias=kwargs.get("bias", False)
)
self.drop_path = (
nn.Identity() if drop_path <= 0.0 else nn.modules.DropPath(drop_path)
)
if layer_scale_init_value > 0:
self.gamma = nn.Parameter(
layer_scale_init_value * torch.ones(dim), requires_grad=True
)
self.ffn_gamma = nn.Parameter(
layer_scale_init_value * torch.ones(dim), requires_grad=True
)
else:
self.gamma = None
self.ffn_gamma = None
def forward(self, x):
residual = x
x = self.norm(x)
x = self.mixer(x)
if self.gamma is not None:
x = x * self.gamma.unsqueeze(-1)
x = residual + self.drop_path(x)
residual = x
x = self.ffn_norm(x)
x = x.permute(0, 2, 1)
x = self.ffn(x)
x = x.permute(0, 2, 1)
if self.ffn_gamma is not None:
x = x * self.ffn_gamma.unsqueeze(-1)
x = residual + self.drop_path(x)
return x
class TokenizerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.channels = config.channels
self.dimension = config.dimension
self.n_filters = config.n_filters
self.ratios = list(reversed(config.ratios))
self.depths = config.depths
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
self.hop_length = np.prod(self.ratios)
self.causal = config.causal
kernel_size = getattr(config, "kernel_size", 7)
last_kernel_size = getattr(config, "last_kernel_size", 7)
norm = getattr(config, "norm", "none")
norm_params = getattr(config, "norm_params", {})
pad_mode = getattr(config, "pad_mode", "reflect")
bias = getattr(config, "bias", True)
layernorm = getattr(config, "layernorm", "LN")
layernorm_eps = getattr(config, "layernorm_eps", 1e-06)
layernorm_elementwise_affine = getattr(
config, "layernorm_elementwise_affine", True
)
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
mixer_layer = getattr(config, "mixer_layer", "conv")
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
disable_last_norm = getattr(config, "disable_last_norm", False)
if layernorm == "LN":
norm_type = ConvLayerNorm
elif layernorm == "RMSNorm":
norm_type = partial(
ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine
)
else:
raise ValueError(f"Unsupported norm type: {layernorm }")
stem = nn.Sequential(
SConv1d(
self.channels,
self.n_filters,
kernel_size,
norm=norm,
norm_kwargs=norm_params,
causal=self.causal,
pad_mode=pad_mode,
bias=bias,
)
)
self.downsample_layers = nn.ModuleList()
self.downsample_layers.append(stem)
for i in range(len(self.ratios)):
in_ch = self.n_filters * 2**i
out_ch = self.n_filters * 2 ** (i + 1)
downsample_layer = nn.Sequential(
SConv1d(
in_ch,
out_ch,
kernel_size=self.ratios[i] * 2,
stride=self.ratios[i],
causal=self.causal,
pad_mode=pad_mode,
norm=norm,
bias=bias,
)
)
self.downsample_layers.append(downsample_layer)
layer_type = partial(
Block1D,
mixer_layer=mixer_layer,
layernorm=layernorm,
eps=layernorm_eps,
causal=self.causal,
pad_mode=pad_mode,
norm=norm,
bias=bias,
layer_scale_init_value=layer_scale_init_value,
)
self.stages = nn.ModuleList()
dp_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))
]
cur = 0
for i in range(len(self.depths)):
in_ch = self.n_filters * 2**i
stage = nn.Sequential(
*[
layer_type(dim=in_ch, drop_path=dp_rates[cur + j])
for j in range(self.depths[i])
]
)
self.stages.append(stage)
cur += self.depths[i]
if not disable_last_norm:
self.norm = norm_type(in_ch, eps=layernorm_eps)
else:
self.norm = nn.Identity()
self.head = SConv1d(
in_ch,
self.dimension,
kernel_size=last_kernel_size,
causal=self.causal,
pad_mode=pad_mode,
norm=norm,
bias=bias,
)
def forward_features(
self, x, cache=None, sample_indices=None, use_cache=False, debug=False
):
for i in range(len(self.depths)):
for layer in self.downsample_layers[i]:
if isinstance(layer, SConv1d):
x = layer(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
else:
x = layer(x)
for block in self.stages[i]:
if (
hasattr(block, "mixer")
and hasattr(block.mixer, "conv")
and isinstance(block.mixer.conv, SConv1d)
):
residual = x
x = block.norm(x)
x = block.mixer.conv(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
if block.gamma is not None:
x = x * block.gamma.unsqueeze(-1)
x = residual + x
residual = x
x = block.ffn_norm(x)
x = x.permute(0, 2, 1)
x = block.ffn(x)
x = x.permute(0, 2, 1)
if block.ffn_gamma is not None:
x = x * block.ffn_gamma.unsqueeze(-1)
x = residual + x
else:
x = block(x)
return self.norm(x)
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
x = self.forward_features(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
x = self.head(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
return x
class TokenizerDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.dimension = config.dimension
self.channels = config.channels
self.n_filters = config.n_filters
self.ratios = config.ratios
self.depths = config.depths
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
self.hop_length = np.prod(self.ratios)
self.causal = config.causal
kernel_size = getattr(config, "kernel_size", 7)
last_kernel_size = getattr(config, "last_kernel_size", 7)
norm = getattr(config, "norm", "none")
norm_params = getattr(config, "norm_params", {})
pad_mode = getattr(config, "pad_mode", "reflect")
bias = getattr(config, "bias", True)
layernorm = getattr(config, "layernorm", "LN")
layernorm_eps = getattr(config, "layernorm_eps", 1e-06)
trim_right_ratio = getattr(config, "trim_right_ratio", 1.0)
layernorm_elementwise_affine = getattr(
config, "layernorm_elementwise_affine", True
)
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
mixer_layer = getattr(config, "mixer_layer", "conv")
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
disable_last_norm = getattr(config, "disable_last_norm", False)
if layernorm == "LN":
norm_type = ConvLayerNorm
elif layernorm == "RMSNorm":
norm_type = partial(
ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine
)
else:
raise ValueError(f"Unsupported norm type: {layernorm }")
stem = nn.Sequential(
SConv1d(
self.dimension,
self.n_filters * 2 ** (len(self.depths) - 1),
kernel_size,
norm=norm,
norm_kwargs=norm_params,
causal=self.causal,
pad_mode=pad_mode,
bias=bias,
)
)
self.upsample_layers = nn.ModuleList()
self.upsample_layers.append(stem)
for i in range(len(self.ratios)):
in_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i)
out_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i - 1)
upsample_layer = nn.Sequential(
SConvTranspose1d(
in_ch,
out_ch,
kernel_size=self.ratios[i] * 2,
stride=self.ratios[i],
norm=norm,
norm_kwargs=norm_params,
bias=bias,
causal=self.causal,
trim_right_ratio=trim_right_ratio,
)
)
self.upsample_layers.append(upsample_layer)
layer_type = partial(
Block1D,
mixer_layer=mixer_layer,
layernorm=layernorm,
eps=layernorm_eps,
causal=self.causal,
pad_mode=pad_mode,
norm=norm,
bias=bias,
layer_scale_init_value=layer_scale_init_value,
)
self.stages = nn.ModuleList()
dp_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))
]
cur = 0
for i in range(len(self.depths)):
in_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i)
stage = nn.Sequential(
*[
layer_type(dim=in_ch, drop_path=dp_rates[cur + j])
for j in range(self.depths[i])
]
)
self.stages.append(stage)
cur += self.depths[i]
if not disable_last_norm:
self.norm = norm_type(in_ch, eps=layernorm_eps)
else:
self.norm = nn.Identity()
self.head = SConv1d(
in_ch,
self.channels,
kernel_size=last_kernel_size,
causal=self.causal,
pad_mode=pad_mode,
norm=norm,
bias=bias,
)
def forward_features(
self, x, cache=None, sample_indices=None, use_cache=False, debug=False
):
for i in range(len(self.depths)):
for layer in self.upsample_layers[i]:
if isinstance(layer, (SConv1d, SConvTranspose1d)):
x = layer(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
else:
x = layer(x)
for block in self.stages[i]:
if (
hasattr(block, "mixer")
and hasattr(block.mixer, "conv")
and isinstance(block.mixer.conv, SConv1d)
):
residual = x
x = block.norm(x)
x = block.mixer.conv(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
if block.gamma is not None:
x = x * block.gamma.unsqueeze(-1)
x = residual + x
residual = x
x = block.ffn_norm(x)
x = x.permute(0, 2, 1)
x = block.ffn(x)
x = x.permute(0, 2, 1)
if block.ffn_gamma is not None:
x = x * block.ffn_gamma.unsqueeze(-1)
x = residual + x
else:
x = block(x)
return self.norm(x)
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
x = self.forward_features(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
x = self.head(
x,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
return x
@dataclass
class QWEN3VoxTokenizerEncoderOutput:
mean: torch.Tensor
std: Optional[Union[float, torch.Tensor]] = None
def sample(self, dist_type="fix"):
if dist_type == "fix":
x = self.mean + self.std * torch.randn_like(self.mean)
return (x, self.std)
elif dist_type == "gaussian":
batch_size = self.mean.size(0)
value = self.std / 0.8
std = (
torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype)
* value
)
while std.dim() < self.mean.dim():
std = std.unsqueeze(-1)
x = self.mean + std * torch.randn_like(self.mean)
return (x, std)
else:
return (self.mean, self.std)
def kl(self):
target = torch.zeros_like(self.mean)
return F.mse_loss(self.mean, target, reduction="none")
def mode(self):
return self.mean
class QWEN3VoxAcousticTokenizerModel(PreTrainedModel):
config_class = QWEN3VoxAcousticTokenizerConfig
base_model_prefix = 'vibevoice_acoustic_tokenizer'
_supports_flash_attn_2 = True
_supports_sdpa = True
_no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
def __init__(self, config):
super().__init__(config)
self.register_buffer("fix_std", torch.tensor(config.fix_std), persistent=False)
self.std_dist_type = getattr(config, "std_dist_type", "fix")
if isinstance(config.encoder_depths, str):
encoder_depths = [int(d) for d in config.encoder_depths.split("-")]
else:
encoder_depths = config.encoder_depths
if config.decoder_depths is not None and isinstance(config.decoder_depths, str):
decoder_depths = [int(d) for d in config.decoder_depths.split("-")]
else:
decoder_depths = list(reversed(encoder_depths))
encoder_config = copy.deepcopy(config)
encoder_config.dimension = config.vae_dim
encoder_config.n_filters = config.encoder_n_filters
encoder_config.ratios = config.encoder_ratios
encoder_config.depths = encoder_depths
encoder_config.norm = config.conv_norm
encoder_config.pad_mode = config.pad_mode
encoder_config.bias = config.conv_bias
encoder_config.layernorm_eps = config.layernorm_eps
encoder_config.layernorm_elementwise_affine = (
config.layernorm_elementwise_affine
)
encoder_config.mixer_layer = config.mixer_layer
encoder_config.layer_scale_init_value = config.layer_scale_init_value
encoder_config.disable_last_norm = config.disable_last_norm
decoder_config = copy.deepcopy(config)
decoder_config.dimension = config.vae_dim
decoder_config.n_filters = config.decoder_n_filters
decoder_config.ratios = config.decoder_ratios
decoder_config.depths = decoder_depths
decoder_config.norm = config.conv_norm
decoder_config.pad_mode = config.pad_mode
decoder_config.bias = config.conv_bias
decoder_config.layernorm_eps = config.layernorm_eps
decoder_config.layernorm_elementwise_affine = (
config.layernorm_elementwise_affine
)
decoder_config.mixer_layer = config.mixer_layer
decoder_config.layer_scale_init_value = config.layer_scale_init_value
decoder_config.disable_last_norm = config.disable_last_norm
self.encoder = TokenizerEncoder(encoder_config)
self.decoder = TokenizerDecoder(decoder_config)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=self.config.weight_init_value)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv1d):
nn.init.normal_(module.weight, std=self.config.weight_init_value)
if module.bias is not None:
nn.init.zeros_(module.bias)
@torch.no_grad()
def encode(
self, audio, cache=None, sample_indices=None, use_cache=False, debug=False
):
latents = self.encoder(
audio,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
return QWEN3VoxTokenizerEncoderOutput(
mean=latents.permute(0, 2, 1), std=self.fix_std
)
@torch.no_grad()
def sampling(self, encoder_output, dist_type=None):
dist_type = dist_type or self.std_dist_type
if dist_type == "fix":
return encoder_output.sample(dist_type="fix")
elif dist_type == "gaussian":
return encoder_output.sample(dist_type="gaussian")
else:
raise ValueError(
f"Unsupported dist_type: {dist_type }, expected 'fix' or 'gaussian'"
)
@torch.no_grad()
def decode(
self, latents, cache=None, sample_indices=None, use_cache=False, debug=False
):
if latents.shape[1] == self.config.vae_dim:
pass
else:
latents = latents.permute(0, 2, 1)
audio = self.decoder(
latents,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
return audio
def forward(
self, audio, cache=None, sample_indices=None, use_cache=False, debug=False
):
encoder_output = self.encode(
audio,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
sampled_latents, _ = self.sampling(encoder_output)
reconstructed = self.decode(
sampled_latents,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
return (reconstructed, sampled_latents)
class QWEN3VoxSemanticTokenizerModel(PreTrainedModel):
config_class = QWEN3VoxSemanticTokenizerConfig
base_model_prefix = 'vibevoice_semantic_tokenizer'
_supports_flash_attn_2 = True
_supports_sdpa = True
_no_split_modules = ["TokenizerEncoder"]
def __init__(self, config):
super().__init__(config)
if isinstance(config.encoder_depths, str):
encoder_depths = [int(d) for d in config.encoder_depths.split("-")]
else:
encoder_depths = config.encoder_depths
encoder_config = copy.deepcopy(config)
encoder_config.dimension = config.vae_dim
encoder_config.n_filters = config.encoder_n_filters
encoder_config.ratios = config.encoder_ratios
encoder_config.depths = encoder_depths
encoder_config.norm = config.conv_norm
encoder_config.pad_mode = config.pad_mode
encoder_config.bias = config.conv_bias
encoder_config.layernorm_eps = config.layernorm_eps
encoder_config.layernorm_elementwise_affine = (
config.layernorm_elementwise_affine
)
encoder_config.mixer_layer = config.mixer_layer
encoder_config.layer_scale_init_value = config.layer_scale_init_value
encoder_config.disable_last_norm = config.disable_last_norm
self.encoder = TokenizerEncoder(encoder_config)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=self.config.weight_init_value)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv1d):
nn.init.normal_(module.weight, std=self.config.weight_init_value)
if module.bias is not None:
nn.init.zeros_(module.bias)
@torch.no_grad()
def encode(
self, audio, cache=None, sample_indices=None, use_cache=False, debug=False
):
latents = self.encoder(
audio,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
return QWEN3VoxTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
@torch.no_grad()
def sampling(self, encoder_output, dist_type=None):
return encoder_output.sample(dist_type="none")
def forward(
self, audio, cache=None, sample_indices=None, use_cache=False, debug=False
):
encoder_output = self.encode(
audio,
cache=cache,
sample_indices=sample_indices,
use_cache=use_cache,
debug=debug,
)
sampled_latents, _ = self.sampling(encoder_output, dist_type="none")
return (None, sampled_latents)
AutoModel.register(QWEN3VoxAcousticTokenizerConfig, QWEN3VoxAcousticTokenizerModel)
AutoModel.register(QWEN3VoxSemanticTokenizerConfig, QWEN3VoxSemanticTokenizerModel)
__all__ = [
'QWEN3VoxTokenizerStreamingCache',
'QWEN3VoxAcousticTokenizerModel',
'QWEN3VoxSemanticTokenizerModel',
]
'\nProcessor class for QWEN3Vox ASR models.\n'
import os
import json
import math
import warnings
from typing import List, Optional, Union, Dict, Any, Tuple
import numpy as np
import torch
from transformers.tokenization_utils_base import BatchEncoding
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
SYSTEM_PROMPT = "You are a helpful assistant that transcribes audio input into text output in JSON format."
class QWEN3VoxASRProcessor:
def __init__(
self,
tokenizer=None,
audio_processor=None,
speech_tok_compress_ratio=320,
target_sample_rate=22050,
normalize_audio=True,
**kwargs,
):
self.tokenizer = tokenizer
self.audio_processor = audio_processor or QWEN3VoxTokenizerProcessor(
sampling_rate=target_sample_rate, normalize_audio=normalize_audio
)
self.speech_tok_compress_ratio = speech_tok_compress_ratio
self.target_sample_rate = target_sample_rate
self.normalize_audio = normalize_audio
if normalize_audio:
self.audio_normalizer = AudioNormalizer()
else:
self.audio_normalizer = None
self._cache_special_tokens()
def _cache_special_tokens(self):
if hasattr(self.tokenizer, "speech_start_id"):
self.speech_start_id = self.tokenizer.speech_start_id
else:
self.speech_start_id = self.tokenizer.convert_tokens_to_ids(
"<|speech_start|>"
)
if hasattr(self.tokenizer, "speech_end_id"):
self.speech_end_id = self.tokenizer.speech_end_id
else:
self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|speech_end|>")
if hasattr(self.tokenizer, "speech_pad_id"):
self.speech_pad_id = self.tokenizer.speech_pad_id
else:
self.speech_pad_id = self.tokenizer.convert_tokens_to_ids("<|speech_pad|>")
if hasattr(self.tokenizer, "pad_id"):
self.pad_id = self.tokenizer.pad_id
elif hasattr(self.tokenizer, "pad_token_id"):
self.pad_id = self.tokenizer.pad_token_id
else:
self.pad_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
import json
from transformers.utils import cached_file
model_name = str(pretrained_model_name_or_path)
config_path = os.path.join(
model_name, "preprocessor_config.json"
)
config = {}
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
else:
try:
config_file = cached_file(
model_name, "preprocessor_config.json", **kwargs
)
with open(config_file, "r") as f:
config = json.load(f)
except Exception as e:
logger.warning(f"Could not load preprocessor_config.json: {e }")
logger.warning("Using default configuration")
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
target_sample_rate = config.get("target_sample_rate", 22050)
normalize_audio = config.get("normalize_audio", True)
language_model_pretrained_name = config.get(
"language_model_pretrained_name", None
) or kwargs.pop("language_model_pretrained_name", None)
if not language_model_pretrained_name:
language_model_pretrained_name = model_name
logger.info(f"Loading tokenizer from repo {model_name }")
tokenizer = QWEN3VoxASRTextTokenizerFast.from_pretrained(
model_name, **kwargs
)
audio_processor = QWEN3VoxTokenizerProcessor(
sampling_rate=target_sample_rate,
normalize_audio=normalize_audio,
target_dB_FS=config.get("target_dB_FS", -25),
eps=config.get("eps", 1e-06),
)
return cls(
tokenizer=tokenizer,
audio_processor=audio_processor,
speech_tok_compress_ratio=speech_tok_compress_ratio,
target_sample_rate=target_sample_rate,
normalize_audio=normalize_audio,
)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
import json
os.makedirs(save_directory, exist_ok=True)
processor_config = {
"processor_class": "QWEN3VoxASRProcessor",
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
"target_sample_rate": self.target_sample_rate,
"normalize_audio": self.normalize_audio,
"target_dB_FS": -25,
"eps": 1e-06,
}
config_path = os.path.join(save_directory, "preprocessor_config.json")
with open(config_path, "w") as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Processor configuration saved in {config_path }")
def __call__(
self,
audio: Optional[
Union[
str,
np.ndarray,
torch.Tensor,
List[Union[str, np.ndarray, torch.Tensor]],
]
] = None,
sampling_rate: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
padding: bool = True,
max_length: Optional[int] = None,
truncation: bool = False,
add_generation_prompt: bool = True,
use_streaming: bool = True,
context_info: Optional[str] = None,
**kwargs,
) -> BatchEncoding:
if audio is None:
raise ValueError("Audio input is required for ASR processing")
if isinstance(audio, list):
is_batched = True
audio_list = audio
else:
is_batched = False
audio_list = [audio]
all_encodings = []
for audio_input in audio_list:
encoding = self._process_single_audio(
audio_input,
sampling_rate=sampling_rate,
add_generation_prompt=add_generation_prompt,
use_streaming=use_streaming,
context_info=context_info,
)
all_encodings.append(encoding)
batch_encoding = self._batch_encode(
all_encodings,
padding=padding,
max_length=max_length,
truncation=truncation,
return_tensors=return_tensors,
)
return batch_encoding
def _process_single_audio(
self,
audio: Union[str, np.ndarray, torch.Tensor],
sampling_rate: Optional[int] = None,
add_generation_prompt: bool = True,
use_streaming: bool = True,
context_info: Optional[str] = None,
) -> Dict[str, Any]:
if isinstance(audio, str):
import soundfile as sf
audio_array, file_sr = sf.read(audio)
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
if file_sr != self.target_sample_rate:
import librosa
audio_array = librosa.resample(
audio_array, orig_sr=file_sr, target_sr=self.target_sample_rate
)
elif isinstance(audio, torch.Tensor):
audio_array = audio.cpu().numpy()
if audio_array.ndim > 1:
audio_array = audio_array.squeeze()
else:
audio_array = np.array(audio, dtype=np.float32)
if audio_array.ndim > 1:
audio_array = audio_array.squeeze()
audio_array = audio_array.astype(np.float32)
if self.normalize_audio and self.audio_normalizer:
audio_array = self.audio_normalizer(audio_array)
audio_duration = len(audio_array) / self.target_sample_rate
if use_streaming and audio_duration < 60.0:
use_streaming = False
vae_tok_len = math.ceil(len(audio_array) / self.speech_tok_compress_ratio)
system_prompt_text = self.tokenizer.apply_chat_template(
[{"role": "system", "content": SYSTEM_PROMPT}], tokenize=False
)
system_tokens = self.tokenizer.encode(system_prompt_text)
sp_start_token = self.tokenizer.convert_ids_to_tokens(self.speech_start_id)
sp_pad_token = self.tokenizer.convert_ids_to_tokens(self.speech_pad_id)
sp_end_token = self.tokenizer.convert_ids_to_tokens(self.speech_end_id)
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
if context_info and context_info.strip():
user_suffix = (
f"This is a {audio_duration :.2f} seconds audio, with extra info: {context_info .strip ()}\n\nPlease transcribe it with these keys: "
+ ", ".join(show_keys)
)
else:
user_suffix = (
f"This is a {audio_duration :.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys)
)
user_input_string = (
"".join([sp_start_token] + [sp_pad_token] * vae_tok_len + [sp_end_token])
+ "\n"
+ user_suffix
)
user_tokens = self.tokenizer.apply_chat_template(
[{"role": "user", "content": user_input_string}], tokenize=True
)
full_tokens = system_tokens + user_tokens
acoustic_input_mask = [
1 if token == self.speech_pad_id else 0 for token in full_tokens
]
return {
"input_ids": full_tokens,
"acoustic_input_mask": acoustic_input_mask,
"speech": audio_array,
"vae_tok_len": vae_tok_len,
}
def _batch_encode(
self,
encodings: List[Dict[str, Any]],
padding: bool = True,
max_length: Optional[int] = None,
truncation: bool = False,
return_tensors: Optional[str] = None,
) -> BatchEncoding:
input_ids_list = [enc["input_ids"] for enc in encodings]
acoustic_masks_list = [enc["acoustic_input_mask"] for enc in encodings]
speech_list = [enc["speech"] for enc in encodings]
vae_tok_lens = [enc["vae_tok_len"] for enc in encodings]
if padding:
if max_length is not None:
target_length = max_length
else:
target_length = max((len(ids) for ids in input_ids_list))
padded_input_ids = []
padded_acoustic_masks = []
attention_masks = []
for input_ids, acoustic_mask in zip(input_ids_list, acoustic_masks_list):
if truncation and len(input_ids) > target_length:
input_ids = input_ids[:target_length]
acoustic_mask = acoustic_mask[:target_length]
padding_length = target_length - len(input_ids)
padded_ids = [self.pad_id] * padding_length + input_ids
padded_acoustic = [0] * padding_length + acoustic_mask
attention_mask = [0] * padding_length + [1] * len(input_ids)
padded_input_ids.append(padded_ids)
padded_acoustic_masks.append(padded_acoustic)
attention_masks.append(attention_mask)
input_ids_list = padded_input_ids
acoustic_masks_list = padded_acoustic_masks
else:
attention_masks = [[1] * len(ids) for ids in input_ids_list]
max_speech_length = max((len(s) for s in speech_list))
padded_speeches = np.zeros(
(len(speech_list), max_speech_length), dtype=np.float32
)
speech_masks = np.zeros((len(speech_list), max(vae_tok_lens)), dtype=bool)
for i, (speech, vae_len) in enumerate(zip(speech_list, vae_tok_lens)):
padded_speeches[i, : len(speech)] = speech
speech_masks[i, :vae_len] = True
batch_encoding = BatchEncoding()
if return_tensors == "pt":
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
batch_encoding["attention_mask"] = torch.tensor(
attention_masks, dtype=torch.long
)
batch_encoding["acoustic_input_mask"] = torch.tensor(
acoustic_masks_list, dtype=torch.bool
)
batch_encoding["speech_tensors"] = torch.tensor(
padded_speeches, dtype=torch.float32
)
batch_encoding["speech_masks"] = torch.tensor(
speech_masks, dtype=torch.bool
)
else:
batch_encoding["input_ids"] = (
input_ids_list if len(input_ids_list) > 1 else input_ids_list[0]
)
batch_encoding["attention_mask"] = (
attention_masks if len(attention_masks) > 1 else attention_masks[0]
)
batch_encoding["acoustic_input_mask"] = (
acoustic_masks_list
if len(acoustic_masks_list) > 1
else acoustic_masks_list[0]
)
batch_encoding["speech_tensors"] = (
padded_speeches if len(padded_speeches) > 1 else padded_speeches[0]
)
batch_encoding["speech_masks"] = (
speech_masks if len(speech_masks) > 1 else speech_masks[0]
)
return batch_encoding
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
def post_process_transcription(self, text: str) -> List[Dict[str, Any]]:
try:
if "```json" in text:
json_start = text.find("```json") + 7
json_end = text.find("```", json_start)
json_str = text[json_start:json_end].strip()
else:
json_start = text.find("[")
if json_start == -1:
json_start = text.find("{")
if json_start != -1:
bracket_count = 0
json_end = json_start
for i in range(json_start, len(text)):
if text[i] in "[{":
bracket_count += 1
elif text[i] in "]}":
bracket_count -= 1
if bracket_count == 0:
json_end = i + 1
break
json_str = text[json_start:json_end]
else:
json_str = text
result = json.loads(json_str)
if isinstance(result, dict):
result = [result]
cleaned_result = []
for item in result:
if isinstance(item, dict):
cleaned_item = {}
key_mapping = {
"Start time": "start_time",
"Start": "start_time",
"End time": "end_time",
"End": "end_time",
"Speaker ID": "speaker_id",
"Speaker": "speaker_id",
"Content": "text",
}
for key, mapped_key in key_mapping.items():
if key in item:
cleaned_item[mapped_key] = item[key]
if cleaned_item:
cleaned_result.append(cleaned_item)
return cleaned_result
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON from transcription: {e }")
logger.debug(f"Raw text: {text }")
return []
except Exception as e:
logger.warning(f"Error post-processing transcription: {e }")
return []
@property
def model_input_names(self):
return [
"input_ids",
"attention_mask",
"acoustic_input_mask",
"speech_tensors",
"speech_masks",
]
__all__ = [
'QWEN3VoxASRProcessor'
]
import math
import warnings
from typing import List, Optional, Union, Dict, Any, Tuple
import os
import re
import numpy as np
import torch
from transformers.tokenization_utils_base import (
BatchEncoding,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
class QWEN3VoxProcessor:
def __init__(
self,
tokenizer=None,
audio_processor=None,
speech_tok_compress_ratio=3200,
db_normalize=True,
**kwargs,
):
self.tokenizer = tokenizer
self.audio_processor = audio_processor
self.speech_tok_compress_ratio = speech_tok_compress_ratio
self.db_normalize = db_normalize
self.audio_normalizer = AudioNormalizer() if db_normalize else None
self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
import os
import json
from transformers.utils import cached_file
model_name = str(pretrained_model_name_or_path)
config_path = os.path.join(
model_name, "preprocessor_config.json"
)
config = None
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
else:
try:
config_file = cached_file(
model_name, "preprocessor_config.json", **kwargs
)
with open(config_file, "r") as f:
config = json.load(f)
except Exception as e:
logger.warning(
f"Could not load preprocessor_config.json from {model_name }: {e }"
)
logger.warning("Using default configuration")
config = {"speech_tok_compress_ratio": 3200, "db_normalize": True}
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
db_normalize = config.get("db_normalize", True)
language_model_pretrained_name = config.get(
"language_model_pretrained_name", None
) or kwargs.pop("language_model_pretrained_name", None)
if not language_model_pretrained_name:
language_model_pretrained_name = model_name
logger.info(f"Loading tokenizer from repo {model_name }")
tokenizer = QWEN3VoxTextTokenizerFast.from_pretrained(
model_name, **kwargs
)
if "audio_processor" in config:
audio_config = config["audio_processor"]
audio_processor = QWEN3VoxTokenizerProcessor(
sampling_rate=audio_config.get("sampling_rate", 22050),
normalize_audio=audio_config.get("normalize_audio", True),
target_dB_FS=audio_config.get("target_dB_FS", -25),
eps=audio_config.get("eps", 1e-06),
)
else:
audio_processor = QWEN3VoxTokenizerProcessor()
return cls(
tokenizer=tokenizer,
audio_processor=audio_processor,
speech_tok_compress_ratio=speech_tok_compress_ratio,
db_normalize=db_normalize,
)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
import os
import json
os.makedirs(save_directory, exist_ok=True)
processor_config = {
"processor_class": "QWEN3VoxProcessor",
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
"db_normalize": self.db_normalize,
"audio_processor": {
"feature_extractor_type": "QWEN3VoxTokenizerProcessor",
"sampling_rate": getattr(self.audio_processor, "sampling_rate", 22050),
"normalize_audio": getattr(
self.audio_processor, "normalize_audio", True
),
"target_dB_FS": getattr(self.audio_processor, "target_dB_FS", -25),
"eps": getattr(self.audio_processor, "eps", 1e-06),
},
}
config_path = os.path.join(save_directory, "preprocessor_config.json")
with open(config_path, "w") as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Processor configuration saved in {config_path }")
def __call__(
self,
text: Optional[
Union[
str,
List[str],
TextInput,
PreTokenizedInput,
List[TextInput],
List[PreTokenizedInput],
]
] = None,
voice_samples: Optional[
Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]
] = None,
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
**kwargs,
) -> BatchEncoding:
if isinstance(text, str) or (
isinstance(text, list) and len(text) > 0 and (not isinstance(text[0], str))
):
texts = [text]
is_batched = False
else:
texts = text
is_batched = True
if voice_samples is not None:
if not is_batched or isinstance(voice_samples[0], (str, np.ndarray)):
voice_samples_list = [voice_samples]
else:
voice_samples_list = voice_samples
else:
voice_samples_list = [None] * len(texts)
all_encodings = []
for text_input, voice_input in zip(texts, voice_samples_list):
encoding = self._process_single(text_input, voice_input)
all_encodings.append(encoding)
batch_encoding = self._batch_encode(
all_encodings,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_attention_mask=return_attention_mask,
)
return batch_encoding
def _process_single(
self,
text: Union[str, TextInput],
voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
) -> Dict[str, Any]:
script = None
if isinstance(text, str):
if text.endswith(".json") and os.path.exists(text):
script = self._convert_json_to_script(text)
elif text.endswith(".txt") and os.path.exists(text):
script = self._convert_text_to_script(text)
else:
script = text
if script is None:
raise ValueError(f"Could not process input text: {text }")
parsed_lines = self._parse_script(script)
all_speakers = list(set((speaker_id for speaker_id, _ in parsed_lines)))
system_tokens = self.tokenizer.encode(self.system_prompt)
if voice_samples:
voice_tokens, voice_speech_inputs, voice_speech_masks = (
self._create_voice_prompt(voice_samples[: len(all_speakers)])
)
else:
voice_tokens, voice_speech_inputs, voice_speech_masks = ([], [], [])
full_tokens = system_tokens + voice_tokens
speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
full_tokens += self.tokenizer.encode(" Text input:\n", add_special_tokens=False)
speech_input_mask += [False] * len(
self.tokenizer.encode(" Text input:\n", add_special_tokens=False)
)
for speaker_id, speaker_text in parsed_lines:
speaker_text_tokens = self.tokenizer.encode(
f" Speaker {speaker_id }:{speaker_text }\n", add_special_tokens=False
)
full_tokens += speaker_text_tokens
speech_input_mask += [False] * len(speaker_text_tokens)
full_tokens += self.tokenizer.encode(
" Speech output:\n", add_special_tokens=False
) + [self.tokenizer.speech_start_id]
speech_input_mask += [False] * (
len(self.tokenizer.encode(" Speech output:\n", add_special_tokens=False))
+ 1
)
return {
"input_ids": full_tokens,
"speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
"speech_input_mask": speech_input_mask,
"parsed_script": parsed_lines,
"all_speakers": all_speakers,
}
def _batch_encode(
self,
encodings: List[Dict[str, Any]],
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
) -> BatchEncoding:
input_ids_list = [enc["input_ids"] for enc in encodings]
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
if isinstance(padding, bool):
padding_strategy = (
PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
)
elif isinstance(padding, str):
padding_strategy = PaddingStrategy(padding)
else:
padding_strategy = padding
if padding_strategy != PaddingStrategy.DO_NOT_PAD:
if padding_strategy == PaddingStrategy.LONGEST:
max_len = max((len(ids) for ids in input_ids_list))
elif (
padding_strategy == PaddingStrategy.MAX_LENGTH
and max_length is not None
):
max_len = max_length
else:
max_len = max((len(ids) for ids in input_ids_list))
padded_input_ids = []
attention_masks = []
padded_speech_input_masks = []
for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
if truncation and len(input_ids) > max_len:
input_ids = input_ids[:max_len]
speech_mask = speech_mask[:max_len]
padding_length = max_len - len(input_ids)
padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
attention_mask = [0] * padding_length + [1] * len(input_ids)
padded_speech_mask = [False] * padding_length + speech_mask
padded_input_ids.append(padded_ids)
attention_masks.append(attention_mask)
padded_speech_input_masks.append(padded_speech_mask)
input_ids_list = padded_input_ids
speech_input_masks_list = padded_speech_input_masks
else:
attention_masks = (
[[1] * len(ids) for ids in input_ids_list]
if return_attention_mask
else None
)
all_speech_inputs = []
has_speech = False
for enc in encodings:
if enc["speech_inputs"] is not None:
all_speech_inputs.extend(enc["speech_inputs"])
has_speech = True
batch_encoding = BatchEncoding()
if return_tensors is not None:
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = torch.tensor(
attention_masks, dtype=torch.long
)
batch_encoding["speech_input_mask"] = torch.tensor(
speech_input_masks_list, dtype=torch.bool
)
else:
batch_encoding["input_ids"] = input_ids_list
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = attention_masks
batch_encoding["speech_input_mask"] = speech_input_masks_list
if has_speech:
speech_dict = self.prepare_speech_inputs(
all_speech_inputs, return_tensors=return_tensors
)
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
else:
batch_encoding["speech_tensors"] = None
batch_encoding["speech_masks"] = None
batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
return batch_encoding
def _create_voice_prompt(
self, speaker_samples: List[Union[str, np.ndarray]]
) -> Tuple[List[int], List[np.ndarray], List[bool]]:
vae_token_id = self.tokenizer.speech_diffusion_id
voice_full_tokens = self.tokenizer.encode(
" Voice input:\n", add_special_tokens=False
)
voice_speech_inputs = []
voice_speech_masks = [False] * len(voice_full_tokens)
for speaker_id, speaker_audio in enumerate(speaker_samples):
prefix_tokens = self.tokenizer.encode(
f" Speaker {speaker_id }:", add_special_tokens=False
)
if isinstance(speaker_audio, str):
wav = self.audio_processor._load_audio_from_path(speaker_audio)
elif isinstance(speaker_audio, dict):
if "array" in speaker_audio:
wav = np.array(speaker_audio["array"], dtype=np.float32)
elif "audio" in speaker_audio:
wav = np.array(speaker_audio["audio"], dtype=np.float32)
else:
raise ValueError(
f"Dictionary audio input must have 'array' or 'audio' key, got: {speaker_audio .keys ()}"
)
else:
wav = np.array(speaker_audio, dtype=np.float32)
if self.db_normalize and self.audio_normalizer:
wav = self.audio_normalizer(wav)
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
speaker_tokens = (
prefix_tokens
+ [self.tokenizer.speech_start_id]
+ [vae_token_id] * vae_tok_len
+ [self.tokenizer.speech_end_id]
+ self.tokenizer.encode("\n", add_special_tokens=False)
)
vae_input_mask = (
[False] * len(prefix_tokens)
+ [False]
+ [True] * vae_tok_len
+ [False]
+ [False]
)
voice_full_tokens.extend(speaker_tokens)
voice_speech_masks.extend(vae_input_mask)
voice_speech_inputs.append(wav)
return (voice_full_tokens, voice_speech_inputs, voice_speech_masks)
def prepare_speech_inputs(
self,
speech_inputs: List[np.ndarray],
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> Dict[str, Any]:
if not speech_inputs:
return {"padded_speeches": None, "speech_masks": None}
vae_tok_seqlens = [
math.ceil(s.shape[0] / self.speech_tok_compress_ratio)
for s in speech_inputs
]
max_speech_length = max((s.shape[0] for s in speech_inputs))
if speech_inputs[0].ndim == 1:
padded_speeches = np.full(
(len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32
)
else:
padded_speeches = np.full(
(len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]),
fill_value=0,
dtype=np.float32,
)
speech_masks = np.zeros(
(len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_
)
for i, (speech, vae_tok_length) in enumerate(
zip(speech_inputs, vae_tok_seqlens)
):
padded_speeches[i, : len(speech)] = speech
speech_masks[i, :vae_tok_length] = True
result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks}
if return_tensors == "pt":
result["padded_speeches"] = torch.tensor(
padded_speeches, device=device, dtype=dtype or torch.float32
)
result["speech_masks"] = torch.tensor(
speech_masks, device=device, dtype=torch.bool
)
return result
def _convert_json_to_script(self, json_file: str) -> str:
import json
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("JSON file must contain a list of speaker entries")
script_lines = []
for item in data:
if not isinstance(item, dict):
logger.warning(f"Skipping non-dict entry: {item }")
continue
speaker = item.get("speaker")
text = item.get("text")
if speaker is None or text is None:
logger.warning(f"Skipping entry missing speaker or text: {item }")
continue
try:
speaker_id = int(speaker)
except (ValueError, TypeError):
logger.warning(f"Invalid speaker ID: {speaker }, skipping entry")
continue
text = text.strip()
if text:
script_lines.append(f"Speaker {speaker_id }: {text }")
if not script_lines:
raise ValueError("No valid entries found in JSON file")
return "\n".join(script_lines)
def _convert_text_to_script(self, text_file: str) -> str:
with open(text_file, "r", encoding="utf-8") as f:
lines = f.readlines()
script_lines = []
current_speaker = 1
for line in lines:
line = line.strip()
if not line:
continue
speaker_match = re.match(
"^Speaker\\s+(\\d+)\\s*:\\s*(.*)$", line, re.IGNORECASE
)
if speaker_match:
speaker_id = int(speaker_match.group(1))
text = speaker_match.group(2).strip()
if text:
script_lines.append(f"Speaker {speaker_id }: {text }")
else:
script_lines.append(f"Speaker {current_speaker }: {line }")
if not script_lines:
raise ValueError("No valid content found in text file")
return "\n".join(script_lines)
def _parse_script(self, script: str) -> List[Tuple[int, str]]:
stripped = script.strip()
if not stripped:
raise ValueError(
"No valid speaker lines found in script (empty text). "
"If training with HuggingFace Trainer, set remove_unused_columns=False "
"so dataset columns like `text` are not stripped before the collator."
)
non_empty = [ln.strip() for ln in stripped.split("\n") if ln.strip()]
if not non_empty:
raise ValueError("No valid speaker lines found in script")
_speaker_line = r"^Speaker\s+(\d+)\s*:\s*(.*)$"
if not any(re.match(_speaker_line, ln, re.IGNORECASE) for ln in non_empty):
# JSONL / TTS-style rows: plain prompt with no "Speaker N:" lines.
collapsed = " ".join(stripped.split())
return [(0, " " + collapsed)]
parsed_lines: List[Tuple[int, str]] = []
speaker_ids: List[int] = []
for line in non_empty:
match = re.match(_speaker_line, line, re.IGNORECASE)
if match:
speaker_id = int(match.group(1))
text = " " + match.group(2).strip()
parsed_lines.append((speaker_id, text))
speaker_ids.append(speaker_id)
else:
logger.warning(f"Could not parse line: '{line }'")
if not parsed_lines:
raise ValueError("No valid speaker lines found in script")
min_speaker_id = min(speaker_ids)
if min_speaker_id > 0:
normalized_lines = []
for speaker_id, text in parsed_lines:
normalized_lines.append((speaker_id - 1, text))
return normalized_lines
else:
return parsed_lines
def _merge_inputs(
self, text_inputs: BatchEncoding, audio_inputs: Dict
) -> BatchEncoding:
merged = BatchEncoding(text_inputs)
if "audio" in audio_inputs:
merged["speech_inputs"] = audio_inputs["audio"]
if "streaming" in audio_inputs:
merged["streaming"] = audio_inputs["streaming"]
return merged
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
audio_processor_input_names = self.audio_processor.model_input_names
return list(
dict.fromkeys(
tokenizer_input_names
+ audio_processor_input_names
+ ["speech_inputs", "speech_input_mask"]
)
)
def save_audio(
self,
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
) -> str:
return self.audio_processor.save_audio(
audio,
output_path=output_path,
sampling_rate=sampling_rate,
normalize=normalize,
batch_prefix=batch_prefix,
)
__all__ = [
'QWEN3VoxProcessor'
]
'\nQWEN3Vox Streaming Processor\n\nThis processor handles input preparation for the streaming 0.5B model,\nincluding text tokenization and cached voice prompt handling.\n'
import math
import warnings
from typing import List, Optional, Union, Dict, Any, Tuple
import os
import re
import numpy as np
import torch
from transformers.tokenization_utils_base import (
BatchEncoding,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
class QWEN3VoxStreamingProcessor:
def __init__(
self,
tokenizer=None,
audio_processor=None,
speech_tok_compress_ratio=3200,
db_normalize=True,
**kwargs,
):
self.tokenizer = tokenizer
self.audio_processor = audio_processor
self.speech_tok_compress_ratio = speech_tok_compress_ratio
self.db_normalize = db_normalize
self.audio_normalizer = AudioNormalizer() if db_normalize else None
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
import os
import json
from transformers.utils import cached_file
model_name = str(pretrained_model_name_or_path)
config_path = os.path.join(
model_name, "preprocessor_config.json"
)
config = None
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
else:
try:
config_file = cached_file(
model_name, "preprocessor_config.json", **kwargs
)
with open(config_file, "r") as f:
config = json.load(f)
except Exception as e:
logger.warning(
f"Could not load preprocessor_config.json from {model_name }: {e }"
)
logger.warning("Using default configuration")
config = {"speech_tok_compress_ratio": 3200, "db_normalize": True}
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
db_normalize = config.get("db_normalize", True)
logger.info(f"Loading tokenizer from repo {model_name }")
tokenizer = QWEN3VoxTextTokenizerFast.from_pretrained(
model_name, **kwargs
)
if "audio_processor" in config:
audio_config = config["audio_processor"]
audio_processor = QWEN3VoxTokenizerProcessor(
sampling_rate=audio_config.get("sampling_rate", 22050),
normalize_audio=audio_config.get("normalize_audio", True),
target_dB_FS=audio_config.get("target_dB_FS", -25),
eps=audio_config.get("eps", 1e-06),
)
else:
audio_processor = QWEN3VoxTokenizerProcessor()
return cls(
tokenizer=tokenizer,
audio_processor=audio_processor,
speech_tok_compress_ratio=speech_tok_compress_ratio,
db_normalize=db_normalize,
)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
import os
import json
os.makedirs(save_directory, exist_ok=True)
processor_config = {
"processor_class": "QWEN3VoxStreamingProcessor",
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
"db_normalize": self.db_normalize,
"audio_processor": {
"feature_extractor_type": "QWEN3VoxTokenizerProcessor",
"sampling_rate": getattr(self.audio_processor, "sampling_rate", 22050),
"normalize_audio": getattr(
self.audio_processor, "normalize_audio", True
),
"target_dB_FS": getattr(self.audio_processor, "target_dB_FS", -25),
"eps": getattr(self.audio_processor, "eps", 1e-06),
},
}
config_path = os.path.join(save_directory, "preprocessor_config.json")
with open(config_path, "w") as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Processor configuration saved in {config_path }")
def __call__(self) -> BatchEncoding:
raise NotImplementedError(
'QWEN3VoxStreamingProcessor.__call__ is not implemented. Use process_input_with_cached_prompt for streaming inputs.'
)
def process_input_with_cached_prompt(
self,
text: Optional[str] = None,
cached_prompt: Optional[Dict[str, Any]] = None,
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
**kwargs,
) -> BatchEncoding:
texts = [text]
cached_prompts = [cached_prompt]
is_batched = False
all_encodings = []
for text_input, cached_prompt_input in zip(texts, cached_prompts):
script_tokens = self.tokenizer.encode(
text_input.strip() + "\n", add_special_tokens=False
)
input_id_length = cached_prompt_input["lm"]["last_hidden_state"].size(1)
tts_lm_input_id_length = cached_prompt_input["tts_lm"][
"last_hidden_state"
].size(1)
input_ids = [self.tokenizer.pad_id] * input_id_length
tts_lm_input_ids = [self.tokenizer.pad_id] * tts_lm_input_id_length
speech_input_mask = [False] * tts_lm_input_id_length
encoding = {
"input_ids": input_ids,
"tts_lm_input_ids": tts_lm_input_ids,
"tts_text_ids": script_tokens,
"speech_inputs": None,
"speech_input_mask": speech_input_mask,
}
all_encodings.append(encoding)
batch_encoding = self._batch_encode(
all_encodings,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_attention_mask=return_attention_mask,
)
return batch_encoding
def _batch_encode(
self,
encodings: List[Dict[str, Any]],
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
) -> BatchEncoding:
input_ids_list = [enc["input_ids"] for enc in encodings]
tts_lm_input_ids_list = [enc["tts_lm_input_ids"] for enc in encodings]
tts_text_ids_list = [enc["tts_text_ids"] for enc in encodings]
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
attention_masks = (
[[1] * len(ids) for ids in input_ids_list]
if return_attention_mask
else None
)
tts_lm_attention_masks = (
[[1] * len(ids) for ids in tts_lm_input_ids_list]
if return_attention_mask
else None
)
all_speech_inputs = []
has_speech = False
for enc in encodings:
if enc["speech_inputs"] is not None:
all_speech_inputs.extend(enc["speech_inputs"])
has_speech = True
batch_encoding = BatchEncoding()
if return_tensors is not None:
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
batch_encoding["tts_lm_input_ids"] = torch.tensor(
tts_lm_input_ids_list, dtype=torch.long
)
batch_encoding["tts_text_ids"] = torch.tensor(
tts_text_ids_list, dtype=torch.long
)
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = torch.tensor(
attention_masks, dtype=torch.long
)
batch_encoding["tts_lm_attention_mask"] = torch.tensor(
tts_lm_attention_masks, dtype=torch.long
)
batch_encoding["speech_input_mask"] = torch.tensor(
speech_input_masks_list, dtype=torch.bool
)
else:
batch_encoding["input_ids"] = input_ids_list
batch_encoding["tts_lm_input_ids"] = tts_lm_input_ids_list
batch_encoding["tts_text_ids"] = tts_text_ids_list
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = attention_masks
batch_encoding["tts_lm_attention_mask"] = tts_lm_attention_masks
batch_encoding["speech_input_mask"] = speech_input_masks_list
if has_speech:
speech_dict = self.prepare_speech_inputs(
all_speech_inputs, return_tensors=return_tensors
)
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
else:
batch_encoding["speech_tensors"] = None
batch_encoding["speech_masks"] = None
return batch_encoding
def prepare_speech_inputs(
self,
speech_inputs: List[np.ndarray],
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> Dict[str, Any]:
if not speech_inputs:
return {"padded_speeches": None, "speech_masks": None}
vae_tok_seqlens = [
math.ceil(s.shape[0] / self.speech_tok_compress_ratio)
for s in speech_inputs
]
max_speech_length = max((s.shape[0] for s in speech_inputs))
if speech_inputs[0].ndim == 1:
padded_speeches = np.full(
(len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32
)
else:
padded_speeches = np.full(
(len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]),
fill_value=0,
dtype=np.float32,
)
speech_masks = np.zeros(
(len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_
)
for i, (speech, vae_tok_length) in enumerate(
zip(speech_inputs, vae_tok_seqlens)
):
padded_speeches[i, : len(speech)] = speech
speech_masks[i, :vae_tok_length] = True
result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks}
if return_tensors == "pt":
result["padded_speeches"] = torch.tensor(
padded_speeches, device=device, dtype=dtype or torch.float32
)
result["speech_masks"] = torch.tensor(
speech_masks, device=device, dtype=torch.bool
)
return result
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
audio_processor_input_names = self.audio_processor.model_input_names
return list(
dict.fromkeys(
tokenizer_input_names
+ audio_processor_input_names
+ ["speech_inputs", "speech_input_mask"]
)
)
def save_audio(
self,
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
) -> str:
return self.audio_processor.save_audio(
audio,
output_path=output_path,
sampling_rate=sampling_rate,
normalize=normalize,
batch_prefix=batch_prefix,
)
__all__ = [
'QWEN3VoxStreamingProcessor'
]
'\nQWEN3Vox Streaming Model Architecture (0.5B)\n\nThis module implements the streaming-optimized version of QWEN3Vox for real-time TTS.\nKey differences from the multi-speaker model:\n- No semantic tokenizer (only acoustic)\n- Split language model architecture: lower layers for text, upper layers for TTS\n- Optimized for low-latency generation\n'
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
CausalLMOutput,
BaseModelOutputWithPast,
ModelOutput,
)
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
logger = logging.get_logger(__name__)
if (
not hasattr(modeling_utils, "ALL_PARALLEL_STYLES")
or modeling_utils.ALL_PARALLEL_STYLES is None
):
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
class BinaryClassifier(nn.Module):
def __init__(self, hidden_size):
super(BinaryClassifier, self).__init__()
self.fc1 = nn.Linear(hidden_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
class SpeechConnector(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
self.norm = LlamaRMSNorm(output_dim, eps=1e-06)
self.fc2 = nn.Linear(output_dim, output_dim)
def forward(self, features, **kwargs):
x = self.fc1(features)
x = self.norm(x)
x = self.fc2(x)
return x
class QWEN3VoxStreamingPreTrainedModel(PreTrainedModel):
config_class = QWEN3VoxStreamingConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if isinstance(module, QWEN3VoxDiffusionHead):
module.initialize_weights()
return
if hasattr(self.config, "language_model_config") and hasattr(
self.config.language_model_config, "initializer_range"
):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, "decoder_config") and hasattr(
self.config.decoder_config, "initializer_range"
):
std = self.config.decoder_config.initializer_range
else:
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
class QWEN3VoxStreamingModel(QWEN3VoxStreamingPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
lm_config = copy.deepcopy(config.decoder_config)
lm_backbone_num_hidden_layers = (
getattr(lm_config, "num_hidden_layers", 24)
- config.tts_backbone_num_hidden_layers
)
lm_config.num_hidden_layers = lm_backbone_num_hidden_layers
self.language_model = AutoModel.from_config(lm_config)
self.language_model.norm = nn.Identity()
tts_lm_config = copy.deepcopy(lm_config)
tts_lm_config.num_hidden_layers = config.tts_backbone_num_hidden_layers
self.tts_language_model = AutoModel.from_config(tts_lm_config)
self.tts_input_types = nn.Embedding(
num_embeddings=2, embedding_dim=config.decoder_config.hidden_size
)
self.acoustic_tokenizer = AutoModel.from_config(
config.acoustic_tokenizer_config
).to(dtype)
self.acoustic_connector = SpeechConnector(
config.acoustic_vae_dim, lm_config.hidden_size
).to(dtype)
self.register_buffer("speech_scaling_factor", torch.tensor(float("nan")))
self.register_buffer("speech_bias_factor", torch.tensor(float("nan")))
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(
dtype
)
self.noise_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
prediction_type=config.diffusion_head_config.prediction_type,
)
def get_input_embeddings(self):
if hasattr(self.language_model, "embed_tokens"):
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items():
if attr.orig_name == "embed_tokens.weight":
return getattr(self.language_model, name)
assert False, "should not arrive here"
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None):
self.acoustic_tokenizer = acoustic_tokenizer
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.train(False)
def forward(self, *args, **kwargs):
raise RuntimeError(
'QWEN3VoxStreamingModel.forward is intentionally disabled. Use `model.language_model(...)` or `model.tts_language_model(...)` instead.'
)
AutoModel.register(QWEN3VoxStreamingConfig, QWEN3VoxStreamingModel)
__all__ = [
'QWEN3VoxStreamingPreTrainedModel',
'QWEN3VoxStreamingModel',
"BinaryClassifier",
"SpeechConnector",
]
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
CausalLMOutput,
BaseModelOutputWithPast,
ModelOutput,
)
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
logger = logging.get_logger(__name__)
if (
not hasattr(modeling_utils, "ALL_PARALLEL_STYLES")
or modeling_utils.ALL_PARALLEL_STYLES is None
):
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
@dataclass
class QWEN3VoxCausalLMOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
diffusion_loss: Optional[torch.FloatTensor] = None
speech_token_num: Optional[int] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class QWEN3VoxGenerationOutput(ModelOutput):
sequences: torch.LongTensor = None
speech_outputs: Optional[List[torch.FloatTensor]] = None
class SpeechConnector(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
self.norm = LlamaRMSNorm(output_dim, eps=1e-06)
self.fc2 = nn.Linear(output_dim, output_dim)
def forward(self, features, **kwargs):
x = self.fc1(features)
x = self.norm(x)
x = self.fc2(x)
return x
class QWEN3VoxPreTrainedModel(PreTrainedModel):
config_class = QWEN3VoxConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if isinstance(module, QWEN3VoxDiffusionHead):
module.initialize_weights()
return
if hasattr(self.config, "language_model_config") and hasattr(
self.config.language_model_config, "initializer_range"
):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, "decoder_config") and hasattr(
self.config.decoder_config, "initializer_range"
):
std = self.config.decoder_config.initializer_range
else:
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
class QWEN3VoxModel(QWEN3VoxPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
lm_config = config.decoder_config
self.language_model = AutoModel.from_config(lm_config)
self.acoustic_tokenizer = AutoModel.from_config(
config.acoustic_tokenizer_config
).to(dtype)
self.semantic_tokenizer = AutoModel.from_config(
config.semantic_tokenizer_config
).to(dtype)
self.acoustic_connector = SpeechConnector(
config.acoustic_vae_dim, lm_config.hidden_size
).to(dtype)
self.semantic_connector = SpeechConnector(
config.semantic_vae_dim, lm_config.hidden_size
).to(dtype)
self.register_buffer("speech_scaling_factor", torch.tensor(float("nan")))
self.register_buffer("speech_bias_factor", torch.tensor(float("nan")))
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(
dtype
)
self.noise_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
prediction_type=config.diffusion_head_config.prediction_type,
)
def get_input_embeddings(self):
if hasattr(self.language_model, "embed_tokens"):
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items():
if attr.orig_name == "embed_tokens.weight":
return getattr(self.language_model, name)
assert False, "should not arrive here"
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
self.acoustic_tokenizer = acoustic_tokenizer
self.semantic_tokenizer = semantic_tokenizer
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.train(False)
if self.semantic_tokenizer is not None:
self.semantic_tokenizer.train(False)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
if not return_dict:
return outputs
return BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class QWEN3VoxForConditionalGeneration(QWEN3VoxPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = QWEN3VoxModel(config)
self.vocab_size = config.decoder_config.vocab_size
self.lm_head = nn.Linear(
config.decoder_config.hidden_size, self.vocab_size, bias=False
)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_decoder(self, decoder):
self.model.language_model = decoder
def get_decoder(self):
return self.model.language_model
def tie_weights(self):
if getattr(self.config.decoder_config, "tie_word_embeddings", False):
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
if hasattr(input_embeddings, "weight"):
output_embeddings.weight = input_embeddings.weight
else:
output_embeddings.weight = input_embeddings
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = nn.functional.pad(
output_embeddings.bias.data,
(
0,
output_embeddings.weight.shape[0]
- output_embeddings.bias.shape[0],
),
"constant",
0,
)
print("Tied input and output embeddings using standard assignment.")
else:
print("tie_word_embeddings is False, not tying weights.")
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward_speech_features(
self,
speech_tensors=None,
speech_masks=None,
speech_type="audio",
return_unmask=False,
):
if speech_tensors is None:
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
audio_features = torch.zeros(1, 1, vae_dim).to(
self.get_input_embeddings().weight
)
connect_features = self.model.acoustic_connector(audio_features)
return (audio_features, connect_features)
else:
with torch.no_grad():
if speech_type == "audio":
with torch.no_grad():
frames = self.model.acoustic_tokenizer.encode(
speech_tensors.unsqueeze(1)
)[0][0]
audio_tokens = frames.sample(
self.model.acoustic_tokenizer.std_dist_type
)[0]
elif speech_type == "vae":
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
speech_mode = speech_tensors.reshape(
speech_tensors.size(0), -1, vae_dim
)
batch_size = speech_mode.size(0)
value = self.model.acoustic_tokenizer.fix_std / 0.8
std = (
torch.randn(
batch_size,
dtype=speech_mode.dtype,
device=speech_mode.device,
)
* value
)
std = std.view(-1, *[1] * (speech_mode.dim() - 1))
audio_tokens = speech_mode + std * torch.randn(
speech_mode.shape
).to(speech_mode)
else:
raise NotImplementedError(
f"Speech type {speech_type } not implemented"
)
if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(
self.model.speech_bias_factor
):
scaling_factor = 1.0 / audio_tokens[speech_masks].flatten().std()
bias_factor = -audio_tokens[speech_masks].flatten().mean()
if dist.is_available() and dist.is_initialized():
dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
world_size = dist.get_world_size()
self.model.speech_scaling_factor.copy_(
scaling_factor / world_size
)
self.model.speech_bias_factor.copy_(bias_factor / world_size)
print(
f"Speech scaling factor (distributed): {self .model .speech_scaling_factor }, bias factor: {self .model .speech_bias_factor }",
flush=True,
)
else:
self.model.speech_scaling_factor.copy_(scaling_factor)
self.model.speech_bias_factor.copy_(bias_factor)
print(
f"Speech scaling factor (single process): {self .model .speech_scaling_factor }, bias factor: {self .model .speech_bias_factor }",
flush=True,
)
audio_features = (
audio_tokens + self.model.speech_bias_factor
) * self.model.speech_scaling_factor
connect_features = self.model.acoustic_connector(audio_features)
if return_unmask:
return (audio_features, connect_features)
return (audio_features[speech_masks], connect_features[speech_masks])
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speeches_loss_input: Optional[torch.FloatTensor] = None,
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
acoustic_input_mask: Optional[torch.BoolTensor] = None,
acoustic_loss_mask: Optional[torch.BoolTensor] = None,
ddpm_batch_mul: int = 1,
**kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
) -> Union[Tuple, QWEN3VoxCausalLMOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
x = self.get_input_embeddings()(input_ids)
semantic_speech_all_connect_features = self.model.semantic_connector(
speech_semantic_tensors
)
if speeches_loss_input is not None:
speech_all_features, speech_all_connect_features = (
self.forward_speech_features(
speech_tensors=(
speech_tensors.type_as(x)
if speech_tensors is not None
else None
),
speech_masks=speech_masks,
speech_type=kwargs.get("speech_type", "audio"),
return_unmask=True,
)
)
if speech_tensors is not None:
if semantic_speech_all_connect_features is not None:
x[acoustic_input_mask] = (
speech_all_connect_features[speech_masks]
+ semantic_speech_all_connect_features[speech_masks]
)
else:
x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
target_latent_mask = speeches_loss_input & speech_masks
speech_features = speech_all_features[target_latent_mask]
speech_connect_features = speech_all_connect_features[
target_latent_mask
]
else:
speech_features, speech_connect_features = self.forward_speech_features(
speech_tensors=(
speech_tensors.type_as(x) if speech_tensors is not None else None
),
speech_masks=speech_masks,
speech_type=kwargs.get("speech_type", "audio"),
)
if speech_tensors is not None:
x[acoustic_input_mask] = speech_connect_features
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=x,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=False,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
pass
diffusion_loss = None
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
condition_features = hidden_states[acoustic_loss_mask]
speech_len, latent_size = speech_features.shape
noise = torch.randn(
(speech_len * ddpm_batch_mul, latent_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
timesteps = torch.multinomial(
torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
speech_len * ddpm_batch_mul,
replacement=True,
).to(hidden_states.device)
speech_features_repeated = speech_features.repeat_interleave(
ddpm_batch_mul, dim=0
)
condition_features_repeated = condition_features.repeat_interleave(
ddpm_batch_mul, dim=0
)
noisy_speech_features = self.model.noise_scheduler.add_noise(
speech_features_repeated, noise, timesteps
)
model_output = self.model.prediction_head(
noisy_speech_features, timesteps.type_as(x), condition_features_repeated
)
prediction_type = self.config.diffusion_head_config.prediction_type
if prediction_type == "epsilon":
target_for_loss = noise
elif prediction_type == "v_prediction":
target_for_loss = self.model.noise_scheduler.get_velocity(
speech_features_repeated, noise, timesteps
)
else:
raise NotImplementedError(
f"Prediction type {prediction_type } not implemented"
)
diffusion_loss = F.mse_loss(
model_output.float(), target_for_loss.float(), reduction="sum"
)
if latent_size > 0 and ddpm_batch_mul > 0:
diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
else:
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
else:
diffusion_loss = (
sum((p.sum() for p in self.model.prediction_head.parameters())) * 0.0
)
diffusion_loss += (
sum((p.sum() for p in self.model.acoustic_connector.parameters())) * 0.0
)
diffusion_loss += (
sum((p.sum() for p in self.model.semantic_connector.parameters())) * 0.0
)
if not return_dict:
output = (logits, speech_len) + outputs.to_tuple()[1:]
return (loss, diffusion_loss) + output
return QWEN3VoxCausalLMOutputWithPast(
loss=loss,
diffusion_loss=diffusion_loss,
speech_token_num=speech_len if speech_tensors is not None else 0,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
AutoModel.register(QWEN3VoxConfig, QWEN3VoxModel)
AutoModelForCausalLM.register(QWEN3VoxConfig, QWEN3VoxForConditionalGeneration)
__all__ = [
'QWEN3VoxModel',
'QWEN3VoxPreTrainedModel',
'QWEN3VoxForConditionalGeneration',
'QWEN3VoxCausalLMOutputWithPast',
'QWEN3VoxGenerationOutput',
]
'\nQWEN3Vox Processors\n\nThis module provides processors for preparing inputs for QWEN3Vox models:\n- QWEN3VoxProcessor: For multi-speaker models (1.5B, 7B)\n- QWEN3VoxStreamingProcessor: For streaming model (0.5B)\n'
__all__ = [
'QWEN3VoxProcessor',
'QWEN3VoxStreamingProcessor',
'QWEN3VoxTokenizerProcessor',
"AudioNormalizer",
'QWEN3VoxASRProcessor',
]
'\nQWEN3Vox Streaming Inference Model (0.5B)\n\nThis module implements the inference engine for real-time streaming TTS.\nKey features:\n- Window-based text/speech interleaving for streaming\n- Binary EOS classifier for end-of-speech detection\n- Classifier-free guidance for speech quality\n- Audio streaming support\n'
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.generation import (
GenerationMixin,
GenerationConfig,
LogitsProcessor,
LogitsProcessorList,
StoppingCriteriaList,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
logger = logging.get_logger(__name__)
if (
not hasattr(modeling_utils, "ALL_PARALLEL_STYLES")
or modeling_utils.ALL_PARALLEL_STYLES is None
):
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
TTS_TEXT_WINDOW_SIZE = 5
TTS_SPEECH_WINDOW_SIZE = 6
def _update_model_kwargs_for_generation(
outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1
) -> Dict[str, Any]:
model_kwargs["past_key_values"] = getattr(outputs, "past_key_values")
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], num_new_tokens)),
],
dim=-1,
)
model_kwargs["cache_position"] = torch.arange(
model_kwargs["cache_position"][-1] + 1,
model_kwargs["cache_position"][-1] + num_new_tokens + 1,
).to(model_kwargs["cache_position"].device)
return model_kwargs
@dataclass
class QWEN3VoxLMHeadOutputWithPast(BaseModelOutputWithPast):
"""LM-head-only return type for streaming / lightweight forwards (no loss/diffusion fields)."""
logits: Optional[torch.FloatTensor] = None
@dataclass
class QWEN3VoxGenerationOutput(ModelOutput):
sequences: torch.LongTensor = None
speech_outputs: Optional[List[torch.FloatTensor]] = None
reach_max_step_sample: Optional[torch.BoolTensor] = None
class QWEN3VoxStreamingForConditionalGenerationInference(
QWEN3VoxStreamingPreTrainedModel, GenerationMixin
):
def __init__(self, config):
super().__init__(config)
self.model = QWEN3VoxStreamingModel(config)
self.tts_eos_classifier = BinaryClassifier(config.decoder_config.hidden_size)
self.ddpm_inference_steps = (
config.diffusion_head_config.ddpm_num_inference_steps
)
self.post_init()
@property
def noise_scheduler(self):
return self.model.noise_scheduler
@property
def prediction_head(self):
return self.model.prediction_head
@property
def speech_scaling_factor(self):
return self.model.speech_scaling_factor
@property
def speech_bias_factor(self):
return self.model.speech_bias_factor
@property
def acoustic_tokenizer(self):
return self.model.acoustic_tokenizer
@property
def acoustic_connector(self):
return self.model.acoustic_connector
def tie_weights(self):
if not getattr(self.config, "tie_word_embeddings", False):
return
if hasattr(self, "lm_head") and hasattr(
self.model.language_model, "embed_tokens"
):
self.lm_head.weight = self.model.language_model.embed_tokens.weight
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return None
def set_output_embeddings(self, new_embeddings):
raise RuntimeError(
"Output embeddings (lm_head) are not defined for this model. Create one before calling set_output_embeddings if needed."
)
def set_speech_tokenizers(self, acoustic_tokenizer=None):
self.model.set_speech_tokenizers(acoustic_tokenizer)
def set_ddpm_inference_steps(self, num_steps=None):
self.ddpm_inference_steps = (
num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
)
def forward_lm(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
if labels is not None:
raise NotImplementedError(
"Loss computation is not implemented in this version."
)
return BaseModelOutputWithPast(
past_key_values=outputs.past_key_values,
last_hidden_state=hidden_states,
attentions=outputs.attentions,
)
def forward_tts_lm(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
lm_last_hidden_state: Optional[torch.FloatTensor] = None,
tts_text_masks: Optional[torch.BoolTensor] = None,
**kwargs,
) -> Union[Tuple, QWEN3VoxLMHeadOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
start_idx = inputs_embeds.shape[1] - lm_last_hidden_state.shape[1]
inputs_embeds[:, start_idx:, :] = lm_last_hidden_state
inputs_embeds = inputs_embeds + self.model.tts_input_types(
tts_text_masks.long()
)
outputs = self.model.tts_language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
logits = self.tts_eos_classifier(hidden_states[:, -1, :])
if labels is not None:
raise NotImplementedError(
"Loss computation is not implemented in this version."
)
return QWEN3VoxLMHeadOutputWithPast(
logits=logits,
past_key_values=outputs.past_key_values,
last_hidden_state=hidden_states,
attentions=outputs.attentions,
)
def forward(self, *args, **kwargs):
raise RuntimeError(
"Unified forward is disabled. Use `forward_lm`, `forward_tts_lm`, or `generate` instead."
)
def _build_generate_config_model_kwargs(
self, generation_config, inputs, tokenizer, return_processors=False, **kwargs
):
if generation_config is None:
generation_config = GenerationConfig(
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
else:
generation_config = GenerationConfig(
**generation_config,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
generation_config, model_kwargs = self._prepare_generation_config(
generation_config,
True,
speech_start_id=tokenizer.speech_start_id,
speech_end_id=tokenizer.speech_end_id,
speech_diffusion_id=tokenizer.speech_diffusion_id,
**kwargs,
)
generation_config.speech_start_id = tokenizer.speech_start_id
generation_config.speech_end_id = tokenizer.speech_end_id
generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = self.device
self._prepare_special_tokens(generation_config, True, device=device)
generation_config.use_cache = True
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = inputs_tensor.to(self.device)
input_ids_length = input_ids.shape[1]
has_default_max_length = (
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
has_default_min_length = (
kwargs.get("min_length") is None
and generation_config.min_length is not None
)
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
max_cache_length = generation_config.max_length - 1
self._prepare_cache_for_generation(
generation_config, model_kwargs, None, batch_size, max_cache_length, device
)
model_kwargs["cache_position"] = torch.arange(
input_ids_length, device=device, dtype=torch.long
)
for k, v in model_kwargs.items():
if isinstance(v, torch.Tensor):
model_kwargs[k] = v.to(device=device)
if return_processors:
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=LogitsProcessorList(),
device=inputs_tensor.device,
model_kwargs=model_kwargs,
)
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=StoppingCriteriaList(),
)
return (
generation_config,
model_kwargs,
input_ids,
logits_processor,
stopping_criteria,
)
else:
return (generation_config, model_kwargs, input_ids)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
tts_text_ids: Optional[torch.LongTensor] = None,
return_speech: bool = True,
cfg_scale: float = 1.0,
stop_check_fn: Optional[Callable[[], bool]] = None,
**kwargs,
) -> Union[torch.LongTensor, QWEN3VoxGenerationOutput]:
tokenizer = kwargs.pop("tokenizer", None)
neg_text_input_id = tokenizer.convert_tokens_to_ids("<|image_pad|>")
tts_lm_input_ids = kwargs.pop("tts_lm_input_ids", None)
tts_lm_attention_mask = kwargs.pop("tts_lm_attention_mask", None)
all_prefilled_outputs = kwargs.pop("all_prefilled_outputs", None)
tts_text_ids = tts_text_ids.to(self.device)
if kwargs.get("max_new_tokens", None) is None:
kwargs["max_new_tokens"] = (
self.config.decoder_config.max_position_embeddings
- tts_lm_input_ids.shape[-1]
)
(
generation_config,
model_kwargs,
input_ids,
logits_processor,
stopping_criteria,
) = self._build_generate_config_model_kwargs(
generation_config, inputs, tokenizer, return_processors=True, **kwargs
)
negative_kwargs = {
"input_ids": torch.full(
(kwargs["input_ids"].shape[0], 1),
neg_text_input_id,
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"attention_mask": torch.ones(
(kwargs["input_ids"].shape[0], 1),
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
negative_generation_config, negative_model_kwargs, negative_input_ids = (
self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **negative_kwargs
)
)
tts_lm_kwargs = {
"input_ids": tts_lm_input_ids,
"attention_mask": tts_lm_attention_mask,
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
tts_lm_generation_config, tts_lm_model_kwargs, tts_lm_input_ids = (
self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **tts_lm_kwargs
)
)
tts_lm_negative_kwargs = {
"input_ids": torch.full(
(kwargs["input_ids"].shape[0], 1),
neg_text_input_id,
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"attention_mask": torch.ones(
(kwargs["input_ids"].shape[0], 1),
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
(
tts_lm_negative_generation_config,
tts_lm_negative_model_kwargs,
tts_lm_negative_input_ids,
) = self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **tts_lm_negative_kwargs
)
acoustic_cache = QWEN3VoxTokenizerStreamingCache()
batch_size = input_ids.shape[0]
assert batch_size == 1, "Currently only supports batch size == 1"
device = input_ids.device
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
verbose = kwargs.get("verbose", False)
audio_chunks = [[] for _ in range(batch_size)]
tts_text_window_index = 0
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
first_text_window_size = (
TTS_TEXT_WINDOW_SIZE
if tts_text_ids.shape[1] >= TTS_TEXT_WINDOW_SIZE
else tts_text_ids.shape[1]
)
outputs = all_prefilled_outputs["lm"]
tts_lm_outputs = all_prefilled_outputs["tts_lm"]
negative_outputs = all_prefilled_outputs["neg_lm"]
tts_lm_negative_outputs = all_prefilled_outputs["neg_tts_lm"]
model_kwargs = _update_model_kwargs_for_generation(
outputs, model_kwargs, num_new_tokens=first_text_window_size
)
tts_lm_model_kwargs = _update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=first_text_window_size
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs, negative_model_kwargs, is_encoder_decoder=False
)
tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_negative_outputs,
tts_lm_negative_model_kwargs,
is_encoder_decoder=False,
)
step = tts_lm_input_ids.shape[1]
total_generated_speech_tokens = 0
total_prefilled_text_tokens = 0
if kwargs.get("show_progress_bar", True):
progress_bar = tqdm(
total=tts_lm_generation_config.max_length,
desc=f"Prefilled {step } tokens, current step ({step } / {tts_lm_generation_config .max_length })",
initial=step,
leave=False,
)
else:
progress_bar = None
while True:
if stop_check_fn is not None and stop_check_fn():
if verbose:
print(f"Generation stopped externally at step {step +1 }")
if audio_streamer is not None:
audio_streamer.end()
break
if finished_tags.all():
if hasattr(progress_bar, "set_description"):
progress_bar.set_description("Generation complete")
break
cur_input_tts_text_ids = tts_text_ids[
:,
tts_text_window_index
* TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 1)
* TTS_TEXT_WINDOW_SIZE,
]
next_text_window_size = tts_text_ids[
:,
(tts_text_window_index + 1)
* TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 2)
* TTS_TEXT_WINDOW_SIZE,
].shape[1]
tts_text_window_index += 1
if cur_input_tts_text_ids.shape[1] > 0:
input_ids = torch.cat([input_ids, cur_input_tts_text_ids], dim=-1)
tts_lm_input_ids = torch.cat(
[tts_lm_input_ids, cur_input_tts_text_ids], dim=-1
)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
if verbose:
print(
f"Reached maximum generation length {generation_config .max_length }, stopped it."
)
reached_samples = torch.arange(batch_size, device=device)[
~finished_tags
]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
step += cur_input_tts_text_ids.shape[1]
total_prefilled_text_tokens += cur_input_tts_text_ids.shape[1]
if progress_bar is not None:
progress_bar.update(cur_input_tts_text_ids.shape[1])
progress_bar.set_description(
f"Prefilled {total_prefilled_text_tokens } text tokens, generated {total_generated_speech_tokens } speech tokens, current step ({step } / {tts_lm_generation_config .max_length })"
)
model_inputs = self.prepare_inputs_for_generation(
input_ids, **model_kwargs
)
outputs = self.forward_lm(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
model_kwargs = _update_model_kwargs_for_generation(
outputs, model_kwargs, num_new_tokens=next_text_window_size
)
tts_lm_model_inputs = self.prepare_inputs_for_generation(
tts_lm_input_ids, **tts_lm_model_kwargs
)
tts_lm_additional_inputs = {
"tts_text_masks": torch.ones_like(tts_lm_input_ids[:, -1:]),
"lm_last_hidden_state": outputs.last_hidden_state,
}
tts_lm_outputs = self.forward_tts_lm(
**tts_lm_model_inputs,
**tts_lm_additional_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
tts_lm_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False
)
diffusion_indices = torch.LongTensor([0])
for cur_speech_index in range(TTS_SPEECH_WINDOW_SIZE):
positive_condition = tts_lm_outputs.last_hidden_state[
diffusion_indices, -1, :
]
negative_condition = tts_lm_negative_outputs.last_hidden_state[
diffusion_indices, -1, :
]
speech_latent = self.sample_speech_tokens(
positive_condition, negative_condition, cfg_scale=cfg_scale
).unsqueeze(1)
scaled_latent = speech_latent / self.model.speech_scaling_factor.to(
speech_latent.device
) - self.model.speech_bias_factor.to(speech_latent.device)
audio_chunk = self.model.acoustic_tokenizer.decode(
scaled_latent.to(self.model.acoustic_tokenizer.device),
cache=acoustic_cache,
sample_indices=diffusion_indices.to(
self.model.acoustic_tokenizer.device
),
use_cache=True,
debug=False,
)
for i, sample_idx in enumerate(diffusion_indices):
idx = sample_idx.item()
if not finished_tags[idx]:
audio_chunks[idx].append(audio_chunk[i])
if audio_streamer is not None:
audio_streamer.put(audio_chunk, diffusion_indices)
acoustic_embed = self.model.acoustic_connector(speech_latent)
tts_lm_input_ids = torch.cat(
[tts_lm_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])],
dim=-1,
)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
break
step += 1
total_generated_speech_tokens += 1
if progress_bar is not None:
progress_bar.update(1)
progress_bar.set_description(
f"Prefilled {total_prefilled_text_tokens } text tokens, generated {total_generated_speech_tokens } speech tokens, current step ({step } / {tts_lm_generation_config .max_length })"
)
tts_lm_model_inputs = self.prepare_inputs_for_generation(
tts_lm_input_ids, **tts_lm_model_kwargs
)
tts_lm_additional_inputs = {
"tts_text_masks": torch.zeros_like(tts_lm_input_ids[:, -1:]),
"lm_last_hidden_state": acoustic_embed,
}
tts_lm_outputs = self.forward_tts_lm(
**tts_lm_model_inputs,
**tts_lm_additional_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
if (
cur_speech_index == TTS_SPEECH_WINDOW_SIZE - 1
and next_text_window_size > 0
):
tts_lm_model_kwargs = _update_model_kwargs_for_generation(
tts_lm_outputs,
tts_lm_model_kwargs,
num_new_tokens=next_text_window_size,
)
else:
tts_lm_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False
)
tts_lm_negative_input_ids = torch.cat(
[
tts_lm_negative_input_ids,
torch.ones_like(tts_lm_input_ids[:, -1:]),
],
dim=-1,
)
tts_lm_negative_model_inputs = self.prepare_inputs_for_generation(
tts_lm_negative_input_ids, **tts_lm_negative_model_kwargs
)
tts_lm_negative_additional_inputs = {
"tts_text_masks": torch.zeros_like(
tts_lm_negative_input_ids[:, -1:]
),
"lm_last_hidden_state": acoustic_embed,
}
tts_lm_negative_outputs = self.forward_tts_lm(
**tts_lm_negative_model_inputs,
**tts_lm_negative_additional_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_negative_outputs,
tts_lm_negative_model_kwargs,
is_encoder_decoder=False,
)
tts_eos_logits = torch.sigmoid(
self.tts_eos_classifier(
tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :]
)
)
if tts_eos_logits[0].item() > 0.5:
finished_tags[diffusion_indices] = True
if audio_streamer is not None:
audio_streamer.end(diffusion_indices)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
if verbose:
print(
f"Reached maximum generation length {tts_lm_generation_config .max_length }, stopped it."
)
reached_samples = torch.arange(batch_size, device=device)[
~finished_tags
]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
if audio_streamer is not None:
audio_streamer.end()
final_audio_outputs = []
for sample_chunks in audio_chunks:
if sample_chunks:
concatenated_audio = torch.cat(sample_chunks, dim=-1)
final_audio_outputs.append(concatenated_audio)
else:
final_audio_outputs.append(None)
if reach_max_step_sample is not None and reach_max_step_sample.any():
print(
f"Reached maximum generation length {tts_lm_generation_config .max_length }, stopped it."
)
return QWEN3VoxGenerationOutput(
sequences=tts_lm_input_ids,
speech_outputs=final_audio_outputs if return_speech else None,
reach_max_step_sample=reach_max_step_sample,
)
@torch.no_grad()
def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
condition = torch.cat([condition, neg_condition], dim=0).to(
self.model.prediction_head.device
)
speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(
condition
)
for t in self.model.noise_scheduler.timesteps:
half = speech[: len(speech) // 2]
combined = torch.cat([half, half], dim=0)
eps = self.model.prediction_head(
combined, t.repeat(combined.shape[0]).to(combined), condition=condition
)
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
return speech[: len(speech) // 2]
AutoModelForCausalLM.register(
QWEN3VoxStreamingConfig, QWEN3VoxStreamingForConditionalGenerationInference
)
__all__ = [
'QWEN3VoxStreamingForConditionalGenerationInference',
'QWEN3VoxGenerationOutput',
'QWEN3VoxLMHeadOutputWithPast',
"TTS_TEXT_WINDOW_SIZE",
"TTS_SPEECH_WINDOW_SIZE",
]
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, DatasetDict, VerificationMode
from transformers import HfArgumentParser, Trainer, set_seed, TrainerCallback
from transformers import TrainingArguments as HfTrainingArguments
from peft import LoraConfig, get_peft_model, TaskType
logger = logging.getLogger(__name__)
import copy
import torch
from transformers import TrainerCallback
class EmaCallback(TrainerCallback):
def __init__(self, attr_path="model.prediction_head", decay=0.999, device="cpu"):
self.attr_path = attr_path
self.decay = float(decay)
self.device = torch.device(device)
self.shadow = None
self._orig = None
def _get_module(self, model):
mod = model
for name in self.attr_path.split("."):
mod = getattr(mod, name)
return mod
def on_train_begin(self, args, state, control, model=None, **kwargs):
head = self._get_module(model)
self.shadow = {
k: p.detach().to(self.device).clone() for k, p in head.state_dict().items()
}
def on_step_end(self, args, state, control, model=None, **kwargs):
if self.shadow is None:
return
head = self._get_module(model)
with torch.no_grad():
for k, v in head.state_dict().items():
self.shadow[k].mul_(self.decay).add_(
v.detach().to(self.device), alpha=1.0 - self.decay
)
def _swap_in_ema(self, model):
head = self._get_module(model)
self._orig = copy.deepcopy(head.state_dict())
head.load_state_dict(self.shadow, strict=False)
def _swap_back(self, model):
if self._orig is None:
return
head = self._get_module(model)
head.load_state_dict(self._orig, strict=False)
self._orig = None
def on_evaluate(self, args, state, control, model=None, **kwargs):
self._swap_in_ema(model)
def on_evaluate_end(self, args, state, control, model=None, **kwargs):
self._swap_back(model)
def on_save(self, args, state, control, model=None, **kwargs):
self._swap_in_ema(model)
def on_save_end(self, args, state, control, model=None, **kwargs):
self._swap_back(model)
def on_train_end(self, args, state, control, model=None, **kwargs):
self._swap_in_ema(model)
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": 'Path to QWEN3Vox base model with config.json'
},
)
processor_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "Path to processor dir (preprocessor_config.json). Defaults to model path."
},
)
cache_dir: Optional[str] = field(default=None)
freeze_acoustic_tokenizer: bool = field(default=True)
freeze_semantic_tokenizer: bool = field(default=True)
lora_r: int = field(default=8)
lora_alpha: int = field(default=32)
lora_dropout: float = field(default=0.05)
lora_target_modules: str = field(
default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
metadata={
"help": "Comma-separated list of target module names in the LLM blocks"
},
)
lora_wrap_diffusion_head: bool = field(
default=False, metadata={"help": "Wrap diffusion head with PEFT LoRA"}
)
train_diffusion_head: bool = field(
default=False,
metadata={"help": "Train diffusion prediction head (full fine-tune)"},
)
train_connectors: bool = field(
default=False,
metadata={"help": "Train acoustic/semantic connectors (full fine-tune)"},
)
layers_to_freeze: Optional[str] = field(
default=None,
metadata={
"help": "Comma-separated indices of diffusion head layers to freeze (e.g., '0,1,5,7,8')."
},
)
@dataclass
class DataArguments:
dataset_name: Optional[str] = field(
default=None,
metadata={
"help": "HF dataset name or 'json' with --train_jsonl for local files"
},
)
dataset_config_name: Optional[str] = field(default=None)
train_split_name: str = field(default="train")
eval_split_name: Optional[str] = field(default="validation")
text_column_name: str = field(default="text")
audio_column_name: str = field(default="audio")
voice_prompts_column_name: Optional[str] = field(default="voice_prompts")
eval_split_size: float = field(default=0.0)
ignore_verifications: bool = field(default=False)
max_length: Optional[int] = field(default=None)
train_jsonl: Optional[str] = field(
default=None,
metadata={
"help": "Path to local train JSONL with {text, audio, [voice_prompts]}"
},
)
validation_jsonl: Optional[str] = field(
default=None, metadata={"help": "Optional path to local validation JSONL"}
)
voice_prompt_drop_rate: float = field(
default=0.0,
metadata={
"help": "Probability to drop conditioning voice prompt during training (0.0 keep always, 1.0 drop always)."
},
)
@dataclass
class CustomTrainingArguments(HfTrainingArguments):
ddpm_batch_mul: int = field(default=1)
ce_loss_weight: float = field(default=1.0)
diffusion_loss_weight: float = field(default=1.0)
debug_ce_details: bool = field(default=False)
debug_ce_topk: int = field(default=5)
debug_ce_max_examples: int = field(default=1)
debug_ce_every_n_steps: int = field(default=200)
gradient_clipping: bool = field(
default=False,
metadata={
"help": "Enable gradient clipping using max_grad_norm (set via --max_grad_norm, default 1.0). When False, disables clipping by forcing max_grad_norm=0.0."
},
)
debug_save: bool = field(
default=False,
metadata={
"help": "If set, saves model components BEFORE training starts, into output_dir/debug_initial."
},
)
def build_lora_config(args: ModelArguments) -> LoraConfig:
target_modules = [
s.strip() for s in args.lora_target_modules.split(",") if s.strip()
]
# language_model is Qwen2Model (backbone), not ForCausalLM. CAUSAL_LM maps to
# PeftModelForCausalLM which requires prepare_inputs_for_generation on the base.
return LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type=TaskType.FEATURE_EXTRACTION,
target_modules=target_modules,
)
def build_head_lora_config(args: ModelArguments) -> LoraConfig:
target_modules = [
"noisy_images_proj",
"cond_proj",
"gate_proj",
"up_proj",
"down_proj",
"linear",
]
return LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type=TaskType.FEATURE_EXTRACTION,
target_modules=target_modules,
)
def mask_for_ce(
labels: torch.Tensor,
attention_mask: torch.Tensor,
acoustic_input_mask: torch.Tensor,
pad_id: int = -100,
) -> torch.Tensor:
shifted = labels[:, 1:].contiguous()
base_mask = (
attention_mask[:, 1:].contiguous().eq(1)
if attention_mask is not None and attention_mask.numel() > 0
else torch.ones_like(shifted, dtype=torch.bool)
)
label_is_acoustic = acoustic_input_mask[:, 1:].contiguous()
final_mask = base_mask & ~label_is_acoustic
out = shifted.clone()
out[~final_mask] = pad_id
return out
def _patch_acoustic_encode_for_legacy_indexing(model_obj, logger_):
try:
acoustic = getattr(
getattr(model_obj, "model", model_obj), "acoustic_tokenizer", None
)
if acoustic is None or not hasattr(acoustic, "encode"):
logger_.warning("No acoustic_tokenizer.encode() found to patch.")
return
base_encode = acoustic.encode
def encode_wrapped(*args, **kwargs):
out = base_encode(*args, **kwargs)
try:
_ = out[0][0]
return out
except Exception:
pass
if isinstance(out, dict):
for k in ("frames", "codes", "tokens", "latents", "hidden_states"):
if k in out:
return [[out[k]]]
if len(out) > 0:
return [[next(iter(out.values()))]]
for attr in ("frames", "codes", "tokens", "latents", "hidden_states"):
if hasattr(out, attr):
return [[getattr(out, attr)]]
try:
if isinstance(out, torch.Tensor):
return [[out]]
except Exception:
pass
return [[out]]
acoustic.encode = encode_wrapped
logger_.info(
"Patched acoustic_tokenizer.encode() to return [[...]] for legacy indexing."
)
except Exception as e:
logger_.warning(f"Failed to patch acoustic_tokenizer.encode(): {e }")
def main() -> None:
parser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.info("Training/evaluation parameters %s", training_args)
set_seed(training_args.seed)
if not getattr(training_args, "gradient_clipping", False):
if hasattr(training_args, "max_grad_norm"):
training_args.max_grad_norm = 0.0
logger.info(
"Gradient clipping disabled (set max_grad_norm=0.0). Use --gradient_clipping to enable."
)
else:
if (
not hasattr(training_args, "max_grad_norm")
or training_args.max_grad_norm is None
or training_args.max_grad_norm <= 0
):
training_args.max_grad_norm = 1.0
logger.info(
f"Gradient clipping enabled: max_grad_norm={training_args .max_grad_norm }"
)
model_name = model_args.model_name_or_path
if model_name is None:
raise ValueError(
"--model_name_or_path (or --processor_name_or_path) must be provided"
)
processor: QWEN3VoxProcessor = QWEN3VoxProcessor.from_pretrained(model_name)
tok = processor.tokenizer
for required in ["speech_start_id", "speech_diffusion_id", "speech_end_id"]:
if not hasattr(tok, required) or getattr(tok, required) is None:
raise RuntimeError(f"Tokenizer missing required special id: {required }")
dtype = torch.float32
if training_args.bf16:
dtype = torch.bfloat16
elif getattr(training_args, "fp16", False):
dtype = torch.float16
model = QWEN3VoxForConditionalGeneration.from_pretrained(
model_name, torch_dtype=dtype
)
_patch_acoustic_encode_for_legacy_indexing(model, logger)
processor.semantic_tokenizer = getattr(model.model, "semantic_tokenizer", None)
try:
in_emb_mod = model.get_input_embeddings()
out_emb_mod = model.get_output_embeddings()
in_w = getattr(in_emb_mod, "weight", None)
out_w = getattr(out_emb_mod, "weight", None)
shared_ptr = bool(
in_w is not None
and out_w is not None
and (in_w.data_ptr() == out_w.data_ptr())
)
values_equal = False
if in_w is not None and out_w is not None and (in_w.shape == out_w.shape):
try:
values_equal = bool(torch.allclose(in_w, out_w))
except Exception:
values_equal = False
try:
tie_cfg = getattr(
getattr(model.config, "decoder_config", model.config),
"tie_word_embeddings",
None,
)
except Exception:
tie_cfg = getattr(model.config, "tie_word_embeddings", None)
logger.info(
f"LM head diagnostics -> shared_params={shared_ptr }, values_equal={values_equal }, tie_word_embeddings={tie_cfg }"
)
if out_w is not None:
logger.info(
f"LM head requires_grad before freeze: {bool (out_w .requires_grad )}"
)
except Exception as e:
logger.warning(f"LM head tie diagnostics failed: {e }")
try:
emb_module = model.get_input_embeddings()
head_module = model.get_output_embeddings()
if hasattr(emb_module, "weight") and hasattr(head_module, "weight"):
if (
emb_module.weight.shape == head_module.weight.shape
and emb_module.weight.data_ptr() != head_module.weight.data_ptr()
):
with torch.no_grad():
head_module.weight = emb_module.weight
logger.info(
"Force-tied LM head weight to input embeddings (pointer share)."
)
except Exception as e:
logger.warning(f"Force-tie of LM head failed: {e }")
try:
special_names = ["speech_start_id", "speech_diffusion_id", "speech_end_id"]
try:
vocab_size = int(getattr(model.config.decoder_config, "vocab_size", 0))
except Exception:
vocab_size = 0
in_emb_mod = model.get_input_embeddings()
out_emb_mod = model.get_output_embeddings()
in_w = getattr(in_emb_mod, "weight", None)
out_w = getattr(out_emb_mod, "weight", None)
for name in special_names:
val = getattr(tok, name, None)
exists = val is not None
in_range = exists and isinstance(val, int) and (0 <= val < vocab_size)
equal_row = None
if (
in_range
and in_w is not None
and (out_w is not None)
and (in_w.shape == out_w.shape)
and (in_w.size(0) > val)
):
try:
equal_row = bool(torch.allclose(in_w[val], out_w[val]))
except Exception:
equal_row = False
decoded_str = None
if exists and isinstance(val, int):
try:
decoded_str = tok.decode([val])
except Exception:
try:
decoded_str = tok.convert_ids_to_tokens(val)
except Exception:
decoded_str = "<decode_failed>"
logger.info(
f"Special token check -> {name }={val }, decoded='{decoded_str }', exists={exists }, in_vocab_range={in_range }, emb_vs_head_row_equal={equal_row }"
)
except Exception as e:
logger.warning(f"Special token ID/row validation failed: {e }")
try:
logger.info("=== TOKENIZER DIAGNOSTICS ===")
logger.info(f"Tokenizer class: {type (tok ).__name__ }")
logger.info(f"Tokenizer vocab_size: {tok .vocab_size }")
with torch.no_grad():
simple_text = "The cat sat on the mat."
simple_ids = torch.tensor(
[tok.encode(simple_text, add_special_tokens=True)], device=model.device
)
simple_mask = torch.ones_like(simple_ids)
x = model.get_input_embeddings()(simple_ids)
outputs = model.model(
inputs_embeds=x, attention_mask=simple_mask, return_dict=True
)
logits = model.lm_head(outputs.last_hidden_state)
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = simple_ids[:, 1:].contiguous()
ce_loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction="mean",
)
logger.info(f"Simple text CE loss: {ce_loss .item ():.4f}")
except Exception as e:
logger.warning(f"Tokenizer diagnostics failed: {e }")
if hasattr(model.config, "use_cache") and training_args.do_train:
model.config.use_cache = False
if model_args.freeze_acoustic_tokenizer and hasattr(
model.model, "acoustic_tokenizer"
):
for p in model.model.acoustic_tokenizer.parameters():
p.requires_grad = False
if model_args.freeze_semantic_tokenizer and hasattr(
model.model, "semantic_tokenizer"
):
for p in model.model.semantic_tokenizer.parameters():
p.requires_grad = False
lora_cfg = build_lora_config(model_args)
tm_lower = [
s.strip().lower()
for s in model_args.lora_target_modules.split(",")
if s.strip()
]
skip_lm_lora = len(tm_lower) == 0 or all(
(t in ("none", "off", "disable", "disabled") for t in tm_lower)
)
if not skip_lm_lora:
model.model.language_model = get_peft_model(
model.model.language_model, lora_cfg
)
else:
logger.info("Skipping LLM LoRA wrapping (lora_target_modules indicates none).")
try:
model.tie_weights()
except Exception:
pass
for _, p in model.named_parameters():
p.requires_grad = False
try:
for n, p in model.model.language_model.named_parameters():
if "lora_A" in n or "lora_B" in n:
p.requires_grad = True
except Exception:
logger.warning("Could not re-enable LoRA params on language_model.")
if getattr(model_args, "lora_wrap_diffusion_head", False) and hasattr(
model.model, "prediction_head"
):
class _HeadForwardShim(nn.Module):
def __init__(self, base: nn.Module):
super().__init__()
self.base = base
def forward(self, *args, **kwargs):
if len(args) >= 3:
noisy_images, timesteps, condition = args[:3]
else:
noisy_images = kwargs.get("noisy_images")
timesteps = kwargs.get("timesteps")
condition = kwargs.get("condition")
return self.base(noisy_images, timesteps, condition)
try:
shim = _HeadForwardShim(model.model.prediction_head)
model.model.prediction_head = get_peft_model(
shim, build_head_lora_config(model_args)
)
for n, p in model.model.prediction_head.named_parameters():
if "lora_A" in n or "lora_B" in n:
p.requires_grad = True
except Exception as e:
logger.warning(f"Could not LoRA-wrap diffusion head: {e }")
if getattr(model_args, "train_diffusion_head", False) and hasattr(
model.model, "prediction_head"
):
for p in model.model.prediction_head.parameters():
p.requires_grad = True
if model_args.layers_to_freeze is not None and hasattr(
model.model, "prediction_head"
):
head_params = list(model.model.prediction_head.named_parameters())
try:
indices_to_freeze = {
int(x.strip())
for x in model_args.layers_to_freeze.split(",")
if x.strip()
}
frozen_count = 0
for i, (name, param) in enumerate(head_params):
if i in indices_to_freeze:
param.requires_grad = False
frozen_count += 1
logger.info(f"Froze layer [{i }]: {name }")
logger.info(
f"Successfully froze {frozen_count } parameter groups in the diffusion head."
)
except Exception as e:
logger.error(f"Could not parse --layers_to_freeze: {e }")
raise
if getattr(model_args, "train_connectors", False):
if hasattr(model.model, "acoustic_connector"):
for p in model.model.acoustic_connector.parameters():
p.requires_grad = True
if hasattr(model.model, "semantic_connector"):
for p in model.model.semantic_connector.parameters():
p.requires_grad = True
else:
if hasattr(model.model, "acoustic_connector"):
for p in model.model.acoustic_connector.parameters():
p.requires_grad = False
if hasattr(model.model, "semantic_connector"):
for p in model.model.semantic_connector.parameters():
p.requires_grad = False
try:
emb = model.get_input_embeddings()
if hasattr(emb, "weight"):
emb.weight.requires_grad_(False)
head = model.get_output_embeddings()
if head is not None and hasattr(head, "weight"):
head.weight.requires_grad_(False)
except Exception:
pass
def _sum_params(named_iter):
return sum((p.numel() for _, p in named_iter if p.requires_grad))
try:
lm_lora = (
_sum_params(model.model.language_model.named_parameters())
if hasattr(model.model, "language_model")
else 0
)
pred_head_train = (
_sum_params(model.model.prediction_head.named_parameters())
if hasattr(model.model, "prediction_head")
else 0
)
ac_conn_train = (
_sum_params(model.model.acoustic_connector.named_parameters())
if hasattr(model.model, "acoustic_connector")
else 0
)
se_conn_train = (
_sum_params(model.model.semantic_connector.named_parameters())
if hasattr(model.model, "semantic_connector")
else 0
)
total_trainable = sum(
(p.numel() for p in model.parameters() if p.requires_grad)
)
logger.info(
f"Trainable by block -> LLM-LoRA: {lm_lora :,} | diff_head: {pred_head_train :,} | ac_conn: {ac_conn_train :,} | se_conn: {se_conn_train :,}"
)
logger.info("TOTAL trainable: %s", f"{total_trainable :,}")
except Exception:
pass
verification_mode = (
VerificationMode.NO_CHECKS
if data_args.ignore_verifications
else VerificationMode.BASIC_CHECKS
)
if data_args.train_jsonl is not None:
data_files: Dict[str, str] = {"train": data_args.train_jsonl}
if data_args.validation_jsonl is not None:
data_files["validation"] = data_args.validation_jsonl
raw = load_dataset(
"json",
data_files=data_files,
verification_mode=verification_mode,
cache_dir=model_args.cache_dir,
)
else:
if data_args.dataset_name is None:
raise ValueError(
"Provide --dataset_name (HF datasets) or use --train_jsonl/--validation_jsonl for local files."
)
raw = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
verification_mode=verification_mode,
cache_dir=model_args.cache_dir,
)
train_ds = raw[data_args.train_split_name]
eval_ds = None
if training_args.do_eval:
if data_args.eval_split_name and data_args.eval_split_name in raw:
eval_ds = raw[data_args.eval_split_name]
elif (
data_args.eval_split_size
and data_args.eval_split_size > 0
and (len(train_ds) > 1)
):
split = train_ds.train_test_split(
test_size=data_args.eval_split_size, seed=training_args.seed
)
train_ds, eval_ds = (split["train"], split["test"])
train_dataset = QWEN3VoxDataset(
train_ds,
text_column=data_args.text_column_name,
audio_column=data_args.audio_column_name,
voice_prompts_column=data_args.voice_prompts_column_name,
)
eval_dataset = None
if eval_ds is not None:
eval_dataset = QWEN3VoxDataset(
eval_ds,
text_column=data_args.text_column_name,
audio_column=data_args.audio_column_name,
voice_prompts_column=data_args.voice_prompts_column_name,
)
speech_compress_ratio = getattr(processor, "speech_tok_compress_ratio", 3200)
semantic_dim = getattr(model.config, "semantic_vae_dim", None)
if semantic_dim is None:
try:
semantic_dim = int(
getattr(model.config.semantic_tokenizer_config, "vae_dim", 128)
)
except Exception:
semantic_dim = 128
compute_semantics_flag = (
hasattr(processor, "semantic_tokenizer")
and processor.semantic_tokenizer is not None
)
data_collator = QWEN3VoxCollator(
processor=processor,
max_length=data_args.max_length,
speech_compress_ratio=speech_compress_ratio,
semantic_vae_dim=semantic_dim,
compute_semantics=compute_semantics_flag,
debug_checks=False,
voice_prompt_drop_rate=data_args.voice_prompt_drop_rate,
)
class LoRADebugCallback(TrainerCallback):
def __init__(self, log_every_n_steps: int = 50):
self.log_every_n_steps = max(1, int(log_every_n_steps))
self.prev_param_norms: Dict[str, float] = {}
self.lora_param_names: List[str] = []
def on_train_begin(self, args, state, control, model=None, **kwargs):
try:
if model is None:
return
named: Dict[str, torch.nn.Parameter] = dict(model.named_parameters())
self.lora_param_names = [
n for n in named.keys() if "lora_A" in n or "lora_B" in n
]
for n in self.lora_param_names:
p = named[n]
self.prev_param_norms[n] = float(p.data.norm().item())
total = len(self.lora_param_names)
req_grad = sum(
(1 for n in self.lora_param_names if named[n].requires_grad)
)
num_A = sum((1 for n in self.lora_param_names if "lora_A" in n))
num_B = sum((1 for n in self.lora_param_names if "lora_B" in n))
zero_B = sum(
(
1
for n in self.lora_param_names
if "lora_B" in n and float(named[n].data.norm().item()) == 0.0
)
)
logger.info(
f"LoRA debug: found {total } LoRA params (A={num_A }, B={num_B }); trainable={req_grad }. Initial lora_B_zero={zero_B }."
)
if total == 0:
logger.warning(
"LoRA debug: No LoRA parameters found. Check lora_target_modules."
)
if req_grad != total:
logger.warning(
"LoRA debug: Some LoRA params are frozen. They should be trainable."
)
except Exception as e:
logger.warning(f"LoRA debug (on_train_begin) failed: {e }")
def on_step_end(self, args, state, control, model=None, **kwargs):
try:
if model is None or len(self.lora_param_names) == 0:
return
step = int(getattr(state, "global_step", 0) or 0)
if step % self.log_every_n_steps != 0 and step != 1:
return
named: Dict[str, torch.nn.Parameter] = dict(model.named_parameters())
changed_A = 0
changed_B = 0
zero_B = 0
eps = 1e-12
for n in self.lora_param_names:
p = named.get(n, None)
if p is None:
continue
prev = self.prev_param_norms.get(n, 0.0)
curr = float(p.data.norm().item())
if "lora_A" in n and abs(curr - prev) > eps:
changed_A += 1
if "lora_B" in n:
if abs(curr - prev) > eps:
changed_B += 1
if curr == 0.0:
zero_B += 1
self.prev_param_norms[n] = curr
total_A = sum((1 for n in self.lora_param_names if "lora_A" in n))
total_B = sum((1 for n in self.lora_param_names if "lora_B" in n))
logger.info(
f"LoRA debug step {step }: changed A {changed_A }/{total_A }, changed B {changed_B }/{total_B }, lora_B_zero_now={zero_B }."
)
except Exception as e:
logger.warning(f"LoRA debug (on_step_end) failed: {e }")
class QWEN3VoxTrainer(Trainer):
def training_forward(
self, model: QWEN3VoxForConditionalGeneration, inputs: Dict[str, Any]
):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
position_ids = inputs.get("position_ids")
past_key_values = inputs.get("past_key_values")
inputs_embeds = inputs.get("inputs_embeds")
use_cache = inputs.get("use_cache", False)
output_attentions = inputs.get("output_attentions")
output_hidden_states = inputs.get("output_hidden_states")
return_dict = inputs.get("return_dict", True)
cache_position = inputs.get("cache_position")
speech_tensors = inputs.get("speech_tensors")
speech_masks = inputs.get("speech_masks")
speeches_loss_input = inputs.get("speeches_loss_input")
speech_semantic_tensors = inputs.get("speech_semantic_tensors")
acoustic_input_mask = inputs.get("acoustic_input_mask")
acoustic_loss_mask = inputs.get("acoustic_loss_mask")
ddmp_batch_mul = training_args.ddpm_batch_mul
kwargs = {}
x = model.get_input_embeddings()(input_ids)
semantic_speech_all_connect_features = model.model.semantic_connector(
speech_semantic_tensors
)
if speeches_loss_input is not None:
speech_all_features, speech_all_connect_features = (
model.forward_speech_features(
speech_tensors=(
speech_tensors.type_as(x)
if speech_tensors is not None
else None
),
speech_masks=speech_masks,
speech_type=kwargs.get("speech_type", "audio"),
return_unmask=True,
)
)
if speech_tensors is not None:
if semantic_speech_all_connect_features is not None:
x[acoustic_input_mask] = (
speech_all_connect_features[speech_masks]
+ semantic_speech_all_connect_features[speech_masks]
)
else:
x[acoustic_input_mask] = speech_all_connect_features[
speech_masks
]
speech_features = speech_all_features[
speeches_loss_input & speech_masks
]
speech_connect_features = speech_all_connect_features[
speeches_loss_input & speech_masks
]
try:
if acoustic_input_mask is not None:
assert speech_connect_features.shape[0] == int(
acoustic_input_mask.sum().item()
), f"Mismatch between selected speech connectors ({speech_connect_features .shape [0 ]}) and acoustic_input_mask sum ({int (acoustic_input_mask .sum ().item ())})"
except Exception:
pass
else:
speech_features, speech_connect_features = (
model.forward_speech_features(
speech_tensors=(
speech_tensors.type_as(x)
if speech_tensors is not None
else None
),
speech_masks=speech_masks,
speech_type=kwargs.get("speech_type", "audio"),
)
)
if speech_tensors is not None:
x[acoustic_input_mask] = speech_connect_features
outputs = model.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=x,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=False,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs.last_hidden_state
logits = model.lm_head(hidden_states)
loss = None
diffusion_loss = None
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
cond_mask = torch.zeros_like(acoustic_loss_mask, dtype=torch.bool)
cond_mask[:, :-1] = acoustic_loss_mask[:, 1:]
cond_mask[:, 0] = False
condition_features = hidden_states[cond_mask]
speech_len, latent_size = speech_features.shape
try:
assert (
condition_features.shape[0] == speech_len
), f"Mismatch: condition_features={condition_features .shape [0 ]} vs speech_features={speech_len }"
except Exception:
pass
noise = torch.randn(
(speech_len * ddmp_batch_mul, latent_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
timesteps = torch.multinomial(
torch.ones(model.config.diffusion_head_config.ddpm_num_steps),
speech_len * ddmp_batch_mul,
replacement=True,
).to(hidden_states.device)
speech_features_repeated = speech_features.repeat_interleave(
ddmp_batch_mul, dim=0
)
condition_features_repeated = condition_features.repeat_interleave(
ddmp_batch_mul, dim=0
)
noisy_speech_features = model.model.noise_scheduler.add_noise(
speech_features_repeated, noise, timesteps
)
model_output = model.model.prediction_head(
noisy_speech_features,
timesteps.type_as(x),
condition_features_repeated,
)
prediction_type = model.config.diffusion_head_config.prediction_type
if prediction_type == "epsilon":
target_for_loss = noise
elif prediction_type == "v_prediction":
target_for_loss = model.model.noise_scheduler.get_velocity(
speech_features_repeated, noise, timesteps
)
else:
raise NotImplementedError(
f"Prediction type {prediction_type } not implemented"
)
diffusion_loss = F.mse_loss(
model_output.float(), target_for_loss.float(), reduction="sum"
)
if latent_size > 0 and ddmp_batch_mul > 0:
diffusion_loss = (
diffusion_loss
/ latent_size
/ ddmp_batch_mul
/ max(speech_len, 1)
)
else:
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
else:
diffusion_loss = (
sum((p.sum() for p in model.model.prediction_head.parameters()))
* 0.0
)
diffusion_loss += (
sum((p.sum() for p in model.model.acoustic_connector.parameters()))
* 0.0
)
diffusion_loss += (
sum((p.sum() for p in model.model.semantic_connector.parameters()))
* 0.0
)
return QWEN3VoxCausalLMOutputWithPast(
loss=loss,
diffusion_loss=diffusion_loss,
speech_token_num=speech_len if speech_tensors is not None else 0,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def compute_loss(
self,
model: QWEN3VoxForConditionalGeneration,
inputs: Dict[str, Any],
return_outputs=False,
num_items_in_batch: Optional[int] = None,
):
labels = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
acoustic_input_mask = inputs.get("acoustic_input_mask")
sem = inputs.get("speech_semantic_tensors", None)
try:
target_dtype = next(model.model.semantic_connector.parameters()).dtype
except Exception:
target_dtype = model.get_input_embeddings().weight.dtype
if sem is None:
sm = inputs.get("speech_masks")
if sm is not None:
zeros = torch.zeros(
sm.size(0),
sm.size(1),
getattr(model.config, "semantic_vae_dim", 128),
dtype=target_dtype,
device=sm.device,
)
inputs["speech_semantic_tensors"] = zeros
elif isinstance(sem, torch.Tensor):
inputs["speech_semantic_tensors"] = sem.to(dtype=target_dtype)
outputs = self.training_forward(model, inputs)
try:
al_mask = inputs.get("acoustic_loss_mask")
sp_masks = inputs.get("speech_masks")
sp_loss_sel = inputs.get("speeches_loss_input")
num_tok_total = (
int(acoustic_input_mask.sum().item())
if acoustic_input_mask is not None
else 0
)
num_tok_loss = int(al_mask.sum().item()) if al_mask is not None else 0
num_lat_total = (
int(sp_masks.sum().item()) if sp_masks is not None else 0
)
num_lat_loss = (
int((sp_loss_sel & sp_masks).sum().item())
if sp_loss_sel is not None and sp_masks is not None
else 0
)
self.log(
{
"debug/num_tok_total": float(num_tok_total),
"debug/num_tok_loss": float(num_tok_loss),
"debug/num_lat_total": float(num_lat_total),
"debug/num_lat_loss": float(num_lat_loss),
}
)
if (
sp_loss_sel is not None
and sp_masks is not None
and (al_mask is not None)
):
if num_tok_loss != num_lat_loss:
logger.warning(
f"Loss selection mismatch: acoustic_loss_mask={num_tok_loss } vs speeches_loss_input={num_lat_loss }"
)
except Exception:
pass
logits = outputs.logits
ce_labels = mask_for_ce(
labels, attention_mask, acoustic_input_mask, pad_id=-100
)
shift_logits = logits[:, :-1, :].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
ce_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), ce_labels.view(-1)
)
try:
self._debug_ce(
shift_logits, ce_labels, attention_mask, acoustic_input_mask
)
except Exception as e:
logger.warning(f"Failed invoking CE debug: {e }")
diffusion_loss = (
outputs.diffusion_loss
if outputs.diffusion_loss is not None
else torch.tensor(0.0, device=ce_loss.device)
)
total = (
training_args.ce_loss_weight * ce_loss
+ training_args.diffusion_loss_weight * diffusion_loss
)
try:
prefix = "train" if model.training else "eval"
self.log(
{
f"{prefix }/ce_loss": ce_loss.detach().item(),
f"{prefix }/diffusion_loss": (
diffusion_loss.detach().item()
if isinstance(diffusion_loss, torch.Tensor)
else float(diffusion_loss)
),
}
)
if (
hasattr(self, "optimizer")
and self.optimizer is not None
and (len(self.optimizer.param_groups) > 0)
):
lr_val = self.optimizer.param_groups[0].get("lr", None)
if lr_val is not None:
self.log({"train/learning_rate_real": float(lr_val)})
except Exception:
pass
return (total, outputs) if return_outputs else total
def _debug_ce(
self,
shift_logits: torch.Tensor,
ce_labels: torch.Tensor,
attention_mask: Optional[torch.Tensor],
acoustic_input_mask: Optional[torch.Tensor],
):
try:
if not getattr(training_args, "debug_ce_details", False):
return
step = int(getattr(self.state, "global_step", 0) or 0)
every_n = max(
1, int(getattr(training_args, "debug_ce_every_n_steps", 200) or 200)
)
if not (step <= 1 or step % every_n == 0):
return
with torch.no_grad():
vocab = shift_logits.size(-1)
per_token_loss = F.cross_entropy(
shift_logits.view(-1, vocab),
ce_labels.view(-1),
reduction="none",
ignore_index=-100,
).view_as(ce_labels)
valid_mask = ce_labels.ne(-100)
num_valid = int(valid_mask.sum().item())
avg_loss = (
float(per_token_loss[valid_mask].mean().item())
if num_valid > 0
else float("nan")
)
per_ex_avgs = []
max_examples = max(
1, int(getattr(training_args, "debug_ce_max_examples", 1) or 1)
)
B = ce_labels.size(0)
for b in range(min(B, max_examples)):
vb = valid_mask[b]
if int(vb.sum().item()) > 0:
per_ex_avgs.append(
float(per_token_loss[b][vb].mean().item())
)
else:
per_ex_avgs.append(float("nan"))
logger.info(
f"CE debug: tokens_in_loss={num_valid }, avg_loss={avg_loss :.4f}, per_example_avgs={[round (x ,4 )if x ==x else None for x in per_ex_avgs ]}"
)
except Exception as e:
logger.warning(f"CE detailed debug failed: {e }")
def _save(self, output_dir: Optional[str] = None, state_dict=None) -> None:
try:
target_dir = output_dir or self.args.output_dir
lora_out = os.path.join(target_dir, "lora")
os.makedirs(lora_out, exist_ok=True)
language_model = getattr(self.model.model, "language_model", None)
if hasattr(language_model, "save_pretrained"):
language_model.save_pretrained(lora_out)
pred_head = getattr(self.model.model, "prediction_head", None)
if hasattr(pred_head, "save_pretrained"):
ph_dir = os.path.join(lora_out, "diffusion_head")
os.makedirs(ph_dir, exist_ok=True)
pred_head.save_pretrained(ph_dir)
if pred_head is not None and hasattr(pred_head, "state_dict"):
sd = pred_head.state_dict()
torch.save(sd, os.path.join(lora_out, "diffusion_head_full.bin"))
ph_dir = os.path.join(lora_out, "diffusion_head")
os.makedirs(ph_dir, exist_ok=True)
torch.save(sd, os.path.join(ph_dir, "diffusion_head_full.bin"))
ac = getattr(self.model.model, "acoustic_connector", None)
if ac is not None:
ac_dir = os.path.join(lora_out, "acoustic_connector")
os.makedirs(ac_dir, exist_ok=True)
torch.save(
ac.state_dict(), os.path.join(ac_dir, "pytorch_model.bin")
)
se = getattr(self.model.model, "semantic_connector", None)
if se is not None:
se_dir = os.path.join(lora_out, "semantic_connector")
os.makedirs(se_dir, exist_ok=True)
torch.save(
se.state_dict(), os.path.join(se_dir, "pytorch_model.bin")
)
except Exception as e:
logger.warning(f"Failed to save LoRA assets: {e }")
ema_cb = EmaCallback(attr_path="model.prediction_head", decay=0.999, device="cpu")
trainer = QWEN3VoxTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
callbacks=[
ema_cb,
LoRADebugCallback(
log_every_n_steps=int(getattr(training_args, "logging_steps", 50) or 50)
),
],
)
if getattr(training_args, "debug_save", False):
try:
debug_dir = os.path.join(training_args.output_dir, "debug_initial")
lora_out = os.path.join(debug_dir, "lora")
os.makedirs(lora_out, exist_ok=True)
logger.info(
f"[debug_save] Saving initial (pre-training) model components to: {debug_dir }"
)
try:
if hasattr(model.model.language_model, "save_pretrained"):
model.model.language_model.save_pretrained(lora_out)
except Exception as e_lm:
logger.warning(f"[debug_save] Failed to save language_model: {e_lm }")
try:
if hasattr(model.model, "prediction_head") and hasattr(
model.model.prediction_head, "save_pretrained"
):
model.model.prediction_head.save_pretrained(
os.path.join(lora_out, "diffusion_head")
)
except Exception as e_head:
logger.warning(
f"[debug_save] Failed to save prediction_head: {e_head }"
)
try:
ph = getattr(model.model, "prediction_head", None)
if ph is not None and hasattr(ph, "state_dict"):
sd = ph.state_dict()
torch.save(sd, os.path.join(lora_out, "diffusion_head_full.bin"))
os.makedirs(os.path.join(lora_out, "diffusion_head"), exist_ok=True)
torch.save(
sd,
os.path.join(
lora_out, "diffusion_head", "diffusion_head_full.bin"
),
)
except Exception as e:
logger.warning(f"[debug_save] Failed to save FULL diffusion head: {e }")
try:
ac_conn = getattr(model.model, "acoustic_connector", None)
if ac_conn is not None:
ac_dir = os.path.join(lora_out, "acoustic_connector")
os.makedirs(ac_dir, exist_ok=True)
torch.save(
ac_conn.state_dict(), os.path.join(ac_dir, "pytorch_model.bin")
)
except Exception as e_ac:
logger.warning(
f"[debug_save] Failed to save acoustic_connector: {e_ac }"
)
try:
se_conn = getattr(model.model, "semantic_connector", None)
if se_conn is not None:
se_dir = os.path.join(lora_out, "semantic_connector")
os.makedirs(se_dir, exist_ok=True)
torch.save(
se_conn.state_dict(), os.path.join(se_dir, "pytorch_model.bin")
)
except Exception as e_se:
logger.warning(
f"[debug_save] Failed to save semantic_connector: {e_se }"
)
except Exception as e:
logger.warning(
f"[debug_save] Unexpected failure saving initial components: {e }"
)
if getattr(training_args, "gradient_checkpointing", False):
try:
model.gradient_checkpointing_enable()
except Exception:
logger.warning("Failed to enable gradient checkpointing on the model.")
if training_args.do_train:
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
lora_out = os.path.join(training_args.output_dir, "lora")
os.makedirs(lora_out, exist_ok=True)
lm = getattr(model.model, "language_model", None)
if hasattr(lm, "save_pretrained"):
lm.save_pretrained(lora_out)
ph = getattr(model.model, "prediction_head", None)
if hasattr(ph, "save_pretrained"):
ph_dir = os.path.join(lora_out, "diffusion_head")
os.makedirs(ph_dir, exist_ok=True)
ph.save_pretrained(ph_dir)
try:
if ph is not None and hasattr(ph, "state_dict"):
sd = ph.state_dict()
torch.save(sd, os.path.join(lora_out, "diffusion_head_full.bin"))
ph_dir = os.path.join(lora_out, "diffusion_head")
os.makedirs(ph_dir, exist_ok=True)
torch.save(sd, os.path.join(ph_dir, "diffusion_head_full.bin"))
except Exception as e:
logger.warning(f"Failed to save FULL diffusion head at end: {e }")
try:
ac = getattr(model.model, "acoustic_connector", None)
if ac is not None:
ac_dir = os.path.join(lora_out, "acoustic_connector")
os.makedirs(ac_dir, exist_ok=True)
torch.save(ac.state_dict(), os.path.join(ac_dir, "pytorch_model.bin"))
except Exception as e:
logger.warning(f"Failed to save acoustic_connector: {e }")
try:
se = getattr(model.model, "semantic_connector", None)
if se is not None:
se_dir = os.path.join(lora_out, "semantic_connector")
os.makedirs(se_dir, exist_ok=True)
torch.save(se.state_dict(), os.path.join(se_dir, "pytorch_model.bin"))
except Exception as e:
logger.warning(f"Failed to save semantic_connector: {e }")
if training_args.do_eval and eval_dataset is not None:
trainer.evaluate()
if __name__ == "__main__":
main()
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.generation import GenerationMixin
logger = logging.get_logger(__name__)
if (
not hasattr(modeling_utils, "ALL_PARALLEL_STYLES")
or modeling_utils.ALL_PARALLEL_STYLES is None
):
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
class QWEN3VoxASRPreTrainedModel(PreTrainedModel):
config_class = QWEN3VoxASRConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if hasattr(self.config, "language_model_config") and hasattr(
self.config.language_model_config, "initializer_range"
):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, "decoder_config") and hasattr(
self.config.decoder_config, "initializer_range"
):
std = self.config.decoder_config.initializer_range
else:
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
class QWEN3VoxASRModel(QWEN3VoxASRPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
lm_config = config.decoder_config
self.language_model = AutoModel.from_config(lm_config)
self.acoustic_tokenizer = AutoModel.from_config(
config.acoustic_tokenizer_config
).to(dtype)
self.semantic_tokenizer = AutoModel.from_config(
config.semantic_tokenizer_config
).to(dtype)
self.acoustic_connector = SpeechConnector(
config.acoustic_vae_dim, lm_config.hidden_size
).to(dtype)
self.semantic_connector = SpeechConnector(
config.semantic_vae_dim, lm_config.hidden_size
).to(dtype)
def get_input_embeddings(self):
if hasattr(self.language_model, "embed_tokens"):
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items():
if attr.orig_name == "embed_tokens.weight":
return getattr(self.language_model, name)
assert False, "should not arrive here"
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
self.acoustic_tokenizer = acoustic_tokenizer
self.semantic_tokenizer = semantic_tokenizer
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.train(False)
if self.semantic_tokenizer is not None:
self.semantic_tokenizer.train(False)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
if not return_dict:
return outputs
return BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class QWEN3VoxASRForConditionalGeneration(QWEN3VoxASRPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = QWEN3VoxASRModel(config)
self.vocab_size = config.decoder_config.vocab_size
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
self.lm_head = nn.Linear(
config.decoder_config.hidden_size, self.vocab_size, bias=False
).to(dtype)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.language_model = decoder
def get_decoder(self):
return self.model.language_model
def tie_weights(self):
if getattr(self.config.decoder_config, "tie_word_embeddings", False):
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
if hasattr(input_embeddings, "weight"):
output_embeddings.weight = input_embeddings.weight
else:
output_embeddings.weight = input_embeddings
def encode_speech(
self,
speech_tensors: torch.FloatTensor,
speech_masks: Optional[torch.BoolTensor] = None,
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
streaming_segment_duration: float = 60.0,
):
if hasattr(self.config, "torch_dtype") and self.config.torch_dtype is not None:
if isinstance(self.config.torch_dtype, str):
dtype = getattr(torch, self.config.torch_dtype)
else:
dtype = self.config.torch_dtype
else:
dtype = torch.float32
speech_tensors = speech_tensors.to(dtype)
if speech_tensors.ndim == 1:
speech_tensors = speech_tensors.unsqueeze(0)
batch_size, total_samples = speech_tensors.shape
sample_rate = 22050
segment_samples = int(streaming_segment_duration * sample_rate)
use_streaming = total_samples > segment_samples
with torch.no_grad():
if not use_streaming:
encoder_output = self.model.acoustic_tokenizer.encode(
speech_tensors.unsqueeze(1)
)
audio_tokens = encoder_output.sample(
dist_type=self.model.acoustic_tokenizer.std_dist_type
)[0]
acoustic_features = self.model.acoustic_connector(audio_tokens)
if speech_semantic_tensors is not None:
semantic_features = self.model.semantic_connector(
speech_semantic_tensors
)
else:
semantic_tokens = self.model.semantic_tokenizer.encode(
speech_tensors.unsqueeze(1)
).mean
semantic_features = self.model.semantic_connector(semantic_tokens)
else:
acoustic_encoder_cache = QWEN3VoxTokenizerStreamingCache()
semantic_encoder_cache = QWEN3VoxTokenizerStreamingCache()
acoustic_mean_segments = []
semantic_mean_segments = []
sample_indices = torch.arange(batch_size, device=speech_tensors.device)
def _iter_segments(total_length: int, segment_length: int):
if segment_length <= 0:
raise ValueError("segment_length must be positive")
for start in range(0, total_length, segment_length):
end = min(start + segment_length, total_length)
if end > start:
yield (start, end)
segments = list(_iter_segments(total_samples, segment_samples))
num_segments = len(segments)
for seg_idx, (start, end) in enumerate(segments):
chunk = speech_tensors[:, start:end].contiguous()
if chunk.numel() == 0:
continue
is_final = seg_idx == num_segments - 1
acoustic_encoder_output = self.model.acoustic_tokenizer.encode(
chunk.unsqueeze(1),
cache=acoustic_encoder_cache,
sample_indices=sample_indices,
use_cache=True,
is_final_chunk=is_final,
)
acoustic_mean_segments.append(acoustic_encoder_output.mean)
semantic_encoder_output = self.model.semantic_tokenizer.encode(
chunk.unsqueeze(1),
cache=semantic_encoder_cache,
sample_indices=sample_indices,
use_cache=True,
is_final_chunk=is_final,
)
semantic_mean_segments.append(semantic_encoder_output.mean)
acoustic_mean_full = torch.cat(
acoustic_mean_segments, dim=1
).contiguous()
acoustic_encoder_output = QWEN3VoxTokenizerEncoderOutput(
mean=acoustic_mean_full, std=self.model.acoustic_tokenizer.fix_std
)
audio_tokens = acoustic_encoder_output.sample(
dist_type=self.model.acoustic_tokenizer.std_dist_type
)[0]
acoustic_features = self.model.acoustic_connector(audio_tokens)
semantic_tokens = torch.cat(semantic_mean_segments, dim=1).contiguous()
semantic_features = self.model.semantic_connector(semantic_tokens)
if speech_masks is not None:
combined_features = (
acoustic_features[speech_masks] + semantic_features[speech_masks]
)
else:
combined_features = acoustic_features + semantic_features
return combined_features
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
acoustic_input_mask: Optional[torch.BoolTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutput]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if inputs_embeds is None and input_ids is not None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if speech_tensors is not None and acoustic_input_mask is not None:
speech_features = self.encode_speech(
speech_tensors=speech_tensors,
speech_masks=speech_masks,
speech_semantic_tensors=speech_semantic_tensors,
)
inputs_embeds[acoustic_input_mask] = speech_features
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return QWEN3VoxCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
speech_tensors=None,
speech_masks=None,
speech_semantic_tensors=None,
acoustic_input_mask=None,
**kwargs,
):
if past_key_values is not None:
if isinstance(past_key_values, tuple):
past_length = past_key_values[0][0].shape[2]
else:
past_length = past_key_values.get_seq_length()
if input_ids is not None and input_ids.shape[1] > past_length:
input_ids = input_ids[:, past_length:]
if position_ids is None and attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values is not None and input_ids is not None:
position_ids = position_ids[:, -input_ids.shape[1] :]
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens
+ (
input_ids.shape[1]
if input_ids is not None
else inputs_embeds.shape[1]
),
device=(
input_ids.device if input_ids is not None else inputs_embeds.device
),
)
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
if (
cache_position is not None
and len(cache_position) > 0
and (cache_position[0] == 0)
):
model_inputs.update(
{
"speech_tensors": speech_tensors,
"speech_masks": speech_masks,
"speech_semantic_tensors": speech_semantic_tensors,
"acoustic_input_mask": acoustic_input_mask,
}
)
else:
model_inputs.update(
{
"speech_tensors": None,
"speech_masks": None,
"speech_semantic_tensors": None,
"acoustic_input_mask": None,
}
)
model_inputs.update(kwargs)
return model_inputs
AutoModel.register(QWEN3VoxASRConfig, QWEN3VoxASRModel)
AutoModelForCausalLM.register(QWEN3VoxASRConfig, QWEN3VoxASRForConditionalGeneration)
__all__ = [
'QWEN3VoxASRPreTrainedModel',
'QWEN3VoxASRModel',
'QWEN3VoxASRForConditionalGeneration',
]
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.generation import (
GenerationMixin,
GenerationConfig,
LogitsProcessor,
LogitsProcessorList,
StoppingCriteriaList,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
logger = logging.get_logger(__name__)
if (
not hasattr(modeling_utils, "ALL_PARALLEL_STYLES")
or modeling_utils.ALL_PARALLEL_STYLES is None
):
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
class QWEN3VoxTokenConstraintProcessor(LogitsProcessor):
def __init__(self, valid_token_ids: List[int], device: torch.device = None):
self.valid_token_ids = torch.tensor(
valid_token_ids, dtype=torch.long, device=device
)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
mask = torch.full_like(scores, float("-inf"))
mask[:, self.valid_token_ids] = 0
scores = scores + mask
return scores
class QWEN3VoxForConditionalGenerationInference(QWEN3VoxPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = QWEN3VoxModel(config)
self.lm_head = nn.Linear(
config.decoder_config.hidden_size,
config.decoder_config.vocab_size,
bias=False,
)
self.ddpm_inference_steps = (
config.diffusion_head_config.ddpm_num_inference_steps
)
self.post_init()
@property
def noise_scheduler(self):
return self.model.noise_scheduler
@property
def prediction_head(self):
return self.model.prediction_head
@property
def speech_scaling_factor(self):
return self.model.speech_scaling_factor
@property
def speech_bias_factor(self):
return self.model.speech_bias_factor
@property
def acoustic_tokenizer(self):
return self.model.acoustic_tokenizer
@property
def semantic_tokenizer(self):
return self.model.semantic_tokenizer
@property
def acoustic_connector(self):
return self.model.acoustic_connector
@property
def semantic_connector(self):
return self.model.semantic_connector
def tie_weights(self):
if not getattr(self.config, "tie_word_embeddings", False):
return
if hasattr(self, "lm_head") and hasattr(
self.model.language_model, "embed_tokens"
):
self.lm_head.weight = self.model.language_model.embed_tokens.weight
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer)
def set_ddpm_inference_steps(self, num_steps=None):
self.ddpm_inference_steps = (
num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
)
def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"):
with torch.no_grad():
if speech_type == "audio":
encoder_output = self.model.acoustic_tokenizer.encode(
speech_tensors.unsqueeze(1)
)
acoustic_latents = encoder_output.sample(
dist_type=self.model.acoustic_tokenizer.std_dist_type
)[0]
acoustic_features = (
acoustic_latents
+ self.model.speech_bias_factor.to(acoustic_latents.device)
) * self.model.speech_scaling_factor.to(acoustic_latents.device)
acoustic_connected = self.model.acoustic_connector(acoustic_features)[
speech_masks.cpu()
]
return (acoustic_features, acoustic_connected)
elif speech_type == "pt":
encoder_output = QWEN3VoxTokenizerEncoderOutput(
mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std
)
acoustic_latents = encoder_output.sample(
dist_type=self.model.acoustic_tokenizer.std_dist_type
)[0]
acoustic_features = (
acoustic_latents
+ self.model.speech_bias_factor.to(acoustic_latents.device)
) * self.model.speech_scaling_factor.to(acoustic_latents.device)
acoustic_connected = self.model.acoustic_connector(acoustic_features)[
speech_masks.cpu()
]
return (acoustic_features, acoustic_connected)
else:
raise NotImplementedError(f"Speech type {speech_type } not implemented")
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
logits_to_keep: Union[int, slice] = 0,
**kwargs,
) -> Union[Tuple, QWEN3VoxLMHeadOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if speech_tensors is not None and speech_masks is not None:
acoustic_features, speech_embeds = self._process_speech_inputs(
speech_tensors.to(self.dtype), speech_masks
)
if speech_input_mask is not None:
inputs_embeds[speech_input_mask] = speech_embeds
outputs = self.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
raise NotImplementedError(
"Loss computation is not implemented in this version."
)
return QWEN3VoxLMHeadOutputWithPast(
logits=logits,
past_key_values=outputs.past_key_values,
last_hidden_state=hidden_states,
attentions=outputs.attentions,
)
def _build_generate_config_model_kwargs(
self, generation_config, inputs, tokenizer, return_processors=False, **kwargs
):
if generation_config is None:
generation_config = GenerationConfig(
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
else:
generation_config = GenerationConfig(
**generation_config,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
generation_config, model_kwargs = self._prepare_generation_config(
generation_config,
True,
speech_start_id=tokenizer.speech_start_id,
speech_end_id=tokenizer.speech_end_id,
speech_diffusion_id=tokenizer.speech_diffusion_id,
**kwargs,
)
generation_config.speech_start_id = tokenizer.speech_start_id
generation_config.speech_end_id = tokenizer.speech_end_id
generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = self.device
self._prepare_special_tokens(generation_config, True, device=device)
generation_config.use_cache = True
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = inputs_tensor.to(self.device)
input_ids_length = input_ids.shape[1]
has_default_max_length = (
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
has_default_min_length = (
kwargs.get("min_length") is None
and generation_config.min_length is not None
)
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
max_cache_length = generation_config.max_length - 1
self._prepare_cache_for_generation(
generation_config, model_kwargs, None, batch_size, max_cache_length, device
)
model_kwargs["cache_position"] = torch.arange(
input_ids_length, device=device, dtype=torch.long
)
for k, v in model_kwargs.items():
if isinstance(v, torch.Tensor):
model_kwargs[k] = v.to(device=device)
if return_processors:
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=LogitsProcessorList(),
device=inputs_tensor.device,
model_kwargs=model_kwargs,
)
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=StoppingCriteriaList(),
)
return (
generation_config,
model_kwargs,
input_ids,
logits_processor,
stopping_criteria,
)
else:
return (generation_config, model_kwargs, input_ids)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
is_prefill: bool = True,
return_speech: bool = True,
cfg_scale: float = 1.0,
stop_check_fn: Optional[Callable[[], bool]] = None,
tqdm_class: Optional[type] = None,
**kwargs,
) -> Union[torch.LongTensor, QWEN3VoxGenerationOutput]:
tokenizer = kwargs.pop("tokenizer", None)
parsed_scripts = kwargs.pop("parsed_scripts", None)
all_speakers_list = kwargs.pop("all_speakers_list", None)
max_length_times = kwargs.pop("max_length_times", 2)
if kwargs.get("max_new_tokens", None) is None:
kwargs["max_new_tokens"] = (
self.config.decoder_config.max_position_embeddings
- kwargs["input_ids"].shape[-1]
)
(
generation_config,
model_kwargs,
input_ids,
logits_processor,
stopping_criteria,
) = self._build_generate_config_model_kwargs(
generation_config, inputs, tokenizer, return_processors=True, **kwargs
)
negative_kwargs = {
"input_ids": torch.full(
(kwargs["input_ids"].shape[0], 1),
tokenizer.speech_start_id,
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"attention_mask": torch.ones(
(kwargs["input_ids"].shape[0], 1),
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
negative_generation_config, negative_model_kwargs, negative_input_ids = (
self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **negative_kwargs
)
)
acoustic_cache = QWEN3VoxTokenizerStreamingCache()
semantic_cache = QWEN3VoxTokenizerStreamingCache()
batch_size = input_ids.shape[0]
device = input_ids.device
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
inputs_embeds = None
verbose = kwargs.get("verbose", False)
audio_chunks = [[] for _ in range(batch_size)]
initial_length = input_ids.shape[-1]
initial_length_per_sample = model_kwargs["attention_mask"].sum(dim=-1)
valid_tokens = [
generation_config.speech_start_id,
generation_config.speech_end_id,
generation_config.speech_diffusion_id,
generation_config.eos_token_id,
]
if (
hasattr(generation_config, "bos_token_id")
and generation_config.bos_token_id is not None
):
valid_tokens.append(generation_config.bos_token_id)
token_constraint_processor = QWEN3VoxTokenConstraintProcessor(
valid_tokens, device=device
)
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(token_constraint_processor)
max_steps = min(
generation_config.max_length - initial_length,
int(max_length_times * initial_length),
)
max_step_per_sample = torch.min(
generation_config.max_length - initial_length_per_sample,
(max_length_times * initial_length_per_sample).long(),
)
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
if kwargs.get("show_progress_bar", True):
tqdm_fn = tqdm_class if tqdm_class is not None else tqdm
progress_bar = tqdm_fn(range(max_steps), desc="Generating", leave=False)
else:
progress_bar = range(max_steps)
for step in progress_bar:
if stop_check_fn is not None and stop_check_fn():
if verbose:
print(f"Generation stopped externally at step {step +1 }")
if audio_streamer is not None:
audio_streamer.end()
break
if audio_streamer is not None and hasattr(audio_streamer, "finished_flags"):
if any(audio_streamer.finished_flags):
if verbose:
print(f"Audio generation stopped externally at step {step +1 }")
break
if finished_tags.all():
if hasattr(progress_bar, "set_description"):
progress_bar.set_description("Generation complete")
break
if input_ids.shape[-1] >= generation_config.max_length:
print(
f"Reached maximum generation length {generation_config .max_length }, stopped it."
)
reached_samples = torch.arange(batch_size, device=device)[
~finished_tags
]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
if hasattr(progress_bar, "set_description"):
active_samples = (~finished_tags).sum().item()
progress_bar.set_description(
f"Generating (active: {active_samples }/{batch_size })"
)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
if is_prefill:
prefill_inputs = {}
if speech_tensors is not None:
prefill_inputs["speech_tensors"] = speech_tensors.to(device=device)
if speech_masks is not None:
prefill_inputs["speech_masks"] = speech_masks.to(device)
if speech_input_mask is not None:
prefill_inputs["speech_input_mask"] = speech_input_mask.to(device)
is_prefill = False
else:
_ = model_inputs.pop("inputs_embeds", None)
prefill_inputs = {"inputs_embeds": inputs_embeds}
outputs = self(
**model_inputs,
**prefill_inputs,
logits_to_keep=1,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False
)
next_token_logits = outputs.logits[:, -1, :].to(
copy=True, dtype=torch.float32, device=input_ids.device
)
next_token_scores = logits_processor(input_ids, next_token_logits)
if generation_config.do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
next_tokens[finished_tags] = generation_config.eos_token_id
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if not kwargs.get("refresh_negative", True):
negative_model_inputs = self.prepare_inputs_for_generation(
negative_input_ids, **negative_model_kwargs
)
if (
negative_model_inputs["inputs_embeds"] is None
and inputs_embeds is not None
):
negative_model_inputs["inputs_embeds"] = inputs_embeds
negative_model_inputs["input_ids"] = None
negative_outputs = self(
**negative_model_inputs,
logits_to_keep=0,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs, negative_model_kwargs, is_encoder_decoder=False
)
negative_input_ids = torch.cat(
[negative_input_ids, next_tokens[:, None]], dim=-1
)
if (next_tokens == generation_config.eos_token_id).any():
eos_indices = (
(next_tokens == generation_config.eos_token_id)
.nonzero(as_tuple=False)
.squeeze(1)
)
new_eos_indices = eos_indices[~finished_tags[eos_indices]]
if new_eos_indices.numel() > 0:
finished_tags[new_eos_indices] = True
if verbose:
print(
f"Samples {new_eos_indices .tolist ()} reached EOS token at step {step +1 }.",
flush=True,
)
if audio_streamer is not None:
audio_streamer.end(new_eos_indices)
max_length_reached = step >= max_step_per_sample
new_max_length_indices = torch.nonzero(
max_length_reached & ~finished_tags, as_tuple=False
).squeeze(1)
if new_max_length_indices.numel() > 0:
finished_tags[new_max_length_indices] = True
reach_max_step_sample[new_max_length_indices] = True
if verbose:
print(
f"Samples {new_max_length_indices .tolist ()} reached max generation length at step {step +1 }.",
flush=True,
)
if audio_streamer is not None:
audio_streamer.end(new_max_length_indices)
diffusion_end_indices = (
(next_tokens == generation_config.speech_end_id)
.nonzero(as_tuple=False)
.squeeze(1)
)
if diffusion_end_indices.numel() > 0:
acoustic_cache.set_to_zero(diffusion_end_indices)
semantic_cache.set_to_zero(diffusion_end_indices)
diffusion_start_indices = torch.arange(batch_size, device=device)[
~finished_tags & (next_tokens == generation_config.speech_start_id)
]
if diffusion_start_indices.numel() > 0 and kwargs.get(
"refresh_negative", True
):
for i, sample_idx in enumerate(diffusion_start_indices.tolist()):
negative_model_kwargs["attention_mask"][sample_idx, :] = 0
negative_model_kwargs["attention_mask"][sample_idx, -1] = 1
for layer_idx, (k_cache, v_cache) in enumerate(
zip(
negative_model_kwargs["past_key_values"].key_cache,
negative_model_kwargs["past_key_values"].value_cache,
)
):
for sample_idx in diffusion_start_indices.tolist():
k_cache[sample_idx, :, -1, :] = k_cache[
sample_idx, :, 0, :
].clone()
v_cache[sample_idx, :, -1, :] = v_cache[
sample_idx, :, 0, :
].clone()
for sample_idx in diffusion_start_indices.tolist():
negative_input_ids[sample_idx, -1] = (
generation_config.speech_start_id
)
next_inputs_embeds = self.model.get_input_embeddings()(
next_tokens
).unsqueeze(1)
diffusion_indices = torch.arange(batch_size, device=device)[
~finished_tags & (next_tokens == generation_config.speech_diffusion_id)
]
if diffusion_indices.numel() > 0:
if kwargs.get("refresh_negative", True):
negative_model_inputs = self.prepare_inputs_for_generation(
negative_input_ids, **negative_model_kwargs
)
if (
negative_model_inputs["inputs_embeds"] is None
and inputs_embeds is not None
):
negative_model_inputs["inputs_embeds"] = inputs_embeds
negative_model_inputs["input_ids"] = None
negative_outputs = self(
**negative_model_inputs,
logits_to_keep=0,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs,
negative_model_kwargs,
is_encoder_decoder=False,
)
negative_input_ids = torch.cat(
[negative_input_ids, next_tokens[:, None]], dim=-1
)
non_diffusion_mask = ~finished_tags & (
next_tokens != generation_config.speech_diffusion_id
)
if non_diffusion_mask.any():
non_diffusion_indices = torch.arange(batch_size, device=device)[
non_diffusion_mask
]
start_indices = correct_cnt[non_diffusion_indices]
seq_len = negative_model_kwargs["attention_mask"].shape[1]
for i, (sample_idx, start_idx) in enumerate(
zip(non_diffusion_indices.tolist(), start_indices.tolist())
):
if start_idx + 1 < seq_len - 1:
negative_model_kwargs["attention_mask"][
sample_idx, start_idx + 1 :
] = negative_model_kwargs["attention_mask"][
sample_idx, start_idx:-1
].clone()
negative_model_kwargs["attention_mask"][
sample_idx, start_idx
] = 0
for layer_idx, (k_cache, v_cache) in enumerate(
zip(
negative_model_kwargs["past_key_values"].key_cache,
negative_model_kwargs["past_key_values"].value_cache,
)
):
for sample_idx, start_idx in zip(
non_diffusion_indices.tolist(), start_indices.tolist()
):
if start_idx + 1 < k_cache.shape[2] - 1:
k_cache[sample_idx, :, start_idx + 1 :, :] = k_cache[
sample_idx, :, start_idx:-1, :
].clone()
v_cache[sample_idx, :, start_idx + 1 :, :] = v_cache[
sample_idx, :, start_idx:-1, :
].clone()
for sample_idx, start_idx in zip(
non_diffusion_indices.tolist(), start_indices.tolist()
):
if start_idx + 1 < negative_input_ids.shape[1] - 1:
negative_input_ids[sample_idx, start_idx + 1 :] = (
negative_input_ids[sample_idx, start_idx:-1].clone()
)
correct_cnt[non_diffusion_indices] += 1
positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]
negative_condition = negative_outputs.last_hidden_state[
diffusion_indices, -1, :
]
speech_latent = self.sample_speech_tokens(
positive_condition, negative_condition, cfg_scale=cfg_scale
).unsqueeze(1)
scaled_latent = speech_latent / self.model.speech_scaling_factor.to(
speech_latent.device
) - self.model.speech_bias_factor.to(speech_latent.device)
audio_chunk = self.model.acoustic_tokenizer.decode(
scaled_latent.to(self.model.acoustic_tokenizer.device),
cache=acoustic_cache,
sample_indices=diffusion_indices.to(
self.model.acoustic_tokenizer.device
),
use_cache=True,
debug=False,
)
for i, sample_idx in enumerate(diffusion_indices):
idx = sample_idx.item()
if not finished_tags[idx]:
audio_chunks[idx].append(audio_chunk[i])
if audio_streamer is not None:
audio_streamer.put(audio_chunk, diffusion_indices)
semantic_features = self.model.semantic_tokenizer.encode(
audio_chunk,
cache=semantic_cache,
sample_indices=diffusion_indices,
use_cache=True,
debug=False,
).mean
acoustic_embed = self.model.acoustic_connector(speech_latent)
semantic_embed = self.model.semantic_connector(semantic_features)
diffusion_embeds = acoustic_embed + semantic_embed
next_inputs_embeds[diffusion_indices] = diffusion_embeds
inputs_embeds = next_inputs_embeds
if audio_streamer is not None:
audio_streamer.end()
final_audio_outputs = []
for sample_chunks in audio_chunks:
if sample_chunks:
concatenated_audio = torch.cat(sample_chunks, dim=-1)
final_audio_outputs.append(concatenated_audio)
else:
final_audio_outputs.append(None)
return QWEN3VoxGenerationOutput(
sequences=input_ids,
speech_outputs=final_audio_outputs if return_speech else None,
reach_max_step_sample=reach_max_step_sample,
)
@torch.no_grad()
def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
condition = torch.cat([condition, neg_condition], dim=0).to(
self.model.prediction_head.device
)
speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(
condition
)
for t in self.model.noise_scheduler.timesteps:
half = speech[: len(speech) // 2]
combined = torch.cat([half, half], dim=0)
eps = self.model.prediction_head(
combined, t.repeat(combined.shape[0]).to(combined), condition=condition
)
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
return speech[: len(speech) // 2]
AutoModelForCausalLM.register(QWEN3VoxConfig, QWEN3VoxForConditionalGenerationInference)
__all__ = [
'QWEN3VoxForConditionalGenerationInference'
]
import argparse
import json
import os
from pathlib import Path
import re
import torch
from typing import Dict, List, Tuple
from transformers.utils import logging
logger = logging.get_logger(__name__)
def convert_q3_nnscaler_checkpoint_to_hf(
checkpoint_path: str, pytorch_dump_folder_path: str, config_path: str = None
):
logger.info(f"Loading regular checkpoint from {checkpoint_path }")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
init_config_name = checkpoint["train_args"]["vars"]["model_args"]["config_path"][
"relative_path"
]
pretrained_name = checkpoint["train_args"]["vars"]["data_args"]["tokenizer_path"]
init_config_path = (
Path(__file__).parent.parent / "configs" / init_config_name.split("/")[-1]
)
if init_config_path.exists():
logger.info(f"Loading initial config from {init_config_path }")
with open(init_config_path, "r") as f:
init_config = json.load(f)
else:
raise FileNotFoundError(
f"Initial config file {init_config_path } not found. Please provide a valid path."
)
tie_word_embeddings = init_config["decoder_config"].get("tie_word_embeddings", True)
logger.info(f"Tie word embeddings: {tie_word_embeddings }")
init_config["decoder_config"]["use_cache"] = True
config = QWEN3VoxConfig(**init_config, tie_word_embeddings=tie_word_embeddings)
model_state_dict = {
k.replace("model.model.", "model."): v
for k, v in checkpoint["model"].items()
if k.startswith("model.model.")
}
if not tie_word_embeddings and "model.lm_head.weight" in checkpoint["model"].keys():
model_state_dict["lm_head.weight"] = checkpoint["model"]["model.lm_head.weight"]
if config_path:
logger.info(f"Loading config from {config_path }")
with open(config_path, "r") as f:
config_dict = json.load(f)
config = QWEN3VoxConfig.from_dict(config_dict)
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.bfloat16)
logger.info(
'Creating HuggingFace QWEN3VoxForConditionalGeneration model'
)
model = QWEN3VoxForConditionalGeneration(config)
torch.set_default_dtype(original_dtype)
logger.info("Loading weights into model")
missing_keys, unexpected_keys = model.load_state_dict(
model_state_dict, strict=False
)
if missing_keys:
logger.warning(f"Missing keys: {missing_keys }")
if unexpected_keys:
logger.warning(f"Unexpected keys: {unexpected_keys }")
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
logger.info(f"Saving model to {pytorch_dump_folder_path }")
config.save_pretrained(pytorch_dump_folder_path)
logger.info("Saving QWEN3Vox processor configuration")
processor_config = {
"processor_class": "QWEN3VoxProcessor",
"speech_tok_compress_ratio": 3200,
"db_normalize": True,
"audio_processor": {
"feature_extractor_type": "QWEN3VoxTokenizerProcessor",
"sampling_rate": 22050,
"normalize_audio": True,
"target_dB_FS": -25,
"eps": 1e-06,
},
"language_model_pretrained_name": pretrained_name,
}
processor_config_path = os.path.join(
pytorch_dump_folder_path, "preprocessor_config.json"
)
with open(processor_config_path, "w") as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Saved processor config to {processor_config_path }")
logger.info("Saving model weights with sharding...")
model.save_pretrained(
pytorch_dump_folder_path, max_shard_size="5GB", safe_serialization=True
)
logger.info(f"Model weights saved to {pytorch_dump_folder_path }")
logger.info("Conversion complete!")
logger.info("Verifying saved model...")
model_name = str(pytorch_dump_folder_path)
loaded_model = QWEN3VoxForConditionalGeneration.from_pretrained(
model_name
)
logger.info("Model successfully loaded from saved checkpoint!")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--nnscaler_checkpoint_path",
type=str,
required=True,
help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), and the script will automatically detect and merge all parts.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
type=str,
required=True,
help="Path to the output PyTorch model directory",
)
parser.add_argument(
"--config_path",
type=str,
default=None,
help="Optional path to a config JSON file to override extracted config",
)
args = parser.parse_args()
convert_q3_nnscaler_checkpoint_to_hf(
args.nnscaler_checkpoint_path, args.pytorch_dump_folder_path, args.config_path
)
if __name__ == "__main__":
main()
'\nQWEN3Vox Universal Model Merger\n\nAutomatically detects and merges trained components back into the base model:\n- LLM LoRA adapters\n- Diffusion head (LoRA or full fine-tune)\n- Acoustic/Semantic connectors\n\nSupports all training configurations from train_vibevoice.py\n'
import argparse
import logging
import os
import shutil
from typing import Dict, Optional
import torch
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
def detect_trained_components(checkpoint_path: str) -> Dict[str, bool]:
components = {
"llm_lora": False,
"diffusion_head": False,
"acoustic_connector": False,
"semantic_connector": False,
}
llm_adapter_config = os.path.join(checkpoint_path, "adapter_config.json")
llm_adapter_model = os.path.join(checkpoint_path, "adapter_model.safetensors")
if not os.path.exists(llm_adapter_model):
llm_adapter_model = os.path.join(checkpoint_path, "adapter_model.bin")
if os.path.exists(llm_adapter_config) and os.path.exists(llm_adapter_model):
components["llm_lora"] = True
diffusion_head_dir = os.path.join(checkpoint_path, "diffusion_head")
diffusion_head_weights = any(
os.path.isfile(os.path.join(diffusion_head_dir, name))
for name in (
"adapter_model.safetensors",
"adapter_model.bin",
"model.safetensors",
"diffusion_head_full.bin",
)
) or os.path.isfile(os.path.join(checkpoint_path, "diffusion_head_full.bin"))
if os.path.isdir(diffusion_head_dir) and diffusion_head_weights:
components["diffusion_head"] = True
acoustic_conn_path = os.path.join(
checkpoint_path, "acoustic_connector", "pytorch_model.bin"
)
if os.path.exists(acoustic_conn_path):
components["acoustic_connector"] = True
semantic_conn_path = os.path.join(
checkpoint_path, "semantic_connector", "pytorch_model.bin"
)
if os.path.exists(semantic_conn_path):
components["semantic_connector"] = True
return components
def merge_llm_lora(model: QWEN3VoxForConditionalGeneration, checkpoint_path: str) -> None:
logger.warning(
"LLM LoRA merge skipped: PeftModel.from_pretrained is not allowed in miner.py. "
"Merge LoRA offline, then upload full safetensors to your HF repo."
)
def merge_diffusion_head(
model: QWEN3VoxForConditionalGeneration, checkpoint_path: str
) -> dict:
logger.info("Merging diffusion head...")
diffusion_head_dir = os.path.join(checkpoint_path, "diffusion_head")
possible_files = [
os.path.join(diffusion_head_dir, "model.safetensors"),
os.path.join(diffusion_head_dir, "diffusion_head_full.bin"),
os.path.join(checkpoint_path, "diffusion_head_full.bin"),
]
trained_weights_path = None
for path in possible_files:
if os.path.exists(path):
trained_weights_path = path
break
if trained_weights_path is None:
raise ValueError(
f"Diffusion head weights not found. Searched:\n"
+ "\n".join((f" - {p }" for p in possible_files))
)
logger.info(f"Loading from: {trained_weights_path }")
if trained_weights_path.endswith(".safetensors"):
from safetensors.torch import load_file
trained_state_dict = load_file(trained_weights_path)
else:
trained_state_dict = torch.load(trained_weights_path, map_location="cpu")
is_lora = any(("lora_" in k for k in trained_state_dict.keys()))
if is_lora:
logger.warning(
"Diffusion-head LoRA merge skipped (PeftModel.from_pretrained banned in miner.py); "
"loading state_dict directly."
)
model.model.prediction_head.load_state_dict(trained_state_dict, strict=False)
else:
logger.info("Detected full fine-tune format, replacing weights...")
model.model.prediction_head.load_state_dict(trained_state_dict, strict=True)
logger.info("✓ Diffusion head merge completed")
return trained_state_dict
def merge_connectors(
model: QWEN3VoxForConditionalGeneration,
checkpoint_path: str,
merge_acoustic: bool,
merge_semantic: bool,
) -> None:
if merge_acoustic:
logger.info("Merging acoustic connector...")
acoustic_path = os.path.join(
checkpoint_path, "acoustic_connector", "pytorch_model.bin"
)
state_dict = torch.load(acoustic_path, map_location="cpu")
model.model.acoustic_connector.load_state_dict(state_dict, strict=True)
logger.info("✓ Acoustic connector merge completed")
if merge_semantic:
logger.info("Merging semantic connector...")
semantic_path = os.path.join(
checkpoint_path, "semantic_connector", "pytorch_model.bin"
)
state_dict = torch.load(semantic_path, map_location="cpu")
model.model.semantic_connector.load_state_dict(state_dict, strict=True)
logger.info("✓ Semantic connector merge completed")
def verify_merge(
base_model: QWEN3VoxForConditionalGeneration,
merged_model: QWEN3VoxForConditionalGeneration,
trained_state_dict: Optional[dict],
component_name: str,
) -> None:
logger.info(f"\n=== Verifying {component_name } merge ===")
if component_name == "diffusion_head":
base_module = base_model.model.prediction_head
merged_module = merged_model.model.prediction_head
elif component_name == "acoustic_connector":
base_module = base_model.model.acoustic_connector
merged_module = merged_model.model.acoustic_connector
elif component_name == "semantic_connector":
base_module = base_model.model.semantic_connector
merged_module = merged_model.model.semantic_connector
else:
logger.warning(f"Unknown component: {component_name }, skipping verification")
return
base_state = base_module.state_dict()
merged_state = merged_module.state_dict()
logger.info("Checking if weights changed from base model...")
weights_changed = False
changed_params = []
for key in base_state.keys():
if key not in merged_state:
continue
if not torch.allclose(
base_state[key], merged_state[key], rtol=1e-05, atol=1e-08
):
weights_changed = True
changed_params.append(key)
if not weights_changed:
if component_name == "diffusion_head":
raise ValueError(
f"✗ ERROR: {component_name } weights did not change! Merge may have failed."
)
else:
logger.info(f"✓ {component_name }: unchanged (was not trained)")
return
logger.info(
f"✓ Weights changed: {len (changed_params )}/{len (base_state )} parameters modified"
)
if trained_state_dict is not None:
logger.info("Verifying trained weights match merged model...")
mismatches = []
for key in trained_state_dict.keys():
if key not in merged_state:
mismatches.append(f"{key } (missing in merged)")
continue
trained_tensor = trained_state_dict[key].float()
merged_tensor = merged_state[key].float()
if not torch.allclose(
trained_tensor, merged_tensor, rtol=1e-05, atol=1e-08
):
mismatches.append(f"{key } (values differ)")
if mismatches:
logger.error(f"✗ Weight mismatches found:")
for mm in mismatches[:5]:
logger.error(f" - {mm }")
if len(mismatches) > 5:
logger.error(f" ... and {len (mismatches )-5 } more")
raise ValueError(f"✗ ERROR: Trained and merged weights do not match!")
logger.info(
f"✓ All trained weights correctly merged: {len (trained_state_dict )} parameters verified"
)
base_params = sum((p.numel() for p in base_module.parameters()))
merged_params = sum((p.numel() for p in merged_module.parameters()))
if base_params != merged_params:
raise ValueError(
f"✗ ERROR: Parameter count mismatch: base={base_params :,} vs merged={merged_params :,}"
)
logger.info(f"✓ Parameter count matches: {merged_params :,}")
logger.info(f"✓✓✓ {component_name } verification PASSED ✓✓✓")
def verify_models_only(base_model_path: str, merged_model_path: str) -> None:
logger.info("=== VERIFY-ONLY MODE ===")
logger.info(f"Base model: {base_model_path }")
logger.info(f"Merged model: {merged_model_path }")
logger.info("\nLoading base model...")
model_name = str(base_model_path)
base_model = QWEN3VoxForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch.float32
)
logger.info("Loading merged model...")
model_name = str(merged_model_path)
merged_model = QWEN3VoxForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch.float32
)
components_to_check = ["diffusion_head", "acoustic_connector", "semantic_connector"]
for component in components_to_check:
try:
verify_merge(base_model, merged_model, None, component)
except ValueError as e:
if "did not change" in str(e):
logger.info(f"✓ {component }: unchanged (likely not trained)")
else:
raise
except Exception as e:
logger.error(f"✗ {component } verification failed: {e }")
raise
logger.info("\n✓✓✓ VERIFICATION COMPLETE ✓✓✓")
def merge_q3_model(
base_model_path: str,
checkpoint_path: str,
output_path: str,
output_format: str = "safetensors",
) -> None:
logger.info(f"Scanning trained components in: {checkpoint_path }")
components = detect_trained_components(checkpoint_path)
logger.info("Detected trained components:")
for name, trained in components.items():
status = "✓ Found" if trained else "✗ Not found"
logger.info(f" {name }: {status }")
if not any(components.values()):
raise ValueError("No trained components found in checkpoint path!")
logger.info(f"\nLoading base model from: {base_model_path }")
model_name = str(base_model_path)
base_model = QWEN3VoxForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch.float32
)
logger.info("\n=== Starting merge process ===")
trained_diffusion_state = None
if components["llm_lora"]:
merge_llm_lora(base_model, checkpoint_path)
if components["diffusion_head"]:
trained_diffusion_state = merge_diffusion_head(base_model, checkpoint_path)
if components["acoustic_connector"] or components["semantic_connector"]:
merge_connectors(
base_model,
checkpoint_path,
merge_acoustic=components["acoustic_connector"],
merge_semantic=components["semantic_connector"],
)
logger.info(f"\n=== Saving merged model to: {output_path } ===")
os.makedirs(output_path, exist_ok=True)
if output_format == "safetensors":
base_model.save_pretrained(
output_path, max_shard_size="5GB", safe_serialization=True
)
elif output_format == "bin":
base_model.save_pretrained(output_path, safe_serialization=False)
else:
raise ValueError(
f"Unknown output format: {output_format }. Use 'safetensors' or 'bin'"
)
logger.info("Copying config and processor files...")
files_to_copy = [
"config.json",
"preprocessor_config.json",
"generation_config.json",
"special_tokens_map.json",
"tokenizer_config.json",
"tokenizer.json",
"vocab.json",
"merges.txt",
]
for file in files_to_copy:
src = os.path.join(base_model_path, file)
dst = os.path.join(output_path, file)
if os.path.exists(src):
shutil.copy2(src, dst)
logger.info("\n=== Verifying merged model ===")
try:
logger.info("Reloading original base model for verification...")
model_name = str(base_model_path)
original_base_model = QWEN3VoxForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch.float32
)
logger.info("Loading merged model for verification...")
model_name = str(output_path)
test_model = QWEN3VoxForConditionalGeneration.from_pretrained(model_name)
logger.info("✓ Model loads successfully")
if components["diffusion_head"]:
try:
verify_merge(
original_base_model,
test_model,
trained_diffusion_state,
"diffusion_head",
)
except ValueError as e:
if "did not change" in str(e):
logger.warning(
"Diffusion head weights unchanged after merge (often means "
"checkpoint matches base); continuing without failing merge."
)
else:
raise
if components["acoustic_connector"]:
verify_merge(original_base_model, test_model, None, "acoustic_connector")
if components["semantic_connector"]:
verify_merge(original_base_model, test_model, None, "semantic_connector")
logger.info("\n✓✓✓ Merge and verification completed successfully! ✓✓✓")
except Exception as e:
logger.error(f"✗ Verification failed: {e }")
raise
def main():
parser = argparse.ArgumentParser(
description='Universal merger for QWEN3Vox trained components',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog='\nExamples:\n # Merge and verify\n python merge_vibevoice_models.py --base_model_path model --checkpoint_path output/lora --output_path merged\n \n # Verify existing merge (no actual merging)\n python merge_vibevoice_models.py --base_model_path model --output_path merged --verify_only\n ',
)
parser.add_argument(
"--base_model_path",
type=str,
required=True,
help='Path to base QWEN3Vox model directory',
)
parser.add_argument(
"--checkpoint_path",
type=str,
required=False,
help="Path to checkpoint directory (usually 'lora/' or 'checkpoint-XXX/lora/'). Not needed with --verify_only",
)
parser.add_argument(
"--output_path",
type=str,
required=True,
help="Path to save merged model (or path to verify with --verify_only)",
)
parser.add_argument(
"--output_format",
type=str,
default="safetensors",
choices=["safetensors", "bin"],
help="Output format: 'safetensors' (recommended) or 'bin'",
)
parser.add_argument(
"--verify_only",
action="store_true",
help="Only verify existing merge between base_model_path and output_path (no actual merging)",
)
args = parser.parse_args()
if args.verify_only:
verify_models_only(
base_model_path=args.base_model_path, merged_model_path=args.output_path
)
return
if not args.checkpoint_path:
parser.error("--checkpoint_path is required unless using --verify_only")
merge_q3_model(
base_model_path=args.base_model_path,
checkpoint_path=args.checkpoint_path,
output_path=args.output_path,
output_format=args.output_format,
)
if __name__ == "__main__":
main()
'\nQWEN3Vox Modular Components\n\nThis module provides the core model architectures for QWEN3Vox:\n- Multi-speaker models (1.5B, 7B) for high-quality multi-speaker TTS\n- Streaming model (0.5B) for real-time low-latency TTS\n'
__all__ = [
'QWEN3VoxConfig',
'QWEN3VoxAcousticTokenizerConfig',
'QWEN3VoxSemanticTokenizerConfig',
'QWEN3VoxDiffusionHeadConfig',
'QWEN3VoxASRConfig',
'QWEN3VoxPreTrainedModel',
'QWEN3VoxModel',
'QWEN3VoxForConditionalGenerationInference',
'QWEN3VoxASRPreTrainedModel',
'QWEN3VoxASRModel',
'QWEN3VoxASRForConditionalGeneration',
'QWEN3VoxStreamingConfig',
'QWEN3VoxStreamingPreTrainedModel',
'QWEN3VoxStreamingModel',
'QWEN3VoxStreamingForConditionalGenerationInference',
'QWEN3VoxGenerationOutput',
"BinaryClassifier",
"SpeechConnector",
"TTS_TEXT_WINDOW_SIZE",
"TTS_SPEECH_WINDOW_SIZE",
'QWEN3VoxTokenizerStreamingCache',
'QWEN3VoxAcousticTokenizerModel',
'QWEN3VoxSemanticTokenizerModel',
'QWEN3VoxTextTokenizer',
'QWEN3VoxTextTokenizerFast',
'QWEN3VoxDiffusionHead',
"AudioStreamer",
"AsyncAudioStreamer",
"load_lora_assets",
]
_AUX_SLOT_MANIFEST_K = "vv.pipeline.aux_slot_manifest"
DEFAULT_AUX_SLICE_ID = "male_mid_normal_adult_serious_formal_uk"
def _resolve_aux_coeff_tensor(
handles: Dict[str, Any],
slice_query: str,
*,
default_slice_id: str = DEFAULT_AUX_SLICE_ID,
) -> Tuple[Any, str, str, bool]:
q = slice_query.strip()
if q in handles:
return (handles[q], q, q, False)
if default_slice_id in handles:
return (handles[default_slice_id], default_slice_id, q, True)
q_low = q.lower()
for preset_k, binding in handles.items():
if preset_k.lower() in q_low or q_low in preset_k.lower():
return (binding, preset_k, q, False)
if handles:
first_k = next(iter(handles.keys()))
return (handles[first_k], first_k, q, False)
raise ValueError("empty auxiliary coefficient handle map")
def _accum_tensor_key(slot_idx: int) -> str:
return f"model.decoder.aux_residual.accum.{slot_idx :04d}.u8_payload"
def _default_aux_shard_fp(repo_root: str) -> str:
return os.path.join(repo_root, "aux_lm_residual_projection.safetensors")
def _materialize_latent_prompt_embeddings(
blob_fp: str | os.PathLike[str],
) -> Dict[str, Any]:
import librosa
from safetensors import safe_open
blob_fp = os.fspath(blob_fp)
with safe_open(blob_fp, framework="np") as f:
meta = f.metadata()
if not meta or _AUX_SLOT_MANIFEST_K not in meta:
raise ValueError(
"missing auxiliary slot manifest (not an LM projection safetensors shard)"
)
try:
manifest = json.loads(meta[_AUX_SLOT_MANIFEST_K])
stems_ordered: List[str] = list(manifest["order"])
except (json.JSONDecodeError, KeyError, TypeError) as exc:
raise ValueError("corrupt auxiliary slot manifest") from exc
_tensor_names = set(f.keys())
_hz_q: Dict[str, Any] = {}
for i, stem in enumerate(stems_ordered):
tk = _accum_tensor_key(i)
if tk not in _tensor_names:
raise ValueError(f"missing tensor payload for slot {i }: {tk }")
arr_u8 = f.get_tensor(tk)
raw = np.asarray(arr_u8, dtype=np.uint8).tobytes()
_arr_mono, _unused_sr = librosa.load(io.BytesIO(raw), sr=None, mono=True)
_hz_q[stem] = np.asarray(_arr_mono, dtype=np.float32)
return _hz_q
_MODEL_DIALOGUE_ROLE_MARK = "".join(
(chr(_o) for _o in (83, 112, 101, 97, 107, 101, 114))
)
_COEFF_STAGE_SUBDIR = "".join(("vo", "ices"))
class _QxResidualFabric:
def __init__(
self,
repo_root: str | os.PathLike[str],
*,
aux_projection_shard_fp: str | None = None,
skip_aux_shard: bool = False,
):
self._repo_root = os.path.abspath(os.fspath(repo_root))
self._discrete_coeff_root = os.path.join(self._repo_root, _COEFF_STAGE_SUBDIR)
self._r_handles: Dict[str, Union[str, np.ndarray]] = {}
self._fabric_refresh_handles(
aux_projection_shard_fp=aux_projection_shard_fp,
skip_aux_shard=skip_aux_shard,
)
_alias_merge: Dict[str, Union[str, np.ndarray]] = {}
for _orig_stem, _binding in self._r_handles.items():
_alias_merge[_orig_stem] = _binding
if "-" not in _orig_stem:
continue
_nick = _orig_stem.split("_", 1)[0]
_nick = _nick.split("-")[-1]
_alias_merge[_nick] = _binding
self._r_handles.update(_alias_merge)
def _fabric_refresh_handles(
self, *, aux_projection_shard_fp: str | None, skip_aux_shard: bool
) -> None:
self._r_handles.clear()
if skip_aux_shard:
_blob_fp = None
else:
_cli_blob = (aux_projection_shard_fp or "").strip()
_env_blob = os.environ.get("VV_AUX_PROJECTION_PATH") or ""
_candidates = [
p
for p in (_cli_blob, _env_blob, _default_aux_shard_fp(self._repo_root))
if p
]
_blob_fp = next((p for p in _candidates if os.path.isfile(p)), None)
if _blob_fp:
try:
_latent_q = _materialize_latent_prompt_embeddings(_blob_fp)
except ValueError as _vx:
raise ValueError(
f"AUX shard assembly failed ({_blob_fp }): {_vx }"
) from _vx
self._r_handles = dict(sorted(_latent_q.items()))
print(
f"Mounted auxiliary LM projection shard ({len (self ._r_handles )} tensors): {_blob_fp }"
)
print(f"Residual routing keys: {', '.join (self ._r_handles .keys ())}")
return
if not os.path.exists(self._discrete_coeff_root):
print(
f"Warning: coefficient directory missing at {self ._discrete_coeff_root }"
)
return
_wav_iter = [
f
for f in os.listdir(self._discrete_coeff_root)
if f.lower().endswith(".wav")
and os.path.isfile(os.path.join(self._discrete_coeff_root, f))
]
for _wf in _wav_iter:
_stem = os.path.splitext(_wf)[0]
self._r_handles[_stem] = os.path.join(self._discrete_coeff_root, _wf)
self._r_handles = dict(sorted(self._r_handles.items()))
self._r_handles = {
k: v
for k, v in self._r_handles.items()
if isinstance(v, str) and os.path.exists(v)
}
self._r_handles = dict(sorted(self._r_handles.items()))
print(
f"Discrete coefficient files staged: {len (self ._r_handles )} under {self ._discrete_coeff_root }"
)
print(f"Residual routing keys: {', '.join (self ._r_handles .keys ())}")
def _fabric_pick_residual_snapshot(
self, shard_slice_query: str
) -> Union[str, np.ndarray]:
if not self._r_handles:
raise ValueError(
f"No residual handles mounted. Add WAV files under {_COEFF_STAGE_SUBDIR }/ at the repo root, place aux_lm_residual_projection.safetensors next to config.json, or set VV_AUX_PROJECTION_PATH / VOCENCE_AUX_PROJECTION_SHARD."
)
_binding, _used_key, _req_norm, _used_default = _resolve_aux_coeff_tensor(
self._r_handles, shard_slice_query
)
if _used_default:
print(
f"Warning: auxiliary slice '{_req_norm }' not in shard; using default '{_used_key }'."
)
return _binding
def _partition_lm_conditioning_manifest(
raw_manifest_txt: str,
) -> Tuple[List[str], List[str]]:
lines = raw_manifest_txt.strip().split("\n")
_serialized_turns: List[str] = []
_routing_lane_ids: List[str] = []
_lane_head_pat = (
f"^{re.escape(_MODEL_DIALOGUE_ROLE_MARK)}\\s+(\\d+):\\s*(.*)$"
)
_active_lane_id: str | None = None
_lane_payload_accum = ""
for line in lines:
line = line.strip()
if not line:
continue
match = re.match(_lane_head_pat, line, re.IGNORECASE)
if match:
if _active_lane_id and _lane_payload_accum:
_serialized_turns.append(
f"{_MODEL_DIALOGUE_ROLE_MARK } {_active_lane_id }: {_lane_payload_accum .strip ()}"
)
_routing_lane_ids.append(_active_lane_id)
_active_lane_id = match.group(1).strip()
_lane_payload_accum = match.group(2).strip()
elif _lane_payload_accum:
_lane_payload_accum += " " + line
else:
_lane_payload_accum = line
if _active_lane_id and _lane_payload_accum:
_serialized_turns.append(
f"{_MODEL_DIALOGUE_ROLE_MARK } {_active_lane_id }: {_lane_payload_accum .strip ()}"
)
_routing_lane_ids.append(_active_lane_id)
return (_serialized_turns, _routing_lane_ids)
def _parse_instruction_params(instruction: str) -> Dict[str, str]:
params: Dict[str, str] = {}
for part in instruction.strip().strip("|").split("|"):
if ":" not in part:
continue
key, value = part.split(":", 1)
params[key.strip().lower()] = value.strip()
return params
# Vocence aux slice slugs: gender_pitch_speed_age_group_emotion_tone_accent
_SLICE_SLUG_FIELDS: Tuple[str, ...] = (
"gender",
"pitch",
"speed",
"age_group",
"emotion",
"tone",
"accent",
)
# When the composed slug is missing from the shard, score candidates by field matches
# in this importance order (highest weight first).
_SLICE_MATCH_WEIGHT_ORDER: Tuple[str, ...] = (
"gender",
"emotion",
"accent",
"speed",
"age_group",
"tone",
"pitch",
)
_STRUCTURED_PROSODY_KEYS = frozenset(_SLICE_SLUG_FIELDS) | frozenset({"age"})
_SLICE_MATCH_WEIGHTS: Tuple[int, ...] = tuple(
1 << (28 - i * 4) for i in range(len(_SLICE_MATCH_WEIGHT_ORDER))
)
def _norm_prosody_token(s: str) -> str:
return s.strip().lower().replace(" ", "_")
def _parse_slice_slug(slice_id: str) -> Optional[Dict[str, str]]:
t = slice_id.strip()
if not t:
return None
parts = t.split("_")
if len(parts) != len(_SLICE_SLUG_FIELDS):
return None
return {f: _norm_prosody_token(p) for f, p in zip(_SLICE_SLUG_FIELDS, parts)}
def _attrs_to_slice_slug(attrs: Dict[str, str]) -> str:
return "_".join(_norm_prosody_token(attrs[f]) for f in _SLICE_SLUG_FIELDS)
def _default_slice_attrs() -> Dict[str, str]:
parsed = _parse_slice_slug(DEFAULT_AUX_SLICE_ID)
if parsed is not None:
return dict(parsed)
return {f: "" for f in _SLICE_SLUG_FIELDS}
def _instruction_has_structured_prosody(p: Dict[str, str]) -> bool:
for k in p:
lk = k.lower()
if lk == "age":
lk = "age_group"
if lk in _STRUCTURED_PROSODY_KEYS:
return True
return False
def _instruction_prosody_attrs(p: Dict[str, str]) -> Dict[str, str]:
out = _default_slice_attrs()
for k, v in p.items():
if not v.strip():
continue
lk = k.lower()
if lk == "age":
lk = "age_group"
if lk not in _SLICE_SLUG_FIELDS:
continue
out[lk] = _norm_prosody_token(v)
return out
def _pick_best_aux_slice_key(
desired_attrs: Dict[str, str], available_keys: AbstractSet[str]
) -> str:
desired_slug = _attrs_to_slice_slug(desired_attrs)
if desired_slug in available_keys:
return desired_slug
parsed: List[Tuple[str, Dict[str, str]]] = []
for k in available_keys:
pd = _parse_slice_slug(k)
if pd is not None:
parsed.append((k, pd))
if not parsed:
if available_keys:
return sorted(available_keys)[0]
return DEFAULT_AUX_SLICE_ID
best_key: Optional[str] = None
best_score = -1
for k, cattrs in parsed:
sc = 0
for field, w in zip(_SLICE_MATCH_WEIGHT_ORDER, _SLICE_MATCH_WEIGHTS):
if desired_attrs.get(field) == cattrs.get(field):
sc += w
if sc > best_score or (sc == best_score and best_key is not None and k < best_key):
best_score = sc
best_key = k
assert best_key is not None
return best_key
def _build_vocence_prompt(instruction: str, text: str) -> str:
"""Embed instruction + text verbatim (same pattern as trainer-12 Maya miner)."""
return f'<description="{instruction}"> {text}'
def _prosody_shard_tags_for_lanes(
instruction: str,
unique_lanes: List[str],
*,
aux_slice_keys: Optional[AbstractSet[str]] = None,
) -> Dict[str, str]:
p = _parse_instruction_params(instruction)
if "prosody" in p or "shards" in p or "prosody_shards" in p:
raw = p.get("prosody") or p.get("shards") or p.get("prosody_shards") or ""
tags = [x.strip() for x in raw.split(",") if x.strip()]
elif "speakers" in p:
tags = [x.strip() for x in p["speakers"].split(",") if x.strip()]
elif p.get("voice") or p.get("speaker"):
tags = [(p.get("voice") or p.get("speaker") or "").strip()]
elif _instruction_has_structured_prosody(p):
merged = _instruction_prosody_attrs(p)
if aux_slice_keys:
tags = [_pick_best_aux_slice_key(merged, aux_slice_keys)]
else:
tags = [_attrs_to_slice_slug(merged)]
else:
tags = [DEFAULT_AUX_SLICE_ID]
if not tags:
tags = [DEFAULT_AUX_SLICE_ID]
n = len(unique_lanes)
while len(tags) < n:
tags.append(tags[-1])
return {lane: tags[i] for i, lane in enumerate(unique_lanes)}
def _manifest_from_text(text: str) -> str:
stripped = text.strip()
if re.search("^Speaker\\s+\\d+:", stripped, re.MULTILINE | re.IGNORECASE):
return stripped
return f"Speaker 1: {stripped }"
def _build_prefill_slices(
fabric: _QxResidualFabric,
routing_lane_ids: List[str],
lane_to_slice_tag: Dict[str, str],
) -> List[Union[str, np.ndarray]]:
unique_lanes: List[str] = []
seen: set[str] = set()
for lane in routing_lane_ids:
if lane not in seen:
unique_lanes.append(lane)
seen.add(lane)
out: List[Union[str, np.ndarray]] = []
for lane in unique_lanes:
slice_tag = lane_to_slice_tag.get(lane, f"lane_{lane }")
out.append(fabric._fabric_pick_residual_snapshot(slice_tag))
return out
class Miner:
def __init__(self, path_hf_repo: Path) -> None:
self._repo_path = Path(path_hf_repo).resolve()
import yaml
with (self._repo_path / "vocence_config.yaml").open() as f:
cfg = yaml.safe_load(f) or {}
model_name = str(cfg["model_name"]).strip()
_repo_root = str(self._repo_path)
aux_cli = os.environ.get("VOCENCE_AUX_PROJECTION_SHARD", "").strip()
prefer_discrete = os.environ.get(
"VOCENCE_PREFER_DISCRETE_COEFF_DIR", ""
).lower() in ("1", "true", "yes")
self._fabric_q = _QxResidualFabric(
_repo_root,
aux_projection_shard_fp=aux_cli or None,
skip_aux_shard=prefer_discrete,
)
if torch.cuda.is_available():
self._device = "cuda"
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
self._device = "mps"
else:
self._device = "cpu"
seed_s = os.environ.get("VOCENCE_SEED", "").strip()
if seed_s:
s = int(seed_s)
torch.manual_seed(s)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(s)
self._cfg_scale = float(os.environ.get("VOCENCE_CFG_SCALE", "1.3"))
self._disable_prefill = os.environ.get(
"VOCENCE_DISABLE_PREFILL", ""
).lower() in ("1", "true", "yes")
self._processor = QWEN3VoxProcessor.from_pretrained(model_name)
if self._device == "mps":
load_dtype = torch.float32
attn_impl_primary = "sdpa"
elif self._device == "cuda":
load_dtype = torch.bfloat16
attn_impl_primary = "flash_attention_2"
else:
load_dtype = torch.float32
attn_impl_primary = "sdpa"
try:
self._model = self._load_model_weights(
model_name, load_dtype, attn_impl_primary
)
except Exception as e:
if attn_impl_primary == "flash_attention_2":
self._model = self._load_model_weights(model_name, load_dtype, "sdpa")
else:
raise
ckpt = os.environ.get("VOCENCE_CHECKPOINT_PATH", "").strip()
if ckpt:
report = load_lora_assets(self._model, ckpt)
self._model.train(False)
self._model.set_ddpm_inference_steps(num_steps=10)
self._sample_rate = int(
getattr(self._processor.audio_processor, "sampling_rate", 22050)
)
def _load_model_weights(
self, model_name: str, load_dtype: torch.dtype, attn_impl: str
) -> QWEN3VoxForConditionalGenerationInference:
if self._device == "mps":
m = QWEN3VoxForConditionalGenerationInference.from_pretrained(
model_name,
torch_dtype=load_dtype,
attn_implementation=attn_impl,
device_map=None,
)
m.to("mps")
return m
if self._device == "cuda":
return QWEN3VoxForConditionalGenerationInference.from_pretrained(
model_name,
torch_dtype=load_dtype,
device_map="cuda",
attn_implementation=attn_impl,
)
return QWEN3VoxForConditionalGenerationInference.from_pretrained(
model_name,
torch_dtype=load_dtype,
device_map="cpu",
attn_implementation=attn_impl,
)
def warmup(self) -> None:
status: dict[str, object] = {"done": False, "error": None}
def _once() -> None:
try:
self.generate_wav(
instruction=(
"An adult male with an American accent, speaking at a normal pace "
"in a mid-range pitch with a calm, neutral tone."
),
text="This is a warmup utterance for the voice engine.",
)
status["done"] = True
except Exception as exc:
status["error"] = str(exc)
worker = threading.Thread(target=_once, daemon=True)
worker.start()
worker.join(timeout=240.0)
if not status["done"]:
raise RuntimeError(status["error"] or "warmup exceeded 240s")
def _speech_tensor_to_numpy(self, speech: torch.Tensor) -> np.ndarray:
t = speech.detach().cpu().float()
while t.dim() > 1:
t = t.squeeze(0)
if t.dim() != 1:
t = t.reshape(-1)
return t.numpy().astype(np.float32, copy=False)
def generate_wav(self, instruction: str, text: str) -> Tuple[np.ndarray, int]:
# trainer-12 pattern: embed instruction + text verbatim for the LM (no trait parsing).
prompt = _build_vocence_prompt(instruction, text)
inputs = self._processor(
text=[prompt],
voice_samples=None,
padding=True,
return_tensors="pt",
return_attention_mask=True,
)
target = self._device if self._device != "cpu" else "cpu"
for k, v in inputs.items():
if torch.is_tensor(v):
inputs[k] = v.to(target)
with torch.inference_mode():
outputs = self._model.generate(
**inputs,
max_new_tokens=None,
cfg_scale=self._cfg_scale,
tokenizer=self._processor.tokenizer,
generation_config={"do_sample": False},
verbose=False,
is_prefill=not self._disable_prefill,
)
if not outputs.speech_outputs or outputs.speech_outputs[0] is None:
raise RuntimeError("QWEN3Vox returned no speech output.")
wav = self._speech_tensor_to_numpy(outputs.speech_outputs[0])
return (wav, self._sample_rate)