diff --git "a/miner.py" "b/miner.py" new file mode 100644--- /dev/null +++ "b/miner.py" @@ -0,0 +1,6583 @@ +from __future__ import annotations + + +import io +import json +import os +import re +import sys +import threading +import traceback +from functools import cached_property +from pathlib import Path +from types import SimpleNamespace +from typing import AbstractSet, Any, Dict, List, Optional, Sequence, Tuple, Union +import numpy as np +import torch +from transformers.utils import logging as hf_logging +import math +import random +import warnings +from dataclasses import dataclass + +try: + import librosa +except Exception: + librosa = None +try: + import resampy +except Exception: + resampy = None + + +def _resample_if_needed(wav: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: + if orig_sr == target_sr: + return wav.astype(np.float32, copy=False) + if resampy is not None: + return resampy.resample(wav.astype(np.float32), orig_sr, target_sr) + if librosa is not None: + return librosa.resample( + y=wav.astype(np.float32), orig_sr=orig_sr, target_sr=target_sr + ) + warnings.warn( + "No resampler available; treating audio as target_sr without resampling. Install resampy or librosa.", + RuntimeWarning, + ) + return wav.astype(np.float32, copy=False) + + +' QWEN3Vox_AcousticTokenizer model configuration' +from typing import Dict, List, Optional, Tuple +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config + +logger = logging.get_logger(__name__) + + +class QWEN3VoxAcousticTokenizerConfig(PretrainedConfig): + model_type = 'vibevoice_acoustic_tokenizer' + + def __init__( + self, + channels: int = 1, + corpus_normalize: float = 0.0, + causal: bool = True, + vae_dim: int = 64, + fix_std: float = 0.5, + std_dist_type: str = "gaussian", + mixer_layer: str = "depthwise_conv", + conv_norm: str = "none", + pad_mode: str = "constant", + disable_last_norm: bool = True, + layernorm: str = "RMSNorm", + layernorm_eps: float = 1e-05, + layernorm_elementwise_affine: bool = True, + conv_bias: bool = True, + layer_scale_init_value: float = 1e-06, + weight_init_value: float = 0.01, + encoder_n_filters: int = 32, + encoder_ratios: Optional[List[int]] = [8, 5, 5, 4, 2, 2], + encoder_depths: str = "3-3-3-3-3-3-8", + decoder_n_filters: int = 32, + decoder_ratios: Optional[List[int]] = None, + decoder_depths: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.channels = channels + self.corpus_normalize = corpus_normalize + self.causal = causal + self.vae_dim = vae_dim + self.fix_std = fix_std + self.std_dist_type = std_dist_type + self.conv_norm = conv_norm + self.pad_mode = pad_mode + self.layernorm_eps = layernorm_eps + self.disable_last_norm = disable_last_norm + self.layernorm = layernorm + self.layernorm_elementwise_affine = layernorm_elementwise_affine + self.conv_bias = conv_bias + self.layer_scale_init_value = layer_scale_init_value + self.weight_init_value = weight_init_value + self.mixer_layer = mixer_layer + self.encoder_n_filters = encoder_n_filters + self.encoder_ratios = encoder_ratios + self.encoder_depths = encoder_depths + self.decoder_ratios = ( + decoder_ratios if decoder_ratios is not None else encoder_ratios + ) + self.decoder_n_filters = decoder_n_filters + self.decoder_depths = decoder_depths + + +class QWEN3VoxSemanticTokenizerConfig(PretrainedConfig): + model_type = 'vibevoice_semantic_tokenizer' + + def __init__( + self, + channels: int = 1, + corpus_normalize: float = 0.0, + causal: bool = True, + vae_dim: int = 64, + fix_std: float = 0, + std_dist_type: str = "none", + mixer_layer: str = "depthwise_conv", + conv_norm: str = "none", + pad_mode: str = "constant", + disable_last_norm: bool = True, + layernorm: str = "RMSNorm", + layernorm_eps: float = 1e-05, + layernorm_elementwise_affine: bool = True, + conv_bias: bool = True, + layer_scale_init_value: float = 1e-06, + weight_init_value: float = 0.01, + encoder_n_filters: int = 32, + encoder_ratios: Optional[List[int]] = [8, 5, 5, 4, 2, 2], + encoder_depths: str = "3-3-3-3-3-3-8", + **kwargs, + ): + super().__init__(**kwargs) + self.channels = channels + self.corpus_normalize = corpus_normalize + self.causal = causal + self.vae_dim = vae_dim + self.fix_std = fix_std + self.std_dist_type = std_dist_type + self.conv_norm = conv_norm + self.pad_mode = pad_mode + self.layernorm_eps = layernorm_eps + self.disable_last_norm = disable_last_norm + self.layernorm = layernorm + self.layernorm_elementwise_affine = layernorm_elementwise_affine + self.conv_bias = conv_bias + self.layer_scale_init_value = layer_scale_init_value + self.weight_init_value = weight_init_value + self.mixer_layer = mixer_layer + self.encoder_n_filters = encoder_n_filters + self.encoder_ratios = encoder_ratios + self.encoder_depths = encoder_depths + + +class QWEN3VoxDiffusionHeadConfig(PretrainedConfig): + model_type = 'vibevoice_diffusion_head' + + def __init__( + self, + hidden_size=768, + head_layers=4, + head_ffn_ratio=3.0, + rms_norm_eps=1e-05, + latent_size=64, + speech_vae_dim=None, + prediction_type="v_prediction", + diffusion_type="ddpm", + ddpm_num_steps=1000, + ddpm_num_inference_steps=30, + ddpm_beta_schedule="cosine", + ddpm_batch_mul=4, + **kwargs, + ): + self.hidden_size = hidden_size + self.head_layers = head_layers + self.head_ffn_ratio = head_ffn_ratio + self.rms_norm_eps = rms_norm_eps + self.latent_size = latent_size + self.speech_vae_dim = speech_vae_dim + self.prediction_type = prediction_type + self.diffusion_type = diffusion_type + self.ddpm_num_steps = ddpm_num_steps + self.ddpm_num_inference_steps = ddpm_num_inference_steps + self.ddpm_beta_schedule = ddpm_beta_schedule + self.ddpm_batch_mul = ddpm_batch_mul + super().__init__(**kwargs) + + +class QWEN3VoxConfig(PretrainedConfig): + model_type = 'vibevoice' + is_composition = True + sub_configs = { + "acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig, + "semantic_tokenizer_config": QWEN3VoxSemanticTokenizerConfig, + "decoder_config": Qwen2Config, + "diffusion_head_config": QWEN3VoxDiffusionHeadConfig, + } + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + acoustic_tokenizer_config=None, + semantic_tokenizer_config=None, + decoder_config=None, + diffusion_head_config=None, + **kwargs, + ): + kwargs["_attn_implementation_autoset"] = False + if acoustic_tokenizer_config is None: + self.acoustic_tokenizer_config = self.sub_configs[ + "acoustic_tokenizer_config" + ]() + elif isinstance(acoustic_tokenizer_config, dict): + acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer' + self.acoustic_tokenizer_config = self.sub_configs[ + "acoustic_tokenizer_config" + ](**acoustic_tokenizer_config) + elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig): + self.acoustic_tokenizer_config = acoustic_tokenizer_config + if semantic_tokenizer_config is None: + self.semantic_tokenizer_config = self.sub_configs[ + "semantic_tokenizer_config" + ]() + elif isinstance(semantic_tokenizer_config, dict): + semantic_tokenizer_config["model_type"] = 'vibevoice_semantic_tokenizer' + self.semantic_tokenizer_config = self.sub_configs[ + "semantic_tokenizer_config" + ](**semantic_tokenizer_config) + elif isinstance(semantic_tokenizer_config, QWEN3VoxSemanticTokenizerConfig): + self.semantic_tokenizer_config = semantic_tokenizer_config + if decoder_config is None: + self.decoder_config = self.sub_configs["decoder_config"]() + elif isinstance(decoder_config, dict): + if decoder_config.get("model_type", "") == "qwen2": + self.decoder_config = Qwen2Config(**decoder_config) + else: + raise ValueError( + f"Unsupported decoder model type: {decoder_config .get ('model_type','')}" + ) + elif isinstance(decoder_config, (Qwen2Config,)): + self.decoder_config = decoder_config + if diffusion_head_config is None: + self.diffusion_head_config = self.sub_configs["diffusion_head_config"]() + elif isinstance(diffusion_head_config, dict): + diffusion_head_config["model_type"] = 'vibevoice_diffusion_head' + self.diffusion_head_config = self.sub_configs["diffusion_head_config"]( + **diffusion_head_config + ) + elif isinstance(diffusion_head_config, QWEN3VoxDiffusionHeadConfig): + self.diffusion_head_config = diffusion_head_config + self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64) + self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, "vae_dim", 128) + super().__init__(**kwargs) + + +class QWEN3VoxASRConfig(PretrainedConfig): + model_type = 'vibevoice' + is_composition = True + sub_configs = { + "acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig, + "semantic_tokenizer_config": QWEN3VoxSemanticTokenizerConfig, + "decoder_config": Qwen2Config, + } + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + acoustic_tokenizer_config=None, + semantic_tokenizer_config=None, + decoder_config=None, + **kwargs, + ): + kwargs["_attn_implementation_autoset"] = False + if acoustic_tokenizer_config is None: + self.acoustic_tokenizer_config = self.sub_configs[ + "acoustic_tokenizer_config" + ]() + elif isinstance(acoustic_tokenizer_config, dict): + acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer' + self.acoustic_tokenizer_config = self.sub_configs[ + "acoustic_tokenizer_config" + ](**acoustic_tokenizer_config) + elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig): + self.acoustic_tokenizer_config = acoustic_tokenizer_config + if semantic_tokenizer_config is None: + self.semantic_tokenizer_config = self.sub_configs[ + "semantic_tokenizer_config" + ]() + elif isinstance(semantic_tokenizer_config, dict): + semantic_tokenizer_config["model_type"] = 'vibevoice_semantic_tokenizer' + self.semantic_tokenizer_config = self.sub_configs[ + "semantic_tokenizer_config" + ](**semantic_tokenizer_config) + elif isinstance(semantic_tokenizer_config, QWEN3VoxSemanticTokenizerConfig): + self.semantic_tokenizer_config = semantic_tokenizer_config + if decoder_config is None: + self.decoder_config = self.sub_configs["decoder_config"]() + elif isinstance(decoder_config, dict): + if decoder_config.get("model_type", "") == "qwen2": + self.decoder_config = Qwen2Config(**decoder_config) + else: + raise ValueError( + f"Unsupported decoder model type: {decoder_config .get ('model_type','')}" + ) + elif isinstance(decoder_config, Qwen2Config): + self.decoder_config = decoder_config + self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64) + self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, "vae_dim", 128) + super().__init__(**kwargs) + + def get_text_config(self, decoder: bool = False): + return self.decoder_config + + @property + def vocab_size(self): + return self.decoder_config.vocab_size + + @property + def num_attention_heads(self): + return self.decoder_config.num_attention_heads + + @property + def num_key_value_heads(self): + return self.decoder_config.num_key_value_heads + + @property + def hidden_size(self): + return self.decoder_config.hidden_size + + @property + def num_hidden_layers(self): + return self.decoder_config.num_hidden_layers + + @property + def head_dim(self): + return getattr( + self.decoder_config, + "head_dim", + self.hidden_size // self.num_attention_heads, + ) + + +__all__ = [ + 'QWEN3VoxAcousticTokenizerConfig', + 'QWEN3VoxSemanticTokenizerConfig', + 'QWEN3VoxDiffusionHeadConfig', + 'QWEN3VoxConfig', + 'QWEN3VoxASRConfig', +] +import torch +import asyncio +from queue import Queue +from typing import TYPE_CHECKING, Optional +from transformers.generation import BaseStreamer + + +class AudioStreamer(BaseStreamer): + + def __init__( + self, + batch_size: int, + stop_signal: Optional[any] = None, + timeout: Optional[float] = None, + ): + self.batch_size = batch_size + self.stop_signal = stop_signal + self.timeout = timeout + self.audio_queues = [Queue() for _ in range(batch_size)] + self.finished_flags = [False for _ in range(batch_size)] + self.sample_indices_map = {} + + def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): + for i, sample_idx in enumerate(sample_indices): + idx = sample_idx.item() + if idx < self.batch_size and (not self.finished_flags[idx]): + audio_chunk = audio_chunks[i].detach().cpu() + self.audio_queues[idx].put(audio_chunk, timeout=self.timeout) + + def end(self, sample_indices: Optional[torch.Tensor] = None): + if sample_indices is None: + for idx in range(self.batch_size): + if not self.finished_flags[idx]: + self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) + self.finished_flags[idx] = True + else: + for sample_idx in sample_indices: + idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx + if idx < self.batch_size and (not self.finished_flags[idx]): + self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) + self.finished_flags[idx] = True + + def __iter__(self): + return AudioBatchIterator(self) + + def get_stream(self, sample_idx: int): + if sample_idx >= self.batch_size: + raise ValueError( + f"Sample index {sample_idx } exceeds batch size {self .batch_size }" + ) + return AudioSampleIterator(self, sample_idx) + + +class AudioSampleIterator: + + def __init__(self, streamer: AudioStreamer, sample_idx: int): + self.streamer = streamer + self.sample_idx = sample_idx + + def __iter__(self): + return self + + def __next__(self): + value = self.streamer.audio_queues[self.sample_idx].get( + timeout=self.streamer.timeout + ) + if value == self.streamer.stop_signal: + raise StopIteration() + return value + + +class AudioBatchIterator: + + def __init__(self, streamer: AudioStreamer): + self.streamer = streamer + self.active_samples = set(range(streamer.batch_size)) + + def __iter__(self): + return self + + def __next__(self): + if not self.active_samples: + raise StopIteration() + batch_chunks = {} + samples_to_remove = set() + for idx in self.active_samples: + try: + value = self.streamer.audio_queues[idx].get(block=False) + if value == self.streamer.stop_signal: + samples_to_remove.add(idx) + else: + batch_chunks[idx] = value + except: + pass + self.active_samples -= samples_to_remove + if batch_chunks: + return batch_chunks + elif self.active_samples: + import time + + time.sleep(0.01) + return self.__next__() + else: + raise StopIteration() + + +class AsyncAudioStreamer(AudioStreamer): + + def __init__( + self, + batch_size: int, + stop_signal: Optional[any] = None, + timeout: Optional[float] = None, + ): + super().__init__(batch_size, stop_signal, timeout) + self.audio_queues = [asyncio.Queue() for _ in range(batch_size)] + self.loop = asyncio.get_running_loop() + + def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): + for i, sample_idx in enumerate(sample_indices): + idx = sample_idx.item() + if idx < self.batch_size and (not self.finished_flags[idx]): + audio_chunk = audio_chunks[i].detach().cpu() + self.loop.call_soon_threadsafe( + self.audio_queues[idx].put_nowait, audio_chunk + ) + + def end(self, sample_indices: Optional[torch.Tensor] = None): + if sample_indices is None: + indices_to_end = range(self.batch_size) + else: + indices_to_end = [ + s.item() if torch.is_tensor(s) else s for s in sample_indices + ] + for idx in indices_to_end: + if idx < self.batch_size and (not self.finished_flags[idx]): + self.loop.call_soon_threadsafe( + self.audio_queues[idx].put_nowait, self.stop_signal + ) + self.finished_flags[idx] = True + + async def get_stream(self, sample_idx: int): + if sample_idx >= self.batch_size: + raise ValueError( + f"Sample index {sample_idx } exceeds batch size {self .batch_size }" + ) + while True: + value = await self.audio_queues[sample_idx].get() + if value == self.stop_signal: + break + yield value + + def __aiter__(self): + return AsyncAudioBatchIterator(self) + + +class AsyncAudioBatchIterator: + + def __init__(self, streamer: AsyncAudioStreamer): + self.streamer = streamer + self.active_samples = set(range(streamer.batch_size)) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self.active_samples: + raise StopAsyncIteration() + batch_chunks = {} + samples_to_remove = set() + tasks = { + idx: asyncio.create_task(self._get_chunk(idx)) + for idx in self.active_samples + } + done, pending = await asyncio.wait( + tasks.values(), + return_when=asyncio.FIRST_COMPLETED, + timeout=self.streamer.timeout, + ) + for task in pending: + task.cancel() + for idx, task in tasks.items(): + if task in done: + try: + value = await task + if value == self.streamer.stop_signal: + samples_to_remove.add(idx) + else: + batch_chunks[idx] = value + except asyncio.CancelledError: + pass + self.active_samples -= samples_to_remove + if batch_chunks: + return batch_chunks + elif self.active_samples: + return await self.__anext__() + else: + raise StopAsyncIteration() + + async def _get_chunk(self, idx): + return await self.streamer.audio_queues[idx].get() + + +'Tokenization classes for QWEN3Vox.' +from typing import List, Optional, Union +from transformers.utils import logging +from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer +from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast + +logger = logging.get_logger(__name__) + + +class QWEN3VoxTextTokenizer(Qwen2Tokenizer): + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + add_prefix_space=False, + add_special_tokens=True, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + merges_file=merges_file, + errors=errors, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + add_prefix_space=add_prefix_space, + add_special_tokens=add_special_tokens, + **kwargs, + ) + self._add_q3_sp_tok() + + def _add_q3_sp_tok(self): + special_tokens = { + "additional_special_tokens": [ + "<|vision_start|>", + "<|vision_end|>", + "<|vision_pad|>", + ] + } + num_added = self.add_special_tokens(special_tokens) + self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>") + self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>") + self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>") + self._eos_id = self.convert_tokens_to_ids("<|endoftext|>") + return num_added + + @property + def eos_id(self) -> int: + return self._eos_id + + @property + def speech_start_id(self) -> int: + return self._speech_start_id + + @property + def speech_end_id(self) -> int: + return self._speech_end_id + + @property + def speech_diffusion_id(self) -> int: + return self._speech_diffusion_id + + @property + def pad_id(self) -> int: + return -100 + + +class QWEN3VoxTextTokenizerFast(Qwen2TokenizerFast): + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + add_prefix_space=False, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + merges_file=merges_file, + tokenizer_file=tokenizer_file, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + self._add_q3_sp_tok() + + def _add_q3_sp_tok(self): + special_tokens = { + "additional_special_tokens": [ + "<|vision_start|>", + "<|vision_end|>", + "<|vision_pad|>", + ] + } + num_added = self.add_special_tokens(special_tokens) + self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>") + self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>") + self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>") + self._eos_id = self.eos_token_id + self._pad_id = self.convert_tokens_to_ids("<|image_pad|>") + return num_added + + @property + def eos_id(self) -> int: + return self._eos_id + + @property + def speech_start_id(self) -> int: + return self._speech_start_id + + @property + def speech_end_id(self) -> int: + return self._speech_end_id + + @property + def speech_diffusion_id(self) -> int: + return self._speech_diffusion_id + + @property + def pad_id(self) -> int: + return self._pad_id + + +QWEN3VoxASRTextTokenizerFast = QWEN3VoxTextTokenizerFast + +__all__ = [ + 'QWEN3VoxTextTokenizer', + 'QWEN3VoxTextTokenizerFast', +] +"Utilities for loading fine-tuned LoRA adapters and connector weights." +from dataclasses import dataclass +from pathlib import Path +from typing import Optional +import torch +import torch.nn as nn +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +@dataclass +class _LoadReport: + language_model: bool = False + diffusion_head_lora: bool = False + diffusion_head_full: bool = False + acoustic_connector: bool = False + semantic_connector: bool = False + adapter_root: Optional[Path] = None + + +class _DiffusionHeadForwardShim(nn.Module): + + def __init__(self, base: nn.Module): + super().__init__() + self.base = base + + def forward(self, *args, **kwargs): + if len(args) >= 3: + noisy_images, timesteps, condition = args[:3] + else: + noisy_images = kwargs.get("noisy_images") + timesteps = kwargs.get("timesteps") + condition = kwargs.get("condition") + return self.base(noisy_images, timesteps, condition) + + +import math +from typing import List, Optional, Tuple, Union +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import deprecate +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) + + +def betas_for_alpha_bar( + num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine" +): + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + elif alpha_transform_type == "cauchy": + + def alpha_bar_fn(t, gamma=1, mu=3): + snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9) + return 1 - 1 / (math.exp(snr) + 1.1) + + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t, mu=0, b=1): + snr = mu - b * math.copysign(1, 0.5 - t) * math.log( + 1 - 2 * abs(t - 0.5) * 0.98 + ) + return 1 - 1 / (math.exp(snr) + 1.02) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type }") + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(betas): + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + alphas_bar_sqrt -= alphas_bar_sqrt_T + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + alphas_bar = alphas_bar_sqrt**2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + + +class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + use_lu_lambdas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "zero", + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type } is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate( + "algorithm_types dpmsolver and sde-dpmsolver", + "1.0.0", + deprecation_message, + ) + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace( + beta_start, beta_end, num_train_timesteps, dtype=torch.float32 + ) + elif beta_schedule == "scaled_linear": + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine": + self.betas = betas_for_alpha_bar( + num_train_timesteps, alpha_transform_type="cosine" + ) + elif beta_schedule == "cauchy": + self.betas = betas_for_alpha_bar( + num_train_timesteps, alpha_transform_type="cauchy" + ) + elif beta_schedule == "laplace": + self.betas = betas_for_alpha_bar( + num_train_timesteps, alpha_transform_type="laplace" + ) + else: + raise NotImplementedError( + f"{beta_schedule } is not implemented for {self .__class__ }" + ) + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + if rescale_betas_zero_snr: + self.alphas_cumprod[-1] = 2 ** (-24) + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + self.init_noise_sigma = 1.0 + if algorithm_type not in [ + "dpmsolver", + "dpmsolver++", + "sde-dpmsolver", + "sde-dpmsolver++", + ]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError( + f"{algorithm_type } is not implemented for {self .__class__ }" + ) + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError( + f"{solver_type } is not implemented for {self .__class__ }" + ) + if ( + algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] + and final_sigmas_type == "zero" + ): + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type } is not supported for `algorithm_type` {algorithm_type }. Please choose `sigma_min` instead." + ) + self.num_inference_steps = None + timesteps = np.linspace( + 0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32 + )[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") + + @property + def step_index(self): + return self._step_index + + @property + def begin_index(self): + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + if num_inference_steps is None and timesteps is None: + raise ValueError( + "Must pass exactly one of `num_inference_steps` or `timesteps`." + ) + if num_inference_steps is not None and timesteps is not None: + raise ValueError( + "Can only pass one of `num_inference_steps` or `custom_timesteps`." + ) + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError( + "Cannot use `timesteps` with `config.use_karras_sigmas = True`" + ) + if timesteps is not None and self.config.use_lu_lambdas: + raise ValueError( + "Cannot use `timesteps` with `config.use_lu_lambdas = True`" + ) + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.int64) + else: + clipped_idx = torch.searchsorted( + torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped + ) + last_timestep = ( + (self.config.num_train_timesteps - clipped_idx).numpy().item() + ) + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + timesteps = ( + np.arange(last_timestep, 0, -step_ratio) + .round() + .copy() + .astype(np.int64) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self .config .timestep_spacing } is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + if self.config.use_karras_sigmas: + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ).round() + elif self.config.use_lu_lambdas: + lambdas = np.flip(log_sigmas.copy()) + lambdas = self._convert_to_lu( + in_lambdas=lambdas, num_inference_steps=num_inference_steps + ) + sigmas = np.exp(lambdas) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ).round() + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self .config .final_sigmas_type }" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64 + ) + self.num_inference_steps = len(timesteps) + self.model_outputs = [None] * self.config.solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") + + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + if dtype not in (torch.float32, torch.float64): + sample = sample.float() + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + abs_sample = sample.abs() + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp(s, min=1, max=self.config.sample_max_value) + s = s.unsqueeze(1) + sample = torch.clamp(sample, -s, s) / s + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + return sample + + def _sigma_to_t(self, sigma, log_sigmas): + log_sigma = np.log(np.maximum(sigma, 1e-10)) + dists = log_sigma - log_sigmas[:, np.newaxis] + low_idx = ( + np.cumsum(dists >= 0, axis=0) + .argmax(axis=0) + .clip(max=log_sigmas.shape[0] - 2) + ) + high_idx = low_idx + 1 + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / (sigma**2 + 1) ** 0.5 + sigma_t = sigma * alpha_t + return (alpha_t, sigma_t) + + def _convert_to_karras( + self, in_sigmas: torch.Tensor, num_inference_steps + ) -> torch.Tensor: + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + rho = 7.0 + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _convert_to_lu( + self, in_lambdas: torch.Tensor, num_inference_steps + ) -> torch.Tensor: + lambda_min: float = in_lambdas[-1].item() + lambda_max: float = in_lambdas[0].item() + rho = 1.0 + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = lambda_min ** (1 / rho) + max_inv_rho = lambda_max ** (1 / rho) + lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return lambdas + + def convert_model_output( + self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs + ) -> torch.Tensor: + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self .config .prediction_type } must be one of `epsilon`, `sample`, or `v_prediction` for the DPMSolverMultistepScheduler." + ) + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + return x0_pred + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self .config .prediction_type } must be one of `epsilon`, `sample`, or `v_prediction` for the DPMSolverMultistepScheduler." + ) + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = ( + sigma_t / sigma_s * sample + - alpha_t * (torch.exp(-h) - 1.0) * model_output + ) + elif self.config.algorithm_type == "dpmsolver": + x_t = ( + alpha_t / alpha_s * sample + - sigma_t * (torch.exp(h) - 1.0) * model_output + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + sigma_t / sigma_s * torch.exp(-h) * sample + + alpha_t * (1 - torch.exp(-2.0 * h)) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + alpha_t / alpha_s * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + m0, m1 = (model_output_list[-1], model_output_list[-2]) + h, h_0 = (lambda_t - lambda_s0, lambda_s0 - lambda_s1) + r0 = h_0 / h + D0, D1 = (m0, 1.0 / r0 * (m0 - m1)) + if self.config.algorithm_type == "dpmsolver++": + if self.config.solver_type == "midpoint": + x_t = ( + sigma_t / sigma_s0 * sample + - alpha_t * (torch.exp(-h) - 1.0) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + sigma_t / sigma_s0 * sample + - alpha_t * (torch.exp(-h) - 1.0) * D0 + + alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + if self.config.solver_type == "midpoint": + x_t = ( + alpha_t / alpha_s0 * sample + - sigma_t * (torch.exp(h) - 1.0) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + alpha_t / alpha_s0 * sample + - sigma_t * (torch.exp(h) - 1.0) * D0 + - sigma_t * ((torch.exp(h) - 1.0) / h - 1.0) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + sigma_t / sigma_s0 * torch.exp(-h) * sample + + alpha_t * (1 - torch.exp(-2.0 * h)) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + sigma_t / sigma_s0 * torch.exp(-h) * sample + + alpha_t * (1 - torch.exp(-2.0 * h)) * D0 + + alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + alpha_t / alpha_s0 * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - sigma_t * (torch.exp(h) - 1.0) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + alpha_t / alpha_s0 * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + m0, m1, m2 = ( + model_output_list[-1], + model_output_list[-2], + model_output_list[-3], + ) + h, h_0, h_1 = ( + lambda_t - lambda_s0, + lambda_s0 - lambda_s1, + lambda_s1 - lambda_s2, + ) + r0, r1 = (h_0 / h, h_1 / h) + D0 = m0 + D1_0, D1_1 = (1.0 / r0 * (m0 - m1), 1.0 / r1 * (m1 - m2)) + D1 = D1_0 + r0 / (r0 + r1) * (D1_0 - D1_1) + D2 = 1.0 / (r0 + r1) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + x_t = ( + sigma_t / sigma_s0 * sample + - alpha_t * (torch.exp(-h) - 1.0) * D0 + + alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0) * D1 + - alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + x_t = ( + alpha_t / alpha_s0 * sample + - sigma_t * (torch.exp(h) - 1.0) * D0 + - sigma_t * ((torch.exp(h) - 1.0) / h - 1.0) * D1 + - sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5) * D2 + ) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + return step_index + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + if self.step_index is None: + self._init_step_index(timestep) + lower_order_final = self.step_index == len(self.timesteps) - 1 and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + self.step_index == len(self.timesteps) - 2 + and self.config.lower_order_final + and (len(self.timesteps) < 15) + ) + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + sample = sample.to(torch.float32) + if ( + self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] + and variance_noise is None + ): + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=torch.float32, + ) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to(device=model_output.device, dtype=torch.float32) + else: + noise = None + if ( + self.config.solver_order == 1 + or self.lower_order_nums < 1 + or lower_order_final + ): + prev_sample = self.dpm_solver_first_order_update( + model_output, sample=sample, noise=noise + ) + elif ( + self.config.solver_order == 2 + or self.lower_order_nums < 2 + or lower_order_second + ): + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, sample=sample, noise=noise + ) + else: + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, sample=sample + ) + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + prev_sample = prev_sample.to(model_output.dtype) + self._step_index += 1 + if not return_dict: + return (prev_sample,) + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype) + sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + alpha_t = alpha_t[timesteps].flatten() + while len(alpha_t.shape) < len(original_samples.shape): + alpha_t = alpha_t.unsqueeze(-1) + sigma_t = sigma_t[timesteps].flatten() + while len(sigma_t.shape) < len(original_samples.shape): + sigma_t = sigma_t.unsqueeze(-1) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def get_velocity( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype) + sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + alpha_t = alpha_t[timesteps].flatten() + while len(alpha_t.shape) < len(original_samples.shape): + alpha_t = alpha_t.unsqueeze(-1) + sigma_t = sigma_t[timesteps].flatten() + while len(sigma_t.shape) < len(original_samples.shape): + sigma_t = sigma_t.unsqueeze(-1) + velocity = alpha_t * noise - sigma_t * original_samples + return velocity + + def __len__(self): + return self.config.num_train_timesteps + + +'\nProcessor class for QWEN3Vox models.\n' +import os +import json +import warnings +from typing import List, Optional, Union, Dict, Any +import numpy as np +import torch +from transformers.feature_extraction_utils import FeatureExtractionMixin +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class AudioNormalizer: + + def __init__(self, target_dB_FS: float = -25, eps: float = 1e-06): + self.target_dB_FS = target_dB_FS + self.eps = eps + + def tailor_dB_FS(self, audio: np.ndarray) -> tuple: + rms = np.sqrt(np.mean(audio**2)) + scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps) + normalized_audio = audio * scalar + return (normalized_audio, rms, scalar) + + def avoid_clipping( + self, audio: np.ndarray, scalar: Optional[float] = None + ) -> tuple: + if scalar is None: + max_val = np.max(np.abs(audio)) + if max_val > 1.0: + scalar = max_val + self.eps + else: + scalar = 1.0 + return (audio / scalar, scalar) + + def __call__(self, audio: np.ndarray) -> np.ndarray: + audio, _, _ = self.tailor_dB_FS(audio) + audio, _ = self.avoid_clipping(audio) + return audio + + +class QWEN3VoxTokenizerProcessor(FeatureExtractionMixin): + model_input_names = ["input_features"] + + def __init__( + self, + sampling_rate: int = 22050, + normalize_audio: bool = True, + target_dB_FS: float = -25, + eps: float = 1e-06, + **kwargs, + ): + super().__init__(**kwargs) + self.sampling_rate = sampling_rate + self.normalize_audio = normalize_audio + if self.normalize_audio: + self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps) + else: + self.normalizer = None + self.feature_extractor_dict = { + "sampling_rate": sampling_rate, + "normalize_audio": normalize_audio, + "target_dB_FS": target_dB_FS, + "eps": eps, + } + + def _ensure_mono(self, audio: np.ndarray) -> np.ndarray: + if len(audio.shape) == 1: + return audio + elif len(audio.shape) == 2: + if audio.shape[0] == 2: + return np.mean(audio, axis=0) + elif audio.shape[1] == 2: + return np.mean(audio, axis=1) + elif audio.shape[0] == 1: + return audio.squeeze(0) + elif audio.shape[1] == 1: + return audio.squeeze(1) + else: + raise ValueError(f"Unexpected audio shape: {audio .shape }") + else: + raise ValueError(f"Audio should be 1D or 2D, got shape: {audio .shape }") + + def _process_single_audio( + self, audio: Union[np.ndarray, List[float]] + ) -> np.ndarray: + if not isinstance(audio, np.ndarray): + audio = np.array(audio, dtype=np.float32) + else: + audio = audio.astype(np.float32) + audio = self._ensure_mono(audio) + if self.normalize_audio and self.normalizer is not None: + audio = self.normalizer(audio) + return audio + + def __call__( + self, + audio: Union[ + str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str] + ] = None, + sampling_rate: Optional[int] = None, + return_tensors: Optional[str] = None, + **kwargs, + ): + if audio is None: + raise ValueError("Audio input is required") + if sampling_rate is not None and sampling_rate != self.sampling_rate: + logger.warning( + f"Input sampling rate ({sampling_rate }) differs from expected sampling rate ({self .sampling_rate }). Please resample your audio." + ) + if isinstance(audio, str): + audio = self._load_audio_from_path(audio) + is_batched = False + elif isinstance(audio, list): + if len(audio) == 0: + raise ValueError("Empty audio list provided") + if all((isinstance(item, str) for item in audio)): + audio = [self._load_audio_from_path(path) for path in audio] + is_batched = True + else: + is_batched = isinstance(audio[0], (np.ndarray, list)) + else: + is_batched = False + if is_batched: + processed_audio = [self._process_single_audio(a) for a in audio] + else: + processed_audio = [self._process_single_audio(audio)] + if return_tensors == "pt": + if len(processed_audio) == 1: + input_features = ( + torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1) + ) + else: + input_features = torch.stack( + [torch.from_numpy(a) for a in processed_audio] + ).unsqueeze(1) + elif return_tensors == "np": + if len(processed_audio) == 1: + input_features = processed_audio[0][np.newaxis, np.newaxis, :] + else: + input_features = np.stack(processed_audio)[:, np.newaxis, :] + else: + input_features = ( + processed_audio[0] if len(processed_audio) == 1 else processed_audio + ) + outputs = {"audio": input_features} + return outputs + + def _load_audio_from_path(self, audio_path: str) -> np.ndarray: + file_ext = os.path.splitext(audio_path)[1].lower() + if file_ext in [".wav", ".mp3", ".flac", ".m4a", ".ogg"]: + import librosa + + audio_array, sr = librosa.load(audio_path, sr=self.sampling_rate, mono=True) + return audio_array + elif file_ext == ".pt": + audio_tensor = torch.load(audio_path, map_location="cpu").squeeze() + if isinstance(audio_tensor, torch.Tensor): + audio_array = audio_tensor.numpy() + else: + audio_array = np.array(audio_tensor) + return audio_array.astype(np.float32) + elif file_ext == ".npy": + audio_array = np.load(audio_path) + return audio_array.astype(np.float32) + else: + raise ValueError( + f"Unsupported file format: {file_ext }. Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz" + ) + + def preprocess_audio( + self, + audio_path_or_array: Union[str, np.ndarray], + normalize: Optional[bool] = None, + ) -> np.ndarray: + if isinstance(audio_path_or_array, str): + audio_array = self._load_audio_from_path(audio_path_or_array) + else: + audio_array = np.array(audio_path_or_array, dtype=np.float32) + original_normalize = self.normalize_audio + if normalize is not None: + self.normalize_audio = normalize + try: + processed = self._process_single_audio(audio_array) + finally: + self.normalize_audio = original_normalize + return processed + + def to_dict(self) -> Dict[str, Any]: + return self.feature_extractor_dict + + def save_audio( + self, + audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], + output_path: str = "output.wav", + sampling_rate: Optional[int] = None, + normalize: bool = False, + batch_prefix: str = "audio_", + ): + if sampling_rate is None: + sampling_rate = self.sampling_rate + try: + import soundfile as sf + except ImportError: + raise ImportError( + "soundfile is required to save audio files. Install it with: pip install soundfile" + ) + if isinstance(audio, torch.Tensor): + audio_np = audio.float().detach().cpu().numpy() + elif isinstance(audio, np.ndarray): + audio_np = audio + elif isinstance(audio, list): + if all((isinstance(a, torch.Tensor) for a in audio)): + audio_np = [a.float().detach().cpu().numpy() for a in audio] + else: + audio_np = audio + else: + raise ValueError(f"Unsupported audio type: {type (audio )}") + saved_paths = [] + if isinstance(audio_np, list): + output_dir = output_path + os.makedirs(output_dir, exist_ok=True) + for i, audio_item in enumerate(audio_np): + audio_item = self._prepare_audio_for_save(audio_item, normalize) + file_path = os.path.join(output_dir, f"{batch_prefix }{i }.wav") + sf.write(file_path, audio_item, sampling_rate) + saved_paths.append(file_path) + elif len(audio_np.shape) >= 3: + batch_size = audio_np.shape[0] + if batch_size > 1: + output_dir = output_path + os.makedirs(output_dir, exist_ok=True) + for i in range(batch_size): + single_audio = audio_np[i] + if len(single_audio.shape) > 1: + if single_audio.shape[0] == 1: + single_audio = single_audio.squeeze(0) + single_audio = self._prepare_audio_for_save(single_audio, normalize) + file_path = os.path.join(output_dir, f"{batch_prefix }{i }.wav") + sf.write(file_path, single_audio, sampling_rate) + saved_paths.append(file_path) + else: + audio_item = audio_np.squeeze() + audio_item = self._prepare_audio_for_save(audio_item, normalize) + sf.write(output_path, audio_item, sampling_rate) + saved_paths.append(output_path) + else: + audio_item = self._prepare_audio_for_save(audio_np, normalize) + sf.write(output_path, audio_item, sampling_rate) + saved_paths.append(output_path) + return saved_paths + + def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray: + if len(audio.shape) > 1 and audio.shape[0] == 1: + audio = audio.squeeze(0) + if normalize: + max_val = np.abs(audio).max() + if max_val > 0: + audio = audio / max_val + return audio + + +__all__ = [ + 'QWEN3VoxTokenizerProcessor', + "AudioNormalizer", +] +import math +import torch + + +class UniformSampler: + + def __init__(self, timesteps=1000): + self.timesteps = timesteps + + def sample(self, batch_size, device): + return torch.randint(0, self.timesteps, (batch_size,), device=device) + + +class LogitNormalSampler: + + def __init__(self, timesteps=1000, m=0, s=1): + self.timesteps = timesteps + timesteps = torch.linspace(0, 1, timesteps) + logit = torch.log(timesteps / (1 - timesteps)) + self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s**2) / ( + s * math.sqrt(2 * math.pi) + ) + + def sample(self, batch_size, device): + return torch.multinomial(self.prob, batch_size, replacement=True).to(device) + + +' QWEN3Vox Streaming model configuration' +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config + +logger = logging.get_logger(__name__) + + +class QWEN3VoxStreamingConfig(PretrainedConfig): + model_type = 'vibevoice_streaming' + is_composition = True + sub_configs = { + "acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig, + "decoder_config": Qwen2Config, + "diffusion_head_config": QWEN3VoxDiffusionHeadConfig, + } + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + acoustic_tokenizer_config=None, + decoder_config=None, + diffusion_head_config=None, + tts_backbone_num_hidden_layers=20, + **kwargs, + ): + kwargs["_attn_implementation_autoset"] = False + if acoustic_tokenizer_config is None: + self.acoustic_tokenizer_config = self.sub_configs[ + "acoustic_tokenizer_config" + ]() + elif isinstance(acoustic_tokenizer_config, dict): + acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer' + self.acoustic_tokenizer_config = self.sub_configs[ + "acoustic_tokenizer_config" + ](**acoustic_tokenizer_config) + elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig): + self.acoustic_tokenizer_config = acoustic_tokenizer_config + if decoder_config is None: + self.decoder_config = self.sub_configs["decoder_config"]() + elif isinstance(decoder_config, dict): + if decoder_config.get("model_type", "") == "qwen2": + self.decoder_config = Qwen2Config(**decoder_config) + else: + raise ValueError( + f"Unsupported decoder model type: {decoder_config .get ('model_type','')}" + ) + elif isinstance(decoder_config, (Qwen2Config,)): + self.decoder_config = decoder_config + if diffusion_head_config is None: + self.diffusion_head_config = self.sub_configs["diffusion_head_config"]() + elif isinstance(diffusion_head_config, dict): + diffusion_head_config["model_type"] = 'vibevoice_diffusion_head' + self.diffusion_head_config = self.sub_configs["diffusion_head_config"]( + **diffusion_head_config + ) + elif isinstance(diffusion_head_config, QWEN3VoxDiffusionHeadConfig): + self.diffusion_head_config = diffusion_head_config + self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64) + self.tts_backbone_num_hidden_layers = tts_backbone_num_hidden_layers + super().__init__(**kwargs) + + +__all__ = [ + 'QWEN3VoxStreamingConfig' +] +import math +from typing import Optional, Tuple, Union +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.auto import AutoModel +from transformers.modeling_utils import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class RMSNorm(nn.Module): + + def __init__( + self, + dim: int, + eps: float = 1e-06, + elementwise_affine=True, + memory_efficient=False, + ): + super().__init__() + self.dim = dim + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_parameter("weight", None) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.weight is not None: + output = output * self.weight + return output + + def extra_repr(self) -> str: + return f"dim={self .dim }, eps={self .eps }, elementwise_affine={self .elementwise_affine }" + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=False), + ACT2FN["silu"], + nn.Linear(hidden_size, hidden_size, bias=False), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding.to(t.dtype) + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class FeedForwardNetwork(nn.Module): + + def __init__(self, embed_dim, ffn_dim): + super().__init__() + self.embed_dim = embed_dim + self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) + self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) + self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False) + self.act_fn = ACT2FN["silu"] + + def forward(self, x): + gate = self.gate_proj(x) + up = self.up_proj(x) + gate = self.act_fn(gate) + return self.down_proj(gate * up) + + +class HeadLayer(nn.Module): + + def __init__(self, embed_dim, ffn_dim, cond_dim, norm_eps=1e-05): + super().__init__() + self.embed_dim = embed_dim + self.cond_dim = cond_dim + self.ffn_dim = ffn_dim + self.ffn = FeedForwardNetwork(self.embed_dim, self.ffn_dim) + self.norm = RMSNorm(self.embed_dim, eps=norm_eps) + self.adaLN_modulation = nn.Sequential( + ACT2FN["silu"], nn.Linear(cond_dim, 3 * self.embed_dim, bias=False) + ) + + def forward(self, x, c): + shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1) + x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn)) + return x + + +class FinalLayer(nn.Module): + + def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-05): + super().__init__() + self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False) + self.linear = nn.Linear(hidden_size, output_size, bias=False) + self.adaLN_modulation = nn.Sequential( + ACT2FN["silu"], nn.Linear(cond_size, 2 * hidden_size, bias=False) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class QWEN3VoxDiffusionHead(PreTrainedModel): + config_class = QWEN3VoxDiffusionHeadConfig + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__(self, config): + super().__init__(config) + self.config = config + self.cond_dim = config.hidden_size + latent_size = config.latent_size + self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False) + self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False) + self.t_embedder = TimestepEmbedder(self.cond_dim) + ffn_dim = int(config.hidden_size * config.head_ffn_ratio) + self.layers = nn.ModuleList( + [ + HeadLayer( + embed_dim=config.hidden_size, + ffn_dim=ffn_dim, + cond_dim=self.cond_dim, + norm_eps=config.rms_norm_eps, + ) + for _ in range(config.head_layers) + ] + ) + self.final_layer = FinalLayer( + hidden_size=config.hidden_size, + output_size=latent_size, + cond_size=self.cond_dim, + norm_eps=config.rms_norm_eps, + ) + self.initialize_weights() + + def initialize_weights(self): + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for layer in self.layers: + nn.init.constant_(layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + + def forward(self, noisy_images, timesteps, condition): + x = self.noisy_images_proj(noisy_images) + t = self.t_embedder(timesteps) + condition = self.cond_proj(condition) + c = condition + t + for layer in self.layers: + x = layer(x, c) + x = self.final_layer(x, c) + return x + + +AutoModel.register(QWEN3VoxDiffusionHeadConfig, QWEN3VoxDiffusionHead) +__all__ = [ + 'QWEN3VoxDiffusionHead' +] +import math +import typing as tp +from functools import partial +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Union +import copy +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.auto import AutoModel +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from transformers.modeling_utils import PreTrainedModel +from transformers.activations import ACT2FN + +logger = logging.get_logger(__name__) +import os + +try: + from apex.normalization.fused_layer_norm import fused_rms_norm_affine + + APEX_AVAILABLE = True + logger.info("APEX FusedRMSNorm is available and will be used for optimization") + if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0: + APEX_AVAILABLE = False + logger.warning( + "APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0" + ) +except ImportError: + APEX_AVAILABLE = False + logger.warning("APEX FusedRMSNorm not available, using native implementation") + + +class ConvLayerNorm(nn.LayerNorm): + + def __init__( + self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs + ): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = x.transpose(1, 2) + x = nn.functional.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float(), + self.bias.float(), + self.eps, + ).type_as(x) + x = x.transpose(1, 2) + return x + + +class RMSNorm(nn.Module): + + def __init__( + self, dim: int, eps: float = 1e-05, elementwise_affine=True, weight_shape=None + ): + super().__init__() + self.dim = dim + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + weight_shape = (dim,) if weight_shape is None else weight_shape + self.weight = nn.Parameter(torch.ones(weight_shape)) + else: + self.register_parameter("weight", None) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.weight is not None: + output = output * self.weight + return output + + def extra_repr(self) -> str: + return f"dim={self .dim }, eps={self .eps }, elementwise_affine={self .elementwise_affine }" + + +class ConvRMSNorm(RMSNorm): + + def __init__( + self, dim: int, eps: float = 1e-05, elementwise_affine=True, weight_shape=None + ): + super().__init__(dim, eps, elementwise_affine, weight_shape) + + def forward(self, x): + x = x.transpose(1, 2) + if not APEX_AVAILABLE or not self.elementwise_affine: + output = self._norm(x.float()).type_as(x) + if self.weight is not None: + output = output * self.weight + else: + output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps) + output = output.transpose(1, 2) + return output + + +CONV_NORMALIZATIONS = frozenset( + [ + "none", + "weight_norm", + "spectral_norm", + "time_layer_norm", + "layer_norm", + "time_group_norm", + ] +) + + +def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == "weight_norm": + return nn.utils.weight_norm(module) + elif norm == "spectral_norm": + return nn.utils.spectral_norm(module) + else: + return module + + +def get_norm_module( + module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs +) -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == "layer_norm": + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == "time_group_norm": + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad1d( + x: torch.Tensor, + paddings: tp.Tuple[int, int], + mode: str = "zero", + value: float = 0.0, +): + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert padding_left + padding_right <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +class NormConv1d(nn.Module): + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.convtr = apply_parametrization_norm( + nn.ConvTranspose1d(*args, **kwargs), norm + ) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class QWEN3VoxTokenizerStreamingCache: + + def __init__(self): + self.cache = {} + + def get( + self, layer_id: str, sample_indices: torch.Tensor + ) -> Optional[torch.Tensor]: + states = [] + max_length = 0 + for idx in sample_indices.tolist(): + key = (layer_id, idx) + if key not in self.cache: + return None + state = self.cache[key] + states.append(state) + max_length = max(max_length, state.shape[-1]) + if len(states) > 0 and states[0].dim() >= 2: + padded_states = [] + for state in states: + if state.shape[-1] < max_length: + pad_size = max_length - state.shape[-1] + padded_state = F.pad(state, (pad_size, 0), mode="constant", value=0) + padded_states.append(padded_state) + else: + padded_states.append(state) + return torch.stack(padded_states, dim=0) + else: + return torch.stack(states, dim=0) + + def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor): + for i, idx in enumerate(sample_indices.tolist()): + key = (layer_id, idx) + self.cache[key] = states[i].detach() + + def set_to_zero(self, sample_indices: torch.Tensor): + for key in list(self.cache.keys()): + layer_id, sample_idx = key + if sample_idx in sample_indices.tolist(): + cached_tensor = self.cache[key] + self.cache[key] = torch.zeros_like(cached_tensor) + + def clear( + self, + layer_id: Optional[str] = None, + sample_indices: Optional[torch.Tensor] = None, + ): + if layer_id is None and sample_indices is None: + self.cache.clear() + elif layer_id is not None and sample_indices is None: + keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id] + for k in keys_to_remove: + del self.cache[k] + elif layer_id is not None and sample_indices is not None: + for idx in sample_indices.tolist(): + key = (layer_id, idx) + self.cache.pop(key, None) + + +class SConv1d(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = "reflect", + ): + super().__init__() + self.conv = NormConv1d( + in_channels, + out_channels, + kernel_size, + stride, + dilation=dilation, + groups=groups, + bias=bias, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.pad_mode = pad_mode + self.kernel_size = kernel_size + self.dilation = dilation + self.stride = stride + self.in_channels = in_channels + self.out_channels = out_channels + self.context_size = (kernel_size - 1) * dilation - (stride - 1) + self.padding_total = (kernel_size - 1) * dilation - (stride - 1) + self._layer_id = None + + @property + def layer_id(self): + if self._layer_id is None: + self._layer_id = f"sconv1d_{id (self )}" + return self._layer_id + + def forward( + self, + x: torch.Tensor, + cache: Optional[QWEN3VoxTokenizerStreamingCache] = None, + sample_indices: Optional[torch.Tensor] = None, + use_cache: bool = False, + debug: bool = False, + ) -> torch.Tensor: + B, C, T = x.shape + if not use_cache or cache is None: + return self._forward_non_streaming(x, debug=debug) + assert self.causal, "Streaming mode is only supported for causal convolutions" + assert ( + sample_indices is not None + ), "sample_indices must be provided for streaming mode" + assert len(sample_indices) == B, "sample_indices must match batch size" + return self._forward_streaming(x, cache, sample_indices, debug) + + def _forward_streaming( + self, + x: torch.Tensor, + cache: QWEN3VoxTokenizerStreamingCache, + sample_indices: torch.Tensor, + debug: bool = False, + ) -> torch.Tensor: + B, C, T = x.shape + cached_states = cache.get(self.layer_id, sample_indices) + if cached_states is None: + if self.context_size > 0: + cached_states = torch.zeros( + B, C, self.context_size, device=x.device, dtype=x.dtype + ) + if debug: + print( + f"[DEBUG] Initialized cache with shape: {cached_states .shape }, context_size={self .context_size }" + ) + else: + cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype) + if debug: + print(f"[DEBUG] No context needed (kernel_size=stride)") + if cached_states.shape[2] > 0: + input_with_context = torch.cat([cached_states, x], dim=2) + else: + input_with_context = x + if debug: + print( + f"[DEBUG] Input shape: {x .shape }, Cache shape: {cached_states .shape }, Combined: {input_with_context .shape }" + ) + output = self.conv(input_with_context) + if debug: + print(f"[DEBUG] Output shape: {output .shape }") + if self.context_size > 0: + total_input_length = input_with_context.shape[2] + if total_input_length >= self.context_size: + new_cache_start = total_input_length - self.context_size + new_cache = input_with_context[:, :, new_cache_start:] + else: + new_cache = input_with_context + if debug: + print(f"[DEBUG] New cache shape: {new_cache .shape }") + cache.set(self.layer_id, sample_indices, new_cache) + return output + + def _forward_non_streaming( + self, x: torch.Tensor, debug: bool = False + ) -> torch.Tensor: + B, C, T = x.shape + kernel_size = self.kernel_size + stride = self.stride + dilation = self.dilation + padding_total = self.padding_total + extra_padding = get_extra_padding_for_conv1d( + x, kernel_size, stride, padding_total + ) + if debug: + print( + f"[DEBUG NON-STREAMING] Input shape: {x .shape }, padding_total={padding_total }, extra_padding={extra_padding }" + ) + if self.causal: + if self.pad_mode == "constant": + x = pad1d( + x, (padding_total, extra_padding), mode=self.pad_mode, value=0 + ) + else: + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d( + x, (padding_left, padding_right + extra_padding), mode=self.pad_mode + ) + if debug: + print(f"[DEBUG NON-STREAMING] After padding: {x .shape }") + output = self.conv(x) + if debug: + print(f"[DEBUG NON-STREAMING] Output shape: {output .shape }") + return output + + +class SConvTranspose1d(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + causal: bool = False, + norm: str = "none", + trim_right_ratio: float = 1.0, + norm_kwargs: tp.Dict[str, tp.Any] = {}, + bias: bool = True, + ): + super().__init__() + self.convtr = NormConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + bias=bias, + ) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert ( + self.causal or self.trim_right_ratio == 1.0 + ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 + self.kernel_size = kernel_size + self.stride = stride + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_total = kernel_size - stride + self.context_size = kernel_size - 1 + self._layer_id = None + + @property + def layer_id(self): + if self._layer_id is None: + self._layer_id = f"sconvtr1d_{id (self )}" + return self._layer_id + + def forward( + self, + x: torch.Tensor, + cache: Optional[QWEN3VoxTokenizerStreamingCache] = None, + sample_indices: Optional[torch.Tensor] = None, + use_cache: bool = False, + debug: bool = False, + ) -> torch.Tensor: + B, C, T = x.shape + if not use_cache or cache is None: + return self._forward_non_streaming(x, debug=debug) + assert ( + sample_indices is not None + ), "sample_indices must be provided for streaming mode" + assert len(sample_indices) == B, "sample_indices must match batch size" + return self._forward_streaming(x, cache, sample_indices, debug) + + def _forward_streaming( + self, + x: torch.Tensor, + cache: QWEN3VoxTokenizerStreamingCache, + sample_indices: torch.Tensor, + debug: bool = False, + ) -> torch.Tensor: + B, C, T = x.shape + cached_input = cache.get(self.layer_id, sample_indices) + if cached_input is None: + cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype) + if debug: + print(f"[DEBUG] Initialized empty cache for transposed conv") + full_input = torch.cat([cached_input, x], dim=2) + if debug: + print( + f"[DEBUG] Input shape: {x .shape }, Cache shape: {cached_input .shape }, Combined: {full_input .shape }" + ) + full_output = self.convtr(full_input) + if debug: + print(f"[DEBUG] Full transposed conv output shape: {full_output .shape }") + if self.causal: + padding_right = math.ceil(self.padding_total * self.trim_right_ratio) + padding_left = self.padding_total - padding_right + else: + padding_right = self.padding_total // 2 + padding_left = self.padding_total - padding_right + if padding_left + padding_right > 0: + full_output = unpad1d(full_output, (padding_left, padding_right)) + if debug: + print(f"[DEBUG] After unpadding: {full_output .shape }") + if cached_input.shape[2] == 0: + output = full_output + else: + expected_new_output = T * self.stride + if full_output.shape[2] >= expected_new_output: + output = full_output[:, :, -expected_new_output:] + else: + output = full_output + if debug: + print(f"[DEBUG] Final streaming output shape: {output .shape }") + if full_input.shape[2] > self.context_size: + new_cache = full_input[:, :, -self.context_size :] + else: + new_cache = full_input + if debug: + print(f"[DEBUG] New cache shape: {new_cache .shape }") + cache.set(self.layer_id, sample_indices, new_cache) + return output + + def _forward_non_streaming( + self, x: torch.Tensor, debug: bool = False + ) -> torch.Tensor: + if debug: + print(f"[DEBUG NON-STREAMING] Input shape: {x .shape }") + y = self.convtr(x) + if debug: + print(f"[DEBUG NON-STREAMING] After transposed conv: {y .shape }") + if self.causal: + padding_right = math.ceil(self.padding_total * self.trim_right_ratio) + padding_left = self.padding_total - padding_right + else: + padding_right = self.padding_total // 2 + padding_left = self.padding_total - padding_right + if padding_left + padding_right > 0: + y = unpad1d(y, (padding_left, padding_right)) + if debug: + print(f"[DEBUG NON-STREAMING] Final output shape: {y .shape }") + return y + + +class FFN(nn.Module): + + def __init__(self, embed_dim, ffn_dim, bias=False): + super().__init__() + self.embed_dim = embed_dim + self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias) + self.gelu = ACT2FN["gelu"] + self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias) + + def forward(self, x): + x = self.linear1(x) + x = self.gelu(x) + x = self.linear2(x) + return x + + +class Convlayer(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + pad_mode="zeros", + norm="weight_norm", + causal=True, + ): + super().__init__() + self.conv = SConv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + pad_mode=pad_mode, + norm=norm, + causal=causal, + ) + + def forward(self, x): + return self.conv(x) + + +class Block1D(nn.Module): + + def __init__( + self, + dim, + kernel_size=7, + drop_path=0.0, + mixer_layer="conv", + layer_scale_init_value=1e-06, + **kwargs, + ): + super().__init__() + if kwargs.get("layernorm", "LN") == "LN": + self.norm = ConvLayerNorm(dim, eps=kwargs.get("eps", 1e-06)) + self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get("eps", 1e-06)) + elif kwargs.get("layernorm", "RMSNorm") == "RMSNorm": + self.norm = ConvRMSNorm(dim, eps=kwargs.get("eps", 1e-06)) + self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get("eps", 1e-06)) + if mixer_layer == "conv": + self.mixer = Convlayer( + dim, + dim, + groups=kwargs.get("groups", 1), + kernel_size=kernel_size, + pad_mode=kwargs.get("pad_mode", "reflect"), + norm=kwargs.get("norm", "none"), + causal=kwargs.get("causal", True), + bias=kwargs.get("bias", True), + ) + elif mixer_layer == "depthwise_conv": + self.mixer = Convlayer( + dim, + dim, + groups=dim, + kernel_size=kernel_size, + pad_mode=kwargs.get("pad_mode", "reflect"), + norm=kwargs.get("norm", "none"), + causal=kwargs.get("causal", True), + bias=kwargs.get("bias", True), + ) + else: + raise ValueError(f"Unsupported mixer layer: {mixer_layer }") + self.ffn = FFN( + dim, kwargs.get("ffn_expansion", 4) * dim, bias=kwargs.get("bias", False) + ) + self.drop_path = ( + nn.Identity() if drop_path <= 0.0 else nn.modules.DropPath(drop_path) + ) + if layer_scale_init_value > 0: + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones(dim), requires_grad=True + ) + self.ffn_gamma = nn.Parameter( + layer_scale_init_value * torch.ones(dim), requires_grad=True + ) + else: + self.gamma = None + self.ffn_gamma = None + + def forward(self, x): + residual = x + x = self.norm(x) + x = self.mixer(x) + if self.gamma is not None: + x = x * self.gamma.unsqueeze(-1) + x = residual + self.drop_path(x) + residual = x + x = self.ffn_norm(x) + x = x.permute(0, 2, 1) + x = self.ffn(x) + x = x.permute(0, 2, 1) + if self.ffn_gamma is not None: + x = x * self.ffn_gamma.unsqueeze(-1) + x = residual + self.drop_path(x) + return x + + +class TokenizerEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.channels = config.channels + self.dimension = config.dimension + self.n_filters = config.n_filters + self.ratios = list(reversed(config.ratios)) + self.depths = config.depths + self.n_residual_layers = getattr(config, "n_residual_layers", 1) + self.hop_length = np.prod(self.ratios) + self.causal = config.causal + kernel_size = getattr(config, "kernel_size", 7) + last_kernel_size = getattr(config, "last_kernel_size", 7) + norm = getattr(config, "norm", "none") + norm_params = getattr(config, "norm_params", {}) + pad_mode = getattr(config, "pad_mode", "reflect") + bias = getattr(config, "bias", True) + layernorm = getattr(config, "layernorm", "LN") + layernorm_eps = getattr(config, "layernorm_eps", 1e-06) + layernorm_elementwise_affine = getattr( + config, "layernorm_elementwise_affine", True + ) + drop_path_rate = getattr(config, "drop_path_rate", 0.0) + mixer_layer = getattr(config, "mixer_layer", "conv") + layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) + disable_last_norm = getattr(config, "disable_last_norm", False) + if layernorm == "LN": + norm_type = ConvLayerNorm + elif layernorm == "RMSNorm": + norm_type = partial( + ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine + ) + else: + raise ValueError(f"Unsupported norm type: {layernorm }") + stem = nn.Sequential( + SConv1d( + self.channels, + self.n_filters, + kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=self.causal, + pad_mode=pad_mode, + bias=bias, + ) + ) + self.downsample_layers = nn.ModuleList() + self.downsample_layers.append(stem) + for i in range(len(self.ratios)): + in_ch = self.n_filters * 2**i + out_ch = self.n_filters * 2 ** (i + 1) + downsample_layer = nn.Sequential( + SConv1d( + in_ch, + out_ch, + kernel_size=self.ratios[i] * 2, + stride=self.ratios[i], + causal=self.causal, + pad_mode=pad_mode, + norm=norm, + bias=bias, + ) + ) + self.downsample_layers.append(downsample_layer) + layer_type = partial( + Block1D, + mixer_layer=mixer_layer, + layernorm=layernorm, + eps=layernorm_eps, + causal=self.causal, + pad_mode=pad_mode, + norm=norm, + bias=bias, + layer_scale_init_value=layer_scale_init_value, + ) + self.stages = nn.ModuleList() + dp_rates = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + cur = 0 + for i in range(len(self.depths)): + in_ch = self.n_filters * 2**i + stage = nn.Sequential( + *[ + layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) + for j in range(self.depths[i]) + ] + ) + self.stages.append(stage) + cur += self.depths[i] + if not disable_last_norm: + self.norm = norm_type(in_ch, eps=layernorm_eps) + else: + self.norm = nn.Identity() + self.head = SConv1d( + in_ch, + self.dimension, + kernel_size=last_kernel_size, + causal=self.causal, + pad_mode=pad_mode, + norm=norm, + bias=bias, + ) + + def forward_features( + self, x, cache=None, sample_indices=None, use_cache=False, debug=False + ): + for i in range(len(self.depths)): + for layer in self.downsample_layers[i]: + if isinstance(layer, SConv1d): + x = layer( + x, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + else: + x = layer(x) + for block in self.stages[i]: + if ( + hasattr(block, "mixer") + and hasattr(block.mixer, "conv") + and isinstance(block.mixer.conv, SConv1d) + ): + residual = x + x = block.norm(x) + x = block.mixer.conv( + x, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + if block.gamma is not None: + x = x * block.gamma.unsqueeze(-1) + x = residual + x + residual = x + x = block.ffn_norm(x) + x = x.permute(0, 2, 1) + x = block.ffn(x) + x = x.permute(0, 2, 1) + if block.ffn_gamma is not None: + x = x * block.ffn_gamma.unsqueeze(-1) + x = residual + x + else: + x = block(x) + return self.norm(x) + + def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): + x = self.forward_features( + x, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + x = self.head( + x, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + return x + + +class TokenizerDecoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.dimension = config.dimension + self.channels = config.channels + self.n_filters = config.n_filters + self.ratios = config.ratios + self.depths = config.depths + self.n_residual_layers = getattr(config, "n_residual_layers", 1) + self.hop_length = np.prod(self.ratios) + self.causal = config.causal + kernel_size = getattr(config, "kernel_size", 7) + last_kernel_size = getattr(config, "last_kernel_size", 7) + norm = getattr(config, "norm", "none") + norm_params = getattr(config, "norm_params", {}) + pad_mode = getattr(config, "pad_mode", "reflect") + bias = getattr(config, "bias", True) + layernorm = getattr(config, "layernorm", "LN") + layernorm_eps = getattr(config, "layernorm_eps", 1e-06) + trim_right_ratio = getattr(config, "trim_right_ratio", 1.0) + layernorm_elementwise_affine = getattr( + config, "layernorm_elementwise_affine", True + ) + drop_path_rate = getattr(config, "drop_path_rate", 0.0) + mixer_layer = getattr(config, "mixer_layer", "conv") + layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) + disable_last_norm = getattr(config, "disable_last_norm", False) + if layernorm == "LN": + norm_type = ConvLayerNorm + elif layernorm == "RMSNorm": + norm_type = partial( + ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine + ) + else: + raise ValueError(f"Unsupported norm type: {layernorm }") + stem = nn.Sequential( + SConv1d( + self.dimension, + self.n_filters * 2 ** (len(self.depths) - 1), + kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=self.causal, + pad_mode=pad_mode, + bias=bias, + ) + ) + self.upsample_layers = nn.ModuleList() + self.upsample_layers.append(stem) + for i in range(len(self.ratios)): + in_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i) + out_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i - 1) + upsample_layer = nn.Sequential( + SConvTranspose1d( + in_ch, + out_ch, + kernel_size=self.ratios[i] * 2, + stride=self.ratios[i], + norm=norm, + norm_kwargs=norm_params, + bias=bias, + causal=self.causal, + trim_right_ratio=trim_right_ratio, + ) + ) + self.upsample_layers.append(upsample_layer) + layer_type = partial( + Block1D, + mixer_layer=mixer_layer, + layernorm=layernorm, + eps=layernorm_eps, + causal=self.causal, + pad_mode=pad_mode, + norm=norm, + bias=bias, + layer_scale_init_value=layer_scale_init_value, + ) + self.stages = nn.ModuleList() + dp_rates = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + cur = 0 + for i in range(len(self.depths)): + in_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i) + stage = nn.Sequential( + *[ + layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) + for j in range(self.depths[i]) + ] + ) + self.stages.append(stage) + cur += self.depths[i] + if not disable_last_norm: + self.norm = norm_type(in_ch, eps=layernorm_eps) + else: + self.norm = nn.Identity() + self.head = SConv1d( + in_ch, + self.channels, + kernel_size=last_kernel_size, + causal=self.causal, + pad_mode=pad_mode, + norm=norm, + bias=bias, + ) + + def forward_features( + self, x, cache=None, sample_indices=None, use_cache=False, debug=False + ): + for i in range(len(self.depths)): + for layer in self.upsample_layers[i]: + if isinstance(layer, (SConv1d, SConvTranspose1d)): + x = layer( + x, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + else: + x = layer(x) + for block in self.stages[i]: + if ( + hasattr(block, "mixer") + and hasattr(block.mixer, "conv") + and isinstance(block.mixer.conv, SConv1d) + ): + residual = x + x = block.norm(x) + x = block.mixer.conv( + x, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + if block.gamma is not None: + x = x * block.gamma.unsqueeze(-1) + x = residual + x + residual = x + x = block.ffn_norm(x) + x = x.permute(0, 2, 1) + x = block.ffn(x) + x = x.permute(0, 2, 1) + if block.ffn_gamma is not None: + x = x * block.ffn_gamma.unsqueeze(-1) + x = residual + x + else: + x = block(x) + return self.norm(x) + + def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): + x = self.forward_features( + x, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + x = self.head( + x, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + return x + + +@dataclass +class QWEN3VoxTokenizerEncoderOutput: + mean: torch.Tensor + std: Optional[Union[float, torch.Tensor]] = None + + def sample(self, dist_type="fix"): + if dist_type == "fix": + x = self.mean + self.std * torch.randn_like(self.mean) + return (x, self.std) + elif dist_type == "gaussian": + batch_size = self.mean.size(0) + value = self.std / 0.8 + std = ( + torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) + * value + ) + while std.dim() < self.mean.dim(): + std = std.unsqueeze(-1) + x = self.mean + std * torch.randn_like(self.mean) + return (x, std) + else: + return (self.mean, self.std) + + def kl(self): + target = torch.zeros_like(self.mean) + return F.mse_loss(self.mean, target, reduction="none") + + def mode(self): + return self.mean + + +class QWEN3VoxAcousticTokenizerModel(PreTrainedModel): + config_class = QWEN3VoxAcousticTokenizerConfig + base_model_prefix = 'vibevoice_acoustic_tokenizer' + _supports_flash_attn_2 = True + _supports_sdpa = True + _no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"] + + def __init__(self, config): + super().__init__(config) + self.register_buffer("fix_std", torch.tensor(config.fix_std), persistent=False) + self.std_dist_type = getattr(config, "std_dist_type", "fix") + if isinstance(config.encoder_depths, str): + encoder_depths = [int(d) for d in config.encoder_depths.split("-")] + else: + encoder_depths = config.encoder_depths + if config.decoder_depths is not None and isinstance(config.decoder_depths, str): + decoder_depths = [int(d) for d in config.decoder_depths.split("-")] + else: + decoder_depths = list(reversed(encoder_depths)) + encoder_config = copy.deepcopy(config) + encoder_config.dimension = config.vae_dim + encoder_config.n_filters = config.encoder_n_filters + encoder_config.ratios = config.encoder_ratios + encoder_config.depths = encoder_depths + encoder_config.norm = config.conv_norm + encoder_config.pad_mode = config.pad_mode + encoder_config.bias = config.conv_bias + encoder_config.layernorm_eps = config.layernorm_eps + encoder_config.layernorm_elementwise_affine = ( + config.layernorm_elementwise_affine + ) + encoder_config.mixer_layer = config.mixer_layer + encoder_config.layer_scale_init_value = config.layer_scale_init_value + encoder_config.disable_last_norm = config.disable_last_norm + decoder_config = copy.deepcopy(config) + decoder_config.dimension = config.vae_dim + decoder_config.n_filters = config.decoder_n_filters + decoder_config.ratios = config.decoder_ratios + decoder_config.depths = decoder_depths + decoder_config.norm = config.conv_norm + decoder_config.pad_mode = config.pad_mode + decoder_config.bias = config.conv_bias + decoder_config.layernorm_eps = config.layernorm_eps + decoder_config.layernorm_elementwise_affine = ( + config.layernorm_elementwise_affine + ) + decoder_config.mixer_layer = config.mixer_layer + decoder_config.layer_scale_init_value = config.layer_scale_init_value + decoder_config.disable_last_norm = config.disable_last_norm + self.encoder = TokenizerEncoder(encoder_config) + self.decoder = TokenizerDecoder(decoder_config) + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=self.config.weight_init_value) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv1d): + nn.init.normal_(module.weight, std=self.config.weight_init_value) + if module.bias is not None: + nn.init.zeros_(module.bias) + + @torch.no_grad() + def encode( + self, audio, cache=None, sample_indices=None, use_cache=False, debug=False + ): + latents = self.encoder( + audio, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + return QWEN3VoxTokenizerEncoderOutput( + mean=latents.permute(0, 2, 1), std=self.fix_std + ) + + @torch.no_grad() + def sampling(self, encoder_output, dist_type=None): + dist_type = dist_type or self.std_dist_type + if dist_type == "fix": + return encoder_output.sample(dist_type="fix") + elif dist_type == "gaussian": + return encoder_output.sample(dist_type="gaussian") + else: + raise ValueError( + f"Unsupported dist_type: {dist_type }, expected 'fix' or 'gaussian'" + ) + + @torch.no_grad() + def decode( + self, latents, cache=None, sample_indices=None, use_cache=False, debug=False + ): + if latents.shape[1] == self.config.vae_dim: + pass + else: + latents = latents.permute(0, 2, 1) + audio = self.decoder( + latents, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + return audio + + def forward( + self, audio, cache=None, sample_indices=None, use_cache=False, debug=False + ): + encoder_output = self.encode( + audio, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + sampled_latents, _ = self.sampling(encoder_output) + reconstructed = self.decode( + sampled_latents, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + return (reconstructed, sampled_latents) + + +class QWEN3VoxSemanticTokenizerModel(PreTrainedModel): + config_class = QWEN3VoxSemanticTokenizerConfig + base_model_prefix = 'vibevoice_semantic_tokenizer' + _supports_flash_attn_2 = True + _supports_sdpa = True + _no_split_modules = ["TokenizerEncoder"] + + def __init__(self, config): + super().__init__(config) + if isinstance(config.encoder_depths, str): + encoder_depths = [int(d) for d in config.encoder_depths.split("-")] + else: + encoder_depths = config.encoder_depths + encoder_config = copy.deepcopy(config) + encoder_config.dimension = config.vae_dim + encoder_config.n_filters = config.encoder_n_filters + encoder_config.ratios = config.encoder_ratios + encoder_config.depths = encoder_depths + encoder_config.norm = config.conv_norm + encoder_config.pad_mode = config.pad_mode + encoder_config.bias = config.conv_bias + encoder_config.layernorm_eps = config.layernorm_eps + encoder_config.layernorm_elementwise_affine = ( + config.layernorm_elementwise_affine + ) + encoder_config.mixer_layer = config.mixer_layer + encoder_config.layer_scale_init_value = config.layer_scale_init_value + encoder_config.disable_last_norm = config.disable_last_norm + self.encoder = TokenizerEncoder(encoder_config) + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=self.config.weight_init_value) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv1d): + nn.init.normal_(module.weight, std=self.config.weight_init_value) + if module.bias is not None: + nn.init.zeros_(module.bias) + + @torch.no_grad() + def encode( + self, audio, cache=None, sample_indices=None, use_cache=False, debug=False + ): + latents = self.encoder( + audio, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + return QWEN3VoxTokenizerEncoderOutput(mean=latents.permute(0, 2, 1)) + + @torch.no_grad() + def sampling(self, encoder_output, dist_type=None): + return encoder_output.sample(dist_type="none") + + def forward( + self, audio, cache=None, sample_indices=None, use_cache=False, debug=False + ): + encoder_output = self.encode( + audio, + cache=cache, + sample_indices=sample_indices, + use_cache=use_cache, + debug=debug, + ) + sampled_latents, _ = self.sampling(encoder_output, dist_type="none") + return (None, sampled_latents) + + +AutoModel.register(QWEN3VoxAcousticTokenizerConfig, QWEN3VoxAcousticTokenizerModel) +AutoModel.register(QWEN3VoxSemanticTokenizerConfig, QWEN3VoxSemanticTokenizerModel) +__all__ = [ + 'QWEN3VoxTokenizerStreamingCache', + 'QWEN3VoxAcousticTokenizerModel', + 'QWEN3VoxSemanticTokenizerModel', +] +'\nProcessor class for QWEN3Vox ASR models.\n' +import os +import json +import math +import warnings +from typing import List, Optional, Union, Dict, Any, Tuple +import numpy as np +import torch +from transformers.tokenization_utils_base import BatchEncoding +from transformers.utils import TensorType, logging + +logger = logging.get_logger(__name__) +SYSTEM_PROMPT = "You are a helpful assistant that transcribes audio input into text output in JSON format." + + +class QWEN3VoxASRProcessor: + + def __init__( + self, + tokenizer=None, + audio_processor=None, + speech_tok_compress_ratio=320, + target_sample_rate=22050, + normalize_audio=True, + **kwargs, + ): + self.tokenizer = tokenizer + self.audio_processor = audio_processor or QWEN3VoxTokenizerProcessor( + sampling_rate=target_sample_rate, normalize_audio=normalize_audio + ) + self.speech_tok_compress_ratio = speech_tok_compress_ratio + self.target_sample_rate = target_sample_rate + self.normalize_audio = normalize_audio + if normalize_audio: + self.audio_normalizer = AudioNormalizer() + else: + self.audio_normalizer = None + self._cache_special_tokens() + + def _cache_special_tokens(self): + if hasattr(self.tokenizer, "speech_start_id"): + self.speech_start_id = self.tokenizer.speech_start_id + else: + self.speech_start_id = self.tokenizer.convert_tokens_to_ids( + "<|speech_start|>" + ) + if hasattr(self.tokenizer, "speech_end_id"): + self.speech_end_id = self.tokenizer.speech_end_id + else: + self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|speech_end|>") + if hasattr(self.tokenizer, "speech_pad_id"): + self.speech_pad_id = self.tokenizer.speech_pad_id + else: + self.speech_pad_id = self.tokenizer.convert_tokens_to_ids("<|speech_pad|>") + if hasattr(self.tokenizer, "pad_id"): + self.pad_id = self.tokenizer.pad_id + elif hasattr(self.tokenizer, "pad_token_id"): + self.pad_id = self.tokenizer.pad_token_id + else: + self.pad_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + import json + from transformers.utils import cached_file + + config_path = os.path.join( + pretrained_model_name_or_path, "preprocessor_config.json" + ) + config = {} + if os.path.exists(config_path): + with open(config_path, "r") as f: + config = json.load(f) + else: + try: + config_file = cached_file( + pretrained_model_name_or_path, "preprocessor_config.json", **kwargs + ) + with open(config_file, "r") as f: + config = json.load(f) + except Exception as e: + logger.warning(f"Could not load preprocessor_config.json: {e }") + logger.warning("Using default configuration") + speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) + target_sample_rate = config.get("target_sample_rate", 22050) + normalize_audio = config.get("normalize_audio", True) + model_name = str(pretrained_model_name_or_path) + logger.info(f"Loading tokenizer from {model_name}") + if "qwen" in model_name.lower(): + tokenizer = QWEN3VoxASRTextTokenizerFast.from_pretrained(model_name, **kwargs) + else: + raise ValueError( + f"Unsupported tokenizer type for {language_model_pretrained_name }" + ) + audio_processor = QWEN3VoxTokenizerProcessor( + sampling_rate=target_sample_rate, + normalize_audio=normalize_audio, + target_dB_FS=config.get("target_dB_FS", -25), + eps=config.get("eps", 1e-06), + ) + return cls( + tokenizer=tokenizer, + audio_processor=audio_processor, + speech_tok_compress_ratio=speech_tok_compress_ratio, + target_sample_rate=target_sample_rate, + normalize_audio=normalize_audio, + ) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): + import json + + os.makedirs(save_directory, exist_ok=True) + processor_config = { + "processor_class": "QWEN3VoxASRProcessor", + "speech_tok_compress_ratio": self.speech_tok_compress_ratio, + "target_sample_rate": self.target_sample_rate, + "normalize_audio": self.normalize_audio, + "target_dB_FS": -25, + "eps": 1e-06, + } + config_path = os.path.join(save_directory, "preprocessor_config.json") + with open(config_path, "w") as f: + json.dump(processor_config, f, indent=2) + logger.info(f"Processor configuration saved in {config_path }") + + def __call__( + self, + audio: Optional[ + Union[ + str, + np.ndarray, + torch.Tensor, + List[Union[str, np.ndarray, torch.Tensor]], + ] + ] = None, + sampling_rate: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + padding: bool = True, + max_length: Optional[int] = None, + truncation: bool = False, + add_generation_prompt: bool = True, + use_streaming: bool = True, + context_info: Optional[str] = None, + **kwargs, + ) -> BatchEncoding: + if audio is None: + raise ValueError("Audio input is required for ASR processing") + if isinstance(audio, list): + is_batched = True + audio_list = audio + else: + is_batched = False + audio_list = [audio] + all_encodings = [] + for audio_input in audio_list: + encoding = self._process_single_audio( + audio_input, + sampling_rate=sampling_rate, + add_generation_prompt=add_generation_prompt, + use_streaming=use_streaming, + context_info=context_info, + ) + all_encodings.append(encoding) + batch_encoding = self._batch_encode( + all_encodings, + padding=padding, + max_length=max_length, + truncation=truncation, + return_tensors=return_tensors, + ) + return batch_encoding + + def _process_single_audio( + self, + audio: Union[str, np.ndarray, torch.Tensor], + sampling_rate: Optional[int] = None, + add_generation_prompt: bool = True, + use_streaming: bool = True, + context_info: Optional[str] = None, + ) -> Dict[str, Any]: + if isinstance(audio, str): + import soundfile as sf + + audio_array, file_sr = sf.read(audio) + if audio_array.ndim > 1: + audio_array = audio_array.mean(axis=1) + if file_sr != self.target_sample_rate: + import librosa + + audio_array = librosa.resample( + audio_array, orig_sr=file_sr, target_sr=self.target_sample_rate + ) + elif isinstance(audio, torch.Tensor): + audio_array = audio.cpu().numpy() + if audio_array.ndim > 1: + audio_array = audio_array.squeeze() + else: + audio_array = np.array(audio, dtype=np.float32) + if audio_array.ndim > 1: + audio_array = audio_array.squeeze() + audio_array = audio_array.astype(np.float32) + if self.normalize_audio and self.audio_normalizer: + audio_array = self.audio_normalizer(audio_array) + audio_duration = len(audio_array) / self.target_sample_rate + if use_streaming and audio_duration < 60.0: + use_streaming = False + vae_tok_len = math.ceil(len(audio_array) / self.speech_tok_compress_ratio) + system_prompt_text = self.tokenizer.apply_chat_template( + [{"role": "system", "content": SYSTEM_PROMPT}], tokenize=False + ) + system_tokens = self.tokenizer.encode(system_prompt_text) + sp_start_token = self.tokenizer.convert_ids_to_tokens(self.speech_start_id) + sp_pad_token = self.tokenizer.convert_ids_to_tokens(self.speech_pad_id) + sp_end_token = self.tokenizer.convert_ids_to_tokens(self.speech_end_id) + show_keys = ["Start time", "End time", "Speaker ID", "Content"] + if context_info and context_info.strip(): + user_suffix = ( + f"This is a {audio_duration :.2f} seconds audio, with extra info: {context_info .strip ()}\n\nPlease transcribe it with these keys: " + + ", ".join(show_keys) + ) + else: + user_suffix = ( + f"This is a {audio_duration :.2f} seconds audio, please transcribe it with these keys: " + + ", ".join(show_keys) + ) + user_input_string = ( + "".join([sp_start_token] + [sp_pad_token] * vae_tok_len + [sp_end_token]) + + "\n" + + user_suffix + ) + user_tokens = self.tokenizer.apply_chat_template( + [{"role": "user", "content": user_input_string}], tokenize=True + ) + full_tokens = system_tokens + user_tokens + acoustic_input_mask = [ + 1 if token == self.speech_pad_id else 0 for token in full_tokens + ] + return { + "input_ids": full_tokens, + "acoustic_input_mask": acoustic_input_mask, + "speech": audio_array, + "vae_tok_len": vae_tok_len, + } + + def _batch_encode( + self, + encodings: List[Dict[str, Any]], + padding: bool = True, + max_length: Optional[int] = None, + truncation: bool = False, + return_tensors: Optional[str] = None, + ) -> BatchEncoding: + input_ids_list = [enc["input_ids"] for enc in encodings] + acoustic_masks_list = [enc["acoustic_input_mask"] for enc in encodings] + speech_list = [enc["speech"] for enc in encodings] + vae_tok_lens = [enc["vae_tok_len"] for enc in encodings] + if padding: + if max_length is not None: + target_length = max_length + else: + target_length = max((len(ids) for ids in input_ids_list)) + padded_input_ids = [] + padded_acoustic_masks = [] + attention_masks = [] + for input_ids, acoustic_mask in zip(input_ids_list, acoustic_masks_list): + if truncation and len(input_ids) > target_length: + input_ids = input_ids[:target_length] + acoustic_mask = acoustic_mask[:target_length] + padding_length = target_length - len(input_ids) + padded_ids = [self.pad_id] * padding_length + input_ids + padded_acoustic = [0] * padding_length + acoustic_mask + attention_mask = [0] * padding_length + [1] * len(input_ids) + padded_input_ids.append(padded_ids) + padded_acoustic_masks.append(padded_acoustic) + attention_masks.append(attention_mask) + input_ids_list = padded_input_ids + acoustic_masks_list = padded_acoustic_masks + else: + attention_masks = [[1] * len(ids) for ids in input_ids_list] + max_speech_length = max((len(s) for s in speech_list)) + padded_speeches = np.zeros( + (len(speech_list), max_speech_length), dtype=np.float32 + ) + speech_masks = np.zeros((len(speech_list), max(vae_tok_lens)), dtype=bool) + for i, (speech, vae_len) in enumerate(zip(speech_list, vae_tok_lens)): + padded_speeches[i, : len(speech)] = speech + speech_masks[i, :vae_len] = True + batch_encoding = BatchEncoding() + if return_tensors == "pt": + batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) + batch_encoding["attention_mask"] = torch.tensor( + attention_masks, dtype=torch.long + ) + batch_encoding["acoustic_input_mask"] = torch.tensor( + acoustic_masks_list, dtype=torch.bool + ) + batch_encoding["speech_tensors"] = torch.tensor( + padded_speeches, dtype=torch.float32 + ) + batch_encoding["speech_masks"] = torch.tensor( + speech_masks, dtype=torch.bool + ) + else: + batch_encoding["input_ids"] = ( + input_ids_list if len(input_ids_list) > 1 else input_ids_list[0] + ) + batch_encoding["attention_mask"] = ( + attention_masks if len(attention_masks) > 1 else attention_masks[0] + ) + batch_encoding["acoustic_input_mask"] = ( + acoustic_masks_list + if len(acoustic_masks_list) > 1 + else acoustic_masks_list[0] + ) + batch_encoding["speech_tensors"] = ( + padded_speeches if len(padded_speeches) > 1 else padded_speeches[0] + ) + batch_encoding["speech_masks"] = ( + speech_masks if len(speech_masks) > 1 else speech_masks[0] + ) + return batch_encoding + + def batch_decode(self, *args, **kwargs): + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.tokenizer.decode(*args, **kwargs) + + def post_process_transcription(self, text: str) -> List[Dict[str, Any]]: + try: + if "```json" in text: + json_start = text.find("```json") + 7 + json_end = text.find("```", json_start) + json_str = text[json_start:json_end].strip() + else: + json_start = text.find("[") + if json_start == -1: + json_start = text.find("{") + if json_start != -1: + bracket_count = 0 + json_end = json_start + for i in range(json_start, len(text)): + if text[i] in "[{": + bracket_count += 1 + elif text[i] in "]}": + bracket_count -= 1 + if bracket_count == 0: + json_end = i + 1 + break + json_str = text[json_start:json_end] + else: + json_str = text + result = json.loads(json_str) + if isinstance(result, dict): + result = [result] + cleaned_result = [] + for item in result: + if isinstance(item, dict): + cleaned_item = {} + key_mapping = { + "Start time": "start_time", + "Start": "start_time", + "End time": "end_time", + "End": "end_time", + "Speaker ID": "speaker_id", + "Speaker": "speaker_id", + "Content": "text", + } + for key, mapped_key in key_mapping.items(): + if key in item: + cleaned_item[mapped_key] = item[key] + if cleaned_item: + cleaned_result.append(cleaned_item) + return cleaned_result + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse JSON from transcription: {e }") + logger.debug(f"Raw text: {text }") + return [] + except Exception as e: + logger.warning(f"Error post-processing transcription: {e }") + return [] + + @property + def model_input_names(self): + return [ + "input_ids", + "attention_mask", + "acoustic_input_mask", + "speech_tensors", + "speech_masks", + ] + + +__all__ = [ + 'QWEN3VoxASRProcessor' +] +import math +import warnings +from typing import List, Optional, Union, Dict, Any, Tuple +import os +import re +import numpy as np +import torch +from transformers.tokenization_utils_base import ( + BatchEncoding, + PaddingStrategy, + PreTokenizedInput, + TextInput, + TruncationStrategy, +) +from transformers.utils import TensorType, logging + +logger = logging.get_logger(__name__) + + +class QWEN3VoxProcessor: + + def __init__( + self, + tokenizer=None, + audio_processor=None, + speech_tok_compress_ratio=3200, + db_normalize=True, + **kwargs, + ): + self.tokenizer = tokenizer + self.audio_processor = audio_processor + self.speech_tok_compress_ratio = speech_tok_compress_ratio + self.db_normalize = db_normalize + self.audio_normalizer = AudioNormalizer() if db_normalize else None + self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + import os + import json + from transformers.utils import cached_file + + config_path = os.path.join( + pretrained_model_name_or_path, "preprocessor_config.json" + ) + config = None + if os.path.exists(config_path): + with open(config_path, "r") as f: + config = json.load(f) + else: + try: + config_file = cached_file( + pretrained_model_name_or_path, "preprocessor_config.json", **kwargs + ) + with open(config_file, "r") as f: + config = json.load(f) + except Exception as e: + logger.warning( + f"Could not load preprocessor_config.json from {pretrained_model_name_or_path }: {e }" + ) + logger.warning("Using default configuration") + config = {"speech_tok_compress_ratio": 3200, "db_normalize": True} + speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) + db_normalize = config.get("db_normalize", True) + model_name = str(pretrained_model_name_or_path) + logger.info(f"Loading tokenizer from {model_name}") + if "qwen" in model_name.lower(): + tokenizer = QWEN3VoxTextTokenizerFast.from_pretrained(model_name, **kwargs) + else: + raise ValueError( + f"Unsupported tokenizer type for {language_model_pretrained_name }. Supported types: Qwen, Llama, Gemma." + ) + if "audio_processor" in config: + audio_config = config["audio_processor"] + audio_processor = QWEN3VoxTokenizerProcessor( + sampling_rate=audio_config.get("sampling_rate", 22050), + normalize_audio=audio_config.get("normalize_audio", True), + target_dB_FS=audio_config.get("target_dB_FS", -25), + eps=audio_config.get("eps", 1e-06), + ) + else: + audio_processor = QWEN3VoxTokenizerProcessor() + return cls( + tokenizer=tokenizer, + audio_processor=audio_processor, + speech_tok_compress_ratio=speech_tok_compress_ratio, + db_normalize=db_normalize, + ) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): + import os + import json + + os.makedirs(save_directory, exist_ok=True) + processor_config = { + "processor_class": "QWEN3VoxProcessor", + "speech_tok_compress_ratio": self.speech_tok_compress_ratio, + "db_normalize": self.db_normalize, + "audio_processor": { + "feature_extractor_type": "QWEN3VoxTokenizerProcessor", + "sampling_rate": getattr(self.audio_processor, "sampling_rate", 22050), + "normalize_audio": getattr( + self.audio_processor, "normalize_audio", True + ), + "target_dB_FS": getattr(self.audio_processor, "target_dB_FS", -25), + "eps": getattr(self.audio_processor, "eps", 1e-06), + }, + } + config_path = os.path.join(save_directory, "preprocessor_config.json") + with open(config_path, "w") as f: + json.dump(processor_config, f, indent=2) + logger.info(f"Processor configuration saved in {config_path }") + + def __call__( + self, + text: Optional[ + Union[ + str, + List[str], + TextInput, + PreTokenizedInput, + List[TextInput], + List[PreTokenizedInput], + ] + ] = None, + voice_samples: Optional[ + Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]] + ] = None, + padding: Union[bool, str, PaddingStrategy] = True, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: bool = True, + **kwargs, + ) -> BatchEncoding: + if isinstance(text, str) or ( + isinstance(text, list) and len(text) > 0 and (not isinstance(text[0], str)) + ): + texts = [text] + is_batched = False + else: + texts = text + is_batched = True + if voice_samples is not None: + if not is_batched or isinstance(voice_samples[0], (str, np.ndarray)): + voice_samples_list = [voice_samples] + else: + voice_samples_list = voice_samples + else: + voice_samples_list = [None] * len(texts) + all_encodings = [] + for text_input, voice_input in zip(texts, voice_samples_list): + encoding = self._process_single(text_input, voice_input) + all_encodings.append(encoding) + batch_encoding = self._batch_encode( + all_encodings, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_attention_mask=return_attention_mask, + ) + return batch_encoding + + def _process_single( + self, + text: Union[str, TextInput], + voice_samples: Optional[List[Union[str, np.ndarray]]] = None, + ) -> Dict[str, Any]: + script = None + if isinstance(text, str): + if text.endswith(".json") and os.path.exists(text): + script = self._convert_json_to_script(text) + elif text.endswith(".txt") and os.path.exists(text): + script = self._convert_text_to_script(text) + else: + script = text + if script is None: + raise ValueError(f"Could not process input text: {text }") + parsed_lines = self._parse_script(script) + all_speakers = list(set((speaker_id for speaker_id, _ in parsed_lines))) + system_tokens = self.tokenizer.encode(self.system_prompt) + if voice_samples: + voice_tokens, voice_speech_inputs, voice_speech_masks = ( + self._create_voice_prompt(voice_samples[: len(all_speakers)]) + ) + else: + voice_tokens, voice_speech_inputs, voice_speech_masks = ([], [], []) + full_tokens = system_tokens + voice_tokens + speech_input_mask = [False] * len(system_tokens) + voice_speech_masks + full_tokens += self.tokenizer.encode(" Text input:\n", add_special_tokens=False) + speech_input_mask += [False] * len( + self.tokenizer.encode(" Text input:\n", add_special_tokens=False) + ) + for speaker_id, speaker_text in parsed_lines: + speaker_text_tokens = self.tokenizer.encode( + f" Speaker {speaker_id }:{speaker_text }\n", add_special_tokens=False + ) + full_tokens += speaker_text_tokens + speech_input_mask += [False] * len(speaker_text_tokens) + full_tokens += self.tokenizer.encode( + " Speech output:\n", add_special_tokens=False + ) + [self.tokenizer.speech_start_id] + speech_input_mask += [False] * ( + len(self.tokenizer.encode(" Speech output:\n", add_special_tokens=False)) + + 1 + ) + return { + "input_ids": full_tokens, + "speech_inputs": voice_speech_inputs if voice_speech_inputs else None, + "speech_input_mask": speech_input_mask, + "parsed_script": parsed_lines, + "all_speakers": all_speakers, + } + + def _batch_encode( + self, + encodings: List[Dict[str, Any]], + padding: Union[bool, str, PaddingStrategy] = True, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: bool = True, + ) -> BatchEncoding: + input_ids_list = [enc["input_ids"] for enc in encodings] + speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings] + if isinstance(padding, bool): + padding_strategy = ( + PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD + ) + elif isinstance(padding, str): + padding_strategy = PaddingStrategy(padding) + else: + padding_strategy = padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD: + if padding_strategy == PaddingStrategy.LONGEST: + max_len = max((len(ids) for ids in input_ids_list)) + elif ( + padding_strategy == PaddingStrategy.MAX_LENGTH + and max_length is not None + ): + max_len = max_length + else: + max_len = max((len(ids) for ids in input_ids_list)) + padded_input_ids = [] + attention_masks = [] + padded_speech_input_masks = [] + for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list): + if truncation and len(input_ids) > max_len: + input_ids = input_ids[:max_len] + speech_mask = speech_mask[:max_len] + padding_length = max_len - len(input_ids) + padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids + attention_mask = [0] * padding_length + [1] * len(input_ids) + padded_speech_mask = [False] * padding_length + speech_mask + padded_input_ids.append(padded_ids) + attention_masks.append(attention_mask) + padded_speech_input_masks.append(padded_speech_mask) + input_ids_list = padded_input_ids + speech_input_masks_list = padded_speech_input_masks + else: + attention_masks = ( + [[1] * len(ids) for ids in input_ids_list] + if return_attention_mask + else None + ) + all_speech_inputs = [] + has_speech = False + for enc in encodings: + if enc["speech_inputs"] is not None: + all_speech_inputs.extend(enc["speech_inputs"]) + has_speech = True + batch_encoding = BatchEncoding() + if return_tensors is not None: + batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) + if return_attention_mask and attention_masks is not None: + batch_encoding["attention_mask"] = torch.tensor( + attention_masks, dtype=torch.long + ) + batch_encoding["speech_input_mask"] = torch.tensor( + speech_input_masks_list, dtype=torch.bool + ) + else: + batch_encoding["input_ids"] = input_ids_list + if return_attention_mask and attention_masks is not None: + batch_encoding["attention_mask"] = attention_masks + batch_encoding["speech_input_mask"] = speech_input_masks_list + if has_speech: + speech_dict = self.prepare_speech_inputs( + all_speech_inputs, return_tensors=return_tensors + ) + batch_encoding["speech_tensors"] = speech_dict["padded_speeches"] + batch_encoding["speech_masks"] = speech_dict["speech_masks"] + else: + batch_encoding["speech_tensors"] = None + batch_encoding["speech_masks"] = None + batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings] + batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings] + return batch_encoding + + def _create_voice_prompt( + self, speaker_samples: List[Union[str, np.ndarray]] + ) -> Tuple[List[int], List[np.ndarray], List[bool]]: + vae_token_id = self.tokenizer.speech_diffusion_id + voice_full_tokens = self.tokenizer.encode( + " Voice input:\n", add_special_tokens=False + ) + voice_speech_inputs = [] + voice_speech_masks = [False] * len(voice_full_tokens) + for speaker_id, speaker_audio in enumerate(speaker_samples): + prefix_tokens = self.tokenizer.encode( + f" Speaker {speaker_id }:", add_special_tokens=False + ) + if isinstance(speaker_audio, str): + wav = self.audio_processor._load_audio_from_path(speaker_audio) + elif isinstance(speaker_audio, dict): + if "array" in speaker_audio: + wav = np.array(speaker_audio["array"], dtype=np.float32) + elif "audio" in speaker_audio: + wav = np.array(speaker_audio["audio"], dtype=np.float32) + else: + raise ValueError( + f"Dictionary audio input must have 'array' or 'audio' key, got: {speaker_audio .keys ()}" + ) + else: + wav = np.array(speaker_audio, dtype=np.float32) + if self.db_normalize and self.audio_normalizer: + wav = self.audio_normalizer(wav) + vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio) + speaker_tokens = ( + prefix_tokens + + [self.tokenizer.speech_start_id] + + [vae_token_id] * vae_tok_len + + [self.tokenizer.speech_end_id] + + self.tokenizer.encode("\n", add_special_tokens=False) + ) + vae_input_mask = ( + [False] * len(prefix_tokens) + + [False] + + [True] * vae_tok_len + + [False] + + [False] + ) + voice_full_tokens.extend(speaker_tokens) + voice_speech_masks.extend(vae_input_mask) + voice_speech_inputs.append(wav) + return (voice_full_tokens, voice_speech_inputs, voice_speech_masks) + + def prepare_speech_inputs( + self, + speech_inputs: List[np.ndarray], + return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + ) -> Dict[str, Any]: + if not speech_inputs: + return {"padded_speeches": None, "speech_masks": None} + vae_tok_seqlens = [ + math.ceil(s.shape[0] / self.speech_tok_compress_ratio) + for s in speech_inputs + ] + max_speech_length = max((s.shape[0] for s in speech_inputs)) + if speech_inputs[0].ndim == 1: + padded_speeches = np.full( + (len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32 + ) + else: + padded_speeches = np.full( + (len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), + fill_value=0, + dtype=np.float32, + ) + speech_masks = np.zeros( + (len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_ + ) + for i, (speech, vae_tok_length) in enumerate( + zip(speech_inputs, vae_tok_seqlens) + ): + padded_speeches[i, : len(speech)] = speech + speech_masks[i, :vae_tok_length] = True + result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks} + if return_tensors == "pt": + result["padded_speeches"] = torch.tensor( + padded_speeches, device=device, dtype=dtype or torch.float32 + ) + result["speech_masks"] = torch.tensor( + speech_masks, device=device, dtype=torch.bool + ) + return result + + def _convert_json_to_script(self, json_file: str) -> str: + import json + + with open(json_file, "r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, list): + raise ValueError("JSON file must contain a list of speaker entries") + script_lines = [] + for item in data: + if not isinstance(item, dict): + logger.warning(f"Skipping non-dict entry: {item }") + continue + speaker = item.get("speaker") + text = item.get("text") + if speaker is None or text is None: + logger.warning(f"Skipping entry missing speaker or text: {item }") + continue + try: + speaker_id = int(speaker) + except (ValueError, TypeError): + logger.warning(f"Invalid speaker ID: {speaker }, skipping entry") + continue + text = text.strip() + if text: + script_lines.append(f"Speaker {speaker_id }: {text }") + if not script_lines: + raise ValueError("No valid entries found in JSON file") + return "\n".join(script_lines) + + def _convert_text_to_script(self, text_file: str) -> str: + with open(text_file, "r", encoding="utf-8") as f: + lines = f.readlines() + script_lines = [] + current_speaker = 1 + for line in lines: + line = line.strip() + if not line: + continue + speaker_match = re.match( + "^Speaker\\s+(\\d+)\\s*:\\s*(.*)$", line, re.IGNORECASE + ) + if speaker_match: + speaker_id = int(speaker_match.group(1)) + text = speaker_match.group(2).strip() + if text: + script_lines.append(f"Speaker {speaker_id }: {text }") + else: + script_lines.append(f"Speaker {current_speaker }: {line }") + if not script_lines: + raise ValueError("No valid content found in text file") + return "\n".join(script_lines) + + def _parse_script(self, script: str) -> List[Tuple[int, str]]: + stripped = script.strip() + if not stripped: + raise ValueError( + "No valid speaker lines found in script (empty text). " + "If training with HuggingFace Trainer, set remove_unused_columns=False " + "so dataset columns like `text` are not stripped before the collator." + ) + non_empty = [ln.strip() for ln in stripped.split("\n") if ln.strip()] + if not non_empty: + raise ValueError("No valid speaker lines found in script") + _speaker_line = r"^Speaker\s+(\d+)\s*:\s*(.*)$" + if not any(re.match(_speaker_line, ln, re.IGNORECASE) for ln in non_empty): + # JSONL / Vocence eval rows: plain prompt (e.g. text). + collapsed = " ".join(stripped.split()) + return [(0, " " + collapsed)] + lines = stripped.split("\n") + parsed_lines = [] + speaker_ids = [] + for line in non_empty: + match = re.match(_speaker_line, line, re.IGNORECASE) + if match: + speaker_id = int(match.group(1)) + text = " " + match.group(2).strip() + parsed_lines.append((speaker_id, text)) + speaker_ids.append(speaker_id) + else: + logger.warning(f"Could not parse line: '{line }'") + if not parsed_lines: + raise ValueError("No valid speaker lines found in script") + min_speaker_id = min(speaker_ids) + if min_speaker_id > 0: + normalized_lines = [] + for speaker_id, text in parsed_lines: + normalized_lines.append((speaker_id - 1, text)) + return normalized_lines + else: + return parsed_lines + + def _merge_inputs( + self, text_inputs: BatchEncoding, audio_inputs: Dict + ) -> BatchEncoding: + merged = BatchEncoding(text_inputs) + if "audio" in audio_inputs: + merged["speech_inputs"] = audio_inputs["audio"] + if "streaming" in audio_inputs: + merged["streaming"] = audio_inputs["streaming"] + return merged + + def batch_decode(self, *args, **kwargs): + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + audio_processor_input_names = self.audio_processor.model_input_names + return list( + dict.fromkeys( + tokenizer_input_names + + audio_processor_input_names + + ["speech_inputs", "speech_input_mask"] + ) + ) + + def save_audio( + self, + audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], + output_path: str = "output.wav", + sampling_rate: Optional[int] = None, + normalize: bool = False, + batch_prefix: str = "audio_", + ) -> str: + return self.audio_processor.save_audio( + audio, + output_path=output_path, + sampling_rate=sampling_rate, + normalize=normalize, + batch_prefix=batch_prefix, + ) + + +__all__ = [ + 'QWEN3VoxProcessor' +] +'\nQWEN3Vox Streaming Processor\n\nThis processor handles input preparation for the streaming 0.5B model,\nincluding text tokenization and cached voice prompt handling.\n' +import math +import warnings +from typing import List, Optional, Union, Dict, Any, Tuple +import os +import re +import numpy as np +import torch +from transformers.tokenization_utils_base import ( + BatchEncoding, + PaddingStrategy, + PreTokenizedInput, + TextInput, + TruncationStrategy, +) +from transformers.utils import TensorType, logging + +logger = logging.get_logger(__name__) + + +class QWEN3VoxStreamingProcessor: + + def __init__( + self, + tokenizer=None, + audio_processor=None, + speech_tok_compress_ratio=3200, + db_normalize=True, + **kwargs, + ): + self.tokenizer = tokenizer + self.audio_processor = audio_processor + self.speech_tok_compress_ratio = speech_tok_compress_ratio + self.db_normalize = db_normalize + self.audio_normalizer = AudioNormalizer() if db_normalize else None + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + import os + import json + from transformers.utils import cached_file + + config_path = os.path.join( + pretrained_model_name_or_path, "preprocessor_config.json" + ) + config = None + if os.path.exists(config_path): + with open(config_path, "r") as f: + config = json.load(f) + else: + try: + config_file = cached_file( + pretrained_model_name_or_path, "preprocessor_config.json", **kwargs + ) + with open(config_file, "r") as f: + config = json.load(f) + except Exception as e: + logger.warning( + f"Could not load preprocessor_config.json from {pretrained_model_name_or_path }: {e }" + ) + logger.warning("Using default configuration") + config = {"speech_tok_compress_ratio": 3200, "db_normalize": True} + speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) + db_normalize = config.get("db_normalize", True) + model_name = str(pretrained_model_name_or_path) + logger.info(f"Loading tokenizer from {model_name}") + if "qwen" in model_name.lower(): + tokenizer = QWEN3VoxTextTokenizerFast.from_pretrained(model_name, **kwargs) + else: + raise ValueError( + f"Unsupported tokenizer type for {language_model_pretrained_name }. Supported types: Qwen." + ) + if "audio_processor" in config: + audio_config = config["audio_processor"] + audio_processor = QWEN3VoxTokenizerProcessor( + sampling_rate=audio_config.get("sampling_rate", 22050), + normalize_audio=audio_config.get("normalize_audio", True), + target_dB_FS=audio_config.get("target_dB_FS", -25), + eps=audio_config.get("eps", 1e-06), + ) + else: + audio_processor = QWEN3VoxTokenizerProcessor() + return cls( + tokenizer=tokenizer, + audio_processor=audio_processor, + speech_tok_compress_ratio=speech_tok_compress_ratio, + db_normalize=db_normalize, + ) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): + import os + import json + + os.makedirs(save_directory, exist_ok=True) + processor_config = { + "processor_class": "QWEN3VoxStreamingProcessor", + "speech_tok_compress_ratio": self.speech_tok_compress_ratio, + "db_normalize": self.db_normalize, + "audio_processor": { + "feature_extractor_type": "QWEN3VoxTokenizerProcessor", + "sampling_rate": getattr(self.audio_processor, "sampling_rate", 22050), + "normalize_audio": getattr( + self.audio_processor, "normalize_audio", True + ), + "target_dB_FS": getattr(self.audio_processor, "target_dB_FS", -25), + "eps": getattr(self.audio_processor, "eps", 1e-06), + }, + } + config_path = os.path.join(save_directory, "preprocessor_config.json") + with open(config_path, "w") as f: + json.dump(processor_config, f, indent=2) + logger.info(f"Processor configuration saved in {config_path }") + + def __call__(self) -> BatchEncoding: + raise NotImplementedError( + 'QWEN3VoxStreamingProcessor.__call__ is not implemented. Use process_input_with_cached_prompt for streaming inputs.' + ) + + def process_input_with_cached_prompt( + self, + text: Optional[str] = None, + cached_prompt: Optional[Dict[str, Any]] = None, + padding: Union[bool, str, PaddingStrategy] = True, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: bool = True, + **kwargs, + ) -> BatchEncoding: + texts = [text] + cached_prompts = [cached_prompt] + is_batched = False + all_encodings = [] + for text_input, cached_prompt_input in zip(texts, cached_prompts): + script_tokens = self.tokenizer.encode( + text_input.strip() + "\n", add_special_tokens=False + ) + input_id_length = cached_prompt_input["lm"]["last_hidden_state"].size(1) + tts_lm_input_id_length = cached_prompt_input["tts_lm"][ + "last_hidden_state" + ].size(1) + input_ids = [self.tokenizer.pad_id] * input_id_length + tts_lm_input_ids = [self.tokenizer.pad_id] * tts_lm_input_id_length + speech_input_mask = [False] * tts_lm_input_id_length + encoding = { + "input_ids": input_ids, + "tts_lm_input_ids": tts_lm_input_ids, + "tts_text_ids": script_tokens, + "speech_inputs": None, + "speech_input_mask": speech_input_mask, + } + all_encodings.append(encoding) + batch_encoding = self._batch_encode( + all_encodings, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_attention_mask=return_attention_mask, + ) + return batch_encoding + + def _batch_encode( + self, + encodings: List[Dict[str, Any]], + padding: Union[bool, str, PaddingStrategy] = True, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: bool = True, + ) -> BatchEncoding: + input_ids_list = [enc["input_ids"] for enc in encodings] + tts_lm_input_ids_list = [enc["tts_lm_input_ids"] for enc in encodings] + tts_text_ids_list = [enc["tts_text_ids"] for enc in encodings] + speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings] + attention_masks = ( + [[1] * len(ids) for ids in input_ids_list] + if return_attention_mask + else None + ) + tts_lm_attention_masks = ( + [[1] * len(ids) for ids in tts_lm_input_ids_list] + if return_attention_mask + else None + ) + all_speech_inputs = [] + has_speech = False + for enc in encodings: + if enc["speech_inputs"] is not None: + all_speech_inputs.extend(enc["speech_inputs"]) + has_speech = True + batch_encoding = BatchEncoding() + if return_tensors is not None: + batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) + batch_encoding["tts_lm_input_ids"] = torch.tensor( + tts_lm_input_ids_list, dtype=torch.long + ) + batch_encoding["tts_text_ids"] = torch.tensor( + tts_text_ids_list, dtype=torch.long + ) + if return_attention_mask and attention_masks is not None: + batch_encoding["attention_mask"] = torch.tensor( + attention_masks, dtype=torch.long + ) + batch_encoding["tts_lm_attention_mask"] = torch.tensor( + tts_lm_attention_masks, dtype=torch.long + ) + batch_encoding["speech_input_mask"] = torch.tensor( + speech_input_masks_list, dtype=torch.bool + ) + else: + batch_encoding["input_ids"] = input_ids_list + batch_encoding["tts_lm_input_ids"] = tts_lm_input_ids_list + batch_encoding["tts_text_ids"] = tts_text_ids_list + if return_attention_mask and attention_masks is not None: + batch_encoding["attention_mask"] = attention_masks + batch_encoding["tts_lm_attention_mask"] = tts_lm_attention_masks + batch_encoding["speech_input_mask"] = speech_input_masks_list + if has_speech: + speech_dict = self.prepare_speech_inputs( + all_speech_inputs, return_tensors=return_tensors + ) + batch_encoding["speech_tensors"] = speech_dict["padded_speeches"] + batch_encoding["speech_masks"] = speech_dict["speech_masks"] + else: + batch_encoding["speech_tensors"] = None + batch_encoding["speech_masks"] = None + return batch_encoding + + def prepare_speech_inputs( + self, + speech_inputs: List[np.ndarray], + return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + ) -> Dict[str, Any]: + if not speech_inputs: + return {"padded_speeches": None, "speech_masks": None} + vae_tok_seqlens = [ + math.ceil(s.shape[0] / self.speech_tok_compress_ratio) + for s in speech_inputs + ] + max_speech_length = max((s.shape[0] for s in speech_inputs)) + if speech_inputs[0].ndim == 1: + padded_speeches = np.full( + (len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32 + ) + else: + padded_speeches = np.full( + (len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), + fill_value=0, + dtype=np.float32, + ) + speech_masks = np.zeros( + (len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_ + ) + for i, (speech, vae_tok_length) in enumerate( + zip(speech_inputs, vae_tok_seqlens) + ): + padded_speeches[i, : len(speech)] = speech + speech_masks[i, :vae_tok_length] = True + result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks} + if return_tensors == "pt": + result["padded_speeches"] = torch.tensor( + padded_speeches, device=device, dtype=dtype or torch.float32 + ) + result["speech_masks"] = torch.tensor( + speech_masks, device=device, dtype=torch.bool + ) + return result + + def batch_decode(self, *args, **kwargs): + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + audio_processor_input_names = self.audio_processor.model_input_names + return list( + dict.fromkeys( + tokenizer_input_names + + audio_processor_input_names + + ["speech_inputs", "speech_input_mask"] + ) + ) + + def save_audio( + self, + audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], + output_path: str = "output.wav", + sampling_rate: Optional[int] = None, + normalize: bool = False, + batch_prefix: str = "audio_", + ) -> str: + return self.audio_processor.save_audio( + audio, + output_path=output_path, + sampling_rate=sampling_rate, + normalize=normalize, + batch_prefix=batch_prefix, + ) + + +__all__ = [ + 'QWEN3VoxStreamingProcessor' +] +'\nQWEN3Vox Streaming Model Architecture (0.5B)\n\nThis module implements the streaming-optimized version of QWEN3Vox for real-time TTS.\nKey differences from the multi-speaker model:\n- No semantic tokenizer (only acoustic)\n- Split language model architecture: lower layers for text, upper layers for TTS\n- Optimized for low-latency generation\n' +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union, Callable +from tqdm import tqdm +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from transformers.models.auto import AutoModel, AutoModelForCausalLM +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + CausalLMOutput, + BaseModelOutputWithPast, + ModelOutput, +) +from transformers.models.llama.modeling_llama import LlamaRMSNorm +from transformers import modeling_utils +from transformers.modeling_utils import PreTrainedModel +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.utils import logging + +logger = logging.get_logger(__name__) +if ( + not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") + or modeling_utils.ALL_PARALLEL_STYLES is None +): + modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] + + +class BinaryClassifier(nn.Module): + + def __init__(self, hidden_size): + super(BinaryClassifier, self).__init__() + self.fc1 = nn.Linear(hidden_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, 1) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = self.fc2(x) + return x + + +class SpeechConnector(nn.Module): + + def __init__(self, input_dim, output_dim): + super().__init__() + self.fc1 = nn.Linear(input_dim, output_dim) + self.norm = LlamaRMSNorm(output_dim, eps=1e-06) + self.fc2 = nn.Linear(output_dim, output_dim) + + def forward(self, features, **kwargs): + x = self.fc1(features) + x = self.norm(x) + x = self.fc2(x) + return x + + +class QWEN3VoxStreamingPreTrainedModel(PreTrainedModel): + config_class = QWEN3VoxStreamingConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + if isinstance(module, QWEN3VoxDiffusionHead): + module.initialize_weights() + return + if hasattr(self.config, "language_model_config") and hasattr( + self.config.language_model_config, "initializer_range" + ): + std = self.config.language_model_config.initializer_range + elif hasattr(self.config, "decoder_config") and hasattr( + self.config.decoder_config, "initializer_range" + ): + std = self.config.decoder_config.initializer_range + else: + std = 0.02 + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +class QWEN3VoxStreamingModel(QWEN3VoxStreamingPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + if hasattr(config, "torch_dtype") and config.torch_dtype is not None: + if isinstance(config.torch_dtype, str): + dtype = getattr(torch, config.torch_dtype) + else: + dtype = config.torch_dtype + else: + dtype = torch.float32 + lm_config = copy.deepcopy(config.decoder_config) + lm_backbone_num_hidden_layers = ( + getattr(lm_config, "num_hidden_layers", 24) + - config.tts_backbone_num_hidden_layers + ) + lm_config.num_hidden_layers = lm_backbone_num_hidden_layers + self.language_model = AutoModel.from_config(lm_config) + self.language_model.norm = nn.Identity() + tts_lm_config = copy.deepcopy(lm_config) + tts_lm_config.num_hidden_layers = config.tts_backbone_num_hidden_layers + self.tts_language_model = AutoModel.from_config(tts_lm_config) + self.tts_input_types = nn.Embedding( + num_embeddings=2, embedding_dim=config.decoder_config.hidden_size + ) + self.acoustic_tokenizer = AutoModel.from_config( + config.acoustic_tokenizer_config + ).to(dtype) + self.acoustic_connector = SpeechConnector( + config.acoustic_vae_dim, lm_config.hidden_size + ).to(dtype) + self.register_buffer("speech_scaling_factor", torch.tensor(float("nan"))) + self.register_buffer("speech_bias_factor", torch.tensor(float("nan"))) + self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to( + dtype + ) + self.noise_scheduler = DPMSolverMultistepScheduler( + num_train_timesteps=config.diffusion_head_config.ddpm_num_steps, + beta_schedule=config.diffusion_head_config.ddpm_beta_schedule, + prediction_type=config.diffusion_head_config.prediction_type, + ) + + def get_input_embeddings(self): + if hasattr(self.language_model, "embed_tokens"): + return self.language_model.embed_tokens + for name, attr in self.language_model.fullmap.items(): + if attr.orig_name == "embed_tokens.weight": + return getattr(self.language_model, name) + assert False, "should not arrive here" + + def set_input_embeddings(self, value): + self.language_model.embed_tokens = value + + def set_speech_tokenizers(self, acoustic_tokenizer=None): + self.acoustic_tokenizer = acoustic_tokenizer + if self.acoustic_tokenizer is not None: + self.acoustic_tokenizer.train(False) + + def forward(self, *args, **kwargs): + raise RuntimeError( + 'QWEN3VoxStreamingModel.forward is intentionally disabled. Use `model.language_model(...)` or `model.tts_language_model(...)` instead.' + ) + + +AutoModel.register(QWEN3VoxStreamingConfig, QWEN3VoxStreamingModel) +__all__ = [ + 'QWEN3VoxStreamingPreTrainedModel', + 'QWEN3VoxStreamingModel', + "BinaryClassifier", + "SpeechConnector", +] +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union, Callable +from tqdm import tqdm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from transformers.models.auto import AutoModel, AutoModelForCausalLM +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + CausalLMOutput, + BaseModelOutputWithPast, + ModelOutput, +) +from transformers.models.llama.modeling_llama import LlamaRMSNorm +from transformers import modeling_utils +from transformers.modeling_utils import PreTrainedModel +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.utils import logging + +logger = logging.get_logger(__name__) +if ( + not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") + or modeling_utils.ALL_PARALLEL_STYLES is None +): + modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] + + +@dataclass +class QWEN3VoxCausalLMOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + diffusion_loss: Optional[torch.FloatTensor] = None + speech_token_num: Optional[int] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class QWEN3VoxGenerationOutput(ModelOutput): + sequences: torch.LongTensor = None + speech_outputs: Optional[List[torch.FloatTensor]] = None + + +class SpeechConnector(nn.Module): + + def __init__(self, input_dim, output_dim): + super().__init__() + self.fc1 = nn.Linear(input_dim, output_dim) + self.norm = LlamaRMSNorm(output_dim, eps=1e-06) + self.fc2 = nn.Linear(output_dim, output_dim) + + def forward(self, features, **kwargs): + x = self.fc1(features) + x = self.norm(x) + x = self.fc2(x) + return x + + +class QWEN3VoxPreTrainedModel(PreTrainedModel): + config_class = QWEN3VoxConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + if isinstance(module, QWEN3VoxDiffusionHead): + module.initialize_weights() + return + if hasattr(self.config, "language_model_config") and hasattr( + self.config.language_model_config, "initializer_range" + ): + std = self.config.language_model_config.initializer_range + elif hasattr(self.config, "decoder_config") and hasattr( + self.config.decoder_config, "initializer_range" + ): + std = self.config.decoder_config.initializer_range + else: + std = 0.02 + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +class QWEN3VoxModel(QWEN3VoxPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + if hasattr(config, "torch_dtype") and config.torch_dtype is not None: + if isinstance(config.torch_dtype, str): + dtype = getattr(torch, config.torch_dtype) + else: + dtype = config.torch_dtype + else: + dtype = torch.float32 + lm_config = config.decoder_config + self.language_model = AutoModel.from_config(lm_config) + self.acoustic_tokenizer = AutoModel.from_config( + config.acoustic_tokenizer_config + ).to(dtype) + self.semantic_tokenizer = AutoModel.from_config( + config.semantic_tokenizer_config + ).to(dtype) + self.acoustic_connector = SpeechConnector( + config.acoustic_vae_dim, lm_config.hidden_size + ).to(dtype) + self.semantic_connector = SpeechConnector( + config.semantic_vae_dim, lm_config.hidden_size + ).to(dtype) + self.register_buffer("speech_scaling_factor", torch.tensor(float("nan"))) + self.register_buffer("speech_bias_factor", torch.tensor(float("nan"))) + self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to( + dtype + ) + self.noise_scheduler = DPMSolverMultistepScheduler( + num_train_timesteps=config.diffusion_head_config.ddpm_num_steps, + beta_schedule=config.diffusion_head_config.ddpm_beta_schedule, + prediction_type=config.diffusion_head_config.prediction_type, + ) + + def get_input_embeddings(self): + if hasattr(self.language_model, "embed_tokens"): + return self.language_model.embed_tokens + for name, attr in self.language_model.fullmap.items(): + if attr.orig_name == "embed_tokens.weight": + return getattr(self.language_model, name) + assert False, "should not arrive here" + + def set_input_embeddings(self, value): + self.language_model.embed_tokens = value + + def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None): + self.acoustic_tokenizer = acoustic_tokenizer + self.semantic_tokenizer = semantic_tokenizer + if self.acoustic_tokenizer is not None: + self.acoustic_tokenizer.train(False) + if self.semantic_tokenizer is not None: + self.semantic_tokenizer.train(False) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + if not return_dict: + return outputs + return BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class QWEN3VoxForConditionalGeneration(QWEN3VoxPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = QWEN3VoxModel(config) + self.vocab_size = config.decoder_config.vocab_size + self.lm_head = nn.Linear( + config.decoder_config.hidden_size, self.vocab_size, bias=False + ) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_decoder(self, decoder): + self.model.language_model = decoder + + def get_decoder(self): + return self.model.language_model + + def tie_weights(self): + if getattr(self.config.decoder_config, "tie_word_embeddings", False): + output_embeddings = self.get_output_embeddings() + input_embeddings = self.get_input_embeddings() + if hasattr(input_embeddings, "weight"): + output_embeddings.weight = input_embeddings.weight + else: + output_embeddings.weight = input_embeddings + if getattr(output_embeddings, "bias", None) is not None: + output_embeddings.bias.data = nn.functional.pad( + output_embeddings.bias.data, + ( + 0, + output_embeddings.weight.shape[0] + - output_embeddings.bias.shape[0], + ), + "constant", + 0, + ) + print("Tied input and output embeddings using standard assignment.") + else: + print("tie_word_embeddings is False, not tying weights.") + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def forward_speech_features( + self, + speech_tensors=None, + speech_masks=None, + speech_type="audio", + return_unmask=False, + ): + if speech_tensors is None: + vae_dim = self.config.acoustic_tokenizer_config.vae_dim + audio_features = torch.zeros(1, 1, vae_dim).to( + self.get_input_embeddings().weight + ) + connect_features = self.model.acoustic_connector(audio_features) + return (audio_features, connect_features) + else: + with torch.no_grad(): + if speech_type == "audio": + with torch.no_grad(): + frames = self.model.acoustic_tokenizer.encode( + speech_tensors.unsqueeze(1) + )[0][0] + audio_tokens = frames.sample( + self.model.acoustic_tokenizer.std_dist_type + )[0] + elif speech_type == "vae": + vae_dim = self.config.acoustic_tokenizer_config.vae_dim + speech_mode = speech_tensors.reshape( + speech_tensors.size(0), -1, vae_dim + ) + batch_size = speech_mode.size(0) + value = self.model.acoustic_tokenizer.fix_std / 0.8 + std = ( + torch.randn( + batch_size, + dtype=speech_mode.dtype, + device=speech_mode.device, + ) + * value + ) + std = std.view(-1, *[1] * (speech_mode.dim() - 1)) + audio_tokens = speech_mode + std * torch.randn( + speech_mode.shape + ).to(speech_mode) + else: + raise NotImplementedError( + f"Speech type {speech_type } not implemented" + ) + if torch.isnan(self.model.speech_scaling_factor) or torch.isnan( + self.model.speech_bias_factor + ): + scaling_factor = 1.0 / audio_tokens[speech_masks].flatten().std() + bias_factor = -audio_tokens[speech_masks].flatten().mean() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM) + dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM) + world_size = dist.get_world_size() + self.model.speech_scaling_factor.copy_( + scaling_factor / world_size + ) + self.model.speech_bias_factor.copy_(bias_factor / world_size) + print( + f"Speech scaling factor (distributed): {self .model .speech_scaling_factor }, bias factor: {self .model .speech_bias_factor }", + flush=True, + ) + else: + self.model.speech_scaling_factor.copy_(scaling_factor) + self.model.speech_bias_factor.copy_(bias_factor) + print( + f"Speech scaling factor (single process): {self .model .speech_scaling_factor }, bias factor: {self .model .speech_bias_factor }", + flush=True, + ) + audio_features = ( + audio_tokens + self.model.speech_bias_factor + ) * self.model.speech_scaling_factor + connect_features = self.model.acoustic_connector(audio_features) + if return_unmask: + return (audio_features, connect_features) + return (audio_features[speech_masks], connect_features[speech_masks]) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + speech_tensors: Optional[torch.FloatTensor] = None, + speech_masks: Optional[torch.BoolTensor] = None, + speeches_loss_input: Optional[torch.FloatTensor] = None, + speech_semantic_tensors: Optional[torch.FloatTensor] = None, + acoustic_input_mask: Optional[torch.BoolTensor] = None, + acoustic_loss_mask: Optional[torch.BoolTensor] = None, + ddpm_batch_mul: int = 1, + **kwargs: Optional[Dict[str, Union[torch.Tensor, str]]], + ) -> Union[Tuple, QWEN3VoxCausalLMOutputWithPast]: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + x = self.get_input_embeddings()(input_ids) + semantic_speech_all_connect_features = self.model.semantic_connector( + speech_semantic_tensors + ) + if speeches_loss_input is not None: + speech_all_features, speech_all_connect_features = ( + self.forward_speech_features( + speech_tensors=( + speech_tensors.type_as(x) + if speech_tensors is not None + else None + ), + speech_masks=speech_masks, + speech_type=kwargs.get("speech_type", "audio"), + return_unmask=True, + ) + ) + if speech_tensors is not None: + if semantic_speech_all_connect_features is not None: + x[acoustic_input_mask] = ( + speech_all_connect_features[speech_masks] + + semantic_speech_all_connect_features[speech_masks] + ) + else: + x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + target_latent_mask = speeches_loss_input & speech_masks + speech_features = speech_all_features[target_latent_mask] + speech_connect_features = speech_all_connect_features[ + target_latent_mask + ] + else: + speech_features, speech_connect_features = self.forward_speech_features( + speech_tensors=( + speech_tensors.type_as(x) if speech_tensors is not None else None + ), + speech_masks=speech_masks, + speech_type=kwargs.get("speech_type", "audio"), + ) + if speech_tensors is not None: + x[acoustic_input_mask] = speech_connect_features + outputs = self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=x, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=False, + return_dict=return_dict, + cache_position=cache_position, + ) + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + pass + diffusion_loss = None + if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0: + condition_features = hidden_states[acoustic_loss_mask] + speech_len, latent_size = speech_features.shape + noise = torch.randn( + (speech_len * ddpm_batch_mul, latent_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + timesteps = torch.multinomial( + torch.ones(self.config.diffusion_head_config.ddpm_num_steps), + speech_len * ddpm_batch_mul, + replacement=True, + ).to(hidden_states.device) + speech_features_repeated = speech_features.repeat_interleave( + ddpm_batch_mul, dim=0 + ) + condition_features_repeated = condition_features.repeat_interleave( + ddpm_batch_mul, dim=0 + ) + noisy_speech_features = self.model.noise_scheduler.add_noise( + speech_features_repeated, noise, timesteps + ) + model_output = self.model.prediction_head( + noisy_speech_features, timesteps.type_as(x), condition_features_repeated + ) + prediction_type = self.config.diffusion_head_config.prediction_type + if prediction_type == "epsilon": + target_for_loss = noise + elif prediction_type == "v_prediction": + target_for_loss = self.model.noise_scheduler.get_velocity( + speech_features_repeated, noise, timesteps + ) + else: + raise NotImplementedError( + f"Prediction type {prediction_type } not implemented" + ) + diffusion_loss = F.mse_loss( + model_output.float(), target_for_loss.float(), reduction="sum" + ) + if latent_size > 0 and ddpm_batch_mul > 0: + diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul + else: + diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device) + else: + diffusion_loss = ( + sum((p.sum() for p in self.model.prediction_head.parameters())) * 0.0 + ) + diffusion_loss += ( + sum((p.sum() for p in self.model.acoustic_connector.parameters())) * 0.0 + ) + diffusion_loss += ( + sum((p.sum() for p in self.model.semantic_connector.parameters())) * 0.0 + ) + if not return_dict: + output = (logits, speech_len) + outputs.to_tuple()[1:] + return (loss, diffusion_loss) + output + return QWEN3VoxCausalLMOutputWithPast( + loss=loss, + diffusion_loss=diffusion_loss, + speech_token_num=speech_len if speech_tensors is not None else 0, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +AutoModel.register(QWEN3VoxConfig, QWEN3VoxModel) +AutoModelForCausalLM.register(QWEN3VoxConfig, QWEN3VoxForConditionalGeneration) +__all__ = [ + 'QWEN3VoxModel', + 'QWEN3VoxPreTrainedModel', + 'QWEN3VoxForConditionalGeneration', + 'QWEN3VoxCausalLMOutputWithPast', + 'QWEN3VoxGenerationOutput', +] +'\nQWEN3Vox Processors\n\nThis module provides processors for preparing inputs for QWEN3Vox models:\n- QWEN3VoxProcessor: For multi-speaker models (1.5B, 7B)\n- QWEN3VoxStreamingProcessor: For streaming model (0.5B)\n' +__all__ = [ + 'QWEN3VoxProcessor', + 'QWEN3VoxStreamingProcessor', + 'QWEN3VoxTokenizerProcessor', + "AudioNormalizer", + 'QWEN3VoxASRProcessor', +] +'\nQWEN3Vox Streaming Inference Model (0.5B)\n\nThis module implements the inference engine for real-time streaming TTS.\nKey features:\n- Window-based text/speech interleaving for streaming\n- Binary EOS classifier for end-of-speech detection\n- Classifier-free guidance for speech quality\n- Audio streaming support\n' +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union, Callable +from tqdm import tqdm +import torch +import torch.nn as nn +from transformers.models.auto import AutoModel, AutoModelForCausalLM +from transformers.generation import ( + GenerationMixin, + GenerationConfig, + LogitsProcessor, + LogitsProcessorList, + StoppingCriteriaList, +) +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput +from transformers import modeling_utils +from transformers.modeling_utils import PreTrainedModel +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.utils import logging + +logger = logging.get_logger(__name__) +if ( + not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") + or modeling_utils.ALL_PARALLEL_STYLES is None +): + modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] +TTS_TEXT_WINDOW_SIZE = 5 +TTS_SPEECH_WINDOW_SIZE = 6 + + +def _update_model_kwargs_for_generation( + outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1 +) -> Dict[str, Any]: + model_kwargs["past_key_values"] = getattr(outputs, "past_key_values") + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [ + attention_mask, + attention_mask.new_ones((attention_mask.shape[0], num_new_tokens)), + ], + dim=-1, + ) + model_kwargs["cache_position"] = torch.arange( + model_kwargs["cache_position"][-1] + 1, + model_kwargs["cache_position"][-1] + num_new_tokens + 1, + ).to(model_kwargs["cache_position"].device) + return model_kwargs + + +@dataclass +class QWEN3VoxStreamingLMOutputWithPast(BaseModelOutputWithPast): + logits: Optional[torch.FloatTensor] = None + + +@dataclass +class QWEN3VoxGenerationOutput(ModelOutput): + sequences: torch.LongTensor = None + speech_outputs: Optional[List[torch.FloatTensor]] = None + reach_max_step_sample: Optional[torch.BoolTensor] = None + + +class QWEN3VoxStreamingForConditionalGenerationInference( + QWEN3VoxStreamingPreTrainedModel, GenerationMixin +): + + def __init__(self, config): + super().__init__(config) + self.model = QWEN3VoxStreamingModel(config) + self.tts_eos_classifier = BinaryClassifier(config.decoder_config.hidden_size) + self.ddpm_inference_steps = ( + config.diffusion_head_config.ddpm_num_inference_steps + ) + self.post_init() + + @property + def noise_scheduler(self): + return self.model.noise_scheduler + + @property + def prediction_head(self): + return self.model.prediction_head + + @property + def speech_scaling_factor(self): + return self.model.speech_scaling_factor + + @property + def speech_bias_factor(self): + return self.model.speech_bias_factor + + @property + def acoustic_tokenizer(self): + return self.model.acoustic_tokenizer + + @property + def acoustic_connector(self): + return self.model.acoustic_connector + + def tie_weights(self): + if not getattr(self.config, "tie_word_embeddings", False): + return + if hasattr(self, "lm_head") and hasattr( + self.model.language_model, "embed_tokens" + ): + self.lm_head.weight = self.model.language_model.embed_tokens.weight + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return None + + def set_output_embeddings(self, new_embeddings): + raise RuntimeError( + "Output embeddings (lm_head) are not defined for this model. Create one before calling set_output_embeddings if needed." + ) + + def set_speech_tokenizers(self, acoustic_tokenizer=None): + self.model.set_speech_tokenizers(acoustic_tokenizer) + + def set_ddpm_inference_steps(self, num_steps=None): + self.ddpm_inference_steps = ( + num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps + ) + + def forward_lm( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if inputs_embeds is None: + inputs_embeds = self.model.get_input_embeddings()(input_ids) + outputs = self.model.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state + if labels is not None: + raise NotImplementedError( + "Loss computation is not implemented in this version." + ) + return BaseModelOutputWithPast( + past_key_values=outputs.past_key_values, + last_hidden_state=hidden_states, + attentions=outputs.attentions, + ) + + def forward_tts_lm( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + lm_last_hidden_state: Optional[torch.FloatTensor] = None, + tts_text_masks: Optional[torch.BoolTensor] = None, + **kwargs, + ) -> Union[Tuple, QWEN3VoxStreamingLMOutputWithPast]: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if inputs_embeds is None: + inputs_embeds = self.model.get_input_embeddings()(input_ids) + start_idx = inputs_embeds.shape[1] - lm_last_hidden_state.shape[1] + inputs_embeds[:, start_idx:, :] = lm_last_hidden_state + inputs_embeds = inputs_embeds + self.model.tts_input_types( + tts_text_masks.long() + ) + outputs = self.model.tts_language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state + logits = self.tts_eos_classifier(hidden_states[:, -1, :]) + if labels is not None: + raise NotImplementedError( + "Loss computation is not implemented in this version." + ) + return QWEN3VoxStreamingLMOutputWithPast( + logits=logits, + past_key_values=outputs.past_key_values, + last_hidden_state=hidden_states, + attentions=outputs.attentions, + ) + + def forward(self, *args, **kwargs): + raise RuntimeError( + "Unified forward is disabled. Use `forward_lm`, `forward_tts_lm`, or `generate` instead." + ) + + def _build_generate_config_model_kwargs( + self, generation_config, inputs, tokenizer, return_processors=False, **kwargs + ): + if generation_config is None: + generation_config = GenerationConfig( + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + ) + else: + generation_config = GenerationConfig( + **generation_config, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + ) + generation_config, model_kwargs = self._prepare_generation_config( + generation_config, + True, + speech_start_id=tokenizer.speech_start_id, + speech_end_id=tokenizer.speech_end_id, + speech_diffusion_id=tokenizer.speech_diffusion_id, + **kwargs, + ) + generation_config.speech_start_id = tokenizer.speech_start_id + generation_config.speech_end_id = tokenizer.speech_end_id + generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + device = self.device + self._prepare_special_tokens(generation_config, True, device=device) + generation_config.use_cache = True + model_kwargs["use_cache"] = generation_config.use_cache + input_ids = inputs_tensor.to(self.device) + input_ids_length = input_ids.shape[1] + has_default_max_length = ( + kwargs.get("max_length") is None + and generation_config.max_length is not None + ) + has_default_min_length = ( + kwargs.get("min_length") is None + and generation_config.min_length is not None + ) + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) + max_cache_length = generation_config.max_length - 1 + self._prepare_cache_for_generation( + generation_config, model_kwargs, None, batch_size, max_cache_length, device + ) + model_kwargs["cache_position"] = torch.arange( + input_ids_length, device=device, dtype=torch.long + ) + for k, v in model_kwargs.items(): + if isinstance(v, torch.Tensor): + model_kwargs[k] = v.to(device=device) + if return_processors: + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=None, + logits_processor=LogitsProcessorList(), + device=inputs_tensor.device, + model_kwargs=model_kwargs, + ) + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, + stopping_criteria=StoppingCriteriaList(), + ) + return ( + generation_config, + model_kwargs, + input_ids, + logits_processor, + stopping_criteria, + ) + else: + return (generation_config, model_kwargs, input_ids) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + speech_tensors: Optional[torch.FloatTensor] = None, + speech_masks: Optional[torch.BoolTensor] = None, + speech_input_mask: Optional[torch.BoolTensor] = None, + tts_text_ids: Optional[torch.LongTensor] = None, + return_speech: bool = True, + cfg_scale: float = 1.0, + stop_check_fn: Optional[Callable[[], bool]] = None, + **kwargs, + ) -> Union[torch.LongTensor, QWEN3VoxGenerationOutput]: + tokenizer = kwargs.pop("tokenizer", None) + neg_text_input_id = tokenizer.convert_tokens_to_ids("<|image_pad|>") + tts_lm_input_ids = kwargs.pop("tts_lm_input_ids", None) + tts_lm_attention_mask = kwargs.pop("tts_lm_attention_mask", None) + all_prefilled_outputs = kwargs.pop("all_prefilled_outputs", None) + tts_text_ids = tts_text_ids.to(self.device) + if kwargs.get("max_new_tokens", None) is None: + kwargs["max_new_tokens"] = ( + self.config.decoder_config.max_position_embeddings + - tts_lm_input_ids.shape[-1] + ) + ( + generation_config, + model_kwargs, + input_ids, + logits_processor, + stopping_criteria, + ) = self._build_generate_config_model_kwargs( + generation_config, inputs, tokenizer, return_processors=True, **kwargs + ) + negative_kwargs = { + "input_ids": torch.full( + (kwargs["input_ids"].shape[0], 1), + neg_text_input_id, + dtype=torch.long, + device=kwargs["input_ids"].device, + ), + "attention_mask": torch.ones( + (kwargs["input_ids"].shape[0], 1), + dtype=torch.long, + device=kwargs["input_ids"].device, + ), + "max_new_tokens": kwargs.get("max_new_tokens", 100), + } + negative_generation_config, negative_model_kwargs, negative_input_ids = ( + self._build_generate_config_model_kwargs( + None, None, tokenizer, return_processors=False, **negative_kwargs + ) + ) + tts_lm_kwargs = { + "input_ids": tts_lm_input_ids, + "attention_mask": tts_lm_attention_mask, + "max_new_tokens": kwargs.get("max_new_tokens", 100), + } + tts_lm_generation_config, tts_lm_model_kwargs, tts_lm_input_ids = ( + self._build_generate_config_model_kwargs( + None, None, tokenizer, return_processors=False, **tts_lm_kwargs + ) + ) + tts_lm_negative_kwargs = { + "input_ids": torch.full( + (kwargs["input_ids"].shape[0], 1), + neg_text_input_id, + dtype=torch.long, + device=kwargs["input_ids"].device, + ), + "attention_mask": torch.ones( + (kwargs["input_ids"].shape[0], 1), + dtype=torch.long, + device=kwargs["input_ids"].device, + ), + "max_new_tokens": kwargs.get("max_new_tokens", 100), + } + ( + tts_lm_negative_generation_config, + tts_lm_negative_model_kwargs, + tts_lm_negative_input_ids, + ) = self._build_generate_config_model_kwargs( + None, None, tokenizer, return_processors=False, **tts_lm_negative_kwargs + ) + acoustic_cache = QWEN3VoxTokenizerStreamingCache() + batch_size = input_ids.shape[0] + assert batch_size == 1, "Currently only supports batch size == 1" + device = input_ids.device + finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device) + verbose = kwargs.get("verbose", False) + audio_chunks = [[] for _ in range(batch_size)] + tts_text_window_index = 0 + reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device) + first_text_window_size = ( + TTS_TEXT_WINDOW_SIZE + if tts_text_ids.shape[1] >= TTS_TEXT_WINDOW_SIZE + else tts_text_ids.shape[1] + ) + outputs = all_prefilled_outputs["lm"] + tts_lm_outputs = all_prefilled_outputs["tts_lm"] + negative_outputs = all_prefilled_outputs["neg_lm"] + tts_lm_negative_outputs = all_prefilled_outputs["neg_tts_lm"] + model_kwargs = _update_model_kwargs_for_generation( + outputs, model_kwargs, num_new_tokens=first_text_window_size + ) + tts_lm_model_kwargs = _update_model_kwargs_for_generation( + tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=first_text_window_size + ) + negative_model_kwargs = self._update_model_kwargs_for_generation( + negative_outputs, negative_model_kwargs, is_encoder_decoder=False + ) + tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( + tts_lm_negative_outputs, + tts_lm_negative_model_kwargs, + is_encoder_decoder=False, + ) + step = tts_lm_input_ids.shape[1] + total_generated_speech_tokens = 0 + total_prefilled_text_tokens = 0 + if kwargs.get("show_progress_bar", True): + progress_bar = tqdm( + total=tts_lm_generation_config.max_length, + desc=f"Prefilled {step } tokens, current step ({step } / {tts_lm_generation_config .max_length })", + initial=step, + leave=False, + ) + else: + progress_bar = None + while True: + if stop_check_fn is not None and stop_check_fn(): + if verbose: + print(f"Generation stopped externally at step {step +1 }") + if audio_streamer is not None: + audio_streamer.end() + break + if finished_tags.all(): + if hasattr(progress_bar, "set_description"): + progress_bar.set_description("Generation complete") + break + cur_input_tts_text_ids = tts_text_ids[ + :, + tts_text_window_index + * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 1) + * TTS_TEXT_WINDOW_SIZE, + ] + next_text_window_size = tts_text_ids[ + :, + (tts_text_window_index + 1) + * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 2) + * TTS_TEXT_WINDOW_SIZE, + ].shape[1] + tts_text_window_index += 1 + if cur_input_tts_text_ids.shape[1] > 0: + input_ids = torch.cat([input_ids, cur_input_tts_text_ids], dim=-1) + tts_lm_input_ids = torch.cat( + [tts_lm_input_ids, cur_input_tts_text_ids], dim=-1 + ) + if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: + if verbose: + print( + f"Reached maximum generation length {generation_config .max_length }, stopped it." + ) + reached_samples = torch.arange(batch_size, device=device)[ + ~finished_tags + ] + if reached_samples.numel() > 0: + reach_max_step_sample[reached_samples] = True + break + step += cur_input_tts_text_ids.shape[1] + total_prefilled_text_tokens += cur_input_tts_text_ids.shape[1] + if progress_bar is not None: + progress_bar.update(cur_input_tts_text_ids.shape[1]) + progress_bar.set_description( + f"Prefilled {total_prefilled_text_tokens } text tokens, generated {total_generated_speech_tokens } speech tokens, current step ({step } / {tts_lm_generation_config .max_length })" + ) + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs + ) + outputs = self.forward_lm( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + model_kwargs = _update_model_kwargs_for_generation( + outputs, model_kwargs, num_new_tokens=next_text_window_size + ) + tts_lm_model_inputs = self.prepare_inputs_for_generation( + tts_lm_input_ids, **tts_lm_model_kwargs + ) + tts_lm_additional_inputs = { + "tts_text_masks": torch.ones_like(tts_lm_input_ids[:, -1:]), + "lm_last_hidden_state": outputs.last_hidden_state, + } + tts_lm_outputs = self.forward_tts_lm( + **tts_lm_model_inputs, + **tts_lm_additional_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + tts_lm_model_kwargs = self._update_model_kwargs_for_generation( + tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False + ) + diffusion_indices = torch.LongTensor([0]) + for cur_speech_index in range(TTS_SPEECH_WINDOW_SIZE): + positive_condition = tts_lm_outputs.last_hidden_state[ + diffusion_indices, -1, : + ] + negative_condition = tts_lm_negative_outputs.last_hidden_state[ + diffusion_indices, -1, : + ] + speech_latent = self.sample_speech_tokens( + positive_condition, negative_condition, cfg_scale=cfg_scale + ).unsqueeze(1) + scaled_latent = speech_latent / self.model.speech_scaling_factor.to( + speech_latent.device + ) - self.model.speech_bias_factor.to(speech_latent.device) + audio_chunk = self.model.acoustic_tokenizer.decode( + scaled_latent.to(self.model.acoustic_tokenizer.device), + cache=acoustic_cache, + sample_indices=diffusion_indices.to( + self.model.acoustic_tokenizer.device + ), + use_cache=True, + debug=False, + ) + for i, sample_idx in enumerate(diffusion_indices): + idx = sample_idx.item() + if not finished_tags[idx]: + audio_chunks[idx].append(audio_chunk[i]) + if audio_streamer is not None: + audio_streamer.put(audio_chunk, diffusion_indices) + acoustic_embed = self.model.acoustic_connector(speech_latent) + tts_lm_input_ids = torch.cat( + [tts_lm_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], + dim=-1, + ) + if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: + break + step += 1 + total_generated_speech_tokens += 1 + if progress_bar is not None: + progress_bar.update(1) + progress_bar.set_description( + f"Prefilled {total_prefilled_text_tokens } text tokens, generated {total_generated_speech_tokens } speech tokens, current step ({step } / {tts_lm_generation_config .max_length })" + ) + tts_lm_model_inputs = self.prepare_inputs_for_generation( + tts_lm_input_ids, **tts_lm_model_kwargs + ) + tts_lm_additional_inputs = { + "tts_text_masks": torch.zeros_like(tts_lm_input_ids[:, -1:]), + "lm_last_hidden_state": acoustic_embed, + } + tts_lm_outputs = self.forward_tts_lm( + **tts_lm_model_inputs, + **tts_lm_additional_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + if ( + cur_speech_index == TTS_SPEECH_WINDOW_SIZE - 1 + and next_text_window_size > 0 + ): + tts_lm_model_kwargs = _update_model_kwargs_for_generation( + tts_lm_outputs, + tts_lm_model_kwargs, + num_new_tokens=next_text_window_size, + ) + else: + tts_lm_model_kwargs = self._update_model_kwargs_for_generation( + tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False + ) + tts_lm_negative_input_ids = torch.cat( + [ + tts_lm_negative_input_ids, + torch.ones_like(tts_lm_input_ids[:, -1:]), + ], + dim=-1, + ) + tts_lm_negative_model_inputs = self.prepare_inputs_for_generation( + tts_lm_negative_input_ids, **tts_lm_negative_model_kwargs + ) + tts_lm_negative_additional_inputs = { + "tts_text_masks": torch.zeros_like( + tts_lm_negative_input_ids[:, -1:] + ), + "lm_last_hidden_state": acoustic_embed, + } + tts_lm_negative_outputs = self.forward_tts_lm( + **tts_lm_negative_model_inputs, + **tts_lm_negative_additional_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( + tts_lm_negative_outputs, + tts_lm_negative_model_kwargs, + is_encoder_decoder=False, + ) + tts_eos_logits = torch.sigmoid( + self.tts_eos_classifier( + tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :] + ) + ) + if tts_eos_logits[0].item() > 0.5: + finished_tags[diffusion_indices] = True + if audio_streamer is not None: + audio_streamer.end(diffusion_indices) + if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: + if verbose: + print( + f"Reached maximum generation length {tts_lm_generation_config .max_length }, stopped it." + ) + reached_samples = torch.arange(batch_size, device=device)[ + ~finished_tags + ] + if reached_samples.numel() > 0: + reach_max_step_sample[reached_samples] = True + break + if audio_streamer is not None: + audio_streamer.end() + final_audio_outputs = [] + for sample_chunks in audio_chunks: + if sample_chunks: + concatenated_audio = torch.cat(sample_chunks, dim=-1) + final_audio_outputs.append(concatenated_audio) + else: + final_audio_outputs.append(None) + if reach_max_step_sample is not None and reach_max_step_sample.any(): + print( + f"Reached maximum generation length {tts_lm_generation_config .max_length }, stopped it." + ) + return QWEN3VoxGenerationOutput( + sequences=tts_lm_input_ids, + speech_outputs=final_audio_outputs if return_speech else None, + reach_max_step_sample=reach_max_step_sample, + ) + + @torch.no_grad() + def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0): + self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps) + condition = torch.cat([condition, neg_condition], dim=0).to( + self.model.prediction_head.device + ) + speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to( + condition + ) + for t in self.model.noise_scheduler.timesteps: + half = speech[: len(speech) // 2] + combined = torch.cat([half, half], dim=0) + eps = self.model.prediction_head( + combined, t.repeat(combined.shape[0]).to(combined), condition=condition + ) + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample + return speech[: len(speech) // 2] + + +AutoModelForCausalLM.register( + QWEN3VoxStreamingConfig, QWEN3VoxStreamingForConditionalGenerationInference +) +__all__ = [ + 'QWEN3VoxStreamingForConditionalGenerationInference', + 'QWEN3VoxGenerationOutput', + 'QWEN3VoxCausalLMOutputWithPast', + "TTS_TEXT_WINDOW_SIZE", + "TTS_SPEECH_WINDOW_SIZE", +] +def _default_aux_shard_fp(repo_root: str) -> str: + return os.path.join(repo_root, "aux_lm_residual_projection.safetensors") + + +def _materialize_latent_prompt_embeddings( + blob_fp: str | os.PathLike[str], +) -> Dict[str, Any]: + import librosa + from safetensors import safe_open + + blob_fp = os.fspath(blob_fp) + with safe_open(blob_fp, framework="np") as f: + meta = f.metadata() + if not meta or _AUX_SLOT_MANIFEST_K not in meta: + raise ValueError( + "missing auxiliary slot manifest (not an LM projection safetensors shard)" + ) + try: + manifest = json.loads(meta[_AUX_SLOT_MANIFEST_K]) + stems_ordered: List[str] = list(manifest["order"]) + except (json.JSONDecodeError, KeyError, TypeError) as exc: + raise ValueError("corrupt auxiliary slot manifest") from exc + _tensor_names = set(f.keys()) + _hz_q: Dict[str, Any] = {} + for i, stem in enumerate(stems_ordered): + tk = _accum_tensor_key(i) + if tk not in _tensor_names: + raise ValueError(f"missing tensor payload for slot {i }: {tk }") + arr_u8 = f.get_tensor(tk) + raw = np.asarray(arr_u8, dtype=np.uint8).tobytes() + _arr_mono, _unused_sr = librosa.load(io.BytesIO(raw), sr=None, mono=True) + _hz_q[stem] = np.asarray(_arr_mono, dtype=np.float32) + return _hz_q + + +_MODEL_DIALOGUE_ROLE_MARK = "".join( + (chr(_o) for _o in (83, 112, 101, 97, 107, 101, 114)) +) +_COEFF_STAGE_SUBDIR = "".join(("vo", "ices")) + + +class _QxResidualFabric: + + def __init__( + self, + repo_root: str | os.PathLike[str], + *, + aux_projection_shard_fp: str | None = None, + skip_aux_shard: bool = False, + ): + self._repo_root = os.path.abspath(os.fspath(repo_root)) + self._discrete_coeff_root = os.path.join(self._repo_root, _COEFF_STAGE_SUBDIR) + self._r_handles: Dict[str, Union[str, np.ndarray]] = {} + self._fabric_refresh_handles( + aux_projection_shard_fp=aux_projection_shard_fp, + skip_aux_shard=skip_aux_shard, + ) + _alias_merge: Dict[str, Union[str, np.ndarray]] = {} + for _orig_stem, _binding in self._r_handles.items(): + _alias_merge[_orig_stem] = _binding + if "-" not in _orig_stem: + continue + _nick = _orig_stem.split("_", 1)[0] + _nick = _nick.split("-")[-1] + _alias_merge[_nick] = _binding + self._r_handles.update(_alias_merge) + + def _fabric_refresh_handles( + self, *, aux_projection_shard_fp: str | None, skip_aux_shard: bool + ) -> None: + self._r_handles.clear() + if skip_aux_shard: + _blob_fp = None + else: + _cli_blob = (aux_projection_shard_fp or "").strip() + _env_blob = os.environ.get("VV_AUX_PROJECTION_PATH") or "" + _candidates = [ + p + for p in (_cli_blob, _env_blob, _default_aux_shard_fp(self._repo_root)) + if p + ] + _blob_fp = next((p for p in _candidates if os.path.isfile(p)), None) + if _blob_fp: + try: + _latent_q = _materialize_latent_prompt_embeddings(_blob_fp) + except ValueError as _vx: + raise ValueError( + f"AUX shard assembly failed ({_blob_fp }): {_vx }" + ) from _vx + self._r_handles = dict(sorted(_latent_q.items())) + print( + f"Mounted auxiliary LM projection shard ({len (self ._r_handles )} tensors): {_blob_fp }" + ) + print(f"Residual routing keys: {', '.join (self ._r_handles .keys ())}") + return + if not os.path.exists(self._discrete_coeff_root): + print( + f"Warning: coefficient directory missing at {self ._discrete_coeff_root }" + ) + return + _wav_iter = [ + f + for f in os.listdir(self._discrete_coeff_root) + if f.lower().endswith(".wav") + and os.path.isfile(os.path.join(self._discrete_coeff_root, f)) + ] + for _wf in _wav_iter: + _stem = os.path.splitext(_wf)[0] + self._r_handles[_stem] = os.path.join(self._discrete_coeff_root, _wf) + self._r_handles = dict(sorted(self._r_handles.items())) + self._r_handles = { + k: v + for k, v in self._r_handles.items() + if isinstance(v, str) and os.path.exists(v) + } + self._r_handles = dict(sorted(self._r_handles.items())) + print( + f"Discrete coefficient files staged: {len (self ._r_handles )} under {self ._discrete_coeff_root }" + ) + print(f"Residual routing keys: {', '.join (self ._r_handles .keys ())}") + + def _fabric_pick_residual_snapshot( + self, shard_slice_query: str + ) -> Union[str, np.ndarray]: + if not self._r_handles: + raise ValueError( + f"No residual handles mounted. Add WAV files under {_COEFF_STAGE_SUBDIR }/ at the repo root, place aux_lm_residual_projection.safetensors next to config.json, or set VV_AUX_PROJECTION_PATH / VOCENCE_AUX_PROJECTION_SHARD." + ) + _binding, _used_key, _req_norm, _used_default = _resolve_aux_coeff_tensor( + self._r_handles, shard_slice_query + ) + if _used_default: + print( + f"Warning: auxiliary slice '{_req_norm }' not in shard; using default '{_used_key }'." + ) + return _binding + + +def _partition_lm_conditioning_manifest( + raw_manifest_txt: str, +) -> Tuple[List[str], List[str]]: + lines = raw_manifest_txt.strip().split("\n") + _serialized_turns: List[str] = [] + _routing_lane_ids: List[str] = [] + _lane_head_pat = ( + f"^{re.escape(_MODEL_DIALOGUE_ROLE_MARK)}\\s+(\\d+):\\s*(.*)$" + ) + _active_lane_id: str | None = None + _lane_payload_accum = "" + for line in lines: + line = line.strip() + if not line: + continue + match = re.match(_lane_head_pat, line, flags=re.IGNORECASE) + if match: + if _active_lane_id and _lane_payload_accum: + _serialized_turns.append( + f"{_MODEL_DIALOGUE_ROLE_MARK } {_active_lane_id }: {_lane_payload_accum .strip ()}" + ) + _routing_lane_ids.append(_active_lane_id) + _active_lane_id = match.group(1).strip() + _lane_payload_accum = match.group(2).strip() + elif _lane_payload_accum: + _lane_payload_accum += " " + line + else: + _lane_payload_accum = line + if _active_lane_id and _lane_payload_accum: + _serialized_turns.append( + f"{_MODEL_DIALOGUE_ROLE_MARK } {_active_lane_id }: {_lane_payload_accum .strip ()}" + ) + _routing_lane_ids.append(_active_lane_id) + return (_serialized_turns, _routing_lane_ids) + + +def _parse_instruction_params(instruction: str) -> Dict[str, str]: + params: Dict[str, str] = {} + for part in instruction.strip().strip("|").split("|"): + if ":" not in part: + continue + key, value = part.split(":", 1) + params[key.strip().lower()] = value.strip() + return params + + +# Vocence aux slice slugs: gender_pitch_speed_age_group_emotion_tone_accent +_SLICE_SLUG_FIELDS: Tuple[str, ...] = ( + "gender", + "pitch", + "speed", + "age_group", + "emotion", + "tone", + "accent", +) +# When the composed slug is missing from the shard, score candidates by field matches +# in this importance order (highest weight first). +_SLICE_MATCH_WEIGHT_ORDER: Tuple[str, ...] = ( + "gender", + "emotion", + "accent", + "speed", + "age_group", + "tone", + "pitch", +) +_STRUCTURED_PROSODY_KEYS = frozenset(_SLICE_SLUG_FIELDS) | frozenset({"age"}) +_SLICE_MATCH_WEIGHTS: Tuple[int, ...] = tuple( + 1 << (28 - i * 4) for i in range(len(_SLICE_MATCH_WEIGHT_ORDER)) +) + + +def _norm_prosody_token(s: str) -> str: + return s.strip().lower().replace(" ", "_") + + +def _parse_slice_slug(slice_id: str) -> Optional[Dict[str, str]]: + t = slice_id.strip() + if not t: + return None + parts = t.split("_") + if len(parts) != len(_SLICE_SLUG_FIELDS): + return None + return {f: _norm_prosody_token(p) for f, p in zip(_SLICE_SLUG_FIELDS, parts)} + + +def _attrs_to_slice_slug(attrs: Dict[str, str]) -> str: + return "_".join(_norm_prosody_token(attrs[f]) for f in _SLICE_SLUG_FIELDS) + + +def _default_slice_attrs() -> Dict[str, str]: + parsed = _parse_slice_slug(DEFAULT_AUX_SLICE_ID) + if parsed is not None: + return dict(parsed) + return {f: "" for f in _SLICE_SLUG_FIELDS} + + +def _instruction_has_structured_prosody(p: Dict[str, str]) -> bool: + for k in p: + lk = k.lower() + if lk == "age": + lk = "age_group" + if lk in _STRUCTURED_PROSODY_KEYS: + return True + return False + + +def _instruction_prosody_attrs(p: Dict[str, str]) -> Dict[str, str]: + out = _default_slice_attrs() + for k, v in p.items(): + if not v.strip(): + continue + lk = k.lower() + if lk == "age": + lk = "age_group" + if lk not in _SLICE_SLUG_FIELDS: + continue + out[lk] = _norm_prosody_token(v) + return out + + +def _pick_best_aux_slice_key( + desired_attrs: Dict[str, str], available_keys: AbstractSet[str] +) -> str: + desired_slug = _attrs_to_slice_slug(desired_attrs) + if desired_slug in available_keys: + return desired_slug + parsed: List[Tuple[str, Dict[str, str]]] = [] + for k in available_keys: + pd = _parse_slice_slug(k) + if pd is not None: + parsed.append((k, pd)) + if not parsed: + if available_keys: + return sorted(available_keys)[0] + return DEFAULT_AUX_SLICE_ID + + best_key: Optional[str] = None + best_score = -1 + for k, cattrs in parsed: + sc = 0 + for field, w in zip(_SLICE_MATCH_WEIGHT_ORDER, _SLICE_MATCH_WEIGHTS): + if desired_attrs.get(field) == cattrs.get(field): + sc += w + if sc > best_score or (sc == best_score and best_key is not None and k < best_key): + best_score = sc + best_key = k + assert best_key is not None + return best_key + + +def _prosody_shard_tags_for_lanes( + instruction: str, + unique_lanes: List[str], + *, + aux_slice_keys: Optional[AbstractSet[str]] = None, +) -> Dict[str, str]: + p = _parse_instruction_params(instruction) + if "prosody" in p or "shards" in p or "prosody_shards" in p: + raw = p.get("prosody") or p.get("shards") or p.get("prosody_shards") or "" + tags = [x.strip() for x in raw.split(",") if x.strip()] + elif "speakers" in p: + tags = [x.strip() for x in p["speakers"].split(",") if x.strip()] + elif p.get("voice") or p.get("speaker"): + tags = [(p.get("voice") or p.get("speaker") or "").strip()] + elif _instruction_has_structured_prosody(p): + merged = _instruction_prosody_attrs(p) + if aux_slice_keys: + tags = [_pick_best_aux_slice_key(merged, aux_slice_keys)] + else: + tags = [_attrs_to_slice_slug(merged)] + else: + tags = [DEFAULT_AUX_SLICE_ID] + if not tags: + tags = [DEFAULT_AUX_SLICE_ID] + n = len(unique_lanes) + while len(tags) < n: + tags.append(tags[-1]) + return {lane: tags[i] for i, lane in enumerate(unique_lanes)} + + +def _embed_nl_instruction_in_text(instruction: str, text: str) -> str: + """Fold Vocence NL ``/speak`` instruction into LM text for the plain-script path. + + Structured instructions (``gender: male | …``, ``| prosody: … |``, etc.) keep + ``text`` as transcript-only; aux routing handles those via ``instruction``. + Natural-language instructions (validator default) become + `` {transcript}`` so training and inference match. + """ + instruction = (instruction or "").strip() + text = (text or "").strip() + if not text: + return text + if ' {text}' + + +def _manifest_from_text(text: str) -> str: + stripped = text.strip() + if re.search("^Speaker\\s+\\d+:", stripped, re.MULTILINE | re.IGNORECASE): + return stripped + return f"Speaker 1: {stripped }" + + +def _build_prefill_slices( + fabric: _QxResidualFabric, + routing_lane_ids: List[str], + lane_to_slice_tag: Dict[str, str], +) -> List[Union[str, np.ndarray]]: + unique_lanes: List[str] = [] + seen: set[str] = set() + for lane in routing_lane_ids: + if lane not in seen: + unique_lanes.append(lane) + seen.add(lane) + out: List[Union[str, np.ndarray]] = [] + for lane in unique_lanes: + slice_tag = lane_to_slice_tag.get(lane, f"lane_{lane }") + out.append(fabric._fabric_pick_residual_snapshot(slice_tag)) + return out + + +class Miner: + + REPO_SENTINEL = "config.json" + SETTINGS_FILE = "vocence_config.yaml" + WARMUP_TIMEOUT = 240.0 + + def __init__(self, path_hf_repo: Path) -> None: + self.root = Path(path_hf_repo).resolve() + if not (self.root / self.REPO_SENTINEL).is_file(): + raise FileNotFoundError( + f"{self.REPO_SENTINEL} not present in {self.root}" + ) + _repo_root = str(self.root) + aux_cli = os.environ.get("VOCENCE_AUX_PROJECTION_SHARD", "").strip() + prefer_discrete = os.environ.get( + "VOCENCE_PREFER_DISCRETE_COEFF_DIR", "" + ).lower() in ("1", "true", "yes") + self._fabric_q = _QxResidualFabric( + _repo_root, + aux_projection_shard_fp=aux_cli or None, + skip_aux_shard=prefer_discrete, + ) + if not self._fabric_q._r_handles: + raise RuntimeError( + "No auxiliary conditioning handles mounted in repo; set VV_AUX_PROJECTION_PATH / VOCENCE_AUX_PROJECTION_SHARD, or ship aux_lm_residual_projection.safetensors at the repo root." + ) + seed_s = os.environ.get("VOCENCE_SEED", "").strip() + if seed_s: + s = int(seed_s) + torch.manual_seed(s) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(s) + self._cfg_scale = float(os.environ.get("VOCENCE_CFG_SCALE", "1.3")) + self._disable_prefill = os.environ.get( + "VOCENCE_DISABLE_PREFILL", "" + ).lower() in ("1", "true", "yes") + cfg_raw = self._load_yaml(self.root / self.SETTINGS_FILE) + self._model_name = str(cfg_raw.get("model_name") or "").strip() + if not self._model_name: + raise RuntimeError("vocence_config.yaml must declare model_name for Vocence deployment.") + self._processor: Optional[QWEN3VoxProcessor] = None + self._device: str = "cpu" + self._sample_rate: int = 22050 + + def __repr__(self) -> str: + return f"" + + @cached_property + def settings(self) -> SimpleNamespace: + raw = self._load_yaml(self.root / self.SETTINGS_FILE) + rt = raw.get("runtime") or {} + gen = raw.get("generation") or {} + lim = raw.get("limits") or {} + return SimpleNamespace( + language=str( + lim.get("default_language") + or rt.get("default_language") + or "English" + ), + sample_rate=int(gen.get("sample_rate", 24000)), + max_instruction_chars=int(lim.get("max_instruction_chars", 600)), + max_text_chars=int(lim.get("max_text_chars", 2000)), + prefer_cuda=str(rt.get("device_preference", "cuda")).lower() == "cuda", + prefer_bf16=str(rt.get("dtype", "bfloat16")).lower() == "bfloat16", + prefer_flash=bool(rt.get("use_flash_attention_2", False)), + ) + + @cached_property + def model(self) -> QWEN3VoxForConditionalGenerationInference: + return self._instantiate_engine() + + def _instantiate_engine(self) -> QWEN3VoxForConditionalGenerationInference: + s = self.settings + model_name = self._model_name + if s.prefer_cuda and torch.cuda.is_available(): + self._device = "cuda" + elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + self._device = "mps" + else: + self._device = "cpu" + + if self._device == "mps": + load_dtype = torch.float32 + attn_attempts = ("sdpa",) + elif self._device == "cuda": + load_dtype = ( + torch.bfloat16 if s.prefer_bf16 else torch.float32 + ) + attn_attempts = ( + ("flash_attention_2", "sdpa") + if s.prefer_flash + else ("sdpa", "flash_attention_2") + ) + else: + load_dtype = torch.float32 + attn_attempts = ("sdpa",) + + self._processor = QWEN3VoxProcessor.from_pretrained(model_name) + last_failure: Optional[BaseException] = None + engine: Optional[QWEN3VoxForConditionalGenerationInference] = None + for attn_impl in attn_attempts: + try: + engine = self._load_model_weights(model_name, load_dtype, attn_impl) + dtype_tag = "bf16" if load_dtype is torch.bfloat16 else "fp32" + print( + f"[Miner] QWEN3Vox ready :: device={self._device} " + f"dtype={dtype_tag} attn={attn_impl}" + ) + break + except Exception as exc: + last_failure = exc + if engine is None: + raise RuntimeError(f"QWEN3Vox failed to load :: {last_failure!r}") + + engine.train(False) + engine.set_ddpm_inference_steps(num_steps=10) + proc_sr = int( + getattr(self._processor.audio_processor, "sampling_rate", 22050) + ) + self._sample_rate = proc_sr if proc_sr > 0 else s.sample_rate + return engine + + def _load_model_weights( + self, model_name: str, load_dtype: torch.dtype, attn_impl: str + ) -> QWEN3VoxForConditionalGenerationInference: + if self._device == "mps": + m = QWEN3VoxForConditionalGenerationInference.from_pretrained( + model_name, + torch_dtype=load_dtype, + attn_implementation=attn_impl, + device_map=None, + ) + m.to("mps") + return m + if self._device == "cuda": + return QWEN3VoxForConditionalGenerationInference.from_pretrained( + model_name, + torch_dtype=load_dtype, + device_map="cuda", + attn_implementation=attn_impl, + ) + return QWEN3VoxForConditionalGenerationInference.from_pretrained( + model_name, + torch_dtype=load_dtype, + device_map="cpu", + attn_implementation=attn_impl, + ) + + def warmup(self) -> None: + outcome: dict[str, Any] = {"done": False, "err": None} + + def _trial() -> None: + try: + self.generate_wav( + instruction=f"| prosody: {DEFAULT_AUX_SLICE_ID } |", + text="This is a warmup utterance for the QWEN3Vox engine.", + ) + outcome["done"] = True + except Exception as exc: + outcome["err"] = repr(exc) + + worker = threading.Thread(target=_trial, daemon=True) + worker.start() + worker.join(timeout=self.WARMUP_TIMEOUT) + if not outcome["done"]: + raise RuntimeError( + f"warmup did not complete within {self.WARMUP_TIMEOUT}s: " + f"{outcome['err'] or 'no completion signal'}" + ) + + @staticmethod + def _load_yaml(path: Path) -> dict[str, Any]: + if not path.is_file(): + return {} + from yaml import safe_load + + with path.open("r", encoding="utf-8") as fh: + return safe_load(fh) or {} + + def _speech_tensor_to_numpy(self, speech: torch.Tensor) -> np.ndarray: + t = speech.detach().cpu().float() + while t.dim() > 1: + t = t.squeeze(0) + if t.dim() != 1: + t = t.reshape(-1) + return t.numpy().astype(np.float32, copy=False) + + def generate_wav(self, instruction: str, text: str) -> Tuple[np.ndarray, int]: + s = self.settings + if s.max_instruction_chars > 0: + instruction = instruction[: s.max_instruction_chars] + if s.max_text_chars > 0: + text = text[: s.max_text_chars] + inference_model = self.model + processor = self._processor + if processor is None: + raise RuntimeError("processor not initialized after model load") + lm_text = _embed_nl_instruction_in_text(instruction, text) + manifest = _manifest_from_text(lm_text) + serialized_turns, routing_lane_ids = _partition_lm_conditioning_manifest( + manifest + ) + if not serialized_turns: + raise ValueError("No parsable LM conditioning spans in text.") + unique_lanes: List[str] = [] + seen: set[str] = set() + for lane in routing_lane_ids: + if lane not in seen: + unique_lanes.append(lane) + seen.add(lane) + lane_to_slice = _prosody_shard_tags_for_lanes( + instruction, + unique_lanes, + aux_slice_keys=self._fabric_q._r_handles.keys(), + ) + prefill_slices = _build_prefill_slices( + self._fabric_q, routing_lane_ids, lane_to_slice + ) + full_script = "\n".join(serialized_turns) + full_script = full_script.replace("’", "'") + inputs = processor( + text=[full_script], + voice_samples=[prefill_slices], + padding=True, + return_tensors="pt", + return_attention_mask=True, + ) + target = self._device if self._device != "cpu" else "cpu" + for k, v in inputs.items(): + if torch.is_tensor(v): + inputs[k] = v.to(target) + with torch.inference_mode(): + outputs = inference_model.generate( + **inputs, + max_new_tokens=None, + cfg_scale=self._cfg_scale, + tokenizer=processor.tokenizer, + generation_config={"do_sample": False}, + verbose=False, + is_prefill=not self._disable_prefill, + ) + if not outputs.speech_outputs or outputs.speech_outputs[0] is None: + raise RuntimeError("QWEN3Vox returned no speech output.") + wav = self._speech_tensor_to_numpy(outputs.speech_outputs[0]) + return (wav, self._sample_rate)