from __future__ import annotations import io import json import os import re import sys import threading import traceback from functools import cached_property from pathlib import Path from types import SimpleNamespace from typing import AbstractSet, Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch from transformers.utils import logging as hf_logging import math import random import warnings from dataclasses import dataclass try: import librosa except Exception: librosa = None try: import resampy except Exception: resampy = None def _resample_if_needed(wav: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: if orig_sr == target_sr: return wav.astype(np.float32, copy=False) if resampy is not None: return resampy.resample(wav.astype(np.float32), orig_sr, target_sr) if librosa is not None: return librosa.resample( y=wav.astype(np.float32), orig_sr=orig_sr, target_sr=target_sr ) warnings.warn( "No resampler available; treating audio as target_sr without resampling. Install resampy or librosa.", RuntimeWarning, ) return wav.astype(np.float32, copy=False) class QWEN3VoxDataset: def __init__( self, dataset: Any, text_column: str = "text", audio_column: str = "audio", voice_prompts_column: Optional[str] = "voice_prompts", ) -> None: self.dataset = dataset self.text_column = text_column self.audio_column = audio_column self.voice_prompts_column = voice_prompts_column def __len__(self) -> int: return len(self.dataset) def __getitem__(self, idx: int) -> Dict[str, Any]: item = self.dataset[idx] data: Dict[str, Any] = {} data["text"] = item[self.text_column] data["audio"] = item[self.audio_column] user_provided_prompt = None if self.voice_prompts_column and self.voice_prompts_column in item: user_provided_prompt = item[self.voice_prompts_column] if user_provided_prompt: if not isinstance(user_provided_prompt, list): data["voice_prompts"] = [user_provided_prompt] else: data["voice_prompts"] = user_provided_prompt else: try: target_sr = 22050 wav_array = _load_audio_to_24k( item[self.audio_column], target_sr=target_sr ) audio_len_seconds = len(wav_array) / target_sr min_len_sec = min(5.0, audio_len_seconds / 4.0) max_len_sec = min(15.0, audio_len_seconds / 2.0) if min_len_sec > max_len_sec: min_len_sec = max_len_sec max_len_sec = min(max_len_sec, audio_len_seconds) if max_len_sec > 0.1: prompt_len_sec = random.uniform(min_len_sec, max_len_sec) prompt_len_samples = int(prompt_len_sec * target_sr) max_start_sample = len(wav_array) - prompt_len_samples start_sample = random.randint(0, max_start_sample) prompt_crop = wav_array[ start_sample : start_sample + prompt_len_samples ] data["voice_prompts"] = [prompt_crop] else: data["voice_prompts"] = None except Exception as e: warnings.warn(f"Could not create voice prompt for item {idx }: {e }") data["voice_prompts"] = None return data def _apply_silence_with_crossfade( wav: np.ndarray, *, sample_rate: int, pre_silence_sec: float = 0.25, pre_crossfade_sec: float = 0.25, post_crossfade_sec: float = 0.25, post_silence_sec: float = 0.75, ) -> np.ndarray: wav = np.asarray(wav, dtype=np.float32).reshape(-1) start_sil_samples = int(round(pre_silence_sec * sample_rate)) end_sil_samples = int(round(post_silence_sec * sample_rate)) pre_crossfade_samples = int(round(pre_crossfade_sec * sample_rate)) post_crossfade_samples = int(round(post_crossfade_sec * sample_rate)) total_len = wav.shape[0] if total_len == 0: pieces: List[np.ndarray] = [] if start_sil_samples > 0: pieces.append(np.zeros(start_sil_samples, dtype=np.float32)) if end_sil_samples > 0: pieces.append(np.zeros(end_sil_samples, dtype=np.float32)) return np.concatenate(pieces) if pieces else wav start_len = min(pre_crossfade_samples, total_len) remaining_after_start = max(total_len - start_len, 0) end_len = min(post_crossfade_samples, remaining_after_start) middle_end_idx = total_len - end_len start_segment = wav[:start_len] middle_segment = wav[start_len:middle_end_idx] end_segment = wav[middle_end_idx:] def _linear_fade(num_samples: int, start: float, end: float) -> np.ndarray: if num_samples <= 0: return np.zeros((0,), dtype=np.float32) return np.linspace(start, end, num_samples, endpoint=True, dtype=np.float32) start_crossfade = start_segment * _linear_fade(start_len, 0.0, 1.0) end_crossfade = end_segment * _linear_fade(end_segment.shape[0], 1.0, 0.0) pieces: List[np.ndarray] = [] if start_sil_samples > 0: pieces.append(np.zeros(start_sil_samples, dtype=np.float32)) if start_crossfade.size > 0: pieces.append(start_crossfade.astype(np.float32, copy=False)) if middle_segment.size > 0: pieces.append(middle_segment.astype(np.float32, copy=False)) if end_crossfade.size > 0: pieces.append(end_crossfade.astype(np.float32, copy=False)) if end_sil_samples > 0: pieces.append(np.zeros(end_sil_samples, dtype=np.float32)) return np.concatenate(pieces) def _load_audio_to_24k( audio: Union[str, np.ndarray, torch.Tensor, Dict[str, Any]], *, target_sr: int = 22050, augment_with_silence: bool = False, ) -> np.ndarray: if isinstance(audio, np.ndarray): wav_out = audio.astype(np.float32) elif isinstance(audio, torch.Tensor): wav_out = audio.detach().cpu().float().numpy() elif isinstance(audio, str): if librosa is None: raise RuntimeError( "librosa is required to load audio file paths. Please pip install librosa." ) wav, sr = librosa.load(audio, sr=None, mono=True) wav_out = _resample_if_needed(wav, int(sr), target_sr) elif isinstance(audio, dict) and "array" in audio and ("sampling_rate" in audio): arr = np.asarray(audio["array"], dtype=np.float32) sr = int(audio["sampling_rate"]) wav_out = _resample_if_needed(arr, sr, target_sr) else: raise ValueError(f"Unsupported audio type: {type (audio )}") wav_out = np.asarray(wav_out, dtype=np.float32) if augment_with_silence: wav_out = _apply_silence_with_crossfade(wav_out, sample_rate=target_sr) return wav_out @dataclass class QWEN3VoxCollator: processor: Any max_length: Optional[int] = None speech_compress_ratio: int = 3200 semantic_vae_dim: int = 128 compute_semantics: bool = False debug_checks: bool = False text_field: str = "text" audio_field: str = "audio" voice_prompts_field: str = "voice_prompts" voice_prompt_drop_rate: float = 0.0 def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, Any]: batch_size = len(features) sample_input_ids: List[List[int]] = [] sample_attention_masks: List[List[int]] = [] sample_acoustic_input_masks: List[List[bool]] = [] sample_acoustic_loss_masks: List[List[bool]] = [] all_speech_waveforms: List[np.ndarray] = [] all_speech_latent_lengths: List[int] = [] per_segment_is_target: List[bool] = [] for ex in features: text: str = ex.get(self.text_field, "") voice_prompts: Optional[List[Union[str, np.ndarray, torch.Tensor]]] = ( ex.get(self.voice_prompts_field) ) target_audio: Union[str, np.ndarray, torch.Tensor, Dict[str, Any]] = ex.get( self.audio_field ) _drop_rate = self.voice_prompt_drop_rate if _drop_rate < 0.0: _drop_rate = 0.0 elif _drop_rate > 1.0: _drop_rate = 1.0 proc = self.processor( text=[text], voice_samples=( [voice_prompts] if voice_prompts is not None and random.random() >= _drop_rate else None ), padding=False, truncation=False, max_length=self.max_length, return_tensors="pt", ) ids = proc["input_ids"][0].tolist() attn = proc.get("attention_mask", torch.ones_like(proc["input_ids"]))[ 0 ].tolist() speech_input_mask = proc.get("speech_input_mask") if speech_input_mask is None: speech_input_mask = torch.zeros_like( proc["input_ids"], dtype=torch.bool ) speech_input_mask_list = speech_input_mask[0].tolist() wav_target = _load_audio_to_24k( target_audio, target_sr=22050, augment_with_silence=True ) target_latent_len = None try: acoustic_tok = getattr(self.processor, "acoustic_tokenizer", None) if acoustic_tok is not None and hasattr(acoustic_tok, "encode"): enc_out = acoustic_tok.encode(wav_target) T = None try: if ( hasattr(enc_out, "shape") and len(getattr(enc_out, "shape", [])) >= 1 ): T = int(enc_out.shape[0]) else: cand = enc_out for _ in range(2): if isinstance(cand, (list, tuple)) and len(cand) > 0: cand = cand[0] if ( hasattr(cand, "shape") and len(getattr(cand, "shape", [])) >= 1 ): T = int(cand.shape[0]) except Exception: T = None if T is not None and T > 0: target_latent_len = T except Exception: target_latent_len = None if target_latent_len is None: target_latent_len = max( 1, int(math.ceil(len(wav_target) / float(self.speech_compress_ratio))), ) speech_diff_id = self.processor.tokenizer.speech_diffusion_id target_placeholders = [speech_diff_id] * target_latent_len ids_extended = ids + target_placeholders attn_extended = attn + [1] * target_latent_len acoustic_input_mask = speech_input_mask_list + [True] * target_latent_len acoustic_loss_mask = [False] * len(speech_input_mask_list) + [ True ] * target_latent_len speech_end_id = self.processor.tokenizer.speech_end_id ids_extended.append(speech_end_id) attn_extended.append(1) acoustic_input_mask.append(False) acoustic_loss_mask.append(False) eos_token_id = getattr(self.processor.tokenizer, "eos_id", None) if eos_token_id is None: eos_token_id = getattr(self.processor.tokenizer, "eos_token_id", None) if eos_token_id is not None and eos_token_id >= 0: ids_extended.append(eos_token_id) attn_extended.append(1) acoustic_input_mask.append(False) acoustic_loss_mask.append(False) if self.max_length is not None and len(ids_extended) > self.max_length: cut = len(ids_extended) - int(self.max_length) leading_non_acoustic = 0 for v in acoustic_input_mask: if v: break leading_non_acoustic += 1 if cut > leading_non_acoustic: raise ValueError( f"--max_length={self .max_length } would truncate into acoustic tokens. Needed cut={cut }, but only {leading_non_acoustic } leading non-acoustic tokens available. Increase max_length or shorten text/voice-prompt preamble." ) ids_extended = ids_extended[cut:] attn_extended = attn_extended[cut:] acoustic_input_mask = acoustic_input_mask[cut:] acoustic_loss_mask = acoustic_loss_mask[cut:] sample_input_ids.append(ids_extended) sample_attention_masks.append(attn_extended) sample_acoustic_input_masks.append(acoustic_input_mask) sample_acoustic_loss_masks.append(acoustic_loss_mask) voice_speeches = [] voice_latent_lengths = [] if proc.get("speech_tensors") is not None: voice_np = proc["speech_tensors"].cpu().numpy() voice_masks = proc["speech_masks"].cpu().numpy().astype(bool) for seg_idx in range(voice_np.shape[0]): voice_speeches.append(voice_np[seg_idx]) voice_latent_lengths.append(int(voice_masks[seg_idx].sum())) all_speech_waveforms.extend(voice_speeches) all_speech_latent_lengths.extend(voice_latent_lengths) per_segment_is_target.extend([False] * len(voice_speeches)) all_speech_waveforms.append(wav_target) all_speech_latent_lengths.append(target_latent_len) per_segment_is_target.append(True) max_seq_len = max((len(x) for x in sample_input_ids)) padded_input_ids = [] padded_attention_masks = [] padded_acoustic_input_masks = [] padded_acoustic_loss_masks = [] tok = self.processor.tokenizer pad_token_id = getattr(tok, "pad_token_id", None) if pad_token_id is None or pad_token_id < 0: pad_token_id = getattr(tok, "eos_token_id", None) if pad_token_id is None or pad_token_id < 0: raise ValueError( "Tokenizer has no pad_token_id or eos_token_id; please set one or pass a valid pad id." ) for ids, attn, ain_mask, aloss_mask in zip( sample_input_ids, sample_attention_masks, sample_acoustic_input_masks, sample_acoustic_loss_masks, ): pad_len = max_seq_len - len(ids) padded_input_ids.append(ids + [pad_token_id] * pad_len) padded_attention_masks.append(attn + [0] * pad_len) padded_acoustic_input_masks.append(ain_mask + [False] * pad_len) padded_acoustic_loss_masks.append(aloss_mask + [False] * pad_len) input_ids_tensor = torch.tensor(padded_input_ids, dtype=torch.long) attention_mask_tensor = torch.tensor(padded_attention_masks, dtype=torch.long) acoustic_input_mask_tensor = torch.tensor( padded_acoustic_input_masks, dtype=torch.bool ) acoustic_loss_mask_tensor = torch.tensor( padded_acoustic_loss_masks, dtype=torch.bool ) if all_speech_waveforms: max_wave_len = max((w.shape[0] for w in all_speech_waveforms)) padded_speeches = np.zeros( (len(all_speech_waveforms), max_wave_len), dtype=np.float32 ) for i, w in enumerate(all_speech_waveforms): L = w.shape[0] padded_speeches[i, :L] = w max_latent_len = ( max(all_speech_latent_lengths) if all_speech_latent_lengths else 1 ) speech_masks_np = np.zeros( (len(all_speech_waveforms), max_latent_len), dtype=np.bool_ ) for i, L_lat in enumerate(all_speech_latent_lengths): speech_masks_np[i, :L_lat] = True speech_tensors_tensor = torch.tensor(padded_speeches, dtype=torch.float32) speech_masks_tensor = torch.tensor(speech_masks_np, dtype=torch.bool) speeches_loss_input_np = np.zeros_like(speech_masks_np, dtype=np.bool_) for i, is_target in enumerate(per_segment_is_target): if is_target: speeches_loss_input_np[i] = speech_masks_np[i] speeches_loss_input_tensor = torch.tensor( speeches_loss_input_np, dtype=torch.bool ) if ( self.compute_semantics and hasattr(self.processor, "semantic_tokenizer") and (self.processor.semantic_tokenizer is not None) ): sem_feats: List[np.ndarray] = [] for w in all_speech_waveforms: try: sem = self.processor.semantic_tokenizer.encode(w) sem = np.asarray(sem, dtype=np.float32) except Exception: sem = np.zeros((0, self.semantic_vae_dim), dtype=np.float32) if sem.ndim != 2: raise RuntimeError( f"Semantic tokenizer returned unexpected shape {sem .shape }. Expect [T, D]." ) L = sem.shape[0] D = sem.shape[1] if D != self.semantic_vae_dim: if D < self.semantic_vae_dim: pad_d = np.zeros( (L, self.semantic_vae_dim - D), dtype=np.float32 ) sem = np.concatenate([sem, pad_d], axis=1) else: sem = sem[:, : self.semantic_vae_dim] if L < max_latent_len: pad = np.zeros( (max_latent_len - L, self.semantic_vae_dim), dtype=np.float32, ) sem = np.concatenate([sem, pad], axis=0) elif L > max_latent_len: sem = sem[:max_latent_len] sem_feats.append(sem.astype(np.float32)) speech_semantic_tensors = torch.tensor( np.stack(sem_feats, axis=0), dtype=torch.float32 ) else: raise RuntimeError( "Semantic features are required but could not be computed. Ensure processor.semantic_tokenizer is available or precompute and provide features." ) else: speech_tensors_tensor = None speech_masks_tensor = None speeches_loss_input_tensor = None speech_semantic_tensors = None if self.debug_checks: assert (input_ids_tensor >= 0).all(), "input_ids contains negative indices" if speech_tensors_tensor is not None: assert ( speech_tensors_tensor.dim() == 2 ), "Expected speech_tensors 2D [segments, samples]" return { "input_ids": input_ids_tensor, "attention_mask": attention_mask_tensor, "speech_tensors": speech_tensors_tensor, "speech_masks": speech_masks_tensor, "speech_semantic_tensors": speech_semantic_tensors, "acoustic_input_mask": acoustic_input_mask_tensor, "acoustic_loss_mask": acoustic_loss_mask_tensor, "speeches_loss_input": speeches_loss_input_tensor, } ' QWEN3Vox_AcousticTokenizer model configuration' from typing import Dict, List, Optional, Tuple from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging from transformers.models.qwen2.configuration_qwen2 import Qwen2Config logger = logging.get_logger(__name__) class QWEN3VoxAcousticTokenizerConfig(PretrainedConfig): model_type = 'vibevoice_acoustic_tokenizer' def __init__( self, channels: int = 1, corpus_normalize: float = 0.0, causal: bool = True, vae_dim: int = 64, fix_std: float = 0.5, std_dist_type: str = "gaussian", mixer_layer: str = "depthwise_conv", conv_norm: str = "none", pad_mode: str = "constant", disable_last_norm: bool = True, layernorm: str = "RMSNorm", layernorm_eps: float = 1e-05, layernorm_elementwise_affine: bool = True, conv_bias: bool = True, layer_scale_init_value: float = 1e-06, weight_init_value: float = 0.01, encoder_n_filters: int = 32, encoder_ratios: Optional[List[int]] = [8, 5, 5, 4, 2, 2], encoder_depths: str = "3-3-3-3-3-3-8", decoder_n_filters: int = 32, decoder_ratios: Optional[List[int]] = None, decoder_depths: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) self.channels = channels self.corpus_normalize = corpus_normalize self.causal = causal self.vae_dim = vae_dim self.fix_std = fix_std self.std_dist_type = std_dist_type self.conv_norm = conv_norm self.pad_mode = pad_mode self.layernorm_eps = layernorm_eps self.disable_last_norm = disable_last_norm self.layernorm = layernorm self.layernorm_elementwise_affine = layernorm_elementwise_affine self.conv_bias = conv_bias self.layer_scale_init_value = layer_scale_init_value self.weight_init_value = weight_init_value self.mixer_layer = mixer_layer self.encoder_n_filters = encoder_n_filters self.encoder_ratios = encoder_ratios self.encoder_depths = encoder_depths self.decoder_ratios = ( decoder_ratios if decoder_ratios is not None else encoder_ratios ) self.decoder_n_filters = decoder_n_filters self.decoder_depths = decoder_depths class QWEN3VoxSemanticTokenizerConfig(PretrainedConfig): model_type = 'vibevoice_semantic_tokenizer' def __init__( self, channels: int = 1, corpus_normalize: float = 0.0, causal: bool = True, vae_dim: int = 64, fix_std: float = 0, std_dist_type: str = "none", mixer_layer: str = "depthwise_conv", conv_norm: str = "none", pad_mode: str = "constant", disable_last_norm: bool = True, layernorm: str = "RMSNorm", layernorm_eps: float = 1e-05, layernorm_elementwise_affine: bool = True, conv_bias: bool = True, layer_scale_init_value: float = 1e-06, weight_init_value: float = 0.01, encoder_n_filters: int = 32, encoder_ratios: Optional[List[int]] = [8, 5, 5, 4, 2, 2], encoder_depths: str = "3-3-3-3-3-3-8", **kwargs, ): super().__init__(**kwargs) self.channels = channels self.corpus_normalize = corpus_normalize self.causal = causal self.vae_dim = vae_dim self.fix_std = fix_std self.std_dist_type = std_dist_type self.conv_norm = conv_norm self.pad_mode = pad_mode self.layernorm_eps = layernorm_eps self.disable_last_norm = disable_last_norm self.layernorm = layernorm self.layernorm_elementwise_affine = layernorm_elementwise_affine self.conv_bias = conv_bias self.layer_scale_init_value = layer_scale_init_value self.weight_init_value = weight_init_value self.mixer_layer = mixer_layer self.encoder_n_filters = encoder_n_filters self.encoder_ratios = encoder_ratios self.encoder_depths = encoder_depths class QWEN3VoxDiffusionHeadConfig(PretrainedConfig): model_type = 'vibevoice_diffusion_head' def __init__( self, hidden_size=768, head_layers=4, head_ffn_ratio=3.0, rms_norm_eps=1e-05, latent_size=64, speech_vae_dim=None, prediction_type="v_prediction", diffusion_type="ddpm", ddpm_num_steps=1000, ddpm_num_inference_steps=30, ddpm_beta_schedule="cosine", ddpm_batch_mul=4, **kwargs, ): self.hidden_size = hidden_size self.head_layers = head_layers self.head_ffn_ratio = head_ffn_ratio self.rms_norm_eps = rms_norm_eps self.latent_size = latent_size self.speech_vae_dim = speech_vae_dim self.prediction_type = prediction_type self.diffusion_type = diffusion_type self.ddpm_num_steps = ddpm_num_steps self.ddpm_num_inference_steps = ddpm_num_inference_steps self.ddpm_beta_schedule = ddpm_beta_schedule self.ddpm_batch_mul = ddpm_batch_mul super().__init__(**kwargs) class QWEN3VoxConfig(PretrainedConfig): model_type = 'vibevoice' is_composition = True sub_configs = { "acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig, "semantic_tokenizer_config": QWEN3VoxSemanticTokenizerConfig, "decoder_config": Qwen2Config, "diffusion_head_config": QWEN3VoxDiffusionHeadConfig, } base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } def __init__( self, acoustic_tokenizer_config=None, semantic_tokenizer_config=None, decoder_config=None, diffusion_head_config=None, **kwargs, ): kwargs["_attn_implementation_autoset"] = False if acoustic_tokenizer_config is None: self.acoustic_tokenizer_config = self.sub_configs[ "acoustic_tokenizer_config" ]() elif isinstance(acoustic_tokenizer_config, dict): acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer' self.acoustic_tokenizer_config = self.sub_configs[ "acoustic_tokenizer_config" ](**acoustic_tokenizer_config) elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig): self.acoustic_tokenizer_config = acoustic_tokenizer_config if semantic_tokenizer_config is None: self.semantic_tokenizer_config = self.sub_configs[ "semantic_tokenizer_config" ]() elif isinstance(semantic_tokenizer_config, dict): semantic_tokenizer_config["model_type"] = 'vibevoice_semantic_tokenizer' self.semantic_tokenizer_config = self.sub_configs[ "semantic_tokenizer_config" ](**semantic_tokenizer_config) elif isinstance(semantic_tokenizer_config, QWEN3VoxSemanticTokenizerConfig): self.semantic_tokenizer_config = semantic_tokenizer_config if decoder_config is None: self.decoder_config = self.sub_configs["decoder_config"]() elif isinstance(decoder_config, dict): if decoder_config.get("model_type", "") == "qwen2": self.decoder_config = Qwen2Config(**decoder_config) else: raise ValueError( f"Unsupported decoder model type: {decoder_config .get ('model_type','')}" ) elif isinstance(decoder_config, (Qwen2Config,)): self.decoder_config = decoder_config if diffusion_head_config is None: self.diffusion_head_config = self.sub_configs["diffusion_head_config"]() elif isinstance(diffusion_head_config, dict): diffusion_head_config["model_type"] = 'vibevoice_diffusion_head' self.diffusion_head_config = self.sub_configs["diffusion_head_config"]( **diffusion_head_config ) elif isinstance(diffusion_head_config, QWEN3VoxDiffusionHeadConfig): self.diffusion_head_config = diffusion_head_config self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64) self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, "vae_dim", 128) super().__init__(**kwargs) class QWEN3VoxASRConfig(PretrainedConfig): model_type = 'vibevoice' is_composition = True sub_configs = { "acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig, "semantic_tokenizer_config": QWEN3VoxSemanticTokenizerConfig, "decoder_config": Qwen2Config, } base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } def __init__( self, acoustic_tokenizer_config=None, semantic_tokenizer_config=None, decoder_config=None, **kwargs, ): kwargs["_attn_implementation_autoset"] = False if acoustic_tokenizer_config is None: self.acoustic_tokenizer_config = self.sub_configs[ "acoustic_tokenizer_config" ]() elif isinstance(acoustic_tokenizer_config, dict): acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer' self.acoustic_tokenizer_config = self.sub_configs[ "acoustic_tokenizer_config" ](**acoustic_tokenizer_config) elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig): self.acoustic_tokenizer_config = acoustic_tokenizer_config if semantic_tokenizer_config is None: self.semantic_tokenizer_config = self.sub_configs[ "semantic_tokenizer_config" ]() elif isinstance(semantic_tokenizer_config, dict): semantic_tokenizer_config["model_type"] = 'vibevoice_semantic_tokenizer' self.semantic_tokenizer_config = self.sub_configs[ "semantic_tokenizer_config" ](**semantic_tokenizer_config) elif isinstance(semantic_tokenizer_config, QWEN3VoxSemanticTokenizerConfig): self.semantic_tokenizer_config = semantic_tokenizer_config if decoder_config is None: self.decoder_config = self.sub_configs["decoder_config"]() elif isinstance(decoder_config, dict): if decoder_config.get("model_type", "") == "qwen2": self.decoder_config = Qwen2Config(**decoder_config) else: raise ValueError( f"Unsupported decoder model type: {decoder_config .get ('model_type','')}" ) elif isinstance(decoder_config, Qwen2Config): self.decoder_config = decoder_config self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64) self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, "vae_dim", 128) super().__init__(**kwargs) def get_text_config(self, decoder: bool = False): return self.decoder_config @property def vocab_size(self): return self.decoder_config.vocab_size @property def num_attention_heads(self): return self.decoder_config.num_attention_heads @property def num_key_value_heads(self): return self.decoder_config.num_key_value_heads @property def hidden_size(self): return self.decoder_config.hidden_size @property def num_hidden_layers(self): return self.decoder_config.num_hidden_layers @property def head_dim(self): return getattr( self.decoder_config, "head_dim", self.hidden_size // self.num_attention_heads, ) __all__ = [ 'QWEN3VoxAcousticTokenizerConfig', 'QWEN3VoxSemanticTokenizerConfig', 'QWEN3VoxDiffusionHeadConfig', 'QWEN3VoxConfig', 'QWEN3VoxASRConfig', ] import torch import asyncio from queue import Queue from typing import TYPE_CHECKING, Optional from transformers.generation import BaseStreamer class AudioStreamer(BaseStreamer): def __init__( self, batch_size: int, stop_signal: Optional[any] = None, timeout: Optional[float] = None, ): self.batch_size = batch_size self.stop_signal = stop_signal self.timeout = timeout self.audio_queues = [Queue() for _ in range(batch_size)] self.finished_flags = [False for _ in range(batch_size)] self.sample_indices_map = {} def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): for i, sample_idx in enumerate(sample_indices): idx = sample_idx.item() if idx < self.batch_size and (not self.finished_flags[idx]): audio_chunk = audio_chunks[i].detach().cpu() self.audio_queues[idx].put(audio_chunk, timeout=self.timeout) def end(self, sample_indices: Optional[torch.Tensor] = None): if sample_indices is None: for idx in range(self.batch_size): if not self.finished_flags[idx]: self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) self.finished_flags[idx] = True else: for sample_idx in sample_indices: idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx if idx < self.batch_size and (not self.finished_flags[idx]): self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) self.finished_flags[idx] = True def __iter__(self): return AudioBatchIterator(self) def get_stream(self, sample_idx: int): if sample_idx >= self.batch_size: raise ValueError( f"Sample index {sample_idx } exceeds batch size {self .batch_size }" ) return AudioSampleIterator(self, sample_idx) class AudioSampleIterator: def __init__(self, streamer: AudioStreamer, sample_idx: int): self.streamer = streamer self.sample_idx = sample_idx def __iter__(self): return self def __next__(self): value = self.streamer.audio_queues[self.sample_idx].get( timeout=self.streamer.timeout ) if value == self.streamer.stop_signal: raise StopIteration() return value class AudioBatchIterator: def __init__(self, streamer: AudioStreamer): self.streamer = streamer self.active_samples = set(range(streamer.batch_size)) def __iter__(self): return self def __next__(self): if not self.active_samples: raise StopIteration() batch_chunks = {} samples_to_remove = set() for idx in self.active_samples: try: value = self.streamer.audio_queues[idx].get(block=False) if value == self.streamer.stop_signal: samples_to_remove.add(idx) else: batch_chunks[idx] = value except: pass self.active_samples -= samples_to_remove if batch_chunks: return batch_chunks elif self.active_samples: import time time.sleep(0.01) return self.__next__() else: raise StopIteration() class AsyncAudioStreamer(AudioStreamer): def __init__( self, batch_size: int, stop_signal: Optional[any] = None, timeout: Optional[float] = None, ): super().__init__(batch_size, stop_signal, timeout) self.audio_queues = [asyncio.Queue() for _ in range(batch_size)] self.loop = asyncio.get_running_loop() def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): for i, sample_idx in enumerate(sample_indices): idx = sample_idx.item() if idx < self.batch_size and (not self.finished_flags[idx]): audio_chunk = audio_chunks[i].detach().cpu() self.loop.call_soon_threadsafe( self.audio_queues[idx].put_nowait, audio_chunk ) def end(self, sample_indices: Optional[torch.Tensor] = None): if sample_indices is None: indices_to_end = range(self.batch_size) else: indices_to_end = [ s.item() if torch.is_tensor(s) else s for s in sample_indices ] for idx in indices_to_end: if idx < self.batch_size and (not self.finished_flags[idx]): self.loop.call_soon_threadsafe( self.audio_queues[idx].put_nowait, self.stop_signal ) self.finished_flags[idx] = True async def get_stream(self, sample_idx: int): if sample_idx >= self.batch_size: raise ValueError( f"Sample index {sample_idx } exceeds batch size {self .batch_size }" ) while True: value = await self.audio_queues[sample_idx].get() if value == self.stop_signal: break yield value def __aiter__(self): return AsyncAudioBatchIterator(self) class AsyncAudioBatchIterator: def __init__(self, streamer: AsyncAudioStreamer): self.streamer = streamer self.active_samples = set(range(streamer.batch_size)) def __aiter__(self): return self async def __anext__(self): if not self.active_samples: raise StopAsyncIteration() batch_chunks = {} samples_to_remove = set() tasks = { idx: asyncio.create_task(self._get_chunk(idx)) for idx in self.active_samples } done, pending = await asyncio.wait( tasks.values(), return_when=asyncio.FIRST_COMPLETED, timeout=self.streamer.timeout, ) for task in pending: task.cancel() for idx, task in tasks.items(): if task in done: try: value = await task if value == self.streamer.stop_signal: samples_to_remove.add(idx) else: batch_chunks[idx] = value except asyncio.CancelledError: pass self.active_samples -= samples_to_remove if batch_chunks: return batch_chunks elif self.active_samples: return await self.__anext__() else: raise StopAsyncIteration() async def _get_chunk(self, idx): return await self.streamer.audio_queues[idx].get() 'Tokenization classes for QWEN3Vox.' from typing import List, Optional, Union from transformers.utils import logging from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast logger = logging.get_logger(__name__) class QWEN3VoxTextTokenizer(Qwen2Tokenizer): model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file, merges_file, errors="replace", unk_token="<|endoftext|>", bos_token=None, eos_token="<|endoftext|>", pad_token="<|endoftext|>", add_prefix_space=False, add_special_tokens=True, **kwargs, ): super().__init__( vocab_file=vocab_file, merges_file=merges_file, errors=errors, unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, add_prefix_space=add_prefix_space, add_special_tokens=add_special_tokens, **kwargs, ) self._add_q3_sp_tok() def _add_q3_sp_tok(self): special_tokens = { "additional_special_tokens": [ "<|vision_start|>", "<|vision_end|>", "<|vision_pad|>", ] } num_added = self.add_special_tokens(special_tokens) self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>") self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>") self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>") self._eos_id = self.convert_tokens_to_ids("<|endoftext|>") return num_added @property def eos_id(self) -> int: return self._eos_id @property def speech_start_id(self) -> int: return self._speech_start_id @property def speech_end_id(self) -> int: return self._speech_end_id @property def speech_diffusion_id(self) -> int: return self._speech_diffusion_id @property def pad_id(self) -> int: return -100 class QWEN3VoxTextTokenizerFast(Qwen2TokenizerFast): model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file=None, merges_file=None, tokenizer_file=None, unk_token="<|endoftext|>", bos_token=None, eos_token="<|endoftext|>", pad_token="<|endoftext|>", add_prefix_space=False, **kwargs, ): super().__init__( vocab_file=vocab_file, merges_file=merges_file, tokenizer_file=tokenizer_file, unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, add_prefix_space=add_prefix_space, **kwargs, ) self._add_q3_sp_tok() def _add_q3_sp_tok(self): special_tokens = { "additional_special_tokens": [ "<|vision_start|>", "<|vision_end|>", "<|vision_pad|>", ] } num_added = self.add_special_tokens(special_tokens) self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>") self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>") self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>") self._eos_id = self.eos_token_id self._pad_id = self.convert_tokens_to_ids("<|image_pad|>") return num_added @property def eos_id(self) -> int: return self._eos_id @property def speech_start_id(self) -> int: return self._speech_start_id @property def speech_end_id(self) -> int: return self._speech_end_id @property def speech_diffusion_id(self) -> int: return self._speech_diffusion_id @property def pad_id(self) -> int: return self._pad_id QWEN3VoxASRTextTokenizerFast = QWEN3VoxTextTokenizerFast __all__ = [ 'QWEN3VoxTextTokenizer', 'QWEN3VoxTextTokenizerFast', ] "Utilities for loading fine-tuned LoRA adapters and connector weights." from dataclasses import dataclass from pathlib import Path from typing import Optional import torch import torch.nn as nn from transformers.utils import logging logger = logging.get_logger(__name__) @dataclass class _LoadReport: language_model: bool = False diffusion_head_lora: bool = False diffusion_head_full: bool = False acoustic_connector: bool = False semantic_connector: bool = False adapter_root: Optional[Path] = None class _DiffusionHeadForwardShim(nn.Module): def __init__(self, base: nn.Module): super().__init__() self.base = base def forward(self, *args, **kwargs): if len(args) >= 3: noisy_images, timesteps, condition = args[:3] else: noisy_images = kwargs.get("noisy_images") timesteps = kwargs.get("timesteps") condition = kwargs.get("condition") return self.base(noisy_images, timesteps, condition) def _resolve_adapter_root(checkpoint_path: Path) -> Path: if checkpoint_path.is_file(): checkpoint_path = checkpoint_path.parent if (checkpoint_path / "lora").exists(): return checkpoint_path / "lora" return checkpoint_path def _load_connector( module: Optional[nn.Module], path: Path, device: torch.device ) -> bool: if module is None or not path.exists(): return False state_dict = torch.load(path, map_location=device) missing, unexpected = module.load_state_dict(state_dict, strict=False) if missing: logger.warning(f"Connector load missing keys: {missing }") if unexpected: logger.warning(f"Connector load unexpected keys: {unexpected }") module.to(device) return True def _load_diffusion_head( model, adapter_root: Path, device: torch.device, report: _LoadReport ) -> None: diff_dir = adapter_root / "diffusion_head" adapter_config = diff_dir / "adapter_config.json" adapter_model = diff_dir / "adapter_model.bin" adapter_model_safetensors = diff_dir / "adapter_model.safetensors" try: from peft import PeftModel except ImportError as exc: raise RuntimeError( "peft is required to load diffusion head adapters but is not installed" ) from exc if adapter_config.exists() and ( adapter_model.exists() or adapter_model_safetensors.exists() ): logger.info(f"Loading diffusion head LoRA from {diff_dir }") shim = _DiffusionHeadForwardShim(model.model.prediction_head) _peft_load = getattr(PeftModel, "from_pretrained") peft_head = _peft_load(shim, diff_dir) peft_head.to(device) model.model.prediction_head = peft_head report.diffusion_head_lora = True return full_path = diff_dir / "diffusion_head_full.bin" if not full_path.exists(): full_path = adapter_root / "diffusion_head_full.bin" if full_path.exists(): logger.info(f"Loading full diffusion head weights from {full_path }") state_dict = torch.load(full_path, map_location=device) missing, unexpected = model.model.prediction_head.load_state_dict( state_dict, strict=False ) if missing: logger.warning(f"Diffusion head load missing keys: {missing }") if unexpected: logger.warning(f"Diffusion head load unexpected keys: {unexpected }") model.model.prediction_head.to(device) report.diffusion_head_full = True def _load_language_model( model, adapter_root: Path, device: torch.device, report: _LoadReport ) -> None: config_file = adapter_root / "adapter_config.json" bin_file = adapter_root / "adapter_model.bin" safe_tensors_file = adapter_root / "adapter_model.safetensors" if not (config_file.exists() and (bin_file.exists() or safe_tensors_file.exists())): return try: from peft import PeftModel except ImportError as exc: raise RuntimeError( "peft is required to load language model adapters but is not installed" ) from exc logger.info(f"Loading language model LoRA from {adapter_root }") _peft_load = getattr(PeftModel, "from_pretrained") peft_lm = _peft_load(model.model.language_model, adapter_root) peft_lm.to(device) model.model.language_model = peft_lm if hasattr(model, "tie_weights"): try: model.tie_weights() except Exception as exc: logger.warning( f"Failed to retie weights after loading language LoRA: {exc }" ) report.language_model = True def load_lora_assets( model, checkpoint_dir: str, device: Optional[torch.device] = None ) -> _LoadReport: adapter_root = _resolve_adapter_root(Path(checkpoint_dir)) if not adapter_root.exists(): raise FileNotFoundError(f"Adapter directory not found: {adapter_root }") inferred_device = device or next(model.parameters()).device report = _LoadReport(adapter_root=adapter_root) _load_language_model(model, adapter_root, inferred_device, report) _load_diffusion_head(model, adapter_root, inferred_device, report) ac_path = adapter_root / "acoustic_connector" / "pytorch_model.bin" if _load_connector( getattr(model.model, "acoustic_connector", None), ac_path, inferred_device ): report.acoustic_connector = True se_path = adapter_root / "semantic_connector" / "pytorch_model.bin" if _load_connector( getattr(model.model, "semantic_connector", None), se_path, inferred_device ): report.semantic_connector = True if not any(report.__dict__.values()): logger.warning( "No adapter assets were loaded. Ensure the checkpoint directory is correct and contains LoRA weights." ) return report import math from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import deprecate from diffusers.utils.torch_utils import randn_tensor from diffusers.schedulers.scheduling_utils import ( KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput, ) def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine" ): if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) elif alpha_transform_type == "cauchy": def alpha_bar_fn(t, gamma=1, mu=3): snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9) return 1 - 1 / (math.exp(snr) + 1.1) elif alpha_transform_type == "laplace": def alpha_bar_fn(t, mu=0, b=1): snr = mu - b * math.copysign(1, 0.5 - t) * math.log( 1 - 2 * abs(t - 0.5) * 0.98 ) return 1 - 1 / (math.exp(snr) + 1.02) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type }") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) def rescale_zero_terminal_snr(betas): alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() alphas_bar_sqrt -= alphas_bar_sqrt_T alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) alphas_bar = alphas_bar_sqrt**2 alphas = alphas_bar[1:] / alphas_bar[:-1] alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, use_lu_lambdas: Optional[bool] = False, final_sigmas_type: Optional[str] = "zero", lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, ): if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: deprecation_message = f"algorithm_type {algorithm_type } is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecate( "algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message, ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace( beta_start, beta_end, num_train_timesteps, dtype=torch.float32 ) elif beta_schedule == "scaled_linear": self.betas = ( torch.linspace( beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32, ) ** 2 ) elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine": self.betas = betas_for_alpha_bar( num_train_timesteps, alpha_transform_type="cosine" ) elif beta_schedule == "cauchy": self.betas = betas_for_alpha_bar( num_train_timesteps, alpha_transform_type="cauchy" ) elif beta_schedule == "laplace": self.betas = betas_for_alpha_bar( num_train_timesteps, alpha_transform_type="laplace" ) else: raise NotImplementedError( f"{beta_schedule } is not implemented for {self .__class__ }" ) if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) if rescale_betas_zero_snr: self.alphas_cumprod[-1] = 2 ** (-24) self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 self.init_noise_sigma = 1.0 if algorithm_type not in [ "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++", ]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: raise NotImplementedError( f"{algorithm_type } is not implemented for {self .__class__ }" ) if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: raise NotImplementedError( f"{solver_type } is not implemented for {self .__class__ }" ) if ( algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero" ): raise ValueError( f"`final_sigmas_type` {final_sigmas_type } is not supported for `algorithm_type` {algorithm_type }. Please choose `sigma_min` instead." ) self.num_inference_steps = None timesteps = np.linspace( 0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32 )[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") @property def step_index(self): return self._step_index @property def begin_index(self): return self._begin_index def set_begin_index(self, begin_index: int = 0): self._begin_index = begin_index def set_timesteps( self, num_inference_steps: int = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, ): if num_inference_steps is None and timesteps is None: raise ValueError( "Must pass exactly one of `num_inference_steps` or `timesteps`." ) if num_inference_steps is not None and timesteps is not None: raise ValueError( "Can only pass one of `num_inference_steps` or `custom_timesteps`." ) if timesteps is not None and self.config.use_karras_sigmas: raise ValueError( "Cannot use `timesteps` with `config.use_karras_sigmas = True`" ) if timesteps is not None and self.config.use_lu_lambdas: raise ValueError( "Cannot use `timesteps` with `config.use_lu_lambdas = True`" ) if timesteps is not None: timesteps = np.array(timesteps).astype(np.int64) else: clipped_idx = torch.searchsorted( torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped ) last_timestep = ( (self.config.num_train_timesteps - clipped_idx).numpy().item() ) if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, last_timestep - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // (num_inference_steps + 1) timesteps = ( (np.arange(0, num_inference_steps + 1) * step_ratio) .round()[::-1][:-1] .copy() .astype(np.int64) ) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps timesteps = ( np.arange(last_timestep, 0, -step_ratio) .round() .copy() .astype(np.int64) ) timesteps -= 1 else: raise ValueError( f"{self .config .timestep_spacing } is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) if self.config.use_karras_sigmas: sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras( in_sigmas=sigmas, num_inference_steps=num_inference_steps ) timesteps = np.array( [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] ).round() elif self.config.use_lu_lambdas: lambdas = np.flip(log_sigmas.copy()) lambdas = self._convert_to_lu( in_lambdas=lambdas, num_inference_steps=num_inference_steps ) sigmas = np.exp(lambdas) timesteps = np.array( [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] ).round() else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 elif self.config.final_sigmas_type == "zero": sigma_last = 0 else: raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self .config .final_sigmas_type }" ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to( device=device, dtype=torch.int64 ) self.num_inference_steps = len(timesteps) self.model_outputs = [None] * self.config.solver_order self.lower_order_nums = 0 self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp(s, min=1, max=self.config.sample_max_value) s = s.unsqueeze(1) sample = torch.clamp(sample, -s, s) / s sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample def _sigma_to_t(self, sigma, log_sigmas): log_sigma = np.log(np.maximum(sigma, 1e-10)) dists = log_sigma - log_sigmas[:, np.newaxis] low_idx = ( np.cumsum(dists >= 0, axis=0) .argmax(axis=0) .clip(max=log_sigmas.shape[0] - 2) ) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = 1 / (sigma**2 + 1) ** 0.5 sigma_t = sigma * alpha_t return (alpha_t, sigma_t) def _convert_to_karras( self, in_sigmas: torch.Tensor, num_inference_steps ) -> torch.Tensor: if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def _convert_to_lu( self, in_lambdas: torch.Tensor, num_inference_steps ) -> torch.Tensor: lambda_min: float = in_lambdas[-1].item() lambda_max: float = in_lambdas[0].item() rho = 1.0 ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = lambda_min ** (1 / rho) max_inv_rho = lambda_max ** (1 / rho) lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return lambdas def convert_model_output( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs ) -> torch.Tensor: timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError("missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self .config .prediction_type } must be one of `epsilon`, `sample`, or `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if self.config.prediction_type == "epsilon": if self.config.variance_type in ["learned", "learned_range"]: epsilon = model_output[:, :3] else: epsilon = model_output elif self.config.prediction_type == "sample": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( f"prediction_type given as {self .config .prediction_type } must be one of `epsilon`, `sample`, or `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t return epsilon def dpm_solver_first_order_update( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = ( sigma_t / sigma_s * sample - alpha_t * (torch.exp(-h) - 1.0) * model_output ) elif self.config.algorithm_type == "dpmsolver": x_t = ( alpha_t / alpha_s * sample - sigma_t * (torch.exp(h) - 1.0) * model_output ) elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( sigma_t / sigma_s * torch.exp(-h) * sample + alpha_t * (1 - torch.exp(-2.0 * h)) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None x_t = ( alpha_t / alpha_s * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing `sample` as a required keyward argument") if timestep_list is not None: deprecate( "timestep_list", "1.0.0", "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s0, sigma_s1 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) m0, m1 = (model_output_list[-1], model_output_list[-2]) h, h_0 = (lambda_t - lambda_s0, lambda_s0 - lambda_s1) r0 = h_0 / h D0, D1 = (m0, 1.0 / r0 * (m0 - m1)) if self.config.algorithm_type == "dpmsolver++": if self.config.solver_type == "midpoint": x_t = ( sigma_t / sigma_s0 * sample - alpha_t * (torch.exp(-h) - 1.0) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( sigma_t / sigma_s0 * sample - alpha_t * (torch.exp(-h) - 1.0) * D0 + alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0) * D1 ) elif self.config.algorithm_type == "dpmsolver": if self.config.solver_type == "midpoint": x_t = ( alpha_t / alpha_s0 * sample - sigma_t * (torch.exp(h) - 1.0) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( alpha_t / alpha_s0 * sample - sigma_t * (torch.exp(h) - 1.0) * D0 - sigma_t * ((torch.exp(h) - 1.0) / h - 1.0) * D1 ) elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( sigma_t / sigma_s0 * torch.exp(-h) * sample + alpha_t * (1 - torch.exp(-2.0 * h)) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( sigma_t / sigma_s0 * torch.exp(-h) * sample + alpha_t * (1 - torch.exp(-2.0 * h)) * D0 + alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( alpha_t / alpha_s0 * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - sigma_t * (torch.exp(h) - 1.0) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) elif self.config.solver_type == "heun": x_t = ( alpha_t / alpha_s0 * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing`sample` as a required keyward argument") if timestep_list is not None: deprecate( "timestep_list", "1.0.0", "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], self.sigmas[self.step_index - 2], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) m0, m1, m2 = ( model_output_list[-1], model_output_list[-2], model_output_list[-3], ) h, h_0, h_1 = ( lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2, ) r0, r1 = (h_0 / h, h_1 / h) D0 = m0 D1_0, D1_1 = (1.0 / r0 * (m0 - m1), 1.0 / r1 * (m1 - m2)) D1 = D1_0 + r0 / (r0 + r1) * (D1_0 - D1_1) D2 = 1.0 / (r0 + r1) * (D1_0 - D1_1) if self.config.algorithm_type == "dpmsolver++": x_t = ( sigma_t / sigma_s0 * sample - alpha_t * (torch.exp(-h) - 1.0) * D0 + alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0) * D1 - alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5) * D2 ) elif self.config.algorithm_type == "dpmsolver": x_t = ( alpha_t / alpha_s0 * sample - sigma_t * (torch.exp(h) - 1.0) * D0 - sigma_t * ((torch.exp(h) - 1.0) / h - 1.0) * D1 - sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5) * D2 ) return x_t def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 elif len(index_candidates) > 1: step_index = index_candidates[1].item() else: step_index = index_candidates[0].item() return step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) lower_order_final = self.step_index == len(self.timesteps) - 1 and ( self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) or self.config.final_sigmas_type == "zero" ) lower_order_second = ( self.step_index == len(self.timesteps) - 2 and self.config.lower_order_final and (len(self.timesteps) < 15) ) model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output sample = sample.to(torch.float32) if ( self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None ): noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32, ) elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = variance_noise.to(device=model_output.device, dtype=torch.float32) else: noise = None if ( self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final ): prev_sample = self.dpm_solver_first_order_update( model_output, sample=sample, noise=noise ) elif ( self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second ): prev_sample = self.multistep_dpm_solver_second_order_update( self.model_outputs, sample=sample, noise=noise ) else: prev_sample = self.multistep_dpm_solver_third_order_update( self.model_outputs, sample=sample ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 prev_sample = prev_sample.to(model_output.dtype) self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype) sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype) timesteps = timesteps.to(original_samples.device) alpha_t = alpha_t[timesteps].flatten() while len(alpha_t.shape) < len(original_samples.shape): alpha_t = alpha_t.unsqueeze(-1) sigma_t = sigma_t[timesteps].flatten() while len(sigma_t.shape) < len(original_samples.shape): sigma_t = sigma_t.unsqueeze(-1) noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def get_velocity( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype) sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype) timesteps = timesteps.to(original_samples.device) alpha_t = alpha_t[timesteps].flatten() while len(alpha_t.shape) < len(original_samples.shape): alpha_t = alpha_t.unsqueeze(-1) sigma_t = sigma_t[timesteps].flatten() while len(sigma_t.shape) < len(original_samples.shape): sigma_t = sigma_t.unsqueeze(-1) velocity = alpha_t * noise - sigma_t * original_samples return velocity def __len__(self): return self.config.num_train_timesteps '\nProcessor class for QWEN3Vox models.\n' import os import json import warnings from typing import List, Optional, Union, Dict, Any import numpy as np import torch from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.utils import logging logger = logging.get_logger(__name__) class AudioNormalizer: def __init__(self, target_dB_FS: float = -25, eps: float = 1e-06): self.target_dB_FS = target_dB_FS self.eps = eps def tailor_dB_FS(self, audio: np.ndarray) -> tuple: rms = np.sqrt(np.mean(audio**2)) scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps) normalized_audio = audio * scalar return (normalized_audio, rms, scalar) def avoid_clipping( self, audio: np.ndarray, scalar: Optional[float] = None ) -> tuple: if scalar is None: max_val = np.max(np.abs(audio)) if max_val > 1.0: scalar = max_val + self.eps else: scalar = 1.0 return (audio / scalar, scalar) def __call__(self, audio: np.ndarray) -> np.ndarray: audio, _, _ = self.tailor_dB_FS(audio) audio, _ = self.avoid_clipping(audio) return audio class QWEN3VoxTokenizerProcessor(FeatureExtractionMixin): model_input_names = ["input_features"] def __init__( self, sampling_rate: int = 22050, normalize_audio: bool = True, target_dB_FS: float = -25, eps: float = 1e-06, **kwargs, ): super().__init__(**kwargs) self.sampling_rate = sampling_rate self.normalize_audio = normalize_audio if self.normalize_audio: self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps) else: self.normalizer = None self.feature_extractor_dict = { "sampling_rate": sampling_rate, "normalize_audio": normalize_audio, "target_dB_FS": target_dB_FS, "eps": eps, } def _ensure_mono(self, audio: np.ndarray) -> np.ndarray: if len(audio.shape) == 1: return audio elif len(audio.shape) == 2: if audio.shape[0] == 2: return np.mean(audio, axis=0) elif audio.shape[1] == 2: return np.mean(audio, axis=1) elif audio.shape[0] == 1: return audio.squeeze(0) elif audio.shape[1] == 1: return audio.squeeze(1) else: raise ValueError(f"Unexpected audio shape: {audio .shape }") else: raise ValueError(f"Audio should be 1D or 2D, got shape: {audio .shape }") def _process_single_audio( self, audio: Union[np.ndarray, List[float]] ) -> np.ndarray: if not isinstance(audio, np.ndarray): audio = np.array(audio, dtype=np.float32) else: audio = audio.astype(np.float32) audio = self._ensure_mono(audio) if self.normalize_audio and self.normalizer is not None: audio = self.normalizer(audio) return audio def __call__( self, audio: Union[ str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str] ] = None, sampling_rate: Optional[int] = None, return_tensors: Optional[str] = None, **kwargs, ): if audio is None: raise ValueError("Audio input is required") if sampling_rate is not None and sampling_rate != self.sampling_rate: logger.warning( f"Input sampling rate ({sampling_rate }) differs from expected sampling rate ({self .sampling_rate }). Please resample your audio." ) if isinstance(audio, str): audio = self._load_audio_from_path(audio) is_batched = False elif isinstance(audio, list): if len(audio) == 0: raise ValueError("Empty audio list provided") if all((isinstance(item, str) for item in audio)): audio = [self._load_audio_from_path(path) for path in audio] is_batched = True else: is_batched = isinstance(audio[0], (np.ndarray, list)) else: is_batched = False if is_batched: processed_audio = [self._process_single_audio(a) for a in audio] else: processed_audio = [self._process_single_audio(audio)] if return_tensors == "pt": if len(processed_audio) == 1: input_features = ( torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1) ) else: input_features = torch.stack( [torch.from_numpy(a) for a in processed_audio] ).unsqueeze(1) elif return_tensors == "np": if len(processed_audio) == 1: input_features = processed_audio[0][np.newaxis, np.newaxis, :] else: input_features = np.stack(processed_audio)[:, np.newaxis, :] else: input_features = ( processed_audio[0] if len(processed_audio) == 1 else processed_audio ) outputs = {"audio": input_features} return outputs def _load_audio_from_path(self, audio_path: str) -> np.ndarray: file_ext = os.path.splitext(audio_path)[1].lower() if file_ext in [".wav", ".mp3", ".flac", ".m4a", ".ogg"]: import librosa audio_array, sr = librosa.load(audio_path, sr=self.sampling_rate, mono=True) return audio_array elif file_ext == ".pt": audio_tensor = torch.load(audio_path, map_location="cpu").squeeze() if isinstance(audio_tensor, torch.Tensor): audio_array = audio_tensor.numpy() else: audio_array = np.array(audio_tensor) return audio_array.astype(np.float32) elif file_ext == ".npy": audio_array = np.load(audio_path) return audio_array.astype(np.float32) else: raise ValueError( f"Unsupported file format: {file_ext }. Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz" ) def preprocess_audio( self, audio_path_or_array: Union[str, np.ndarray], normalize: Optional[bool] = None, ) -> np.ndarray: if isinstance(audio_path_or_array, str): audio_array = self._load_audio_from_path(audio_path_or_array) else: audio_array = np.array(audio_path_or_array, dtype=np.float32) original_normalize = self.normalize_audio if normalize is not None: self.normalize_audio = normalize try: processed = self._process_single_audio(audio_array) finally: self.normalize_audio = original_normalize return processed def to_dict(self) -> Dict[str, Any]: return self.feature_extractor_dict def save_audio( self, audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], output_path: str = "output.wav", sampling_rate: Optional[int] = None, normalize: bool = False, batch_prefix: str = "audio_", ): if sampling_rate is None: sampling_rate = self.sampling_rate try: import soundfile as sf except ImportError: raise ImportError( "soundfile is required to save audio files. Install it with: pip install soundfile" ) if isinstance(audio, torch.Tensor): audio_np = audio.float().detach().cpu().numpy() elif isinstance(audio, np.ndarray): audio_np = audio elif isinstance(audio, list): if all((isinstance(a, torch.Tensor) for a in audio)): audio_np = [a.float().detach().cpu().numpy() for a in audio] else: audio_np = audio else: raise ValueError(f"Unsupported audio type: {type (audio )}") saved_paths = [] if isinstance(audio_np, list): output_dir = output_path os.makedirs(output_dir, exist_ok=True) for i, audio_item in enumerate(audio_np): audio_item = self._prepare_audio_for_save(audio_item, normalize) file_path = os.path.join(output_dir, f"{batch_prefix }{i }.wav") sf.write(file_path, audio_item, sampling_rate) saved_paths.append(file_path) elif len(audio_np.shape) >= 3: batch_size = audio_np.shape[0] if batch_size > 1: output_dir = output_path os.makedirs(output_dir, exist_ok=True) for i in range(batch_size): single_audio = audio_np[i] if len(single_audio.shape) > 1: if single_audio.shape[0] == 1: single_audio = single_audio.squeeze(0) single_audio = self._prepare_audio_for_save(single_audio, normalize) file_path = os.path.join(output_dir, f"{batch_prefix }{i }.wav") sf.write(file_path, single_audio, sampling_rate) saved_paths.append(file_path) else: audio_item = audio_np.squeeze() audio_item = self._prepare_audio_for_save(audio_item, normalize) sf.write(output_path, audio_item, sampling_rate) saved_paths.append(output_path) else: audio_item = self._prepare_audio_for_save(audio_np, normalize) sf.write(output_path, audio_item, sampling_rate) saved_paths.append(output_path) return saved_paths def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray: if len(audio.shape) > 1 and audio.shape[0] == 1: audio = audio.squeeze(0) if normalize: max_val = np.abs(audio).max() if max_val > 0: audio = audio / max_val return audio __all__ = [ 'QWEN3VoxTokenizerProcessor', "AudioNormalizer", ] import math import torch class UniformSampler: def __init__(self, timesteps=1000): self.timesteps = timesteps def sample(self, batch_size, device): return torch.randint(0, self.timesteps, (batch_size,), device=device) class LogitNormalSampler: def __init__(self, timesteps=1000, m=0, s=1): self.timesteps = timesteps timesteps = torch.linspace(0, 1, timesteps) logit = torch.log(timesteps / (1 - timesteps)) self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s**2) / ( s * math.sqrt(2 * math.pi) ) def sample(self, batch_size, device): return torch.multinomial(self.prob, batch_size, replacement=True).to(device) ' QWEN3Vox Streaming model configuration' from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging from transformers.models.qwen2.configuration_qwen2 import Qwen2Config logger = logging.get_logger(__name__) class QWEN3VoxStreamingConfig(PretrainedConfig): model_type = 'vibevoice_streaming' is_composition = True sub_configs = { "acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig, "decoder_config": Qwen2Config, "diffusion_head_config": QWEN3VoxDiffusionHeadConfig, } base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } def __init__( self, acoustic_tokenizer_config=None, decoder_config=None, diffusion_head_config=None, tts_backbone_num_hidden_layers=20, **kwargs, ): kwargs["_attn_implementation_autoset"] = False if acoustic_tokenizer_config is None: self.acoustic_tokenizer_config = self.sub_configs[ "acoustic_tokenizer_config" ]() elif isinstance(acoustic_tokenizer_config, dict): acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer' self.acoustic_tokenizer_config = self.sub_configs[ "acoustic_tokenizer_config" ](**acoustic_tokenizer_config) elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig): self.acoustic_tokenizer_config = acoustic_tokenizer_config if decoder_config is None: self.decoder_config = self.sub_configs["decoder_config"]() elif isinstance(decoder_config, dict): if decoder_config.get("model_type", "") == "qwen2": self.decoder_config = Qwen2Config(**decoder_config) else: raise ValueError( f"Unsupported decoder model type: {decoder_config .get ('model_type','')}" ) elif isinstance(decoder_config, (Qwen2Config,)): self.decoder_config = decoder_config if diffusion_head_config is None: self.diffusion_head_config = self.sub_configs["diffusion_head_config"]() elif isinstance(diffusion_head_config, dict): diffusion_head_config["model_type"] = 'vibevoice_diffusion_head' self.diffusion_head_config = self.sub_configs["diffusion_head_config"]( **diffusion_head_config ) elif isinstance(diffusion_head_config, QWEN3VoxDiffusionHeadConfig): self.diffusion_head_config = diffusion_head_config self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64) self.tts_backbone_num_hidden_layers = tts_backbone_num_hidden_layers super().__init__(**kwargs) __all__ = [ 'QWEN3VoxStreamingConfig' ] import math from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers.models.auto import AutoModel from transformers.modeling_utils import PreTrainedModel from transformers.activations import ACT2FN from transformers.utils import logging logger = logging.get_logger(__name__) class RMSNorm(nn.Module): def __init__( self, dim: int, eps: float = 1e-06, elementwise_affine=True, memory_efficient=False, ): super().__init__() self.dim = dim self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter(torch.ones(dim)) else: self.register_parameter("weight", None) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) if self.weight is not None: output = output * self.weight return output def extra_repr(self) -> str: return f"dim={self .dim }, eps={self .eps }, elementwise_affine={self .elementwise_affine }" def modulate(x, shift, scale): return x * (1 + scale) + shift class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=False), ACT2FN["silu"], nn.Linear(hidden_size, hidden_size, bias=False), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding.to(t.dtype) def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb class FeedForwardNetwork(nn.Module): def __init__(self, embed_dim, ffn_dim): super().__init__() self.embed_dim = embed_dim self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False) self.act_fn = ACT2FN["silu"] def forward(self, x): gate = self.gate_proj(x) up = self.up_proj(x) gate = self.act_fn(gate) return self.down_proj(gate * up) class HeadLayer(nn.Module): def __init__(self, embed_dim, ffn_dim, cond_dim, norm_eps=1e-05): super().__init__() self.embed_dim = embed_dim self.cond_dim = cond_dim self.ffn_dim = ffn_dim self.ffn = FeedForwardNetwork(self.embed_dim, self.ffn_dim) self.norm = RMSNorm(self.embed_dim, eps=norm_eps) self.adaLN_modulation = nn.Sequential( ACT2FN["silu"], nn.Linear(cond_dim, 3 * self.embed_dim, bias=False) ) def forward(self, x, c): shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1) x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn)) return x class FinalLayer(nn.Module): def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-05): super().__init__() self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False) self.linear = nn.Linear(hidden_size, output_size, bias=False) self.adaLN_modulation = nn.Sequential( ACT2FN["silu"], nn.Linear(cond_size, 2 * hidden_size, bias=False) ) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class QWEN3VoxDiffusionHead(PreTrainedModel): config_class = QWEN3VoxDiffusionHeadConfig supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True def __init__(self, config): super().__init__(config) self.config = config self.cond_dim = config.hidden_size latent_size = config.latent_size self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False) self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False) self.t_embedder = TimestepEmbedder(self.cond_dim) ffn_dim = int(config.hidden_size * config.head_ffn_ratio) self.layers = nn.ModuleList( [ HeadLayer( embed_dim=config.hidden_size, ffn_dim=ffn_dim, cond_dim=self.cond_dim, norm_eps=config.rms_norm_eps, ) for _ in range(config.head_layers) ] ) self.final_layer = FinalLayer( hidden_size=config.hidden_size, output_size=latent_size, cond_size=self.cond_dim, norm_eps=config.rms_norm_eps, ) self.initialize_weights() def initialize_weights(self): nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) for layer in self.layers: nn.init.constant_(layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.linear.weight, 0) def forward(self, noisy_images, timesteps, condition): x = self.noisy_images_proj(noisy_images) t = self.t_embedder(timesteps) condition = self.cond_proj(condition) c = condition + t for layer in self.layers: x = layer(x, c) x = self.final_layer(x, c) return x AutoModel.register(QWEN3VoxDiffusionHeadConfig, QWEN3VoxDiffusionHead) __all__ = [ 'QWEN3VoxDiffusionHead' ] import math import typing as tp from functools import partial from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Union import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers.models.auto import AutoModel from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging from transformers.modeling_utils import PreTrainedModel from transformers.activations import ACT2FN logger = logging.get_logger(__name__) import os try: from apex.normalization.fused_layer_norm import fused_rms_norm_affine APEX_AVAILABLE = True logger.info("APEX FusedRMSNorm is available and will be used for optimization") if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0: APEX_AVAILABLE = False logger.warning( "APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0" ) except ImportError: APEX_AVAILABLE = False logger.warning("APEX FusedRMSNorm not available, using native implementation") class ConvLayerNorm(nn.LayerNorm): def __init__( self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs ): super().__init__(normalized_shape, **kwargs) def forward(self, x): x = x.transpose(1, 2) x = nn.functional.layer_norm( x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps, ).type_as(x) x = x.transpose(1, 2) return x class RMSNorm(nn.Module): def __init__( self, dim: int, eps: float = 1e-05, elementwise_affine=True, weight_shape=None ): super().__init__() self.dim = dim self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: weight_shape = (dim,) if weight_shape is None else weight_shape self.weight = nn.Parameter(torch.ones(weight_shape)) else: self.register_parameter("weight", None) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) if self.weight is not None: output = output * self.weight return output def extra_repr(self) -> str: return f"dim={self .dim }, eps={self .eps }, elementwise_affine={self .elementwise_affine }" class ConvRMSNorm(RMSNorm): def __init__( self, dim: int, eps: float = 1e-05, elementwise_affine=True, weight_shape=None ): super().__init__(dim, eps, elementwise_affine, weight_shape) def forward(self, x): x = x.transpose(1, 2) if not APEX_AVAILABLE or not self.elementwise_affine: output = self._norm(x.float()).type_as(x) if self.weight is not None: output = output * self.weight else: output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps) output = output.transpose(1, 2) return output CONV_NORMALIZATIONS = frozenset( [ "none", "weight_norm", "spectral_norm", "time_layer_norm", "layer_norm", "time_group_norm", ] ) def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module: assert norm in CONV_NORMALIZATIONS if norm == "weight_norm": return nn.utils.weight_norm(module) elif norm == "spectral_norm": return nn.utils.spectral_norm(module) else: return module def get_norm_module( module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs ) -> nn.Module: assert norm in CONV_NORMALIZATIONS if norm == "layer_norm": assert isinstance(module, nn.modules.conv._ConvNd) return ConvLayerNorm(module.out_channels, **norm_kwargs) elif norm == "time_group_norm": if causal: raise ValueError("GroupNorm doesn't support causal evaluation.") assert isinstance(module, nn.modules.conv._ConvNd) return nn.GroupNorm(1, module.out_channels, **norm_kwargs) else: return nn.Identity() def get_extra_padding_for_conv1d( x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 ) -> int: length = x.shape[-1] n_frames = (length - kernel_size + padding_total) / stride + 1 ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) return ideal_length - length def pad1d( x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = "zero", value: float = 0.0, ): length = x.shape[-1] padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) if mode == "reflect": max_pad = max(padding_left, padding_right) extra_pad = 0 if length <= max_pad: extra_pad = max_pad - length + 1 x = F.pad(x, (0, extra_pad)) padded = F.pad(x, paddings, mode, value) end = padded.shape[-1] - extra_pad return padded[..., :end] else: return F.pad(x, paddings, mode, value) def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) assert padding_left + padding_right <= x.shape[-1] end = x.shape[-1] - padding_right return x[..., padding_left:end] class NormConv1d(nn.Module): def __init__( self, *args, causal: bool = False, norm: str = "none", norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs, ): super().__init__() self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) self.norm_type = norm def forward(self, x): x = self.conv(x) x = self.norm(x) return x class NormConvTranspose1d(nn.Module): def __init__( self, *args, causal: bool = False, norm: str = "none", norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs, ): super().__init__() self.convtr = apply_parametrization_norm( nn.ConvTranspose1d(*args, **kwargs), norm ) self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) self.norm_type = norm def forward(self, x): x = self.convtr(x) x = self.norm(x) return x class QWEN3VoxTokenizerStreamingCache: def __init__(self): self.cache = {} def get( self, layer_id: str, sample_indices: torch.Tensor ) -> Optional[torch.Tensor]: states = [] max_length = 0 for idx in sample_indices.tolist(): key = (layer_id, idx) if key not in self.cache: return None state = self.cache[key] states.append(state) max_length = max(max_length, state.shape[-1]) if len(states) > 0 and states[0].dim() >= 2: padded_states = [] for state in states: if state.shape[-1] < max_length: pad_size = max_length - state.shape[-1] padded_state = F.pad(state, (pad_size, 0), mode="constant", value=0) padded_states.append(padded_state) else: padded_states.append(state) return torch.stack(padded_states, dim=0) else: return torch.stack(states, dim=0) def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor): for i, idx in enumerate(sample_indices.tolist()): key = (layer_id, idx) self.cache[key] = states[i].detach() def set_to_zero(self, sample_indices: torch.Tensor): for key in list(self.cache.keys()): layer_id, sample_idx = key if sample_idx in sample_indices.tolist(): cached_tensor = self.cache[key] self.cache[key] = torch.zeros_like(cached_tensor) def clear( self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None, ): if layer_id is None and sample_indices is None: self.cache.clear() elif layer_id is not None and sample_indices is None: keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id] for k in keys_to_remove: del self.cache[k] elif layer_id is not None and sample_indices is not None: for idx in sample_indices.tolist(): key = (layer_id, idx) self.cache.pop(key, None) class SConv1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, causal: bool = False, norm: str = "none", norm_kwargs: tp.Dict[str, tp.Any] = {}, pad_mode: str = "reflect", ): super().__init__() self.conv = NormConv1d( in_channels, out_channels, kernel_size, stride, dilation=dilation, groups=groups, bias=bias, causal=causal, norm=norm, norm_kwargs=norm_kwargs, ) self.causal = causal self.pad_mode = pad_mode self.kernel_size = kernel_size self.dilation = dilation self.stride = stride self.in_channels = in_channels self.out_channels = out_channels self.context_size = (kernel_size - 1) * dilation - (stride - 1) self.padding_total = (kernel_size - 1) * dilation - (stride - 1) self._layer_id = None @property def layer_id(self): if self._layer_id is None: self._layer_id = f"sconv1d_{id (self )}" return self._layer_id def forward( self, x: torch.Tensor, cache: Optional[QWEN3VoxTokenizerStreamingCache] = None, sample_indices: Optional[torch.Tensor] = None, use_cache: bool = False, debug: bool = False, ) -> torch.Tensor: B, C, T = x.shape if not use_cache or cache is None: return self._forward_non_streaming(x, debug=debug) assert self.causal, "Streaming mode is only supported for causal convolutions" assert ( sample_indices is not None ), "sample_indices must be provided for streaming mode" assert len(sample_indices) == B, "sample_indices must match batch size" return self._forward_streaming(x, cache, sample_indices, debug) def _forward_streaming( self, x: torch.Tensor, cache: QWEN3VoxTokenizerStreamingCache, sample_indices: torch.Tensor, debug: bool = False, ) -> torch.Tensor: B, C, T = x.shape cached_states = cache.get(self.layer_id, sample_indices) if cached_states is None: if self.context_size > 0: cached_states = torch.zeros( B, C, self.context_size, device=x.device, dtype=x.dtype ) if debug: print( f"[DEBUG] Initialized cache with shape: {cached_states .shape }, context_size={self .context_size }" ) else: cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype) if debug: print(f"[DEBUG] No context needed (kernel_size=stride)") if cached_states.shape[2] > 0: input_with_context = torch.cat([cached_states, x], dim=2) else: input_with_context = x if debug: print( f"[DEBUG] Input shape: {x .shape }, Cache shape: {cached_states .shape }, Combined: {input_with_context .shape }" ) output = self.conv(input_with_context) if debug: print(f"[DEBUG] Output shape: {output .shape }") if self.context_size > 0: total_input_length = input_with_context.shape[2] if total_input_length >= self.context_size: new_cache_start = total_input_length - self.context_size new_cache = input_with_context[:, :, new_cache_start:] else: new_cache = input_with_context if debug: print(f"[DEBUG] New cache shape: {new_cache .shape }") cache.set(self.layer_id, sample_indices, new_cache) return output def _forward_non_streaming( self, x: torch.Tensor, debug: bool = False ) -> torch.Tensor: B, C, T = x.shape kernel_size = self.kernel_size stride = self.stride dilation = self.dilation padding_total = self.padding_total extra_padding = get_extra_padding_for_conv1d( x, kernel_size, stride, padding_total ) if debug: print( f"[DEBUG NON-STREAMING] Input shape: {x .shape }, padding_total={padding_total }, extra_padding={extra_padding }" ) if self.causal: if self.pad_mode == "constant": x = pad1d( x, (padding_total, extra_padding), mode=self.pad_mode, value=0 ) else: x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) else: padding_right = padding_total // 2 padding_left = padding_total - padding_right x = pad1d( x, (padding_left, padding_right + extra_padding), mode=self.pad_mode ) if debug: print(f"[DEBUG NON-STREAMING] After padding: {x .shape }") output = self.conv(x) if debug: print(f"[DEBUG NON-STREAMING] Output shape: {output .shape }") return output class SConvTranspose1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, causal: bool = False, norm: str = "none", trim_right_ratio: float = 1.0, norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True, ): super().__init__() self.convtr = NormConvTranspose1d( in_channels, out_channels, kernel_size, stride, causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias, ) self.causal = causal self.trim_right_ratio = trim_right_ratio assert ( self.causal or self.trim_right_ratio == 1.0 ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 self.kernel_size = kernel_size self.stride = stride self.in_channels = in_channels self.out_channels = out_channels self.padding_total = kernel_size - stride self.context_size = kernel_size - 1 self._layer_id = None @property def layer_id(self): if self._layer_id is None: self._layer_id = f"sconvtr1d_{id (self )}" return self._layer_id def forward( self, x: torch.Tensor, cache: Optional[QWEN3VoxTokenizerStreamingCache] = None, sample_indices: Optional[torch.Tensor] = None, use_cache: bool = False, debug: bool = False, ) -> torch.Tensor: B, C, T = x.shape if not use_cache or cache is None: return self._forward_non_streaming(x, debug=debug) assert ( sample_indices is not None ), "sample_indices must be provided for streaming mode" assert len(sample_indices) == B, "sample_indices must match batch size" return self._forward_streaming(x, cache, sample_indices, debug) def _forward_streaming( self, x: torch.Tensor, cache: QWEN3VoxTokenizerStreamingCache, sample_indices: torch.Tensor, debug: bool = False, ) -> torch.Tensor: B, C, T = x.shape cached_input = cache.get(self.layer_id, sample_indices) if cached_input is None: cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype) if debug: print(f"[DEBUG] Initialized empty cache for transposed conv") full_input = torch.cat([cached_input, x], dim=2) if debug: print( f"[DEBUG] Input shape: {x .shape }, Cache shape: {cached_input .shape }, Combined: {full_input .shape }" ) full_output = self.convtr(full_input) if debug: print(f"[DEBUG] Full transposed conv output shape: {full_output .shape }") if self.causal: padding_right = math.ceil(self.padding_total * self.trim_right_ratio) padding_left = self.padding_total - padding_right else: padding_right = self.padding_total // 2 padding_left = self.padding_total - padding_right if padding_left + padding_right > 0: full_output = unpad1d(full_output, (padding_left, padding_right)) if debug: print(f"[DEBUG] After unpadding: {full_output .shape }") if cached_input.shape[2] == 0: output = full_output else: expected_new_output = T * self.stride if full_output.shape[2] >= expected_new_output: output = full_output[:, :, -expected_new_output:] else: output = full_output if debug: print(f"[DEBUG] Final streaming output shape: {output .shape }") if full_input.shape[2] > self.context_size: new_cache = full_input[:, :, -self.context_size :] else: new_cache = full_input if debug: print(f"[DEBUG] New cache shape: {new_cache .shape }") cache.set(self.layer_id, sample_indices, new_cache) return output def _forward_non_streaming( self, x: torch.Tensor, debug: bool = False ) -> torch.Tensor: if debug: print(f"[DEBUG NON-STREAMING] Input shape: {x .shape }") y = self.convtr(x) if debug: print(f"[DEBUG NON-STREAMING] After transposed conv: {y .shape }") if self.causal: padding_right = math.ceil(self.padding_total * self.trim_right_ratio) padding_left = self.padding_total - padding_right else: padding_right = self.padding_total // 2 padding_left = self.padding_total - padding_right if padding_left + padding_right > 0: y = unpad1d(y, (padding_left, padding_right)) if debug: print(f"[DEBUG NON-STREAMING] Final output shape: {y .shape }") return y class FFN(nn.Module): def __init__(self, embed_dim, ffn_dim, bias=False): super().__init__() self.embed_dim = embed_dim self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias) self.gelu = ACT2FN["gelu"] self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias) def forward(self, x): x = self.linear1(x) x = self.gelu(x) x = self.linear2(x) return x class Convlayer(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True, pad_mode="zeros", norm="weight_norm", causal=True, ): super().__init__() self.conv = SConv1d( in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal, ) def forward(self, x): return self.conv(x) class Block1D(nn.Module): def __init__( self, dim, kernel_size=7, drop_path=0.0, mixer_layer="conv", layer_scale_init_value=1e-06, **kwargs, ): super().__init__() if kwargs.get("layernorm", "LN") == "LN": self.norm = ConvLayerNorm(dim, eps=kwargs.get("eps", 1e-06)) self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get("eps", 1e-06)) elif kwargs.get("layernorm", "RMSNorm") == "RMSNorm": self.norm = ConvRMSNorm(dim, eps=kwargs.get("eps", 1e-06)) self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get("eps", 1e-06)) if mixer_layer == "conv": self.mixer = Convlayer( dim, dim, groups=kwargs.get("groups", 1), kernel_size=kernel_size, pad_mode=kwargs.get("pad_mode", "reflect"), norm=kwargs.get("norm", "none"), causal=kwargs.get("causal", True), bias=kwargs.get("bias", True), ) elif mixer_layer == "depthwise_conv": self.mixer = Convlayer( dim, dim, groups=dim, kernel_size=kernel_size, pad_mode=kwargs.get("pad_mode", "reflect"), norm=kwargs.get("norm", "none"), causal=kwargs.get("causal", True), bias=kwargs.get("bias", True), ) else: raise ValueError(f"Unsupported mixer layer: {mixer_layer }") self.ffn = FFN( dim, kwargs.get("ffn_expansion", 4) * dim, bias=kwargs.get("bias", False) ) self.drop_path = ( nn.Identity() if drop_path <= 0.0 else nn.modules.DropPath(drop_path) ) if layer_scale_init_value > 0: self.gamma = nn.Parameter( layer_scale_init_value * torch.ones(dim), requires_grad=True ) self.ffn_gamma = nn.Parameter( layer_scale_init_value * torch.ones(dim), requires_grad=True ) else: self.gamma = None self.ffn_gamma = None def forward(self, x): residual = x x = self.norm(x) x = self.mixer(x) if self.gamma is not None: x = x * self.gamma.unsqueeze(-1) x = residual + self.drop_path(x) residual = x x = self.ffn_norm(x) x = x.permute(0, 2, 1) x = self.ffn(x) x = x.permute(0, 2, 1) if self.ffn_gamma is not None: x = x * self.ffn_gamma.unsqueeze(-1) x = residual + self.drop_path(x) return x class TokenizerEncoder(nn.Module): def __init__(self, config): super().__init__() self.channels = config.channels self.dimension = config.dimension self.n_filters = config.n_filters self.ratios = list(reversed(config.ratios)) self.depths = config.depths self.n_residual_layers = getattr(config, "n_residual_layers", 1) self.hop_length = np.prod(self.ratios) self.causal = config.causal kernel_size = getattr(config, "kernel_size", 7) last_kernel_size = getattr(config, "last_kernel_size", 7) norm = getattr(config, "norm", "none") norm_params = getattr(config, "norm_params", {}) pad_mode = getattr(config, "pad_mode", "reflect") bias = getattr(config, "bias", True) layernorm = getattr(config, "layernorm", "LN") layernorm_eps = getattr(config, "layernorm_eps", 1e-06) layernorm_elementwise_affine = getattr( config, "layernorm_elementwise_affine", True ) drop_path_rate = getattr(config, "drop_path_rate", 0.0) mixer_layer = getattr(config, "mixer_layer", "conv") layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) disable_last_norm = getattr(config, "disable_last_norm", False) if layernorm == "LN": norm_type = ConvLayerNorm elif layernorm == "RMSNorm": norm_type = partial( ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine ) else: raise ValueError(f"Unsupported norm type: {layernorm }") stem = nn.Sequential( SConv1d( self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias, ) ) self.downsample_layers = nn.ModuleList() self.downsample_layers.append(stem) for i in range(len(self.ratios)): in_ch = self.n_filters * 2**i out_ch = self.n_filters * 2 ** (i + 1) downsample_layer = nn.Sequential( SConv1d( in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias, ) ) self.downsample_layers.append(downsample_layer) layer_type = partial( Block1D, mixer_layer=mixer_layer, layernorm=layernorm, eps=layernorm_eps, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias, layer_scale_init_value=layer_scale_init_value, ) self.stages = nn.ModuleList() dp_rates = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths)) ] cur = 0 for i in range(len(self.depths)): in_ch = self.n_filters * 2**i stage = nn.Sequential( *[ layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i]) ] ) self.stages.append(stage) cur += self.depths[i] if not disable_last_norm: self.norm = norm_type(in_ch, eps=layernorm_eps) else: self.norm = nn.Identity() self.head = SConv1d( in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias, ) def forward_features( self, x, cache=None, sample_indices=None, use_cache=False, debug=False ): for i in range(len(self.depths)): for layer in self.downsample_layers[i]: if isinstance(layer, SConv1d): x = layer( x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) else: x = layer(x) for block in self.stages[i]: if ( hasattr(block, "mixer") and hasattr(block.mixer, "conv") and isinstance(block.mixer.conv, SConv1d) ): residual = x x = block.norm(x) x = block.mixer.conv( x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) if block.gamma is not None: x = x * block.gamma.unsqueeze(-1) x = residual + x residual = x x = block.ffn_norm(x) x = x.permute(0, 2, 1) x = block.ffn(x) x = x.permute(0, 2, 1) if block.ffn_gamma is not None: x = x * block.ffn_gamma.unsqueeze(-1) x = residual + x else: x = block(x) return self.norm(x) def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): x = self.forward_features( x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) x = self.head( x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) return x class TokenizerDecoder(nn.Module): def __init__(self, config): super().__init__() self.dimension = config.dimension self.channels = config.channels self.n_filters = config.n_filters self.ratios = config.ratios self.depths = config.depths self.n_residual_layers = getattr(config, "n_residual_layers", 1) self.hop_length = np.prod(self.ratios) self.causal = config.causal kernel_size = getattr(config, "kernel_size", 7) last_kernel_size = getattr(config, "last_kernel_size", 7) norm = getattr(config, "norm", "none") norm_params = getattr(config, "norm_params", {}) pad_mode = getattr(config, "pad_mode", "reflect") bias = getattr(config, "bias", True) layernorm = getattr(config, "layernorm", "LN") layernorm_eps = getattr(config, "layernorm_eps", 1e-06) trim_right_ratio = getattr(config, "trim_right_ratio", 1.0) layernorm_elementwise_affine = getattr( config, "layernorm_elementwise_affine", True ) drop_path_rate = getattr(config, "drop_path_rate", 0.0) mixer_layer = getattr(config, "mixer_layer", "conv") layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) disable_last_norm = getattr(config, "disable_last_norm", False) if layernorm == "LN": norm_type = ConvLayerNorm elif layernorm == "RMSNorm": norm_type = partial( ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine ) else: raise ValueError(f"Unsupported norm type: {layernorm }") stem = nn.Sequential( SConv1d( self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias, ) ) self.upsample_layers = nn.ModuleList() self.upsample_layers.append(stem) for i in range(len(self.ratios)): in_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i) out_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i - 1) upsample_layer = nn.Sequential( SConvTranspose1d( in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], norm=norm, norm_kwargs=norm_params, bias=bias, causal=self.causal, trim_right_ratio=trim_right_ratio, ) ) self.upsample_layers.append(upsample_layer) layer_type = partial( Block1D, mixer_layer=mixer_layer, layernorm=layernorm, eps=layernorm_eps, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias, layer_scale_init_value=layer_scale_init_value, ) self.stages = nn.ModuleList() dp_rates = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths)) ] cur = 0 for i in range(len(self.depths)): in_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i) stage = nn.Sequential( *[ layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i]) ] ) self.stages.append(stage) cur += self.depths[i] if not disable_last_norm: self.norm = norm_type(in_ch, eps=layernorm_eps) else: self.norm = nn.Identity() self.head = SConv1d( in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias, ) def forward_features( self, x, cache=None, sample_indices=None, use_cache=False, debug=False ): for i in range(len(self.depths)): for layer in self.upsample_layers[i]: if isinstance(layer, (SConv1d, SConvTranspose1d)): x = layer( x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) else: x = layer(x) for block in self.stages[i]: if ( hasattr(block, "mixer") and hasattr(block.mixer, "conv") and isinstance(block.mixer.conv, SConv1d) ): residual = x x = block.norm(x) x = block.mixer.conv( x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) if block.gamma is not None: x = x * block.gamma.unsqueeze(-1) x = residual + x residual = x x = block.ffn_norm(x) x = x.permute(0, 2, 1) x = block.ffn(x) x = x.permute(0, 2, 1) if block.ffn_gamma is not None: x = x * block.ffn_gamma.unsqueeze(-1) x = residual + x else: x = block(x) return self.norm(x) def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): x = self.forward_features( x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) x = self.head( x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) return x @dataclass class QWEN3VoxTokenizerEncoderOutput: mean: torch.Tensor std: Optional[Union[float, torch.Tensor]] = None def sample(self, dist_type="fix"): if dist_type == "fix": x = self.mean + self.std * torch.randn_like(self.mean) return (x, self.std) elif dist_type == "gaussian": batch_size = self.mean.size(0) value = self.std / 0.8 std = ( torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value ) while std.dim() < self.mean.dim(): std = std.unsqueeze(-1) x = self.mean + std * torch.randn_like(self.mean) return (x, std) else: return (self.mean, self.std) def kl(self): target = torch.zeros_like(self.mean) return F.mse_loss(self.mean, target, reduction="none") def mode(self): return self.mean class QWEN3VoxAcousticTokenizerModel(PreTrainedModel): config_class = QWEN3VoxAcousticTokenizerConfig base_model_prefix = 'vibevoice_acoustic_tokenizer' _supports_flash_attn_2 = True _supports_sdpa = True _no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"] def __init__(self, config): super().__init__(config) self.register_buffer("fix_std", torch.tensor(config.fix_std), persistent=False) self.std_dist_type = getattr(config, "std_dist_type", "fix") if isinstance(config.encoder_depths, str): encoder_depths = [int(d) for d in config.encoder_depths.split("-")] else: encoder_depths = config.encoder_depths if config.decoder_depths is not None and isinstance(config.decoder_depths, str): decoder_depths = [int(d) for d in config.decoder_depths.split("-")] else: decoder_depths = list(reversed(encoder_depths)) encoder_config = copy.deepcopy(config) encoder_config.dimension = config.vae_dim encoder_config.n_filters = config.encoder_n_filters encoder_config.ratios = config.encoder_ratios encoder_config.depths = encoder_depths encoder_config.norm = config.conv_norm encoder_config.pad_mode = config.pad_mode encoder_config.bias = config.conv_bias encoder_config.layernorm_eps = config.layernorm_eps encoder_config.layernorm_elementwise_affine = ( config.layernorm_elementwise_affine ) encoder_config.mixer_layer = config.mixer_layer encoder_config.layer_scale_init_value = config.layer_scale_init_value encoder_config.disable_last_norm = config.disable_last_norm decoder_config = copy.deepcopy(config) decoder_config.dimension = config.vae_dim decoder_config.n_filters = config.decoder_n_filters decoder_config.ratios = config.decoder_ratios decoder_config.depths = decoder_depths decoder_config.norm = config.conv_norm decoder_config.pad_mode = config.pad_mode decoder_config.bias = config.conv_bias decoder_config.layernorm_eps = config.layernorm_eps decoder_config.layernorm_elementwise_affine = ( config.layernorm_elementwise_affine ) decoder_config.mixer_layer = config.mixer_layer decoder_config.layer_scale_init_value = config.layer_scale_init_value decoder_config.disable_last_norm = config.disable_last_norm self.encoder = TokenizerEncoder(encoder_config) self.decoder = TokenizerDecoder(decoder_config) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=self.config.weight_init_value) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv1d): nn.init.normal_(module.weight, std=self.config.weight_init_value) if module.bias is not None: nn.init.zeros_(module.bias) @torch.no_grad() def encode( self, audio, cache=None, sample_indices=None, use_cache=False, debug=False ): latents = self.encoder( audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) return QWEN3VoxTokenizerEncoderOutput( mean=latents.permute(0, 2, 1), std=self.fix_std ) @torch.no_grad() def sampling(self, encoder_output, dist_type=None): dist_type = dist_type or self.std_dist_type if dist_type == "fix": return encoder_output.sample(dist_type="fix") elif dist_type == "gaussian": return encoder_output.sample(dist_type="gaussian") else: raise ValueError( f"Unsupported dist_type: {dist_type }, expected 'fix' or 'gaussian'" ) @torch.no_grad() def decode( self, latents, cache=None, sample_indices=None, use_cache=False, debug=False ): if latents.shape[1] == self.config.vae_dim: pass else: latents = latents.permute(0, 2, 1) audio = self.decoder( latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) return audio def forward( self, audio, cache=None, sample_indices=None, use_cache=False, debug=False ): encoder_output = self.encode( audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) sampled_latents, _ = self.sampling(encoder_output) reconstructed = self.decode( sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) return (reconstructed, sampled_latents) class QWEN3VoxSemanticTokenizerModel(PreTrainedModel): config_class = QWEN3VoxSemanticTokenizerConfig base_model_prefix = 'vibevoice_semantic_tokenizer' _supports_flash_attn_2 = True _supports_sdpa = True _no_split_modules = ["TokenizerEncoder"] def __init__(self, config): super().__init__(config) if isinstance(config.encoder_depths, str): encoder_depths = [int(d) for d in config.encoder_depths.split("-")] else: encoder_depths = config.encoder_depths encoder_config = copy.deepcopy(config) encoder_config.dimension = config.vae_dim encoder_config.n_filters = config.encoder_n_filters encoder_config.ratios = config.encoder_ratios encoder_config.depths = encoder_depths encoder_config.norm = config.conv_norm encoder_config.pad_mode = config.pad_mode encoder_config.bias = config.conv_bias encoder_config.layernorm_eps = config.layernorm_eps encoder_config.layernorm_elementwise_affine = ( config.layernorm_elementwise_affine ) encoder_config.mixer_layer = config.mixer_layer encoder_config.layer_scale_init_value = config.layer_scale_init_value encoder_config.disable_last_norm = config.disable_last_norm self.encoder = TokenizerEncoder(encoder_config) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=self.config.weight_init_value) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv1d): nn.init.normal_(module.weight, std=self.config.weight_init_value) if module.bias is not None: nn.init.zeros_(module.bias) @torch.no_grad() def encode( self, audio, cache=None, sample_indices=None, use_cache=False, debug=False ): latents = self.encoder( audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) return QWEN3VoxTokenizerEncoderOutput(mean=latents.permute(0, 2, 1)) @torch.no_grad() def sampling(self, encoder_output, dist_type=None): return encoder_output.sample(dist_type="none") def forward( self, audio, cache=None, sample_indices=None, use_cache=False, debug=False ): encoder_output = self.encode( audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, ) sampled_latents, _ = self.sampling(encoder_output, dist_type="none") return (None, sampled_latents) AutoModel.register(QWEN3VoxAcousticTokenizerConfig, QWEN3VoxAcousticTokenizerModel) AutoModel.register(QWEN3VoxSemanticTokenizerConfig, QWEN3VoxSemanticTokenizerModel) __all__ = [ 'QWEN3VoxTokenizerStreamingCache', 'QWEN3VoxAcousticTokenizerModel', 'QWEN3VoxSemanticTokenizerModel', ] '\nProcessor class for QWEN3Vox ASR models.\n' import os import json import math import warnings from typing import List, Optional, Union, Dict, Any, Tuple import numpy as np import torch from transformers.tokenization_utils_base import BatchEncoding from transformers.utils import TensorType, logging logger = logging.get_logger(__name__) SYSTEM_PROMPT = "You are a helpful assistant that transcribes audio input into text output in JSON format." class QWEN3VoxASRProcessor: def __init__( self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=320, target_sample_rate=22050, normalize_audio=True, **kwargs, ): self.tokenizer = tokenizer self.audio_processor = audio_processor or QWEN3VoxTokenizerProcessor( sampling_rate=target_sample_rate, normalize_audio=normalize_audio ) self.speech_tok_compress_ratio = speech_tok_compress_ratio self.target_sample_rate = target_sample_rate self.normalize_audio = normalize_audio if normalize_audio: self.audio_normalizer = AudioNormalizer() else: self.audio_normalizer = None self._cache_special_tokens() def _cache_special_tokens(self): if hasattr(self.tokenizer, "speech_start_id"): self.speech_start_id = self.tokenizer.speech_start_id else: self.speech_start_id = self.tokenizer.convert_tokens_to_ids( "<|speech_start|>" ) if hasattr(self.tokenizer, "speech_end_id"): self.speech_end_id = self.tokenizer.speech_end_id else: self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|speech_end|>") if hasattr(self.tokenizer, "speech_pad_id"): self.speech_pad_id = self.tokenizer.speech_pad_id else: self.speech_pad_id = self.tokenizer.convert_tokens_to_ids("<|speech_pad|>") if hasattr(self.tokenizer, "pad_id"): self.pad_id = self.tokenizer.pad_id elif hasattr(self.tokenizer, "pad_token_id"): self.pad_id = self.tokenizer.pad_token_id else: self.pad_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>") @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): import json from transformers.utils import cached_file config_path = os.path.join( pretrained_model_name_or_path, "preprocessor_config.json" ) config = {} if os.path.exists(config_path): with open(config_path, "r") as f: config = json.load(f) else: try: config_file = cached_file( pretrained_model_name_or_path, "preprocessor_config.json", **kwargs ) with open(config_file, "r") as f: config = json.load(f) except Exception as e: logger.warning(f"Could not load preprocessor_config.json: {e }") logger.warning("Using default configuration") speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) target_sample_rate = config.get("target_sample_rate", 22050) normalize_audio = config.get("normalize_audio", True) model_name = pretrained_model_name_or_path logger.info(f"Loading tokenizer from {model_name }") tokenizer = QWEN3VoxASRTextTokenizerFast.from_pretrained( model_name, **kwargs ) audio_processor = QWEN3VoxTokenizerProcessor( sampling_rate=target_sample_rate, normalize_audio=normalize_audio, target_dB_FS=config.get("target_dB_FS", -25), eps=config.get("eps", 1e-06), ) return cls( tokenizer=tokenizer, audio_processor=audio_processor, speech_tok_compress_ratio=speech_tok_compress_ratio, target_sample_rate=target_sample_rate, normalize_audio=normalize_audio, ) def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): import json os.makedirs(save_directory, exist_ok=True) processor_config = { "processor_class": "QWEN3VoxASRProcessor", "speech_tok_compress_ratio": self.speech_tok_compress_ratio, "target_sample_rate": self.target_sample_rate, "normalize_audio": self.normalize_audio, "target_dB_FS": -25, "eps": 1e-06, } config_path = os.path.join(save_directory, "preprocessor_config.json") with open(config_path, "w") as f: json.dump(processor_config, f, indent=2) logger.info(f"Processor configuration saved in {config_path }") def __call__( self, audio: Optional[ Union[ str, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, torch.Tensor]], ] ] = None, sampling_rate: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, padding: bool = True, max_length: Optional[int] = None, truncation: bool = False, add_generation_prompt: bool = True, use_streaming: bool = True, context_info: Optional[str] = None, **kwargs, ) -> BatchEncoding: if audio is None: raise ValueError("Audio input is required for ASR processing") if isinstance(audio, list): is_batched = True audio_list = audio else: is_batched = False audio_list = [audio] all_encodings = [] for audio_input in audio_list: encoding = self._process_single_audio( audio_input, sampling_rate=sampling_rate, add_generation_prompt=add_generation_prompt, use_streaming=use_streaming, context_info=context_info, ) all_encodings.append(encoding) batch_encoding = self._batch_encode( all_encodings, padding=padding, max_length=max_length, truncation=truncation, return_tensors=return_tensors, ) return batch_encoding def _process_single_audio( self, audio: Union[str, np.ndarray, torch.Tensor], sampling_rate: Optional[int] = None, add_generation_prompt: bool = True, use_streaming: bool = True, context_info: Optional[str] = None, ) -> Dict[str, Any]: if isinstance(audio, str): import soundfile as sf audio_array, file_sr = sf.read(audio) if audio_array.ndim > 1: audio_array = audio_array.mean(axis=1) if file_sr != self.target_sample_rate: import librosa audio_array = librosa.resample( audio_array, orig_sr=file_sr, target_sr=self.target_sample_rate ) elif isinstance(audio, torch.Tensor): audio_array = audio.cpu().numpy() if audio_array.ndim > 1: audio_array = audio_array.squeeze() else: audio_array = np.array(audio, dtype=np.float32) if audio_array.ndim > 1: audio_array = audio_array.squeeze() audio_array = audio_array.astype(np.float32) if self.normalize_audio and self.audio_normalizer: audio_array = self.audio_normalizer(audio_array) audio_duration = len(audio_array) / self.target_sample_rate if use_streaming and audio_duration < 60.0: use_streaming = False vae_tok_len = math.ceil(len(audio_array) / self.speech_tok_compress_ratio) system_prompt_text = self.tokenizer.apply_chat_template( [{"role": "system", "content": SYSTEM_PROMPT}], tokenize=False ) system_tokens = self.tokenizer.encode(system_prompt_text) sp_start_token = self.tokenizer.convert_ids_to_tokens(self.speech_start_id) sp_pad_token = self.tokenizer.convert_ids_to_tokens(self.speech_pad_id) sp_end_token = self.tokenizer.convert_ids_to_tokens(self.speech_end_id) show_keys = ["Start time", "End time", "Speaker ID", "Content"] if context_info and context_info.strip(): user_suffix = ( f"This is a {audio_duration :.2f} seconds audio, with extra info: {context_info .strip ()}\n\nPlease transcribe it with these keys: " + ", ".join(show_keys) ) else: user_suffix = ( f"This is a {audio_duration :.2f} seconds audio, please transcribe it with these keys: " + ", ".join(show_keys) ) user_input_string = ( "".join([sp_start_token] + [sp_pad_token] * vae_tok_len + [sp_end_token]) + "\n" + user_suffix ) user_tokens = self.tokenizer.apply_chat_template( [{"role": "user", "content": user_input_string}], tokenize=True ) full_tokens = system_tokens + user_tokens acoustic_input_mask = [ 1 if token == self.speech_pad_id else 0 for token in full_tokens ] return { "input_ids": full_tokens, "acoustic_input_mask": acoustic_input_mask, "speech": audio_array, "vae_tok_len": vae_tok_len, } def _batch_encode( self, encodings: List[Dict[str, Any]], padding: bool = True, max_length: Optional[int] = None, truncation: bool = False, return_tensors: Optional[str] = None, ) -> BatchEncoding: input_ids_list = [enc["input_ids"] for enc in encodings] acoustic_masks_list = [enc["acoustic_input_mask"] for enc in encodings] speech_list = [enc["speech"] for enc in encodings] vae_tok_lens = [enc["vae_tok_len"] for enc in encodings] if padding: if max_length is not None: target_length = max_length else: target_length = max((len(ids) for ids in input_ids_list)) padded_input_ids = [] padded_acoustic_masks = [] attention_masks = [] for input_ids, acoustic_mask in zip(input_ids_list, acoustic_masks_list): if truncation and len(input_ids) > target_length: input_ids = input_ids[:target_length] acoustic_mask = acoustic_mask[:target_length] padding_length = target_length - len(input_ids) padded_ids = [self.pad_id] * padding_length + input_ids padded_acoustic = [0] * padding_length + acoustic_mask attention_mask = [0] * padding_length + [1] * len(input_ids) padded_input_ids.append(padded_ids) padded_acoustic_masks.append(padded_acoustic) attention_masks.append(attention_mask) input_ids_list = padded_input_ids acoustic_masks_list = padded_acoustic_masks else: attention_masks = [[1] * len(ids) for ids in input_ids_list] max_speech_length = max((len(s) for s in speech_list)) padded_speeches = np.zeros( (len(speech_list), max_speech_length), dtype=np.float32 ) speech_masks = np.zeros((len(speech_list), max(vae_tok_lens)), dtype=bool) for i, (speech, vae_len) in enumerate(zip(speech_list, vae_tok_lens)): padded_speeches[i, : len(speech)] = speech speech_masks[i, :vae_len] = True batch_encoding = BatchEncoding() if return_tensors == "pt": batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) batch_encoding["attention_mask"] = torch.tensor( attention_masks, dtype=torch.long ) batch_encoding["acoustic_input_mask"] = torch.tensor( acoustic_masks_list, dtype=torch.bool ) batch_encoding["speech_tensors"] = torch.tensor( padded_speeches, dtype=torch.float32 ) batch_encoding["speech_masks"] = torch.tensor( speech_masks, dtype=torch.bool ) else: batch_encoding["input_ids"] = ( input_ids_list if len(input_ids_list) > 1 else input_ids_list[0] ) batch_encoding["attention_mask"] = ( attention_masks if len(attention_masks) > 1 else attention_masks[0] ) batch_encoding["acoustic_input_mask"] = ( acoustic_masks_list if len(acoustic_masks_list) > 1 else acoustic_masks_list[0] ) batch_encoding["speech_tensors"] = ( padded_speeches if len(padded_speeches) > 1 else padded_speeches[0] ) batch_encoding["speech_masks"] = ( speech_masks if len(speech_masks) > 1 else speech_masks[0] ) return batch_encoding def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) def post_process_transcription(self, text: str) -> List[Dict[str, Any]]: try: if "```json" in text: json_start = text.find("```json") + 7 json_end = text.find("```", json_start) json_str = text[json_start:json_end].strip() else: json_start = text.find("[") if json_start == -1: json_start = text.find("{") if json_start != -1: bracket_count = 0 json_end = json_start for i in range(json_start, len(text)): if text[i] in "[{": bracket_count += 1 elif text[i] in "]}": bracket_count -= 1 if bracket_count == 0: json_end = i + 1 break json_str = text[json_start:json_end] else: json_str = text result = json.loads(json_str) if isinstance(result, dict): result = [result] cleaned_result = [] for item in result: if isinstance(item, dict): cleaned_item = {} key_mapping = { "Start time": "start_time", "Start": "start_time", "End time": "end_time", "End": "end_time", "Speaker ID": "speaker_id", "Speaker": "speaker_id", "Content": "text", } for key, mapped_key in key_mapping.items(): if key in item: cleaned_item[mapped_key] = item[key] if cleaned_item: cleaned_result.append(cleaned_item) return cleaned_result except json.JSONDecodeError as e: logger.warning(f"Failed to parse JSON from transcription: {e }") logger.debug(f"Raw text: {text }") return [] except Exception as e: logger.warning(f"Error post-processing transcription: {e }") return [] @property def model_input_names(self): return [ "input_ids", "attention_mask", "acoustic_input_mask", "speech_tensors", "speech_masks", ] __all__ = [ 'QWEN3VoxASRProcessor' ] import math import warnings from typing import List, Optional, Union, Dict, Any, Tuple import os import re import numpy as np import torch from transformers.tokenization_utils_base import ( BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy, ) from transformers.utils import TensorType, logging logger = logging.get_logger(__name__) class QWEN3VoxProcessor: def __init__( self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs, ): self.tokenizer = tokenizer self.audio_processor = audio_processor self.speech_tok_compress_ratio = speech_tok_compress_ratio self.db_normalize = db_normalize self.audio_normalizer = AudioNormalizer() if db_normalize else None self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n" @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): import os import json from transformers.utils import cached_file config_path = os.path.join( pretrained_model_name_or_path, "preprocessor_config.json" ) config = None if os.path.exists(config_path): with open(config_path, "r") as f: config = json.load(f) else: try: config_file = cached_file( pretrained_model_name_or_path, "preprocessor_config.json", **kwargs ) with open(config_file, "r") as f: config = json.load(f) except Exception as e: logger.warning( f"Could not load preprocessor_config.json from {pretrained_model_name_or_path }: {e }" ) logger.warning("Using default configuration") config = {"speech_tok_compress_ratio": 3200, "db_normalize": True} speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) db_normalize = config.get("db_normalize", True) model_name = pretrained_model_name_or_path logger.info(f"Loading tokenizer from {model_name }") tokenizer = QWEN3VoxTextTokenizerFast.from_pretrained( model_name, **kwargs ) if "audio_processor" in config: audio_config = config["audio_processor"] audio_processor = QWEN3VoxTokenizerProcessor( sampling_rate=audio_config.get("sampling_rate", 22050), normalize_audio=audio_config.get("normalize_audio", True), target_dB_FS=audio_config.get("target_dB_FS", -25), eps=audio_config.get("eps", 1e-06), ) else: audio_processor = QWEN3VoxTokenizerProcessor() return cls( tokenizer=tokenizer, audio_processor=audio_processor, speech_tok_compress_ratio=speech_tok_compress_ratio, db_normalize=db_normalize, ) def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): import os import json os.makedirs(save_directory, exist_ok=True) processor_config = { "processor_class": "QWEN3VoxProcessor", "speech_tok_compress_ratio": self.speech_tok_compress_ratio, "db_normalize": self.db_normalize, "audio_processor": { "feature_extractor_type": "QWEN3VoxTokenizerProcessor", "sampling_rate": getattr(self.audio_processor, "sampling_rate", 22050), "normalize_audio": getattr( self.audio_processor, "normalize_audio", True ), "target_dB_FS": getattr(self.audio_processor, "target_dB_FS", -25), "eps": getattr(self.audio_processor, "eps", 1e-06), }, } config_path = os.path.join(save_directory, "preprocessor_config.json") with open(config_path, "w") as f: json.dump(processor_config, f, indent=2) logger.info(f"Processor configuration saved in {config_path }") def __call__( self, text: Optional[ Union[ str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], ] ] = None, voice_samples: Optional[ Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]] ] = None, padding: Union[bool, str, PaddingStrategy] = True, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_attention_mask: bool = True, **kwargs, ) -> BatchEncoding: if isinstance(text, str) or ( isinstance(text, list) and len(text) > 0 and (not isinstance(text[0], str)) ): texts = [text] is_batched = False else: texts = text is_batched = True if voice_samples is not None: if not is_batched or isinstance(voice_samples[0], (str, np.ndarray)): voice_samples_list = [voice_samples] else: voice_samples_list = voice_samples else: voice_samples_list = [None] * len(texts) all_encodings = [] for text_input, voice_input in zip(texts, voice_samples_list): encoding = self._process_single(text_input, voice_input) all_encodings.append(encoding) batch_encoding = self._batch_encode( all_encodings, padding=padding, truncation=truncation, max_length=max_length, return_tensors=return_tensors, return_attention_mask=return_attention_mask, ) return batch_encoding def _process_single( self, text: Union[str, TextInput], voice_samples: Optional[List[Union[str, np.ndarray]]] = None, ) -> Dict[str, Any]: script = None if isinstance(text, str): if text.endswith(".json") and os.path.exists(text): script = self._convert_json_to_script(text) elif text.endswith(".txt") and os.path.exists(text): script = self._convert_text_to_script(text) else: script = text if script is None: raise ValueError(f"Could not process input text: {text }") parsed_lines = self._parse_script(script) all_speakers = list(set((speaker_id for speaker_id, _ in parsed_lines))) system_tokens = self.tokenizer.encode(self.system_prompt) if voice_samples: voice_tokens, voice_speech_inputs, voice_speech_masks = ( self._create_voice_prompt(voice_samples[: len(all_speakers)]) ) else: voice_tokens, voice_speech_inputs, voice_speech_masks = ([], [], []) full_tokens = system_tokens + voice_tokens speech_input_mask = [False] * len(system_tokens) + voice_speech_masks full_tokens += self.tokenizer.encode(" Text input:\n", add_special_tokens=False) speech_input_mask += [False] * len( self.tokenizer.encode(" Text input:\n", add_special_tokens=False) ) for speaker_id, speaker_text in parsed_lines: speaker_text_tokens = self.tokenizer.encode( f" Speaker {speaker_id }:{speaker_text }\n", add_special_tokens=False ) full_tokens += speaker_text_tokens speech_input_mask += [False] * len(speaker_text_tokens) full_tokens += self.tokenizer.encode( " Speech output:\n", add_special_tokens=False ) + [self.tokenizer.speech_start_id] speech_input_mask += [False] * ( len(self.tokenizer.encode(" Speech output:\n", add_special_tokens=False)) + 1 ) return { "input_ids": full_tokens, "speech_inputs": voice_speech_inputs if voice_speech_inputs else None, "speech_input_mask": speech_input_mask, "parsed_script": parsed_lines, "all_speakers": all_speakers, } def _batch_encode( self, encodings: List[Dict[str, Any]], padding: Union[bool, str, PaddingStrategy] = True, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_attention_mask: bool = True, ) -> BatchEncoding: input_ids_list = [enc["input_ids"] for enc in encodings] speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings] if isinstance(padding, bool): padding_strategy = ( PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD ) elif isinstance(padding, str): padding_strategy = PaddingStrategy(padding) else: padding_strategy = padding if padding_strategy != PaddingStrategy.DO_NOT_PAD: if padding_strategy == PaddingStrategy.LONGEST: max_len = max((len(ids) for ids in input_ids_list)) elif ( padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None ): max_len = max_length else: max_len = max((len(ids) for ids in input_ids_list)) padded_input_ids = [] attention_masks = [] padded_speech_input_masks = [] for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list): if truncation and len(input_ids) > max_len: input_ids = input_ids[:max_len] speech_mask = speech_mask[:max_len] padding_length = max_len - len(input_ids) padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids attention_mask = [0] * padding_length + [1] * len(input_ids) padded_speech_mask = [False] * padding_length + speech_mask padded_input_ids.append(padded_ids) attention_masks.append(attention_mask) padded_speech_input_masks.append(padded_speech_mask) input_ids_list = padded_input_ids speech_input_masks_list = padded_speech_input_masks else: attention_masks = ( [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None ) all_speech_inputs = [] has_speech = False for enc in encodings: if enc["speech_inputs"] is not None: all_speech_inputs.extend(enc["speech_inputs"]) has_speech = True batch_encoding = BatchEncoding() if return_tensors is not None: batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) if return_attention_mask and attention_masks is not None: batch_encoding["attention_mask"] = torch.tensor( attention_masks, dtype=torch.long ) batch_encoding["speech_input_mask"] = torch.tensor( speech_input_masks_list, dtype=torch.bool ) else: batch_encoding["input_ids"] = input_ids_list if return_attention_mask and attention_masks is not None: batch_encoding["attention_mask"] = attention_masks batch_encoding["speech_input_mask"] = speech_input_masks_list if has_speech: speech_dict = self.prepare_speech_inputs( all_speech_inputs, return_tensors=return_tensors ) batch_encoding["speech_tensors"] = speech_dict["padded_speeches"] batch_encoding["speech_masks"] = speech_dict["speech_masks"] else: batch_encoding["speech_tensors"] = None batch_encoding["speech_masks"] = None batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings] batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings] return batch_encoding def _create_voice_prompt( self, speaker_samples: List[Union[str, np.ndarray]] ) -> Tuple[List[int], List[np.ndarray], List[bool]]: vae_token_id = self.tokenizer.speech_diffusion_id voice_full_tokens = self.tokenizer.encode( " Voice input:\n", add_special_tokens=False ) voice_speech_inputs = [] voice_speech_masks = [False] * len(voice_full_tokens) for speaker_id, speaker_audio in enumerate(speaker_samples): prefix_tokens = self.tokenizer.encode( f" Speaker {speaker_id }:", add_special_tokens=False ) if isinstance(speaker_audio, str): wav = self.audio_processor._load_audio_from_path(speaker_audio) elif isinstance(speaker_audio, dict): if "array" in speaker_audio: wav = np.array(speaker_audio["array"], dtype=np.float32) elif "audio" in speaker_audio: wav = np.array(speaker_audio["audio"], dtype=np.float32) else: raise ValueError( f"Dictionary audio input must have 'array' or 'audio' key, got: {speaker_audio .keys ()}" ) else: wav = np.array(speaker_audio, dtype=np.float32) if self.db_normalize and self.audio_normalizer: wav = self.audio_normalizer(wav) vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio) speaker_tokens = ( prefix_tokens + [self.tokenizer.speech_start_id] + [vae_token_id] * vae_tok_len + [self.tokenizer.speech_end_id] + self.tokenizer.encode("\n", add_special_tokens=False) ) vae_input_mask = ( [False] * len(prefix_tokens) + [False] + [True] * vae_tok_len + [False] + [False] ) voice_full_tokens.extend(speaker_tokens) voice_speech_masks.extend(vae_input_mask) voice_speech_inputs.append(wav) return (voice_full_tokens, voice_speech_inputs, voice_speech_masks) def prepare_speech_inputs( self, speech_inputs: List[np.ndarray], return_tensors: Optional[Union[str, TensorType]] = None, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, ) -> Dict[str, Any]: if not speech_inputs: return {"padded_speeches": None, "speech_masks": None} vae_tok_seqlens = [ math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs ] max_speech_length = max((s.shape[0] for s in speech_inputs)) if speech_inputs[0].ndim == 1: padded_speeches = np.full( (len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32 ) else: padded_speeches = np.full( (len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32, ) speech_masks = np.zeros( (len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_ ) for i, (speech, vae_tok_length) in enumerate( zip(speech_inputs, vae_tok_seqlens) ): padded_speeches[i, : len(speech)] = speech speech_masks[i, :vae_tok_length] = True result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks} if return_tensors == "pt": result["padded_speeches"] = torch.tensor( padded_speeches, device=device, dtype=dtype or torch.float32 ) result["speech_masks"] = torch.tensor( speech_masks, device=device, dtype=torch.bool ) return result def _convert_json_to_script(self, json_file: str) -> str: import json with open(json_file, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): raise ValueError("JSON file must contain a list of speaker entries") script_lines = [] for item in data: if not isinstance(item, dict): logger.warning(f"Skipping non-dict entry: {item }") continue speaker = item.get("speaker") text = item.get("text") if speaker is None or text is None: logger.warning(f"Skipping entry missing speaker or text: {item }") continue try: speaker_id = int(speaker) except (ValueError, TypeError): logger.warning(f"Invalid speaker ID: {speaker }, skipping entry") continue text = text.strip() if text: script_lines.append(f"Speaker {speaker_id }: {text }") if not script_lines: raise ValueError("No valid entries found in JSON file") return "\n".join(script_lines) def _convert_text_to_script(self, text_file: str) -> str: with open(text_file, "r", encoding="utf-8") as f: lines = f.readlines() script_lines = [] current_speaker = 1 for line in lines: line = line.strip() if not line: continue speaker_match = re.match( "^Speaker\\s+(\\d+)\\s*:\\s*(.*)$", line, re.IGNORECASE ) if speaker_match: speaker_id = int(speaker_match.group(1)) text = speaker_match.group(2).strip() if text: script_lines.append(f"Speaker {speaker_id }: {text }") else: script_lines.append(f"Speaker {current_speaker }: {line }") if not script_lines: raise ValueError("No valid content found in text file") return "\n".join(script_lines) def _parse_script(self, script: str) -> List[Tuple[int, str]]: lines = script.strip().split("\n") parsed_lines = [] speaker_ids = [] for line in lines: if not line.strip(): continue match = re.match( "^Speaker\\s+(\\d+)\\s*:\\s*(.*)$", line.strip(), re.IGNORECASE ) if match: speaker_id = int(match.group(1)) text = " " + match.group(2).strip() parsed_lines.append((speaker_id, text)) speaker_ids.append(speaker_id) elif line.strip(): # Vocence validators send plain transcription (no "Speaker N:" prefix). parsed_lines.append((1, " " + line.strip())) speaker_ids.append(1) if not parsed_lines: raise ValueError("No valid speaker lines found in script") min_speaker_id = min(speaker_ids) if min_speaker_id > 0: normalized_lines = [] for speaker_id, text in parsed_lines: normalized_lines.append((speaker_id - 1, text)) return normalized_lines else: return parsed_lines def _merge_inputs( self, text_inputs: BatchEncoding, audio_inputs: Dict ) -> BatchEncoding: merged = BatchEncoding(text_inputs) if "audio" in audio_inputs: merged["speech_inputs"] = audio_inputs["audio"] if "streaming" in audio_inputs: merged["streaming"] = audio_inputs["streaming"] return merged def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names audio_processor_input_names = self.audio_processor.model_input_names return list( dict.fromkeys( tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"] ) ) def save_audio( self, audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], output_path: str = "output.wav", sampling_rate: Optional[int] = None, normalize: bool = False, batch_prefix: str = "audio_", ) -> str: return self.audio_processor.save_audio( audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix, ) __all__ = [ 'QWEN3VoxProcessor' ] '\nQWEN3Vox Streaming Processor\n\nThis processor handles input preparation for the streaming 0.5B model,\nincluding text tokenization and cached voice prompt handling.\n' import math import warnings from typing import List, Optional, Union, Dict, Any, Tuple import os import re import numpy as np import torch from transformers.tokenization_utils_base import ( BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy, ) from transformers.utils import TensorType, logging logger = logging.get_logger(__name__) class QWEN3VoxStreamingProcessor: def __init__( self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs, ): self.tokenizer = tokenizer self.audio_processor = audio_processor self.speech_tok_compress_ratio = speech_tok_compress_ratio self.db_normalize = db_normalize self.audio_normalizer = AudioNormalizer() if db_normalize else None @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): import os import json from transformers.utils import cached_file config_path = os.path.join( pretrained_model_name_or_path, "preprocessor_config.json" ) config = None if os.path.exists(config_path): with open(config_path, "r") as f: config = json.load(f) else: try: config_file = cached_file( pretrained_model_name_or_path, "preprocessor_config.json", **kwargs ) with open(config_file, "r") as f: config = json.load(f) except Exception as e: logger.warning( f"Could not load preprocessor_config.json from {pretrained_model_name_or_path }: {e }" ) logger.warning("Using default configuration") config = {"speech_tok_compress_ratio": 3200, "db_normalize": True} speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) db_normalize = config.get("db_normalize", True) model_name = pretrained_model_name_or_path logger.info(f"Loading tokenizer from {model_name }") tokenizer = QWEN3VoxTextTokenizerFast.from_pretrained( model_name, **kwargs ) if "audio_processor" in config: audio_config = config["audio_processor"] audio_processor = QWEN3VoxTokenizerProcessor( sampling_rate=audio_config.get("sampling_rate", 22050), normalize_audio=audio_config.get("normalize_audio", True), target_dB_FS=audio_config.get("target_dB_FS", -25), eps=audio_config.get("eps", 1e-06), ) else: audio_processor = QWEN3VoxTokenizerProcessor() return cls( tokenizer=tokenizer, audio_processor=audio_processor, speech_tok_compress_ratio=speech_tok_compress_ratio, db_normalize=db_normalize, ) def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): import os import json os.makedirs(save_directory, exist_ok=True) processor_config = { "processor_class": "QWEN3VoxStreamingProcessor", "speech_tok_compress_ratio": self.speech_tok_compress_ratio, "db_normalize": self.db_normalize, "audio_processor": { "feature_extractor_type": "QWEN3VoxTokenizerProcessor", "sampling_rate": getattr(self.audio_processor, "sampling_rate", 22050), "normalize_audio": getattr( self.audio_processor, "normalize_audio", True ), "target_dB_FS": getattr(self.audio_processor, "target_dB_FS", -25), "eps": getattr(self.audio_processor, "eps", 1e-06), }, } config_path = os.path.join(save_directory, "preprocessor_config.json") with open(config_path, "w") as f: json.dump(processor_config, f, indent=2) logger.info(f"Processor configuration saved in {config_path }") def __call__(self) -> BatchEncoding: raise NotImplementedError( 'QWEN3VoxStreamingProcessor.__call__ is not implemented. Use process_input_with_cached_prompt for streaming inputs.' ) def process_input_with_cached_prompt( self, text: Optional[str] = None, cached_prompt: Optional[Dict[str, Any]] = None, padding: Union[bool, str, PaddingStrategy] = True, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_attention_mask: bool = True, **kwargs, ) -> BatchEncoding: texts = [text] cached_prompts = [cached_prompt] is_batched = False all_encodings = [] for text_input, cached_prompt_input in zip(texts, cached_prompts): script_tokens = self.tokenizer.encode( text_input.strip() + "\n", add_special_tokens=False ) input_id_length = cached_prompt_input["lm"]["last_hidden_state"].size(1) tts_lm_input_id_length = cached_prompt_input["tts_lm"][ "last_hidden_state" ].size(1) input_ids = [self.tokenizer.pad_id] * input_id_length tts_lm_input_ids = [self.tokenizer.pad_id] * tts_lm_input_id_length speech_input_mask = [False] * tts_lm_input_id_length encoding = { "input_ids": input_ids, "tts_lm_input_ids": tts_lm_input_ids, "tts_text_ids": script_tokens, "speech_inputs": None, "speech_input_mask": speech_input_mask, } all_encodings.append(encoding) batch_encoding = self._batch_encode( all_encodings, padding=padding, truncation=truncation, max_length=max_length, return_tensors=return_tensors, return_attention_mask=return_attention_mask, ) return batch_encoding def _batch_encode( self, encodings: List[Dict[str, Any]], padding: Union[bool, str, PaddingStrategy] = True, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_attention_mask: bool = True, ) -> BatchEncoding: input_ids_list = [enc["input_ids"] for enc in encodings] tts_lm_input_ids_list = [enc["tts_lm_input_ids"] for enc in encodings] tts_text_ids_list = [enc["tts_text_ids"] for enc in encodings] speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings] attention_masks = ( [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None ) tts_lm_attention_masks = ( [[1] * len(ids) for ids in tts_lm_input_ids_list] if return_attention_mask else None ) all_speech_inputs = [] has_speech = False for enc in encodings: if enc["speech_inputs"] is not None: all_speech_inputs.extend(enc["speech_inputs"]) has_speech = True batch_encoding = BatchEncoding() if return_tensors is not None: batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) batch_encoding["tts_lm_input_ids"] = torch.tensor( tts_lm_input_ids_list, dtype=torch.long ) batch_encoding["tts_text_ids"] = torch.tensor( tts_text_ids_list, dtype=torch.long ) if return_attention_mask and attention_masks is not None: batch_encoding["attention_mask"] = torch.tensor( attention_masks, dtype=torch.long ) batch_encoding["tts_lm_attention_mask"] = torch.tensor( tts_lm_attention_masks, dtype=torch.long ) batch_encoding["speech_input_mask"] = torch.tensor( speech_input_masks_list, dtype=torch.bool ) else: batch_encoding["input_ids"] = input_ids_list batch_encoding["tts_lm_input_ids"] = tts_lm_input_ids_list batch_encoding["tts_text_ids"] = tts_text_ids_list if return_attention_mask and attention_masks is not None: batch_encoding["attention_mask"] = attention_masks batch_encoding["tts_lm_attention_mask"] = tts_lm_attention_masks batch_encoding["speech_input_mask"] = speech_input_masks_list if has_speech: speech_dict = self.prepare_speech_inputs( all_speech_inputs, return_tensors=return_tensors ) batch_encoding["speech_tensors"] = speech_dict["padded_speeches"] batch_encoding["speech_masks"] = speech_dict["speech_masks"] else: batch_encoding["speech_tensors"] = None batch_encoding["speech_masks"] = None return batch_encoding def prepare_speech_inputs( self, speech_inputs: List[np.ndarray], return_tensors: Optional[Union[str, TensorType]] = None, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, ) -> Dict[str, Any]: if not speech_inputs: return {"padded_speeches": None, "speech_masks": None} vae_tok_seqlens = [ math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs ] max_speech_length = max((s.shape[0] for s in speech_inputs)) if speech_inputs[0].ndim == 1: padded_speeches = np.full( (len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32 ) else: padded_speeches = np.full( (len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32, ) speech_masks = np.zeros( (len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_ ) for i, (speech, vae_tok_length) in enumerate( zip(speech_inputs, vae_tok_seqlens) ): padded_speeches[i, : len(speech)] = speech speech_masks[i, :vae_tok_length] = True result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks} if return_tensors == "pt": result["padded_speeches"] = torch.tensor( padded_speeches, device=device, dtype=dtype or torch.float32 ) result["speech_masks"] = torch.tensor( speech_masks, device=device, dtype=torch.bool ) return result def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names audio_processor_input_names = self.audio_processor.model_input_names return list( dict.fromkeys( tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"] ) ) def save_audio( self, audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], output_path: str = "output.wav", sampling_rate: Optional[int] = None, normalize: bool = False, batch_prefix: str = "audio_", ) -> str: return self.audio_processor.save_audio( audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix, ) __all__ = [ 'QWEN3VoxStreamingProcessor' ] '\nQWEN3Vox Streaming Model Architecture (0.5B)\n\nThis module implements the streaming-optimized version of QWEN3Vox for real-time TTS.\nKey differences from the multi-speaker model:\n- No semantic tokenizer (only acoustic)\n- Split language model architecture: lower layers for text, upper layers for TTS\n- Optimized for low-latency generation\n' from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union, Callable from tqdm import tqdm import copy import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from transformers.models.auto import AutoModel, AutoModelForCausalLM from transformers.activations import ACT2FN from transformers.modeling_outputs import ( CausalLMOutput, BaseModelOutputWithPast, ModelOutput, ) from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers import modeling_utils from transformers.modeling_utils import PreTrainedModel from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.utils import logging logger = logging.get_logger(__name__) if ( not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None ): modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] class BinaryClassifier(nn.Module): def __init__(self, hidden_size): super(BinaryClassifier, self).__init__() self.fc1 = nn.Linear(hidden_size, hidden_size) self.fc2 = nn.Linear(hidden_size, 1) def forward(self, x): x = torch.relu(self.fc1(x)) x = self.fc2(x) return x class SpeechConnector(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.fc1 = nn.Linear(input_dim, output_dim) self.norm = LlamaRMSNorm(output_dim, eps=1e-06) self.fc2 = nn.Linear(output_dim, output_dim) def forward(self, features, **kwargs): x = self.fc1(features) x = self.norm(x) x = self.fc2(x) return x class QWEN3VoxStreamingPreTrainedModel(PreTrainedModel): config_class = QWEN3VoxStreamingConfig base_model_prefix = "model" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True def _init_weights(self, module): if isinstance(module, QWEN3VoxDiffusionHead): module.initialize_weights() return if hasattr(self.config, "language_model_config") and hasattr( self.config.language_model_config, "initializer_range" ): std = self.config.language_model_config.initializer_range elif hasattr(self.config, "decoder_config") and hasattr( self.config.decoder_config, "initializer_range" ): std = self.config.decoder_config.initializer_range else: std = 0.02 if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) module.bias.data.zero_() class QWEN3VoxStreamingModel(QWEN3VoxStreamingPreTrainedModel): def __init__(self, config): super().__init__(config) if hasattr(config, "torch_dtype") and config.torch_dtype is not None: if isinstance(config.torch_dtype, str): dtype = getattr(torch, config.torch_dtype) else: dtype = config.torch_dtype else: dtype = torch.float32 lm_config = copy.deepcopy(config.decoder_config) lm_backbone_num_hidden_layers = ( getattr(lm_config, "num_hidden_layers", 24) - config.tts_backbone_num_hidden_layers ) lm_config.num_hidden_layers = lm_backbone_num_hidden_layers self.language_model = AutoModel.from_config(lm_config) self.language_model.norm = nn.Identity() tts_lm_config = copy.deepcopy(lm_config) tts_lm_config.num_hidden_layers = config.tts_backbone_num_hidden_layers self.tts_language_model = AutoModel.from_config(tts_lm_config) self.tts_input_types = nn.Embedding( num_embeddings=2, embedding_dim=config.decoder_config.hidden_size ) self.acoustic_tokenizer = AutoModel.from_config( config.acoustic_tokenizer_config ).to(dtype) self.acoustic_connector = SpeechConnector( config.acoustic_vae_dim, lm_config.hidden_size ).to(dtype) self.register_buffer("speech_scaling_factor", torch.tensor(float("nan"))) self.register_buffer("speech_bias_factor", torch.tensor(float("nan"))) self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to( dtype ) self.noise_scheduler = DPMSolverMultistepScheduler( num_train_timesteps=config.diffusion_head_config.ddpm_num_steps, beta_schedule=config.diffusion_head_config.ddpm_beta_schedule, prediction_type=config.diffusion_head_config.prediction_type, ) def get_input_embeddings(self): if hasattr(self.language_model, "embed_tokens"): return self.language_model.embed_tokens for name, attr in self.language_model.fullmap.items(): if attr.orig_name == "embed_tokens.weight": return getattr(self.language_model, name) assert False, "should not arrive here" def set_input_embeddings(self, value): self.language_model.embed_tokens = value def set_speech_tokenizers(self, acoustic_tokenizer=None): self.acoustic_tokenizer = acoustic_tokenizer if self.acoustic_tokenizer is not None: self.acoustic_tokenizer.train(False) def forward(self, *args, **kwargs): raise RuntimeError( 'QWEN3VoxStreamingModel.forward is intentionally disabled. Use `model.language_model(...)` or `model.tts_language_model(...)` instead.' ) AutoModel.register(QWEN3VoxStreamingConfig, QWEN3VoxStreamingModel) __all__ = [ 'QWEN3VoxStreamingPreTrainedModel', 'QWEN3VoxStreamingModel', "BinaryClassifier", "SpeechConnector", ] from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union, Callable from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from transformers.models.auto import AutoModel, AutoModelForCausalLM from transformers.activations import ACT2FN from transformers.modeling_outputs import ( CausalLMOutput, BaseModelOutputWithPast, ModelOutput, ) from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers import modeling_utils from transformers.modeling_utils import PreTrainedModel from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.utils import logging logger = logging.get_logger(__name__) if ( not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None ): modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] @dataclass class QWEN3VoxCausalLMOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None diffusion_loss: Optional[torch.FloatTensor] = None speech_token_num: Optional[int] = None logits: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None @dataclass class QWEN3VoxGenerationOutput(ModelOutput): sequences: torch.LongTensor = None speech_outputs: Optional[List[torch.FloatTensor]] = None class SpeechConnector(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.fc1 = nn.Linear(input_dim, output_dim) self.norm = LlamaRMSNorm(output_dim, eps=1e-06) self.fc2 = nn.Linear(output_dim, output_dim) def forward(self, features, **kwargs): x = self.fc1(features) x = self.norm(x) x = self.fc2(x) return x class QWEN3VoxPreTrainedModel(PreTrainedModel): config_class = QWEN3VoxConfig base_model_prefix = "model" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True def _init_weights(self, module): if isinstance(module, QWEN3VoxDiffusionHead): module.initialize_weights() return if hasattr(self.config, "language_model_config") and hasattr( self.config.language_model_config, "initializer_range" ): std = self.config.language_model_config.initializer_range elif hasattr(self.config, "decoder_config") and hasattr( self.config.decoder_config, "initializer_range" ): std = self.config.decoder_config.initializer_range else: std = 0.02 if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) module.bias.data.zero_() class QWEN3VoxModel(QWEN3VoxPreTrainedModel): def __init__(self, config): super().__init__(config) if hasattr(config, "torch_dtype") and config.torch_dtype is not None: if isinstance(config.torch_dtype, str): dtype = getattr(torch, config.torch_dtype) else: dtype = config.torch_dtype else: dtype = torch.float32 lm_config = config.decoder_config self.language_model = AutoModel.from_config(lm_config) self.acoustic_tokenizer = AutoModel.from_config( config.acoustic_tokenizer_config ).to(dtype) self.semantic_tokenizer = AutoModel.from_config( config.semantic_tokenizer_config ).to(dtype) self.acoustic_connector = SpeechConnector( config.acoustic_vae_dim, lm_config.hidden_size ).to(dtype) self.semantic_connector = SpeechConnector( config.semantic_vae_dim, lm_config.hidden_size ).to(dtype) self.register_buffer("speech_scaling_factor", torch.tensor(float("nan"))) self.register_buffer("speech_bias_factor", torch.tensor(float("nan"))) self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to( dtype ) self.noise_scheduler = DPMSolverMultistepScheduler( num_train_timesteps=config.diffusion_head_config.ddpm_num_steps, beta_schedule=config.diffusion_head_config.ddpm_beta_schedule, prediction_type=config.diffusion_head_config.prediction_type, ) def get_input_embeddings(self): if hasattr(self.language_model, "embed_tokens"): return self.language_model.embed_tokens for name, attr in self.language_model.fullmap.items(): if attr.orig_name == "embed_tokens.weight": return getattr(self.language_model, name) assert False, "should not arrive here" def set_input_embeddings(self, value): self.language_model.embed_tokens = value def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None): self.acoustic_tokenizer = acoustic_tokenizer self.semantic_tokenizer = semantic_tokenizer if self.acoustic_tokenizer is not None: self.acoustic_tokenizer.train(False) if self.semantic_tokenizer is not None: self.semantic_tokenizer.train(False) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.language_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) if not return_dict: return outputs return BaseModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class QWEN3VoxForConditionalGeneration(QWEN3VoxPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) self.model = QWEN3VoxModel(config) self.vocab_size = config.decoder_config.vocab_size self.lm_head = nn.Linear( config.decoder_config.hidden_size, self.vocab_size, bias=False ) self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self): return self.lm_head def set_decoder(self, decoder): self.model.language_model = decoder def get_decoder(self): return self.model.language_model def tie_weights(self): if getattr(self.config.decoder_config, "tie_word_embeddings", False): output_embeddings = self.get_output_embeddings() input_embeddings = self.get_input_embeddings() if hasattr(input_embeddings, "weight"): output_embeddings.weight = input_embeddings.weight else: output_embeddings.weight = input_embeddings if getattr(output_embeddings, "bias", None) is not None: output_embeddings.bias.data = nn.functional.pad( output_embeddings.bias.data, ( 0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0], ), "constant", 0, ) print("Tied input and output embeddings using standard assignment.") else: print("tie_word_embeddings is False, not tying weights.") def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward_speech_features( self, speech_tensors=None, speech_masks=None, speech_type="audio", return_unmask=False, ): if speech_tensors is None: vae_dim = self.config.acoustic_tokenizer_config.vae_dim audio_features = torch.zeros(1, 1, vae_dim).to( self.get_input_embeddings().weight ) connect_features = self.model.acoustic_connector(audio_features) return (audio_features, connect_features) else: with torch.no_grad(): if speech_type == "audio": with torch.no_grad(): frames = self.model.acoustic_tokenizer.encode( speech_tensors.unsqueeze(1) )[0][0] audio_tokens = frames.sample( self.model.acoustic_tokenizer.std_dist_type )[0] elif speech_type == "vae": vae_dim = self.config.acoustic_tokenizer_config.vae_dim speech_mode = speech_tensors.reshape( speech_tensors.size(0), -1, vae_dim ) batch_size = speech_mode.size(0) value = self.model.acoustic_tokenizer.fix_std / 0.8 std = ( torch.randn( batch_size, dtype=speech_mode.dtype, device=speech_mode.device, ) * value ) std = std.view(-1, *[1] * (speech_mode.dim() - 1)) audio_tokens = speech_mode + std * torch.randn( speech_mode.shape ).to(speech_mode) else: raise NotImplementedError( f"Speech type {speech_type } not implemented" ) if torch.isnan(self.model.speech_scaling_factor) or torch.isnan( self.model.speech_bias_factor ): scaling_factor = 1.0 / audio_tokens[speech_masks].flatten().std() bias_factor = -audio_tokens[speech_masks].flatten().mean() if dist.is_available() and dist.is_initialized(): dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM) dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM) world_size = dist.get_world_size() self.model.speech_scaling_factor.copy_( scaling_factor / world_size ) self.model.speech_bias_factor.copy_(bias_factor / world_size) print( f"Speech scaling factor (distributed): {self .model .speech_scaling_factor }, bias factor: {self .model .speech_bias_factor }", flush=True, ) else: self.model.speech_scaling_factor.copy_(scaling_factor) self.model.speech_bias_factor.copy_(bias_factor) print( f"Speech scaling factor (single process): {self .model .speech_scaling_factor }, bias factor: {self .model .speech_bias_factor }", flush=True, ) audio_features = ( audio_tokens + self.model.speech_bias_factor ) * self.model.speech_scaling_factor connect_features = self.model.acoustic_connector(audio_features) if return_unmask: return (audio_features, connect_features) return (audio_features[speech_masks], connect_features[speech_masks]) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, speech_tensors: Optional[torch.FloatTensor] = None, speech_masks: Optional[torch.BoolTensor] = None, speeches_loss_input: Optional[torch.FloatTensor] = None, speech_semantic_tensors: Optional[torch.FloatTensor] = None, acoustic_input_mask: Optional[torch.BoolTensor] = None, acoustic_loss_mask: Optional[torch.BoolTensor] = None, ddpm_batch_mul: int = 1, **kwargs: Optional[Dict[str, Union[torch.Tensor, str]]], ) -> Union[Tuple, QWEN3VoxCausalLMOutputWithPast]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) x = self.get_input_embeddings()(input_ids) semantic_speech_all_connect_features = self.model.semantic_connector( speech_semantic_tensors ) if speeches_loss_input is not None: speech_all_features, speech_all_connect_features = ( self.forward_speech_features( speech_tensors=( speech_tensors.type_as(x) if speech_tensors is not None else None ), speech_masks=speech_masks, speech_type=kwargs.get("speech_type", "audio"), return_unmask=True, ) ) if speech_tensors is not None: if semantic_speech_all_connect_features is not None: x[acoustic_input_mask] = ( speech_all_connect_features[speech_masks] + semantic_speech_all_connect_features[speech_masks] ) else: x[acoustic_input_mask] = speech_all_connect_features[speech_masks] target_latent_mask = speeches_loss_input & speech_masks speech_features = speech_all_features[target_latent_mask] speech_connect_features = speech_all_connect_features[ target_latent_mask ] else: speech_features, speech_connect_features = self.forward_speech_features( speech_tensors=( speech_tensors.type_as(x) if speech_tensors is not None else None ), speech_masks=speech_masks, speech_type=kwargs.get("speech_type", "audio"), ) if speech_tensors is not None: x[acoustic_input_mask] = speech_connect_features outputs = self.model( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=x, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=False, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states) loss = None if labels is not None: pass diffusion_loss = None if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0: condition_features = hidden_states[acoustic_loss_mask] speech_len, latent_size = speech_features.shape noise = torch.randn( (speech_len * ddpm_batch_mul, latent_size), device=hidden_states.device, dtype=hidden_states.dtype, ) timesteps = torch.multinomial( torch.ones(self.config.diffusion_head_config.ddpm_num_steps), speech_len * ddpm_batch_mul, replacement=True, ).to(hidden_states.device) speech_features_repeated = speech_features.repeat_interleave( ddpm_batch_mul, dim=0 ) condition_features_repeated = condition_features.repeat_interleave( ddpm_batch_mul, dim=0 ) noisy_speech_features = self.model.noise_scheduler.add_noise( speech_features_repeated, noise, timesteps ) model_output = self.model.prediction_head( noisy_speech_features, timesteps.type_as(x), condition_features_repeated ) prediction_type = self.config.diffusion_head_config.prediction_type if prediction_type == "epsilon": target_for_loss = noise elif prediction_type == "v_prediction": target_for_loss = self.model.noise_scheduler.get_velocity( speech_features_repeated, noise, timesteps ) else: raise NotImplementedError( f"Prediction type {prediction_type } not implemented" ) diffusion_loss = F.mse_loss( model_output.float(), target_for_loss.float(), reduction="sum" ) if latent_size > 0 and ddpm_batch_mul > 0: diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul else: diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device) else: diffusion_loss = ( sum((p.sum() for p in self.model.prediction_head.parameters())) * 0.0 ) diffusion_loss += ( sum((p.sum() for p in self.model.acoustic_connector.parameters())) * 0.0 ) diffusion_loss += ( sum((p.sum() for p in self.model.semantic_connector.parameters())) * 0.0 ) if not return_dict: output = (logits, speech_len) + outputs.to_tuple()[1:] return (loss, diffusion_loss) + output return QWEN3VoxCausalLMOutputWithPast( loss=loss, diffusion_loss=diffusion_loss, speech_token_num=speech_len if speech_tensors is not None else 0, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) AutoModel.register(QWEN3VoxConfig, QWEN3VoxModel) AutoModelForCausalLM.register(QWEN3VoxConfig, QWEN3VoxForConditionalGeneration) __all__ = [ 'QWEN3VoxModel', 'QWEN3VoxPreTrainedModel', 'QWEN3VoxForConditionalGeneration', 'QWEN3VoxCausalLMOutputWithPast', 'QWEN3VoxGenerationOutput', ] '\nQWEN3Vox Processors\n\nThis module provides processors for preparing inputs for QWEN3Vox models:\n- QWEN3VoxProcessor: For multi-speaker models (1.5B, 7B)\n- QWEN3VoxStreamingProcessor: For streaming model (0.5B)\n' __all__ = [ 'QWEN3VoxProcessor', 'QWEN3VoxStreamingProcessor', 'QWEN3VoxTokenizerProcessor', "AudioNormalizer", 'QWEN3VoxASRProcessor', ] '\nQWEN3Vox Streaming Inference Model (0.5B)\n\nThis module implements the inference engine for real-time streaming TTS.\nKey features:\n- Window-based text/speech interleaving for streaming\n- Binary EOS classifier for end-of-speech detection\n- Classifier-free guidance for speech quality\n- Audio streaming support\n' from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union, Callable from tqdm import tqdm import torch import torch.nn as nn from transformers.models.auto import AutoModel, AutoModelForCausalLM from transformers.generation import ( GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList, ) from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput from transformers import modeling_utils from transformers.modeling_utils import PreTrainedModel from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.utils import logging logger = logging.get_logger(__name__) if ( not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None ): modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] TTS_TEXT_WINDOW_SIZE = 5 TTS_SPEECH_WINDOW_SIZE = 6 def _update_model_kwargs_for_generation( outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1 ) -> Dict[str, Any]: model_kwargs["past_key_values"] = getattr(outputs, "past_key_values") attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [ attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_tokens)), ], dim=-1, ) model_kwargs["cache_position"] = torch.arange( model_kwargs["cache_position"][-1] + 1, model_kwargs["cache_position"][-1] + num_new_tokens + 1, ).to(model_kwargs["cache_position"].device) return model_kwargs @dataclass class QWEN3VoxCausalLMOutputWithPast(BaseModelOutputWithPast): logits: Optional[torch.FloatTensor] = None @dataclass class QWEN3VoxGenerationOutput(ModelOutput): sequences: torch.LongTensor = None speech_outputs: Optional[List[torch.FloatTensor]] = None reach_max_step_sample: Optional[torch.BoolTensor] = None class QWEN3VoxStreamingForConditionalGenerationInference( QWEN3VoxStreamingPreTrainedModel, GenerationMixin ): def __init__(self, config): super().__init__(config) self.model = QWEN3VoxStreamingModel(config) self.tts_eos_classifier = BinaryClassifier(config.decoder_config.hidden_size) self.ddpm_inference_steps = ( config.diffusion_head_config.ddpm_num_inference_steps ) self.post_init() @property def noise_scheduler(self): return self.model.noise_scheduler @property def prediction_head(self): return self.model.prediction_head @property def speech_scaling_factor(self): return self.model.speech_scaling_factor @property def speech_bias_factor(self): return self.model.speech_bias_factor @property def acoustic_tokenizer(self): return self.model.acoustic_tokenizer @property def acoustic_connector(self): return self.model.acoustic_connector def tie_weights(self): if not getattr(self.config, "tie_word_embeddings", False): return if hasattr(self, "lm_head") and hasattr( self.model.language_model, "embed_tokens" ): self.lm_head.weight = self.model.language_model.embed_tokens.weight def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self): return None def set_output_embeddings(self, new_embeddings): raise RuntimeError( "Output embeddings (lm_head) are not defined for this model. Create one before calling set_output_embeddings if needed." ) def set_speech_tokenizers(self, acoustic_tokenizer=None): self.model.set_speech_tokenizers(acoustic_tokenizer) def set_ddpm_inference_steps(self, num_steps=None): self.ddpm_inference_steps = ( num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps ) def forward_lm( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if inputs_embeds is None: inputs_embeds = self.model.get_input_embeddings()(input_ids) outputs = self.model.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state if labels is not None: raise NotImplementedError( "Loss computation is not implemented in this version." ) return BaseModelOutputWithPast( past_key_values=outputs.past_key_values, last_hidden_state=hidden_states, attentions=outputs.attentions, ) def forward_tts_lm( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, lm_last_hidden_state: Optional[torch.FloatTensor] = None, tts_text_masks: Optional[torch.BoolTensor] = None, **kwargs, ) -> Union[Tuple, QWEN3VoxCausalLMOutputWithPast]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if inputs_embeds is None: inputs_embeds = self.model.get_input_embeddings()(input_ids) start_idx = inputs_embeds.shape[1] - lm_last_hidden_state.shape[1] inputs_embeds[:, start_idx:, :] = lm_last_hidden_state inputs_embeds = inputs_embeds + self.model.tts_input_types( tts_text_masks.long() ) outputs = self.model.tts_language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state logits = self.tts_eos_classifier(hidden_states[:, -1, :]) if labels is not None: raise NotImplementedError( "Loss computation is not implemented in this version." ) return QWEN3VoxCausalLMOutputWithPast( logits=logits, past_key_values=outputs.past_key_values, last_hidden_state=hidden_states, attentions=outputs.attentions, ) def forward(self, *args, **kwargs): raise RuntimeError( "Unified forward is disabled. Use `forward_lm`, `forward_tts_lm`, or `generate` instead." ) def _build_generate_config_model_kwargs( self, generation_config, inputs, tokenizer, return_processors=False, **kwargs ): if generation_config is None: generation_config = GenerationConfig( bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) else: generation_config = GenerationConfig( **generation_config, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) generation_config, model_kwargs = self._prepare_generation_config( generation_config, True, speech_start_id=tokenizer.speech_start_id, speech_end_id=tokenizer.speech_end_id, speech_diffusion_id=tokenizer.speech_diffusion_id, **kwargs, ) generation_config.speech_start_id = tokenizer.speech_start_id generation_config.speech_end_id = tokenizer.speech_end_id generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] device = self.device self._prepare_special_tokens(generation_config, True, device=device) generation_config.use_cache = True model_kwargs["use_cache"] = generation_config.use_cache input_ids = inputs_tensor.to(self.device) input_ids_length = input_ids.shape[1] has_default_max_length = ( kwargs.get("max_length") is None and generation_config.max_length is not None ) has_default_min_length = ( kwargs.get("min_length") is None and generation_config.min_length is not None ) generation_config = self._prepare_generated_length( generation_config=generation_config, has_default_max_length=has_default_max_length, has_default_min_length=has_default_min_length, model_input_name=model_input_name, inputs_tensor=inputs_tensor, input_ids_length=input_ids_length, ) max_cache_length = generation_config.max_length - 1 self._prepare_cache_for_generation( generation_config, model_kwargs, None, batch_size, max_cache_length, device ) model_kwargs["cache_position"] = torch.arange( input_ids_length, device=device, dtype=torch.long ) for k, v in model_kwargs.items(): if isinstance(v, torch.Tensor): model_kwargs[k] = v.to(device=device) if return_processors: logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=None, logits_processor=LogitsProcessorList(), device=inputs_tensor.device, model_kwargs=model_kwargs, ) stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=StoppingCriteriaList(), ) return ( generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria, ) else: return (generation_config, model_kwargs, input_ids) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[ Callable[[int, torch.Tensor], List[int]] ] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None, negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, speech_tensors: Optional[torch.FloatTensor] = None, speech_masks: Optional[torch.BoolTensor] = None, speech_input_mask: Optional[torch.BoolTensor] = None, tts_text_ids: Optional[torch.LongTensor] = None, return_speech: bool = True, cfg_scale: float = 1.0, stop_check_fn: Optional[Callable[[], bool]] = None, **kwargs, ) -> Union[torch.LongTensor, QWEN3VoxGenerationOutput]: tokenizer = kwargs.pop("tokenizer", None) neg_text_input_id = tokenizer.convert_tokens_to_ids("<|image_pad|>") tts_lm_input_ids = kwargs.pop("tts_lm_input_ids", None) tts_lm_attention_mask = kwargs.pop("tts_lm_attention_mask", None) all_prefilled_outputs = kwargs.pop("all_prefilled_outputs", None) tts_text_ids = tts_text_ids.to(self.device) if kwargs.get("max_new_tokens", None) is None: kwargs["max_new_tokens"] = ( self.config.decoder_config.max_position_embeddings - tts_lm_input_ids.shape[-1] ) ( generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria, ) = self._build_generate_config_model_kwargs( generation_config, inputs, tokenizer, return_processors=True, **kwargs ) negative_kwargs = { "input_ids": torch.full( (kwargs["input_ids"].shape[0], 1), neg_text_input_id, dtype=torch.long, device=kwargs["input_ids"].device, ), "attention_mask": torch.ones( (kwargs["input_ids"].shape[0], 1), dtype=torch.long, device=kwargs["input_ids"].device, ), "max_new_tokens": kwargs.get("max_new_tokens", 100), } negative_generation_config, negative_model_kwargs, negative_input_ids = ( self._build_generate_config_model_kwargs( None, None, tokenizer, return_processors=False, **negative_kwargs ) ) tts_lm_kwargs = { "input_ids": tts_lm_input_ids, "attention_mask": tts_lm_attention_mask, "max_new_tokens": kwargs.get("max_new_tokens", 100), } tts_lm_generation_config, tts_lm_model_kwargs, tts_lm_input_ids = ( self._build_generate_config_model_kwargs( None, None, tokenizer, return_processors=False, **tts_lm_kwargs ) ) tts_lm_negative_kwargs = { "input_ids": torch.full( (kwargs["input_ids"].shape[0], 1), neg_text_input_id, dtype=torch.long, device=kwargs["input_ids"].device, ), "attention_mask": torch.ones( (kwargs["input_ids"].shape[0], 1), dtype=torch.long, device=kwargs["input_ids"].device, ), "max_new_tokens": kwargs.get("max_new_tokens", 100), } ( tts_lm_negative_generation_config, tts_lm_negative_model_kwargs, tts_lm_negative_input_ids, ) = self._build_generate_config_model_kwargs( None, None, tokenizer, return_processors=False, **tts_lm_negative_kwargs ) acoustic_cache = QWEN3VoxTokenizerStreamingCache() batch_size = input_ids.shape[0] assert batch_size == 1, "Currently only supports batch size == 1" device = input_ids.device finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device) verbose = kwargs.get("verbose", False) audio_chunks = [[] for _ in range(batch_size)] tts_text_window_index = 0 reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device) first_text_window_size = ( TTS_TEXT_WINDOW_SIZE if tts_text_ids.shape[1] >= TTS_TEXT_WINDOW_SIZE else tts_text_ids.shape[1] ) outputs = all_prefilled_outputs["lm"] tts_lm_outputs = all_prefilled_outputs["tts_lm"] negative_outputs = all_prefilled_outputs["neg_lm"] tts_lm_negative_outputs = all_prefilled_outputs["neg_tts_lm"] model_kwargs = _update_model_kwargs_for_generation( outputs, model_kwargs, num_new_tokens=first_text_window_size ) tts_lm_model_kwargs = _update_model_kwargs_for_generation( tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=first_text_window_size ) negative_model_kwargs = self._update_model_kwargs_for_generation( negative_outputs, negative_model_kwargs, is_encoder_decoder=False ) tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( tts_lm_negative_outputs, tts_lm_negative_model_kwargs, is_encoder_decoder=False, ) step = tts_lm_input_ids.shape[1] total_generated_speech_tokens = 0 total_prefilled_text_tokens = 0 if kwargs.get("show_progress_bar", True): progress_bar = tqdm( total=tts_lm_generation_config.max_length, desc=f"Prefilled {step } tokens, current step ({step } / {tts_lm_generation_config .max_length })", initial=step, leave=False, ) else: progress_bar = None while True: if stop_check_fn is not None and stop_check_fn(): if verbose: print(f"Generation stopped externally at step {step +1 }") if audio_streamer is not None: audio_streamer.end() break if finished_tags.all(): if hasattr(progress_bar, "set_description"): progress_bar.set_description("Generation complete") break cur_input_tts_text_ids = tts_text_ids[ :, tts_text_window_index * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 1) * TTS_TEXT_WINDOW_SIZE, ] next_text_window_size = tts_text_ids[ :, (tts_text_window_index + 1) * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 2) * TTS_TEXT_WINDOW_SIZE, ].shape[1] tts_text_window_index += 1 if cur_input_tts_text_ids.shape[1] > 0: input_ids = torch.cat([input_ids, cur_input_tts_text_ids], dim=-1) tts_lm_input_ids = torch.cat( [tts_lm_input_ids, cur_input_tts_text_ids], dim=-1 ) if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: if verbose: print( f"Reached maximum generation length {generation_config .max_length }, stopped it." ) reached_samples = torch.arange(batch_size, device=device)[ ~finished_tags ] if reached_samples.numel() > 0: reach_max_step_sample[reached_samples] = True break step += cur_input_tts_text_ids.shape[1] total_prefilled_text_tokens += cur_input_tts_text_ids.shape[1] if progress_bar is not None: progress_bar.update(cur_input_tts_text_ids.shape[1]) progress_bar.set_description( f"Prefilled {total_prefilled_text_tokens } text tokens, generated {total_generated_speech_tokens } speech tokens, current step ({step } / {tts_lm_generation_config .max_length })" ) model_inputs = self.prepare_inputs_for_generation( input_ids, **model_kwargs ) outputs = self.forward_lm( **model_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) model_kwargs = _update_model_kwargs_for_generation( outputs, model_kwargs, num_new_tokens=next_text_window_size ) tts_lm_model_inputs = self.prepare_inputs_for_generation( tts_lm_input_ids, **tts_lm_model_kwargs ) tts_lm_additional_inputs = { "tts_text_masks": torch.ones_like(tts_lm_input_ids[:, -1:]), "lm_last_hidden_state": outputs.last_hidden_state, } tts_lm_outputs = self.forward_tts_lm( **tts_lm_model_inputs, **tts_lm_additional_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) tts_lm_model_kwargs = self._update_model_kwargs_for_generation( tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False ) diffusion_indices = torch.LongTensor([0]) for cur_speech_index in range(TTS_SPEECH_WINDOW_SIZE): positive_condition = tts_lm_outputs.last_hidden_state[ diffusion_indices, -1, : ] negative_condition = tts_lm_negative_outputs.last_hidden_state[ diffusion_indices, -1, : ] speech_latent = self.sample_speech_tokens( positive_condition, negative_condition, cfg_scale=cfg_scale ).unsqueeze(1) scaled_latent = speech_latent / self.model.speech_scaling_factor.to( speech_latent.device ) - self.model.speech_bias_factor.to(speech_latent.device) audio_chunk = self.model.acoustic_tokenizer.decode( scaled_latent.to(self.model.acoustic_tokenizer.device), cache=acoustic_cache, sample_indices=diffusion_indices.to( self.model.acoustic_tokenizer.device ), use_cache=True, debug=False, ) for i, sample_idx in enumerate(diffusion_indices): idx = sample_idx.item() if not finished_tags[idx]: audio_chunks[idx].append(audio_chunk[i]) if audio_streamer is not None: audio_streamer.put(audio_chunk, diffusion_indices) acoustic_embed = self.model.acoustic_connector(speech_latent) tts_lm_input_ids = torch.cat( [tts_lm_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1, ) if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: break step += 1 total_generated_speech_tokens += 1 if progress_bar is not None: progress_bar.update(1) progress_bar.set_description( f"Prefilled {total_prefilled_text_tokens } text tokens, generated {total_generated_speech_tokens } speech tokens, current step ({step } / {tts_lm_generation_config .max_length })" ) tts_lm_model_inputs = self.prepare_inputs_for_generation( tts_lm_input_ids, **tts_lm_model_kwargs ) tts_lm_additional_inputs = { "tts_text_masks": torch.zeros_like(tts_lm_input_ids[:, -1:]), "lm_last_hidden_state": acoustic_embed, } tts_lm_outputs = self.forward_tts_lm( **tts_lm_model_inputs, **tts_lm_additional_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) if ( cur_speech_index == TTS_SPEECH_WINDOW_SIZE - 1 and next_text_window_size > 0 ): tts_lm_model_kwargs = _update_model_kwargs_for_generation( tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=next_text_window_size, ) else: tts_lm_model_kwargs = self._update_model_kwargs_for_generation( tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False ) tts_lm_negative_input_ids = torch.cat( [ tts_lm_negative_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:]), ], dim=-1, ) tts_lm_negative_model_inputs = self.prepare_inputs_for_generation( tts_lm_negative_input_ids, **tts_lm_negative_model_kwargs ) tts_lm_negative_additional_inputs = { "tts_text_masks": torch.zeros_like( tts_lm_negative_input_ids[:, -1:] ), "lm_last_hidden_state": acoustic_embed, } tts_lm_negative_outputs = self.forward_tts_lm( **tts_lm_negative_model_inputs, **tts_lm_negative_additional_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( tts_lm_negative_outputs, tts_lm_negative_model_kwargs, is_encoder_decoder=False, ) tts_eos_logits = torch.sigmoid( self.tts_eos_classifier( tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :] ) ) if tts_eos_logits[0].item() > 0.5: finished_tags[diffusion_indices] = True if audio_streamer is not None: audio_streamer.end(diffusion_indices) if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: if verbose: print( f"Reached maximum generation length {tts_lm_generation_config .max_length }, stopped it." ) reached_samples = torch.arange(batch_size, device=device)[ ~finished_tags ] if reached_samples.numel() > 0: reach_max_step_sample[reached_samples] = True break if audio_streamer is not None: audio_streamer.end() final_audio_outputs = [] for sample_chunks in audio_chunks: if sample_chunks: concatenated_audio = torch.cat(sample_chunks, dim=-1) final_audio_outputs.append(concatenated_audio) else: final_audio_outputs.append(None) if reach_max_step_sample is not None and reach_max_step_sample.any(): print( f"Reached maximum generation length {tts_lm_generation_config .max_length }, stopped it." ) return QWEN3VoxGenerationOutput( sequences=tts_lm_input_ids, speech_outputs=final_audio_outputs if return_speech else None, reach_max_step_sample=reach_max_step_sample, ) @torch.no_grad() def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0): self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps) condition = torch.cat([condition, neg_condition], dim=0).to( self.model.prediction_head.device ) speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to( condition ) for t in self.model.noise_scheduler.timesteps: half = speech[: len(speech) // 2] combined = torch.cat([half, half], dim=0) eps = self.model.prediction_head( combined, t.repeat(combined.shape[0]).to(combined), condition=condition ) cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample return speech[: len(speech) // 2] AutoModelForCausalLM.register( QWEN3VoxStreamingConfig, QWEN3VoxStreamingForConditionalGenerationInference ) __all__ = [ 'QWEN3VoxStreamingForConditionalGenerationInference', 'QWEN3VoxGenerationOutput', 'QWEN3VoxCausalLMOutputWithPast', "TTS_TEXT_WINDOW_SIZE", "TTS_SPEECH_WINDOW_SIZE", ] import logging import os from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from datasets import load_dataset, DatasetDict, VerificationMode from transformers import HfArgumentParser, Trainer, set_seed, TrainerCallback from transformers import TrainingArguments as HfTrainingArguments from peft import LoraConfig, get_peft_model, TaskType logger = logging.getLogger(__name__) import copy import torch from transformers import TrainerCallback class EmaCallback(TrainerCallback): def __init__(self, attr_path="model.prediction_head", decay=0.999, device="cpu"): self.attr_path = attr_path self.decay = float(decay) self.device = torch.device(device) self.shadow = None self._orig = None def _get_module(self, model): mod = model for name in self.attr_path.split("."): mod = getattr(mod, name) return mod def on_train_begin(self, args, state, control, model=None, **kwargs): head = self._get_module(model) self.shadow = { k: p.detach().to(self.device).clone() for k, p in head.state_dict().items() } def on_step_end(self, args, state, control, model=None, **kwargs): if self.shadow is None: return head = self._get_module(model) with torch.no_grad(): for k, v in head.state_dict().items(): self.shadow[k].mul_(self.decay).add_( v.detach().to(self.device), alpha=1.0 - self.decay ) def _swap_in_ema(self, model): head = self._get_module(model) self._orig = copy.deepcopy(head.state_dict()) head.load_state_dict(self.shadow, strict=False) def _swap_back(self, model): if self._orig is None: return head = self._get_module(model) head.load_state_dict(self._orig, strict=False) self._orig = None def on_evaluate(self, args, state, control, model=None, **kwargs): self._swap_in_ema(model) def on_evaluate_end(self, args, state, control, model=None, **kwargs): self._swap_back(model) def on_save(self, args, state, control, model=None, **kwargs): self._swap_in_ema(model) def on_save_end(self, args, state, control, model=None, **kwargs): self._swap_back(model) def on_train_end(self, args, state, control, model=None, **kwargs): self._swap_in_ema(model) @dataclass class ModelArguments: model_name_or_path: Optional[str] = field( default=None, metadata={ "help": 'Path to QWEN3Vox base model with config.json' }, ) processor_name_or_path: Optional[str] = field( default=None, metadata={ "help": "Path to processor dir (preprocessor_config.json). Defaults to model path." }, ) cache_dir: Optional[str] = field(default=None) freeze_acoustic_tokenizer: bool = field(default=True) freeze_semantic_tokenizer: bool = field(default=True) lora_r: int = field(default=8) lora_alpha: int = field(default=32) lora_dropout: float = field(default=0.05) lora_target_modules: str = field( default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj", metadata={ "help": "Comma-separated list of target module names in the LLM blocks" }, ) lora_wrap_diffusion_head: bool = field( default=False, metadata={"help": "Wrap diffusion head with PEFT LoRA"} ) train_diffusion_head: bool = field( default=False, metadata={"help": "Train diffusion prediction head (full fine-tune)"}, ) train_connectors: bool = field( default=False, metadata={"help": "Train acoustic/semantic connectors (full fine-tune)"}, ) layers_to_freeze: Optional[str] = field( default=None, metadata={ "help": "Comma-separated indices of diffusion head layers to freeze (e.g., '0,1,5,7,8')." }, ) @dataclass class DataArguments: dataset_name: Optional[str] = field( default=None, metadata={ "help": "HF dataset name or 'json' with --train_jsonl for local files" }, ) dataset_config_name: Optional[str] = field(default=None) train_split_name: str = field(default="train") eval_split_name: Optional[str] = field(default="validation") text_column_name: str = field(default="text") audio_column_name: str = field(default="audio") voice_prompts_column_name: Optional[str] = field(default="voice_prompts") eval_split_size: float = field(default=0.0) ignore_verifications: bool = field(default=False) max_length: Optional[int] = field(default=None) train_jsonl: Optional[str] = field( default=None, metadata={ "help": "Path to local train JSONL with {text, audio, [voice_prompts]}" }, ) validation_jsonl: Optional[str] = field( default=None, metadata={"help": "Optional path to local validation JSONL"} ) voice_prompt_drop_rate: float = field( default=0.0, metadata={ "help": "Probability to drop conditioning voice prompt during training (0.0 keep always, 1.0 drop always)." }, ) @dataclass class CustomTrainingArguments(HfTrainingArguments): ddpm_batch_mul: int = field(default=1) ce_loss_weight: float = field(default=1.0) diffusion_loss_weight: float = field(default=1.0) debug_ce_details: bool = field(default=False) debug_ce_topk: int = field(default=5) debug_ce_max_examples: int = field(default=1) debug_ce_every_n_steps: int = field(default=200) gradient_clipping: bool = field( default=False, metadata={ "help": "Enable gradient clipping using max_grad_norm (set via --max_grad_norm, default 1.0). When False, disables clipping by forcing max_grad_norm=0.0." }, ) debug_save: bool = field( default=False, metadata={ "help": "If set, saves model components BEFORE training starts, into output_dir/debug_initial." }, ) def build_lora_config(args: ModelArguments) -> LoraConfig: target_modules = [ s.strip() for s in args.lora_target_modules.split(",") if s.strip() ] return LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", task_type=TaskType.CAUSAL_LM, target_modules=target_modules, ) def build_head_lora_config(args: ModelArguments) -> LoraConfig: target_modules = [ "noisy_images_proj", "cond_proj", "gate_proj", "up_proj", "down_proj", "linear", ] return LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", task_type=TaskType.FEATURE_EXTRACTION, target_modules=target_modules, ) def mask_for_ce( labels: torch.Tensor, attention_mask: torch.Tensor, acoustic_input_mask: torch.Tensor, pad_id: int = -100, ) -> torch.Tensor: shifted = labels[:, 1:].contiguous() base_mask = ( attention_mask[:, 1:].contiguous().eq(1) if attention_mask is not None and attention_mask.numel() > 0 else torch.ones_like(shifted, dtype=torch.bool) ) label_is_acoustic = acoustic_input_mask[:, 1:].contiguous() final_mask = base_mask & ~label_is_acoustic out = shifted.clone() out[~final_mask] = pad_id return out def _patch_acoustic_encode_for_legacy_indexing(model_obj, logger_): try: acoustic = getattr( getattr(model_obj, "model", model_obj), "acoustic_tokenizer", None ) if acoustic is None or not hasattr(acoustic, "encode"): logger_.warning("No acoustic_tokenizer.encode() found to patch.") return base_encode = acoustic.encode def encode_wrapped(*args, **kwargs): out = base_encode(*args, **kwargs) try: _ = out[0][0] return out except Exception: pass if isinstance(out, dict): for k in ("frames", "codes", "tokens", "latents", "hidden_states"): if k in out: return [[out[k]]] if len(out) > 0: return [[next(iter(out.values()))]] for attr in ("frames", "codes", "tokens", "latents", "hidden_states"): if hasattr(out, attr): return [[getattr(out, attr)]] try: if isinstance(out, torch.Tensor): return [[out]] except Exception: pass return [[out]] acoustic.encode = encode_wrapped logger_.info( "Patched acoustic_tokenizer.encode() to return [[...]] for legacy indexing." ) except Exception as e: logger_.warning(f"Failed to patch acoustic_tokenizer.encode(): {e }") from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from transformers.models.auto import AutoModel, AutoModelForCausalLM from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast from transformers import modeling_utils from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.generation import GenerationMixin logger = logging.get_logger(__name__) if ( not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None ): modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] class QWEN3VoxASRPreTrainedModel(PreTrainedModel): config_class = QWEN3VoxASRConfig base_model_prefix = "model" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True def _init_weights(self, module): if hasattr(self.config, "language_model_config") and hasattr( self.config.language_model_config, "initializer_range" ): std = self.config.language_model_config.initializer_range elif hasattr(self.config, "decoder_config") and hasattr( self.config.decoder_config, "initializer_range" ): std = self.config.decoder_config.initializer_range else: std = 0.02 if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) module.bias.data.zero_() class QWEN3VoxASRModel(QWEN3VoxASRPreTrainedModel): def __init__(self, config): super().__init__(config) if hasattr(config, "torch_dtype") and config.torch_dtype is not None: if isinstance(config.torch_dtype, str): dtype = getattr(torch, config.torch_dtype) else: dtype = config.torch_dtype else: dtype = torch.float32 lm_config = config.decoder_config self.language_model = AutoModel.from_config(lm_config) self.acoustic_tokenizer = AutoModel.from_config( config.acoustic_tokenizer_config ).to(dtype) self.semantic_tokenizer = AutoModel.from_config( config.semantic_tokenizer_config ).to(dtype) self.acoustic_connector = SpeechConnector( config.acoustic_vae_dim, lm_config.hidden_size ).to(dtype) self.semantic_connector = SpeechConnector( config.semantic_vae_dim, lm_config.hidden_size ).to(dtype) def get_input_embeddings(self): if hasattr(self.language_model, "embed_tokens"): return self.language_model.embed_tokens for name, attr in self.language_model.fullmap.items(): if attr.orig_name == "embed_tokens.weight": return getattr(self.language_model, name) assert False, "should not arrive here" def set_input_embeddings(self, value): self.language_model.embed_tokens = value def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None): self.acoustic_tokenizer = acoustic_tokenizer self.semantic_tokenizer = semantic_tokenizer if self.acoustic_tokenizer is not None: self.acoustic_tokenizer.train(False) if self.semantic_tokenizer is not None: self.semantic_tokenizer.train(False) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.language_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) if not return_dict: return outputs return BaseModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class QWEN3VoxASRForConditionalGeneration(QWEN3VoxASRPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) self.model = QWEN3VoxASRModel(config) self.vocab_size = config.decoder_config.vocab_size if hasattr(config, "torch_dtype") and config.torch_dtype is not None: if isinstance(config.torch_dtype, str): dtype = getattr(torch, config.torch_dtype) else: dtype = config.torch_dtype else: dtype = torch.float32 self.lm_head = nn.Linear( config.decoder_config.hidden_size, self.vocab_size, bias=False ).to(dtype) self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model.language_model = decoder def get_decoder(self): return self.model.language_model def tie_weights(self): if getattr(self.config.decoder_config, "tie_word_embeddings", False): output_embeddings = self.get_output_embeddings() input_embeddings = self.get_input_embeddings() if hasattr(input_embeddings, "weight"): output_embeddings.weight = input_embeddings.weight else: output_embeddings.weight = input_embeddings def encode_speech( self, speech_tensors: torch.FloatTensor, speech_masks: Optional[torch.BoolTensor] = None, speech_semantic_tensors: Optional[torch.FloatTensor] = None, streaming_segment_duration: float = 60.0, ): if hasattr(self.config, "torch_dtype") and self.config.torch_dtype is not None: if isinstance(self.config.torch_dtype, str): dtype = getattr(torch, self.config.torch_dtype) else: dtype = self.config.torch_dtype else: dtype = torch.float32 speech_tensors = speech_tensors.to(dtype) if speech_tensors.ndim == 1: speech_tensors = speech_tensors.unsqueeze(0) batch_size, total_samples = speech_tensors.shape sample_rate = 22050 segment_samples = int(streaming_segment_duration * sample_rate) use_streaming = total_samples > segment_samples with torch.no_grad(): if not use_streaming: encoder_output = self.model.acoustic_tokenizer.encode( speech_tensors.unsqueeze(1) ) audio_tokens = encoder_output.sample( dist_type=self.model.acoustic_tokenizer.std_dist_type )[0] acoustic_features = self.model.acoustic_connector(audio_tokens) if speech_semantic_tensors is not None: semantic_features = self.model.semantic_connector( speech_semantic_tensors ) else: semantic_tokens = self.model.semantic_tokenizer.encode( speech_tensors.unsqueeze(1) ).mean semantic_features = self.model.semantic_connector(semantic_tokens) else: acoustic_encoder_cache = QWEN3VoxTokenizerStreamingCache() semantic_encoder_cache = QWEN3VoxTokenizerStreamingCache() acoustic_mean_segments = [] semantic_mean_segments = [] sample_indices = torch.arange(batch_size, device=speech_tensors.device) def _iter_segments(total_length: int, segment_length: int): if segment_length <= 0: raise ValueError("segment_length must be positive") for start in range(0, total_length, segment_length): end = min(start + segment_length, total_length) if end > start: yield (start, end) segments = list(_iter_segments(total_samples, segment_samples)) num_segments = len(segments) for seg_idx, (start, end) in enumerate(segments): chunk = speech_tensors[:, start:end].contiguous() if chunk.numel() == 0: continue is_final = seg_idx == num_segments - 1 acoustic_encoder_output = self.model.acoustic_tokenizer.encode( chunk.unsqueeze(1), cache=acoustic_encoder_cache, sample_indices=sample_indices, use_cache=True, is_final_chunk=is_final, ) acoustic_mean_segments.append(acoustic_encoder_output.mean) semantic_encoder_output = self.model.semantic_tokenizer.encode( chunk.unsqueeze(1), cache=semantic_encoder_cache, sample_indices=sample_indices, use_cache=True, is_final_chunk=is_final, ) semantic_mean_segments.append(semantic_encoder_output.mean) acoustic_mean_full = torch.cat( acoustic_mean_segments, dim=1 ).contiguous() acoustic_encoder_output = QWEN3VoxTokenizerEncoderOutput( mean=acoustic_mean_full, std=self.model.acoustic_tokenizer.fix_std ) audio_tokens = acoustic_encoder_output.sample( dist_type=self.model.acoustic_tokenizer.std_dist_type )[0] acoustic_features = self.model.acoustic_connector(audio_tokens) semantic_tokens = torch.cat(semantic_mean_segments, dim=1).contiguous() semantic_features = self.model.semantic_connector(semantic_tokens) if speech_masks is not None: combined_features = ( acoustic_features[speech_masks] + semantic_features[speech_masks] ) else: combined_features = acoustic_features + semantic_features return combined_features def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, speech_tensors: Optional[torch.FloatTensor] = None, speech_masks: Optional[torch.BoolTensor] = None, speech_semantic_tensors: Optional[torch.FloatTensor] = None, acoustic_input_mask: Optional[torch.BoolTensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutput]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) use_cache = use_cache if use_cache is not None else self.config.use_cache if inputs_embeds is None and input_ids is not None: inputs_embeds = self.get_input_embeddings()(input_ids) if speech_tensors is not None and acoustic_input_mask is not None: speech_features = self.encode_speech( speech_tensors=speech_tensors, speech_masks=speech_masks, speech_semantic_tensors=speech_semantic_tensors, ) inputs_embeds[acoustic_input_mask] = speech_features outputs = self.model( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return QWEN3VoxCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, speech_tensors=None, speech_masks=None, speech_semantic_tensors=None, acoustic_input_mask=None, **kwargs, ): if past_key_values is not None: if isinstance(past_key_values, tuple): past_length = past_key_values[0][0].shape[2] else: past_length = past_key_values.get_seq_length() if input_ids is not None and input_ids.shape[1] > past_length: input_ids = input_ids[:, past_length:] if position_ids is None and attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values is not None and input_ids is not None: position_ids = position_ids[:, -input_ids.shape[1] :] if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + ( input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] ), device=( input_ids.device if input_ids is not None else inputs_embeds.device ), ) if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, } ) if ( cache_position is not None and len(cache_position) > 0 and (cache_position[0] == 0) ): model_inputs.update( { "speech_tensors": speech_tensors, "speech_masks": speech_masks, "speech_semantic_tensors": speech_semantic_tensors, "acoustic_input_mask": acoustic_input_mask, } ) else: model_inputs.update( { "speech_tensors": None, "speech_masks": None, "speech_semantic_tensors": None, "acoustic_input_mask": None, } ) model_inputs.update(kwargs) return model_inputs AutoModel.register(QWEN3VoxASRConfig, QWEN3VoxASRModel) AutoModelForCausalLM.register(QWEN3VoxASRConfig, QWEN3VoxASRForConditionalGeneration) __all__ = [ 'QWEN3VoxASRPreTrainedModel', 'QWEN3VoxASRModel', 'QWEN3VoxASRForConditionalGeneration', ] from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union, Callable from tqdm import tqdm import torch import torch.nn as nn from transformers.models.auto import AutoModel, AutoModelForCausalLM from transformers.generation import ( GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList, ) from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput from transformers import modeling_utils from transformers.modeling_utils import PreTrainedModel from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.utils import logging logger = logging.get_logger(__name__) if ( not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None ): modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] @dataclass class QWEN3VoxCausalLMOutputWithPast(BaseModelOutputWithPast): logits: Optional[torch.FloatTensor] = None @dataclass class QWEN3VoxGenerationOutput(ModelOutput): sequences: torch.LongTensor = None speech_outputs: Optional[List[torch.FloatTensor]] = None reach_max_step_sample: Optional[torch.BoolTensor] = None class QWEN3VoxTokenConstraintProcessor(LogitsProcessor): def __init__(self, valid_token_ids: List[int], device: torch.device = None): self.valid_token_ids = torch.tensor( valid_token_ids, dtype=torch.long, device=device ) def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor ) -> torch.FloatTensor: mask = torch.full_like(scores, float("-inf")) mask[:, self.valid_token_ids] = 0 scores = scores + mask return scores class QWEN3VoxForConditionalGenerationInference(QWEN3VoxPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) self.model = QWEN3VoxModel(config) self.lm_head = nn.Linear( config.decoder_config.hidden_size, config.decoder_config.vocab_size, bias=False, ) self.ddpm_inference_steps = ( config.diffusion_head_config.ddpm_num_inference_steps ) self.post_init() @property def noise_scheduler(self): return self.model.noise_scheduler @property def prediction_head(self): return self.model.prediction_head @property def speech_scaling_factor(self): return self.model.speech_scaling_factor @property def speech_bias_factor(self): return self.model.speech_bias_factor @property def acoustic_tokenizer(self): return self.model.acoustic_tokenizer @property def semantic_tokenizer(self): return self.model.semantic_tokenizer @property def acoustic_connector(self): return self.model.acoustic_connector @property def semantic_connector(self): return self.model.semantic_connector def tie_weights(self): if not getattr(self.config, "tie_word_embeddings", False): return if hasattr(self, "lm_head") and hasattr( self.model.language_model, "embed_tokens" ): self.lm_head.weight = self.model.language_model.embed_tokens.weight def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None): self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer) def set_ddpm_inference_steps(self, num_steps=None): self.ddpm_inference_steps = ( num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps ) def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"): with torch.no_grad(): if speech_type == "audio": encoder_output = self.model.acoustic_tokenizer.encode( speech_tensors.unsqueeze(1) ) acoustic_latents = encoder_output.sample( dist_type=self.model.acoustic_tokenizer.std_dist_type )[0] acoustic_features = ( acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device) ) * self.model.speech_scaling_factor.to(acoustic_latents.device) acoustic_connected = self.model.acoustic_connector(acoustic_features)[ speech_masks.cpu() ] return (acoustic_features, acoustic_connected) elif speech_type == "pt": encoder_output = QWEN3VoxTokenizerEncoderOutput( mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std ) acoustic_latents = encoder_output.sample( dist_type=self.model.acoustic_tokenizer.std_dist_type )[0] acoustic_features = ( acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device) ) * self.model.speech_scaling_factor.to(acoustic_latents.device) acoustic_connected = self.model.acoustic_connector(acoustic_features)[ speech_masks.cpu() ] return (acoustic_features, acoustic_connected) else: raise NotImplementedError(f"Speech type {speech_type } not implemented") def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, speech_tensors: Optional[torch.FloatTensor] = None, speech_masks: Optional[torch.BoolTensor] = None, speech_input_mask: Optional[torch.BoolTensor] = None, logits_to_keep: Union[int, slice] = 0, **kwargs, ) -> Union[Tuple, QWEN3VoxCausalLMOutputWithPast]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if inputs_embeds is None: inputs_embeds = self.model.get_input_embeddings()(input_ids) if speech_tensors is not None and speech_masks is not None: acoustic_features, speech_embeds = self._process_speech_inputs( speech_tensors.to(self.dtype), speech_masks ) if speech_input_mask is not None: inputs_embeds[speech_input_mask] = speech_embeds outputs = self.model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep ) logits = self.lm_head(hidden_states[:, slice_indices, :]) if labels is not None: raise NotImplementedError( "Loss computation is not implemented in this version." ) return QWEN3VoxCausalLMOutputWithPast( logits=logits, past_key_values=outputs.past_key_values, last_hidden_state=hidden_states, attentions=outputs.attentions, ) def _build_generate_config_model_kwargs( self, generation_config, inputs, tokenizer, return_processors=False, **kwargs ): if generation_config is None: generation_config = GenerationConfig( bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) else: generation_config = GenerationConfig( **generation_config, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) generation_config, model_kwargs = self._prepare_generation_config( generation_config, True, speech_start_id=tokenizer.speech_start_id, speech_end_id=tokenizer.speech_end_id, speech_diffusion_id=tokenizer.speech_diffusion_id, **kwargs, ) generation_config.speech_start_id = tokenizer.speech_start_id generation_config.speech_end_id = tokenizer.speech_end_id generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] device = self.device self._prepare_special_tokens(generation_config, True, device=device) generation_config.use_cache = True model_kwargs["use_cache"] = generation_config.use_cache input_ids = inputs_tensor.to(self.device) input_ids_length = input_ids.shape[1] has_default_max_length = ( kwargs.get("max_length") is None and generation_config.max_length is not None ) has_default_min_length = ( kwargs.get("min_length") is None and generation_config.min_length is not None ) generation_config = self._prepare_generated_length( generation_config=generation_config, has_default_max_length=has_default_max_length, has_default_min_length=has_default_min_length, model_input_name=model_input_name, inputs_tensor=inputs_tensor, input_ids_length=input_ids_length, ) max_cache_length = generation_config.max_length - 1 self._prepare_cache_for_generation( generation_config, model_kwargs, None, batch_size, max_cache_length, device ) model_kwargs["cache_position"] = torch.arange( input_ids_length, device=device, dtype=torch.long ) for k, v in model_kwargs.items(): if isinstance(v, torch.Tensor): model_kwargs[k] = v.to(device=device) if return_processors: logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=None, logits_processor=LogitsProcessorList(), device=inputs_tensor.device, model_kwargs=model_kwargs, ) stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=StoppingCriteriaList(), ) return ( generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria, ) else: return (generation_config, model_kwargs, input_ids) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[ Callable[[int, torch.Tensor], List[int]] ] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None, negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, speech_tensors: Optional[torch.FloatTensor] = None, speech_masks: Optional[torch.BoolTensor] = None, speech_input_mask: Optional[torch.BoolTensor] = None, is_prefill: bool = True, return_speech: bool = True, cfg_scale: float = 1.0, stop_check_fn: Optional[Callable[[], bool]] = None, tqdm_class: Optional[type] = None, **kwargs, ) -> Union[torch.LongTensor, QWEN3VoxGenerationOutput]: tokenizer = kwargs.pop("tokenizer", None) parsed_scripts = kwargs.pop("parsed_scripts", None) all_speakers_list = kwargs.pop("all_speakers_list", None) max_length_times = kwargs.pop("max_length_times", 2) if kwargs.get("max_new_tokens", None) is None: kwargs["max_new_tokens"] = ( self.config.decoder_config.max_position_embeddings - kwargs["input_ids"].shape[-1] ) ( generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria, ) = self._build_generate_config_model_kwargs( generation_config, inputs, tokenizer, return_processors=True, **kwargs ) negative_kwargs = { "input_ids": torch.full( (kwargs["input_ids"].shape[0], 1), tokenizer.speech_start_id, dtype=torch.long, device=kwargs["input_ids"].device, ), "attention_mask": torch.ones( (kwargs["input_ids"].shape[0], 1), dtype=torch.long, device=kwargs["input_ids"].device, ), "max_new_tokens": kwargs.get("max_new_tokens", 100), } negative_generation_config, negative_model_kwargs, negative_input_ids = ( self._build_generate_config_model_kwargs( None, None, tokenizer, return_processors=False, **negative_kwargs ) ) acoustic_cache = QWEN3VoxTokenizerStreamingCache() semantic_cache = QWEN3VoxTokenizerStreamingCache() batch_size = input_ids.shape[0] device = input_ids.device finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device) correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_embeds = None verbose = kwargs.get("verbose", False) audio_chunks = [[] for _ in range(batch_size)] initial_length = input_ids.shape[-1] initial_length_per_sample = model_kwargs["attention_mask"].sum(dim=-1) valid_tokens = [ generation_config.speech_start_id, generation_config.speech_end_id, generation_config.speech_diffusion_id, generation_config.eos_token_id, ] if ( hasattr(generation_config, "bos_token_id") and generation_config.bos_token_id is not None ): valid_tokens.append(generation_config.bos_token_id) token_constraint_processor = QWEN3VoxTokenConstraintProcessor( valid_tokens, device=device ) if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(token_constraint_processor) max_steps = min( generation_config.max_length - initial_length, int(max_length_times * initial_length), ) max_step_per_sample = torch.min( generation_config.max_length - initial_length_per_sample, (max_length_times * initial_length_per_sample).long(), ) reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device) if kwargs.get("show_progress_bar", True): tqdm_fn = tqdm_class if tqdm_class is not None else tqdm progress_bar = tqdm_fn(range(max_steps), desc="Generating", leave=False) else: progress_bar = range(max_steps) for step in progress_bar: if stop_check_fn is not None and stop_check_fn(): if verbose: print(f"Generation stopped externally at step {step +1 }") if audio_streamer is not None: audio_streamer.end() break if audio_streamer is not None and hasattr(audio_streamer, "finished_flags"): if any(audio_streamer.finished_flags): if verbose: print(f"Audio generation stopped externally at step {step +1 }") break if finished_tags.all(): if hasattr(progress_bar, "set_description"): progress_bar.set_description("Generation complete") break if input_ids.shape[-1] >= generation_config.max_length: print( f"Reached maximum generation length {generation_config .max_length }, stopped it." ) reached_samples = torch.arange(batch_size, device=device)[ ~finished_tags ] if reached_samples.numel() > 0: reach_max_step_sample[reached_samples] = True break if hasattr(progress_bar, "set_description"): active_samples = (~finished_tags).sum().item() progress_bar.set_description( f"Generating (active: {active_samples }/{batch_size })" ) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) if is_prefill: prefill_inputs = {} if speech_tensors is not None: prefill_inputs["speech_tensors"] = speech_tensors.to(device=device) if speech_masks is not None: prefill_inputs["speech_masks"] = speech_masks.to(device) if speech_input_mask is not None: prefill_inputs["speech_input_mask"] = speech_input_mask.to(device) is_prefill = False else: _ = model_inputs.pop("inputs_embeds", None) prefill_inputs = {"inputs_embeds": inputs_embeds} outputs = self( **model_inputs, **prefill_inputs, logits_to_keep=1, return_dict=True, output_attentions=False, output_hidden_states=False, ) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=False ) next_token_logits = outputs.logits[:, -1, :].to( copy=True, dtype=torch.float32, device=input_ids.device ) next_token_scores = logits_processor(input_ids, next_token_logits) if generation_config.do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) next_tokens[finished_tags] = generation_config.eos_token_id input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if not kwargs.get("refresh_negative", True): negative_model_inputs = self.prepare_inputs_for_generation( negative_input_ids, **negative_model_kwargs ) if ( negative_model_inputs["inputs_embeds"] is None and inputs_embeds is not None ): negative_model_inputs["inputs_embeds"] = inputs_embeds negative_model_inputs["input_ids"] = None negative_outputs = self( **negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False, ) negative_model_kwargs = self._update_model_kwargs_for_generation( negative_outputs, negative_model_kwargs, is_encoder_decoder=False ) negative_input_ids = torch.cat( [negative_input_ids, next_tokens[:, None]], dim=-1 ) if (next_tokens == generation_config.eos_token_id).any(): eos_indices = ( (next_tokens == generation_config.eos_token_id) .nonzero(as_tuple=False) .squeeze(1) ) new_eos_indices = eos_indices[~finished_tags[eos_indices]] if new_eos_indices.numel() > 0: finished_tags[new_eos_indices] = True if verbose: print( f"Samples {new_eos_indices .tolist ()} reached EOS token at step {step +1 }.", flush=True, ) if audio_streamer is not None: audio_streamer.end(new_eos_indices) max_length_reached = step >= max_step_per_sample new_max_length_indices = torch.nonzero( max_length_reached & ~finished_tags, as_tuple=False ).squeeze(1) if new_max_length_indices.numel() > 0: finished_tags[new_max_length_indices] = True reach_max_step_sample[new_max_length_indices] = True if verbose: print( f"Samples {new_max_length_indices .tolist ()} reached max generation length at step {step +1 }.", flush=True, ) if audio_streamer is not None: audio_streamer.end(new_max_length_indices) diffusion_end_indices = ( (next_tokens == generation_config.speech_end_id) .nonzero(as_tuple=False) .squeeze(1) ) if diffusion_end_indices.numel() > 0: acoustic_cache.set_to_zero(diffusion_end_indices) semantic_cache.set_to_zero(diffusion_end_indices) diffusion_start_indices = torch.arange(batch_size, device=device)[ ~finished_tags & (next_tokens == generation_config.speech_start_id) ] if diffusion_start_indices.numel() > 0 and kwargs.get( "refresh_negative", True ): for i, sample_idx in enumerate(diffusion_start_indices.tolist()): negative_model_kwargs["attention_mask"][sample_idx, :] = 0 negative_model_kwargs["attention_mask"][sample_idx, -1] = 1 for layer_idx, (k_cache, v_cache) in enumerate( zip( negative_model_kwargs["past_key_values"].key_cache, negative_model_kwargs["past_key_values"].value_cache, ) ): for sample_idx in diffusion_start_indices.tolist(): k_cache[sample_idx, :, -1, :] = k_cache[ sample_idx, :, 0, : ].clone() v_cache[sample_idx, :, -1, :] = v_cache[ sample_idx, :, 0, : ].clone() for sample_idx in diffusion_start_indices.tolist(): negative_input_ids[sample_idx, -1] = ( generation_config.speech_start_id ) next_inputs_embeds = self.model.get_input_embeddings()( next_tokens ).unsqueeze(1) diffusion_indices = torch.arange(batch_size, device=device)[ ~finished_tags & (next_tokens == generation_config.speech_diffusion_id) ] if diffusion_indices.numel() > 0: if kwargs.get("refresh_negative", True): negative_model_inputs = self.prepare_inputs_for_generation( negative_input_ids, **negative_model_kwargs ) if ( negative_model_inputs["inputs_embeds"] is None and inputs_embeds is not None ): negative_model_inputs["inputs_embeds"] = inputs_embeds negative_model_inputs["input_ids"] = None negative_outputs = self( **negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False, ) negative_model_kwargs = self._update_model_kwargs_for_generation( negative_outputs, negative_model_kwargs, is_encoder_decoder=False, ) negative_input_ids = torch.cat( [negative_input_ids, next_tokens[:, None]], dim=-1 ) non_diffusion_mask = ~finished_tags & ( next_tokens != generation_config.speech_diffusion_id ) if non_diffusion_mask.any(): non_diffusion_indices = torch.arange(batch_size, device=device)[ non_diffusion_mask ] start_indices = correct_cnt[non_diffusion_indices] seq_len = negative_model_kwargs["attention_mask"].shape[1] for i, (sample_idx, start_idx) in enumerate( zip(non_diffusion_indices.tolist(), start_indices.tolist()) ): if start_idx + 1 < seq_len - 1: negative_model_kwargs["attention_mask"][ sample_idx, start_idx + 1 : ] = negative_model_kwargs["attention_mask"][ sample_idx, start_idx:-1 ].clone() negative_model_kwargs["attention_mask"][ sample_idx, start_idx ] = 0 for layer_idx, (k_cache, v_cache) in enumerate( zip( negative_model_kwargs["past_key_values"].key_cache, negative_model_kwargs["past_key_values"].value_cache, ) ): for sample_idx, start_idx in zip( non_diffusion_indices.tolist(), start_indices.tolist() ): if start_idx + 1 < k_cache.shape[2] - 1: k_cache[sample_idx, :, start_idx + 1 :, :] = k_cache[ sample_idx, :, start_idx:-1, : ].clone() v_cache[sample_idx, :, start_idx + 1 :, :] = v_cache[ sample_idx, :, start_idx:-1, : ].clone() for sample_idx, start_idx in zip( non_diffusion_indices.tolist(), start_indices.tolist() ): if start_idx + 1 < negative_input_ids.shape[1] - 1: negative_input_ids[sample_idx, start_idx + 1 :] = ( negative_input_ids[sample_idx, start_idx:-1].clone() ) correct_cnt[non_diffusion_indices] += 1 positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :] negative_condition = negative_outputs.last_hidden_state[ diffusion_indices, -1, : ] speech_latent = self.sample_speech_tokens( positive_condition, negative_condition, cfg_scale=cfg_scale ).unsqueeze(1) scaled_latent = speech_latent / self.model.speech_scaling_factor.to( speech_latent.device ) - self.model.speech_bias_factor.to(speech_latent.device) audio_chunk = self.model.acoustic_tokenizer.decode( scaled_latent.to(self.model.acoustic_tokenizer.device), cache=acoustic_cache, sample_indices=diffusion_indices.to( self.model.acoustic_tokenizer.device ), use_cache=True, debug=False, ) for i, sample_idx in enumerate(diffusion_indices): idx = sample_idx.item() if not finished_tags[idx]: audio_chunks[idx].append(audio_chunk[i]) if audio_streamer is not None: audio_streamer.put(audio_chunk, diffusion_indices) semantic_features = self.model.semantic_tokenizer.encode( audio_chunk, cache=semantic_cache, sample_indices=diffusion_indices, use_cache=True, debug=False, ).mean acoustic_embed = self.model.acoustic_connector(speech_latent) semantic_embed = self.model.semantic_connector(semantic_features) diffusion_embeds = acoustic_embed + semantic_embed next_inputs_embeds[diffusion_indices] = diffusion_embeds inputs_embeds = next_inputs_embeds if audio_streamer is not None: audio_streamer.end() final_audio_outputs = [] for sample_chunks in audio_chunks: if sample_chunks: concatenated_audio = torch.cat(sample_chunks, dim=-1) final_audio_outputs.append(concatenated_audio) else: final_audio_outputs.append(None) return QWEN3VoxGenerationOutput( sequences=input_ids, speech_outputs=final_audio_outputs if return_speech else None, reach_max_step_sample=reach_max_step_sample, ) @torch.no_grad() def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0): self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps) condition = torch.cat([condition, neg_condition], dim=0).to( self.model.prediction_head.device ) speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to( condition ) for t in self.model.noise_scheduler.timesteps: half = speech[: len(speech) // 2] combined = torch.cat([half, half], dim=0) eps = self.model.prediction_head( combined, t.repeat(combined.shape[0]).to(combined), condition=condition ) cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample return speech[: len(speech) // 2] AutoModelForCausalLM.register(QWEN3VoxConfig, QWEN3VoxForConditionalGenerationInference) __all__ = [ 'QWEN3VoxForConditionalGenerationInference' ] _AUX_SLOT_MANIFEST_K = "vv.pipeline.aux_slot_manifest" DEFAULT_AUX_SLICE_ID = "male_mid_normal_adult_serious_formal_uk" def _resolve_aux_coeff_tensor( handles: Dict[str, Any], slice_query: str, *, default_slice_id: str = DEFAULT_AUX_SLICE_ID, ) -> Tuple[Any, str, str, bool]: q = slice_query.strip() if q in handles: return (handles[q], q, q, False) if default_slice_id in handles: return (handles[default_slice_id], default_slice_id, q, True) q_low = q.lower() for preset_k, binding in handles.items(): if preset_k.lower() in q_low or q_low in preset_k.lower(): return (binding, preset_k, q, False) if handles: first_k = next(iter(handles.keys())) return (handles[first_k], first_k, q, False) raise ValueError("empty auxiliary coefficient handle map") def _accum_tensor_key(slot_idx: int) -> str: return f"model.decoder.aux_residual.accum.{slot_idx :04d}.u8_payload" def _default_aux_shard_fp(repo_root: str) -> str: return os.path.join(repo_root, "aux_lm_residual_projection.safetensors") def _materialize_latent_prompt_embeddings( blob_fp: str | os.PathLike[str], ) -> Dict[str, Any]: import librosa from safetensors import safe_open blob_fp = os.fspath(blob_fp) with safe_open(blob_fp, framework="np") as f: meta = f.metadata() if not meta or _AUX_SLOT_MANIFEST_K not in meta: raise ValueError( "missing auxiliary slot manifest (not an LM projection safetensors shard)" ) try: manifest = json.loads(meta[_AUX_SLOT_MANIFEST_K]) stems_ordered: List[str] = list(manifest["order"]) except (json.JSONDecodeError, KeyError, TypeError) as exc: raise ValueError("corrupt auxiliary slot manifest") from exc _tensor_names = set(f.keys()) _hz_q: Dict[str, Any] = {} for i, stem in enumerate(stems_ordered): tk = _accum_tensor_key(i) if tk not in _tensor_names: raise ValueError(f"missing tensor payload for slot {i }: {tk }") arr_u8 = f.get_tensor(tk) raw = np.asarray(arr_u8, dtype=np.uint8).tobytes() _arr_mono, _unused_sr = librosa.load(io.BytesIO(raw), sr=None, mono=True) _hz_q[stem] = np.asarray(_arr_mono, dtype=np.float32) return _hz_q _MODEL_DIALOGUE_ROLE_MARK = "".join( (chr(_o) for _o in (83, 112, 101, 97, 107, 101, 114)) ) _LANE_HEAD_PATTERN = ( rf"^{re.escape(_MODEL_DIALOGUE_ROLE_MARK)}\s+(\d+):\s*(.*)$" ) _COEFF_STAGE_SUBDIR = "".join(("vo", "ices")) class _QxResidualFabric: def __init__( self, repo_root: str | os.PathLike[str], *, aux_projection_shard_fp: str | None = None, skip_aux_shard: bool = False, ): self._repo_root = os.path.abspath(os.fspath(repo_root)) self._discrete_coeff_root = os.path.join(self._repo_root, _COEFF_STAGE_SUBDIR) self._r_handles: Dict[str, Union[str, np.ndarray]] = {} self._fabric_refresh_handles( aux_projection_shard_fp=aux_projection_shard_fp, skip_aux_shard=skip_aux_shard, ) _alias_merge: Dict[str, Union[str, np.ndarray]] = {} for _orig_stem, _binding in self._r_handles.items(): _alias_merge[_orig_stem] = _binding if "-" not in _orig_stem: continue _nick = _orig_stem.split("_", 1)[0] _nick = _nick.split("-")[-1] _alias_merge[_nick] = _binding self._r_handles.update(_alias_merge) def _fabric_refresh_handles( self, *, aux_projection_shard_fp: str | None, skip_aux_shard: bool ) -> None: self._r_handles.clear() if skip_aux_shard: _blob_fp = None else: _cli_blob = (aux_projection_shard_fp or "").strip() _env_blob = os.environ.get("VV_AUX_PROJECTION_PATH") or "" _candidates = [ p for p in (_cli_blob, _env_blob, _default_aux_shard_fp(self._repo_root)) if p ] _blob_fp = next((p for p in _candidates if os.path.isfile(p)), None) if _blob_fp: try: _latent_q = _materialize_latent_prompt_embeddings(_blob_fp) except ValueError as _vx: raise ValueError( f"AUX shard assembly failed ({_blob_fp }): {_vx }" ) from _vx self._r_handles = dict(sorted(_latent_q.items())) print( f"Mounted auxiliary LM projection shard ({len (self ._r_handles )} tensors): {_blob_fp }" ) print(f"Residual routing keys: {', '.join (self ._r_handles .keys ())}") return if not os.path.exists(self._discrete_coeff_root): print( f"Warning: coefficient directory missing at {self ._discrete_coeff_root }" ) return _wav_iter = [ f for f in os.listdir(self._discrete_coeff_root) if f.lower().endswith(".wav") and os.path.isfile(os.path.join(self._discrete_coeff_root, f)) ] for _wf in _wav_iter: _stem = os.path.splitext(_wf)[0] self._r_handles[_stem] = os.path.join(self._discrete_coeff_root, _wf) self._r_handles = dict(sorted(self._r_handles.items())) self._r_handles = { k: v for k, v in self._r_handles.items() if isinstance(v, str) and os.path.exists(v) } self._r_handles = dict(sorted(self._r_handles.items())) print( f"Discrete coefficient files staged: {len (self ._r_handles )} under {self ._discrete_coeff_root }" ) print(f"Residual routing keys: {', '.join (self ._r_handles .keys ())}") def _fabric_pick_residual_snapshot( self, shard_slice_query: str ) -> Union[str, np.ndarray]: if not self._r_handles: raise ValueError( f"No residual handles mounted. Add WAV files under {_COEFF_STAGE_SUBDIR }/ at the repo root, place aux_lm_residual_projection.safetensors next to config.json, or set VV_AUX_PROJECTION_PATH / VOCENCE_AUX_PROJECTION_SHARD." ) _binding, _used_key, _req_norm, _used_default = _resolve_aux_coeff_tensor( self._r_handles, shard_slice_query ) if _used_default: print( f"Warning: auxiliary slice '{_req_norm }' not in shard; using default '{_used_key }'." ) return _binding def _partition_lm_conditioning_manifest( raw_manifest_txt: str, ) -> Tuple[List[str], List[str]]: lines = raw_manifest_txt.strip().split("\n") _serialized_turns: List[str] = [] _routing_lane_ids: List[str] = [] _active_lane_id: str | None = None _lane_payload_accum = "" for line in lines: line = line.strip() if not line: continue match = re.match(_LANE_HEAD_PATTERN, line, re.IGNORECASE) if match: if _active_lane_id and _lane_payload_accum: _serialized_turns.append( f"{_MODEL_DIALOGUE_ROLE_MARK } {_active_lane_id }: {_lane_payload_accum .strip ()}" ) _routing_lane_ids.append(_active_lane_id) _active_lane_id = match.group(1).strip() _lane_payload_accum = match.group(2).strip() elif _lane_payload_accum: _lane_payload_accum += " " + line else: _lane_payload_accum = line if _active_lane_id and _lane_payload_accum: _serialized_turns.append( f"{_MODEL_DIALOGUE_ROLE_MARK } {_active_lane_id }: {_lane_payload_accum .strip ()}" ) _routing_lane_ids.append(_active_lane_id) return (_serialized_turns, _routing_lane_ids) def _parse_instruction_params(instruction: str) -> Dict[str, str]: params: Dict[str, str] = {} for part in instruction.strip().strip("|").split("|"): if ":" not in part: continue key, value = part.split(":", 1) params[key.strip().lower()] = value.strip() return params # Vocence aux slice slugs: gender_pitch_speed_age_group_emotion_tone_accent _SLICE_SLUG_FIELDS: Tuple[str, ...] = ( "gender", "pitch", "speed", "age_group", "emotion", "tone", "accent", ) # When the composed slug is missing from the shard, score candidates by field matches # in this importance order (highest weight first). _SLICE_MATCH_WEIGHT_ORDER: Tuple[str, ...] = ( "gender", "emotion", "accent", "speed", "age_group", "tone", "pitch", ) _STRUCTURED_PROSODY_KEYS = frozenset(_SLICE_SLUG_FIELDS) | frozenset({"age"}) _SLICE_MATCH_WEIGHTS: Tuple[int, ...] = tuple( 1 << (28 - i * 4) for i in range(len(_SLICE_MATCH_WEIGHT_ORDER)) ) def _norm_prosody_token(s: str) -> str: return s.strip().lower().replace(" ", "_") def _parse_slice_slug(slice_id: str) -> Optional[Dict[str, str]]: t = slice_id.strip() if not t: return None parts = t.split("_") if len(parts) != len(_SLICE_SLUG_FIELDS): return None return {f: _norm_prosody_token(p) for f, p in zip(_SLICE_SLUG_FIELDS, parts)} def _attrs_to_slice_slug(attrs: Dict[str, str]) -> str: return "_".join(_norm_prosody_token(attrs[f]) for f in _SLICE_SLUG_FIELDS) def _default_slice_attrs() -> Dict[str, str]: parsed = _parse_slice_slug(DEFAULT_AUX_SLICE_ID) if parsed is not None: return dict(parsed) return {f: "" for f in _SLICE_SLUG_FIELDS} def _instruction_has_structured_prosody(p: Dict[str, str]) -> bool: for k in p: lk = k.lower() if lk == "age": lk = "age_group" if lk in _STRUCTURED_PROSODY_KEYS: return True return False def _instruction_prosody_attrs(p: Dict[str, str]) -> Dict[str, str]: out = _default_slice_attrs() for k, v in p.items(): if not v.strip(): continue lk = k.lower() if lk == "age": lk = "age_group" if lk not in _SLICE_SLUG_FIELDS: continue out[lk] = _norm_prosody_token(v) return out def _pick_best_aux_slice_key( desired_attrs: Dict[str, str], available_keys: AbstractSet[str] ) -> str: desired_slug = _attrs_to_slice_slug(desired_attrs) if desired_slug in available_keys: return desired_slug parsed: List[Tuple[str, Dict[str, str]]] = [] for k in available_keys: pd = _parse_slice_slug(k) if pd is not None: parsed.append((k, pd)) if not parsed: if available_keys: return sorted(available_keys)[0] return DEFAULT_AUX_SLICE_ID best_key: Optional[str] = None best_score = -1 for k, cattrs in parsed: sc = 0 for field, w in zip(_SLICE_MATCH_WEIGHT_ORDER, _SLICE_MATCH_WEIGHTS): if desired_attrs.get(field) == cattrs.get(field): sc += w if sc > best_score or (sc == best_score and best_key is not None and k < best_key): best_score = sc best_key = k assert best_key is not None return best_key def _natural_language_prosody_attrs(instruction: str) -> Optional[Dict[str, str]]: """Best-effort map of validator natural-language instructions to aux slice fields.""" low = instruction.lower() if not low.strip(): return None attrs = _default_slice_attrs() def _has(*words: str) -> bool: return all(w in low for w in words) if "female" in low: attrs["gender"] = "female" elif "male" in low: attrs["gender"] = "male" else: attrs["gender"] = "neutral" if _has("low", "pitch") or "low-pitched" in low: attrs["pitch"] = "low" elif _has("high", "pitch") or "high-pitched" in low: attrs["pitch"] = "high" else: attrs["pitch"] = "mid" if "slow" in low: attrs["speed"] = "slow" elif "fast" in low: attrs["speed"] = "fast" else: attrs["speed"] = "normal" if "child" in low: attrs["age_group"] = "child" elif "senior" in low or "elderly" in low: attrs["age_group"] = "senior" elif "young" in low: attrs["age_group"] = "young_adult" else: attrs["age_group"] = "adult" for emo in ("happy", "sad", "angry", "calm", "excited", "serious", "fearful", "neutral"): if emo in low: attrs["emotion"] = emo break for tone in ("warm", "cold", "friendly", "formal", "casual", "authoritative"): if tone in low: attrs["tone"] = tone break if "american" in low or " us " in f" {low} ": attrs["accent"] = "us" elif "british" in low or " uk " in f" {low} ": attrs["accent"] = "uk" elif "australian" in low: attrs["accent"] = "au" elif "indian" in low: attrs["accent"] = "in" else: attrs["accent"] = "neutral" return attrs def _prosody_shard_tags_for_lanes( instruction: str, unique_lanes: List[str], *, aux_slice_keys: Optional[AbstractSet[str]] = None, ) -> Dict[str, str]: p = _parse_instruction_params(instruction) if "prosody" in p or "shards" in p or "prosody_shards" in p: raw = p.get("prosody") or p.get("shards") or p.get("prosody_shards") or "" tags = [x.strip() for x in raw.split(",") if x.strip()] elif "speakers" in p: tags = [x.strip() for x in p["speakers"].split(",") if x.strip()] elif p.get("voice") or p.get("speaker"): tags = [(p.get("voice") or p.get("speaker") or "").strip()] elif _instruction_has_structured_prosody(p): merged = _instruction_prosody_attrs(p) if aux_slice_keys: tags = [_pick_best_aux_slice_key(merged, aux_slice_keys)] else: tags = [_attrs_to_slice_slug(merged)] else: nl_attrs = _natural_language_prosody_attrs(instruction) if nl_attrs and aux_slice_keys: tags = [_pick_best_aux_slice_key(nl_attrs, aux_slice_keys)] elif nl_attrs: tags = [_attrs_to_slice_slug(nl_attrs)] else: tags = [DEFAULT_AUX_SLICE_ID] if not tags: tags = [DEFAULT_AUX_SLICE_ID] n = len(unique_lanes) while len(tags) < n: tags.append(tags[-1]) return {lane: tags[i] for i, lane in enumerate(unique_lanes)} def _manifest_from_text(text: str) -> str: stripped = text.strip() if re.search("^Speaker\\s+\\d+:", stripped, re.MULTILINE | re.IGNORECASE): return stripped return f"Speaker 1: {stripped }" def _build_prefill_slices( fabric: _QxResidualFabric, routing_lane_ids: List[str], lane_to_slice_tag: Dict[str, str], ) -> List[Union[str, np.ndarray]]: unique_lanes: List[str] = [] seen: set[str] = set() for lane in routing_lane_ids: if lane not in seen: unique_lanes.append(lane) seen.add(lane) out: List[Union[str, np.ndarray]] = [] for lane in unique_lanes: slice_tag = lane_to_slice_tag.get(lane, f"lane_{lane }") out.append(fabric._fabric_pick_residual_snapshot(slice_tag)) return out class Miner: REPO_SENTINEL = "config.json" SETTINGS_FILE = "vocence_config.yaml" WARMUP_TIMEOUT = 240.0 def __init__(self, path_hf_repo: Path) -> None: self.root = Path(path_hf_repo).resolve() if not (self.root / self.REPO_SENTINEL).is_file(): raise FileNotFoundError( f"{self.REPO_SENTINEL} not present in {self.root}" ) _ = self.model_name _repo_root = str(self.root) aux_cli = os.environ.get("VOCENCE_AUX_PROJECTION_SHARD", "").strip() prefer_discrete = os.environ.get( "VOCENCE_PREFER_DISCRETE_COEFF_DIR", "" ).lower() in ("1", "true", "yes") self._fabric_q = _QxResidualFabric( _repo_root, aux_projection_shard_fp=aux_cli or None, skip_aux_shard=prefer_discrete, ) if not self._fabric_q._r_handles: raise RuntimeError( "No auxiliary conditioning handles mounted in repo; set VV_AUX_PROJECTION_PATH / VOCENCE_AUX_PROJECTION_SHARD, or ship aux_lm_residual_projection.safetensors at the repo root." ) seed_s = os.environ.get("VOCENCE_SEED", "").strip() if seed_s: s = int(seed_s) torch.manual_seed(s) if torch.cuda.is_available(): torch.cuda.manual_seed_all(s) self._cfg_scale = float(os.environ.get("VOCENCE_CFG_SCALE", "1.3")) self._disable_prefill = os.environ.get( "VOCENCE_DISABLE_PREFILL", "" ).lower() in ("1", "true", "yes") self._processor: Optional[QWEN3VoxProcessor] = None self._device: str = "cpu" self._sample_rate: int = 22050 def __repr__(self) -> str: return f"" @cached_property def model_name(self) -> str: raw = self._load_yaml(self.root / self.SETTINGS_FILE) name = str(raw.get("model_name") or "").strip() if not name: raise ValueError("vocence_config.yaml missing model_name") return name @cached_property def settings(self) -> SimpleNamespace: raw = self._load_yaml(self.root / self.SETTINGS_FILE) rt = raw.get("runtime") or {} gen = raw.get("generation") or {} lim = raw.get("limits") or {} return SimpleNamespace( language=str( lim.get("default_language") or rt.get("default_language") or "English" ), sample_rate=int(gen.get("sample_rate", 24000)), max_instruction_chars=int(lim.get("max_instruction_chars", 600)), max_text_chars=int(lim.get("max_text_chars", 2000)), prefer_cuda=str(rt.get("device_preference", "cuda")).lower() == "cuda", prefer_bf16=str(rt.get("dtype", "bfloat16")).lower() == "bfloat16", prefer_flash=bool(rt.get("use_flash_attention_2", False)), ) @cached_property def model(self) -> QWEN3VoxForConditionalGenerationInference: return self._instantiate_engine() def _instantiate_engine(self) -> QWEN3VoxForConditionalGenerationInference: s = self.settings model_name = self.model_name if s.prefer_cuda and torch.cuda.is_available(): self._device = "cuda" elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): self._device = "mps" else: self._device = "cpu" if self._device == "mps": load_dtype = torch.float32 attn_attempts = ("sdpa",) elif self._device == "cuda": load_dtype = ( torch.bfloat16 if s.prefer_bf16 else torch.float32 ) attn_attempts = ( ("flash_attention_2", "sdpa") if s.prefer_flash else ("sdpa", "flash_attention_2") ) else: load_dtype = torch.float32 attn_attempts = ("sdpa",) self._processor = QWEN3VoxProcessor.from_pretrained(model_name) last_failure: Optional[BaseException] = None engine: Optional[QWEN3VoxForConditionalGenerationInference] = None for attn_impl in attn_attempts: try: engine = self._load_model_weights(model_name, load_dtype, attn_impl) dtype_tag = "bf16" if load_dtype is torch.bfloat16 else "fp32" print( f"[Miner] QWEN3Vox ready :: device={self._device} " f"dtype={dtype_tag} attn={attn_impl}" ) break except Exception as exc: last_failure = exc if engine is None: raise RuntimeError(f"QWEN3Vox failed to load :: {last_failure!r}") ckpt = os.environ.get("VOCENCE_CHECKPOINT_PATH", "").strip() if ckpt: load_lora_assets(engine, ckpt) engine.train(False) engine.set_ddpm_inference_steps(num_steps=10) proc_sr = int( getattr(self._processor.audio_processor, "sampling_rate", 22050) ) self._sample_rate = proc_sr if proc_sr > 0 else s.sample_rate return engine def _load_model_weights( self, model_name: str, load_dtype: torch.dtype, attn_impl: str ) -> QWEN3VoxForConditionalGenerationInference: if self._device == "mps": m = QWEN3VoxForConditionalGenerationInference.from_pretrained( model_name, torch_dtype=load_dtype, attn_implementation=attn_impl, device_map=None, ) m.to("mps") return m if self._device == "cuda": return QWEN3VoxForConditionalGenerationInference.from_pretrained( model_name, torch_dtype=load_dtype, device_map="cuda", attn_implementation=attn_impl, ) return QWEN3VoxForConditionalGenerationInference.from_pretrained( model_name, torch_dtype=load_dtype, device_map="cpu", attn_implementation=attn_impl, ) def warmup(self) -> None: outcome: dict[str, Any] = {"done": False, "err": None} def _trial() -> None: try: self.generate_wav( instruction=( "An adult male with a neutral British accent, speaking at a " "normal pace in a mid-range pitch, sounding calm and formal." ), text="This is a warmup utterance for the QWEN3Vox engine.", ) outcome["done"] = True except Exception as exc: outcome["err"] = repr(exc) worker = threading.Thread(target=_trial, daemon=True) worker.start() worker.join(timeout=self.WARMUP_TIMEOUT) if not outcome["done"]: raise RuntimeError( f"warmup did not complete within {self.WARMUP_TIMEOUT}s: " f"{outcome['err'] or 'no completion signal'}" ) @staticmethod def _load_yaml(path: Path) -> dict[str, Any]: if not path.is_file(): return {} from yaml import safe_load with path.open("r", encoding="utf-8") as fh: return safe_load(fh) or {} def _speech_tensor_to_numpy(self, speech: torch.Tensor) -> np.ndarray: t = speech.detach().cpu().float() while t.dim() > 1: t = t.squeeze(0) if t.dim() != 1: t = t.reshape(-1) return t.numpy().astype(np.float32, copy=False) def generate_wav(self, instruction: str, text: str) -> Tuple[np.ndarray, int]: """Synthesize audio. Pass validator text and instruction verbatim (length caps only). - instruction → processor system_prompt (tokenized as-is; no NL parsing or rewriting) - text → script body (plain transcript or existing Speaker N: lines; no wrapping) - default aux shard for acoustic prefill only (not derived from instruction text) """ s = self.settings if s.max_instruction_chars > 0 and len(instruction) > s.max_instruction_chars: instruction = instruction[: s.max_instruction_chars] if s.max_text_chars > 0 and len(text) > s.max_text_chars: text = text[: s.max_text_chars] inference_model = self.model processor = self._processor if processor is None: raise RuntimeError("processor not initialized after model load") default_system = getattr(processor, "_vocence_default_system_prompt", None) if default_system is None: processor._vocence_default_system_prompt = processor.system_prompt default_system = processor.system_prompt if instruction.strip(): processor.system_prompt = instruction else: processor.system_prompt = default_system prefill = self._fabric_q._fabric_pick_residual_snapshot(DEFAULT_AUX_SLICE_ID) try: inputs = processor( text=[text], voice_samples=[[prefill]], padding=True, return_tensors="pt", return_attention_mask=True, ) target = self._device if self._device != "cpu" else "cpu" for k, v in inputs.items(): if torch.is_tensor(v): inputs[k] = v.to(target) with torch.inference_mode(): outputs = inference_model.generate( **inputs, max_new_tokens=None, cfg_scale=self._cfg_scale, tokenizer=processor.tokenizer, generation_config={"do_sample": False}, verbose=False, is_prefill=not self._disable_prefill, ) if not outputs.speech_outputs or outputs.speech_outputs[0] is None: raise RuntimeError("QWEN3Vox returned no speech output.") wav = self._speech_tensor_to_numpy(outputs.speech_outputs[0]) return (wav, self._sample_rate) finally: processor.system_prompt = default_system