| from __future__ import annotations |
|
|
|
|
| import io |
| import json |
| import os |
| import re |
| import sys |
| import threading |
| import traceback |
| from pathlib import Path |
| from typing import AbstractSet, Any, Dict, List, Optional, Sequence, Tuple, Union |
| import numpy as np |
| import torch |
| from transformers.utils import logging as hf_logging |
| import math |
| import random |
| import warnings |
| from dataclasses import dataclass |
|
|
| try: |
| import librosa |
| except Exception: |
| librosa = None |
| try: |
| import resampy |
| except Exception: |
| resampy = None |
|
|
|
|
| def _resample_if_needed(wav: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: |
| if orig_sr == target_sr: |
| return wav.astype(np.float32, copy=False) |
| if resampy is not None: |
| return resampy.resample(wav.astype(np.float32), orig_sr, target_sr) |
| if librosa is not None: |
| return librosa.resample( |
| y=wav.astype(np.float32), orig_sr=orig_sr, target_sr=target_sr |
| ) |
| warnings.warn( |
| "No resampler available; treating audio as target_sr without resampling. Install resampy or librosa.", |
| RuntimeWarning, |
| ) |
| return wav.astype(np.float32, copy=False) |
|
|
|
|
| class QWEN3VoxDataset: |
|
|
| def __init__( |
| self, |
| dataset: Any, |
| text_column: str = "text", |
| audio_column: str = "audio", |
| voice_prompts_column: Optional[str] = "voice_prompts", |
| ) -> None: |
| self.dataset = dataset |
| self.text_column = text_column |
| self.audio_column = audio_column |
| self.voice_prompts_column = voice_prompts_column |
|
|
| def __len__(self) -> int: |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| item = self.dataset[idx] |
| data: Dict[str, Any] = {} |
| data["text"] = item[self.text_column] |
| data["audio"] = item[self.audio_column] |
| user_provided_prompt = None |
| if self.voice_prompts_column and self.voice_prompts_column in item: |
| user_provided_prompt = item[self.voice_prompts_column] |
| if user_provided_prompt: |
| if not isinstance(user_provided_prompt, list): |
| data["voice_prompts"] = [user_provided_prompt] |
| else: |
| data["voice_prompts"] = user_provided_prompt |
| else: |
| try: |
| target_sr = 22050 |
| wav_array = _load_audio_to_24k( |
| item[self.audio_column], target_sr=target_sr |
| ) |
| audio_len_seconds = len(wav_array) / target_sr |
| min_len_sec = min(5.0, audio_len_seconds / 4.0) |
| max_len_sec = min(15.0, audio_len_seconds / 2.0) |
| if min_len_sec > max_len_sec: |
| min_len_sec = max_len_sec |
| max_len_sec = min(max_len_sec, audio_len_seconds) |
| if max_len_sec > 0.1: |
| prompt_len_sec = random.uniform(min_len_sec, max_len_sec) |
| prompt_len_samples = int(prompt_len_sec * target_sr) |
| max_start_sample = len(wav_array) - prompt_len_samples |
| start_sample = random.randint(0, max_start_sample) |
| prompt_crop = wav_array[ |
| start_sample : start_sample + prompt_len_samples |
| ] |
| data["voice_prompts"] = [prompt_crop] |
| else: |
| data["voice_prompts"] = None |
| except Exception as e: |
| warnings.warn(f"Could not create voice prompt for item {idx }: {e }") |
| data["voice_prompts"] = None |
| return data |
|
|
|
|
| def _apply_silence_with_crossfade( |
| wav: np.ndarray, |
| *, |
| sample_rate: int, |
| pre_silence_sec: float = 0.25, |
| pre_crossfade_sec: float = 0.25, |
| post_crossfade_sec: float = 0.25, |
| post_silence_sec: float = 0.75, |
| ) -> np.ndarray: |
| wav = np.asarray(wav, dtype=np.float32).reshape(-1) |
| start_sil_samples = int(round(pre_silence_sec * sample_rate)) |
| end_sil_samples = int(round(post_silence_sec * sample_rate)) |
| pre_crossfade_samples = int(round(pre_crossfade_sec * sample_rate)) |
| post_crossfade_samples = int(round(post_crossfade_sec * sample_rate)) |
| total_len = wav.shape[0] |
| if total_len == 0: |
| pieces: List[np.ndarray] = [] |
| if start_sil_samples > 0: |
| pieces.append(np.zeros(start_sil_samples, dtype=np.float32)) |
| if end_sil_samples > 0: |
| pieces.append(np.zeros(end_sil_samples, dtype=np.float32)) |
| return np.concatenate(pieces) if pieces else wav |
| start_len = min(pre_crossfade_samples, total_len) |
| remaining_after_start = max(total_len - start_len, 0) |
| end_len = min(post_crossfade_samples, remaining_after_start) |
| middle_end_idx = total_len - end_len |
| start_segment = wav[:start_len] |
| middle_segment = wav[start_len:middle_end_idx] |
| end_segment = wav[middle_end_idx:] |
|
|
| def _linear_fade(num_samples: int, start: float, end: float) -> np.ndarray: |
| if num_samples <= 0: |
| return np.zeros((0,), dtype=np.float32) |
| return np.linspace(start, end, num_samples, endpoint=True, dtype=np.float32) |
|
|
| start_crossfade = start_segment * _linear_fade(start_len, 0.0, 1.0) |
| end_crossfade = end_segment * _linear_fade(end_segment.shape[0], 1.0, 0.0) |
| pieces: List[np.ndarray] = [] |
| if start_sil_samples > 0: |
| pieces.append(np.zeros(start_sil_samples, dtype=np.float32)) |
| if start_crossfade.size > 0: |
| pieces.append(start_crossfade.astype(np.float32, copy=False)) |
| if middle_segment.size > 0: |
| pieces.append(middle_segment.astype(np.float32, copy=False)) |
| if end_crossfade.size > 0: |
| pieces.append(end_crossfade.astype(np.float32, copy=False)) |
| if end_sil_samples > 0: |
| pieces.append(np.zeros(end_sil_samples, dtype=np.float32)) |
| return np.concatenate(pieces) |
|
|
|
|
| def _load_audio_to_24k( |
| audio: Union[str, np.ndarray, torch.Tensor, Dict[str, Any]], |
| *, |
| target_sr: int = 22050, |
| augment_with_silence: bool = False, |
| ) -> np.ndarray: |
| if isinstance(audio, np.ndarray): |
| wav_out = audio.astype(np.float32) |
| elif isinstance(audio, torch.Tensor): |
| wav_out = audio.detach().cpu().float().numpy() |
| elif isinstance(audio, str): |
| if librosa is None: |
| raise RuntimeError( |
| "librosa is required to load audio file paths. Please pip install librosa." |
| ) |
| wav, sr = librosa.load(audio, sr=None, mono=True) |
| wav_out = _resample_if_needed(wav, int(sr), target_sr) |
| elif isinstance(audio, dict) and "array" in audio and ("sampling_rate" in audio): |
| arr = np.asarray(audio["array"], dtype=np.float32) |
| sr = int(audio["sampling_rate"]) |
| wav_out = _resample_if_needed(arr, sr, target_sr) |
| else: |
| raise ValueError(f"Unsupported audio type: {type (audio )}") |
| wav_out = np.asarray(wav_out, dtype=np.float32) |
| if augment_with_silence: |
| wav_out = _apply_silence_with_crossfade(wav_out, sample_rate=target_sr) |
| return wav_out |
|
|
|
|
| @dataclass |
| class QWEN3VoxCollator: |
| processor: Any |
| max_length: Optional[int] = None |
| speech_compress_ratio: int = 3200 |
| semantic_vae_dim: int = 128 |
| compute_semantics: bool = False |
| debug_checks: bool = False |
| text_field: str = "text" |
| audio_field: str = "audio" |
| voice_prompts_field: str = "voice_prompts" |
| voice_prompt_drop_rate: float = 0.0 |
|
|
| def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, Any]: |
| batch_size = len(features) |
| sample_input_ids: List[List[int]] = [] |
| sample_attention_masks: List[List[int]] = [] |
| sample_acoustic_input_masks: List[List[bool]] = [] |
| sample_acoustic_loss_masks: List[List[bool]] = [] |
| all_speech_waveforms: List[np.ndarray] = [] |
| all_speech_latent_lengths: List[int] = [] |
| per_segment_is_target: List[bool] = [] |
| for ex in features: |
| text: str = ex.get(self.text_field, "") |
| voice_prompts: Optional[List[Union[str, np.ndarray, torch.Tensor]]] = ( |
| ex.get(self.voice_prompts_field) |
| ) |
| target_audio: Union[str, np.ndarray, torch.Tensor, Dict[str, Any]] = ex.get( |
| self.audio_field |
| ) |
| _drop_rate = self.voice_prompt_drop_rate |
| if _drop_rate < 0.0: |
| _drop_rate = 0.0 |
| elif _drop_rate > 1.0: |
| _drop_rate = 1.0 |
| proc = self.processor( |
| text=[text], |
| voice_samples=( |
| [voice_prompts] |
| if voice_prompts is not None and random.random() >= _drop_rate |
| else None |
| ), |
| padding=False, |
| truncation=False, |
| max_length=self.max_length, |
| return_tensors="pt", |
| ) |
| ids = proc["input_ids"][0].tolist() |
| attn = proc.get("attention_mask", torch.ones_like(proc["input_ids"]))[ |
| 0 |
| ].tolist() |
| speech_input_mask = proc.get("speech_input_mask") |
| if speech_input_mask is None: |
| speech_input_mask = torch.zeros_like( |
| proc["input_ids"], dtype=torch.bool |
| ) |
| speech_input_mask_list = speech_input_mask[0].tolist() |
| wav_target = _load_audio_to_24k( |
| target_audio, target_sr=22050, augment_with_silence=True |
| ) |
| target_latent_len = None |
| try: |
| acoustic_tok = getattr(self.processor, "acoustic_tokenizer", None) |
| if acoustic_tok is not None and hasattr(acoustic_tok, "encode"): |
| enc_out = acoustic_tok.encode(wav_target) |
| T = None |
| try: |
| if ( |
| hasattr(enc_out, "shape") |
| and len(getattr(enc_out, "shape", [])) >= 1 |
| ): |
| T = int(enc_out.shape[0]) |
| else: |
| cand = enc_out |
| for _ in range(2): |
| if isinstance(cand, (list, tuple)) and len(cand) > 0: |
| cand = cand[0] |
| if ( |
| hasattr(cand, "shape") |
| and len(getattr(cand, "shape", [])) >= 1 |
| ): |
| T = int(cand.shape[0]) |
| except Exception: |
| T = None |
| if T is not None and T > 0: |
| target_latent_len = T |
| except Exception: |
| target_latent_len = None |
| if target_latent_len is None: |
| target_latent_len = max( |
| 1, |
| int(math.ceil(len(wav_target) / float(self.speech_compress_ratio))), |
| ) |
| speech_diff_id = self.processor.tokenizer.speech_diffusion_id |
| target_placeholders = [speech_diff_id] * target_latent_len |
| ids_extended = ids + target_placeholders |
| attn_extended = attn + [1] * target_latent_len |
| acoustic_input_mask = speech_input_mask_list + [True] * target_latent_len |
| acoustic_loss_mask = [False] * len(speech_input_mask_list) + [ |
| True |
| ] * target_latent_len |
| speech_end_id = self.processor.tokenizer.speech_end_id |
| ids_extended.append(speech_end_id) |
| attn_extended.append(1) |
| acoustic_input_mask.append(False) |
| acoustic_loss_mask.append(False) |
| eos_token_id = getattr(self.processor.tokenizer, "eos_id", None) |
| if eos_token_id is None: |
| eos_token_id = getattr(self.processor.tokenizer, "eos_token_id", None) |
| if eos_token_id is not None and eos_token_id >= 0: |
| ids_extended.append(eos_token_id) |
| attn_extended.append(1) |
| acoustic_input_mask.append(False) |
| acoustic_loss_mask.append(False) |
| if self.max_length is not None and len(ids_extended) > self.max_length: |
| cut = len(ids_extended) - int(self.max_length) |
| leading_non_acoustic = 0 |
| for v in acoustic_input_mask: |
| if v: |
| break |
| leading_non_acoustic += 1 |
| if cut > leading_non_acoustic: |
| raise ValueError( |
| f"--max_length={self .max_length } would truncate into acoustic tokens. Needed cut={cut }, but only {leading_non_acoustic } leading non-acoustic tokens available. Increase max_length or shorten text/voice-prompt preamble." |
| ) |
| ids_extended = ids_extended[cut:] |
| attn_extended = attn_extended[cut:] |
| acoustic_input_mask = acoustic_input_mask[cut:] |
| acoustic_loss_mask = acoustic_loss_mask[cut:] |
| sample_input_ids.append(ids_extended) |
| sample_attention_masks.append(attn_extended) |
| sample_acoustic_input_masks.append(acoustic_input_mask) |
| sample_acoustic_loss_masks.append(acoustic_loss_mask) |
| voice_speeches = [] |
| voice_latent_lengths = [] |
| if proc.get("speech_tensors") is not None: |
| voice_np = proc["speech_tensors"].cpu().numpy() |
| voice_masks = proc["speech_masks"].cpu().numpy().astype(bool) |
| for seg_idx in range(voice_np.shape[0]): |
| voice_speeches.append(voice_np[seg_idx]) |
| voice_latent_lengths.append(int(voice_masks[seg_idx].sum())) |
| all_speech_waveforms.extend(voice_speeches) |
| all_speech_latent_lengths.extend(voice_latent_lengths) |
| per_segment_is_target.extend([False] * len(voice_speeches)) |
| all_speech_waveforms.append(wav_target) |
| all_speech_latent_lengths.append(target_latent_len) |
| per_segment_is_target.append(True) |
| max_seq_len = max((len(x) for x in sample_input_ids)) |
| padded_input_ids = [] |
| padded_attention_masks = [] |
| padded_acoustic_input_masks = [] |
| padded_acoustic_loss_masks = [] |
| tok = self.processor.tokenizer |
| pad_token_id = getattr(tok, "pad_token_id", None) |
| if pad_token_id is None or pad_token_id < 0: |
| pad_token_id = getattr(tok, "eos_token_id", None) |
| if pad_token_id is None or pad_token_id < 0: |
| raise ValueError( |
| "Tokenizer has no pad_token_id or eos_token_id; please set one or pass a valid pad id." |
| ) |
| for ids, attn, ain_mask, aloss_mask in zip( |
| sample_input_ids, |
| sample_attention_masks, |
| sample_acoustic_input_masks, |
| sample_acoustic_loss_masks, |
| ): |
| pad_len = max_seq_len - len(ids) |
| padded_input_ids.append(ids + [pad_token_id] * pad_len) |
| padded_attention_masks.append(attn + [0] * pad_len) |
| padded_acoustic_input_masks.append(ain_mask + [False] * pad_len) |
| padded_acoustic_loss_masks.append(aloss_mask + [False] * pad_len) |
| input_ids_tensor = torch.tensor(padded_input_ids, dtype=torch.long) |
| attention_mask_tensor = torch.tensor(padded_attention_masks, dtype=torch.long) |
| acoustic_input_mask_tensor = torch.tensor( |
| padded_acoustic_input_masks, dtype=torch.bool |
| ) |
| acoustic_loss_mask_tensor = torch.tensor( |
| padded_acoustic_loss_masks, dtype=torch.bool |
| ) |
| if all_speech_waveforms: |
| max_wave_len = max((w.shape[0] for w in all_speech_waveforms)) |
| padded_speeches = np.zeros( |
| (len(all_speech_waveforms), max_wave_len), dtype=np.float32 |
| ) |
| for i, w in enumerate(all_speech_waveforms): |
| L = w.shape[0] |
| padded_speeches[i, :L] = w |
| max_latent_len = ( |
| max(all_speech_latent_lengths) if all_speech_latent_lengths else 1 |
| ) |
| speech_masks_np = np.zeros( |
| (len(all_speech_waveforms), max_latent_len), dtype=np.bool_ |
| ) |
| for i, L_lat in enumerate(all_speech_latent_lengths): |
| speech_masks_np[i, :L_lat] = True |
| speech_tensors_tensor = torch.tensor(padded_speeches, dtype=torch.float32) |
| speech_masks_tensor = torch.tensor(speech_masks_np, dtype=torch.bool) |
| speeches_loss_input_np = np.zeros_like(speech_masks_np, dtype=np.bool_) |
| for i, is_target in enumerate(per_segment_is_target): |
| if is_target: |
| speeches_loss_input_np[i] = speech_masks_np[i] |
| speeches_loss_input_tensor = torch.tensor( |
| speeches_loss_input_np, dtype=torch.bool |
| ) |
| if ( |
| self.compute_semantics |
| and hasattr(self.processor, "semantic_tokenizer") |
| and (self.processor.semantic_tokenizer is not None) |
| ): |
| sem_feats: List[np.ndarray] = [] |
| for w in all_speech_waveforms: |
| try: |
| sem = self.processor.semantic_tokenizer.encode(w) |
| sem = np.asarray(sem, dtype=np.float32) |
| except Exception: |
| sem = np.zeros((0, self.semantic_vae_dim), dtype=np.float32) |
| if sem.ndim != 2: |
| raise RuntimeError( |
| f"Semantic tokenizer returned unexpected shape {sem .shape }. Expect [T, D]." |
| ) |
| L = sem.shape[0] |
| D = sem.shape[1] |
| if D != self.semantic_vae_dim: |
| if D < self.semantic_vae_dim: |
| pad_d = np.zeros( |
| (L, self.semantic_vae_dim - D), dtype=np.float32 |
| ) |
| sem = np.concatenate([sem, pad_d], axis=1) |
| else: |
| sem = sem[:, : self.semantic_vae_dim] |
| if L < max_latent_len: |
| pad = np.zeros( |
| (max_latent_len - L, self.semantic_vae_dim), |
| dtype=np.float32, |
| ) |
| sem = np.concatenate([sem, pad], axis=0) |
| elif L > max_latent_len: |
| sem = sem[:max_latent_len] |
| sem_feats.append(sem.astype(np.float32)) |
| speech_semantic_tensors = torch.tensor( |
| np.stack(sem_feats, axis=0), dtype=torch.float32 |
| ) |
| else: |
| raise RuntimeError( |
| "Semantic features are required but could not be computed. Ensure processor.semantic_tokenizer is available or precompute and provide features." |
| ) |
| else: |
| speech_tensors_tensor = None |
| speech_masks_tensor = None |
| speeches_loss_input_tensor = None |
| speech_semantic_tensors = None |
| if self.debug_checks: |
| assert (input_ids_tensor >= 0).all(), "input_ids contains negative indices" |
| if speech_tensors_tensor is not None: |
| assert ( |
| speech_tensors_tensor.dim() == 2 |
| ), "Expected speech_tensors 2D [segments, samples]" |
| return { |
| "input_ids": input_ids_tensor, |
| "attention_mask": attention_mask_tensor, |
| "speech_tensors": speech_tensors_tensor, |
| "speech_masks": speech_masks_tensor, |
| "speech_semantic_tensors": speech_semantic_tensors, |
| "acoustic_input_mask": acoustic_input_mask_tensor, |
| "acoustic_loss_mask": acoustic_loss_mask_tensor, |
| "speeches_loss_input": speeches_loss_input_tensor, |
| } |
|
|
|
|
| ' QWEN3Vox_AcousticTokenizer model configuration' |
| from typing import Dict, List, Optional, Tuple |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils import logging |
| from transformers.models.qwen2.configuration_qwen2 import Qwen2Config |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class QWEN3VoxAcousticTokenizerConfig(PretrainedConfig): |
| model_type = 'vibevoice_acoustic_tokenizer' |
|
|
| def __init__( |
| self, |
| channels: int = 1, |
| corpus_normalize: float = 0.0, |
| causal: bool = True, |
| vae_dim: int = 64, |
| fix_std: float = 0.5, |
| std_dist_type: str = "gaussian", |
| mixer_layer: str = "depthwise_conv", |
| conv_norm: str = "none", |
| pad_mode: str = "constant", |
| disable_last_norm: bool = True, |
| layernorm: str = "RMSNorm", |
| layernorm_eps: float = 1e-05, |
| layernorm_elementwise_affine: bool = True, |
| conv_bias: bool = True, |
| layer_scale_init_value: float = 1e-06, |
| weight_init_value: float = 0.01, |
| encoder_n_filters: int = 32, |
| encoder_ratios: Optional[List[int]] = [8, 5, 5, 4, 2, 2], |
| encoder_depths: str = "3-3-3-3-3-3-8", |
| decoder_n_filters: int = 32, |
| decoder_ratios: Optional[List[int]] = None, |
| decoder_depths: Optional[str] = None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.channels = channels |
| self.corpus_normalize = corpus_normalize |
| self.causal = causal |
| self.vae_dim = vae_dim |
| self.fix_std = fix_std |
| self.std_dist_type = std_dist_type |
| self.conv_norm = conv_norm |
| self.pad_mode = pad_mode |
| self.layernorm_eps = layernorm_eps |
| self.disable_last_norm = disable_last_norm |
| self.layernorm = layernorm |
| self.layernorm_elementwise_affine = layernorm_elementwise_affine |
| self.conv_bias = conv_bias |
| self.layer_scale_init_value = layer_scale_init_value |
| self.weight_init_value = weight_init_value |
| self.mixer_layer = mixer_layer |
| self.encoder_n_filters = encoder_n_filters |
| self.encoder_ratios = encoder_ratios |
| self.encoder_depths = encoder_depths |
| self.decoder_ratios = ( |
| decoder_ratios if decoder_ratios is not None else encoder_ratios |
| ) |
| self.decoder_n_filters = decoder_n_filters |
| self.decoder_depths = decoder_depths |
|
|
|
|
| class QWEN3VoxSemanticTokenizerConfig(PretrainedConfig): |
| model_type = 'vibevoice_semantic_tokenizer' |
|
|
| def __init__( |
| self, |
| channels: int = 1, |
| corpus_normalize: float = 0.0, |
| causal: bool = True, |
| vae_dim: int = 64, |
| fix_std: float = 0, |
| std_dist_type: str = "none", |
| mixer_layer: str = "depthwise_conv", |
| conv_norm: str = "none", |
| pad_mode: str = "constant", |
| disable_last_norm: bool = True, |
| layernorm: str = "RMSNorm", |
| layernorm_eps: float = 1e-05, |
| layernorm_elementwise_affine: bool = True, |
| conv_bias: bool = True, |
| layer_scale_init_value: float = 1e-06, |
| weight_init_value: float = 0.01, |
| encoder_n_filters: int = 32, |
| encoder_ratios: Optional[List[int]] = [8, 5, 5, 4, 2, 2], |
| encoder_depths: str = "3-3-3-3-3-3-8", |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.channels = channels |
| self.corpus_normalize = corpus_normalize |
| self.causal = causal |
| self.vae_dim = vae_dim |
| self.fix_std = fix_std |
| self.std_dist_type = std_dist_type |
| self.conv_norm = conv_norm |
| self.pad_mode = pad_mode |
| self.layernorm_eps = layernorm_eps |
| self.disable_last_norm = disable_last_norm |
| self.layernorm = layernorm |
| self.layernorm_elementwise_affine = layernorm_elementwise_affine |
| self.conv_bias = conv_bias |
| self.layer_scale_init_value = layer_scale_init_value |
| self.weight_init_value = weight_init_value |
| self.mixer_layer = mixer_layer |
| self.encoder_n_filters = encoder_n_filters |
| self.encoder_ratios = encoder_ratios |
| self.encoder_depths = encoder_depths |
|
|
|
|
| class QWEN3VoxDiffusionHeadConfig(PretrainedConfig): |
| model_type = 'vibevoice_diffusion_head' |
|
|
| def __init__( |
| self, |
| hidden_size=768, |
| head_layers=4, |
| head_ffn_ratio=3.0, |
| rms_norm_eps=1e-05, |
| latent_size=64, |
| speech_vae_dim=None, |
| prediction_type="v_prediction", |
| diffusion_type="ddpm", |
| ddpm_num_steps=1000, |
| ddpm_num_inference_steps=30, |
| ddpm_beta_schedule="cosine", |
| ddpm_batch_mul=4, |
| **kwargs, |
| ): |
| self.hidden_size = hidden_size |
| self.head_layers = head_layers |
| self.head_ffn_ratio = head_ffn_ratio |
| self.rms_norm_eps = rms_norm_eps |
| self.latent_size = latent_size |
| self.speech_vae_dim = speech_vae_dim |
| self.prediction_type = prediction_type |
| self.diffusion_type = diffusion_type |
| self.ddpm_num_steps = ddpm_num_steps |
| self.ddpm_num_inference_steps = ddpm_num_inference_steps |
| self.ddpm_beta_schedule = ddpm_beta_schedule |
| self.ddpm_batch_mul = ddpm_batch_mul |
| super().__init__(**kwargs) |
|
|
|
|
| class QWEN3VoxConfig(PretrainedConfig): |
| model_type = 'vibevoice' |
| is_composition = True |
| sub_configs = { |
| "acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig, |
| "semantic_tokenizer_config": QWEN3VoxSemanticTokenizerConfig, |
| "decoder_config": Qwen2Config, |
| "diffusion_head_config": QWEN3VoxDiffusionHeadConfig, |
| } |
| base_model_tp_plan = { |
| "layers.*.self_attn.q_proj": "colwise", |
| "layers.*.self_attn.k_proj": "colwise", |
| "layers.*.self_attn.v_proj": "colwise", |
| "layers.*.self_attn.o_proj": "rowwise", |
| "layers.*.mlp.gate_proj": "colwise", |
| "layers.*.mlp.up_proj": "colwise", |
| "layers.*.mlp.down_proj": "rowwise", |
| } |
|
|
| def __init__( |
| self, |
| acoustic_tokenizer_config=None, |
| semantic_tokenizer_config=None, |
| decoder_config=None, |
| diffusion_head_config=None, |
| **kwargs, |
| ): |
| kwargs["_attn_implementation_autoset"] = False |
| if acoustic_tokenizer_config is None: |
| self.acoustic_tokenizer_config = self.sub_configs[ |
| "acoustic_tokenizer_config" |
| ]() |
| elif isinstance(acoustic_tokenizer_config, dict): |
| acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer' |
| self.acoustic_tokenizer_config = self.sub_configs[ |
| "acoustic_tokenizer_config" |
| ](**acoustic_tokenizer_config) |
| elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig): |
| self.acoustic_tokenizer_config = acoustic_tokenizer_config |
| if semantic_tokenizer_config is None: |
| self.semantic_tokenizer_config = self.sub_configs[ |
| "semantic_tokenizer_config" |
| ]() |
| elif isinstance(semantic_tokenizer_config, dict): |
| semantic_tokenizer_config["model_type"] = 'vibevoice_semantic_tokenizer' |
| self.semantic_tokenizer_config = self.sub_configs[ |
| "semantic_tokenizer_config" |
| ](**semantic_tokenizer_config) |
| elif isinstance(semantic_tokenizer_config, QWEN3VoxSemanticTokenizerConfig): |
| self.semantic_tokenizer_config = semantic_tokenizer_config |
| if decoder_config is None: |
| self.decoder_config = self.sub_configs["decoder_config"]() |
| elif isinstance(decoder_config, dict): |
| if decoder_config.get("model_type", "") == "qwen2": |
| self.decoder_config = Qwen2Config(**decoder_config) |
| else: |
| raise ValueError( |
| f"Unsupported decoder model type: {decoder_config .get ('model_type','')}" |
| ) |
| elif isinstance(decoder_config, (Qwen2Config,)): |
| self.decoder_config = decoder_config |
| if diffusion_head_config is None: |
| self.diffusion_head_config = self.sub_configs["diffusion_head_config"]() |
| elif isinstance(diffusion_head_config, dict): |
| diffusion_head_config["model_type"] = 'vibevoice_diffusion_head' |
| self.diffusion_head_config = self.sub_configs["diffusion_head_config"]( |
| **diffusion_head_config |
| ) |
| elif isinstance(diffusion_head_config, QWEN3VoxDiffusionHeadConfig): |
| self.diffusion_head_config = diffusion_head_config |
| self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64) |
| self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, "vae_dim", 128) |
| super().__init__(**kwargs) |
|
|
|
|
| class QWEN3VoxASRConfig(PretrainedConfig): |
| model_type = 'vibevoice' |
| is_composition = True |
| sub_configs = { |
| "acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig, |
| "semantic_tokenizer_config": QWEN3VoxSemanticTokenizerConfig, |
| "decoder_config": Qwen2Config, |
| } |
| base_model_tp_plan = { |
| "layers.*.self_attn.q_proj": "colwise", |
| "layers.*.self_attn.k_proj": "colwise", |
| "layers.*.self_attn.v_proj": "colwise", |
| "layers.*.self_attn.o_proj": "rowwise", |
| "layers.*.mlp.gate_proj": "colwise", |
| "layers.*.mlp.up_proj": "colwise", |
| "layers.*.mlp.down_proj": "rowwise", |
| } |
|
|
| def __init__( |
| self, |
| acoustic_tokenizer_config=None, |
| semantic_tokenizer_config=None, |
| decoder_config=None, |
| **kwargs, |
| ): |
| kwargs["_attn_implementation_autoset"] = False |
| if acoustic_tokenizer_config is None: |
| self.acoustic_tokenizer_config = self.sub_configs[ |
| "acoustic_tokenizer_config" |
| ]() |
| elif isinstance(acoustic_tokenizer_config, dict): |
| acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer' |
| self.acoustic_tokenizer_config = self.sub_configs[ |
| "acoustic_tokenizer_config" |
| ](**acoustic_tokenizer_config) |
| elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig): |
| self.acoustic_tokenizer_config = acoustic_tokenizer_config |
| if semantic_tokenizer_config is None: |
| self.semantic_tokenizer_config = self.sub_configs[ |
| "semantic_tokenizer_config" |
| ]() |
| elif isinstance(semantic_tokenizer_config, dict): |
| semantic_tokenizer_config["model_type"] = 'vibevoice_semantic_tokenizer' |
| self.semantic_tokenizer_config = self.sub_configs[ |
| "semantic_tokenizer_config" |
| ](**semantic_tokenizer_config) |
| elif isinstance(semantic_tokenizer_config, QWEN3VoxSemanticTokenizerConfig): |
| self.semantic_tokenizer_config = semantic_tokenizer_config |
| if decoder_config is None: |
| self.decoder_config = self.sub_configs["decoder_config"]() |
| elif isinstance(decoder_config, dict): |
| if decoder_config.get("model_type", "") == "qwen2": |
| self.decoder_config = Qwen2Config(**decoder_config) |
| else: |
| raise ValueError( |
| f"Unsupported decoder model type: {decoder_config .get ('model_type','')}" |
| ) |
| elif isinstance(decoder_config, Qwen2Config): |
| self.decoder_config = decoder_config |
| self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64) |
| self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, "vae_dim", 128) |
| super().__init__(**kwargs) |
|
|
| def get_text_config(self, decoder: bool = False): |
| return self.decoder_config |
|
|
| @property |
| def vocab_size(self): |
| return self.decoder_config.vocab_size |
|
|
| @property |
| def num_attention_heads(self): |
| return self.decoder_config.num_attention_heads |
|
|
| @property |
| def num_key_value_heads(self): |
| return self.decoder_config.num_key_value_heads |
|
|
| @property |
| def hidden_size(self): |
| return self.decoder_config.hidden_size |
|
|
| @property |
| def num_hidden_layers(self): |
| return self.decoder_config.num_hidden_layers |
|
|
| @property |
| def head_dim(self): |
| return getattr( |
| self.decoder_config, |
| "head_dim", |
| self.hidden_size // self.num_attention_heads, |
| ) |
|
|
|
|
| __all__ = [ |
| 'QWEN3VoxAcousticTokenizerConfig', |
| 'QWEN3VoxSemanticTokenizerConfig', |
| 'QWEN3VoxDiffusionHeadConfig', |
| 'QWEN3VoxConfig', |
| 'QWEN3VoxASRConfig', |
| ] |
| import torch |
| import asyncio |
| from queue import Queue |
| from typing import TYPE_CHECKING, Optional |
| from transformers.generation import BaseStreamer |
|
|
|
|
| class AudioStreamer(BaseStreamer): |
|
|
| def __init__( |
| self, |
| batch_size: int, |
| stop_signal: Optional[any] = None, |
| timeout: Optional[float] = None, |
| ): |
| self.batch_size = batch_size |
| self.stop_signal = stop_signal |
| self.timeout = timeout |
| self.audio_queues = [Queue() for _ in range(batch_size)] |
| self.finished_flags = [False for _ in range(batch_size)] |
| self.sample_indices_map = {} |
|
|
| def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): |
| for i, sample_idx in enumerate(sample_indices): |
| idx = sample_idx.item() |
| if idx < self.batch_size and (not self.finished_flags[idx]): |
| audio_chunk = audio_chunks[i].detach().cpu() |
| self.audio_queues[idx].put(audio_chunk, timeout=self.timeout) |
|
|
| def end(self, sample_indices: Optional[torch.Tensor] = None): |
| if sample_indices is None: |
| for idx in range(self.batch_size): |
| if not self.finished_flags[idx]: |
| self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) |
| self.finished_flags[idx] = True |
| else: |
| for sample_idx in sample_indices: |
| idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx |
| if idx < self.batch_size and (not self.finished_flags[idx]): |
| self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) |
| self.finished_flags[idx] = True |
|
|
| def __iter__(self): |
| return AudioBatchIterator(self) |
|
|
| def get_stream(self, sample_idx: int): |
| if sample_idx >= self.batch_size: |
| raise ValueError( |
| f"Sample index {sample_idx } exceeds batch size {self .batch_size }" |
| ) |
| return AudioSampleIterator(self, sample_idx) |
|
|
|
|
| class AudioSampleIterator: |
|
|
| def __init__(self, streamer: AudioStreamer, sample_idx: int): |
| self.streamer = streamer |
| self.sample_idx = sample_idx |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self): |
| value = self.streamer.audio_queues[self.sample_idx].get( |
| timeout=self.streamer.timeout |
| ) |
| if value == self.streamer.stop_signal: |
| raise StopIteration() |
| return value |
|
|
|
|
| class AudioBatchIterator: |
|
|
| def __init__(self, streamer: AudioStreamer): |
| self.streamer = streamer |
| self.active_samples = set(range(streamer.batch_size)) |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self): |
| if not self.active_samples: |
| raise StopIteration() |
| batch_chunks = {} |
| samples_to_remove = set() |
| for idx in self.active_samples: |
| try: |
| value = self.streamer.audio_queues[idx].get(block=False) |
| if value == self.streamer.stop_signal: |
| samples_to_remove.add(idx) |
| else: |
| batch_chunks[idx] = value |
| except: |
| pass |
| self.active_samples -= samples_to_remove |
| if batch_chunks: |
| return batch_chunks |
| elif self.active_samples: |
| import time |
|
|
| time.sleep(0.01) |
| return self.__next__() |
| else: |
| raise StopIteration() |
|
|
|
|
| class AsyncAudioStreamer(AudioStreamer): |
|
|
| def __init__( |
| self, |
| batch_size: int, |
| stop_signal: Optional[any] = None, |
| timeout: Optional[float] = None, |
| ): |
| super().__init__(batch_size, stop_signal, timeout) |
| self.audio_queues = [asyncio.Queue() for _ in range(batch_size)] |
| self.loop = asyncio.get_running_loop() |
|
|
| def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): |
| for i, sample_idx in enumerate(sample_indices): |
| idx = sample_idx.item() |
| if idx < self.batch_size and (not self.finished_flags[idx]): |
| audio_chunk = audio_chunks[i].detach().cpu() |
| self.loop.call_soon_threadsafe( |
| self.audio_queues[idx].put_nowait, audio_chunk |
| ) |
|
|
| def end(self, sample_indices: Optional[torch.Tensor] = None): |
| if sample_indices is None: |
| indices_to_end = range(self.batch_size) |
| else: |
| indices_to_end = [ |
| s.item() if torch.is_tensor(s) else s for s in sample_indices |
| ] |
| for idx in indices_to_end: |
| if idx < self.batch_size and (not self.finished_flags[idx]): |
| self.loop.call_soon_threadsafe( |
| self.audio_queues[idx].put_nowait, self.stop_signal |
| ) |
| self.finished_flags[idx] = True |
|
|
| async def get_stream(self, sample_idx: int): |
| if sample_idx >= self.batch_size: |
| raise ValueError( |
| f"Sample index {sample_idx } exceeds batch size {self .batch_size }" |
| ) |
| while True: |
| value = await self.audio_queues[sample_idx].get() |
| if value == self.stop_signal: |
| break |
| yield value |
|
|
| def __aiter__(self): |
| return AsyncAudioBatchIterator(self) |
|
|
|
|
| class AsyncAudioBatchIterator: |
|
|
| def __init__(self, streamer: AsyncAudioStreamer): |
| self.streamer = streamer |
| self.active_samples = set(range(streamer.batch_size)) |
|
|
| def __aiter__(self): |
| return self |
|
|
| async def __anext__(self): |
| if not self.active_samples: |
| raise StopAsyncIteration() |
| batch_chunks = {} |
| samples_to_remove = set() |
| tasks = { |
| idx: asyncio.create_task(self._get_chunk(idx)) |
| for idx in self.active_samples |
| } |
| done, pending = await asyncio.wait( |
| tasks.values(), |
| return_when=asyncio.FIRST_COMPLETED, |
| timeout=self.streamer.timeout, |
| ) |
| for task in pending: |
| task.cancel() |
| for idx, task in tasks.items(): |
| if task in done: |
| try: |
| value = await task |
| if value == self.streamer.stop_signal: |
| samples_to_remove.add(idx) |
| else: |
| batch_chunks[idx] = value |
| except asyncio.CancelledError: |
| pass |
| self.active_samples -= samples_to_remove |
| if batch_chunks: |
| return batch_chunks |
| elif self.active_samples: |
| return await self.__anext__() |
| else: |
| raise StopAsyncIteration() |
|
|
| async def _get_chunk(self, idx): |
| return await self.streamer.audio_queues[idx].get() |
|
|
|
|
| 'Tokenization classes for QWEN3Vox.' |
| from typing import List, Optional, Union |
| from transformers.utils import logging |
| from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer |
| from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class QWEN3VoxTextTokenizer(Qwen2Tokenizer): |
| model_input_names = ["input_ids", "attention_mask"] |
|
|
| def __init__( |
| self, |
| vocab_file, |
| merges_file, |
| errors="replace", |
| unk_token="<|endoftext|>", |
| bos_token=None, |
| eos_token="<|endoftext|>", |
| pad_token="<|endoftext|>", |
| add_prefix_space=False, |
| add_special_tokens=True, |
| **kwargs, |
| ): |
| super().__init__( |
| vocab_file=vocab_file, |
| merges_file=merges_file, |
| errors=errors, |
| unk_token=unk_token, |
| bos_token=bos_token, |
| eos_token=eos_token, |
| pad_token=pad_token, |
| add_prefix_space=add_prefix_space, |
| add_special_tokens=add_special_tokens, |
| **kwargs, |
| ) |
| self._add_q3_sp_tok() |
|
|
| def _add_q3_sp_tok(self): |
| special_tokens = { |
| "additional_special_tokens": [ |
| "<|vision_start|>", |
| "<|vision_end|>", |
| "<|vision_pad|>", |
| ] |
| } |
| num_added = self.add_special_tokens(special_tokens) |
| self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>") |
| self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>") |
| self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>") |
| self._eos_id = self.convert_tokens_to_ids("<|endoftext|>") |
| return num_added |
|
|
| @property |
| def eos_id(self) -> int: |
| return self._eos_id |
|
|
| @property |
| def speech_start_id(self) -> int: |
| return self._speech_start_id |
|
|
| @property |
| def speech_end_id(self) -> int: |
| return self._speech_end_id |
|
|
| @property |
| def speech_diffusion_id(self) -> int: |
| return self._speech_diffusion_id |
|
|
| @property |
| def pad_id(self) -> int: |
| return -100 |
|
|
|
|
| class QWEN3VoxTextTokenizerFast(Qwen2TokenizerFast): |
| model_input_names = ["input_ids", "attention_mask"] |
|
|
| def __init__( |
| self, |
| vocab_file=None, |
| merges_file=None, |
| tokenizer_file=None, |
| unk_token="<|endoftext|>", |
| bos_token=None, |
| eos_token="<|endoftext|>", |
| pad_token="<|endoftext|>", |
| add_prefix_space=False, |
| **kwargs, |
| ): |
| super().__init__( |
| vocab_file=vocab_file, |
| merges_file=merges_file, |
| tokenizer_file=tokenizer_file, |
| unk_token=unk_token, |
| bos_token=bos_token, |
| eos_token=eos_token, |
| pad_token=pad_token, |
| add_prefix_space=add_prefix_space, |
| **kwargs, |
| ) |
| self._add_q3_sp_tok() |
|
|
| def _add_q3_sp_tok(self): |
| special_tokens = { |
| "additional_special_tokens": [ |
| "<|vision_start|>", |
| "<|vision_end|>", |
| "<|vision_pad|>", |
| ] |
| } |
| num_added = self.add_special_tokens(special_tokens) |
| self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>") |
| self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>") |
| self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>") |
| self._eos_id = self.eos_token_id |
| self._pad_id = self.convert_tokens_to_ids("<|image_pad|>") |
| return num_added |
|
|
| @property |
| def eos_id(self) -> int: |
| return self._eos_id |
|
|
| @property |
| def speech_start_id(self) -> int: |
| return self._speech_start_id |
|
|
| @property |
| def speech_end_id(self) -> int: |
| return self._speech_end_id |
|
|
| @property |
| def speech_diffusion_id(self) -> int: |
| return self._speech_diffusion_id |
|
|
| @property |
| def pad_id(self) -> int: |
| return self._pad_id |
|
|
|
|
| QWEN3VoxASRTextTokenizerFast = QWEN3VoxTextTokenizerFast |
|
|
| __all__ = [ |
| 'QWEN3VoxTextTokenizer', |
| 'QWEN3VoxTextTokenizerFast', |
| ] |
| "Utilities for loading fine-tuned LoRA adapters and connector weights." |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional |
| import torch |
| import torch.nn as nn |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @dataclass |
| class _LoadReport: |
| language_model: bool = False |
| diffusion_head_lora: bool = False |
| diffusion_head_full: bool = False |
| acoustic_connector: bool = False |
| semantic_connector: bool = False |
| adapter_root: Optional[Path] = None |
|
|
|
|
| class _DiffusionHeadForwardShim(nn.Module): |
|
|
| def __init__(self, base: nn.Module): |
| super().__init__() |
| self.base = base |
|
|
| def forward(self, *args, **kwargs): |
| if len(args) >= 3: |
| noisy_images, timesteps, condition = args[:3] |
| else: |
| noisy_images = kwargs.get("noisy_images") |
| timesteps = kwargs.get("timesteps") |
| condition = kwargs.get("condition") |
| return self.base(noisy_images, timesteps, condition) |
|
|
|
|
| def _resolve_adapter_root(checkpoint_path: Path) -> Path: |
| if checkpoint_path.is_file(): |
| checkpoint_path = checkpoint_path.parent |
| if (checkpoint_path / "lora").exists(): |
| return checkpoint_path / "lora" |
| return checkpoint_path |
|
|
|
|
| def _load_connector( |
| module: Optional[nn.Module], path: Path, device: torch.device |
| ) -> bool: |
| if module is None or not path.exists(): |
| return False |
| state_dict = torch.load(path, map_location=device) |
| missing, unexpected = module.load_state_dict(state_dict, strict=False) |
| if missing: |
| logger.warning(f"Connector load missing keys: {missing }") |
| if unexpected: |
| logger.warning(f"Connector load unexpected keys: {unexpected }") |
| module.to(device) |
| return True |
|
|
|
|
| def _load_diffusion_head( |
| model, adapter_root: Path, device: torch.device, report: _LoadReport |
| ) -> None: |
| diff_dir = adapter_root / "diffusion_head" |
| adapter_config = diff_dir / "adapter_config.json" |
| adapter_model = diff_dir / "adapter_model.bin" |
| adapter_model_safetensors = diff_dir / "adapter_model.safetensors" |
| try: |
| from peft import PeftModel |
| except ImportError as exc: |
| raise RuntimeError( |
| "peft is required to load diffusion head adapters but is not installed" |
| ) from exc |
| if adapter_config.exists() and ( |
| adapter_model.exists() or adapter_model_safetensors.exists() |
| ): |
| logger.warning( |
| f"Skipping diffusion-head LoRA at {diff_dir }; " |
| "PeftModel.from_pretrained is not allowed in miner.py (use full weights .bin)." |
| ) |
| return |
| full_path = diff_dir / "diffusion_head_full.bin" |
| if not full_path.exists(): |
| full_path = adapter_root / "diffusion_head_full.bin" |
| if full_path.exists(): |
| logger.info(f"Loading full diffusion head weights from {full_path }") |
| state_dict = torch.load(full_path, map_location=device) |
| missing, unexpected = model.model.prediction_head.load_state_dict( |
| state_dict, strict=False |
| ) |
| if missing: |
| logger.warning(f"Diffusion head load missing keys: {missing }") |
| if unexpected: |
| logger.warning(f"Diffusion head load unexpected keys: {unexpected }") |
| model.model.prediction_head.to(device) |
| report.diffusion_head_full = True |
|
|
|
|
| def _load_language_model( |
| model, adapter_root: Path, device: torch.device, report: _LoadReport |
| ) -> None: |
| config_file = adapter_root / "adapter_config.json" |
| bin_file = adapter_root / "adapter_model.bin" |
| safe_tensors_file = adapter_root / "adapter_model.safetensors" |
| if not (config_file.exists() and (bin_file.exists() or safe_tensors_file.exists())): |
| return |
| try: |
| from peft import PeftConfig, PeftModel, TaskType |
| except ImportError as exc: |
| raise RuntimeError( |
| "peft is required to load language model adapters but is not installed" |
| ) from exc |
| logger.warning( |
| f"Skipping language-model LoRA at {adapter_root }; " |
| "PeftModel.from_pretrained is not allowed in miner.py (use full weights .bin)." |
| ) |
|
|
|
|
| def load_lora_assets( |
| model, checkpoint_dir: str, device: Optional[torch.device] = None |
| ) -> _LoadReport: |
| adapter_root = _resolve_adapter_root(Path(checkpoint_dir)) |
| if not adapter_root.exists(): |
| raise FileNotFoundError(f"Adapter directory not found: {adapter_root }") |
| inferred_device = device or next(model.parameters()).device |
| report = _LoadReport(adapter_root=adapter_root) |
| _load_language_model(model, adapter_root, inferred_device, report) |
| _load_diffusion_head(model, adapter_root, inferred_device, report) |
| ac_path = adapter_root / "acoustic_connector" / "pytorch_model.bin" |
| if _load_connector( |
| getattr(model.model, "acoustic_connector", None), ac_path, inferred_device |
| ): |
| report.acoustic_connector = True |
| se_path = adapter_root / "semantic_connector" / "pytorch_model.bin" |
| if _load_connector( |
| getattr(model.model, "semantic_connector", None), se_path, inferred_device |
| ): |
| report.semantic_connector = True |
| if not any(report.__dict__.values()): |
| logger.warning( |
| "No adapter assets were loaded. Ensure the checkpoint directory is correct and contains LoRA weights." |
| ) |
| return report |
|
|
|
|
| import math |
| from typing import List, Optional, Tuple, Union |
| import numpy as np |
| import torch |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.utils import deprecate |
| from diffusers.utils.torch_utils import randn_tensor |
| from diffusers.schedulers.scheduling_utils import ( |
| KarrasDiffusionSchedulers, |
| SchedulerMixin, |
| SchedulerOutput, |
| ) |
|
|
|
|
| def betas_for_alpha_bar( |
| num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine" |
| ): |
| if alpha_transform_type == "cosine": |
|
|
| def alpha_bar_fn(t): |
| return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 |
|
|
| elif alpha_transform_type == "exp": |
|
|
| def alpha_bar_fn(t): |
| return math.exp(t * -12.0) |
|
|
| elif alpha_transform_type == "cauchy": |
|
|
| def alpha_bar_fn(t, gamma=1, mu=3): |
| snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9) |
| return 1 - 1 / (math.exp(snr) + 1.1) |
|
|
| elif alpha_transform_type == "laplace": |
|
|
| def alpha_bar_fn(t, mu=0, b=1): |
| snr = mu - b * math.copysign(1, 0.5 - t) * math.log( |
| 1 - 2 * abs(t - 0.5) * 0.98 |
| ) |
| return 1 - 1 / (math.exp(snr) + 1.02) |
|
|
| else: |
| raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type }") |
| betas = [] |
| for i in range(num_diffusion_timesteps): |
| t1 = i / num_diffusion_timesteps |
| t2 = (i + 1) / num_diffusion_timesteps |
| betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) |
| return torch.tensor(betas, dtype=torch.float32) |
|
|
|
|
| def rescale_zero_terminal_snr(betas): |
| alphas = 1.0 - betas |
| alphas_cumprod = torch.cumprod(alphas, dim=0) |
| alphas_bar_sqrt = alphas_cumprod.sqrt() |
| alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() |
| alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() |
| alphas_bar_sqrt -= alphas_bar_sqrt_T |
| alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) |
| alphas_bar = alphas_bar_sqrt**2 |
| alphas = alphas_bar[1:] / alphas_bar[:-1] |
| alphas = torch.cat([alphas_bar[0:1], alphas]) |
| betas = 1 - alphas |
| return betas |
|
|
|
|
| class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): |
| _compatibles = [e.name for e in KarrasDiffusionSchedulers] |
| order = 1 |
|
|
| @register_to_config |
| def __init__( |
| self, |
| num_train_timesteps: int = 1000, |
| beta_start: float = 0.0001, |
| beta_end: float = 0.02, |
| beta_schedule: str = "linear", |
| trained_betas: Optional[Union[np.ndarray, List[float]]] = None, |
| solver_order: int = 2, |
| prediction_type: str = "epsilon", |
| thresholding: bool = False, |
| dynamic_thresholding_ratio: float = 0.995, |
| sample_max_value: float = 1.0, |
| algorithm_type: str = "dpmsolver++", |
| solver_type: str = "midpoint", |
| lower_order_final: bool = True, |
| euler_at_final: bool = False, |
| use_karras_sigmas: Optional[bool] = False, |
| use_lu_lambdas: Optional[bool] = False, |
| final_sigmas_type: Optional[str] = "zero", |
| lambda_min_clipped: float = -float("inf"), |
| variance_type: Optional[str] = None, |
| timestep_spacing: str = "linspace", |
| steps_offset: int = 0, |
| rescale_betas_zero_snr: bool = False, |
| ): |
| if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: |
| deprecation_message = f"algorithm_type {algorithm_type } is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" |
| deprecate( |
| "algorithm_types dpmsolver and sde-dpmsolver", |
| "1.0.0", |
| deprecation_message, |
| ) |
| if trained_betas is not None: |
| self.betas = torch.tensor(trained_betas, dtype=torch.float32) |
| elif beta_schedule == "linear": |
| self.betas = torch.linspace( |
| beta_start, beta_end, num_train_timesteps, dtype=torch.float32 |
| ) |
| elif beta_schedule == "scaled_linear": |
| self.betas = ( |
| torch.linspace( |
| beta_start**0.5, |
| beta_end**0.5, |
| num_train_timesteps, |
| dtype=torch.float32, |
| ) |
| ** 2 |
| ) |
| elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine": |
| self.betas = betas_for_alpha_bar( |
| num_train_timesteps, alpha_transform_type="cosine" |
| ) |
| elif beta_schedule == "cauchy": |
| self.betas = betas_for_alpha_bar( |
| num_train_timesteps, alpha_transform_type="cauchy" |
| ) |
| elif beta_schedule == "laplace": |
| self.betas = betas_for_alpha_bar( |
| num_train_timesteps, alpha_transform_type="laplace" |
| ) |
| else: |
| raise NotImplementedError( |
| f"{beta_schedule } is not implemented for {self .__class__ }" |
| ) |
| if rescale_betas_zero_snr: |
| self.betas = rescale_zero_terminal_snr(self.betas) |
| self.alphas = 1.0 - self.betas |
| self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
| if rescale_betas_zero_snr: |
| self.alphas_cumprod[-1] = 2 ** (-24) |
| self.alpha_t = torch.sqrt(self.alphas_cumprod) |
| self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) |
| self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) |
| self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
| self.init_noise_sigma = 1.0 |
| if algorithm_type not in [ |
| "dpmsolver", |
| "dpmsolver++", |
| "sde-dpmsolver", |
| "sde-dpmsolver++", |
| ]: |
| if algorithm_type == "deis": |
| self.register_to_config(algorithm_type="dpmsolver++") |
| else: |
| raise NotImplementedError( |
| f"{algorithm_type } is not implemented for {self .__class__ }" |
| ) |
| if solver_type not in ["midpoint", "heun"]: |
| if solver_type in ["logrho", "bh1", "bh2"]: |
| self.register_to_config(solver_type="midpoint") |
| else: |
| raise NotImplementedError( |
| f"{solver_type } is not implemented for {self .__class__ }" |
| ) |
| if ( |
| algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] |
| and final_sigmas_type == "zero" |
| ): |
| raise ValueError( |
| f"`final_sigmas_type` {final_sigmas_type } is not supported for `algorithm_type` {algorithm_type }. Please choose `sigma_min` instead." |
| ) |
| self.num_inference_steps = None |
| timesteps = np.linspace( |
| 0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32 |
| )[::-1].copy() |
| self.timesteps = torch.from_numpy(timesteps) |
| self.model_outputs = [None] * solver_order |
| self.lower_order_nums = 0 |
| self._step_index = None |
| self._begin_index = None |
| self.sigmas = self.sigmas.to("cpu") |
|
|
| @property |
| def step_index(self): |
| return self._step_index |
|
|
| @property |
| def begin_index(self): |
| return self._begin_index |
|
|
| def set_begin_index(self, begin_index: int = 0): |
| self._begin_index = begin_index |
|
|
| def set_timesteps( |
| self, |
| num_inference_steps: int = None, |
| device: Union[str, torch.device] = None, |
| timesteps: Optional[List[int]] = None, |
| ): |
| if num_inference_steps is None and timesteps is None: |
| raise ValueError( |
| "Must pass exactly one of `num_inference_steps` or `timesteps`." |
| ) |
| if num_inference_steps is not None and timesteps is not None: |
| raise ValueError( |
| "Can only pass one of `num_inference_steps` or `custom_timesteps`." |
| ) |
| if timesteps is not None and self.config.use_karras_sigmas: |
| raise ValueError( |
| "Cannot use `timesteps` with `config.use_karras_sigmas = True`" |
| ) |
| if timesteps is not None and self.config.use_lu_lambdas: |
| raise ValueError( |
| "Cannot use `timesteps` with `config.use_lu_lambdas = True`" |
| ) |
| if timesteps is not None: |
| timesteps = np.array(timesteps).astype(np.int64) |
| else: |
| clipped_idx = torch.searchsorted( |
| torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped |
| ) |
| last_timestep = ( |
| (self.config.num_train_timesteps - clipped_idx).numpy().item() |
| ) |
| if self.config.timestep_spacing == "linspace": |
| timesteps = ( |
| np.linspace(0, last_timestep - 1, num_inference_steps + 1) |
| .round()[::-1][:-1] |
| .copy() |
| .astype(np.int64) |
| ) |
| elif self.config.timestep_spacing == "leading": |
| step_ratio = last_timestep // (num_inference_steps + 1) |
| timesteps = ( |
| (np.arange(0, num_inference_steps + 1) * step_ratio) |
| .round()[::-1][:-1] |
| .copy() |
| .astype(np.int64) |
| ) |
| timesteps += self.config.steps_offset |
| elif self.config.timestep_spacing == "trailing": |
| step_ratio = self.config.num_train_timesteps / num_inference_steps |
| timesteps = ( |
| np.arange(last_timestep, 0, -step_ratio) |
| .round() |
| .copy() |
| .astype(np.int64) |
| ) |
| timesteps -= 1 |
| else: |
| raise ValueError( |
| f"{self .config .timestep_spacing } is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." |
| ) |
| sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) |
| log_sigmas = np.log(sigmas) |
| if self.config.use_karras_sigmas: |
| sigmas = np.flip(sigmas).copy() |
| sigmas = self._convert_to_karras( |
| in_sigmas=sigmas, num_inference_steps=num_inference_steps |
| ) |
| timesteps = np.array( |
| [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] |
| ).round() |
| elif self.config.use_lu_lambdas: |
| lambdas = np.flip(log_sigmas.copy()) |
| lambdas = self._convert_to_lu( |
| in_lambdas=lambdas, num_inference_steps=num_inference_steps |
| ) |
| sigmas = np.exp(lambdas) |
| timesteps = np.array( |
| [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] |
| ).round() |
| else: |
| sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) |
| if self.config.final_sigmas_type == "sigma_min": |
| sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 |
| elif self.config.final_sigmas_type == "zero": |
| sigma_last = 0 |
| else: |
| raise ValueError( |
| f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self .config .final_sigmas_type }" |
| ) |
| sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) |
| self.sigmas = torch.from_numpy(sigmas) |
| self.timesteps = torch.from_numpy(timesteps).to( |
| device=device, dtype=torch.int64 |
| ) |
| self.num_inference_steps = len(timesteps) |
| self.model_outputs = [None] * self.config.solver_order |
| self.lower_order_nums = 0 |
| self._step_index = None |
| self._begin_index = None |
| self.sigmas = self.sigmas.to("cpu") |
|
|
| def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: |
| dtype = sample.dtype |
| batch_size, channels, *remaining_dims = sample.shape |
| if dtype not in (torch.float32, torch.float64): |
| sample = sample.float() |
| sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) |
| abs_sample = sample.abs() |
| s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) |
| s = torch.clamp(s, min=1, max=self.config.sample_max_value) |
| s = s.unsqueeze(1) |
| sample = torch.clamp(sample, -s, s) / s |
| sample = sample.reshape(batch_size, channels, *remaining_dims) |
| sample = sample.to(dtype) |
| return sample |
|
|
| def _sigma_to_t(self, sigma, log_sigmas): |
| log_sigma = np.log(np.maximum(sigma, 1e-10)) |
| dists = log_sigma - log_sigmas[:, np.newaxis] |
| low_idx = ( |
| np.cumsum(dists >= 0, axis=0) |
| .argmax(axis=0) |
| .clip(max=log_sigmas.shape[0] - 2) |
| ) |
| high_idx = low_idx + 1 |
| low = log_sigmas[low_idx] |
| high = log_sigmas[high_idx] |
| w = (low - log_sigma) / (low - high) |
| w = np.clip(w, 0, 1) |
| t = (1 - w) * low_idx + w * high_idx |
| t = t.reshape(sigma.shape) |
| return t |
|
|
| def _sigma_to_alpha_sigma_t(self, sigma): |
| alpha_t = 1 / (sigma**2 + 1) ** 0.5 |
| sigma_t = sigma * alpha_t |
| return (alpha_t, sigma_t) |
|
|
| def _convert_to_karras( |
| self, in_sigmas: torch.Tensor, num_inference_steps |
| ) -> torch.Tensor: |
| if hasattr(self.config, "sigma_min"): |
| sigma_min = self.config.sigma_min |
| else: |
| sigma_min = None |
| if hasattr(self.config, "sigma_max"): |
| sigma_max = self.config.sigma_max |
| else: |
| sigma_max = None |
| sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() |
| sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() |
| rho = 7.0 |
| ramp = np.linspace(0, 1, num_inference_steps) |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho |
| return sigmas |
|
|
| def _convert_to_lu( |
| self, in_lambdas: torch.Tensor, num_inference_steps |
| ) -> torch.Tensor: |
| lambda_min: float = in_lambdas[-1].item() |
| lambda_max: float = in_lambdas[0].item() |
| rho = 1.0 |
| ramp = np.linspace(0, 1, num_inference_steps) |
| min_inv_rho = lambda_min ** (1 / rho) |
| max_inv_rho = lambda_max ** (1 / rho) |
| lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho |
| return lambdas |
|
|
| def convert_model_output( |
| self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs |
| ) -> torch.Tensor: |
| timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) |
| if sample is None: |
| if len(args) > 1: |
| sample = args[1] |
| else: |
| raise ValueError("missing `sample` as a required keyward argument") |
| if timestep is not None: |
| deprecate( |
| "timesteps", |
| "1.0.0", |
| "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", |
| ) |
| if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: |
| if self.config.prediction_type == "epsilon": |
| if self.config.variance_type in ["learned", "learned_range"]: |
| model_output = model_output[:, :3] |
| sigma = self.sigmas[self.step_index] |
| alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) |
| x0_pred = (sample - sigma_t * model_output) / alpha_t |
| elif self.config.prediction_type == "sample": |
| x0_pred = model_output |
| elif self.config.prediction_type == "v_prediction": |
| sigma = self.sigmas[self.step_index] |
| alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) |
| x0_pred = alpha_t * sample - sigma_t * model_output |
| else: |
| raise ValueError( |
| f"prediction_type given as {self .config .prediction_type } must be one of `epsilon`, `sample`, or `v_prediction` for the DPMSolverMultistepScheduler." |
| ) |
| if self.config.thresholding: |
| x0_pred = self._threshold_sample(x0_pred) |
| return x0_pred |
| elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: |
| if self.config.prediction_type == "epsilon": |
| if self.config.variance_type in ["learned", "learned_range"]: |
| epsilon = model_output[:, :3] |
| else: |
| epsilon = model_output |
| elif self.config.prediction_type == "sample": |
| sigma = self.sigmas[self.step_index] |
| alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) |
| epsilon = (sample - alpha_t * model_output) / sigma_t |
| elif self.config.prediction_type == "v_prediction": |
| sigma = self.sigmas[self.step_index] |
| alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) |
| epsilon = alpha_t * model_output + sigma_t * sample |
| else: |
| raise ValueError( |
| f"prediction_type given as {self .config .prediction_type } must be one of `epsilon`, `sample`, or `v_prediction` for the DPMSolverMultistepScheduler." |
| ) |
| if self.config.thresholding: |
| sigma = self.sigmas[self.step_index] |
| alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) |
| x0_pred = (sample - sigma_t * epsilon) / alpha_t |
| x0_pred = self._threshold_sample(x0_pred) |
| epsilon = (sample - alpha_t * x0_pred) / sigma_t |
| return epsilon |
|
|
| def dpm_solver_first_order_update( |
| self, |
| model_output: torch.Tensor, |
| *args, |
| sample: torch.Tensor = None, |
| noise: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) |
| prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) |
| if sample is None: |
| if len(args) > 2: |
| sample = args[2] |
| else: |
| raise ValueError(" missing `sample` as a required keyward argument") |
| if timestep is not None: |
| deprecate( |
| "timesteps", |
| "1.0.0", |
| "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", |
| ) |
| if prev_timestep is not None: |
| deprecate( |
| "prev_timestep", |
| "1.0.0", |
| "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", |
| ) |
| sigma_t, sigma_s = ( |
| self.sigmas[self.step_index + 1], |
| self.sigmas[self.step_index], |
| ) |
| alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) |
| alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) |
| lambda_t = torch.log(alpha_t) - torch.log(sigma_t) |
| lambda_s = torch.log(alpha_s) - torch.log(sigma_s) |
| h = lambda_t - lambda_s |
| if self.config.algorithm_type == "dpmsolver++": |
| x_t = ( |
| sigma_t / sigma_s * sample |
| - alpha_t * (torch.exp(-h) - 1.0) * model_output |
| ) |
| elif self.config.algorithm_type == "dpmsolver": |
| x_t = ( |
| alpha_t / alpha_s * sample |
| - sigma_t * (torch.exp(h) - 1.0) * model_output |
| ) |
| elif self.config.algorithm_type == "sde-dpmsolver++": |
| assert noise is not None |
| x_t = ( |
| sigma_t / sigma_s * torch.exp(-h) * sample |
| + alpha_t * (1 - torch.exp(-2.0 * h)) * model_output |
| + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise |
| ) |
| elif self.config.algorithm_type == "sde-dpmsolver": |
| assert noise is not None |
| x_t = ( |
| alpha_t / alpha_s * sample |
| - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output |
| + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise |
| ) |
| return x_t |
|
|
| def multistep_dpm_solver_second_order_update( |
| self, |
| model_output_list: List[torch.Tensor], |
| *args, |
| sample: torch.Tensor = None, |
| noise: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) |
| prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) |
| if sample is None: |
| if len(args) > 2: |
| sample = args[2] |
| else: |
| raise ValueError(" missing `sample` as a required keyward argument") |
| if timestep_list is not None: |
| deprecate( |
| "timestep_list", |
| "1.0.0", |
| "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", |
| ) |
| if prev_timestep is not None: |
| deprecate( |
| "prev_timestep", |
| "1.0.0", |
| "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", |
| ) |
| sigma_t, sigma_s0, sigma_s1 = ( |
| self.sigmas[self.step_index + 1], |
| self.sigmas[self.step_index], |
| self.sigmas[self.step_index - 1], |
| ) |
| alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) |
| alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) |
| alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) |
| lambda_t = torch.log(alpha_t) - torch.log(sigma_t) |
| lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) |
| lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) |
| m0, m1 = (model_output_list[-1], model_output_list[-2]) |
| h, h_0 = (lambda_t - lambda_s0, lambda_s0 - lambda_s1) |
| r0 = h_0 / h |
| D0, D1 = (m0, 1.0 / r0 * (m0 - m1)) |
| if self.config.algorithm_type == "dpmsolver++": |
| if self.config.solver_type == "midpoint": |
| x_t = ( |
| sigma_t / sigma_s0 * sample |
| - alpha_t * (torch.exp(-h) - 1.0) * D0 |
| - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 |
| ) |
| elif self.config.solver_type == "heun": |
| x_t = ( |
| sigma_t / sigma_s0 * sample |
| - alpha_t * (torch.exp(-h) - 1.0) * D0 |
| + alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0) * D1 |
| ) |
| elif self.config.algorithm_type == "dpmsolver": |
| if self.config.solver_type == "midpoint": |
| x_t = ( |
| alpha_t / alpha_s0 * sample |
| - sigma_t * (torch.exp(h) - 1.0) * D0 |
| - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 |
| ) |
| elif self.config.solver_type == "heun": |
| x_t = ( |
| alpha_t / alpha_s0 * sample |
| - sigma_t * (torch.exp(h) - 1.0) * D0 |
| - sigma_t * ((torch.exp(h) - 1.0) / h - 1.0) * D1 |
| ) |
| elif self.config.algorithm_type == "sde-dpmsolver++": |
| assert noise is not None |
| if self.config.solver_type == "midpoint": |
| x_t = ( |
| sigma_t / sigma_s0 * torch.exp(-h) * sample |
| + alpha_t * (1 - torch.exp(-2.0 * h)) * D0 |
| + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 |
| + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise |
| ) |
| elif self.config.solver_type == "heun": |
| x_t = ( |
| sigma_t / sigma_s0 * torch.exp(-h) * sample |
| + alpha_t * (1 - torch.exp(-2.0 * h)) * D0 |
| + alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0) * D1 |
| + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise |
| ) |
| elif self.config.algorithm_type == "sde-dpmsolver": |
| assert noise is not None |
| if self.config.solver_type == "midpoint": |
| x_t = ( |
| alpha_t / alpha_s0 * sample |
| - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 |
| - sigma_t * (torch.exp(h) - 1.0) * D1 |
| + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise |
| ) |
| elif self.config.solver_type == "heun": |
| x_t = ( |
| alpha_t / alpha_s0 * sample |
| - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 |
| - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 |
| + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise |
| ) |
| return x_t |
|
|
| def multistep_dpm_solver_third_order_update( |
| self, |
| model_output_list: List[torch.Tensor], |
| *args, |
| sample: torch.Tensor = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) |
| prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) |
| if sample is None: |
| if len(args) > 2: |
| sample = args[2] |
| else: |
| raise ValueError(" missing`sample` as a required keyward argument") |
| if timestep_list is not None: |
| deprecate( |
| "timestep_list", |
| "1.0.0", |
| "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", |
| ) |
| if prev_timestep is not None: |
| deprecate( |
| "prev_timestep", |
| "1.0.0", |
| "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", |
| ) |
| sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( |
| self.sigmas[self.step_index + 1], |
| self.sigmas[self.step_index], |
| self.sigmas[self.step_index - 1], |
| self.sigmas[self.step_index - 2], |
| ) |
| alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) |
| alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) |
| alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) |
| alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) |
| lambda_t = torch.log(alpha_t) - torch.log(sigma_t) |
| lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) |
| lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) |
| lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) |
| m0, m1, m2 = ( |
| model_output_list[-1], |
| model_output_list[-2], |
| model_output_list[-3], |
| ) |
| h, h_0, h_1 = ( |
| lambda_t - lambda_s0, |
| lambda_s0 - lambda_s1, |
| lambda_s1 - lambda_s2, |
| ) |
| r0, r1 = (h_0 / h, h_1 / h) |
| D0 = m0 |
| D1_0, D1_1 = (1.0 / r0 * (m0 - m1), 1.0 / r1 * (m1 - m2)) |
| D1 = D1_0 + r0 / (r0 + r1) * (D1_0 - D1_1) |
| D2 = 1.0 / (r0 + r1) * (D1_0 - D1_1) |
| if self.config.algorithm_type == "dpmsolver++": |
| x_t = ( |
| sigma_t / sigma_s0 * sample |
| - alpha_t * (torch.exp(-h) - 1.0) * D0 |
| + alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0) * D1 |
| - alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5) * D2 |
| ) |
| elif self.config.algorithm_type == "dpmsolver": |
| x_t = ( |
| alpha_t / alpha_s0 * sample |
| - sigma_t * (torch.exp(h) - 1.0) * D0 |
| - sigma_t * ((torch.exp(h) - 1.0) / h - 1.0) * D1 |
| - sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5) * D2 |
| ) |
| return x_t |
|
|
| def index_for_timestep(self, timestep, schedule_timesteps=None): |
| if schedule_timesteps is None: |
| schedule_timesteps = self.timesteps |
| index_candidates = (schedule_timesteps == timestep).nonzero() |
| if len(index_candidates) == 0: |
| step_index = len(self.timesteps) - 1 |
| elif len(index_candidates) > 1: |
| step_index = index_candidates[1].item() |
| else: |
| step_index = index_candidates[0].item() |
| return step_index |
|
|
| def _init_step_index(self, timestep): |
| if self.begin_index is None: |
| if isinstance(timestep, torch.Tensor): |
| timestep = timestep.to(self.timesteps.device) |
| self._step_index = self.index_for_timestep(timestep) |
| else: |
| self._step_index = self._begin_index |
|
|
| def step( |
| self, |
| model_output: torch.Tensor, |
| timestep: int, |
| sample: torch.Tensor, |
| generator=None, |
| variance_noise: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| ) -> Union[SchedulerOutput, Tuple]: |
| if self.num_inference_steps is None: |
| raise ValueError( |
| "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
| ) |
| if self.step_index is None: |
| self._init_step_index(timestep) |
| lower_order_final = self.step_index == len(self.timesteps) - 1 and ( |
| self.config.euler_at_final |
| or (self.config.lower_order_final and len(self.timesteps) < 15) |
| or self.config.final_sigmas_type == "zero" |
| ) |
| lower_order_second = ( |
| self.step_index == len(self.timesteps) - 2 |
| and self.config.lower_order_final |
| and (len(self.timesteps) < 15) |
| ) |
| model_output = self.convert_model_output(model_output, sample=sample) |
| for i in range(self.config.solver_order - 1): |
| self.model_outputs[i] = self.model_outputs[i + 1] |
| self.model_outputs[-1] = model_output |
| sample = sample.to(torch.float32) |
| if ( |
| self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] |
| and variance_noise is None |
| ): |
| noise = randn_tensor( |
| model_output.shape, |
| generator=generator, |
| device=model_output.device, |
| dtype=torch.float32, |
| ) |
| elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: |
| noise = variance_noise.to(device=model_output.device, dtype=torch.float32) |
| else: |
| noise = None |
| if ( |
| self.config.solver_order == 1 |
| or self.lower_order_nums < 1 |
| or lower_order_final |
| ): |
| prev_sample = self.dpm_solver_first_order_update( |
| model_output, sample=sample, noise=noise |
| ) |
| elif ( |
| self.config.solver_order == 2 |
| or self.lower_order_nums < 2 |
| or lower_order_second |
| ): |
| prev_sample = self.multistep_dpm_solver_second_order_update( |
| self.model_outputs, sample=sample, noise=noise |
| ) |
| else: |
| prev_sample = self.multistep_dpm_solver_third_order_update( |
| self.model_outputs, sample=sample |
| ) |
| if self.lower_order_nums < self.config.solver_order: |
| self.lower_order_nums += 1 |
| prev_sample = prev_sample.to(model_output.dtype) |
| self._step_index += 1 |
| if not return_dict: |
| return (prev_sample,) |
| return SchedulerOutput(prev_sample=prev_sample) |
|
|
| def add_noise( |
| self, |
| original_samples: torch.Tensor, |
| noise: torch.Tensor, |
| timesteps: torch.IntTensor, |
| ) -> torch.Tensor: |
| alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype) |
| sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype) |
| timesteps = timesteps.to(original_samples.device) |
| alpha_t = alpha_t[timesteps].flatten() |
| while len(alpha_t.shape) < len(original_samples.shape): |
| alpha_t = alpha_t.unsqueeze(-1) |
| sigma_t = sigma_t[timesteps].flatten() |
| while len(sigma_t.shape) < len(original_samples.shape): |
| sigma_t = sigma_t.unsqueeze(-1) |
| noisy_samples = alpha_t * original_samples + sigma_t * noise |
| return noisy_samples |
|
|
| def get_velocity( |
| self, |
| original_samples: torch.Tensor, |
| noise: torch.Tensor, |
| timesteps: torch.IntTensor, |
| ) -> torch.Tensor: |
| alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype) |
| sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype) |
| timesteps = timesteps.to(original_samples.device) |
| alpha_t = alpha_t[timesteps].flatten() |
| while len(alpha_t.shape) < len(original_samples.shape): |
| alpha_t = alpha_t.unsqueeze(-1) |
| sigma_t = sigma_t[timesteps].flatten() |
| while len(sigma_t.shape) < len(original_samples.shape): |
| sigma_t = sigma_t.unsqueeze(-1) |
| velocity = alpha_t * noise - sigma_t * original_samples |
| return velocity |
|
|
| def __len__(self): |
| return self.config.num_train_timesteps |
|
|
|
|
| '\nProcessor class for QWEN3Vox models.\n' |
| import os |
| import json |
| import warnings |
| from typing import List, Optional, Union, Dict, Any |
| import numpy as np |
| import torch |
| from transformers.feature_extraction_utils import FeatureExtractionMixin |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class AudioNormalizer: |
|
|
| def __init__(self, target_dB_FS: float = -25, eps: float = 1e-06): |
| self.target_dB_FS = target_dB_FS |
| self.eps = eps |
|
|
| def tailor_dB_FS(self, audio: np.ndarray) -> tuple: |
| rms = np.sqrt(np.mean(audio**2)) |
| scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps) |
| normalized_audio = audio * scalar |
| return (normalized_audio, rms, scalar) |
|
|
| def avoid_clipping( |
| self, audio: np.ndarray, scalar: Optional[float] = None |
| ) -> tuple: |
| if scalar is None: |
| max_val = np.max(np.abs(audio)) |
| if max_val > 1.0: |
| scalar = max_val + self.eps |
| else: |
| scalar = 1.0 |
| return (audio / scalar, scalar) |
|
|
| def __call__(self, audio: np.ndarray) -> np.ndarray: |
| audio, _, _ = self.tailor_dB_FS(audio) |
| audio, _ = self.avoid_clipping(audio) |
| return audio |
|
|
|
|
| class QWEN3VoxTokenizerProcessor(FeatureExtractionMixin): |
| model_input_names = ["input_features"] |
|
|
| def __init__( |
| self, |
| sampling_rate: int = 22050, |
| normalize_audio: bool = True, |
| target_dB_FS: float = -25, |
| eps: float = 1e-06, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.sampling_rate = sampling_rate |
| self.normalize_audio = normalize_audio |
| if self.normalize_audio: |
| self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps) |
| else: |
| self.normalizer = None |
| self.feature_extractor_dict = { |
| "sampling_rate": sampling_rate, |
| "normalize_audio": normalize_audio, |
| "target_dB_FS": target_dB_FS, |
| "eps": eps, |
| } |
|
|
| def _ensure_mono(self, audio: np.ndarray) -> np.ndarray: |
| if len(audio.shape) == 1: |
| return audio |
| elif len(audio.shape) == 2: |
| if audio.shape[0] == 2: |
| return np.mean(audio, axis=0) |
| elif audio.shape[1] == 2: |
| return np.mean(audio, axis=1) |
| elif audio.shape[0] == 1: |
| return audio.squeeze(0) |
| elif audio.shape[1] == 1: |
| return audio.squeeze(1) |
| else: |
| raise ValueError(f"Unexpected audio shape: {audio .shape }") |
| else: |
| raise ValueError(f"Audio should be 1D or 2D, got shape: {audio .shape }") |
|
|
| def _process_single_audio( |
| self, audio: Union[np.ndarray, List[float]] |
| ) -> np.ndarray: |
| if not isinstance(audio, np.ndarray): |
| audio = np.array(audio, dtype=np.float32) |
| else: |
| audio = audio.astype(np.float32) |
| audio = self._ensure_mono(audio) |
| if self.normalize_audio and self.normalizer is not None: |
| audio = self.normalizer(audio) |
| return audio |
|
|
| def __call__( |
| self, |
| audio: Union[ |
| str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str] |
| ] = None, |
| sampling_rate: Optional[int] = None, |
| return_tensors: Optional[str] = None, |
| **kwargs, |
| ): |
| if audio is None: |
| raise ValueError("Audio input is required") |
| if sampling_rate is not None and sampling_rate != self.sampling_rate: |
| logger.warning( |
| f"Input sampling rate ({sampling_rate }) differs from expected sampling rate ({self .sampling_rate }). Please resample your audio." |
| ) |
| if isinstance(audio, str): |
| audio = self._load_audio_from_path(audio) |
| is_batched = False |
| elif isinstance(audio, list): |
| if len(audio) == 0: |
| raise ValueError("Empty audio list provided") |
| if all((isinstance(item, str) for item in audio)): |
| audio = [self._load_audio_from_path(path) for path in audio] |
| is_batched = True |
| else: |
| is_batched = isinstance(audio[0], (np.ndarray, list)) |
| else: |
| is_batched = False |
| if is_batched: |
| processed_audio = [self._process_single_audio(a) for a in audio] |
| else: |
| processed_audio = [self._process_single_audio(audio)] |
| if return_tensors == "pt": |
| if len(processed_audio) == 1: |
| input_features = ( |
| torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1) |
| ) |
| else: |
| input_features = torch.stack( |
| [torch.from_numpy(a) for a in processed_audio] |
| ).unsqueeze(1) |
| elif return_tensors == "np": |
| if len(processed_audio) == 1: |
| input_features = processed_audio[0][np.newaxis, np.newaxis, :] |
| else: |
| input_features = np.stack(processed_audio)[:, np.newaxis, :] |
| else: |
| input_features = ( |
| processed_audio[0] if len(processed_audio) == 1 else processed_audio |
| ) |
| outputs = {"audio": input_features} |
| return outputs |
|
|
| def _load_audio_from_path(self, audio_path: str) -> np.ndarray: |
| file_ext = os.path.splitext(audio_path)[1].lower() |
| if file_ext in [".wav", ".mp3", ".flac", ".m4a", ".ogg"]: |
| import librosa |
|
|
| audio_array, sr = librosa.load(audio_path, sr=self.sampling_rate, mono=True) |
| return audio_array |
| elif file_ext == ".pt": |
| audio_tensor = torch.load(audio_path, map_location="cpu").squeeze() |
| if isinstance(audio_tensor, torch.Tensor): |
| audio_array = audio_tensor.numpy() |
| else: |
| audio_array = np.array(audio_tensor) |
| return audio_array.astype(np.float32) |
| elif file_ext == ".npy": |
| audio_array = np.load(audio_path) |
| return audio_array.astype(np.float32) |
| else: |
| raise ValueError( |
| f"Unsupported file format: {file_ext }. Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz" |
| ) |
|
|
| def preprocess_audio( |
| self, |
| audio_path_or_array: Union[str, np.ndarray], |
| normalize: Optional[bool] = None, |
| ) -> np.ndarray: |
| if isinstance(audio_path_or_array, str): |
| audio_array = self._load_audio_from_path(audio_path_or_array) |
| else: |
| audio_array = np.array(audio_path_or_array, dtype=np.float32) |
| original_normalize = self.normalize_audio |
| if normalize is not None: |
| self.normalize_audio = normalize |
| try: |
| processed = self._process_single_audio(audio_array) |
| finally: |
| self.normalize_audio = original_normalize |
| return processed |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| return self.feature_extractor_dict |
|
|
| def save_audio( |
| self, |
| audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], |
| output_path: str = "output.wav", |
| sampling_rate: Optional[int] = None, |
| normalize: bool = False, |
| batch_prefix: str = "audio_", |
| ): |
| if sampling_rate is None: |
| sampling_rate = self.sampling_rate |
| try: |
| import soundfile as sf |
| except ImportError: |
| raise ImportError( |
| "soundfile is required to save audio files. Install it with: pip install soundfile" |
| ) |
| if isinstance(audio, torch.Tensor): |
| audio_np = audio.float().detach().cpu().numpy() |
| elif isinstance(audio, np.ndarray): |
| audio_np = audio |
| elif isinstance(audio, list): |
| if all((isinstance(a, torch.Tensor) for a in audio)): |
| audio_np = [a.float().detach().cpu().numpy() for a in audio] |
| else: |
| audio_np = audio |
| else: |
| raise ValueError(f"Unsupported audio type: {type (audio )}") |
| saved_paths = [] |
| if isinstance(audio_np, list): |
| output_dir = output_path |
| os.makedirs(output_dir, exist_ok=True) |
| for i, audio_item in enumerate(audio_np): |
| audio_item = self._prepare_audio_for_save(audio_item, normalize) |
| file_path = os.path.join(output_dir, f"{batch_prefix }{i }.wav") |
| sf.write(file_path, audio_item, sampling_rate) |
| saved_paths.append(file_path) |
| elif len(audio_np.shape) >= 3: |
| batch_size = audio_np.shape[0] |
| if batch_size > 1: |
| output_dir = output_path |
| os.makedirs(output_dir, exist_ok=True) |
| for i in range(batch_size): |
| single_audio = audio_np[i] |
| if len(single_audio.shape) > 1: |
| if single_audio.shape[0] == 1: |
| single_audio = single_audio.squeeze(0) |
| single_audio = self._prepare_audio_for_save(single_audio, normalize) |
| file_path = os.path.join(output_dir, f"{batch_prefix }{i }.wav") |
| sf.write(file_path, single_audio, sampling_rate) |
| saved_paths.append(file_path) |
| else: |
| audio_item = audio_np.squeeze() |
| audio_item = self._prepare_audio_for_save(audio_item, normalize) |
| sf.write(output_path, audio_item, sampling_rate) |
| saved_paths.append(output_path) |
| else: |
| audio_item = self._prepare_audio_for_save(audio_np, normalize) |
| sf.write(output_path, audio_item, sampling_rate) |
| saved_paths.append(output_path) |
| return saved_paths |
|
|
| def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray: |
| if len(audio.shape) > 1 and audio.shape[0] == 1: |
| audio = audio.squeeze(0) |
| if normalize: |
| max_val = np.abs(audio).max() |
| if max_val > 0: |
| audio = audio / max_val |
| return audio |
|
|
|
|
| __all__ = [ |
| 'QWEN3VoxTokenizerProcessor', |
| "AudioNormalizer", |
| ] |
| import math |
| import torch |
|
|
|
|
| class UniformSampler: |
|
|
| def __init__(self, timesteps=1000): |
| self.timesteps = timesteps |
|
|
| def sample(self, batch_size, device): |
| return torch.randint(0, self.timesteps, (batch_size,), device=device) |
|
|
|
|
| class LogitNormalSampler: |
|
|
| def __init__(self, timesteps=1000, m=0, s=1): |
| self.timesteps = timesteps |
| timesteps = torch.linspace(0, 1, timesteps) |
| logit = torch.log(timesteps / (1 - timesteps)) |
| self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s**2) / ( |
| s * math.sqrt(2 * math.pi) |
| ) |
|
|
| def sample(self, batch_size, device): |
| return torch.multinomial(self.prob, batch_size, replacement=True).to(device) |
|
|
|
|
| ' QWEN3Vox Streaming model configuration' |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils import logging |
| from transformers.models.qwen2.configuration_qwen2 import Qwen2Config |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class QWEN3VoxStreamingConfig(PretrainedConfig): |
| model_type = 'vibevoice_streaming' |
| is_composition = True |
| sub_configs = { |
| "acoustic_tokenizer_config": QWEN3VoxAcousticTokenizerConfig, |
| "decoder_config": Qwen2Config, |
| "diffusion_head_config": QWEN3VoxDiffusionHeadConfig, |
| } |
| base_model_tp_plan = { |
| "layers.*.self_attn.q_proj": "colwise", |
| "layers.*.self_attn.k_proj": "colwise", |
| "layers.*.self_attn.v_proj": "colwise", |
| "layers.*.self_attn.o_proj": "rowwise", |
| "layers.*.mlp.gate_proj": "colwise", |
| "layers.*.mlp.up_proj": "colwise", |
| "layers.*.mlp.down_proj": "rowwise", |
| } |
|
|
| def __init__( |
| self, |
| acoustic_tokenizer_config=None, |
| decoder_config=None, |
| diffusion_head_config=None, |
| tts_backbone_num_hidden_layers=20, |
| **kwargs, |
| ): |
| kwargs["_attn_implementation_autoset"] = False |
| if acoustic_tokenizer_config is None: |
| self.acoustic_tokenizer_config = self.sub_configs[ |
| "acoustic_tokenizer_config" |
| ]() |
| elif isinstance(acoustic_tokenizer_config, dict): |
| acoustic_tokenizer_config["model_type"] = 'vibevoice_acoustic_tokenizer' |
| self.acoustic_tokenizer_config = self.sub_configs[ |
| "acoustic_tokenizer_config" |
| ](**acoustic_tokenizer_config) |
| elif isinstance(acoustic_tokenizer_config, QWEN3VoxAcousticTokenizerConfig): |
| self.acoustic_tokenizer_config = acoustic_tokenizer_config |
| if decoder_config is None: |
| self.decoder_config = self.sub_configs["decoder_config"]() |
| elif isinstance(decoder_config, dict): |
| if decoder_config.get("model_type", "") == "qwen2": |
| self.decoder_config = Qwen2Config(**decoder_config) |
| else: |
| raise ValueError( |
| f"Unsupported decoder model type: {decoder_config .get ('model_type','')}" |
| ) |
| elif isinstance(decoder_config, (Qwen2Config,)): |
| self.decoder_config = decoder_config |
| if diffusion_head_config is None: |
| self.diffusion_head_config = self.sub_configs["diffusion_head_config"]() |
| elif isinstance(diffusion_head_config, dict): |
| diffusion_head_config["model_type"] = 'vibevoice_diffusion_head' |
| self.diffusion_head_config = self.sub_configs["diffusion_head_config"]( |
| **diffusion_head_config |
| ) |
| elif isinstance(diffusion_head_config, QWEN3VoxDiffusionHeadConfig): |
| self.diffusion_head_config = diffusion_head_config |
| self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, "vae_dim", 64) |
| self.tts_backbone_num_hidden_layers = tts_backbone_num_hidden_layers |
| super().__init__(**kwargs) |
|
|
|
|
| __all__ = [ |
| 'QWEN3VoxStreamingConfig' |
| ] |
| import math |
| from typing import Optional, Tuple, Union |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers.models.auto import AutoModel |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.activations import ACT2FN |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class RMSNorm(nn.Module): |
|
|
| def __init__( |
| self, |
| dim: int, |
| eps: float = 1e-06, |
| elementwise_affine=True, |
| memory_efficient=False, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.eps = eps |
| self.elementwise_affine = elementwise_affine |
| if self.elementwise_affine: |
| self.weight = nn.Parameter(torch.ones(dim)) |
| else: |
| self.register_parameter("weight", None) |
|
|
| def _norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| output = self._norm(x.float()).type_as(x) |
| if self.weight is not None: |
| output = output * self.weight |
| return output |
|
|
| def extra_repr(self) -> str: |
| return f"dim={self .dim }, eps={self .eps }, elementwise_affine={self .elementwise_affine }" |
|
|
|
|
| def modulate(x, shift, scale): |
| return x * (1 + scale) + shift |
|
|
|
|
| class TimestepEmbedder(nn.Module): |
|
|
| def __init__(self, hidden_size, frequency_embedding_size=256): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(frequency_embedding_size, hidden_size, bias=False), |
| ACT2FN["silu"], |
| nn.Linear(hidden_size, hidden_size, bias=False), |
| ) |
| self.frequency_embedding_size = frequency_embedding_size |
|
|
| @staticmethod |
| def timestep_embedding(t, dim, max_period=10000): |
| half = dim // 2 |
| freqs = torch.exp( |
| -math.log(max_period) |
| * torch.arange(start=0, end=half, dtype=torch.float32) |
| / half |
| ).to(t.device) |
| args = t[:, None].float() * freqs[None] |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if dim % 2: |
| embedding = torch.cat( |
| [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 |
| ) |
| return embedding.to(t.dtype) |
|
|
| def forward(self, t): |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
| t_emb = self.mlp(t_freq) |
| return t_emb |
|
|
|
|
| class FeedForwardNetwork(nn.Module): |
|
|
| def __init__(self, embed_dim, ffn_dim): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) |
| self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) |
| self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False) |
| self.act_fn = ACT2FN["silu"] |
|
|
| def forward(self, x): |
| gate = self.gate_proj(x) |
| up = self.up_proj(x) |
| gate = self.act_fn(gate) |
| return self.down_proj(gate * up) |
|
|
|
|
| class HeadLayer(nn.Module): |
|
|
| def __init__(self, embed_dim, ffn_dim, cond_dim, norm_eps=1e-05): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.cond_dim = cond_dim |
| self.ffn_dim = ffn_dim |
| self.ffn = FeedForwardNetwork(self.embed_dim, self.ffn_dim) |
| self.norm = RMSNorm(self.embed_dim, eps=norm_eps) |
| self.adaLN_modulation = nn.Sequential( |
| ACT2FN["silu"], nn.Linear(cond_dim, 3 * self.embed_dim, bias=False) |
| ) |
|
|
| def forward(self, x, c): |
| shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1) |
| x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn)) |
| return x |
|
|
|
|
| class FinalLayer(nn.Module): |
|
|
| def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-05): |
| super().__init__() |
| self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False) |
| self.linear = nn.Linear(hidden_size, output_size, bias=False) |
| self.adaLN_modulation = nn.Sequential( |
| ACT2FN["silu"], nn.Linear(cond_size, 2 * hidden_size, bias=False) |
| ) |
|
|
| def forward(self, x, c): |
| shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) |
| x = modulate(self.norm_final(x), shift, scale) |
| x = self.linear(x) |
| return x |
|
|
|
|
| class QWEN3VoxDiffusionHead(PreTrainedModel): |
| config_class = QWEN3VoxDiffusionHeadConfig |
| supports_gradient_checkpointing = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| self.cond_dim = config.hidden_size |
| latent_size = config.latent_size |
| self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False) |
| self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False) |
| self.t_embedder = TimestepEmbedder(self.cond_dim) |
| ffn_dim = int(config.hidden_size * config.head_ffn_ratio) |
| self.layers = nn.ModuleList( |
| [ |
| HeadLayer( |
| embed_dim=config.hidden_size, |
| ffn_dim=ffn_dim, |
| cond_dim=self.cond_dim, |
| norm_eps=config.rms_norm_eps, |
| ) |
| for _ in range(config.head_layers) |
| ] |
| ) |
| self.final_layer = FinalLayer( |
| hidden_size=config.hidden_size, |
| output_size=latent_size, |
| cond_size=self.cond_dim, |
| norm_eps=config.rms_norm_eps, |
| ) |
| self.initialize_weights() |
|
|
| def initialize_weights(self): |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
| for layer in self.layers: |
| nn.init.constant_(layer.adaLN_modulation[-1].weight, 0) |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) |
| nn.init.constant_(self.final_layer.linear.weight, 0) |
|
|
| def forward(self, noisy_images, timesteps, condition): |
| x = self.noisy_images_proj(noisy_images) |
| t = self.t_embedder(timesteps) |
| condition = self.cond_proj(condition) |
| c = condition + t |
| for layer in self.layers: |
| x = layer(x, c) |
| x = self.final_layer(x, c) |
| return x |
|
|
|
|
| AutoModel.register(QWEN3VoxDiffusionHeadConfig, QWEN3VoxDiffusionHead) |
| __all__ = [ |
| 'QWEN3VoxDiffusionHead' |
| ] |
| import math |
| import typing as tp |
| from functools import partial |
| from dataclasses import dataclass, field |
| from typing import Dict, List, Optional, Tuple, Union |
| import copy |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers.models.auto import AutoModel |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils import logging |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.activations import ACT2FN |
|
|
| logger = logging.get_logger(__name__) |
| import os |
|
|
| try: |
| from apex.normalization.fused_layer_norm import fused_rms_norm_affine |
|
|
| APEX_AVAILABLE = True |
| logger.info("APEX FusedRMSNorm is available and will be used for optimization") |
| if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0: |
| APEX_AVAILABLE = False |
| logger.warning( |
| "APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0" |
| ) |
| except ImportError: |
| APEX_AVAILABLE = False |
| logger.warning("APEX FusedRMSNorm not available, using native implementation") |
|
|
|
|
| class ConvLayerNorm(nn.LayerNorm): |
|
|
| def __init__( |
| self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs |
| ): |
| super().__init__(normalized_shape, **kwargs) |
|
|
| def forward(self, x): |
| x = x.transpose(1, 2) |
| x = nn.functional.layer_norm( |
| x.float(), |
| self.normalized_shape, |
| self.weight.float(), |
| self.bias.float(), |
| self.eps, |
| ).type_as(x) |
| x = x.transpose(1, 2) |
| return x |
|
|
|
|
| class RMSNorm(nn.Module): |
|
|
| def __init__( |
| self, dim: int, eps: float = 1e-05, elementwise_affine=True, weight_shape=None |
| ): |
| super().__init__() |
| self.dim = dim |
| self.eps = eps |
| self.elementwise_affine = elementwise_affine |
| if self.elementwise_affine: |
| weight_shape = (dim,) if weight_shape is None else weight_shape |
| self.weight = nn.Parameter(torch.ones(weight_shape)) |
| else: |
| self.register_parameter("weight", None) |
|
|
| def _norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| output = self._norm(x.float()).type_as(x) |
| if self.weight is not None: |
| output = output * self.weight |
| return output |
|
|
| def extra_repr(self) -> str: |
| return f"dim={self .dim }, eps={self .eps }, elementwise_affine={self .elementwise_affine }" |
|
|
|
|
| class ConvRMSNorm(RMSNorm): |
|
|
| def __init__( |
| self, dim: int, eps: float = 1e-05, elementwise_affine=True, weight_shape=None |
| ): |
| super().__init__(dim, eps, elementwise_affine, weight_shape) |
|
|
| def forward(self, x): |
| x = x.transpose(1, 2) |
| if not APEX_AVAILABLE or not self.elementwise_affine: |
| output = self._norm(x.float()).type_as(x) |
| if self.weight is not None: |
| output = output * self.weight |
| else: |
| output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps) |
| output = output.transpose(1, 2) |
| return output |
|
|
|
|
| CONV_NORMALIZATIONS = frozenset( |
| [ |
| "none", |
| "weight_norm", |
| "spectral_norm", |
| "time_layer_norm", |
| "layer_norm", |
| "time_group_norm", |
| ] |
| ) |
|
|
|
|
| def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module: |
| assert norm in CONV_NORMALIZATIONS |
| if norm == "weight_norm": |
| return nn.utils.weight_norm(module) |
| elif norm == "spectral_norm": |
| return nn.utils.spectral_norm(module) |
| else: |
| return module |
|
|
|
|
| def get_norm_module( |
| module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs |
| ) -> nn.Module: |
| assert norm in CONV_NORMALIZATIONS |
| if norm == "layer_norm": |
| assert isinstance(module, nn.modules.conv._ConvNd) |
| return ConvLayerNorm(module.out_channels, **norm_kwargs) |
| elif norm == "time_group_norm": |
| if causal: |
| raise ValueError("GroupNorm doesn't support causal evaluation.") |
| assert isinstance(module, nn.modules.conv._ConvNd) |
| return nn.GroupNorm(1, module.out_channels, **norm_kwargs) |
| else: |
| return nn.Identity() |
|
|
|
|
| def get_extra_padding_for_conv1d( |
| x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 |
| ) -> int: |
| length = x.shape[-1] |
| n_frames = (length - kernel_size + padding_total) / stride + 1 |
| ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) |
| return ideal_length - length |
|
|
|
|
| def pad1d( |
| x: torch.Tensor, |
| paddings: tp.Tuple[int, int], |
| mode: str = "zero", |
| value: float = 0.0, |
| ): |
| length = x.shape[-1] |
| padding_left, padding_right = paddings |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) |
| if mode == "reflect": |
| max_pad = max(padding_left, padding_right) |
| extra_pad = 0 |
| if length <= max_pad: |
| extra_pad = max_pad - length + 1 |
| x = F.pad(x, (0, extra_pad)) |
| padded = F.pad(x, paddings, mode, value) |
| end = padded.shape[-1] - extra_pad |
| return padded[..., :end] |
| else: |
| return F.pad(x, paddings, mode, value) |
|
|
|
|
| def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): |
| padding_left, padding_right = paddings |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) |
| assert padding_left + padding_right <= x.shape[-1] |
| end = x.shape[-1] - padding_right |
| return x[..., padding_left:end] |
|
|
|
|
| class NormConv1d(nn.Module): |
|
|
| def __init__( |
| self, |
| *args, |
| causal: bool = False, |
| norm: str = "none", |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, |
| **kwargs, |
| ): |
| super().__init__() |
| self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) |
| self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) |
| self.norm_type = norm |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| x = self.norm(x) |
| return x |
|
|
|
|
| class NormConvTranspose1d(nn.Module): |
|
|
| def __init__( |
| self, |
| *args, |
| causal: bool = False, |
| norm: str = "none", |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, |
| **kwargs, |
| ): |
| super().__init__() |
| self.convtr = apply_parametrization_norm( |
| nn.ConvTranspose1d(*args, **kwargs), norm |
| ) |
| self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) |
| self.norm_type = norm |
|
|
| def forward(self, x): |
| x = self.convtr(x) |
| x = self.norm(x) |
| return x |
|
|
|
|
| class QWEN3VoxTokenizerStreamingCache: |
|
|
| def __init__(self): |
| self.cache = {} |
|
|
| def get( |
| self, layer_id: str, sample_indices: torch.Tensor |
| ) -> Optional[torch.Tensor]: |
| states = [] |
| max_length = 0 |
| for idx in sample_indices.tolist(): |
| key = (layer_id, idx) |
| if key not in self.cache: |
| return None |
| state = self.cache[key] |
| states.append(state) |
| max_length = max(max_length, state.shape[-1]) |
| if len(states) > 0 and states[0].dim() >= 2: |
| padded_states = [] |
| for state in states: |
| if state.shape[-1] < max_length: |
| pad_size = max_length - state.shape[-1] |
| padded_state = F.pad(state, (pad_size, 0), mode="constant", value=0) |
| padded_states.append(padded_state) |
| else: |
| padded_states.append(state) |
| return torch.stack(padded_states, dim=0) |
| else: |
| return torch.stack(states, dim=0) |
|
|
| def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor): |
| for i, idx in enumerate(sample_indices.tolist()): |
| key = (layer_id, idx) |
| self.cache[key] = states[i].detach() |
|
|
| def set_to_zero(self, sample_indices: torch.Tensor): |
| for key in list(self.cache.keys()): |
| layer_id, sample_idx = key |
| if sample_idx in sample_indices.tolist(): |
| cached_tensor = self.cache[key] |
| self.cache[key] = torch.zeros_like(cached_tensor) |
|
|
| def clear( |
| self, |
| layer_id: Optional[str] = None, |
| sample_indices: Optional[torch.Tensor] = None, |
| ): |
| if layer_id is None and sample_indices is None: |
| self.cache.clear() |
| elif layer_id is not None and sample_indices is None: |
| keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id] |
| for k in keys_to_remove: |
| del self.cache[k] |
| elif layer_id is not None and sample_indices is not None: |
| for idx in sample_indices.tolist(): |
| key = (layer_id, idx) |
| self.cache.pop(key, None) |
|
|
|
|
| class SConv1d(nn.Module): |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| stride: int = 1, |
| dilation: int = 1, |
| groups: int = 1, |
| bias: bool = True, |
| causal: bool = False, |
| norm: str = "none", |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, |
| pad_mode: str = "reflect", |
| ): |
| super().__init__() |
| self.conv = NormConv1d( |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride, |
| dilation=dilation, |
| groups=groups, |
| bias=bias, |
| causal=causal, |
| norm=norm, |
| norm_kwargs=norm_kwargs, |
| ) |
| self.causal = causal |
| self.pad_mode = pad_mode |
| self.kernel_size = kernel_size |
| self.dilation = dilation |
| self.stride = stride |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.context_size = (kernel_size - 1) * dilation - (stride - 1) |
| self.padding_total = (kernel_size - 1) * dilation - (stride - 1) |
| self._layer_id = None |
|
|
| @property |
| def layer_id(self): |
| if self._layer_id is None: |
| self._layer_id = f"sconv1d_{id (self )}" |
| return self._layer_id |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| cache: Optional[QWEN3VoxTokenizerStreamingCache] = None, |
| sample_indices: Optional[torch.Tensor] = None, |
| use_cache: bool = False, |
| debug: bool = False, |
| ) -> torch.Tensor: |
| B, C, T = x.shape |
| if not use_cache or cache is None: |
| return self._forward_non_streaming(x, debug=debug) |
| assert self.causal, "Streaming mode is only supported for causal convolutions" |
| assert ( |
| sample_indices is not None |
| ), "sample_indices must be provided for streaming mode" |
| assert len(sample_indices) == B, "sample_indices must match batch size" |
| return self._forward_streaming(x, cache, sample_indices, debug) |
|
|
| def _forward_streaming( |
| self, |
| x: torch.Tensor, |
| cache: QWEN3VoxTokenizerStreamingCache, |
| sample_indices: torch.Tensor, |
| debug: bool = False, |
| ) -> torch.Tensor: |
| B, C, T = x.shape |
| cached_states = cache.get(self.layer_id, sample_indices) |
| if cached_states is None: |
| if self.context_size > 0: |
| cached_states = torch.zeros( |
| B, C, self.context_size, device=x.device, dtype=x.dtype |
| ) |
| if debug: |
| print( |
| f"[DEBUG] Initialized cache with shape: {cached_states .shape }, context_size={self .context_size }" |
| ) |
| else: |
| cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype) |
| if debug: |
| print(f"[DEBUG] No context needed (kernel_size=stride)") |
| if cached_states.shape[2] > 0: |
| input_with_context = torch.cat([cached_states, x], dim=2) |
| else: |
| input_with_context = x |
| if debug: |
| print( |
| f"[DEBUG] Input shape: {x .shape }, Cache shape: {cached_states .shape }, Combined: {input_with_context .shape }" |
| ) |
| output = self.conv(input_with_context) |
| if debug: |
| print(f"[DEBUG] Output shape: {output .shape }") |
| if self.context_size > 0: |
| total_input_length = input_with_context.shape[2] |
| if total_input_length >= self.context_size: |
| new_cache_start = total_input_length - self.context_size |
| new_cache = input_with_context[:, :, new_cache_start:] |
| else: |
| new_cache = input_with_context |
| if debug: |
| print(f"[DEBUG] New cache shape: {new_cache .shape }") |
| cache.set(self.layer_id, sample_indices, new_cache) |
| return output |
|
|
| def _forward_non_streaming( |
| self, x: torch.Tensor, debug: bool = False |
| ) -> torch.Tensor: |
| B, C, T = x.shape |
| kernel_size = self.kernel_size |
| stride = self.stride |
| dilation = self.dilation |
| padding_total = self.padding_total |
| extra_padding = get_extra_padding_for_conv1d( |
| x, kernel_size, stride, padding_total |
| ) |
| if debug: |
| print( |
| f"[DEBUG NON-STREAMING] Input shape: {x .shape }, padding_total={padding_total }, extra_padding={extra_padding }" |
| ) |
| if self.causal: |
| if self.pad_mode == "constant": |
| x = pad1d( |
| x, (padding_total, extra_padding), mode=self.pad_mode, value=0 |
| ) |
| else: |
| x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) |
| else: |
| padding_right = padding_total // 2 |
| padding_left = padding_total - padding_right |
| x = pad1d( |
| x, (padding_left, padding_right + extra_padding), mode=self.pad_mode |
| ) |
| if debug: |
| print(f"[DEBUG NON-STREAMING] After padding: {x .shape }") |
| output = self.conv(x) |
| if debug: |
| print(f"[DEBUG NON-STREAMING] Output shape: {output .shape }") |
| return output |
|
|
|
|
| class SConvTranspose1d(nn.Module): |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| stride: int = 1, |
| causal: bool = False, |
| norm: str = "none", |
| trim_right_ratio: float = 1.0, |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, |
| bias: bool = True, |
| ): |
| super().__init__() |
| self.convtr = NormConvTranspose1d( |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride, |
| causal=causal, |
| norm=norm, |
| norm_kwargs=norm_kwargs, |
| bias=bias, |
| ) |
| self.causal = causal |
| self.trim_right_ratio = trim_right_ratio |
| assert ( |
| self.causal or self.trim_right_ratio == 1.0 |
| ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" |
| assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.padding_total = kernel_size - stride |
| self.context_size = kernel_size - 1 |
| self._layer_id = None |
|
|
| @property |
| def layer_id(self): |
| if self._layer_id is None: |
| self._layer_id = f"sconvtr1d_{id (self )}" |
| return self._layer_id |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| cache: Optional[QWEN3VoxTokenizerStreamingCache] = None, |
| sample_indices: Optional[torch.Tensor] = None, |
| use_cache: bool = False, |
| debug: bool = False, |
| ) -> torch.Tensor: |
| B, C, T = x.shape |
| if not use_cache or cache is None: |
| return self._forward_non_streaming(x, debug=debug) |
| assert ( |
| sample_indices is not None |
| ), "sample_indices must be provided for streaming mode" |
| assert len(sample_indices) == B, "sample_indices must match batch size" |
| return self._forward_streaming(x, cache, sample_indices, debug) |
|
|
| def _forward_streaming( |
| self, |
| x: torch.Tensor, |
| cache: QWEN3VoxTokenizerStreamingCache, |
| sample_indices: torch.Tensor, |
| debug: bool = False, |
| ) -> torch.Tensor: |
| B, C, T = x.shape |
| cached_input = cache.get(self.layer_id, sample_indices) |
| if cached_input is None: |
| cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype) |
| if debug: |
| print(f"[DEBUG] Initialized empty cache for transposed conv") |
| full_input = torch.cat([cached_input, x], dim=2) |
| if debug: |
| print( |
| f"[DEBUG] Input shape: {x .shape }, Cache shape: {cached_input .shape }, Combined: {full_input .shape }" |
| ) |
| full_output = self.convtr(full_input) |
| if debug: |
| print(f"[DEBUG] Full transposed conv output shape: {full_output .shape }") |
| if self.causal: |
| padding_right = math.ceil(self.padding_total * self.trim_right_ratio) |
| padding_left = self.padding_total - padding_right |
| else: |
| padding_right = self.padding_total // 2 |
| padding_left = self.padding_total - padding_right |
| if padding_left + padding_right > 0: |
| full_output = unpad1d(full_output, (padding_left, padding_right)) |
| if debug: |
| print(f"[DEBUG] After unpadding: {full_output .shape }") |
| if cached_input.shape[2] == 0: |
| output = full_output |
| else: |
| expected_new_output = T * self.stride |
| if full_output.shape[2] >= expected_new_output: |
| output = full_output[:, :, -expected_new_output:] |
| else: |
| output = full_output |
| if debug: |
| print(f"[DEBUG] Final streaming output shape: {output .shape }") |
| if full_input.shape[2] > self.context_size: |
| new_cache = full_input[:, :, -self.context_size :] |
| else: |
| new_cache = full_input |
| if debug: |
| print(f"[DEBUG] New cache shape: {new_cache .shape }") |
| cache.set(self.layer_id, sample_indices, new_cache) |
| return output |
|
|
| def _forward_non_streaming( |
| self, x: torch.Tensor, debug: bool = False |
| ) -> torch.Tensor: |
| if debug: |
| print(f"[DEBUG NON-STREAMING] Input shape: {x .shape }") |
| y = self.convtr(x) |
| if debug: |
| print(f"[DEBUG NON-STREAMING] After transposed conv: {y .shape }") |
| if self.causal: |
| padding_right = math.ceil(self.padding_total * self.trim_right_ratio) |
| padding_left = self.padding_total - padding_right |
| else: |
| padding_right = self.padding_total // 2 |
| padding_left = self.padding_total - padding_right |
| if padding_left + padding_right > 0: |
| y = unpad1d(y, (padding_left, padding_right)) |
| if debug: |
| print(f"[DEBUG NON-STREAMING] Final output shape: {y .shape }") |
| return y |
|
|
|
|
| class FFN(nn.Module): |
|
|
| def __init__(self, embed_dim, ffn_dim, bias=False): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias) |
| self.gelu = ACT2FN["gelu"] |
| self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias) |
|
|
| def forward(self, x): |
| x = self.linear1(x) |
| x = self.gelu(x) |
| x = self.linear2(x) |
| return x |
|
|
|
|
| class Convlayer(nn.Module): |
|
|
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=1, |
| dilation=1, |
| groups=1, |
| bias=True, |
| pad_mode="zeros", |
| norm="weight_norm", |
| causal=True, |
| ): |
| super().__init__() |
| self.conv = SConv1d( |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=stride, |
| dilation=dilation, |
| groups=groups, |
| bias=bias, |
| pad_mode=pad_mode, |
| norm=norm, |
| causal=causal, |
| ) |
|
|
| def forward(self, x): |
| return self.conv(x) |
|
|
|
|
| class Block1D(nn.Module): |
|
|
| def __init__( |
| self, |
| dim, |
| kernel_size=7, |
| drop_path=0.0, |
| mixer_layer="conv", |
| layer_scale_init_value=1e-06, |
| **kwargs, |
| ): |
| super().__init__() |
| if kwargs.get("layernorm", "LN") == "LN": |
| self.norm = ConvLayerNorm(dim, eps=kwargs.get("eps", 1e-06)) |
| self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get("eps", 1e-06)) |
| elif kwargs.get("layernorm", "RMSNorm") == "RMSNorm": |
| self.norm = ConvRMSNorm(dim, eps=kwargs.get("eps", 1e-06)) |
| self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get("eps", 1e-06)) |
| if mixer_layer == "conv": |
| self.mixer = Convlayer( |
| dim, |
| dim, |
| groups=kwargs.get("groups", 1), |
| kernel_size=kernel_size, |
| pad_mode=kwargs.get("pad_mode", "reflect"), |
| norm=kwargs.get("norm", "none"), |
| causal=kwargs.get("causal", True), |
| bias=kwargs.get("bias", True), |
| ) |
| elif mixer_layer == "depthwise_conv": |
| self.mixer = Convlayer( |
| dim, |
| dim, |
| groups=dim, |
| kernel_size=kernel_size, |
| pad_mode=kwargs.get("pad_mode", "reflect"), |
| norm=kwargs.get("norm", "none"), |
| causal=kwargs.get("causal", True), |
| bias=kwargs.get("bias", True), |
| ) |
| else: |
| raise ValueError(f"Unsupported mixer layer: {mixer_layer }") |
| self.ffn = FFN( |
| dim, kwargs.get("ffn_expansion", 4) * dim, bias=kwargs.get("bias", False) |
| ) |
| self.drop_path = ( |
| nn.Identity() if drop_path <= 0.0 else nn.modules.DropPath(drop_path) |
| ) |
| if layer_scale_init_value > 0: |
| self.gamma = nn.Parameter( |
| layer_scale_init_value * torch.ones(dim), requires_grad=True |
| ) |
| self.ffn_gamma = nn.Parameter( |
| layer_scale_init_value * torch.ones(dim), requires_grad=True |
| ) |
| else: |
| self.gamma = None |
| self.ffn_gamma = None |
|
|
| def forward(self, x): |
| residual = x |
| x = self.norm(x) |
| x = self.mixer(x) |
| if self.gamma is not None: |
| x = x * self.gamma.unsqueeze(-1) |
| x = residual + self.drop_path(x) |
| residual = x |
| x = self.ffn_norm(x) |
| x = x.permute(0, 2, 1) |
| x = self.ffn(x) |
| x = x.permute(0, 2, 1) |
| if self.ffn_gamma is not None: |
| x = x * self.ffn_gamma.unsqueeze(-1) |
| x = residual + self.drop_path(x) |
| return x |
|
|
|
|
| class TokenizerEncoder(nn.Module): |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.channels = config.channels |
| self.dimension = config.dimension |
| self.n_filters = config.n_filters |
| self.ratios = list(reversed(config.ratios)) |
| self.depths = config.depths |
| self.n_residual_layers = getattr(config, "n_residual_layers", 1) |
| self.hop_length = np.prod(self.ratios) |
| self.causal = config.causal |
| kernel_size = getattr(config, "kernel_size", 7) |
| last_kernel_size = getattr(config, "last_kernel_size", 7) |
| norm = getattr(config, "norm", "none") |
| norm_params = getattr(config, "norm_params", {}) |
| pad_mode = getattr(config, "pad_mode", "reflect") |
| bias = getattr(config, "bias", True) |
| layernorm = getattr(config, "layernorm", "LN") |
| layernorm_eps = getattr(config, "layernorm_eps", 1e-06) |
| layernorm_elementwise_affine = getattr( |
| config, "layernorm_elementwise_affine", True |
| ) |
| drop_path_rate = getattr(config, "drop_path_rate", 0.0) |
| mixer_layer = getattr(config, "mixer_layer", "conv") |
| layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) |
| disable_last_norm = getattr(config, "disable_last_norm", False) |
| if layernorm == "LN": |
| norm_type = ConvLayerNorm |
| elif layernorm == "RMSNorm": |
| norm_type = partial( |
| ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine |
| ) |
| else: |
| raise ValueError(f"Unsupported norm type: {layernorm }") |
| stem = nn.Sequential( |
| SConv1d( |
| self.channels, |
| self.n_filters, |
| kernel_size, |
| norm=norm, |
| norm_kwargs=norm_params, |
| causal=self.causal, |
| pad_mode=pad_mode, |
| bias=bias, |
| ) |
| ) |
| self.downsample_layers = nn.ModuleList() |
| self.downsample_layers.append(stem) |
| for i in range(len(self.ratios)): |
| in_ch = self.n_filters * 2**i |
| out_ch = self.n_filters * 2 ** (i + 1) |
| downsample_layer = nn.Sequential( |
| SConv1d( |
| in_ch, |
| out_ch, |
| kernel_size=self.ratios[i] * 2, |
| stride=self.ratios[i], |
| causal=self.causal, |
| pad_mode=pad_mode, |
| norm=norm, |
| bias=bias, |
| ) |
| ) |
| self.downsample_layers.append(downsample_layer) |
| layer_type = partial( |
| Block1D, |
| mixer_layer=mixer_layer, |
| layernorm=layernorm, |
| eps=layernorm_eps, |
| causal=self.causal, |
| pad_mode=pad_mode, |
| norm=norm, |
| bias=bias, |
| layer_scale_init_value=layer_scale_init_value, |
| ) |
| self.stages = nn.ModuleList() |
| dp_rates = [ |
| x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths)) |
| ] |
| cur = 0 |
| for i in range(len(self.depths)): |
| in_ch = self.n_filters * 2**i |
| stage = nn.Sequential( |
| *[ |
| layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) |
| for j in range(self.depths[i]) |
| ] |
| ) |
| self.stages.append(stage) |
| cur += self.depths[i] |
| if not disable_last_norm: |
| self.norm = norm_type(in_ch, eps=layernorm_eps) |
| else: |
| self.norm = nn.Identity() |
| self.head = SConv1d( |
| in_ch, |
| self.dimension, |
| kernel_size=last_kernel_size, |
| causal=self.causal, |
| pad_mode=pad_mode, |
| norm=norm, |
| bias=bias, |
| ) |
|
|
| def forward_features( |
| self, x, cache=None, sample_indices=None, use_cache=False, debug=False |
| ): |
| for i in range(len(self.depths)): |
| for layer in self.downsample_layers[i]: |
| if isinstance(layer, SConv1d): |
| x = layer( |
| x, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| else: |
| x = layer(x) |
| for block in self.stages[i]: |
| if ( |
| hasattr(block, "mixer") |
| and hasattr(block.mixer, "conv") |
| and isinstance(block.mixer.conv, SConv1d) |
| ): |
| residual = x |
| x = block.norm(x) |
| x = block.mixer.conv( |
| x, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| if block.gamma is not None: |
| x = x * block.gamma.unsqueeze(-1) |
| x = residual + x |
| residual = x |
| x = block.ffn_norm(x) |
| x = x.permute(0, 2, 1) |
| x = block.ffn(x) |
| x = x.permute(0, 2, 1) |
| if block.ffn_gamma is not None: |
| x = x * block.ffn_gamma.unsqueeze(-1) |
| x = residual + x |
| else: |
| x = block(x) |
| return self.norm(x) |
|
|
| def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): |
| x = self.forward_features( |
| x, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| x = self.head( |
| x, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| return x |
|
|
|
|
| class TokenizerDecoder(nn.Module): |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.dimension = config.dimension |
| self.channels = config.channels |
| self.n_filters = config.n_filters |
| self.ratios = config.ratios |
| self.depths = config.depths |
| self.n_residual_layers = getattr(config, "n_residual_layers", 1) |
| self.hop_length = np.prod(self.ratios) |
| self.causal = config.causal |
| kernel_size = getattr(config, "kernel_size", 7) |
| last_kernel_size = getattr(config, "last_kernel_size", 7) |
| norm = getattr(config, "norm", "none") |
| norm_params = getattr(config, "norm_params", {}) |
| pad_mode = getattr(config, "pad_mode", "reflect") |
| bias = getattr(config, "bias", True) |
| layernorm = getattr(config, "layernorm", "LN") |
| layernorm_eps = getattr(config, "layernorm_eps", 1e-06) |
| trim_right_ratio = getattr(config, "trim_right_ratio", 1.0) |
| layernorm_elementwise_affine = getattr( |
| config, "layernorm_elementwise_affine", True |
| ) |
| drop_path_rate = getattr(config, "drop_path_rate", 0.0) |
| mixer_layer = getattr(config, "mixer_layer", "conv") |
| layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) |
| disable_last_norm = getattr(config, "disable_last_norm", False) |
| if layernorm == "LN": |
| norm_type = ConvLayerNorm |
| elif layernorm == "RMSNorm": |
| norm_type = partial( |
| ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine |
| ) |
| else: |
| raise ValueError(f"Unsupported norm type: {layernorm }") |
| stem = nn.Sequential( |
| SConv1d( |
| self.dimension, |
| self.n_filters * 2 ** (len(self.depths) - 1), |
| kernel_size, |
| norm=norm, |
| norm_kwargs=norm_params, |
| causal=self.causal, |
| pad_mode=pad_mode, |
| bias=bias, |
| ) |
| ) |
| self.upsample_layers = nn.ModuleList() |
| self.upsample_layers.append(stem) |
| for i in range(len(self.ratios)): |
| in_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i) |
| out_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i - 1) |
| upsample_layer = nn.Sequential( |
| SConvTranspose1d( |
| in_ch, |
| out_ch, |
| kernel_size=self.ratios[i] * 2, |
| stride=self.ratios[i], |
| norm=norm, |
| norm_kwargs=norm_params, |
| bias=bias, |
| causal=self.causal, |
| trim_right_ratio=trim_right_ratio, |
| ) |
| ) |
| self.upsample_layers.append(upsample_layer) |
| layer_type = partial( |
| Block1D, |
| mixer_layer=mixer_layer, |
| layernorm=layernorm, |
| eps=layernorm_eps, |
| causal=self.causal, |
| pad_mode=pad_mode, |
| norm=norm, |
| bias=bias, |
| layer_scale_init_value=layer_scale_init_value, |
| ) |
| self.stages = nn.ModuleList() |
| dp_rates = [ |
| x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths)) |
| ] |
| cur = 0 |
| for i in range(len(self.depths)): |
| in_ch = self.n_filters * 2 ** (len(self.depths) - 1 - i) |
| stage = nn.Sequential( |
| *[ |
| layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) |
| for j in range(self.depths[i]) |
| ] |
| ) |
| self.stages.append(stage) |
| cur += self.depths[i] |
| if not disable_last_norm: |
| self.norm = norm_type(in_ch, eps=layernorm_eps) |
| else: |
| self.norm = nn.Identity() |
| self.head = SConv1d( |
| in_ch, |
| self.channels, |
| kernel_size=last_kernel_size, |
| causal=self.causal, |
| pad_mode=pad_mode, |
| norm=norm, |
| bias=bias, |
| ) |
|
|
| def forward_features( |
| self, x, cache=None, sample_indices=None, use_cache=False, debug=False |
| ): |
| for i in range(len(self.depths)): |
| for layer in self.upsample_layers[i]: |
| if isinstance(layer, (SConv1d, SConvTranspose1d)): |
| x = layer( |
| x, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| else: |
| x = layer(x) |
| for block in self.stages[i]: |
| if ( |
| hasattr(block, "mixer") |
| and hasattr(block.mixer, "conv") |
| and isinstance(block.mixer.conv, SConv1d) |
| ): |
| residual = x |
| x = block.norm(x) |
| x = block.mixer.conv( |
| x, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| if block.gamma is not None: |
| x = x * block.gamma.unsqueeze(-1) |
| x = residual + x |
| residual = x |
| x = block.ffn_norm(x) |
| x = x.permute(0, 2, 1) |
| x = block.ffn(x) |
| x = x.permute(0, 2, 1) |
| if block.ffn_gamma is not None: |
| x = x * block.ffn_gamma.unsqueeze(-1) |
| x = residual + x |
| else: |
| x = block(x) |
| return self.norm(x) |
|
|
| def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): |
| x = self.forward_features( |
| x, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| x = self.head( |
| x, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| return x |
|
|
|
|
| @dataclass |
| class QWEN3VoxTokenizerEncoderOutput: |
| mean: torch.Tensor |
| std: Optional[Union[float, torch.Tensor]] = None |
|
|
| def sample(self, dist_type="fix"): |
| if dist_type == "fix": |
| x = self.mean + self.std * torch.randn_like(self.mean) |
| return (x, self.std) |
| elif dist_type == "gaussian": |
| batch_size = self.mean.size(0) |
| value = self.std / 0.8 |
| std = ( |
| torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) |
| * value |
| ) |
| while std.dim() < self.mean.dim(): |
| std = std.unsqueeze(-1) |
| x = self.mean + std * torch.randn_like(self.mean) |
| return (x, std) |
| else: |
| return (self.mean, self.std) |
|
|
| def kl(self): |
| target = torch.zeros_like(self.mean) |
| return F.mse_loss(self.mean, target, reduction="none") |
|
|
| def mode(self): |
| return self.mean |
|
|
|
|
| class QWEN3VoxAcousticTokenizerModel(PreTrainedModel): |
| config_class = QWEN3VoxAcousticTokenizerConfig |
| base_model_prefix = 'vibevoice_acoustic_tokenizer' |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.register_buffer("fix_std", torch.tensor(config.fix_std), persistent=False) |
| self.std_dist_type = getattr(config, "std_dist_type", "fix") |
| if isinstance(config.encoder_depths, str): |
| encoder_depths = [int(d) for d in config.encoder_depths.split("-")] |
| else: |
| encoder_depths = config.encoder_depths |
| if config.decoder_depths is not None and isinstance(config.decoder_depths, str): |
| decoder_depths = [int(d) for d in config.decoder_depths.split("-")] |
| else: |
| decoder_depths = list(reversed(encoder_depths)) |
| encoder_config = copy.deepcopy(config) |
| encoder_config.dimension = config.vae_dim |
| encoder_config.n_filters = config.encoder_n_filters |
| encoder_config.ratios = config.encoder_ratios |
| encoder_config.depths = encoder_depths |
| encoder_config.norm = config.conv_norm |
| encoder_config.pad_mode = config.pad_mode |
| encoder_config.bias = config.conv_bias |
| encoder_config.layernorm_eps = config.layernorm_eps |
| encoder_config.layernorm_elementwise_affine = ( |
| config.layernorm_elementwise_affine |
| ) |
| encoder_config.mixer_layer = config.mixer_layer |
| encoder_config.layer_scale_init_value = config.layer_scale_init_value |
| encoder_config.disable_last_norm = config.disable_last_norm |
| decoder_config = copy.deepcopy(config) |
| decoder_config.dimension = config.vae_dim |
| decoder_config.n_filters = config.decoder_n_filters |
| decoder_config.ratios = config.decoder_ratios |
| decoder_config.depths = decoder_depths |
| decoder_config.norm = config.conv_norm |
| decoder_config.pad_mode = config.pad_mode |
| decoder_config.bias = config.conv_bias |
| decoder_config.layernorm_eps = config.layernorm_eps |
| decoder_config.layernorm_elementwise_affine = ( |
| config.layernorm_elementwise_affine |
| ) |
| decoder_config.mixer_layer = config.mixer_layer |
| decoder_config.layer_scale_init_value = config.layer_scale_init_value |
| decoder_config.disable_last_norm = config.disable_last_norm |
| self.encoder = TokenizerEncoder(encoder_config) |
| self.decoder = TokenizerDecoder(decoder_config) |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, std=self.config.weight_init_value) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Conv1d): |
| nn.init.normal_(module.weight, std=self.config.weight_init_value) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| @torch.no_grad() |
| def encode( |
| self, audio, cache=None, sample_indices=None, use_cache=False, debug=False |
| ): |
| latents = self.encoder( |
| audio, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| return QWEN3VoxTokenizerEncoderOutput( |
| mean=latents.permute(0, 2, 1), std=self.fix_std |
| ) |
|
|
| @torch.no_grad() |
| def sampling(self, encoder_output, dist_type=None): |
| dist_type = dist_type or self.std_dist_type |
| if dist_type == "fix": |
| return encoder_output.sample(dist_type="fix") |
| elif dist_type == "gaussian": |
| return encoder_output.sample(dist_type="gaussian") |
| else: |
| raise ValueError( |
| f"Unsupported dist_type: {dist_type }, expected 'fix' or 'gaussian'" |
| ) |
|
|
| @torch.no_grad() |
| def decode( |
| self, latents, cache=None, sample_indices=None, use_cache=False, debug=False |
| ): |
| if latents.shape[1] == self.config.vae_dim: |
| pass |
| else: |
| latents = latents.permute(0, 2, 1) |
| audio = self.decoder( |
| latents, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| return audio |
|
|
| def forward( |
| self, audio, cache=None, sample_indices=None, use_cache=False, debug=False |
| ): |
| encoder_output = self.encode( |
| audio, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| sampled_latents, _ = self.sampling(encoder_output) |
| reconstructed = self.decode( |
| sampled_latents, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| return (reconstructed, sampled_latents) |
|
|
|
|
| class QWEN3VoxSemanticTokenizerModel(PreTrainedModel): |
| config_class = QWEN3VoxSemanticTokenizerConfig |
| base_model_prefix = 'vibevoice_semantic_tokenizer' |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _no_split_modules = ["TokenizerEncoder"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| if isinstance(config.encoder_depths, str): |
| encoder_depths = [int(d) for d in config.encoder_depths.split("-")] |
| else: |
| encoder_depths = config.encoder_depths |
| encoder_config = copy.deepcopy(config) |
| encoder_config.dimension = config.vae_dim |
| encoder_config.n_filters = config.encoder_n_filters |
| encoder_config.ratios = config.encoder_ratios |
| encoder_config.depths = encoder_depths |
| encoder_config.norm = config.conv_norm |
| encoder_config.pad_mode = config.pad_mode |
| encoder_config.bias = config.conv_bias |
| encoder_config.layernorm_eps = config.layernorm_eps |
| encoder_config.layernorm_elementwise_affine = ( |
| config.layernorm_elementwise_affine |
| ) |
| encoder_config.mixer_layer = config.mixer_layer |
| encoder_config.layer_scale_init_value = config.layer_scale_init_value |
| encoder_config.disable_last_norm = config.disable_last_norm |
| self.encoder = TokenizerEncoder(encoder_config) |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, std=self.config.weight_init_value) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Conv1d): |
| nn.init.normal_(module.weight, std=self.config.weight_init_value) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| @torch.no_grad() |
| def encode( |
| self, audio, cache=None, sample_indices=None, use_cache=False, debug=False |
| ): |
| latents = self.encoder( |
| audio, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| return QWEN3VoxTokenizerEncoderOutput(mean=latents.permute(0, 2, 1)) |
|
|
| @torch.no_grad() |
| def sampling(self, encoder_output, dist_type=None): |
| return encoder_output.sample(dist_type="none") |
|
|
| def forward( |
| self, audio, cache=None, sample_indices=None, use_cache=False, debug=False |
| ): |
| encoder_output = self.encode( |
| audio, |
| cache=cache, |
| sample_indices=sample_indices, |
| use_cache=use_cache, |
| debug=debug, |
| ) |
| sampled_latents, _ = self.sampling(encoder_output, dist_type="none") |
| return (None, sampled_latents) |
|
|
|
|
| AutoModel.register(QWEN3VoxAcousticTokenizerConfig, QWEN3VoxAcousticTokenizerModel) |
| AutoModel.register(QWEN3VoxSemanticTokenizerConfig, QWEN3VoxSemanticTokenizerModel) |
| __all__ = [ |
| 'QWEN3VoxTokenizerStreamingCache', |
| 'QWEN3VoxAcousticTokenizerModel', |
| 'QWEN3VoxSemanticTokenizerModel', |
| ] |
| '\nProcessor class for QWEN3Vox ASR models.\n' |
| import os |
| import json |
| import math |
| import warnings |
| from typing import List, Optional, Union, Dict, Any, Tuple |
| import numpy as np |
| import torch |
| from transformers.tokenization_utils_base import BatchEncoding |
| from transformers.utils import TensorType, logging |
|
|
| logger = logging.get_logger(__name__) |
| SYSTEM_PROMPT = "You are a helpful assistant that transcribes audio input into text output in JSON format." |
|
|
|
|
| class QWEN3VoxASRProcessor: |
|
|
| def __init__( |
| self, |
| tokenizer=None, |
| audio_processor=None, |
| speech_tok_compress_ratio=320, |
| target_sample_rate=22050, |
| normalize_audio=True, |
| **kwargs, |
| ): |
| self.tokenizer = tokenizer |
| self.audio_processor = audio_processor or QWEN3VoxTokenizerProcessor( |
| sampling_rate=target_sample_rate, normalize_audio=normalize_audio |
| ) |
| self.speech_tok_compress_ratio = speech_tok_compress_ratio |
| self.target_sample_rate = target_sample_rate |
| self.normalize_audio = normalize_audio |
| if normalize_audio: |
| self.audio_normalizer = AudioNormalizer() |
| else: |
| self.audio_normalizer = None |
| self._cache_special_tokens() |
|
|
| def _cache_special_tokens(self): |
| if hasattr(self.tokenizer, "speech_start_id"): |
| self.speech_start_id = self.tokenizer.speech_start_id |
| else: |
| self.speech_start_id = self.tokenizer.convert_tokens_to_ids( |
| "<|speech_start|>" |
| ) |
| if hasattr(self.tokenizer, "speech_end_id"): |
| self.speech_end_id = self.tokenizer.speech_end_id |
| else: |
| self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|speech_end|>") |
| if hasattr(self.tokenizer, "speech_pad_id"): |
| self.speech_pad_id = self.tokenizer.speech_pad_id |
| else: |
| self.speech_pad_id = self.tokenizer.convert_tokens_to_ids("<|speech_pad|>") |
| if hasattr(self.tokenizer, "pad_id"): |
| self.pad_id = self.tokenizer.pad_id |
| elif hasattr(self.tokenizer, "pad_token_id"): |
| self.pad_id = self.tokenizer.pad_token_id |
| else: |
| self.pad_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>") |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| import json |
| from transformers.utils import cached_file |
|
|
| model_name = str(pretrained_model_name_or_path) |
| config_path = os.path.join( |
| model_name, "preprocessor_config.json" |
| ) |
| config = {} |
| if os.path.exists(config_path): |
| with open(config_path, "r") as f: |
| config = json.load(f) |
| else: |
| try: |
| config_file = cached_file( |
| model_name, "preprocessor_config.json", **kwargs |
| ) |
| with open(config_file, "r") as f: |
| config = json.load(f) |
| except Exception as e: |
| logger.warning(f"Could not load preprocessor_config.json: {e }") |
| logger.warning("Using default configuration") |
| speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) |
| target_sample_rate = config.get("target_sample_rate", 22050) |
| normalize_audio = config.get("normalize_audio", True) |
| language_model_pretrained_name = config.get( |
| "language_model_pretrained_name", None |
| ) or kwargs.pop("language_model_pretrained_name", None) |
| if not language_model_pretrained_name: |
| language_model_pretrained_name = model_name |
| logger.info(f"Loading tokenizer from repo {model_name }") |
| tokenizer = QWEN3VoxASRTextTokenizerFast.from_pretrained( |
| model_name, **kwargs |
| ) |
| audio_processor = QWEN3VoxTokenizerProcessor( |
| sampling_rate=target_sample_rate, |
| normalize_audio=normalize_audio, |
| target_dB_FS=config.get("target_dB_FS", -25), |
| eps=config.get("eps", 1e-06), |
| ) |
| return cls( |
| tokenizer=tokenizer, |
| audio_processor=audio_processor, |
| speech_tok_compress_ratio=speech_tok_compress_ratio, |
| target_sample_rate=target_sample_rate, |
| normalize_audio=normalize_audio, |
| ) |
|
|
| def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): |
| import json |
|
|
| os.makedirs(save_directory, exist_ok=True) |
| processor_config = { |
| "processor_class": "QWEN3VoxASRProcessor", |
| "speech_tok_compress_ratio": self.speech_tok_compress_ratio, |
| "target_sample_rate": self.target_sample_rate, |
| "normalize_audio": self.normalize_audio, |
| "target_dB_FS": -25, |
| "eps": 1e-06, |
| } |
| config_path = os.path.join(save_directory, "preprocessor_config.json") |
| with open(config_path, "w") as f: |
| json.dump(processor_config, f, indent=2) |
| logger.info(f"Processor configuration saved in {config_path }") |
|
|
| def __call__( |
| self, |
| audio: Optional[ |
| Union[ |
| str, |
| np.ndarray, |
| torch.Tensor, |
| List[Union[str, np.ndarray, torch.Tensor]], |
| ] |
| ] = None, |
| sampling_rate: Optional[int] = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| padding: bool = True, |
| max_length: Optional[int] = None, |
| truncation: bool = False, |
| add_generation_prompt: bool = True, |
| use_streaming: bool = True, |
| context_info: Optional[str] = None, |
| **kwargs, |
| ) -> BatchEncoding: |
| if audio is None: |
| raise ValueError("Audio input is required for ASR processing") |
| if isinstance(audio, list): |
| is_batched = True |
| audio_list = audio |
| else: |
| is_batched = False |
| audio_list = [audio] |
| all_encodings = [] |
| for audio_input in audio_list: |
| encoding = self._process_single_audio( |
| audio_input, |
| sampling_rate=sampling_rate, |
| add_generation_prompt=add_generation_prompt, |
| use_streaming=use_streaming, |
| context_info=context_info, |
| ) |
| all_encodings.append(encoding) |
| batch_encoding = self._batch_encode( |
| all_encodings, |
| padding=padding, |
| max_length=max_length, |
| truncation=truncation, |
| return_tensors=return_tensors, |
| ) |
| return batch_encoding |
|
|
| def _process_single_audio( |
| self, |
| audio: Union[str, np.ndarray, torch.Tensor], |
| sampling_rate: Optional[int] = None, |
| add_generation_prompt: bool = True, |
| use_streaming: bool = True, |
| context_info: Optional[str] = None, |
| ) -> Dict[str, Any]: |
| if isinstance(audio, str): |
| import soundfile as sf |
|
|
| audio_array, file_sr = sf.read(audio) |
| if audio_array.ndim > 1: |
| audio_array = audio_array.mean(axis=1) |
| if file_sr != self.target_sample_rate: |
| import librosa |
|
|
| audio_array = librosa.resample( |
| audio_array, orig_sr=file_sr, target_sr=self.target_sample_rate |
| ) |
| elif isinstance(audio, torch.Tensor): |
| audio_array = audio.cpu().numpy() |
| if audio_array.ndim > 1: |
| audio_array = audio_array.squeeze() |
| else: |
| audio_array = np.array(audio, dtype=np.float32) |
| if audio_array.ndim > 1: |
| audio_array = audio_array.squeeze() |
| audio_array = audio_array.astype(np.float32) |
| if self.normalize_audio and self.audio_normalizer: |
| audio_array = self.audio_normalizer(audio_array) |
| audio_duration = len(audio_array) / self.target_sample_rate |
| if use_streaming and audio_duration < 60.0: |
| use_streaming = False |
| vae_tok_len = math.ceil(len(audio_array) / self.speech_tok_compress_ratio) |
| system_prompt_text = self.tokenizer.apply_chat_template( |
| [{"role": "system", "content": SYSTEM_PROMPT}], tokenize=False |
| ) |
| system_tokens = self.tokenizer.encode(system_prompt_text) |
| sp_start_token = self.tokenizer.convert_ids_to_tokens(self.speech_start_id) |
| sp_pad_token = self.tokenizer.convert_ids_to_tokens(self.speech_pad_id) |
| sp_end_token = self.tokenizer.convert_ids_to_tokens(self.speech_end_id) |
| show_keys = ["Start time", "End time", "Speaker ID", "Content"] |
| if context_info and context_info.strip(): |
| user_suffix = ( |
| f"This is a {audio_duration :.2f} seconds audio, with extra info: {context_info .strip ()}\n\nPlease transcribe it with these keys: " |
| + ", ".join(show_keys) |
| ) |
| else: |
| user_suffix = ( |
| f"This is a {audio_duration :.2f} seconds audio, please transcribe it with these keys: " |
| + ", ".join(show_keys) |
| ) |
| user_input_string = ( |
| "".join([sp_start_token] + [sp_pad_token] * vae_tok_len + [sp_end_token]) |
| + "\n" |
| + user_suffix |
| ) |
| user_tokens = self.tokenizer.apply_chat_template( |
| [{"role": "user", "content": user_input_string}], tokenize=True |
| ) |
| full_tokens = system_tokens + user_tokens |
| acoustic_input_mask = [ |
| 1 if token == self.speech_pad_id else 0 for token in full_tokens |
| ] |
| return { |
| "input_ids": full_tokens, |
| "acoustic_input_mask": acoustic_input_mask, |
| "speech": audio_array, |
| "vae_tok_len": vae_tok_len, |
| } |
|
|
| def _batch_encode( |
| self, |
| encodings: List[Dict[str, Any]], |
| padding: bool = True, |
| max_length: Optional[int] = None, |
| truncation: bool = False, |
| return_tensors: Optional[str] = None, |
| ) -> BatchEncoding: |
| input_ids_list = [enc["input_ids"] for enc in encodings] |
| acoustic_masks_list = [enc["acoustic_input_mask"] for enc in encodings] |
| speech_list = [enc["speech"] for enc in encodings] |
| vae_tok_lens = [enc["vae_tok_len"] for enc in encodings] |
| if padding: |
| if max_length is not None: |
| target_length = max_length |
| else: |
| target_length = max((len(ids) for ids in input_ids_list)) |
| padded_input_ids = [] |
| padded_acoustic_masks = [] |
| attention_masks = [] |
| for input_ids, acoustic_mask in zip(input_ids_list, acoustic_masks_list): |
| if truncation and len(input_ids) > target_length: |
| input_ids = input_ids[:target_length] |
| acoustic_mask = acoustic_mask[:target_length] |
| padding_length = target_length - len(input_ids) |
| padded_ids = [self.pad_id] * padding_length + input_ids |
| padded_acoustic = [0] * padding_length + acoustic_mask |
| attention_mask = [0] * padding_length + [1] * len(input_ids) |
| padded_input_ids.append(padded_ids) |
| padded_acoustic_masks.append(padded_acoustic) |
| attention_masks.append(attention_mask) |
| input_ids_list = padded_input_ids |
| acoustic_masks_list = padded_acoustic_masks |
| else: |
| attention_masks = [[1] * len(ids) for ids in input_ids_list] |
| max_speech_length = max((len(s) for s in speech_list)) |
| padded_speeches = np.zeros( |
| (len(speech_list), max_speech_length), dtype=np.float32 |
| ) |
| speech_masks = np.zeros((len(speech_list), max(vae_tok_lens)), dtype=bool) |
| for i, (speech, vae_len) in enumerate(zip(speech_list, vae_tok_lens)): |
| padded_speeches[i, : len(speech)] = speech |
| speech_masks[i, :vae_len] = True |
| batch_encoding = BatchEncoding() |
| if return_tensors == "pt": |
| batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) |
| batch_encoding["attention_mask"] = torch.tensor( |
| attention_masks, dtype=torch.long |
| ) |
| batch_encoding["acoustic_input_mask"] = torch.tensor( |
| acoustic_masks_list, dtype=torch.bool |
| ) |
| batch_encoding["speech_tensors"] = torch.tensor( |
| padded_speeches, dtype=torch.float32 |
| ) |
| batch_encoding["speech_masks"] = torch.tensor( |
| speech_masks, dtype=torch.bool |
| ) |
| else: |
| batch_encoding["input_ids"] = ( |
| input_ids_list if len(input_ids_list) > 1 else input_ids_list[0] |
| ) |
| batch_encoding["attention_mask"] = ( |
| attention_masks if len(attention_masks) > 1 else attention_masks[0] |
| ) |
| batch_encoding["acoustic_input_mask"] = ( |
| acoustic_masks_list |
| if len(acoustic_masks_list) > 1 |
| else acoustic_masks_list[0] |
| ) |
| batch_encoding["speech_tensors"] = ( |
| padded_speeches if len(padded_speeches) > 1 else padded_speeches[0] |
| ) |
| batch_encoding["speech_masks"] = ( |
| speech_masks if len(speech_masks) > 1 else speech_masks[0] |
| ) |
| return batch_encoding |
|
|
| def batch_decode(self, *args, **kwargs): |
| return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs): |
| return self.tokenizer.decode(*args, **kwargs) |
|
|
| def post_process_transcription(self, text: str) -> List[Dict[str, Any]]: |
| try: |
| if "```json" in text: |
| json_start = text.find("```json") + 7 |
| json_end = text.find("```", json_start) |
| json_str = text[json_start:json_end].strip() |
| else: |
| json_start = text.find("[") |
| if json_start == -1: |
| json_start = text.find("{") |
| if json_start != -1: |
| bracket_count = 0 |
| json_end = json_start |
| for i in range(json_start, len(text)): |
| if text[i] in "[{": |
| bracket_count += 1 |
| elif text[i] in "]}": |
| bracket_count -= 1 |
| if bracket_count == 0: |
| json_end = i + 1 |
| break |
| json_str = text[json_start:json_end] |
| else: |
| json_str = text |
| result = json.loads(json_str) |
| if isinstance(result, dict): |
| result = [result] |
| cleaned_result = [] |
| for item in result: |
| if isinstance(item, dict): |
| cleaned_item = {} |
| key_mapping = { |
| "Start time": "start_time", |
| "Start": "start_time", |
| "End time": "end_time", |
| "End": "end_time", |
| "Speaker ID": "speaker_id", |
| "Speaker": "speaker_id", |
| "Content": "text", |
| } |
| for key, mapped_key in key_mapping.items(): |
| if key in item: |
| cleaned_item[mapped_key] = item[key] |
| if cleaned_item: |
| cleaned_result.append(cleaned_item) |
| return cleaned_result |
| except json.JSONDecodeError as e: |
| logger.warning(f"Failed to parse JSON from transcription: {e }") |
| logger.debug(f"Raw text: {text }") |
| return [] |
| except Exception as e: |
| logger.warning(f"Error post-processing transcription: {e }") |
| return [] |
|
|
| @property |
| def model_input_names(self): |
| return [ |
| "input_ids", |
| "attention_mask", |
| "acoustic_input_mask", |
| "speech_tensors", |
| "speech_masks", |
| ] |
|
|
|
|
| __all__ = [ |
| 'QWEN3VoxASRProcessor' |
| ] |
| import math |
| import warnings |
| from typing import List, Optional, Union, Dict, Any, Tuple |
| import os |
| import re |
| import numpy as np |
| import torch |
| from transformers.tokenization_utils_base import ( |
| BatchEncoding, |
| PaddingStrategy, |
| PreTokenizedInput, |
| TextInput, |
| TruncationStrategy, |
| ) |
| from transformers.utils import TensorType, logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class QWEN3VoxProcessor: |
|
|
| def __init__( |
| self, |
| tokenizer=None, |
| audio_processor=None, |
| speech_tok_compress_ratio=3200, |
| db_normalize=True, |
| **kwargs, |
| ): |
| self.tokenizer = tokenizer |
| self.audio_processor = audio_processor |
| self.speech_tok_compress_ratio = speech_tok_compress_ratio |
| self.db_normalize = db_normalize |
| self.audio_normalizer = AudioNormalizer() if db_normalize else None |
| self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n" |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| import os |
| import json |
| from transformers.utils import cached_file |
|
|
| model_name = str(pretrained_model_name_or_path) |
| config_path = os.path.join( |
| model_name, "preprocessor_config.json" |
| ) |
| config = None |
| if os.path.exists(config_path): |
| with open(config_path, "r") as f: |
| config = json.load(f) |
| else: |
| try: |
| config_file = cached_file( |
| model_name, "preprocessor_config.json", **kwargs |
| ) |
| with open(config_file, "r") as f: |
| config = json.load(f) |
| except Exception as e: |
| logger.warning( |
| f"Could not load preprocessor_config.json from {model_name }: {e }" |
| ) |
| logger.warning("Using default configuration") |
| config = {"speech_tok_compress_ratio": 3200, "db_normalize": True} |
| speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) |
| db_normalize = config.get("db_normalize", True) |
| language_model_pretrained_name = config.get( |
| "language_model_pretrained_name", None |
| ) or kwargs.pop("language_model_pretrained_name", None) |
| if not language_model_pretrained_name: |
| language_model_pretrained_name = model_name |
| logger.info(f"Loading tokenizer from repo {model_name }") |
| tokenizer = QWEN3VoxTextTokenizerFast.from_pretrained( |
| model_name, **kwargs |
| ) |
| if "audio_processor" in config: |
| audio_config = config["audio_processor"] |
| audio_processor = QWEN3VoxTokenizerProcessor( |
| sampling_rate=audio_config.get("sampling_rate", 22050), |
| normalize_audio=audio_config.get("normalize_audio", True), |
| target_dB_FS=audio_config.get("target_dB_FS", -25), |
| eps=audio_config.get("eps", 1e-06), |
| ) |
| else: |
| audio_processor = QWEN3VoxTokenizerProcessor() |
| return cls( |
| tokenizer=tokenizer, |
| audio_processor=audio_processor, |
| speech_tok_compress_ratio=speech_tok_compress_ratio, |
| db_normalize=db_normalize, |
| ) |
|
|
| def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): |
| import os |
| import json |
|
|
| os.makedirs(save_directory, exist_ok=True) |
| processor_config = { |
| "processor_class": "QWEN3VoxProcessor", |
| "speech_tok_compress_ratio": self.speech_tok_compress_ratio, |
| "db_normalize": self.db_normalize, |
| "audio_processor": { |
| "feature_extractor_type": "QWEN3VoxTokenizerProcessor", |
| "sampling_rate": getattr(self.audio_processor, "sampling_rate", 22050), |
| "normalize_audio": getattr( |
| self.audio_processor, "normalize_audio", True |
| ), |
| "target_dB_FS": getattr(self.audio_processor, "target_dB_FS", -25), |
| "eps": getattr(self.audio_processor, "eps", 1e-06), |
| }, |
| } |
| config_path = os.path.join(save_directory, "preprocessor_config.json") |
| with open(config_path, "w") as f: |
| json.dump(processor_config, f, indent=2) |
| logger.info(f"Processor configuration saved in {config_path }") |
|
|
| def __call__( |
| self, |
| text: Optional[ |
| Union[ |
| str, |
| List[str], |
| TextInput, |
| PreTokenizedInput, |
| List[TextInput], |
| List[PreTokenizedInput], |
| ] |
| ] = None, |
| voice_samples: Optional[ |
| Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]] |
| ] = None, |
| padding: Union[bool, str, PaddingStrategy] = True, |
| truncation: Union[bool, str, TruncationStrategy] = False, |
| max_length: Optional[int] = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| return_attention_mask: bool = True, |
| **kwargs, |
| ) -> BatchEncoding: |
| if isinstance(text, str) or ( |
| isinstance(text, list) and len(text) > 0 and (not isinstance(text[0], str)) |
| ): |
| texts = [text] |
| is_batched = False |
| else: |
| texts = text |
| is_batched = True |
| if voice_samples is not None: |
| if not is_batched or isinstance(voice_samples[0], (str, np.ndarray)): |
| voice_samples_list = [voice_samples] |
| else: |
| voice_samples_list = voice_samples |
| else: |
| voice_samples_list = [None] * len(texts) |
| all_encodings = [] |
| for text_input, voice_input in zip(texts, voice_samples_list): |
| encoding = self._process_single(text_input, voice_input) |
| all_encodings.append(encoding) |
| batch_encoding = self._batch_encode( |
| all_encodings, |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| return_tensors=return_tensors, |
| return_attention_mask=return_attention_mask, |
| ) |
| return batch_encoding |
|
|
| def _process_single( |
| self, |
| text: Union[str, TextInput], |
| voice_samples: Optional[List[Union[str, np.ndarray]]] = None, |
| ) -> Dict[str, Any]: |
| script = None |
| if isinstance(text, str): |
| if text.endswith(".json") and os.path.exists(text): |
| script = self._convert_json_to_script(text) |
| elif text.endswith(".txt") and os.path.exists(text): |
| script = self._convert_text_to_script(text) |
| else: |
| script = text |
| if script is None: |
| raise ValueError(f"Could not process input text: {text }") |
| parsed_lines = self._parse_script(script) |
| all_speakers = list(set((speaker_id for speaker_id, _ in parsed_lines))) |
| system_tokens = self.tokenizer.encode(self.system_prompt) |
| if voice_samples: |
| voice_tokens, voice_speech_inputs, voice_speech_masks = ( |
| self._create_voice_prompt(voice_samples[: len(all_speakers)]) |
| ) |
| else: |
| voice_tokens, voice_speech_inputs, voice_speech_masks = ([], [], []) |
| full_tokens = system_tokens + voice_tokens |
| speech_input_mask = [False] * len(system_tokens) + voice_speech_masks |
| full_tokens += self.tokenizer.encode(" Text input:\n", add_special_tokens=False) |
| speech_input_mask += [False] * len( |
| self.tokenizer.encode(" Text input:\n", add_special_tokens=False) |
| ) |
| for speaker_id, speaker_text in parsed_lines: |
| speaker_text_tokens = self.tokenizer.encode( |
| f" Speaker {speaker_id }:{speaker_text }\n", add_special_tokens=False |
| ) |
| full_tokens += speaker_text_tokens |
| speech_input_mask += [False] * len(speaker_text_tokens) |
| full_tokens += self.tokenizer.encode( |
| " Speech output:\n", add_special_tokens=False |
| ) + [self.tokenizer.speech_start_id] |
| speech_input_mask += [False] * ( |
| len(self.tokenizer.encode(" Speech output:\n", add_special_tokens=False)) |
| + 1 |
| ) |
| return { |
| "input_ids": full_tokens, |
| "speech_inputs": voice_speech_inputs if voice_speech_inputs else None, |
| "speech_input_mask": speech_input_mask, |
| "parsed_script": parsed_lines, |
| "all_speakers": all_speakers, |
| } |
|
|
| def _batch_encode( |
| self, |
| encodings: List[Dict[str, Any]], |
| padding: Union[bool, str, PaddingStrategy] = True, |
| truncation: Union[bool, str, TruncationStrategy] = False, |
| max_length: Optional[int] = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| return_attention_mask: bool = True, |
| ) -> BatchEncoding: |
| input_ids_list = [enc["input_ids"] for enc in encodings] |
| speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings] |
| if isinstance(padding, bool): |
| padding_strategy = ( |
| PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD |
| ) |
| elif isinstance(padding, str): |
| padding_strategy = PaddingStrategy(padding) |
| else: |
| padding_strategy = padding |
| if padding_strategy != PaddingStrategy.DO_NOT_PAD: |
| if padding_strategy == PaddingStrategy.LONGEST: |
| max_len = max((len(ids) for ids in input_ids_list)) |
| elif ( |
| padding_strategy == PaddingStrategy.MAX_LENGTH |
| and max_length is not None |
| ): |
| max_len = max_length |
| else: |
| max_len = max((len(ids) for ids in input_ids_list)) |
| padded_input_ids = [] |
| attention_masks = [] |
| padded_speech_input_masks = [] |
| for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list): |
| if truncation and len(input_ids) > max_len: |
| input_ids = input_ids[:max_len] |
| speech_mask = speech_mask[:max_len] |
| padding_length = max_len - len(input_ids) |
| padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids |
| attention_mask = [0] * padding_length + [1] * len(input_ids) |
| padded_speech_mask = [False] * padding_length + speech_mask |
| padded_input_ids.append(padded_ids) |
| attention_masks.append(attention_mask) |
| padded_speech_input_masks.append(padded_speech_mask) |
| input_ids_list = padded_input_ids |
| speech_input_masks_list = padded_speech_input_masks |
| else: |
| attention_masks = ( |
| [[1] * len(ids) for ids in input_ids_list] |
| if return_attention_mask |
| else None |
| ) |
| all_speech_inputs = [] |
| has_speech = False |
| for enc in encodings: |
| if enc["speech_inputs"] is not None: |
| all_speech_inputs.extend(enc["speech_inputs"]) |
| has_speech = True |
| batch_encoding = BatchEncoding() |
| if return_tensors is not None: |
| batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) |
| if return_attention_mask and attention_masks is not None: |
| batch_encoding["attention_mask"] = torch.tensor( |
| attention_masks, dtype=torch.long |
| ) |
| batch_encoding["speech_input_mask"] = torch.tensor( |
| speech_input_masks_list, dtype=torch.bool |
| ) |
| else: |
| batch_encoding["input_ids"] = input_ids_list |
| if return_attention_mask and attention_masks is not None: |
| batch_encoding["attention_mask"] = attention_masks |
| batch_encoding["speech_input_mask"] = speech_input_masks_list |
| if has_speech: |
| speech_dict = self.prepare_speech_inputs( |
| all_speech_inputs, return_tensors=return_tensors |
| ) |
| batch_encoding["speech_tensors"] = speech_dict["padded_speeches"] |
| batch_encoding["speech_masks"] = speech_dict["speech_masks"] |
| else: |
| batch_encoding["speech_tensors"] = None |
| batch_encoding["speech_masks"] = None |
| batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings] |
| batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings] |
| return batch_encoding |
|
|
| def _create_voice_prompt( |
| self, speaker_samples: List[Union[str, np.ndarray]] |
| ) -> Tuple[List[int], List[np.ndarray], List[bool]]: |
| vae_token_id = self.tokenizer.speech_diffusion_id |
| voice_full_tokens = self.tokenizer.encode( |
| " Voice input:\n", add_special_tokens=False |
| ) |
| voice_speech_inputs = [] |
| voice_speech_masks = [False] * len(voice_full_tokens) |
| for speaker_id, speaker_audio in enumerate(speaker_samples): |
| prefix_tokens = self.tokenizer.encode( |
| f" Speaker {speaker_id }:", add_special_tokens=False |
| ) |
| if isinstance(speaker_audio, str): |
| wav = self.audio_processor._load_audio_from_path(speaker_audio) |
| elif isinstance(speaker_audio, dict): |
| if "array" in speaker_audio: |
| wav = np.array(speaker_audio["array"], dtype=np.float32) |
| elif "audio" in speaker_audio: |
| wav = np.array(speaker_audio["audio"], dtype=np.float32) |
| else: |
| raise ValueError( |
| f"Dictionary audio input must have 'array' or 'audio' key, got: {speaker_audio .keys ()}" |
| ) |
| else: |
| wav = np.array(speaker_audio, dtype=np.float32) |
| if self.db_normalize and self.audio_normalizer: |
| wav = self.audio_normalizer(wav) |
| vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio) |
| speaker_tokens = ( |
| prefix_tokens |
| + [self.tokenizer.speech_start_id] |
| + [vae_token_id] * vae_tok_len |
| + [self.tokenizer.speech_end_id] |
| + self.tokenizer.encode("\n", add_special_tokens=False) |
| ) |
| vae_input_mask = ( |
| [False] * len(prefix_tokens) |
| + [False] |
| + [True] * vae_tok_len |
| + [False] |
| + [False] |
| ) |
| voice_full_tokens.extend(speaker_tokens) |
| voice_speech_masks.extend(vae_input_mask) |
| voice_speech_inputs.append(wav) |
| return (voice_full_tokens, voice_speech_inputs, voice_speech_masks) |
|
|
| def prepare_speech_inputs( |
| self, |
| speech_inputs: List[np.ndarray], |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| device: Optional[Union[str, torch.device]] = None, |
| dtype: Optional[torch.dtype] = None, |
| ) -> Dict[str, Any]: |
| if not speech_inputs: |
| return {"padded_speeches": None, "speech_masks": None} |
| vae_tok_seqlens = [ |
| math.ceil(s.shape[0] / self.speech_tok_compress_ratio) |
| for s in speech_inputs |
| ] |
| max_speech_length = max((s.shape[0] for s in speech_inputs)) |
| if speech_inputs[0].ndim == 1: |
| padded_speeches = np.full( |
| (len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32 |
| ) |
| else: |
| padded_speeches = np.full( |
| (len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), |
| fill_value=0, |
| dtype=np.float32, |
| ) |
| speech_masks = np.zeros( |
| (len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_ |
| ) |
| for i, (speech, vae_tok_length) in enumerate( |
| zip(speech_inputs, vae_tok_seqlens) |
| ): |
| padded_speeches[i, : len(speech)] = speech |
| speech_masks[i, :vae_tok_length] = True |
| result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks} |
| if return_tensors == "pt": |
| result["padded_speeches"] = torch.tensor( |
| padded_speeches, device=device, dtype=dtype or torch.float32 |
| ) |
| result["speech_masks"] = torch.tensor( |
| speech_masks, device=device, dtype=torch.bool |
| ) |
| return result |
|
|
| def _convert_json_to_script(self, json_file: str) -> str: |
| import json |
|
|
| with open(json_file, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| if not isinstance(data, list): |
| raise ValueError("JSON file must contain a list of speaker entries") |
| script_lines = [] |
| for item in data: |
| if not isinstance(item, dict): |
| logger.warning(f"Skipping non-dict entry: {item }") |
| continue |
| speaker = item.get("speaker") |
| text = item.get("text") |
| if speaker is None or text is None: |
| logger.warning(f"Skipping entry missing speaker or text: {item }") |
| continue |
| try: |
| speaker_id = int(speaker) |
| except (ValueError, TypeError): |
| logger.warning(f"Invalid speaker ID: {speaker }, skipping entry") |
| continue |
| text = text.strip() |
| if text: |
| script_lines.append(f"Speaker {speaker_id }: {text }") |
| if not script_lines: |
| raise ValueError("No valid entries found in JSON file") |
| return "\n".join(script_lines) |
|
|
| def _convert_text_to_script(self, text_file: str) -> str: |
| with open(text_file, "r", encoding="utf-8") as f: |
| lines = f.readlines() |
| script_lines = [] |
| current_speaker = 1 |
| for line in lines: |
| line = line.strip() |
| if not line: |
| continue |
| speaker_match = re.match( |
| "^Speaker\\s+(\\d+)\\s*:\\s*(.*)$", line, re.IGNORECASE |
| ) |
| if speaker_match: |
| speaker_id = int(speaker_match.group(1)) |
| text = speaker_match.group(2).strip() |
| if text: |
| script_lines.append(f"Speaker {speaker_id }: {text }") |
| else: |
| script_lines.append(f"Speaker {current_speaker }: {line }") |
| if not script_lines: |
| raise ValueError("No valid content found in text file") |
| return "\n".join(script_lines) |
|
|
| def _parse_script(self, script: str) -> List[Tuple[int, str]]: |
| stripped = script.strip() |
| if not stripped: |
| raise ValueError( |
| "No valid speaker lines found in script (empty text). " |
| "If training with HuggingFace Trainer, set remove_unused_columns=False " |
| "so dataset columns like `text` are not stripped before the collator." |
| ) |
| non_empty = [ln.strip() for ln in stripped.split("\n") if ln.strip()] |
| if not non_empty: |
| raise ValueError("No valid speaker lines found in script") |
| _speaker_line = r"^Speaker\s+(\d+)\s*:\s*(.*)$" |
| if not any(re.match(_speaker_line, ln, re.IGNORECASE) for ln in non_empty): |
| |
| collapsed = " ".join(stripped.split()) |
| return [(0, " " + collapsed)] |
| parsed_lines: List[Tuple[int, str]] = [] |
| speaker_ids: List[int] = [] |
| for line in non_empty: |
| match = re.match(_speaker_line, line, re.IGNORECASE) |
| if match: |
| speaker_id = int(match.group(1)) |
| text = " " + match.group(2).strip() |
| parsed_lines.append((speaker_id, text)) |
| speaker_ids.append(speaker_id) |
| else: |
| logger.warning(f"Could not parse line: '{line }'") |
| if not parsed_lines: |
| raise ValueError("No valid speaker lines found in script") |
| min_speaker_id = min(speaker_ids) |
| if min_speaker_id > 0: |
| normalized_lines = [] |
| for speaker_id, text in parsed_lines: |
| normalized_lines.append((speaker_id - 1, text)) |
| return normalized_lines |
| else: |
| return parsed_lines |
|
|
| def _merge_inputs( |
| self, text_inputs: BatchEncoding, audio_inputs: Dict |
| ) -> BatchEncoding: |
| merged = BatchEncoding(text_inputs) |
| if "audio" in audio_inputs: |
| merged["speech_inputs"] = audio_inputs["audio"] |
| if "streaming" in audio_inputs: |
| merged["streaming"] = audio_inputs["streaming"] |
| return merged |
|
|
| def batch_decode(self, *args, **kwargs): |
| return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs): |
| return self.tokenizer.decode(*args, **kwargs) |
|
|
| @property |
| def model_input_names(self): |
| tokenizer_input_names = self.tokenizer.model_input_names |
| audio_processor_input_names = self.audio_processor.model_input_names |
| return list( |
| dict.fromkeys( |
| tokenizer_input_names |
| + audio_processor_input_names |
| + ["speech_inputs", "speech_input_mask"] |
| ) |
| ) |
|
|
| def save_audio( |
| self, |
| audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], |
| output_path: str = "output.wav", |
| sampling_rate: Optional[int] = None, |
| normalize: bool = False, |
| batch_prefix: str = "audio_", |
| ) -> str: |
| return self.audio_processor.save_audio( |
| audio, |
| output_path=output_path, |
| sampling_rate=sampling_rate, |
| normalize=normalize, |
| batch_prefix=batch_prefix, |
| ) |
|
|
|
|
| __all__ = [ |
| 'QWEN3VoxProcessor' |
| ] |
| '\nQWEN3Vox Streaming Processor\n\nThis processor handles input preparation for the streaming 0.5B model,\nincluding text tokenization and cached voice prompt handling.\n' |
| import math |
| import warnings |
| from typing import List, Optional, Union, Dict, Any, Tuple |
| import os |
| import re |
| import numpy as np |
| import torch |
| from transformers.tokenization_utils_base import ( |
| BatchEncoding, |
| PaddingStrategy, |
| PreTokenizedInput, |
| TextInput, |
| TruncationStrategy, |
| ) |
| from transformers.utils import TensorType, logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class QWEN3VoxStreamingProcessor: |
|
|
| def __init__( |
| self, |
| tokenizer=None, |
| audio_processor=None, |
| speech_tok_compress_ratio=3200, |
| db_normalize=True, |
| **kwargs, |
| ): |
| self.tokenizer = tokenizer |
| self.audio_processor = audio_processor |
| self.speech_tok_compress_ratio = speech_tok_compress_ratio |
| self.db_normalize = db_normalize |
| self.audio_normalizer = AudioNormalizer() if db_normalize else None |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| import os |
| import json |
| from transformers.utils import cached_file |
|
|
| model_name = str(pretrained_model_name_or_path) |
| config_path = os.path.join( |
| model_name, "preprocessor_config.json" |
| ) |
| config = None |
| if os.path.exists(config_path): |
| with open(config_path, "r") as f: |
| config = json.load(f) |
| else: |
| try: |
| config_file = cached_file( |
| model_name, "preprocessor_config.json", **kwargs |
| ) |
| with open(config_file, "r") as f: |
| config = json.load(f) |
| except Exception as e: |
| logger.warning( |
| f"Could not load preprocessor_config.json from {model_name }: {e }" |
| ) |
| logger.warning("Using default configuration") |
| config = {"speech_tok_compress_ratio": 3200, "db_normalize": True} |
| speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) |
| db_normalize = config.get("db_normalize", True) |
| logger.info(f"Loading tokenizer from repo {model_name }") |
| tokenizer = QWEN3VoxTextTokenizerFast.from_pretrained( |
| model_name, **kwargs |
| ) |
| if "audio_processor" in config: |
| audio_config = config["audio_processor"] |
| audio_processor = QWEN3VoxTokenizerProcessor( |
| sampling_rate=audio_config.get("sampling_rate", 22050), |
| normalize_audio=audio_config.get("normalize_audio", True), |
| target_dB_FS=audio_config.get("target_dB_FS", -25), |
| eps=audio_config.get("eps", 1e-06), |
| ) |
| else: |
| audio_processor = QWEN3VoxTokenizerProcessor() |
| return cls( |
| tokenizer=tokenizer, |
| audio_processor=audio_processor, |
| speech_tok_compress_ratio=speech_tok_compress_ratio, |
| db_normalize=db_normalize, |
| ) |
|
|
| def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): |
| import os |
| import json |
|
|
| os.makedirs(save_directory, exist_ok=True) |
| processor_config = { |
| "processor_class": "QWEN3VoxStreamingProcessor", |
| "speech_tok_compress_ratio": self.speech_tok_compress_ratio, |
| "db_normalize": self.db_normalize, |
| "audio_processor": { |
| "feature_extractor_type": "QWEN3VoxTokenizerProcessor", |
| "sampling_rate": getattr(self.audio_processor, "sampling_rate", 22050), |
| "normalize_audio": getattr( |
| self.audio_processor, "normalize_audio", True |
| ), |
| "target_dB_FS": getattr(self.audio_processor, "target_dB_FS", -25), |
| "eps": getattr(self.audio_processor, "eps", 1e-06), |
| }, |
| } |
| config_path = os.path.join(save_directory, "preprocessor_config.json") |
| with open(config_path, "w") as f: |
| json.dump(processor_config, f, indent=2) |
| logger.info(f"Processor configuration saved in {config_path }") |
|
|
| def __call__(self) -> BatchEncoding: |
| raise NotImplementedError( |
| 'QWEN3VoxStreamingProcessor.__call__ is not implemented. Use process_input_with_cached_prompt for streaming inputs.' |
| ) |
|
|
| def process_input_with_cached_prompt( |
| self, |
| text: Optional[str] = None, |
| cached_prompt: Optional[Dict[str, Any]] = None, |
| padding: Union[bool, str, PaddingStrategy] = True, |
| truncation: Union[bool, str, TruncationStrategy] = False, |
| max_length: Optional[int] = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| return_attention_mask: bool = True, |
| **kwargs, |
| ) -> BatchEncoding: |
| texts = [text] |
| cached_prompts = [cached_prompt] |
| is_batched = False |
| all_encodings = [] |
| for text_input, cached_prompt_input in zip(texts, cached_prompts): |
| script_tokens = self.tokenizer.encode( |
| text_input.strip() + "\n", add_special_tokens=False |
| ) |
| input_id_length = cached_prompt_input["lm"]["last_hidden_state"].size(1) |
| tts_lm_input_id_length = cached_prompt_input["tts_lm"][ |
| "last_hidden_state" |
| ].size(1) |
| input_ids = [self.tokenizer.pad_id] * input_id_length |
| tts_lm_input_ids = [self.tokenizer.pad_id] * tts_lm_input_id_length |
| speech_input_mask = [False] * tts_lm_input_id_length |
| encoding = { |
| "input_ids": input_ids, |
| "tts_lm_input_ids": tts_lm_input_ids, |
| "tts_text_ids": script_tokens, |
| "speech_inputs": None, |
| "speech_input_mask": speech_input_mask, |
| } |
| all_encodings.append(encoding) |
| batch_encoding = self._batch_encode( |
| all_encodings, |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| return_tensors=return_tensors, |
| return_attention_mask=return_attention_mask, |
| ) |
| return batch_encoding |
|
|
| def _batch_encode( |
| self, |
| encodings: List[Dict[str, Any]], |
| padding: Union[bool, str, PaddingStrategy] = True, |
| truncation: Union[bool, str, TruncationStrategy] = False, |
| max_length: Optional[int] = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| return_attention_mask: bool = True, |
| ) -> BatchEncoding: |
| input_ids_list = [enc["input_ids"] for enc in encodings] |
| tts_lm_input_ids_list = [enc["tts_lm_input_ids"] for enc in encodings] |
| tts_text_ids_list = [enc["tts_text_ids"] for enc in encodings] |
| speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings] |
| attention_masks = ( |
| [[1] * len(ids) for ids in input_ids_list] |
| if return_attention_mask |
| else None |
| ) |
| tts_lm_attention_masks = ( |
| [[1] * len(ids) for ids in tts_lm_input_ids_list] |
| if return_attention_mask |
| else None |
| ) |
| all_speech_inputs = [] |
| has_speech = False |
| for enc in encodings: |
| if enc["speech_inputs"] is not None: |
| all_speech_inputs.extend(enc["speech_inputs"]) |
| has_speech = True |
| batch_encoding = BatchEncoding() |
| if return_tensors is not None: |
| batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) |
| batch_encoding["tts_lm_input_ids"] = torch.tensor( |
| tts_lm_input_ids_list, dtype=torch.long |
| ) |
| batch_encoding["tts_text_ids"] = torch.tensor( |
| tts_text_ids_list, dtype=torch.long |
| ) |
| if return_attention_mask and attention_masks is not None: |
| batch_encoding["attention_mask"] = torch.tensor( |
| attention_masks, dtype=torch.long |
| ) |
| batch_encoding["tts_lm_attention_mask"] = torch.tensor( |
| tts_lm_attention_masks, dtype=torch.long |
| ) |
| batch_encoding["speech_input_mask"] = torch.tensor( |
| speech_input_masks_list, dtype=torch.bool |
| ) |
| else: |
| batch_encoding["input_ids"] = input_ids_list |
| batch_encoding["tts_lm_input_ids"] = tts_lm_input_ids_list |
| batch_encoding["tts_text_ids"] = tts_text_ids_list |
| if return_attention_mask and attention_masks is not None: |
| batch_encoding["attention_mask"] = attention_masks |
| batch_encoding["tts_lm_attention_mask"] = tts_lm_attention_masks |
| batch_encoding["speech_input_mask"] = speech_input_masks_list |
| if has_speech: |
| speech_dict = self.prepare_speech_inputs( |
| all_speech_inputs, return_tensors=return_tensors |
| ) |
| batch_encoding["speech_tensors"] = speech_dict["padded_speeches"] |
| batch_encoding["speech_masks"] = speech_dict["speech_masks"] |
| else: |
| batch_encoding["speech_tensors"] = None |
| batch_encoding["speech_masks"] = None |
| return batch_encoding |
|
|
| def prepare_speech_inputs( |
| self, |
| speech_inputs: List[np.ndarray], |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| device: Optional[Union[str, torch.device]] = None, |
| dtype: Optional[torch.dtype] = None, |
| ) -> Dict[str, Any]: |
| if not speech_inputs: |
| return {"padded_speeches": None, "speech_masks": None} |
| vae_tok_seqlens = [ |
| math.ceil(s.shape[0] / self.speech_tok_compress_ratio) |
| for s in speech_inputs |
| ] |
| max_speech_length = max((s.shape[0] for s in speech_inputs)) |
| if speech_inputs[0].ndim == 1: |
| padded_speeches = np.full( |
| (len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32 |
| ) |
| else: |
| padded_speeches = np.full( |
| (len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), |
| fill_value=0, |
| dtype=np.float32, |
| ) |
| speech_masks = np.zeros( |
| (len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_ |
| ) |
| for i, (speech, vae_tok_length) in enumerate( |
| zip(speech_inputs, vae_tok_seqlens) |
| ): |
| padded_speeches[i, : len(speech)] = speech |
| speech_masks[i, :vae_tok_length] = True |
| result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks} |
| if return_tensors == "pt": |
| result["padded_speeches"] = torch.tensor( |
| padded_speeches, device=device, dtype=dtype or torch.float32 |
| ) |
| result["speech_masks"] = torch.tensor( |
| speech_masks, device=device, dtype=torch.bool |
| ) |
| return result |
|
|
| def batch_decode(self, *args, **kwargs): |
| return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs): |
| return self.tokenizer.decode(*args, **kwargs) |
|
|
| @property |
| def model_input_names(self): |
| tokenizer_input_names = self.tokenizer.model_input_names |
| audio_processor_input_names = self.audio_processor.model_input_names |
| return list( |
| dict.fromkeys( |
| tokenizer_input_names |
| + audio_processor_input_names |
| + ["speech_inputs", "speech_input_mask"] |
| ) |
| ) |
|
|
| def save_audio( |
| self, |
| audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], |
| output_path: str = "output.wav", |
| sampling_rate: Optional[int] = None, |
| normalize: bool = False, |
| batch_prefix: str = "audio_", |
| ) -> str: |
| return self.audio_processor.save_audio( |
| audio, |
| output_path=output_path, |
| sampling_rate=sampling_rate, |
| normalize=normalize, |
| batch_prefix=batch_prefix, |
| ) |
|
|
|
|
| __all__ = [ |
| 'QWEN3VoxStreamingProcessor' |
| ] |
| '\nQWEN3Vox Streaming Model Architecture (0.5B)\n\nThis module implements the streaming-optimized version of QWEN3Vox for real-time TTS.\nKey differences from the multi-speaker model:\n- No semantic tokenizer (only acoustic)\n- Split language model architecture: lower layers for text, upper layers for TTS\n- Optimized for low-latency generation\n' |
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Tuple, Union, Callable |
| from tqdm import tqdm |
| import copy |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| from transformers.models.auto import AutoModel, AutoModelForCausalLM |
| from transformers.activations import ACT2FN |
| from transformers.modeling_outputs import ( |
| CausalLMOutput, |
| BaseModelOutputWithPast, |
| ModelOutput, |
| ) |
| from transformers.models.llama.modeling_llama import LlamaRMSNorm |
| from transformers import modeling_utils |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
| if ( |
| not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") |
| or modeling_utils.ALL_PARALLEL_STYLES is None |
| ): |
| modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] |
|
|
|
|
| class BinaryClassifier(nn.Module): |
|
|
| def __init__(self, hidden_size): |
| super(BinaryClassifier, self).__init__() |
| self.fc1 = nn.Linear(hidden_size, hidden_size) |
| self.fc2 = nn.Linear(hidden_size, 1) |
|
|
| def forward(self, x): |
| x = torch.relu(self.fc1(x)) |
| x = self.fc2(x) |
| return x |
|
|
|
|
| class SpeechConnector(nn.Module): |
|
|
| def __init__(self, input_dim, output_dim): |
| super().__init__() |
| self.fc1 = nn.Linear(input_dim, output_dim) |
| self.norm = LlamaRMSNorm(output_dim, eps=1e-06) |
| self.fc2 = nn.Linear(output_dim, output_dim) |
|
|
| def forward(self, features, **kwargs): |
| x = self.fc1(features) |
| x = self.norm(x) |
| x = self.fc2(x) |
| return x |
|
|
|
|
| class QWEN3VoxStreamingPreTrainedModel(PreTrainedModel): |
| config_class = QWEN3VoxStreamingConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _skip_keys_device_placement = "past_key_values" |
| _supports_cache_class = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_quantized_cache = True |
| _supports_static_cache = True |
| _supports_attention_backend = True |
|
|
| def _init_weights(self, module): |
| if isinstance(module, QWEN3VoxDiffusionHead): |
| module.initialize_weights() |
| return |
| if hasattr(self.config, "language_model_config") and hasattr( |
| self.config.language_model_config, "initializer_range" |
| ): |
| std = self.config.language_model_config.initializer_range |
| elif hasattr(self.config, "decoder_config") and hasattr( |
| self.config.decoder_config, "initializer_range" |
| ): |
| std = self.config.decoder_config.initializer_range |
| else: |
| std = 0.02 |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.weight.data.fill_(1.0) |
| module.bias.data.zero_() |
|
|
|
|
| class QWEN3VoxStreamingModel(QWEN3VoxStreamingPreTrainedModel): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| if hasattr(config, "torch_dtype") and config.torch_dtype is not None: |
| if isinstance(config.torch_dtype, str): |
| dtype = getattr(torch, config.torch_dtype) |
| else: |
| dtype = config.torch_dtype |
| else: |
| dtype = torch.float32 |
| lm_config = copy.deepcopy(config.decoder_config) |
| lm_backbone_num_hidden_layers = ( |
| getattr(lm_config, "num_hidden_layers", 24) |
| - config.tts_backbone_num_hidden_layers |
| ) |
| lm_config.num_hidden_layers = lm_backbone_num_hidden_layers |
| self.language_model = AutoModel.from_config(lm_config) |
| self.language_model.norm = nn.Identity() |
| tts_lm_config = copy.deepcopy(lm_config) |
| tts_lm_config.num_hidden_layers = config.tts_backbone_num_hidden_layers |
| self.tts_language_model = AutoModel.from_config(tts_lm_config) |
| self.tts_input_types = nn.Embedding( |
| num_embeddings=2, embedding_dim=config.decoder_config.hidden_size |
| ) |
| self.acoustic_tokenizer = AutoModel.from_config( |
| config.acoustic_tokenizer_config |
| ).to(dtype) |
| self.acoustic_connector = SpeechConnector( |
| config.acoustic_vae_dim, lm_config.hidden_size |
| ).to(dtype) |
| self.register_buffer("speech_scaling_factor", torch.tensor(float("nan"))) |
| self.register_buffer("speech_bias_factor", torch.tensor(float("nan"))) |
| self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to( |
| dtype |
| ) |
| self.noise_scheduler = DPMSolverMultistepScheduler( |
| num_train_timesteps=config.diffusion_head_config.ddpm_num_steps, |
| beta_schedule=config.diffusion_head_config.ddpm_beta_schedule, |
| prediction_type=config.diffusion_head_config.prediction_type, |
| ) |
|
|
| def get_input_embeddings(self): |
| if hasattr(self.language_model, "embed_tokens"): |
| return self.language_model.embed_tokens |
| for name, attr in self.language_model.fullmap.items(): |
| if attr.orig_name == "embed_tokens.weight": |
| return getattr(self.language_model, name) |
| assert False, "should not arrive here" |
|
|
| def set_input_embeddings(self, value): |
| self.language_model.embed_tokens = value |
|
|
| def set_speech_tokenizers(self, acoustic_tokenizer=None): |
| self.acoustic_tokenizer = acoustic_tokenizer |
| if self.acoustic_tokenizer is not None: |
| self.acoustic_tokenizer.train(False) |
|
|
| def forward(self, *args, **kwargs): |
| raise RuntimeError( |
| 'QWEN3VoxStreamingModel.forward is intentionally disabled. Use `model.language_model(...)` or `model.tts_language_model(...)` instead.' |
| ) |
|
|
|
|
| AutoModel.register(QWEN3VoxStreamingConfig, QWEN3VoxStreamingModel) |
| __all__ = [ |
| 'QWEN3VoxStreamingPreTrainedModel', |
| 'QWEN3VoxStreamingModel', |
| "BinaryClassifier", |
| "SpeechConnector", |
| ] |
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Tuple, Union, Callable |
| from tqdm import tqdm |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| from transformers.models.auto import AutoModel, AutoModelForCausalLM |
| from transformers.activations import ACT2FN |
| from transformers.modeling_outputs import ( |
| CausalLMOutput, |
| BaseModelOutputWithPast, |
| ModelOutput, |
| ) |
| from transformers.models.llama.modeling_llama import LlamaRMSNorm |
| from transformers import modeling_utils |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
| if ( |
| not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") |
| or modeling_utils.ALL_PARALLEL_STYLES is None |
| ): |
| modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] |
|
|
|
|
| @dataclass |
| class QWEN3VoxCausalLMOutputWithPast(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| diffusion_loss: Optional[torch.FloatTensor] = None |
| speech_token_num: Optional[int] = None |
| logits: torch.FloatTensor = None |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
|
|
| @dataclass |
| class QWEN3VoxGenerationOutput(ModelOutput): |
| sequences: torch.LongTensor = None |
| speech_outputs: Optional[List[torch.FloatTensor]] = None |
|
|
|
|
| class SpeechConnector(nn.Module): |
|
|
| def __init__(self, input_dim, output_dim): |
| super().__init__() |
| self.fc1 = nn.Linear(input_dim, output_dim) |
| self.norm = LlamaRMSNorm(output_dim, eps=1e-06) |
| self.fc2 = nn.Linear(output_dim, output_dim) |
|
|
| def forward(self, features, **kwargs): |
| x = self.fc1(features) |
| x = self.norm(x) |
| x = self.fc2(x) |
| return x |
|
|
|
|
| class QWEN3VoxPreTrainedModel(PreTrainedModel): |
| config_class = QWEN3VoxConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _skip_keys_device_placement = "past_key_values" |
| _supports_cache_class = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_quantized_cache = True |
| _supports_static_cache = True |
| _supports_attention_backend = True |
|
|
| def _init_weights(self, module): |
| if isinstance(module, QWEN3VoxDiffusionHead): |
| module.initialize_weights() |
| return |
| if hasattr(self.config, "language_model_config") and hasattr( |
| self.config.language_model_config, "initializer_range" |
| ): |
| std = self.config.language_model_config.initializer_range |
| elif hasattr(self.config, "decoder_config") and hasattr( |
| self.config.decoder_config, "initializer_range" |
| ): |
| std = self.config.decoder_config.initializer_range |
| else: |
| std = 0.02 |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.weight.data.fill_(1.0) |
| module.bias.data.zero_() |
|
|
|
|
| class QWEN3VoxModel(QWEN3VoxPreTrainedModel): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| if hasattr(config, "torch_dtype") and config.torch_dtype is not None: |
| if isinstance(config.torch_dtype, str): |
| dtype = getattr(torch, config.torch_dtype) |
| else: |
| dtype = config.torch_dtype |
| else: |
| dtype = torch.float32 |
| lm_config = config.decoder_config |
| self.language_model = AutoModel.from_config(lm_config) |
| self.acoustic_tokenizer = AutoModel.from_config( |
| config.acoustic_tokenizer_config |
| ).to(dtype) |
| self.semantic_tokenizer = AutoModel.from_config( |
| config.semantic_tokenizer_config |
| ).to(dtype) |
| self.acoustic_connector = SpeechConnector( |
| config.acoustic_vae_dim, lm_config.hidden_size |
| ).to(dtype) |
| self.semantic_connector = SpeechConnector( |
| config.semantic_vae_dim, lm_config.hidden_size |
| ).to(dtype) |
| self.register_buffer("speech_scaling_factor", torch.tensor(float("nan"))) |
| self.register_buffer("speech_bias_factor", torch.tensor(float("nan"))) |
| self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to( |
| dtype |
| ) |
| self.noise_scheduler = DPMSolverMultistepScheduler( |
| num_train_timesteps=config.diffusion_head_config.ddpm_num_steps, |
| beta_schedule=config.diffusion_head_config.ddpm_beta_schedule, |
| prediction_type=config.diffusion_head_config.prediction_type, |
| ) |
|
|
| def get_input_embeddings(self): |
| if hasattr(self.language_model, "embed_tokens"): |
| return self.language_model.embed_tokens |
| for name, attr in self.language_model.fullmap.items(): |
| if attr.orig_name == "embed_tokens.weight": |
| return getattr(self.language_model, name) |
| assert False, "should not arrive here" |
|
|
| def set_input_embeddings(self, value): |
| self.language_model.embed_tokens = value |
|
|
| def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None): |
| self.acoustic_tokenizer = acoustic_tokenizer |
| self.semantic_tokenizer = semantic_tokenizer |
| if self.acoustic_tokenizer is not None: |
| self.acoustic_tokenizer.train(False) |
| if self.semantic_tokenizer is not None: |
| self.semantic_tokenizer.train(False) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
| outputs = self.language_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| if not return_dict: |
| return outputs |
| return BaseModelOutputWithPast( |
| last_hidden_state=outputs.last_hidden_state, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class QWEN3VoxForConditionalGeneration(QWEN3VoxPreTrainedModel): |
| _tied_weights_keys = ["lm_head.weight"] |
| _tp_plan = {"lm_head": "colwise_rep"} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = QWEN3VoxModel(config) |
| self.vocab_size = config.decoder_config.vocab_size |
| self.lm_head = nn.Linear( |
| config.decoder_config.hidden_size, self.vocab_size, bias=False |
| ) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_decoder(self, decoder): |
| self.model.language_model = decoder |
|
|
| def get_decoder(self): |
| return self.model.language_model |
|
|
| def tie_weights(self): |
| if getattr(self.config.decoder_config, "tie_word_embeddings", False): |
| output_embeddings = self.get_output_embeddings() |
| input_embeddings = self.get_input_embeddings() |
| if hasattr(input_embeddings, "weight"): |
| output_embeddings.weight = input_embeddings.weight |
| else: |
| output_embeddings.weight = input_embeddings |
| if getattr(output_embeddings, "bias", None) is not None: |
| output_embeddings.bias.data = nn.functional.pad( |
| output_embeddings.bias.data, |
| ( |
| 0, |
| output_embeddings.weight.shape[0] |
| - output_embeddings.bias.shape[0], |
| ), |
| "constant", |
| 0, |
| ) |
| print("Tied input and output embeddings using standard assignment.") |
| else: |
| print("tie_word_embeddings is False, not tying weights.") |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def forward_speech_features( |
| self, |
| speech_tensors=None, |
| speech_masks=None, |
| speech_type="audio", |
| return_unmask=False, |
| ): |
| if speech_tensors is None: |
| vae_dim = self.config.acoustic_tokenizer_config.vae_dim |
| audio_features = torch.zeros(1, 1, vae_dim).to( |
| self.get_input_embeddings().weight |
| ) |
| connect_features = self.model.acoustic_connector(audio_features) |
| return (audio_features, connect_features) |
| else: |
| with torch.no_grad(): |
| if speech_type == "audio": |
| with torch.no_grad(): |
| frames = self.model.acoustic_tokenizer.encode( |
| speech_tensors.unsqueeze(1) |
| )[0][0] |
| audio_tokens = frames.sample( |
| self.model.acoustic_tokenizer.std_dist_type |
| )[0] |
| elif speech_type == "vae": |
| vae_dim = self.config.acoustic_tokenizer_config.vae_dim |
| speech_mode = speech_tensors.reshape( |
| speech_tensors.size(0), -1, vae_dim |
| ) |
| batch_size = speech_mode.size(0) |
| value = self.model.acoustic_tokenizer.fix_std / 0.8 |
| std = ( |
| torch.randn( |
| batch_size, |
| dtype=speech_mode.dtype, |
| device=speech_mode.device, |
| ) |
| * value |
| ) |
| std = std.view(-1, *[1] * (speech_mode.dim() - 1)) |
| audio_tokens = speech_mode + std * torch.randn( |
| speech_mode.shape |
| ).to(speech_mode) |
| else: |
| raise NotImplementedError( |
| f"Speech type {speech_type } not implemented" |
| ) |
| if torch.isnan(self.model.speech_scaling_factor) or torch.isnan( |
| self.model.speech_bias_factor |
| ): |
| scaling_factor = 1.0 / audio_tokens[speech_masks].flatten().std() |
| bias_factor = -audio_tokens[speech_masks].flatten().mean() |
| if dist.is_available() and dist.is_initialized(): |
| dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM) |
| dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM) |
| world_size = dist.get_world_size() |
| self.model.speech_scaling_factor.copy_( |
| scaling_factor / world_size |
| ) |
| self.model.speech_bias_factor.copy_(bias_factor / world_size) |
| print( |
| f"Speech scaling factor (distributed): {self .model .speech_scaling_factor }, bias factor: {self .model .speech_bias_factor }", |
| flush=True, |
| ) |
| else: |
| self.model.speech_scaling_factor.copy_(scaling_factor) |
| self.model.speech_bias_factor.copy_(bias_factor) |
| print( |
| f"Speech scaling factor (single process): {self .model .speech_scaling_factor }, bias factor: {self .model .speech_bias_factor }", |
| flush=True, |
| ) |
| audio_features = ( |
| audio_tokens + self.model.speech_bias_factor |
| ) * self.model.speech_scaling_factor |
| connect_features = self.model.acoustic_connector(audio_features) |
| if return_unmask: |
| return (audio_features, connect_features) |
| return (audio_features[speech_masks], connect_features[speech_masks]) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = False, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| speech_tensors: Optional[torch.FloatTensor] = None, |
| speech_masks: Optional[torch.BoolTensor] = None, |
| speeches_loss_input: Optional[torch.FloatTensor] = None, |
| speech_semantic_tensors: Optional[torch.FloatTensor] = None, |
| acoustic_input_mask: Optional[torch.BoolTensor] = None, |
| acoustic_loss_mask: Optional[torch.BoolTensor] = None, |
| ddpm_batch_mul: int = 1, |
| **kwargs: Optional[Dict[str, Union[torch.Tensor, str]]], |
| ) -> Union[Tuple, QWEN3VoxCausalLMOutputWithPast]: |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
| x = self.get_input_embeddings()(input_ids) |
| semantic_speech_all_connect_features = self.model.semantic_connector( |
| speech_semantic_tensors |
| ) |
| if speeches_loss_input is not None: |
| speech_all_features, speech_all_connect_features = ( |
| self.forward_speech_features( |
| speech_tensors=( |
| speech_tensors.type_as(x) |
| if speech_tensors is not None |
| else None |
| ), |
| speech_masks=speech_masks, |
| speech_type=kwargs.get("speech_type", "audio"), |
| return_unmask=True, |
| ) |
| ) |
| if speech_tensors is not None: |
| if semantic_speech_all_connect_features is not None: |
| x[acoustic_input_mask] = ( |
| speech_all_connect_features[speech_masks] |
| + semantic_speech_all_connect_features[speech_masks] |
| ) |
| else: |
| x[acoustic_input_mask] = speech_all_connect_features[speech_masks] |
| target_latent_mask = speeches_loss_input & speech_masks |
| speech_features = speech_all_features[target_latent_mask] |
| speech_connect_features = speech_all_connect_features[ |
| target_latent_mask |
| ] |
| else: |
| speech_features, speech_connect_features = self.forward_speech_features( |
| speech_tensors=( |
| speech_tensors.type_as(x) if speech_tensors is not None else None |
| ), |
| speech_masks=speech_masks, |
| speech_type=kwargs.get("speech_type", "audio"), |
| ) |
| if speech_tensors is not None: |
| x[acoustic_input_mask] = speech_connect_features |
| outputs = self.model( |
| input_ids=None, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=x, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=False, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
| hidden_states = outputs.last_hidden_state |
| logits = self.lm_head(hidden_states) |
| loss = None |
| if labels is not None: |
| pass |
| diffusion_loss = None |
| if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0: |
| condition_features = hidden_states[acoustic_loss_mask] |
| speech_len, latent_size = speech_features.shape |
| noise = torch.randn( |
| (speech_len * ddpm_batch_mul, latent_size), |
| device=hidden_states.device, |
| dtype=hidden_states.dtype, |
| ) |
| timesteps = torch.multinomial( |
| torch.ones(self.config.diffusion_head_config.ddpm_num_steps), |
| speech_len * ddpm_batch_mul, |
| replacement=True, |
| ).to(hidden_states.device) |
| speech_features_repeated = speech_features.repeat_interleave( |
| ddpm_batch_mul, dim=0 |
| ) |
| condition_features_repeated = condition_features.repeat_interleave( |
| ddpm_batch_mul, dim=0 |
| ) |
| noisy_speech_features = self.model.noise_scheduler.add_noise( |
| speech_features_repeated, noise, timesteps |
| ) |
| model_output = self.model.prediction_head( |
| noisy_speech_features, timesteps.type_as(x), condition_features_repeated |
| ) |
| prediction_type = self.config.diffusion_head_config.prediction_type |
| if prediction_type == "epsilon": |
| target_for_loss = noise |
| elif prediction_type == "v_prediction": |
| target_for_loss = self.model.noise_scheduler.get_velocity( |
| speech_features_repeated, noise, timesteps |
| ) |
| else: |
| raise NotImplementedError( |
| f"Prediction type {prediction_type } not implemented" |
| ) |
| diffusion_loss = F.mse_loss( |
| model_output.float(), target_for_loss.float(), reduction="sum" |
| ) |
| if latent_size > 0 and ddpm_batch_mul > 0: |
| diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul |
| else: |
| diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device) |
| else: |
| diffusion_loss = ( |
| sum((p.sum() for p in self.model.prediction_head.parameters())) * 0.0 |
| ) |
| diffusion_loss += ( |
| sum((p.sum() for p in self.model.acoustic_connector.parameters())) * 0.0 |
| ) |
| diffusion_loss += ( |
| sum((p.sum() for p in self.model.semantic_connector.parameters())) * 0.0 |
| ) |
| if not return_dict: |
| output = (logits, speech_len) + outputs.to_tuple()[1:] |
| return (loss, diffusion_loss) + output |
| return QWEN3VoxCausalLMOutputWithPast( |
| loss=loss, |
| diffusion_loss=diffusion_loss, |
| speech_token_num=speech_len if speech_tensors is not None else 0, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| AutoModel.register(QWEN3VoxConfig, QWEN3VoxModel) |
| AutoModelForCausalLM.register(QWEN3VoxConfig, QWEN3VoxForConditionalGeneration) |
| __all__ = [ |
| 'QWEN3VoxModel', |
| 'QWEN3VoxPreTrainedModel', |
| 'QWEN3VoxForConditionalGeneration', |
| 'QWEN3VoxCausalLMOutputWithPast', |
| 'QWEN3VoxGenerationOutput', |
| ] |
| '\nQWEN3Vox Processors\n\nThis module provides processors for preparing inputs for QWEN3Vox models:\n- QWEN3VoxProcessor: For multi-speaker models (1.5B, 7B)\n- QWEN3VoxStreamingProcessor: For streaming model (0.5B)\n' |
| __all__ = [ |
| 'QWEN3VoxProcessor', |
| 'QWEN3VoxStreamingProcessor', |
| 'QWEN3VoxTokenizerProcessor', |
| "AudioNormalizer", |
| 'QWEN3VoxASRProcessor', |
| ] |
| '\nQWEN3Vox Streaming Inference Model (0.5B)\n\nThis module implements the inference engine for real-time streaming TTS.\nKey features:\n- Window-based text/speech interleaving for streaming\n- Binary EOS classifier for end-of-speech detection\n- Classifier-free guidance for speech quality\n- Audio streaming support\n' |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional, Tuple, Union, Callable |
| from tqdm import tqdm |
| import torch |
| import torch.nn as nn |
| from transformers.models.auto import AutoModel, AutoModelForCausalLM |
| from transformers.generation import ( |
| GenerationMixin, |
| GenerationConfig, |
| LogitsProcessor, |
| LogitsProcessorList, |
| StoppingCriteriaList, |
| ) |
| from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput |
| from transformers import modeling_utils |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
| if ( |
| not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") |
| or modeling_utils.ALL_PARALLEL_STYLES is None |
| ): |
| modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] |
| TTS_TEXT_WINDOW_SIZE = 5 |
| TTS_SPEECH_WINDOW_SIZE = 6 |
|
|
|
|
| def _update_model_kwargs_for_generation( |
| outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1 |
| ) -> Dict[str, Any]: |
| model_kwargs["past_key_values"] = getattr(outputs, "past_key_values") |
| attention_mask = model_kwargs["attention_mask"] |
| model_kwargs["attention_mask"] = torch.cat( |
| [ |
| attention_mask, |
| attention_mask.new_ones((attention_mask.shape[0], num_new_tokens)), |
| ], |
| dim=-1, |
| ) |
| model_kwargs["cache_position"] = torch.arange( |
| model_kwargs["cache_position"][-1] + 1, |
| model_kwargs["cache_position"][-1] + num_new_tokens + 1, |
| ).to(model_kwargs["cache_position"].device) |
| return model_kwargs |
|
|
|
|
| @dataclass |
| class QWEN3VoxLMHeadOutputWithPast(BaseModelOutputWithPast): |
| """LM-head-only return type for streaming / lightweight forwards (no loss/diffusion fields).""" |
|
|
| logits: Optional[torch.FloatTensor] = None |
|
|
|
|
| @dataclass |
| class QWEN3VoxGenerationOutput(ModelOutput): |
| sequences: torch.LongTensor = None |
| speech_outputs: Optional[List[torch.FloatTensor]] = None |
| reach_max_step_sample: Optional[torch.BoolTensor] = None |
|
|
|
|
| class QWEN3VoxStreamingForConditionalGenerationInference( |
| QWEN3VoxStreamingPreTrainedModel, GenerationMixin |
| ): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = QWEN3VoxStreamingModel(config) |
| self.tts_eos_classifier = BinaryClassifier(config.decoder_config.hidden_size) |
| self.ddpm_inference_steps = ( |
| config.diffusion_head_config.ddpm_num_inference_steps |
| ) |
| self.post_init() |
|
|
| @property |
| def noise_scheduler(self): |
| return self.model.noise_scheduler |
|
|
| @property |
| def prediction_head(self): |
| return self.model.prediction_head |
|
|
| @property |
| def speech_scaling_factor(self): |
| return self.model.speech_scaling_factor |
|
|
| @property |
| def speech_bias_factor(self): |
| return self.model.speech_bias_factor |
|
|
| @property |
| def acoustic_tokenizer(self): |
| return self.model.acoustic_tokenizer |
|
|
| @property |
| def acoustic_connector(self): |
| return self.model.acoustic_connector |
|
|
| def tie_weights(self): |
| if not getattr(self.config, "tie_word_embeddings", False): |
| return |
| if hasattr(self, "lm_head") and hasattr( |
| self.model.language_model, "embed_tokens" |
| ): |
| self.lm_head.weight = self.model.language_model.embed_tokens.weight |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self): |
| return None |
|
|
| def set_output_embeddings(self, new_embeddings): |
| raise RuntimeError( |
| "Output embeddings (lm_head) are not defined for this model. Create one before calling set_output_embeddings if needed." |
| ) |
|
|
| def set_speech_tokenizers(self, acoustic_tokenizer=None): |
| self.model.set_speech_tokenizers(acoustic_tokenizer) |
|
|
| def set_ddpm_inference_steps(self, num_steps=None): |
| self.ddpm_inference_steps = ( |
| num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps |
| ) |
|
|
| def forward_lm( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
| if inputs_embeds is None: |
| inputs_embeds = self.model.get_input_embeddings()(input_ids) |
| outputs = self.model.language_model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state |
| if labels is not None: |
| raise NotImplementedError( |
| "Loss computation is not implemented in this version." |
| ) |
| return BaseModelOutputWithPast( |
| past_key_values=outputs.past_key_values, |
| last_hidden_state=hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def forward_tts_lm( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| lm_last_hidden_state: Optional[torch.FloatTensor] = None, |
| tts_text_masks: Optional[torch.BoolTensor] = None, |
| **kwargs, |
| ) -> Union[Tuple, QWEN3VoxLMHeadOutputWithPast]: |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
| if inputs_embeds is None: |
| inputs_embeds = self.model.get_input_embeddings()(input_ids) |
| start_idx = inputs_embeds.shape[1] - lm_last_hidden_state.shape[1] |
| inputs_embeds[:, start_idx:, :] = lm_last_hidden_state |
| inputs_embeds = inputs_embeds + self.model.tts_input_types( |
| tts_text_masks.long() |
| ) |
| outputs = self.model.tts_language_model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state |
| logits = self.tts_eos_classifier(hidden_states[:, -1, :]) |
| if labels is not None: |
| raise NotImplementedError( |
| "Loss computation is not implemented in this version." |
| ) |
| return QWEN3VoxLMHeadOutputWithPast( |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| last_hidden_state=hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def forward(self, *args, **kwargs): |
| raise RuntimeError( |
| "Unified forward is disabled. Use `forward_lm`, `forward_tts_lm`, or `generate` instead." |
| ) |
|
|
| def _build_generate_config_model_kwargs( |
| self, generation_config, inputs, tokenizer, return_processors=False, **kwargs |
| ): |
| if generation_config is None: |
| generation_config = GenerationConfig( |
| bos_token_id=tokenizer.bos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| else: |
| generation_config = GenerationConfig( |
| **generation_config, |
| bos_token_id=tokenizer.bos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| generation_config, model_kwargs = self._prepare_generation_config( |
| generation_config, |
| True, |
| speech_start_id=tokenizer.speech_start_id, |
| speech_end_id=tokenizer.speech_end_id, |
| speech_diffusion_id=tokenizer.speech_diffusion_id, |
| **kwargs, |
| ) |
| generation_config.speech_start_id = tokenizer.speech_start_id |
| generation_config.speech_end_id = tokenizer.speech_end_id |
| generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id |
| inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( |
| inputs, generation_config.bos_token_id, model_kwargs |
| ) |
| batch_size = inputs_tensor.shape[0] |
| device = self.device |
| self._prepare_special_tokens(generation_config, True, device=device) |
| generation_config.use_cache = True |
| model_kwargs["use_cache"] = generation_config.use_cache |
| input_ids = inputs_tensor.to(self.device) |
| input_ids_length = input_ids.shape[1] |
| has_default_max_length = ( |
| kwargs.get("max_length") is None |
| and generation_config.max_length is not None |
| ) |
| has_default_min_length = ( |
| kwargs.get("min_length") is None |
| and generation_config.min_length is not None |
| ) |
| generation_config = self._prepare_generated_length( |
| generation_config=generation_config, |
| has_default_max_length=has_default_max_length, |
| has_default_min_length=has_default_min_length, |
| model_input_name=model_input_name, |
| inputs_tensor=inputs_tensor, |
| input_ids_length=input_ids_length, |
| ) |
| max_cache_length = generation_config.max_length - 1 |
| self._prepare_cache_for_generation( |
| generation_config, model_kwargs, None, batch_size, max_cache_length, device |
| ) |
| model_kwargs["cache_position"] = torch.arange( |
| input_ids_length, device=device, dtype=torch.long |
| ) |
| for k, v in model_kwargs.items(): |
| if isinstance(v, torch.Tensor): |
| model_kwargs[k] = v.to(device=device) |
| if return_processors: |
| logits_processor = self._get_logits_processor( |
| generation_config=generation_config, |
| input_ids_seq_length=input_ids_length, |
| encoder_input_ids=inputs_tensor, |
| prefix_allowed_tokens_fn=None, |
| logits_processor=LogitsProcessorList(), |
| device=inputs_tensor.device, |
| model_kwargs=model_kwargs, |
| ) |
| stopping_criteria = self._get_stopping_criteria( |
| generation_config=generation_config, |
| stopping_criteria=StoppingCriteriaList(), |
| ) |
| return ( |
| generation_config, |
| model_kwargs, |
| input_ids, |
| logits_processor, |
| stopping_criteria, |
| ) |
| else: |
| return (generation_config, model_kwargs, input_ids) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| prefix_allowed_tokens_fn: Optional[ |
| Callable[[int, torch.Tensor], List[int]] |
| ] = None, |
| synced_gpus: Optional[bool] = None, |
| assistant_model: Optional["PreTrainedModel"] = None, |
| audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None, |
| negative_prompt_ids: Optional[torch.Tensor] = None, |
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, |
| speech_tensors: Optional[torch.FloatTensor] = None, |
| speech_masks: Optional[torch.BoolTensor] = None, |
| speech_input_mask: Optional[torch.BoolTensor] = None, |
| tts_text_ids: Optional[torch.LongTensor] = None, |
| return_speech: bool = True, |
| cfg_scale: float = 1.0, |
| stop_check_fn: Optional[Callable[[], bool]] = None, |
| **kwargs, |
| ) -> Union[torch.LongTensor, QWEN3VoxGenerationOutput]: |
| tokenizer = kwargs.pop("tokenizer", None) |
| neg_text_input_id = tokenizer.convert_tokens_to_ids("<|image_pad|>") |
| tts_lm_input_ids = kwargs.pop("tts_lm_input_ids", None) |
| tts_lm_attention_mask = kwargs.pop("tts_lm_attention_mask", None) |
| all_prefilled_outputs = kwargs.pop("all_prefilled_outputs", None) |
| tts_text_ids = tts_text_ids.to(self.device) |
| if kwargs.get("max_new_tokens", None) is None: |
| kwargs["max_new_tokens"] = ( |
| self.config.decoder_config.max_position_embeddings |
| - tts_lm_input_ids.shape[-1] |
| ) |
| ( |
| generation_config, |
| model_kwargs, |
| input_ids, |
| logits_processor, |
| stopping_criteria, |
| ) = self._build_generate_config_model_kwargs( |
| generation_config, inputs, tokenizer, return_processors=True, **kwargs |
| ) |
| negative_kwargs = { |
| "input_ids": torch.full( |
| (kwargs["input_ids"].shape[0], 1), |
| neg_text_input_id, |
| dtype=torch.long, |
| device=kwargs["input_ids"].device, |
| ), |
| "attention_mask": torch.ones( |
| (kwargs["input_ids"].shape[0], 1), |
| dtype=torch.long, |
| device=kwargs["input_ids"].device, |
| ), |
| "max_new_tokens": kwargs.get("max_new_tokens", 100), |
| } |
| negative_generation_config, negative_model_kwargs, negative_input_ids = ( |
| self._build_generate_config_model_kwargs( |
| None, None, tokenizer, return_processors=False, **negative_kwargs |
| ) |
| ) |
| tts_lm_kwargs = { |
| "input_ids": tts_lm_input_ids, |
| "attention_mask": tts_lm_attention_mask, |
| "max_new_tokens": kwargs.get("max_new_tokens", 100), |
| } |
| tts_lm_generation_config, tts_lm_model_kwargs, tts_lm_input_ids = ( |
| self._build_generate_config_model_kwargs( |
| None, None, tokenizer, return_processors=False, **tts_lm_kwargs |
| ) |
| ) |
| tts_lm_negative_kwargs = { |
| "input_ids": torch.full( |
| (kwargs["input_ids"].shape[0], 1), |
| neg_text_input_id, |
| dtype=torch.long, |
| device=kwargs["input_ids"].device, |
| ), |
| "attention_mask": torch.ones( |
| (kwargs["input_ids"].shape[0], 1), |
| dtype=torch.long, |
| device=kwargs["input_ids"].device, |
| ), |
| "max_new_tokens": kwargs.get("max_new_tokens", 100), |
| } |
| ( |
| tts_lm_negative_generation_config, |
| tts_lm_negative_model_kwargs, |
| tts_lm_negative_input_ids, |
| ) = self._build_generate_config_model_kwargs( |
| None, None, tokenizer, return_processors=False, **tts_lm_negative_kwargs |
| ) |
| acoustic_cache = QWEN3VoxTokenizerStreamingCache() |
| batch_size = input_ids.shape[0] |
| assert batch_size == 1, "Currently only supports batch size == 1" |
| device = input_ids.device |
| finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device) |
| verbose = kwargs.get("verbose", False) |
| audio_chunks = [[] for _ in range(batch_size)] |
| tts_text_window_index = 0 |
| reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device) |
| first_text_window_size = ( |
| TTS_TEXT_WINDOW_SIZE |
| if tts_text_ids.shape[1] >= TTS_TEXT_WINDOW_SIZE |
| else tts_text_ids.shape[1] |
| ) |
| outputs = all_prefilled_outputs["lm"] |
| tts_lm_outputs = all_prefilled_outputs["tts_lm"] |
| negative_outputs = all_prefilled_outputs["neg_lm"] |
| tts_lm_negative_outputs = all_prefilled_outputs["neg_tts_lm"] |
| model_kwargs = _update_model_kwargs_for_generation( |
| outputs, model_kwargs, num_new_tokens=first_text_window_size |
| ) |
| tts_lm_model_kwargs = _update_model_kwargs_for_generation( |
| tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=first_text_window_size |
| ) |
| negative_model_kwargs = self._update_model_kwargs_for_generation( |
| negative_outputs, negative_model_kwargs, is_encoder_decoder=False |
| ) |
| tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( |
| tts_lm_negative_outputs, |
| tts_lm_negative_model_kwargs, |
| is_encoder_decoder=False, |
| ) |
| step = tts_lm_input_ids.shape[1] |
| total_generated_speech_tokens = 0 |
| total_prefilled_text_tokens = 0 |
| if kwargs.get("show_progress_bar", True): |
| progress_bar = tqdm( |
| total=tts_lm_generation_config.max_length, |
| desc=f"Prefilled {step } tokens, current step ({step } / {tts_lm_generation_config .max_length })", |
| initial=step, |
| leave=False, |
| ) |
| else: |
| progress_bar = None |
| while True: |
| if stop_check_fn is not None and stop_check_fn(): |
| if verbose: |
| print(f"Generation stopped externally at step {step +1 }") |
| if audio_streamer is not None: |
| audio_streamer.end() |
| break |
| if finished_tags.all(): |
| if hasattr(progress_bar, "set_description"): |
| progress_bar.set_description("Generation complete") |
| break |
| cur_input_tts_text_ids = tts_text_ids[ |
| :, |
| tts_text_window_index |
| * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 1) |
| * TTS_TEXT_WINDOW_SIZE, |
| ] |
| next_text_window_size = tts_text_ids[ |
| :, |
| (tts_text_window_index + 1) |
| * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 2) |
| * TTS_TEXT_WINDOW_SIZE, |
| ].shape[1] |
| tts_text_window_index += 1 |
| if cur_input_tts_text_ids.shape[1] > 0: |
| input_ids = torch.cat([input_ids, cur_input_tts_text_ids], dim=-1) |
| tts_lm_input_ids = torch.cat( |
| [tts_lm_input_ids, cur_input_tts_text_ids], dim=-1 |
| ) |
| if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: |
| if verbose: |
| print( |
| f"Reached maximum generation length {generation_config .max_length }, stopped it." |
| ) |
| reached_samples = torch.arange(batch_size, device=device)[ |
| ~finished_tags |
| ] |
| if reached_samples.numel() > 0: |
| reach_max_step_sample[reached_samples] = True |
| break |
| step += cur_input_tts_text_ids.shape[1] |
| total_prefilled_text_tokens += cur_input_tts_text_ids.shape[1] |
| if progress_bar is not None: |
| progress_bar.update(cur_input_tts_text_ids.shape[1]) |
| progress_bar.set_description( |
| f"Prefilled {total_prefilled_text_tokens } text tokens, generated {total_generated_speech_tokens } speech tokens, current step ({step } / {tts_lm_generation_config .max_length })" |
| ) |
| model_inputs = self.prepare_inputs_for_generation( |
| input_ids, **model_kwargs |
| ) |
| outputs = self.forward_lm( |
| **model_inputs, |
| return_dict=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| ) |
| model_kwargs = _update_model_kwargs_for_generation( |
| outputs, model_kwargs, num_new_tokens=next_text_window_size |
| ) |
| tts_lm_model_inputs = self.prepare_inputs_for_generation( |
| tts_lm_input_ids, **tts_lm_model_kwargs |
| ) |
| tts_lm_additional_inputs = { |
| "tts_text_masks": torch.ones_like(tts_lm_input_ids[:, -1:]), |
| "lm_last_hidden_state": outputs.last_hidden_state, |
| } |
| tts_lm_outputs = self.forward_tts_lm( |
| **tts_lm_model_inputs, |
| **tts_lm_additional_inputs, |
| return_dict=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| ) |
| tts_lm_model_kwargs = self._update_model_kwargs_for_generation( |
| tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False |
| ) |
| diffusion_indices = torch.LongTensor([0]) |
| for cur_speech_index in range(TTS_SPEECH_WINDOW_SIZE): |
| positive_condition = tts_lm_outputs.last_hidden_state[ |
| diffusion_indices, -1, : |
| ] |
| negative_condition = tts_lm_negative_outputs.last_hidden_state[ |
| diffusion_indices, -1, : |
| ] |
| speech_latent = self.sample_speech_tokens( |
| positive_condition, negative_condition, cfg_scale=cfg_scale |
| ).unsqueeze(1) |
| scaled_latent = speech_latent / self.model.speech_scaling_factor.to( |
| speech_latent.device |
| ) - self.model.speech_bias_factor.to(speech_latent.device) |
| audio_chunk = self.model.acoustic_tokenizer.decode( |
| scaled_latent.to(self.model.acoustic_tokenizer.device), |
| cache=acoustic_cache, |
| sample_indices=diffusion_indices.to( |
| self.model.acoustic_tokenizer.device |
| ), |
| use_cache=True, |
| debug=False, |
| ) |
| for i, sample_idx in enumerate(diffusion_indices): |
| idx = sample_idx.item() |
| if not finished_tags[idx]: |
| audio_chunks[idx].append(audio_chunk[i]) |
| if audio_streamer is not None: |
| audio_streamer.put(audio_chunk, diffusion_indices) |
| acoustic_embed = self.model.acoustic_connector(speech_latent) |
| tts_lm_input_ids = torch.cat( |
| [tts_lm_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], |
| dim=-1, |
| ) |
| if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: |
| break |
| step += 1 |
| total_generated_speech_tokens += 1 |
| if progress_bar is not None: |
| progress_bar.update(1) |
| progress_bar.set_description( |
| f"Prefilled {total_prefilled_text_tokens } text tokens, generated {total_generated_speech_tokens } speech tokens, current step ({step } / {tts_lm_generation_config .max_length })" |
| ) |
| tts_lm_model_inputs = self.prepare_inputs_for_generation( |
| tts_lm_input_ids, **tts_lm_model_kwargs |
| ) |
| tts_lm_additional_inputs = { |
| "tts_text_masks": torch.zeros_like(tts_lm_input_ids[:, -1:]), |
| "lm_last_hidden_state": acoustic_embed, |
| } |
| tts_lm_outputs = self.forward_tts_lm( |
| **tts_lm_model_inputs, |
| **tts_lm_additional_inputs, |
| return_dict=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| ) |
| if ( |
| cur_speech_index == TTS_SPEECH_WINDOW_SIZE - 1 |
| and next_text_window_size > 0 |
| ): |
| tts_lm_model_kwargs = _update_model_kwargs_for_generation( |
| tts_lm_outputs, |
| tts_lm_model_kwargs, |
| num_new_tokens=next_text_window_size, |
| ) |
| else: |
| tts_lm_model_kwargs = self._update_model_kwargs_for_generation( |
| tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False |
| ) |
| tts_lm_negative_input_ids = torch.cat( |
| [ |
| tts_lm_negative_input_ids, |
| torch.ones_like(tts_lm_input_ids[:, -1:]), |
| ], |
| dim=-1, |
| ) |
| tts_lm_negative_model_inputs = self.prepare_inputs_for_generation( |
| tts_lm_negative_input_ids, **tts_lm_negative_model_kwargs |
| ) |
| tts_lm_negative_additional_inputs = { |
| "tts_text_masks": torch.zeros_like( |
| tts_lm_negative_input_ids[:, -1:] |
| ), |
| "lm_last_hidden_state": acoustic_embed, |
| } |
| tts_lm_negative_outputs = self.forward_tts_lm( |
| **tts_lm_negative_model_inputs, |
| **tts_lm_negative_additional_inputs, |
| return_dict=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| ) |
| tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( |
| tts_lm_negative_outputs, |
| tts_lm_negative_model_kwargs, |
| is_encoder_decoder=False, |
| ) |
| tts_eos_logits = torch.sigmoid( |
| self.tts_eos_classifier( |
| tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :] |
| ) |
| ) |
| if tts_eos_logits[0].item() > 0.5: |
| finished_tags[diffusion_indices] = True |
| if audio_streamer is not None: |
| audio_streamer.end(diffusion_indices) |
| if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: |
| if verbose: |
| print( |
| f"Reached maximum generation length {tts_lm_generation_config .max_length }, stopped it." |
| ) |
| reached_samples = torch.arange(batch_size, device=device)[ |
| ~finished_tags |
| ] |
| if reached_samples.numel() > 0: |
| reach_max_step_sample[reached_samples] = True |
| break |
| if audio_streamer is not None: |
| audio_streamer.end() |
| final_audio_outputs = [] |
| for sample_chunks in audio_chunks: |
| if sample_chunks: |
| concatenated_audio = torch.cat(sample_chunks, dim=-1) |
| final_audio_outputs.append(concatenated_audio) |
| else: |
| final_audio_outputs.append(None) |
| if reach_max_step_sample is not None and reach_max_step_sample.any(): |
| print( |
| f"Reached maximum generation length {tts_lm_generation_config .max_length }, stopped it." |
| ) |
| return QWEN3VoxGenerationOutput( |
| sequences=tts_lm_input_ids, |
| speech_outputs=final_audio_outputs if return_speech else None, |
| reach_max_step_sample=reach_max_step_sample, |
| ) |
|
|
| @torch.no_grad() |
| def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0): |
| self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps) |
| condition = torch.cat([condition, neg_condition], dim=0).to( |
| self.model.prediction_head.device |
| ) |
| speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to( |
| condition |
| ) |
| for t in self.model.noise_scheduler.timesteps: |
| half = speech[: len(speech) // 2] |
| combined = torch.cat([half, half], dim=0) |
| eps = self.model.prediction_head( |
| combined, t.repeat(combined.shape[0]).to(combined), condition=condition |
| ) |
| cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) |
| half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
| eps = torch.cat([half_eps, half_eps], dim=0) |
| speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample |
| return speech[: len(speech) // 2] |
|
|
|
|
| AutoModelForCausalLM.register( |
| QWEN3VoxStreamingConfig, QWEN3VoxStreamingForConditionalGenerationInference |
| ) |
| __all__ = [ |
| 'QWEN3VoxStreamingForConditionalGenerationInference', |
| 'QWEN3VoxGenerationOutput', |
| 'QWEN3VoxLMHeadOutputWithPast', |
| "TTS_TEXT_WINDOW_SIZE", |
| "TTS_SPEECH_WINDOW_SIZE", |
| ] |
| import logging |
| import os |
| from dataclasses import dataclass, field |
| from typing import Any, Dict, List, Optional, Tuple |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from datasets import load_dataset, DatasetDict, VerificationMode |
| from transformers import HfArgumentParser, Trainer, set_seed, TrainerCallback |
| from transformers import TrainingArguments as HfTrainingArguments |
| from peft import LoraConfig, get_peft_model, TaskType |
|
|
| logger = logging.getLogger(__name__) |
| import copy |
| import torch |
| from transformers import TrainerCallback |
|
|
|
|
| class EmaCallback(TrainerCallback): |
|
|
| def __init__(self, attr_path="model.prediction_head", decay=0.999, device="cpu"): |
| self.attr_path = attr_path |
| self.decay = float(decay) |
| self.device = torch.device(device) |
| self.shadow = None |
| self._orig = None |
|
|
| def _get_module(self, model): |
| mod = model |
| for name in self.attr_path.split("."): |
| mod = getattr(mod, name) |
| return mod |
|
|
| def on_train_begin(self, args, state, control, model=None, **kwargs): |
| head = self._get_module(model) |
| self.shadow = { |
| k: p.detach().to(self.device).clone() for k, p in head.state_dict().items() |
| } |
|
|
| def on_step_end(self, args, state, control, model=None, **kwargs): |
| if self.shadow is None: |
| return |
| head = self._get_module(model) |
| with torch.no_grad(): |
| for k, v in head.state_dict().items(): |
| self.shadow[k].mul_(self.decay).add_( |
| v.detach().to(self.device), alpha=1.0 - self.decay |
| ) |
|
|
| def _swap_in_ema(self, model): |
| head = self._get_module(model) |
| self._orig = copy.deepcopy(head.state_dict()) |
| head.load_state_dict(self.shadow, strict=False) |
|
|
| def _swap_back(self, model): |
| if self._orig is None: |
| return |
| head = self._get_module(model) |
| head.load_state_dict(self._orig, strict=False) |
| self._orig = None |
|
|
| def on_evaluate(self, args, state, control, model=None, **kwargs): |
| self._swap_in_ema(model) |
|
|
| def on_evaluate_end(self, args, state, control, model=None, **kwargs): |
| self._swap_back(model) |
|
|
| def on_save(self, args, state, control, model=None, **kwargs): |
| self._swap_in_ema(model) |
|
|
| def on_save_end(self, args, state, control, model=None, **kwargs): |
| self._swap_back(model) |
|
|
| def on_train_end(self, args, state, control, model=None, **kwargs): |
| self._swap_in_ema(model) |
|
|
|
|
| @dataclass |
| class ModelArguments: |
| model_name_or_path: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": 'Path to QWEN3Vox base model with config.json' |
| }, |
| ) |
| processor_name_or_path: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "Path to processor dir (preprocessor_config.json). Defaults to model path." |
| }, |
| ) |
| cache_dir: Optional[str] = field(default=None) |
| freeze_acoustic_tokenizer: bool = field(default=True) |
| freeze_semantic_tokenizer: bool = field(default=True) |
| lora_r: int = field(default=8) |
| lora_alpha: int = field(default=32) |
| lora_dropout: float = field(default=0.05) |
| lora_target_modules: str = field( |
| default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj", |
| metadata={ |
| "help": "Comma-separated list of target module names in the LLM blocks" |
| }, |
| ) |
| lora_wrap_diffusion_head: bool = field( |
| default=False, metadata={"help": "Wrap diffusion head with PEFT LoRA"} |
| ) |
| train_diffusion_head: bool = field( |
| default=False, |
| metadata={"help": "Train diffusion prediction head (full fine-tune)"}, |
| ) |
| train_connectors: bool = field( |
| default=False, |
| metadata={"help": "Train acoustic/semantic connectors (full fine-tune)"}, |
| ) |
| layers_to_freeze: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "Comma-separated indices of diffusion head layers to freeze (e.g., '0,1,5,7,8')." |
| }, |
| ) |
|
|
|
|
| @dataclass |
| class DataArguments: |
| dataset_name: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "HF dataset name or 'json' with --train_jsonl for local files" |
| }, |
| ) |
| dataset_config_name: Optional[str] = field(default=None) |
| train_split_name: str = field(default="train") |
| eval_split_name: Optional[str] = field(default="validation") |
| text_column_name: str = field(default="text") |
| audio_column_name: str = field(default="audio") |
| voice_prompts_column_name: Optional[str] = field(default="voice_prompts") |
| eval_split_size: float = field(default=0.0) |
| ignore_verifications: bool = field(default=False) |
| max_length: Optional[int] = field(default=None) |
| train_jsonl: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "Path to local train JSONL with {text, audio, [voice_prompts]}" |
| }, |
| ) |
| validation_jsonl: Optional[str] = field( |
| default=None, metadata={"help": "Optional path to local validation JSONL"} |
| ) |
| voice_prompt_drop_rate: float = field( |
| default=0.0, |
| metadata={ |
| "help": "Probability to drop conditioning voice prompt during training (0.0 keep always, 1.0 drop always)." |
| }, |
| ) |
|
|
|
|
| @dataclass |
| class CustomTrainingArguments(HfTrainingArguments): |
| ddpm_batch_mul: int = field(default=1) |
| ce_loss_weight: float = field(default=1.0) |
| diffusion_loss_weight: float = field(default=1.0) |
| debug_ce_details: bool = field(default=False) |
| debug_ce_topk: int = field(default=5) |
| debug_ce_max_examples: int = field(default=1) |
| debug_ce_every_n_steps: int = field(default=200) |
| gradient_clipping: bool = field( |
| default=False, |
| metadata={ |
| "help": "Enable gradient clipping using max_grad_norm (set via --max_grad_norm, default 1.0). When False, disables clipping by forcing max_grad_norm=0.0." |
| }, |
| ) |
| debug_save: bool = field( |
| default=False, |
| metadata={ |
| "help": "If set, saves model components BEFORE training starts, into output_dir/debug_initial." |
| }, |
| ) |
|
|
|
|
| def build_lora_config(args: ModelArguments) -> LoraConfig: |
| target_modules = [ |
| s.strip() for s in args.lora_target_modules.split(",") if s.strip() |
| ] |
| |
| |
| return LoraConfig( |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| bias="none", |
| task_type=TaskType.FEATURE_EXTRACTION, |
| target_modules=target_modules, |
| ) |
|
|
|
|
| def build_head_lora_config(args: ModelArguments) -> LoraConfig: |
| target_modules = [ |
| "noisy_images_proj", |
| "cond_proj", |
| "gate_proj", |
| "up_proj", |
| "down_proj", |
| "linear", |
| ] |
| return LoraConfig( |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| bias="none", |
| task_type=TaskType.FEATURE_EXTRACTION, |
| target_modules=target_modules, |
| ) |
|
|
|
|
| def mask_for_ce( |
| labels: torch.Tensor, |
| attention_mask: torch.Tensor, |
| acoustic_input_mask: torch.Tensor, |
| pad_id: int = -100, |
| ) -> torch.Tensor: |
| shifted = labels[:, 1:].contiguous() |
| base_mask = ( |
| attention_mask[:, 1:].contiguous().eq(1) |
| if attention_mask is not None and attention_mask.numel() > 0 |
| else torch.ones_like(shifted, dtype=torch.bool) |
| ) |
| label_is_acoustic = acoustic_input_mask[:, 1:].contiguous() |
| final_mask = base_mask & ~label_is_acoustic |
| out = shifted.clone() |
| out[~final_mask] = pad_id |
| return out |
|
|
|
|
| def _patch_acoustic_encode_for_legacy_indexing(model_obj, logger_): |
| try: |
| acoustic = getattr( |
| getattr(model_obj, "model", model_obj), "acoustic_tokenizer", None |
| ) |
| if acoustic is None or not hasattr(acoustic, "encode"): |
| logger_.warning("No acoustic_tokenizer.encode() found to patch.") |
| return |
| base_encode = acoustic.encode |
|
|
| def encode_wrapped(*args, **kwargs): |
| out = base_encode(*args, **kwargs) |
| try: |
| _ = out[0][0] |
| return out |
| except Exception: |
| pass |
| if isinstance(out, dict): |
| for k in ("frames", "codes", "tokens", "latents", "hidden_states"): |
| if k in out: |
| return [[out[k]]] |
| if len(out) > 0: |
| return [[next(iter(out.values()))]] |
| for attr in ("frames", "codes", "tokens", "latents", "hidden_states"): |
| if hasattr(out, attr): |
| return [[getattr(out, attr)]] |
| try: |
| if isinstance(out, torch.Tensor): |
| return [[out]] |
| except Exception: |
| pass |
| return [[out]] |
|
|
| acoustic.encode = encode_wrapped |
| logger_.info( |
| "Patched acoustic_tokenizer.encode() to return [[...]] for legacy indexing." |
| ) |
| except Exception as e: |
| logger_.warning(f"Failed to patch acoustic_tokenizer.encode(): {e }") |
|
|
|
|
| def main() -> None: |
| parser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments)) |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, |
| ) |
| logger.info("Training/evaluation parameters %s", training_args) |
| set_seed(training_args.seed) |
| if not getattr(training_args, "gradient_clipping", False): |
| if hasattr(training_args, "max_grad_norm"): |
| training_args.max_grad_norm = 0.0 |
| logger.info( |
| "Gradient clipping disabled (set max_grad_norm=0.0). Use --gradient_clipping to enable." |
| ) |
| else: |
| if ( |
| not hasattr(training_args, "max_grad_norm") |
| or training_args.max_grad_norm is None |
| or training_args.max_grad_norm <= 0 |
| ): |
| training_args.max_grad_norm = 1.0 |
| logger.info( |
| f"Gradient clipping enabled: max_grad_norm={training_args .max_grad_norm }" |
| ) |
| model_name = model_args.model_name_or_path |
| if model_name is None: |
| raise ValueError( |
| "--model_name_or_path (or --processor_name_or_path) must be provided" |
| ) |
| processor: QWEN3VoxProcessor = QWEN3VoxProcessor.from_pretrained(model_name) |
| tok = processor.tokenizer |
| for required in ["speech_start_id", "speech_diffusion_id", "speech_end_id"]: |
| if not hasattr(tok, required) or getattr(tok, required) is None: |
| raise RuntimeError(f"Tokenizer missing required special id: {required }") |
| dtype = torch.float32 |
| if training_args.bf16: |
| dtype = torch.bfloat16 |
| elif getattr(training_args, "fp16", False): |
| dtype = torch.float16 |
| model = QWEN3VoxForConditionalGeneration.from_pretrained( |
| model_name, torch_dtype=dtype |
| ) |
| _patch_acoustic_encode_for_legacy_indexing(model, logger) |
| processor.semantic_tokenizer = getattr(model.model, "semantic_tokenizer", None) |
| try: |
| in_emb_mod = model.get_input_embeddings() |
| out_emb_mod = model.get_output_embeddings() |
| in_w = getattr(in_emb_mod, "weight", None) |
| out_w = getattr(out_emb_mod, "weight", None) |
| shared_ptr = bool( |
| in_w is not None |
| and out_w is not None |
| and (in_w.data_ptr() == out_w.data_ptr()) |
| ) |
| values_equal = False |
| if in_w is not None and out_w is not None and (in_w.shape == out_w.shape): |
| try: |
| values_equal = bool(torch.allclose(in_w, out_w)) |
| except Exception: |
| values_equal = False |
| try: |
| tie_cfg = getattr( |
| getattr(model.config, "decoder_config", model.config), |
| "tie_word_embeddings", |
| None, |
| ) |
| except Exception: |
| tie_cfg = getattr(model.config, "tie_word_embeddings", None) |
| logger.info( |
| f"LM head diagnostics -> shared_params={shared_ptr }, values_equal={values_equal }, tie_word_embeddings={tie_cfg }" |
| ) |
| if out_w is not None: |
| logger.info( |
| f"LM head requires_grad before freeze: {bool (out_w .requires_grad )}" |
| ) |
| except Exception as e: |
| logger.warning(f"LM head tie diagnostics failed: {e }") |
| try: |
| emb_module = model.get_input_embeddings() |
| head_module = model.get_output_embeddings() |
| if hasattr(emb_module, "weight") and hasattr(head_module, "weight"): |
| if ( |
| emb_module.weight.shape == head_module.weight.shape |
| and emb_module.weight.data_ptr() != head_module.weight.data_ptr() |
| ): |
| with torch.no_grad(): |
| head_module.weight = emb_module.weight |
| logger.info( |
| "Force-tied LM head weight to input embeddings (pointer share)." |
| ) |
| except Exception as e: |
| logger.warning(f"Force-tie of LM head failed: {e }") |
| try: |
| special_names = ["speech_start_id", "speech_diffusion_id", "speech_end_id"] |
| try: |
| vocab_size = int(getattr(model.config.decoder_config, "vocab_size", 0)) |
| except Exception: |
| vocab_size = 0 |
| in_emb_mod = model.get_input_embeddings() |
| out_emb_mod = model.get_output_embeddings() |
| in_w = getattr(in_emb_mod, "weight", None) |
| out_w = getattr(out_emb_mod, "weight", None) |
| for name in special_names: |
| val = getattr(tok, name, None) |
| exists = val is not None |
| in_range = exists and isinstance(val, int) and (0 <= val < vocab_size) |
| equal_row = None |
| if ( |
| in_range |
| and in_w is not None |
| and (out_w is not None) |
| and (in_w.shape == out_w.shape) |
| and (in_w.size(0) > val) |
| ): |
| try: |
| equal_row = bool(torch.allclose(in_w[val], out_w[val])) |
| except Exception: |
| equal_row = False |
| decoded_str = None |
| if exists and isinstance(val, int): |
| try: |
| decoded_str = tok.decode([val]) |
| except Exception: |
| try: |
| decoded_str = tok.convert_ids_to_tokens(val) |
| except Exception: |
| decoded_str = "<decode_failed>" |
| logger.info( |
| f"Special token check -> {name }={val }, decoded='{decoded_str }', exists={exists }, in_vocab_range={in_range }, emb_vs_head_row_equal={equal_row }" |
| ) |
| except Exception as e: |
| logger.warning(f"Special token ID/row validation failed: {e }") |
| try: |
| logger.info("=== TOKENIZER DIAGNOSTICS ===") |
| logger.info(f"Tokenizer class: {type (tok ).__name__ }") |
| logger.info(f"Tokenizer vocab_size: {tok .vocab_size }") |
| with torch.no_grad(): |
| simple_text = "The cat sat on the mat." |
| simple_ids = torch.tensor( |
| [tok.encode(simple_text, add_special_tokens=True)], device=model.device |
| ) |
| simple_mask = torch.ones_like(simple_ids) |
| x = model.get_input_embeddings()(simple_ids) |
| outputs = model.model( |
| inputs_embeds=x, attention_mask=simple_mask, return_dict=True |
| ) |
| logits = model.lm_head(outputs.last_hidden_state) |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = simple_ids[:, 1:].contiguous() |
| ce_loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| reduction="mean", |
| ) |
| logger.info(f"Simple text CE loss: {ce_loss .item ():.4f}") |
| except Exception as e: |
| logger.warning(f"Tokenizer diagnostics failed: {e }") |
| if hasattr(model.config, "use_cache") and training_args.do_train: |
| model.config.use_cache = False |
| if model_args.freeze_acoustic_tokenizer and hasattr( |
| model.model, "acoustic_tokenizer" |
| ): |
| for p in model.model.acoustic_tokenizer.parameters(): |
| p.requires_grad = False |
| if model_args.freeze_semantic_tokenizer and hasattr( |
| model.model, "semantic_tokenizer" |
| ): |
| for p in model.model.semantic_tokenizer.parameters(): |
| p.requires_grad = False |
| lora_cfg = build_lora_config(model_args) |
| tm_lower = [ |
| s.strip().lower() |
| for s in model_args.lora_target_modules.split(",") |
| if s.strip() |
| ] |
| skip_lm_lora = len(tm_lower) == 0 or all( |
| (t in ("none", "off", "disable", "disabled") for t in tm_lower) |
| ) |
| if not skip_lm_lora: |
| model.model.language_model = get_peft_model( |
| model.model.language_model, lora_cfg |
| ) |
| else: |
| logger.info("Skipping LLM LoRA wrapping (lora_target_modules indicates none).") |
| try: |
| model.tie_weights() |
| except Exception: |
| pass |
| for _, p in model.named_parameters(): |
| p.requires_grad = False |
| try: |
| for n, p in model.model.language_model.named_parameters(): |
| if "lora_A" in n or "lora_B" in n: |
| p.requires_grad = True |
| except Exception: |
| logger.warning("Could not re-enable LoRA params on language_model.") |
| if getattr(model_args, "lora_wrap_diffusion_head", False) and hasattr( |
| model.model, "prediction_head" |
| ): |
|
|
| class _HeadForwardShim(nn.Module): |
|
|
| def __init__(self, base: nn.Module): |
| super().__init__() |
| self.base = base |
|
|
| def forward(self, *args, **kwargs): |
| if len(args) >= 3: |
| noisy_images, timesteps, condition = args[:3] |
| else: |
| noisy_images = kwargs.get("noisy_images") |
| timesteps = kwargs.get("timesteps") |
| condition = kwargs.get("condition") |
| return self.base(noisy_images, timesteps, condition) |
|
|
| try: |
| shim = _HeadForwardShim(model.model.prediction_head) |
| model.model.prediction_head = get_peft_model( |
| shim, build_head_lora_config(model_args) |
| ) |
| for n, p in model.model.prediction_head.named_parameters(): |
| if "lora_A" in n or "lora_B" in n: |
| p.requires_grad = True |
| except Exception as e: |
| logger.warning(f"Could not LoRA-wrap diffusion head: {e }") |
| if getattr(model_args, "train_diffusion_head", False) and hasattr( |
| model.model, "prediction_head" |
| ): |
| for p in model.model.prediction_head.parameters(): |
| p.requires_grad = True |
| if model_args.layers_to_freeze is not None and hasattr( |
| model.model, "prediction_head" |
| ): |
| head_params = list(model.model.prediction_head.named_parameters()) |
| try: |
| indices_to_freeze = { |
| int(x.strip()) |
| for x in model_args.layers_to_freeze.split(",") |
| if x.strip() |
| } |
| frozen_count = 0 |
| for i, (name, param) in enumerate(head_params): |
| if i in indices_to_freeze: |
| param.requires_grad = False |
| frozen_count += 1 |
| logger.info(f"Froze layer [{i }]: {name }") |
| logger.info( |
| f"Successfully froze {frozen_count } parameter groups in the diffusion head." |
| ) |
| except Exception as e: |
| logger.error(f"Could not parse --layers_to_freeze: {e }") |
| raise |
| if getattr(model_args, "train_connectors", False): |
| if hasattr(model.model, "acoustic_connector"): |
| for p in model.model.acoustic_connector.parameters(): |
| p.requires_grad = True |
| if hasattr(model.model, "semantic_connector"): |
| for p in model.model.semantic_connector.parameters(): |
| p.requires_grad = True |
| else: |
| if hasattr(model.model, "acoustic_connector"): |
| for p in model.model.acoustic_connector.parameters(): |
| p.requires_grad = False |
| if hasattr(model.model, "semantic_connector"): |
| for p in model.model.semantic_connector.parameters(): |
| p.requires_grad = False |
| try: |
| emb = model.get_input_embeddings() |
| if hasattr(emb, "weight"): |
| emb.weight.requires_grad_(False) |
| head = model.get_output_embeddings() |
| if head is not None and hasattr(head, "weight"): |
| head.weight.requires_grad_(False) |
| except Exception: |
| pass |
|
|
| def _sum_params(named_iter): |
| return sum((p.numel() for _, p in named_iter if p.requires_grad)) |
|
|
| try: |
| lm_lora = ( |
| _sum_params(model.model.language_model.named_parameters()) |
| if hasattr(model.model, "language_model") |
| else 0 |
| ) |
| pred_head_train = ( |
| _sum_params(model.model.prediction_head.named_parameters()) |
| if hasattr(model.model, "prediction_head") |
| else 0 |
| ) |
| ac_conn_train = ( |
| _sum_params(model.model.acoustic_connector.named_parameters()) |
| if hasattr(model.model, "acoustic_connector") |
| else 0 |
| ) |
| se_conn_train = ( |
| _sum_params(model.model.semantic_connector.named_parameters()) |
| if hasattr(model.model, "semantic_connector") |
| else 0 |
| ) |
| total_trainable = sum( |
| (p.numel() for p in model.parameters() if p.requires_grad) |
| ) |
| logger.info( |
| f"Trainable by block -> LLM-LoRA: {lm_lora :,} | diff_head: {pred_head_train :,} | ac_conn: {ac_conn_train :,} | se_conn: {se_conn_train :,}" |
| ) |
| logger.info("TOTAL trainable: %s", f"{total_trainable :,}") |
| except Exception: |
| pass |
| verification_mode = ( |
| VerificationMode.NO_CHECKS |
| if data_args.ignore_verifications |
| else VerificationMode.BASIC_CHECKS |
| ) |
| if data_args.train_jsonl is not None: |
| data_files: Dict[str, str] = {"train": data_args.train_jsonl} |
| if data_args.validation_jsonl is not None: |
| data_files["validation"] = data_args.validation_jsonl |
| raw = load_dataset( |
| "json", |
| data_files=data_files, |
| verification_mode=verification_mode, |
| cache_dir=model_args.cache_dir, |
| ) |
| else: |
| if data_args.dataset_name is None: |
| raise ValueError( |
| "Provide --dataset_name (HF datasets) or use --train_jsonl/--validation_jsonl for local files." |
| ) |
| raw = load_dataset( |
| data_args.dataset_name, |
| data_args.dataset_config_name, |
| verification_mode=verification_mode, |
| cache_dir=model_args.cache_dir, |
| ) |
| train_ds = raw[data_args.train_split_name] |
| eval_ds = None |
| if training_args.do_eval: |
| if data_args.eval_split_name and data_args.eval_split_name in raw: |
| eval_ds = raw[data_args.eval_split_name] |
| elif ( |
| data_args.eval_split_size |
| and data_args.eval_split_size > 0 |
| and (len(train_ds) > 1) |
| ): |
| split = train_ds.train_test_split( |
| test_size=data_args.eval_split_size, seed=training_args.seed |
| ) |
| train_ds, eval_ds = (split["train"], split["test"]) |
| train_dataset = QWEN3VoxDataset( |
| train_ds, |
| text_column=data_args.text_column_name, |
| audio_column=data_args.audio_column_name, |
| voice_prompts_column=data_args.voice_prompts_column_name, |
| ) |
| eval_dataset = None |
| if eval_ds is not None: |
| eval_dataset = QWEN3VoxDataset( |
| eval_ds, |
| text_column=data_args.text_column_name, |
| audio_column=data_args.audio_column_name, |
| voice_prompts_column=data_args.voice_prompts_column_name, |
| ) |
| speech_compress_ratio = getattr(processor, "speech_tok_compress_ratio", 3200) |
| semantic_dim = getattr(model.config, "semantic_vae_dim", None) |
| if semantic_dim is None: |
| try: |
| semantic_dim = int( |
| getattr(model.config.semantic_tokenizer_config, "vae_dim", 128) |
| ) |
| except Exception: |
| semantic_dim = 128 |
| compute_semantics_flag = ( |
| hasattr(processor, "semantic_tokenizer") |
| and processor.semantic_tokenizer is not None |
| ) |
| data_collator = QWEN3VoxCollator( |
| processor=processor, |
| max_length=data_args.max_length, |
| speech_compress_ratio=speech_compress_ratio, |
| semantic_vae_dim=semantic_dim, |
| compute_semantics=compute_semantics_flag, |
| debug_checks=False, |
| voice_prompt_drop_rate=data_args.voice_prompt_drop_rate, |
| ) |
|
|
| class LoRADebugCallback(TrainerCallback): |
|
|
| def __init__(self, log_every_n_steps: int = 50): |
| self.log_every_n_steps = max(1, int(log_every_n_steps)) |
| self.prev_param_norms: Dict[str, float] = {} |
| self.lora_param_names: List[str] = [] |
|
|
| def on_train_begin(self, args, state, control, model=None, **kwargs): |
| try: |
| if model is None: |
| return |
| named: Dict[str, torch.nn.Parameter] = dict(model.named_parameters()) |
| self.lora_param_names = [ |
| n for n in named.keys() if "lora_A" in n or "lora_B" in n |
| ] |
| for n in self.lora_param_names: |
| p = named[n] |
| self.prev_param_norms[n] = float(p.data.norm().item()) |
| total = len(self.lora_param_names) |
| req_grad = sum( |
| (1 for n in self.lora_param_names if named[n].requires_grad) |
| ) |
| num_A = sum((1 for n in self.lora_param_names if "lora_A" in n)) |
| num_B = sum((1 for n in self.lora_param_names if "lora_B" in n)) |
| zero_B = sum( |
| ( |
| 1 |
| for n in self.lora_param_names |
| if "lora_B" in n and float(named[n].data.norm().item()) == 0.0 |
| ) |
| ) |
| logger.info( |
| f"LoRA debug: found {total } LoRA params (A={num_A }, B={num_B }); trainable={req_grad }. Initial lora_B_zero={zero_B }." |
| ) |
| if total == 0: |
| logger.warning( |
| "LoRA debug: No LoRA parameters found. Check lora_target_modules." |
| ) |
| if req_grad != total: |
| logger.warning( |
| "LoRA debug: Some LoRA params are frozen. They should be trainable." |
| ) |
| except Exception as e: |
| logger.warning(f"LoRA debug (on_train_begin) failed: {e }") |
|
|
| def on_step_end(self, args, state, control, model=None, **kwargs): |
| try: |
| if model is None or len(self.lora_param_names) == 0: |
| return |
| step = int(getattr(state, "global_step", 0) or 0) |
| if step % self.log_every_n_steps != 0 and step != 1: |
| return |
| named: Dict[str, torch.nn.Parameter] = dict(model.named_parameters()) |
| changed_A = 0 |
| changed_B = 0 |
| zero_B = 0 |
| eps = 1e-12 |
| for n in self.lora_param_names: |
| p = named.get(n, None) |
| if p is None: |
| continue |
| prev = self.prev_param_norms.get(n, 0.0) |
| curr = float(p.data.norm().item()) |
| if "lora_A" in n and abs(curr - prev) > eps: |
| changed_A += 1 |
| if "lora_B" in n: |
| if abs(curr - prev) > eps: |
| changed_B += 1 |
| if curr == 0.0: |
| zero_B += 1 |
| self.prev_param_norms[n] = curr |
| total_A = sum((1 for n in self.lora_param_names if "lora_A" in n)) |
| total_B = sum((1 for n in self.lora_param_names if "lora_B" in n)) |
| logger.info( |
| f"LoRA debug step {step }: changed A {changed_A }/{total_A }, changed B {changed_B }/{total_B }, lora_B_zero_now={zero_B }." |
| ) |
| except Exception as e: |
| logger.warning(f"LoRA debug (on_step_end) failed: {e }") |
|
|
| class QWEN3VoxTrainer(Trainer): |
|
|
| def training_forward( |
| self, model: QWEN3VoxForConditionalGeneration, inputs: Dict[str, Any] |
| ): |
| input_ids = inputs.get("input_ids") |
| attention_mask = inputs.get("attention_mask") |
| position_ids = inputs.get("position_ids") |
| past_key_values = inputs.get("past_key_values") |
| inputs_embeds = inputs.get("inputs_embeds") |
| use_cache = inputs.get("use_cache", False) |
| output_attentions = inputs.get("output_attentions") |
| output_hidden_states = inputs.get("output_hidden_states") |
| return_dict = inputs.get("return_dict", True) |
| cache_position = inputs.get("cache_position") |
| speech_tensors = inputs.get("speech_tensors") |
| speech_masks = inputs.get("speech_masks") |
| speeches_loss_input = inputs.get("speeches_loss_input") |
| speech_semantic_tensors = inputs.get("speech_semantic_tensors") |
| acoustic_input_mask = inputs.get("acoustic_input_mask") |
| acoustic_loss_mask = inputs.get("acoustic_loss_mask") |
| ddmp_batch_mul = training_args.ddpm_batch_mul |
| kwargs = {} |
| x = model.get_input_embeddings()(input_ids) |
| semantic_speech_all_connect_features = model.model.semantic_connector( |
| speech_semantic_tensors |
| ) |
| if speeches_loss_input is not None: |
| speech_all_features, speech_all_connect_features = ( |
| model.forward_speech_features( |
| speech_tensors=( |
| speech_tensors.type_as(x) |
| if speech_tensors is not None |
| else None |
| ), |
| speech_masks=speech_masks, |
| speech_type=kwargs.get("speech_type", "audio"), |
| return_unmask=True, |
| ) |
| ) |
| if speech_tensors is not None: |
| if semantic_speech_all_connect_features is not None: |
| x[acoustic_input_mask] = ( |
| speech_all_connect_features[speech_masks] |
| + semantic_speech_all_connect_features[speech_masks] |
| ) |
| else: |
| x[acoustic_input_mask] = speech_all_connect_features[ |
| speech_masks |
| ] |
| speech_features = speech_all_features[ |
| speeches_loss_input & speech_masks |
| ] |
| speech_connect_features = speech_all_connect_features[ |
| speeches_loss_input & speech_masks |
| ] |
| try: |
| if acoustic_input_mask is not None: |
| assert speech_connect_features.shape[0] == int( |
| acoustic_input_mask.sum().item() |
| ), f"Mismatch between selected speech connectors ({speech_connect_features .shape [0 ]}) and acoustic_input_mask sum ({int (acoustic_input_mask .sum ().item ())})" |
| except Exception: |
| pass |
| else: |
| speech_features, speech_connect_features = ( |
| model.forward_speech_features( |
| speech_tensors=( |
| speech_tensors.type_as(x) |
| if speech_tensors is not None |
| else None |
| ), |
| speech_masks=speech_masks, |
| speech_type=kwargs.get("speech_type", "audio"), |
| ) |
| ) |
| if speech_tensors is not None: |
| x[acoustic_input_mask] = speech_connect_features |
| outputs = model.model( |
| input_ids=None, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=x, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=False, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
| hidden_states = outputs.last_hidden_state |
| logits = model.lm_head(hidden_states) |
| loss = None |
| diffusion_loss = None |
| if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0: |
| cond_mask = torch.zeros_like(acoustic_loss_mask, dtype=torch.bool) |
| cond_mask[:, :-1] = acoustic_loss_mask[:, 1:] |
| cond_mask[:, 0] = False |
| condition_features = hidden_states[cond_mask] |
| speech_len, latent_size = speech_features.shape |
| try: |
| assert ( |
| condition_features.shape[0] == speech_len |
| ), f"Mismatch: condition_features={condition_features .shape [0 ]} vs speech_features={speech_len }" |
| except Exception: |
| pass |
| noise = torch.randn( |
| (speech_len * ddmp_batch_mul, latent_size), |
| device=hidden_states.device, |
| dtype=hidden_states.dtype, |
| ) |
| timesteps = torch.multinomial( |
| torch.ones(model.config.diffusion_head_config.ddpm_num_steps), |
| speech_len * ddmp_batch_mul, |
| replacement=True, |
| ).to(hidden_states.device) |
| speech_features_repeated = speech_features.repeat_interleave( |
| ddmp_batch_mul, dim=0 |
| ) |
| condition_features_repeated = condition_features.repeat_interleave( |
| ddmp_batch_mul, dim=0 |
| ) |
| noisy_speech_features = model.model.noise_scheduler.add_noise( |
| speech_features_repeated, noise, timesteps |
| ) |
| model_output = model.model.prediction_head( |
| noisy_speech_features, |
| timesteps.type_as(x), |
| condition_features_repeated, |
| ) |
| prediction_type = model.config.diffusion_head_config.prediction_type |
| if prediction_type == "epsilon": |
| target_for_loss = noise |
| elif prediction_type == "v_prediction": |
| target_for_loss = model.model.noise_scheduler.get_velocity( |
| speech_features_repeated, noise, timesteps |
| ) |
| else: |
| raise NotImplementedError( |
| f"Prediction type {prediction_type } not implemented" |
| ) |
| diffusion_loss = F.mse_loss( |
| model_output.float(), target_for_loss.float(), reduction="sum" |
| ) |
| if latent_size > 0 and ddmp_batch_mul > 0: |
| diffusion_loss = ( |
| diffusion_loss |
| / latent_size |
| / ddmp_batch_mul |
| / max(speech_len, 1) |
| ) |
| else: |
| diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device) |
| else: |
| diffusion_loss = ( |
| sum((p.sum() for p in model.model.prediction_head.parameters())) |
| * 0.0 |
| ) |
| diffusion_loss += ( |
| sum((p.sum() for p in model.model.acoustic_connector.parameters())) |
| * 0.0 |
| ) |
| diffusion_loss += ( |
| sum((p.sum() for p in model.model.semantic_connector.parameters())) |
| * 0.0 |
| ) |
| return QWEN3VoxCausalLMOutputWithPast( |
| loss=loss, |
| diffusion_loss=diffusion_loss, |
| speech_token_num=speech_len if speech_tensors is not None else 0, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def compute_loss( |
| self, |
| model: QWEN3VoxForConditionalGeneration, |
| inputs: Dict[str, Any], |
| return_outputs=False, |
| num_items_in_batch: Optional[int] = None, |
| ): |
| labels = inputs.get("input_ids") |
| attention_mask = inputs.get("attention_mask") |
| acoustic_input_mask = inputs.get("acoustic_input_mask") |
| sem = inputs.get("speech_semantic_tensors", None) |
| try: |
| target_dtype = next(model.model.semantic_connector.parameters()).dtype |
| except Exception: |
| target_dtype = model.get_input_embeddings().weight.dtype |
| if sem is None: |
| sm = inputs.get("speech_masks") |
| if sm is not None: |
| zeros = torch.zeros( |
| sm.size(0), |
| sm.size(1), |
| getattr(model.config, "semantic_vae_dim", 128), |
| dtype=target_dtype, |
| device=sm.device, |
| ) |
| inputs["speech_semantic_tensors"] = zeros |
| elif isinstance(sem, torch.Tensor): |
| inputs["speech_semantic_tensors"] = sem.to(dtype=target_dtype) |
| outputs = self.training_forward(model, inputs) |
| try: |
| al_mask = inputs.get("acoustic_loss_mask") |
| sp_masks = inputs.get("speech_masks") |
| sp_loss_sel = inputs.get("speeches_loss_input") |
| num_tok_total = ( |
| int(acoustic_input_mask.sum().item()) |
| if acoustic_input_mask is not None |
| else 0 |
| ) |
| num_tok_loss = int(al_mask.sum().item()) if al_mask is not None else 0 |
| num_lat_total = ( |
| int(sp_masks.sum().item()) if sp_masks is not None else 0 |
| ) |
| num_lat_loss = ( |
| int((sp_loss_sel & sp_masks).sum().item()) |
| if sp_loss_sel is not None and sp_masks is not None |
| else 0 |
| ) |
| self.log( |
| { |
| "debug/num_tok_total": float(num_tok_total), |
| "debug/num_tok_loss": float(num_tok_loss), |
| "debug/num_lat_total": float(num_lat_total), |
| "debug/num_lat_loss": float(num_lat_loss), |
| } |
| ) |
| if ( |
| sp_loss_sel is not None |
| and sp_masks is not None |
| and (al_mask is not None) |
| ): |
| if num_tok_loss != num_lat_loss: |
| logger.warning( |
| f"Loss selection mismatch: acoustic_loss_mask={num_tok_loss } vs speeches_loss_input={num_lat_loss }" |
| ) |
| except Exception: |
| pass |
| logits = outputs.logits |
| ce_labels = mask_for_ce( |
| labels, attention_mask, acoustic_input_mask, pad_id=-100 |
| ) |
| shift_logits = logits[:, :-1, :].contiguous() |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| ce_loss = loss_fct( |
| shift_logits.view(-1, shift_logits.size(-1)), ce_labels.view(-1) |
| ) |
| try: |
| self._debug_ce( |
| shift_logits, ce_labels, attention_mask, acoustic_input_mask |
| ) |
| except Exception as e: |
| logger.warning(f"Failed invoking CE debug: {e }") |
| diffusion_loss = ( |
| outputs.diffusion_loss |
| if outputs.diffusion_loss is not None |
| else torch.tensor(0.0, device=ce_loss.device) |
| ) |
| total = ( |
| training_args.ce_loss_weight * ce_loss |
| + training_args.diffusion_loss_weight * diffusion_loss |
| ) |
| try: |
| prefix = "train" if model.training else "eval" |
| self.log( |
| { |
| f"{prefix }/ce_loss": ce_loss.detach().item(), |
| f"{prefix }/diffusion_loss": ( |
| diffusion_loss.detach().item() |
| if isinstance(diffusion_loss, torch.Tensor) |
| else float(diffusion_loss) |
| ), |
| } |
| ) |
| if ( |
| hasattr(self, "optimizer") |
| and self.optimizer is not None |
| and (len(self.optimizer.param_groups) > 0) |
| ): |
| lr_val = self.optimizer.param_groups[0].get("lr", None) |
| if lr_val is not None: |
| self.log({"train/learning_rate_real": float(lr_val)}) |
| except Exception: |
| pass |
| return (total, outputs) if return_outputs else total |
|
|
| def _debug_ce( |
| self, |
| shift_logits: torch.Tensor, |
| ce_labels: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| acoustic_input_mask: Optional[torch.Tensor], |
| ): |
| try: |
| if not getattr(training_args, "debug_ce_details", False): |
| return |
| step = int(getattr(self.state, "global_step", 0) or 0) |
| every_n = max( |
| 1, int(getattr(training_args, "debug_ce_every_n_steps", 200) or 200) |
| ) |
| if not (step <= 1 or step % every_n == 0): |
| return |
| with torch.no_grad(): |
| vocab = shift_logits.size(-1) |
| per_token_loss = F.cross_entropy( |
| shift_logits.view(-1, vocab), |
| ce_labels.view(-1), |
| reduction="none", |
| ignore_index=-100, |
| ).view_as(ce_labels) |
| valid_mask = ce_labels.ne(-100) |
| num_valid = int(valid_mask.sum().item()) |
| avg_loss = ( |
| float(per_token_loss[valid_mask].mean().item()) |
| if num_valid > 0 |
| else float("nan") |
| ) |
| per_ex_avgs = [] |
| max_examples = max( |
| 1, int(getattr(training_args, "debug_ce_max_examples", 1) or 1) |
| ) |
| B = ce_labels.size(0) |
| for b in range(min(B, max_examples)): |
| vb = valid_mask[b] |
| if int(vb.sum().item()) > 0: |
| per_ex_avgs.append( |
| float(per_token_loss[b][vb].mean().item()) |
| ) |
| else: |
| per_ex_avgs.append(float("nan")) |
| logger.info( |
| f"CE debug: tokens_in_loss={num_valid }, avg_loss={avg_loss :.4f}, per_example_avgs={[round (x ,4 )if x ==x else None for x in per_ex_avgs ]}" |
| ) |
| except Exception as e: |
| logger.warning(f"CE detailed debug failed: {e }") |
|
|
| def _save(self, output_dir: Optional[str] = None, state_dict=None) -> None: |
| try: |
| target_dir = output_dir or self.args.output_dir |
| lora_out = os.path.join(target_dir, "lora") |
| os.makedirs(lora_out, exist_ok=True) |
| language_model = getattr(self.model.model, "language_model", None) |
| if hasattr(language_model, "save_pretrained"): |
| language_model.save_pretrained(lora_out) |
| pred_head = getattr(self.model.model, "prediction_head", None) |
| if hasattr(pred_head, "save_pretrained"): |
| ph_dir = os.path.join(lora_out, "diffusion_head") |
| os.makedirs(ph_dir, exist_ok=True) |
| pred_head.save_pretrained(ph_dir) |
| if pred_head is not None and hasattr(pred_head, "state_dict"): |
| sd = pred_head.state_dict() |
| torch.save(sd, os.path.join(lora_out, "diffusion_head_full.bin")) |
| ph_dir = os.path.join(lora_out, "diffusion_head") |
| os.makedirs(ph_dir, exist_ok=True) |
| torch.save(sd, os.path.join(ph_dir, "diffusion_head_full.bin")) |
| ac = getattr(self.model.model, "acoustic_connector", None) |
| if ac is not None: |
| ac_dir = os.path.join(lora_out, "acoustic_connector") |
| os.makedirs(ac_dir, exist_ok=True) |
| torch.save( |
| ac.state_dict(), os.path.join(ac_dir, "pytorch_model.bin") |
| ) |
| se = getattr(self.model.model, "semantic_connector", None) |
| if se is not None: |
| se_dir = os.path.join(lora_out, "semantic_connector") |
| os.makedirs(se_dir, exist_ok=True) |
| torch.save( |
| se.state_dict(), os.path.join(se_dir, "pytorch_model.bin") |
| ) |
| except Exception as e: |
| logger.warning(f"Failed to save LoRA assets: {e }") |
|
|
| ema_cb = EmaCallback(attr_path="model.prediction_head", decay=0.999, device="cpu") |
| trainer = QWEN3VoxTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| data_collator=data_collator, |
| callbacks=[ |
| ema_cb, |
| LoRADebugCallback( |
| log_every_n_steps=int(getattr(training_args, "logging_steps", 50) or 50) |
| ), |
| ], |
| ) |
| if getattr(training_args, "debug_save", False): |
| try: |
| debug_dir = os.path.join(training_args.output_dir, "debug_initial") |
| lora_out = os.path.join(debug_dir, "lora") |
| os.makedirs(lora_out, exist_ok=True) |
| logger.info( |
| f"[debug_save] Saving initial (pre-training) model components to: {debug_dir }" |
| ) |
| try: |
| if hasattr(model.model.language_model, "save_pretrained"): |
| model.model.language_model.save_pretrained(lora_out) |
| except Exception as e_lm: |
| logger.warning(f"[debug_save] Failed to save language_model: {e_lm }") |
| try: |
| if hasattr(model.model, "prediction_head") and hasattr( |
| model.model.prediction_head, "save_pretrained" |
| ): |
| model.model.prediction_head.save_pretrained( |
| os.path.join(lora_out, "diffusion_head") |
| ) |
| except Exception as e_head: |
| logger.warning( |
| f"[debug_save] Failed to save prediction_head: {e_head }" |
| ) |
| try: |
| ph = getattr(model.model, "prediction_head", None) |
| if ph is not None and hasattr(ph, "state_dict"): |
| sd = ph.state_dict() |
| torch.save(sd, os.path.join(lora_out, "diffusion_head_full.bin")) |
| os.makedirs(os.path.join(lora_out, "diffusion_head"), exist_ok=True) |
| torch.save( |
| sd, |
| os.path.join( |
| lora_out, "diffusion_head", "diffusion_head_full.bin" |
| ), |
| ) |
| except Exception as e: |
| logger.warning(f"[debug_save] Failed to save FULL diffusion head: {e }") |
| try: |
| ac_conn = getattr(model.model, "acoustic_connector", None) |
| if ac_conn is not None: |
| ac_dir = os.path.join(lora_out, "acoustic_connector") |
| os.makedirs(ac_dir, exist_ok=True) |
| torch.save( |
| ac_conn.state_dict(), os.path.join(ac_dir, "pytorch_model.bin") |
| ) |
| except Exception as e_ac: |
| logger.warning( |
| f"[debug_save] Failed to save acoustic_connector: {e_ac }" |
| ) |
| try: |
| se_conn = getattr(model.model, "semantic_connector", None) |
| if se_conn is not None: |
| se_dir = os.path.join(lora_out, "semantic_connector") |
| os.makedirs(se_dir, exist_ok=True) |
| torch.save( |
| se_conn.state_dict(), os.path.join(se_dir, "pytorch_model.bin") |
| ) |
| except Exception as e_se: |
| logger.warning( |
| f"[debug_save] Failed to save semantic_connector: {e_se }" |
| ) |
| except Exception as e: |
| logger.warning( |
| f"[debug_save] Unexpected failure saving initial components: {e }" |
| ) |
| if getattr(training_args, "gradient_checkpointing", False): |
| try: |
| model.gradient_checkpointing_enable() |
| except Exception: |
| logger.warning("Failed to enable gradient checkpointing on the model.") |
| if training_args.do_train: |
| trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) |
| lora_out = os.path.join(training_args.output_dir, "lora") |
| os.makedirs(lora_out, exist_ok=True) |
| lm = getattr(model.model, "language_model", None) |
| if hasattr(lm, "save_pretrained"): |
| lm.save_pretrained(lora_out) |
| ph = getattr(model.model, "prediction_head", None) |
| if hasattr(ph, "save_pretrained"): |
| ph_dir = os.path.join(lora_out, "diffusion_head") |
| os.makedirs(ph_dir, exist_ok=True) |
| ph.save_pretrained(ph_dir) |
| try: |
| if ph is not None and hasattr(ph, "state_dict"): |
| sd = ph.state_dict() |
| torch.save(sd, os.path.join(lora_out, "diffusion_head_full.bin")) |
| ph_dir = os.path.join(lora_out, "diffusion_head") |
| os.makedirs(ph_dir, exist_ok=True) |
| torch.save(sd, os.path.join(ph_dir, "diffusion_head_full.bin")) |
| except Exception as e: |
| logger.warning(f"Failed to save FULL diffusion head at end: {e }") |
| try: |
| ac = getattr(model.model, "acoustic_connector", None) |
| if ac is not None: |
| ac_dir = os.path.join(lora_out, "acoustic_connector") |
| os.makedirs(ac_dir, exist_ok=True) |
| torch.save(ac.state_dict(), os.path.join(ac_dir, "pytorch_model.bin")) |
| except Exception as e: |
| logger.warning(f"Failed to save acoustic_connector: {e }") |
| try: |
| se = getattr(model.model, "semantic_connector", None) |
| if se is not None: |
| se_dir = os.path.join(lora_out, "semantic_connector") |
| os.makedirs(se_dir, exist_ok=True) |
| torch.save(se.state_dict(), os.path.join(se_dir, "pytorch_model.bin")) |
| except Exception as e: |
| logger.warning(f"Failed to save semantic_connector: {e }") |
| if training_args.do_eval and eval_dataset is not None: |
| trainer.evaluate() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
| from typing import List, Optional, Tuple, Union |
| import torch |
| import torch.nn as nn |
| from transformers.models.auto import AutoModel, AutoModelForCausalLM |
| from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast |
| from transformers import modeling_utils |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import logging |
| from transformers.generation import GenerationMixin |
|
|
| logger = logging.get_logger(__name__) |
| if ( |
| not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") |
| or modeling_utils.ALL_PARALLEL_STYLES is None |
| ): |
| modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] |
|
|
|
|
| class QWEN3VoxASRPreTrainedModel(PreTrainedModel): |
| config_class = QWEN3VoxASRConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _skip_keys_device_placement = "past_key_values" |
| _supports_cache_class = True |
| _supports_flash_attn = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_quantized_cache = True |
| _supports_static_cache = True |
| _supports_attention_backend = True |
|
|
| def _init_weights(self, module): |
| if hasattr(self.config, "language_model_config") and hasattr( |
| self.config.language_model_config, "initializer_range" |
| ): |
| std = self.config.language_model_config.initializer_range |
| elif hasattr(self.config, "decoder_config") and hasattr( |
| self.config.decoder_config, "initializer_range" |
| ): |
| std = self.config.decoder_config.initializer_range |
| else: |
| std = 0.02 |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.weight.data.fill_(1.0) |
| module.bias.data.zero_() |
|
|
|
|
| class QWEN3VoxASRModel(QWEN3VoxASRPreTrainedModel): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| if hasattr(config, "torch_dtype") and config.torch_dtype is not None: |
| if isinstance(config.torch_dtype, str): |
| dtype = getattr(torch, config.torch_dtype) |
| else: |
| dtype = config.torch_dtype |
| else: |
| dtype = torch.float32 |
| lm_config = config.decoder_config |
| self.language_model = AutoModel.from_config(lm_config) |
| self.acoustic_tokenizer = AutoModel.from_config( |
| config.acoustic_tokenizer_config |
| ).to(dtype) |
| self.semantic_tokenizer = AutoModel.from_config( |
| config.semantic_tokenizer_config |
| ).to(dtype) |
| self.acoustic_connector = SpeechConnector( |
| config.acoustic_vae_dim, lm_config.hidden_size |
| ).to(dtype) |
| self.semantic_connector = SpeechConnector( |
| config.semantic_vae_dim, lm_config.hidden_size |
| ).to(dtype) |
|
|
| def get_input_embeddings(self): |
| if hasattr(self.language_model, "embed_tokens"): |
| return self.language_model.embed_tokens |
| for name, attr in self.language_model.fullmap.items(): |
| if attr.orig_name == "embed_tokens.weight": |
| return getattr(self.language_model, name) |
| assert False, "should not arrive here" |
|
|
| def set_input_embeddings(self, value): |
| self.language_model.embed_tokens = value |
|
|
| def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None): |
| self.acoustic_tokenizer = acoustic_tokenizer |
| self.semantic_tokenizer = semantic_tokenizer |
| if self.acoustic_tokenizer is not None: |
| self.acoustic_tokenizer.train(False) |
| if self.semantic_tokenizer is not None: |
| self.semantic_tokenizer.train(False) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
| outputs = self.language_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| if not return_dict: |
| return outputs |
| return BaseModelOutputWithPast( |
| last_hidden_state=outputs.last_hidden_state, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class QWEN3VoxASRForConditionalGeneration(QWEN3VoxASRPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
| _tp_plan = {"lm_head": "colwise_rep"} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = QWEN3VoxASRModel(config) |
| self.vocab_size = config.decoder_config.vocab_size |
| if hasattr(config, "torch_dtype") and config.torch_dtype is not None: |
| if isinstance(config.torch_dtype, str): |
| dtype = getattr(torch, config.torch_dtype) |
| else: |
| dtype = config.torch_dtype |
| else: |
| dtype = torch.float32 |
| self.lm_head = nn.Linear( |
| config.decoder_config.hidden_size, self.vocab_size, bias=False |
| ).to(dtype) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model.language_model = decoder |
|
|
| def get_decoder(self): |
| return self.model.language_model |
|
|
| def tie_weights(self): |
| if getattr(self.config.decoder_config, "tie_word_embeddings", False): |
| output_embeddings = self.get_output_embeddings() |
| input_embeddings = self.get_input_embeddings() |
| if hasattr(input_embeddings, "weight"): |
| output_embeddings.weight = input_embeddings.weight |
| else: |
| output_embeddings.weight = input_embeddings |
|
|
| def encode_speech( |
| self, |
| speech_tensors: torch.FloatTensor, |
| speech_masks: Optional[torch.BoolTensor] = None, |
| speech_semantic_tensors: Optional[torch.FloatTensor] = None, |
| streaming_segment_duration: float = 60.0, |
| ): |
| if hasattr(self.config, "torch_dtype") and self.config.torch_dtype is not None: |
| if isinstance(self.config.torch_dtype, str): |
| dtype = getattr(torch, self.config.torch_dtype) |
| else: |
| dtype = self.config.torch_dtype |
| else: |
| dtype = torch.float32 |
| speech_tensors = speech_tensors.to(dtype) |
| if speech_tensors.ndim == 1: |
| speech_tensors = speech_tensors.unsqueeze(0) |
| batch_size, total_samples = speech_tensors.shape |
| sample_rate = 22050 |
| segment_samples = int(streaming_segment_duration * sample_rate) |
| use_streaming = total_samples > segment_samples |
| with torch.no_grad(): |
| if not use_streaming: |
| encoder_output = self.model.acoustic_tokenizer.encode( |
| speech_tensors.unsqueeze(1) |
| ) |
| audio_tokens = encoder_output.sample( |
| dist_type=self.model.acoustic_tokenizer.std_dist_type |
| )[0] |
| acoustic_features = self.model.acoustic_connector(audio_tokens) |
| if speech_semantic_tensors is not None: |
| semantic_features = self.model.semantic_connector( |
| speech_semantic_tensors |
| ) |
| else: |
| semantic_tokens = self.model.semantic_tokenizer.encode( |
| speech_tensors.unsqueeze(1) |
| ).mean |
| semantic_features = self.model.semantic_connector(semantic_tokens) |
| else: |
| acoustic_encoder_cache = QWEN3VoxTokenizerStreamingCache() |
| semantic_encoder_cache = QWEN3VoxTokenizerStreamingCache() |
| acoustic_mean_segments = [] |
| semantic_mean_segments = [] |
| sample_indices = torch.arange(batch_size, device=speech_tensors.device) |
|
|
| def _iter_segments(total_length: int, segment_length: int): |
| if segment_length <= 0: |
| raise ValueError("segment_length must be positive") |
| for start in range(0, total_length, segment_length): |
| end = min(start + segment_length, total_length) |
| if end > start: |
| yield (start, end) |
|
|
| segments = list(_iter_segments(total_samples, segment_samples)) |
| num_segments = len(segments) |
| for seg_idx, (start, end) in enumerate(segments): |
| chunk = speech_tensors[:, start:end].contiguous() |
| if chunk.numel() == 0: |
| continue |
| is_final = seg_idx == num_segments - 1 |
| acoustic_encoder_output = self.model.acoustic_tokenizer.encode( |
| chunk.unsqueeze(1), |
| cache=acoustic_encoder_cache, |
| sample_indices=sample_indices, |
| use_cache=True, |
| is_final_chunk=is_final, |
| ) |
| acoustic_mean_segments.append(acoustic_encoder_output.mean) |
| semantic_encoder_output = self.model.semantic_tokenizer.encode( |
| chunk.unsqueeze(1), |
| cache=semantic_encoder_cache, |
| sample_indices=sample_indices, |
| use_cache=True, |
| is_final_chunk=is_final, |
| ) |
| semantic_mean_segments.append(semantic_encoder_output.mean) |
| acoustic_mean_full = torch.cat( |
| acoustic_mean_segments, dim=1 |
| ).contiguous() |
| acoustic_encoder_output = QWEN3VoxTokenizerEncoderOutput( |
| mean=acoustic_mean_full, std=self.model.acoustic_tokenizer.fix_std |
| ) |
| audio_tokens = acoustic_encoder_output.sample( |
| dist_type=self.model.acoustic_tokenizer.std_dist_type |
| )[0] |
| acoustic_features = self.model.acoustic_connector(audio_tokens) |
| semantic_tokens = torch.cat(semantic_mean_segments, dim=1).contiguous() |
| semantic_features = self.model.semantic_connector(semantic_tokens) |
| if speech_masks is not None: |
| combined_features = ( |
| acoustic_features[speech_masks] + semantic_features[speech_masks] |
| ) |
| else: |
| combined_features = acoustic_features + semantic_features |
| return combined_features |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| speech_tensors: Optional[torch.FloatTensor] = None, |
| speech_masks: Optional[torch.BoolTensor] = None, |
| speech_semantic_tensors: Optional[torch.FloatTensor] = None, |
| acoustic_input_mask: Optional[torch.BoolTensor] = None, |
| **kwargs, |
| ) -> Union[Tuple, CausalLMOutput]: |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| if inputs_embeds is None and input_ids is not None: |
| inputs_embeds = self.get_input_embeddings()(input_ids) |
| if speech_tensors is not None and acoustic_input_mask is not None: |
| speech_features = self.encode_speech( |
| speech_tensors=speech_tensors, |
| speech_masks=speech_masks, |
| speech_semantic_tensors=speech_semantic_tensors, |
| ) |
| inputs_embeds[acoustic_input_mask] = speech_features |
| outputs = self.model( |
| input_ids=None, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
| hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state |
| logits = self.lm_head(hidden_states) |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = nn.CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
| return QWEN3VoxCausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| cache_position=None, |
| position_ids=None, |
| use_cache=True, |
| speech_tensors=None, |
| speech_masks=None, |
| speech_semantic_tensors=None, |
| acoustic_input_mask=None, |
| **kwargs, |
| ): |
| if past_key_values is not None: |
| if isinstance(past_key_values, tuple): |
| past_length = past_key_values[0][0].shape[2] |
| else: |
| past_length = past_key_values.get_seq_length() |
| if input_ids is not None and input_ids.shape[1] > past_length: |
| input_ids = input_ids[:, past_length:] |
| if position_ids is None and attention_mask is not None: |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values is not None and input_ids is not None: |
| position_ids = position_ids[:, -input_ids.shape[1] :] |
| if cache_position is None: |
| past_seen_tokens = ( |
| past_key_values.get_seq_length() if past_key_values is not None else 0 |
| ) |
| cache_position = torch.arange( |
| past_seen_tokens, |
| past_seen_tokens |
| + ( |
| input_ids.shape[1] |
| if input_ids is not None |
| else inputs_embeds.shape[1] |
| ), |
| device=( |
| input_ids.device if input_ids is not None else inputs_embeds.device |
| ), |
| ) |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| "cache_position": cache_position, |
| "past_key_values": past_key_values, |
| "use_cache": use_cache, |
| "attention_mask": attention_mask, |
| } |
| ) |
| if ( |
| cache_position is not None |
| and len(cache_position) > 0 |
| and (cache_position[0] == 0) |
| ): |
| model_inputs.update( |
| { |
| "speech_tensors": speech_tensors, |
| "speech_masks": speech_masks, |
| "speech_semantic_tensors": speech_semantic_tensors, |
| "acoustic_input_mask": acoustic_input_mask, |
| } |
| ) |
| else: |
| model_inputs.update( |
| { |
| "speech_tensors": None, |
| "speech_masks": None, |
| "speech_semantic_tensors": None, |
| "acoustic_input_mask": None, |
| } |
| ) |
| model_inputs.update(kwargs) |
| return model_inputs |
|
|
|
|
| AutoModel.register(QWEN3VoxASRConfig, QWEN3VoxASRModel) |
| AutoModelForCausalLM.register(QWEN3VoxASRConfig, QWEN3VoxASRForConditionalGeneration) |
| __all__ = [ |
| 'QWEN3VoxASRPreTrainedModel', |
| 'QWEN3VoxASRModel', |
| 'QWEN3VoxASRForConditionalGeneration', |
| ] |
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Tuple, Union, Callable |
| from tqdm import tqdm |
| import torch |
| import torch.nn as nn |
| from transformers.models.auto import AutoModel, AutoModelForCausalLM |
| from transformers.generation import ( |
| GenerationMixin, |
| GenerationConfig, |
| LogitsProcessor, |
| LogitsProcessorList, |
| StoppingCriteriaList, |
| ) |
| from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput |
| from transformers import modeling_utils |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
| if ( |
| not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") |
| or modeling_utils.ALL_PARALLEL_STYLES is None |
| ): |
| modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] |
|
|
|
|
| class QWEN3VoxTokenConstraintProcessor(LogitsProcessor): |
|
|
| def __init__(self, valid_token_ids: List[int], device: torch.device = None): |
| self.valid_token_ids = torch.tensor( |
| valid_token_ids, dtype=torch.long, device=device |
| ) |
|
|
| def __call__( |
| self, input_ids: torch.LongTensor, scores: torch.FloatTensor |
| ) -> torch.FloatTensor: |
| mask = torch.full_like(scores, float("-inf")) |
| mask[:, self.valid_token_ids] = 0 |
| scores = scores + mask |
| return scores |
|
|
|
|
| class QWEN3VoxForConditionalGenerationInference(QWEN3VoxPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
| _tp_plan = {"lm_head": "colwise_rep"} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = QWEN3VoxModel(config) |
| self.lm_head = nn.Linear( |
| config.decoder_config.hidden_size, |
| config.decoder_config.vocab_size, |
| bias=False, |
| ) |
| self.ddpm_inference_steps = ( |
| config.diffusion_head_config.ddpm_num_inference_steps |
| ) |
| self.post_init() |
|
|
| @property |
| def noise_scheduler(self): |
| return self.model.noise_scheduler |
|
|
| @property |
| def prediction_head(self): |
| return self.model.prediction_head |
|
|
| @property |
| def speech_scaling_factor(self): |
| return self.model.speech_scaling_factor |
|
|
| @property |
| def speech_bias_factor(self): |
| return self.model.speech_bias_factor |
|
|
| @property |
| def acoustic_tokenizer(self): |
| return self.model.acoustic_tokenizer |
|
|
| @property |
| def semantic_tokenizer(self): |
| return self.model.semantic_tokenizer |
|
|
| @property |
| def acoustic_connector(self): |
| return self.model.acoustic_connector |
|
|
| @property |
| def semantic_connector(self): |
| return self.model.semantic_connector |
|
|
| def tie_weights(self): |
| if not getattr(self.config, "tie_word_embeddings", False): |
| return |
| if hasattr(self, "lm_head") and hasattr( |
| self.model.language_model, "embed_tokens" |
| ): |
| self.lm_head.weight = self.model.language_model.embed_tokens.weight |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None): |
| self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer) |
|
|
| def set_ddpm_inference_steps(self, num_steps=None): |
| self.ddpm_inference_steps = ( |
| num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps |
| ) |
|
|
| def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"): |
| with torch.no_grad(): |
| if speech_type == "audio": |
| encoder_output = self.model.acoustic_tokenizer.encode( |
| speech_tensors.unsqueeze(1) |
| ) |
| acoustic_latents = encoder_output.sample( |
| dist_type=self.model.acoustic_tokenizer.std_dist_type |
| )[0] |
| acoustic_features = ( |
| acoustic_latents |
| + self.model.speech_bias_factor.to(acoustic_latents.device) |
| ) * self.model.speech_scaling_factor.to(acoustic_latents.device) |
| acoustic_connected = self.model.acoustic_connector(acoustic_features)[ |
| speech_masks.cpu() |
| ] |
| return (acoustic_features, acoustic_connected) |
| elif speech_type == "pt": |
| encoder_output = QWEN3VoxTokenizerEncoderOutput( |
| mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std |
| ) |
| acoustic_latents = encoder_output.sample( |
| dist_type=self.model.acoustic_tokenizer.std_dist_type |
| )[0] |
| acoustic_features = ( |
| acoustic_latents |
| + self.model.speech_bias_factor.to(acoustic_latents.device) |
| ) * self.model.speech_scaling_factor.to(acoustic_latents.device) |
| acoustic_connected = self.model.acoustic_connector(acoustic_features)[ |
| speech_masks.cpu() |
| ] |
| return (acoustic_features, acoustic_connected) |
| else: |
| raise NotImplementedError(f"Speech type {speech_type } not implemented") |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| speech_tensors: Optional[torch.FloatTensor] = None, |
| speech_masks: Optional[torch.BoolTensor] = None, |
| speech_input_mask: Optional[torch.BoolTensor] = None, |
| logits_to_keep: Union[int, slice] = 0, |
| **kwargs, |
| ) -> Union[Tuple, QWEN3VoxLMHeadOutputWithPast]: |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
| if inputs_embeds is None: |
| inputs_embeds = self.model.get_input_embeddings()(input_ids) |
| if speech_tensors is not None and speech_masks is not None: |
| acoustic_features, speech_embeds = self._process_speech_inputs( |
| speech_tensors.to(self.dtype), speech_masks |
| ) |
| if speech_input_mask is not None: |
| inputs_embeds[speech_input_mask] = speech_embeds |
| outputs = self.model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state |
| slice_indices = ( |
| slice(-logits_to_keep, None) |
| if isinstance(logits_to_keep, int) |
| else logits_to_keep |
| ) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| if labels is not None: |
| raise NotImplementedError( |
| "Loss computation is not implemented in this version." |
| ) |
| return QWEN3VoxLMHeadOutputWithPast( |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| last_hidden_state=hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def _build_generate_config_model_kwargs( |
| self, generation_config, inputs, tokenizer, return_processors=False, **kwargs |
| ): |
| if generation_config is None: |
| generation_config = GenerationConfig( |
| bos_token_id=tokenizer.bos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| else: |
| generation_config = GenerationConfig( |
| **generation_config, |
| bos_token_id=tokenizer.bos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| generation_config, model_kwargs = self._prepare_generation_config( |
| generation_config, |
| True, |
| speech_start_id=tokenizer.speech_start_id, |
| speech_end_id=tokenizer.speech_end_id, |
| speech_diffusion_id=tokenizer.speech_diffusion_id, |
| **kwargs, |
| ) |
| generation_config.speech_start_id = tokenizer.speech_start_id |
| generation_config.speech_end_id = tokenizer.speech_end_id |
| generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id |
| inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( |
| inputs, generation_config.bos_token_id, model_kwargs |
| ) |
| batch_size = inputs_tensor.shape[0] |
| device = self.device |
| self._prepare_special_tokens(generation_config, True, device=device) |
| generation_config.use_cache = True |
| model_kwargs["use_cache"] = generation_config.use_cache |
| input_ids = inputs_tensor.to(self.device) |
| input_ids_length = input_ids.shape[1] |
| has_default_max_length = ( |
| kwargs.get("max_length") is None |
| and generation_config.max_length is not None |
| ) |
| has_default_min_length = ( |
| kwargs.get("min_length") is None |
| and generation_config.min_length is not None |
| ) |
| generation_config = self._prepare_generated_length( |
| generation_config=generation_config, |
| has_default_max_length=has_default_max_length, |
| has_default_min_length=has_default_min_length, |
| model_input_name=model_input_name, |
| inputs_tensor=inputs_tensor, |
| input_ids_length=input_ids_length, |
| ) |
| max_cache_length = generation_config.max_length - 1 |
| self._prepare_cache_for_generation( |
| generation_config, model_kwargs, None, batch_size, max_cache_length, device |
| ) |
| model_kwargs["cache_position"] = torch.arange( |
| input_ids_length, device=device, dtype=torch.long |
| ) |
| for k, v in model_kwargs.items(): |
| if isinstance(v, torch.Tensor): |
| model_kwargs[k] = v.to(device=device) |
| if return_processors: |
| logits_processor = self._get_logits_processor( |
| generation_config=generation_config, |
| input_ids_seq_length=input_ids_length, |
| encoder_input_ids=inputs_tensor, |
| prefix_allowed_tokens_fn=None, |
| logits_processor=LogitsProcessorList(), |
| device=inputs_tensor.device, |
| model_kwargs=model_kwargs, |
| ) |
| stopping_criteria = self._get_stopping_criteria( |
| generation_config=generation_config, |
| stopping_criteria=StoppingCriteriaList(), |
| ) |
| return ( |
| generation_config, |
| model_kwargs, |
| input_ids, |
| logits_processor, |
| stopping_criteria, |
| ) |
| else: |
| return (generation_config, model_kwargs, input_ids) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| prefix_allowed_tokens_fn: Optional[ |
| Callable[[int, torch.Tensor], List[int]] |
| ] = None, |
| synced_gpus: Optional[bool] = None, |
| assistant_model: Optional["PreTrainedModel"] = None, |
| audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None, |
| negative_prompt_ids: Optional[torch.Tensor] = None, |
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, |
| speech_tensors: Optional[torch.FloatTensor] = None, |
| speech_masks: Optional[torch.BoolTensor] = None, |
| speech_input_mask: Optional[torch.BoolTensor] = None, |
| is_prefill: bool = True, |
| return_speech: bool = True, |
| cfg_scale: float = 1.0, |
| stop_check_fn: Optional[Callable[[], bool]] = None, |
| tqdm_class: Optional[type] = None, |
| **kwargs, |
| ) -> Union[torch.LongTensor, QWEN3VoxGenerationOutput]: |
| tokenizer = kwargs.pop("tokenizer", None) |
| parsed_scripts = kwargs.pop("parsed_scripts", None) |
| all_speakers_list = kwargs.pop("all_speakers_list", None) |
| max_length_times = kwargs.pop("max_length_times", 2) |
| if kwargs.get("max_new_tokens", None) is None: |
| kwargs["max_new_tokens"] = ( |
| self.config.decoder_config.max_position_embeddings |
| - kwargs["input_ids"].shape[-1] |
| ) |
| ( |
| generation_config, |
| model_kwargs, |
| input_ids, |
| logits_processor, |
| stopping_criteria, |
| ) = self._build_generate_config_model_kwargs( |
| generation_config, inputs, tokenizer, return_processors=True, **kwargs |
| ) |
| negative_kwargs = { |
| "input_ids": torch.full( |
| (kwargs["input_ids"].shape[0], 1), |
| tokenizer.speech_start_id, |
| dtype=torch.long, |
| device=kwargs["input_ids"].device, |
| ), |
| "attention_mask": torch.ones( |
| (kwargs["input_ids"].shape[0], 1), |
| dtype=torch.long, |
| device=kwargs["input_ids"].device, |
| ), |
| "max_new_tokens": kwargs.get("max_new_tokens", 100), |
| } |
| negative_generation_config, negative_model_kwargs, negative_input_ids = ( |
| self._build_generate_config_model_kwargs( |
| None, None, tokenizer, return_processors=False, **negative_kwargs |
| ) |
| ) |
| acoustic_cache = QWEN3VoxTokenizerStreamingCache() |
| semantic_cache = QWEN3VoxTokenizerStreamingCache() |
| batch_size = input_ids.shape[0] |
| device = input_ids.device |
| finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device) |
| correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device) |
| inputs_embeds = None |
| verbose = kwargs.get("verbose", False) |
| audio_chunks = [[] for _ in range(batch_size)] |
| initial_length = input_ids.shape[-1] |
| initial_length_per_sample = model_kwargs["attention_mask"].sum(dim=-1) |
| valid_tokens = [ |
| generation_config.speech_start_id, |
| generation_config.speech_end_id, |
| generation_config.speech_diffusion_id, |
| generation_config.eos_token_id, |
| ] |
| if ( |
| hasattr(generation_config, "bos_token_id") |
| and generation_config.bos_token_id is not None |
| ): |
| valid_tokens.append(generation_config.bos_token_id) |
| token_constraint_processor = QWEN3VoxTokenConstraintProcessor( |
| valid_tokens, device=device |
| ) |
| if logits_processor is None: |
| logits_processor = LogitsProcessorList() |
| logits_processor.append(token_constraint_processor) |
| max_steps = min( |
| generation_config.max_length - initial_length, |
| int(max_length_times * initial_length), |
| ) |
| max_step_per_sample = torch.min( |
| generation_config.max_length - initial_length_per_sample, |
| (max_length_times * initial_length_per_sample).long(), |
| ) |
| reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device) |
| if kwargs.get("show_progress_bar", True): |
| tqdm_fn = tqdm_class if tqdm_class is not None else tqdm |
| progress_bar = tqdm_fn(range(max_steps), desc="Generating", leave=False) |
| else: |
| progress_bar = range(max_steps) |
| for step in progress_bar: |
| if stop_check_fn is not None and stop_check_fn(): |
| if verbose: |
| print(f"Generation stopped externally at step {step +1 }") |
| if audio_streamer is not None: |
| audio_streamer.end() |
| break |
| if audio_streamer is not None and hasattr(audio_streamer, "finished_flags"): |
| if any(audio_streamer.finished_flags): |
| if verbose: |
| print(f"Audio generation stopped externally at step {step +1 }") |
| break |
| if finished_tags.all(): |
| if hasattr(progress_bar, "set_description"): |
| progress_bar.set_description("Generation complete") |
| break |
| if input_ids.shape[-1] >= generation_config.max_length: |
| print( |
| f"Reached maximum generation length {generation_config .max_length }, stopped it." |
| ) |
| reached_samples = torch.arange(batch_size, device=device)[ |
| ~finished_tags |
| ] |
| if reached_samples.numel() > 0: |
| reach_max_step_sample[reached_samples] = True |
| break |
| if hasattr(progress_bar, "set_description"): |
| active_samples = (~finished_tags).sum().item() |
| progress_bar.set_description( |
| f"Generating (active: {active_samples }/{batch_size })" |
| ) |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| if is_prefill: |
| prefill_inputs = {} |
| if speech_tensors is not None: |
| prefill_inputs["speech_tensors"] = speech_tensors.to(device=device) |
| if speech_masks is not None: |
| prefill_inputs["speech_masks"] = speech_masks.to(device) |
| if speech_input_mask is not None: |
| prefill_inputs["speech_input_mask"] = speech_input_mask.to(device) |
| is_prefill = False |
| else: |
| _ = model_inputs.pop("inputs_embeds", None) |
| prefill_inputs = {"inputs_embeds": inputs_embeds} |
| outputs = self( |
| **model_inputs, |
| **prefill_inputs, |
| logits_to_keep=1, |
| return_dict=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| ) |
| model_kwargs = self._update_model_kwargs_for_generation( |
| outputs, model_kwargs, is_encoder_decoder=False |
| ) |
| next_token_logits = outputs.logits[:, -1, :].to( |
| copy=True, dtype=torch.float32, device=input_ids.device |
| ) |
| next_token_scores = logits_processor(input_ids, next_token_logits) |
| if generation_config.do_sample: |
| probs = nn.functional.softmax(next_token_scores, dim=-1) |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
| else: |
| next_tokens = torch.argmax(next_token_scores, dim=-1) |
| next_tokens[finished_tags] = generation_config.eos_token_id |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
| if not kwargs.get("refresh_negative", True): |
| negative_model_inputs = self.prepare_inputs_for_generation( |
| negative_input_ids, **negative_model_kwargs |
| ) |
| if ( |
| negative_model_inputs["inputs_embeds"] is None |
| and inputs_embeds is not None |
| ): |
| negative_model_inputs["inputs_embeds"] = inputs_embeds |
| negative_model_inputs["input_ids"] = None |
| negative_outputs = self( |
| **negative_model_inputs, |
| logits_to_keep=0, |
| return_dict=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| ) |
| negative_model_kwargs = self._update_model_kwargs_for_generation( |
| negative_outputs, negative_model_kwargs, is_encoder_decoder=False |
| ) |
| negative_input_ids = torch.cat( |
| [negative_input_ids, next_tokens[:, None]], dim=-1 |
| ) |
| if (next_tokens == generation_config.eos_token_id).any(): |
| eos_indices = ( |
| (next_tokens == generation_config.eos_token_id) |
| .nonzero(as_tuple=False) |
| .squeeze(1) |
| ) |
| new_eos_indices = eos_indices[~finished_tags[eos_indices]] |
| if new_eos_indices.numel() > 0: |
| finished_tags[new_eos_indices] = True |
| if verbose: |
| print( |
| f"Samples {new_eos_indices .tolist ()} reached EOS token at step {step +1 }.", |
| flush=True, |
| ) |
| if audio_streamer is not None: |
| audio_streamer.end(new_eos_indices) |
| max_length_reached = step >= max_step_per_sample |
| new_max_length_indices = torch.nonzero( |
| max_length_reached & ~finished_tags, as_tuple=False |
| ).squeeze(1) |
| if new_max_length_indices.numel() > 0: |
| finished_tags[new_max_length_indices] = True |
| reach_max_step_sample[new_max_length_indices] = True |
| if verbose: |
| print( |
| f"Samples {new_max_length_indices .tolist ()} reached max generation length at step {step +1 }.", |
| flush=True, |
| ) |
| if audio_streamer is not None: |
| audio_streamer.end(new_max_length_indices) |
| diffusion_end_indices = ( |
| (next_tokens == generation_config.speech_end_id) |
| .nonzero(as_tuple=False) |
| .squeeze(1) |
| ) |
| if diffusion_end_indices.numel() > 0: |
| acoustic_cache.set_to_zero(diffusion_end_indices) |
| semantic_cache.set_to_zero(diffusion_end_indices) |
| diffusion_start_indices = torch.arange(batch_size, device=device)[ |
| ~finished_tags & (next_tokens == generation_config.speech_start_id) |
| ] |
| if diffusion_start_indices.numel() > 0 and kwargs.get( |
| "refresh_negative", True |
| ): |
| for i, sample_idx in enumerate(diffusion_start_indices.tolist()): |
| negative_model_kwargs["attention_mask"][sample_idx, :] = 0 |
| negative_model_kwargs["attention_mask"][sample_idx, -1] = 1 |
| for layer_idx, (k_cache, v_cache) in enumerate( |
| zip( |
| negative_model_kwargs["past_key_values"].key_cache, |
| negative_model_kwargs["past_key_values"].value_cache, |
| ) |
| ): |
| for sample_idx in diffusion_start_indices.tolist(): |
| k_cache[sample_idx, :, -1, :] = k_cache[ |
| sample_idx, :, 0, : |
| ].clone() |
| v_cache[sample_idx, :, -1, :] = v_cache[ |
| sample_idx, :, 0, : |
| ].clone() |
| for sample_idx in diffusion_start_indices.tolist(): |
| negative_input_ids[sample_idx, -1] = ( |
| generation_config.speech_start_id |
| ) |
| next_inputs_embeds = self.model.get_input_embeddings()( |
| next_tokens |
| ).unsqueeze(1) |
| diffusion_indices = torch.arange(batch_size, device=device)[ |
| ~finished_tags & (next_tokens == generation_config.speech_diffusion_id) |
| ] |
| if diffusion_indices.numel() > 0: |
| if kwargs.get("refresh_negative", True): |
| negative_model_inputs = self.prepare_inputs_for_generation( |
| negative_input_ids, **negative_model_kwargs |
| ) |
| if ( |
| negative_model_inputs["inputs_embeds"] is None |
| and inputs_embeds is not None |
| ): |
| negative_model_inputs["inputs_embeds"] = inputs_embeds |
| negative_model_inputs["input_ids"] = None |
| negative_outputs = self( |
| **negative_model_inputs, |
| logits_to_keep=0, |
| return_dict=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| ) |
| negative_model_kwargs = self._update_model_kwargs_for_generation( |
| negative_outputs, |
| negative_model_kwargs, |
| is_encoder_decoder=False, |
| ) |
| negative_input_ids = torch.cat( |
| [negative_input_ids, next_tokens[:, None]], dim=-1 |
| ) |
| non_diffusion_mask = ~finished_tags & ( |
| next_tokens != generation_config.speech_diffusion_id |
| ) |
| if non_diffusion_mask.any(): |
| non_diffusion_indices = torch.arange(batch_size, device=device)[ |
| non_diffusion_mask |
| ] |
| start_indices = correct_cnt[non_diffusion_indices] |
| seq_len = negative_model_kwargs["attention_mask"].shape[1] |
| for i, (sample_idx, start_idx) in enumerate( |
| zip(non_diffusion_indices.tolist(), start_indices.tolist()) |
| ): |
| if start_idx + 1 < seq_len - 1: |
| negative_model_kwargs["attention_mask"][ |
| sample_idx, start_idx + 1 : |
| ] = negative_model_kwargs["attention_mask"][ |
| sample_idx, start_idx:-1 |
| ].clone() |
| negative_model_kwargs["attention_mask"][ |
| sample_idx, start_idx |
| ] = 0 |
| for layer_idx, (k_cache, v_cache) in enumerate( |
| zip( |
| negative_model_kwargs["past_key_values"].key_cache, |
| negative_model_kwargs["past_key_values"].value_cache, |
| ) |
| ): |
| for sample_idx, start_idx in zip( |
| non_diffusion_indices.tolist(), start_indices.tolist() |
| ): |
| if start_idx + 1 < k_cache.shape[2] - 1: |
| k_cache[sample_idx, :, start_idx + 1 :, :] = k_cache[ |
| sample_idx, :, start_idx:-1, : |
| ].clone() |
| v_cache[sample_idx, :, start_idx + 1 :, :] = v_cache[ |
| sample_idx, :, start_idx:-1, : |
| ].clone() |
| for sample_idx, start_idx in zip( |
| non_diffusion_indices.tolist(), start_indices.tolist() |
| ): |
| if start_idx + 1 < negative_input_ids.shape[1] - 1: |
| negative_input_ids[sample_idx, start_idx + 1 :] = ( |
| negative_input_ids[sample_idx, start_idx:-1].clone() |
| ) |
| correct_cnt[non_diffusion_indices] += 1 |
| positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :] |
| negative_condition = negative_outputs.last_hidden_state[ |
| diffusion_indices, -1, : |
| ] |
| speech_latent = self.sample_speech_tokens( |
| positive_condition, negative_condition, cfg_scale=cfg_scale |
| ).unsqueeze(1) |
| scaled_latent = speech_latent / self.model.speech_scaling_factor.to( |
| speech_latent.device |
| ) - self.model.speech_bias_factor.to(speech_latent.device) |
| audio_chunk = self.model.acoustic_tokenizer.decode( |
| scaled_latent.to(self.model.acoustic_tokenizer.device), |
| cache=acoustic_cache, |
| sample_indices=diffusion_indices.to( |
| self.model.acoustic_tokenizer.device |
| ), |
| use_cache=True, |
| debug=False, |
| ) |
| for i, sample_idx in enumerate(diffusion_indices): |
| idx = sample_idx.item() |
| if not finished_tags[idx]: |
| audio_chunks[idx].append(audio_chunk[i]) |
| if audio_streamer is not None: |
| audio_streamer.put(audio_chunk, diffusion_indices) |
| semantic_features = self.model.semantic_tokenizer.encode( |
| audio_chunk, |
| cache=semantic_cache, |
| sample_indices=diffusion_indices, |
| use_cache=True, |
| debug=False, |
| ).mean |
| acoustic_embed = self.model.acoustic_connector(speech_latent) |
| semantic_embed = self.model.semantic_connector(semantic_features) |
| diffusion_embeds = acoustic_embed + semantic_embed |
| next_inputs_embeds[diffusion_indices] = diffusion_embeds |
| inputs_embeds = next_inputs_embeds |
| if audio_streamer is not None: |
| audio_streamer.end() |
| final_audio_outputs = [] |
| for sample_chunks in audio_chunks: |
| if sample_chunks: |
| concatenated_audio = torch.cat(sample_chunks, dim=-1) |
| final_audio_outputs.append(concatenated_audio) |
| else: |
| final_audio_outputs.append(None) |
| return QWEN3VoxGenerationOutput( |
| sequences=input_ids, |
| speech_outputs=final_audio_outputs if return_speech else None, |
| reach_max_step_sample=reach_max_step_sample, |
| ) |
|
|
| @torch.no_grad() |
| def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0): |
| self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps) |
| condition = torch.cat([condition, neg_condition], dim=0).to( |
| self.model.prediction_head.device |
| ) |
| speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to( |
| condition |
| ) |
| for t in self.model.noise_scheduler.timesteps: |
| half = speech[: len(speech) // 2] |
| combined = torch.cat([half, half], dim=0) |
| eps = self.model.prediction_head( |
| combined, t.repeat(combined.shape[0]).to(combined), condition=condition |
| ) |
| cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) |
| half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
| eps = torch.cat([half_eps, half_eps], dim=0) |
| speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample |
| return speech[: len(speech) // 2] |
|
|
|
|
| AutoModelForCausalLM.register(QWEN3VoxConfig, QWEN3VoxForConditionalGenerationInference) |
| __all__ = [ |
| 'QWEN3VoxForConditionalGenerationInference' |
| ] |
| import argparse |
| import json |
| import os |
| from pathlib import Path |
| import re |
| import torch |
| from typing import Dict, List, Tuple |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def convert_q3_nnscaler_checkpoint_to_hf( |
| checkpoint_path: str, pytorch_dump_folder_path: str, config_path: str = None |
| ): |
| logger.info(f"Loading regular checkpoint from {checkpoint_path }") |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| init_config_name = checkpoint["train_args"]["vars"]["model_args"]["config_path"][ |
| "relative_path" |
| ] |
| pretrained_name = checkpoint["train_args"]["vars"]["data_args"]["tokenizer_path"] |
| init_config_path = ( |
| Path(__file__).parent.parent / "configs" / init_config_name.split("/")[-1] |
| ) |
| if init_config_path.exists(): |
| logger.info(f"Loading initial config from {init_config_path }") |
| with open(init_config_path, "r") as f: |
| init_config = json.load(f) |
| else: |
| raise FileNotFoundError( |
| f"Initial config file {init_config_path } not found. Please provide a valid path." |
| ) |
| tie_word_embeddings = init_config["decoder_config"].get("tie_word_embeddings", True) |
| logger.info(f"Tie word embeddings: {tie_word_embeddings }") |
| init_config["decoder_config"]["use_cache"] = True |
| config = QWEN3VoxConfig(**init_config, tie_word_embeddings=tie_word_embeddings) |
| model_state_dict = { |
| k.replace("model.model.", "model."): v |
| for k, v in checkpoint["model"].items() |
| if k.startswith("model.model.") |
| } |
| if not tie_word_embeddings and "model.lm_head.weight" in checkpoint["model"].keys(): |
| model_state_dict["lm_head.weight"] = checkpoint["model"]["model.lm_head.weight"] |
| if config_path: |
| logger.info(f"Loading config from {config_path }") |
| with open(config_path, "r") as f: |
| config_dict = json.load(f) |
| config = QWEN3VoxConfig.from_dict(config_dict) |
| original_dtype = torch.get_default_dtype() |
| torch.set_default_dtype(torch.bfloat16) |
| logger.info( |
| 'Creating HuggingFace QWEN3VoxForConditionalGeneration model' |
| ) |
| model = QWEN3VoxForConditionalGeneration(config) |
| torch.set_default_dtype(original_dtype) |
| logger.info("Loading weights into model") |
| missing_keys, unexpected_keys = model.load_state_dict( |
| model_state_dict, strict=False |
| ) |
| if missing_keys: |
| logger.warning(f"Missing keys: {missing_keys }") |
| if unexpected_keys: |
| logger.warning(f"Unexpected keys: {unexpected_keys }") |
| os.makedirs(pytorch_dump_folder_path, exist_ok=True) |
| logger.info(f"Saving model to {pytorch_dump_folder_path }") |
| config.save_pretrained(pytorch_dump_folder_path) |
| logger.info("Saving QWEN3Vox processor configuration") |
| processor_config = { |
| "processor_class": "QWEN3VoxProcessor", |
| "speech_tok_compress_ratio": 3200, |
| "db_normalize": True, |
| "audio_processor": { |
| "feature_extractor_type": "QWEN3VoxTokenizerProcessor", |
| "sampling_rate": 22050, |
| "normalize_audio": True, |
| "target_dB_FS": -25, |
| "eps": 1e-06, |
| }, |
| "language_model_pretrained_name": pretrained_name, |
| } |
| processor_config_path = os.path.join( |
| pytorch_dump_folder_path, "preprocessor_config.json" |
| ) |
| with open(processor_config_path, "w") as f: |
| json.dump(processor_config, f, indent=2) |
| logger.info(f"Saved processor config to {processor_config_path }") |
| logger.info("Saving model weights with sharding...") |
| model.save_pretrained( |
| pytorch_dump_folder_path, max_shard_size="5GB", safe_serialization=True |
| ) |
| logger.info(f"Model weights saved to {pytorch_dump_folder_path }") |
| logger.info("Conversion complete!") |
| logger.info("Verifying saved model...") |
| model_name = str(pytorch_dump_folder_path) |
| loaded_model = QWEN3VoxForConditionalGeneration.from_pretrained( |
| model_name |
| ) |
| logger.info("Model successfully loaded from saved checkpoint!") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--nnscaler_checkpoint_path", |
| type=str, |
| required=True, |
| help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), and the script will automatically detect and merge all parts.", |
| ) |
| parser.add_argument( |
| "--pytorch_dump_folder_path", |
| type=str, |
| required=True, |
| help="Path to the output PyTorch model directory", |
| ) |
| parser.add_argument( |
| "--config_path", |
| type=str, |
| default=None, |
| help="Optional path to a config JSON file to override extracted config", |
| ) |
| args = parser.parse_args() |
| convert_q3_nnscaler_checkpoint_to_hf( |
| args.nnscaler_checkpoint_path, args.pytorch_dump_folder_path, args.config_path |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
| '\nQWEN3Vox Universal Model Merger\n\nAutomatically detects and merges trained components back into the base model:\n- LLM LoRA adapters\n- Diffusion head (LoRA or full fine-tune)\n- Acoustic/Semantic connectors\n\nSupports all training configurations from train_vibevoice.py\n' |
| import argparse |
| import logging |
| import os |
| import shutil |
| from typing import Dict, Optional |
| import torch |
|
|
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def detect_trained_components(checkpoint_path: str) -> Dict[str, bool]: |
| components = { |
| "llm_lora": False, |
| "diffusion_head": False, |
| "acoustic_connector": False, |
| "semantic_connector": False, |
| } |
| llm_adapter_config = os.path.join(checkpoint_path, "adapter_config.json") |
| llm_adapter_model = os.path.join(checkpoint_path, "adapter_model.safetensors") |
| if not os.path.exists(llm_adapter_model): |
| llm_adapter_model = os.path.join(checkpoint_path, "adapter_model.bin") |
| if os.path.exists(llm_adapter_config) and os.path.exists(llm_adapter_model): |
| components["llm_lora"] = True |
| diffusion_head_dir = os.path.join(checkpoint_path, "diffusion_head") |
| diffusion_head_weights = any( |
| os.path.isfile(os.path.join(diffusion_head_dir, name)) |
| for name in ( |
| "adapter_model.safetensors", |
| "adapter_model.bin", |
| "model.safetensors", |
| "diffusion_head_full.bin", |
| ) |
| ) or os.path.isfile(os.path.join(checkpoint_path, "diffusion_head_full.bin")) |
| if os.path.isdir(diffusion_head_dir) and diffusion_head_weights: |
| components["diffusion_head"] = True |
| acoustic_conn_path = os.path.join( |
| checkpoint_path, "acoustic_connector", "pytorch_model.bin" |
| ) |
| if os.path.exists(acoustic_conn_path): |
| components["acoustic_connector"] = True |
| semantic_conn_path = os.path.join( |
| checkpoint_path, "semantic_connector", "pytorch_model.bin" |
| ) |
| if os.path.exists(semantic_conn_path): |
| components["semantic_connector"] = True |
| return components |
|
|
|
|
| def merge_llm_lora(model: QWEN3VoxForConditionalGeneration, checkpoint_path: str) -> None: |
| logger.warning( |
| "LLM LoRA merge skipped: PeftModel.from_pretrained is not allowed in miner.py. " |
| "Merge LoRA offline, then upload full safetensors to your HF repo." |
| ) |
|
|
|
|
| def merge_diffusion_head( |
| model: QWEN3VoxForConditionalGeneration, checkpoint_path: str |
| ) -> dict: |
| logger.info("Merging diffusion head...") |
| diffusion_head_dir = os.path.join(checkpoint_path, "diffusion_head") |
| possible_files = [ |
| os.path.join(diffusion_head_dir, "model.safetensors"), |
| os.path.join(diffusion_head_dir, "diffusion_head_full.bin"), |
| os.path.join(checkpoint_path, "diffusion_head_full.bin"), |
| ] |
| trained_weights_path = None |
| for path in possible_files: |
| if os.path.exists(path): |
| trained_weights_path = path |
| break |
| if trained_weights_path is None: |
| raise ValueError( |
| f"Diffusion head weights not found. Searched:\n" |
| + "\n".join((f" - {p }" for p in possible_files)) |
| ) |
| logger.info(f"Loading from: {trained_weights_path }") |
| if trained_weights_path.endswith(".safetensors"): |
| from safetensors.torch import load_file |
|
|
| trained_state_dict = load_file(trained_weights_path) |
| else: |
| trained_state_dict = torch.load(trained_weights_path, map_location="cpu") |
| is_lora = any(("lora_" in k for k in trained_state_dict.keys())) |
| if is_lora: |
| logger.warning( |
| "Diffusion-head LoRA merge skipped (PeftModel.from_pretrained banned in miner.py); " |
| "loading state_dict directly." |
| ) |
| model.model.prediction_head.load_state_dict(trained_state_dict, strict=False) |
| else: |
| logger.info("Detected full fine-tune format, replacing weights...") |
| model.model.prediction_head.load_state_dict(trained_state_dict, strict=True) |
| logger.info("✓ Diffusion head merge completed") |
| return trained_state_dict |
|
|
|
|
| def merge_connectors( |
| model: QWEN3VoxForConditionalGeneration, |
| checkpoint_path: str, |
| merge_acoustic: bool, |
| merge_semantic: bool, |
| ) -> None: |
| if merge_acoustic: |
| logger.info("Merging acoustic connector...") |
| acoustic_path = os.path.join( |
| checkpoint_path, "acoustic_connector", "pytorch_model.bin" |
| ) |
| state_dict = torch.load(acoustic_path, map_location="cpu") |
| model.model.acoustic_connector.load_state_dict(state_dict, strict=True) |
| logger.info("✓ Acoustic connector merge completed") |
| if merge_semantic: |
| logger.info("Merging semantic connector...") |
| semantic_path = os.path.join( |
| checkpoint_path, "semantic_connector", "pytorch_model.bin" |
| ) |
| state_dict = torch.load(semantic_path, map_location="cpu") |
| model.model.semantic_connector.load_state_dict(state_dict, strict=True) |
| logger.info("✓ Semantic connector merge completed") |
|
|
|
|
| def verify_merge( |
| base_model: QWEN3VoxForConditionalGeneration, |
| merged_model: QWEN3VoxForConditionalGeneration, |
| trained_state_dict: Optional[dict], |
| component_name: str, |
| ) -> None: |
| logger.info(f"\n=== Verifying {component_name } merge ===") |
| if component_name == "diffusion_head": |
| base_module = base_model.model.prediction_head |
| merged_module = merged_model.model.prediction_head |
| elif component_name == "acoustic_connector": |
| base_module = base_model.model.acoustic_connector |
| merged_module = merged_model.model.acoustic_connector |
| elif component_name == "semantic_connector": |
| base_module = base_model.model.semantic_connector |
| merged_module = merged_model.model.semantic_connector |
| else: |
| logger.warning(f"Unknown component: {component_name }, skipping verification") |
| return |
| base_state = base_module.state_dict() |
| merged_state = merged_module.state_dict() |
| logger.info("Checking if weights changed from base model...") |
| weights_changed = False |
| changed_params = [] |
| for key in base_state.keys(): |
| if key not in merged_state: |
| continue |
| if not torch.allclose( |
| base_state[key], merged_state[key], rtol=1e-05, atol=1e-08 |
| ): |
| weights_changed = True |
| changed_params.append(key) |
| if not weights_changed: |
| if component_name == "diffusion_head": |
| raise ValueError( |
| f"✗ ERROR: {component_name } weights did not change! Merge may have failed." |
| ) |
| else: |
| logger.info(f"✓ {component_name }: unchanged (was not trained)") |
| return |
| logger.info( |
| f"✓ Weights changed: {len (changed_params )}/{len (base_state )} parameters modified" |
| ) |
| if trained_state_dict is not None: |
| logger.info("Verifying trained weights match merged model...") |
| mismatches = [] |
| for key in trained_state_dict.keys(): |
| if key not in merged_state: |
| mismatches.append(f"{key } (missing in merged)") |
| continue |
| trained_tensor = trained_state_dict[key].float() |
| merged_tensor = merged_state[key].float() |
| if not torch.allclose( |
| trained_tensor, merged_tensor, rtol=1e-05, atol=1e-08 |
| ): |
| mismatches.append(f"{key } (values differ)") |
| if mismatches: |
| logger.error(f"✗ Weight mismatches found:") |
| for mm in mismatches[:5]: |
| logger.error(f" - {mm }") |
| if len(mismatches) > 5: |
| logger.error(f" ... and {len (mismatches )-5 } more") |
| raise ValueError(f"✗ ERROR: Trained and merged weights do not match!") |
| logger.info( |
| f"✓ All trained weights correctly merged: {len (trained_state_dict )} parameters verified" |
| ) |
| base_params = sum((p.numel() for p in base_module.parameters())) |
| merged_params = sum((p.numel() for p in merged_module.parameters())) |
| if base_params != merged_params: |
| raise ValueError( |
| f"✗ ERROR: Parameter count mismatch: base={base_params :,} vs merged={merged_params :,}" |
| ) |
| logger.info(f"✓ Parameter count matches: {merged_params :,}") |
| logger.info(f"✓✓✓ {component_name } verification PASSED ✓✓✓") |
|
|
|
|
| def verify_models_only(base_model_path: str, merged_model_path: str) -> None: |
| logger.info("=== VERIFY-ONLY MODE ===") |
| logger.info(f"Base model: {base_model_path }") |
| logger.info(f"Merged model: {merged_model_path }") |
| logger.info("\nLoading base model...") |
| model_name = str(base_model_path) |
| base_model = QWEN3VoxForConditionalGeneration.from_pretrained( |
| model_name, torch_dtype=torch.float32 |
| ) |
| logger.info("Loading merged model...") |
| model_name = str(merged_model_path) |
| merged_model = QWEN3VoxForConditionalGeneration.from_pretrained( |
| model_name, torch_dtype=torch.float32 |
| ) |
| components_to_check = ["diffusion_head", "acoustic_connector", "semantic_connector"] |
| for component in components_to_check: |
| try: |
| verify_merge(base_model, merged_model, None, component) |
| except ValueError as e: |
| if "did not change" in str(e): |
| logger.info(f"✓ {component }: unchanged (likely not trained)") |
| else: |
| raise |
| except Exception as e: |
| logger.error(f"✗ {component } verification failed: {e }") |
| raise |
| logger.info("\n✓✓✓ VERIFICATION COMPLETE ✓✓✓") |
|
|
|
|
| def merge_q3_model( |
| base_model_path: str, |
| checkpoint_path: str, |
| output_path: str, |
| output_format: str = "safetensors", |
| ) -> None: |
| logger.info(f"Scanning trained components in: {checkpoint_path }") |
| components = detect_trained_components(checkpoint_path) |
| logger.info("Detected trained components:") |
| for name, trained in components.items(): |
| status = "✓ Found" if trained else "✗ Not found" |
| logger.info(f" {name }: {status }") |
| if not any(components.values()): |
| raise ValueError("No trained components found in checkpoint path!") |
| logger.info(f"\nLoading base model from: {base_model_path }") |
| model_name = str(base_model_path) |
| base_model = QWEN3VoxForConditionalGeneration.from_pretrained( |
| model_name, torch_dtype=torch.float32 |
| ) |
| logger.info("\n=== Starting merge process ===") |
| trained_diffusion_state = None |
| if components["llm_lora"]: |
| merge_llm_lora(base_model, checkpoint_path) |
| if components["diffusion_head"]: |
| trained_diffusion_state = merge_diffusion_head(base_model, checkpoint_path) |
| if components["acoustic_connector"] or components["semantic_connector"]: |
| merge_connectors( |
| base_model, |
| checkpoint_path, |
| merge_acoustic=components["acoustic_connector"], |
| merge_semantic=components["semantic_connector"], |
| ) |
| logger.info(f"\n=== Saving merged model to: {output_path } ===") |
| os.makedirs(output_path, exist_ok=True) |
| if output_format == "safetensors": |
| base_model.save_pretrained( |
| output_path, max_shard_size="5GB", safe_serialization=True |
| ) |
| elif output_format == "bin": |
| base_model.save_pretrained(output_path, safe_serialization=False) |
| else: |
| raise ValueError( |
| f"Unknown output format: {output_format }. Use 'safetensors' or 'bin'" |
| ) |
| logger.info("Copying config and processor files...") |
| files_to_copy = [ |
| "config.json", |
| "preprocessor_config.json", |
| "generation_config.json", |
| "special_tokens_map.json", |
| "tokenizer_config.json", |
| "tokenizer.json", |
| "vocab.json", |
| "merges.txt", |
| ] |
| for file in files_to_copy: |
| src = os.path.join(base_model_path, file) |
| dst = os.path.join(output_path, file) |
| if os.path.exists(src): |
| shutil.copy2(src, dst) |
| logger.info("\n=== Verifying merged model ===") |
| try: |
| logger.info("Reloading original base model for verification...") |
| model_name = str(base_model_path) |
| original_base_model = QWEN3VoxForConditionalGeneration.from_pretrained( |
| model_name, torch_dtype=torch.float32 |
| ) |
| logger.info("Loading merged model for verification...") |
| model_name = str(output_path) |
| test_model = QWEN3VoxForConditionalGeneration.from_pretrained(model_name) |
| logger.info("✓ Model loads successfully") |
| if components["diffusion_head"]: |
| try: |
| verify_merge( |
| original_base_model, |
| test_model, |
| trained_diffusion_state, |
| "diffusion_head", |
| ) |
| except ValueError as e: |
| if "did not change" in str(e): |
| logger.warning( |
| "Diffusion head weights unchanged after merge (often means " |
| "checkpoint matches base); continuing without failing merge." |
| ) |
| else: |
| raise |
| if components["acoustic_connector"]: |
| verify_merge(original_base_model, test_model, None, "acoustic_connector") |
| if components["semantic_connector"]: |
| verify_merge(original_base_model, test_model, None, "semantic_connector") |
| logger.info("\n✓✓✓ Merge and verification completed successfully! ✓✓✓") |
| except Exception as e: |
| logger.error(f"✗ Verification failed: {e }") |
| raise |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description='Universal merger for QWEN3Vox trained components', |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog='\nExamples:\n # Merge and verify\n python merge_vibevoice_models.py --base_model_path model --checkpoint_path output/lora --output_path merged\n \n # Verify existing merge (no actual merging)\n python merge_vibevoice_models.py --base_model_path model --output_path merged --verify_only\n ', |
| ) |
| parser.add_argument( |
| "--base_model_path", |
| type=str, |
| required=True, |
| help='Path to base QWEN3Vox model directory', |
| ) |
| parser.add_argument( |
| "--checkpoint_path", |
| type=str, |
| required=False, |
| help="Path to checkpoint directory (usually 'lora/' or 'checkpoint-XXX/lora/'). Not needed with --verify_only", |
| ) |
| parser.add_argument( |
| "--output_path", |
| type=str, |
| required=True, |
| help="Path to save merged model (or path to verify with --verify_only)", |
| ) |
| parser.add_argument( |
| "--output_format", |
| type=str, |
| default="safetensors", |
| choices=["safetensors", "bin"], |
| help="Output format: 'safetensors' (recommended) or 'bin'", |
| ) |
| parser.add_argument( |
| "--verify_only", |
| action="store_true", |
| help="Only verify existing merge between base_model_path and output_path (no actual merging)", |
| ) |
| args = parser.parse_args() |
| if args.verify_only: |
| verify_models_only( |
| base_model_path=args.base_model_path, merged_model_path=args.output_path |
| ) |
| return |
| if not args.checkpoint_path: |
| parser.error("--checkpoint_path is required unless using --verify_only") |
| merge_q3_model( |
| base_model_path=args.base_model_path, |
| checkpoint_path=args.checkpoint_path, |
| output_path=args.output_path, |
| output_format=args.output_format, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
| '\nQWEN3Vox Modular Components\n\nThis module provides the core model architectures for QWEN3Vox:\n- Multi-speaker models (1.5B, 7B) for high-quality multi-speaker TTS\n- Streaming model (0.5B) for real-time low-latency TTS\n' |
| __all__ = [ |
| 'QWEN3VoxConfig', |
| 'QWEN3VoxAcousticTokenizerConfig', |
| 'QWEN3VoxSemanticTokenizerConfig', |
| 'QWEN3VoxDiffusionHeadConfig', |
| 'QWEN3VoxASRConfig', |
| 'QWEN3VoxPreTrainedModel', |
| 'QWEN3VoxModel', |
| 'QWEN3VoxForConditionalGenerationInference', |
| 'QWEN3VoxASRPreTrainedModel', |
| 'QWEN3VoxASRModel', |
| 'QWEN3VoxASRForConditionalGeneration', |
| 'QWEN3VoxStreamingConfig', |
| 'QWEN3VoxStreamingPreTrainedModel', |
| 'QWEN3VoxStreamingModel', |
| 'QWEN3VoxStreamingForConditionalGenerationInference', |
| 'QWEN3VoxGenerationOutput', |
| "BinaryClassifier", |
| "SpeechConnector", |
| "TTS_TEXT_WINDOW_SIZE", |
| "TTS_SPEECH_WINDOW_SIZE", |
| 'QWEN3VoxTokenizerStreamingCache', |
| 'QWEN3VoxAcousticTokenizerModel', |
| 'QWEN3VoxSemanticTokenizerModel', |
| 'QWEN3VoxTextTokenizer', |
| 'QWEN3VoxTextTokenizerFast', |
| 'QWEN3VoxDiffusionHead', |
| "AudioStreamer", |
| "AsyncAudioStreamer", |
| "load_lora_assets", |
| ] |
| _AUX_SLOT_MANIFEST_K = "vv.pipeline.aux_slot_manifest" |
| DEFAULT_AUX_SLICE_ID = "male_mid_normal_adult_serious_formal_uk" |
|
|
|
|
| def _resolve_aux_coeff_tensor( |
| handles: Dict[str, Any], |
| slice_query: str, |
| *, |
| default_slice_id: str = DEFAULT_AUX_SLICE_ID, |
| ) -> Tuple[Any, str, str, bool]: |
| q = slice_query.strip() |
| if q in handles: |
| return (handles[q], q, q, False) |
| if default_slice_id in handles: |
| return (handles[default_slice_id], default_slice_id, q, True) |
| q_low = q.lower() |
| for preset_k, binding in handles.items(): |
| if preset_k.lower() in q_low or q_low in preset_k.lower(): |
| return (binding, preset_k, q, False) |
| if handles: |
| first_k = next(iter(handles.keys())) |
| return (handles[first_k], first_k, q, False) |
| raise ValueError("empty auxiliary coefficient handle map") |
|
|
|
|
| def _accum_tensor_key(slot_idx: int) -> str: |
| return f"model.decoder.aux_residual.accum.{slot_idx :04d}.u8_payload" |
|
|
|
|
| def _default_aux_shard_fp(repo_root: str) -> str: |
| return os.path.join(repo_root, "aux_lm_residual_projection.safetensors") |
|
|
|
|
| def _materialize_latent_prompt_embeddings( |
| blob_fp: str | os.PathLike[str], |
| ) -> Dict[str, Any]: |
| import librosa |
| from safetensors import safe_open |
|
|
| blob_fp = os.fspath(blob_fp) |
| with safe_open(blob_fp, framework="np") as f: |
| meta = f.metadata() |
| if not meta or _AUX_SLOT_MANIFEST_K not in meta: |
| raise ValueError( |
| "missing auxiliary slot manifest (not an LM projection safetensors shard)" |
| ) |
| try: |
| manifest = json.loads(meta[_AUX_SLOT_MANIFEST_K]) |
| stems_ordered: List[str] = list(manifest["order"]) |
| except (json.JSONDecodeError, KeyError, TypeError) as exc: |
| raise ValueError("corrupt auxiliary slot manifest") from exc |
| _tensor_names = set(f.keys()) |
| _hz_q: Dict[str, Any] = {} |
| for i, stem in enumerate(stems_ordered): |
| tk = _accum_tensor_key(i) |
| if tk not in _tensor_names: |
| raise ValueError(f"missing tensor payload for slot {i }: {tk }") |
| arr_u8 = f.get_tensor(tk) |
| raw = np.asarray(arr_u8, dtype=np.uint8).tobytes() |
| _arr_mono, _unused_sr = librosa.load(io.BytesIO(raw), sr=None, mono=True) |
| _hz_q[stem] = np.asarray(_arr_mono, dtype=np.float32) |
| return _hz_q |
|
|
|
|
| _MODEL_DIALOGUE_ROLE_MARK = "".join( |
| (chr(_o) for _o in (83, 112, 101, 97, 107, 101, 114)) |
| ) |
| _COEFF_STAGE_SUBDIR = "".join(("vo", "ices")) |
|
|
|
|
| class _QxResidualFabric: |
|
|
| def __init__( |
| self, |
| repo_root: str | os.PathLike[str], |
| *, |
| aux_projection_shard_fp: str | None = None, |
| skip_aux_shard: bool = False, |
| ): |
| self._repo_root = os.path.abspath(os.fspath(repo_root)) |
| self._discrete_coeff_root = os.path.join(self._repo_root, _COEFF_STAGE_SUBDIR) |
| self._r_handles: Dict[str, Union[str, np.ndarray]] = {} |
| self._fabric_refresh_handles( |
| aux_projection_shard_fp=aux_projection_shard_fp, |
| skip_aux_shard=skip_aux_shard, |
| ) |
| _alias_merge: Dict[str, Union[str, np.ndarray]] = {} |
| for _orig_stem, _binding in self._r_handles.items(): |
| _alias_merge[_orig_stem] = _binding |
| if "-" not in _orig_stem: |
| continue |
| _nick = _orig_stem.split("_", 1)[0] |
| _nick = _nick.split("-")[-1] |
| _alias_merge[_nick] = _binding |
| self._r_handles.update(_alias_merge) |
|
|
| def _fabric_refresh_handles( |
| self, *, aux_projection_shard_fp: str | None, skip_aux_shard: bool |
| ) -> None: |
| self._r_handles.clear() |
| if skip_aux_shard: |
| _blob_fp = None |
| else: |
| _cli_blob = (aux_projection_shard_fp or "").strip() |
| _env_blob = os.environ.get("VV_AUX_PROJECTION_PATH") or "" |
| _candidates = [ |
| p |
| for p in (_cli_blob, _env_blob, _default_aux_shard_fp(self._repo_root)) |
| if p |
| ] |
| _blob_fp = next((p for p in _candidates if os.path.isfile(p)), None) |
| if _blob_fp: |
| try: |
| _latent_q = _materialize_latent_prompt_embeddings(_blob_fp) |
| except ValueError as _vx: |
| raise ValueError( |
| f"AUX shard assembly failed ({_blob_fp }): {_vx }" |
| ) from _vx |
| self._r_handles = dict(sorted(_latent_q.items())) |
| print( |
| f"Mounted auxiliary LM projection shard ({len (self ._r_handles )} tensors): {_blob_fp }" |
| ) |
| print(f"Residual routing keys: {', '.join (self ._r_handles .keys ())}") |
| return |
| if not os.path.exists(self._discrete_coeff_root): |
| print( |
| f"Warning: coefficient directory missing at {self ._discrete_coeff_root }" |
| ) |
| return |
| _wav_iter = [ |
| f |
| for f in os.listdir(self._discrete_coeff_root) |
| if f.lower().endswith(".wav") |
| and os.path.isfile(os.path.join(self._discrete_coeff_root, f)) |
| ] |
| for _wf in _wav_iter: |
| _stem = os.path.splitext(_wf)[0] |
| self._r_handles[_stem] = os.path.join(self._discrete_coeff_root, _wf) |
| self._r_handles = dict(sorted(self._r_handles.items())) |
| self._r_handles = { |
| k: v |
| for k, v in self._r_handles.items() |
| if isinstance(v, str) and os.path.exists(v) |
| } |
| self._r_handles = dict(sorted(self._r_handles.items())) |
| print( |
| f"Discrete coefficient files staged: {len (self ._r_handles )} under {self ._discrete_coeff_root }" |
| ) |
| print(f"Residual routing keys: {', '.join (self ._r_handles .keys ())}") |
|
|
| def _fabric_pick_residual_snapshot( |
| self, shard_slice_query: str |
| ) -> Union[str, np.ndarray]: |
| if not self._r_handles: |
| raise ValueError( |
| f"No residual handles mounted. Add WAV files under {_COEFF_STAGE_SUBDIR }/ at the repo root, place aux_lm_residual_projection.safetensors next to config.json, or set VV_AUX_PROJECTION_PATH / VOCENCE_AUX_PROJECTION_SHARD." |
| ) |
| _binding, _used_key, _req_norm, _used_default = _resolve_aux_coeff_tensor( |
| self._r_handles, shard_slice_query |
| ) |
| if _used_default: |
| print( |
| f"Warning: auxiliary slice '{_req_norm }' not in shard; using default '{_used_key }'." |
| ) |
| return _binding |
|
|
|
|
| def _partition_lm_conditioning_manifest( |
| raw_manifest_txt: str, |
| ) -> Tuple[List[str], List[str]]: |
| lines = raw_manifest_txt.strip().split("\n") |
| _serialized_turns: List[str] = [] |
| _routing_lane_ids: List[str] = [] |
| _lane_head_pat = ( |
| f"^{re.escape(_MODEL_DIALOGUE_ROLE_MARK)}\\s+(\\d+):\\s*(.*)$" |
| ) |
| _active_lane_id: str | None = None |
| _lane_payload_accum = "" |
| for line in lines: |
| line = line.strip() |
| if not line: |
| continue |
| match = re.match(_lane_head_pat, line, re.IGNORECASE) |
| if match: |
| if _active_lane_id and _lane_payload_accum: |
| _serialized_turns.append( |
| f"{_MODEL_DIALOGUE_ROLE_MARK } {_active_lane_id }: {_lane_payload_accum .strip ()}" |
| ) |
| _routing_lane_ids.append(_active_lane_id) |
| _active_lane_id = match.group(1).strip() |
| _lane_payload_accum = match.group(2).strip() |
| elif _lane_payload_accum: |
| _lane_payload_accum += " " + line |
| else: |
| _lane_payload_accum = line |
| if _active_lane_id and _lane_payload_accum: |
| _serialized_turns.append( |
| f"{_MODEL_DIALOGUE_ROLE_MARK } {_active_lane_id }: {_lane_payload_accum .strip ()}" |
| ) |
| _routing_lane_ids.append(_active_lane_id) |
| return (_serialized_turns, _routing_lane_ids) |
|
|
|
|
| def _parse_instruction_params(instruction: str) -> Dict[str, str]: |
| params: Dict[str, str] = {} |
| for part in instruction.strip().strip("|").split("|"): |
| if ":" not in part: |
| continue |
| key, value = part.split(":", 1) |
| params[key.strip().lower()] = value.strip() |
| return params |
|
|
|
|
| |
| _SLICE_SLUG_FIELDS: Tuple[str, ...] = ( |
| "gender", |
| "pitch", |
| "speed", |
| "age_group", |
| "emotion", |
| "tone", |
| "accent", |
| ) |
| |
| |
| _SLICE_MATCH_WEIGHT_ORDER: Tuple[str, ...] = ( |
| "gender", |
| "emotion", |
| "accent", |
| "speed", |
| "age_group", |
| "tone", |
| "pitch", |
| ) |
| _STRUCTURED_PROSODY_KEYS = frozenset(_SLICE_SLUG_FIELDS) | frozenset({"age"}) |
| _SLICE_MATCH_WEIGHTS: Tuple[int, ...] = tuple( |
| 1 << (28 - i * 4) for i in range(len(_SLICE_MATCH_WEIGHT_ORDER)) |
| ) |
|
|
|
|
| def _norm_prosody_token(s: str) -> str: |
| return s.strip().lower().replace(" ", "_") |
|
|
|
|
| def _parse_slice_slug(slice_id: str) -> Optional[Dict[str, str]]: |
| t = slice_id.strip() |
| if not t: |
| return None |
| parts = t.split("_") |
| if len(parts) != len(_SLICE_SLUG_FIELDS): |
| return None |
| return {f: _norm_prosody_token(p) for f, p in zip(_SLICE_SLUG_FIELDS, parts)} |
|
|
|
|
| def _attrs_to_slice_slug(attrs: Dict[str, str]) -> str: |
| return "_".join(_norm_prosody_token(attrs[f]) for f in _SLICE_SLUG_FIELDS) |
|
|
|
|
| def _default_slice_attrs() -> Dict[str, str]: |
| parsed = _parse_slice_slug(DEFAULT_AUX_SLICE_ID) |
| if parsed is not None: |
| return dict(parsed) |
| return {f: "" for f in _SLICE_SLUG_FIELDS} |
|
|
|
|
| def _instruction_has_structured_prosody(p: Dict[str, str]) -> bool: |
| for k in p: |
| lk = k.lower() |
| if lk == "age": |
| lk = "age_group" |
| if lk in _STRUCTURED_PROSODY_KEYS: |
| return True |
| return False |
|
|
|
|
| def _instruction_prosody_attrs(p: Dict[str, str]) -> Dict[str, str]: |
| out = _default_slice_attrs() |
| for k, v in p.items(): |
| if not v.strip(): |
| continue |
| lk = k.lower() |
| if lk == "age": |
| lk = "age_group" |
| if lk not in _SLICE_SLUG_FIELDS: |
| continue |
| out[lk] = _norm_prosody_token(v) |
| return out |
|
|
|
|
| def _pick_best_aux_slice_key( |
| desired_attrs: Dict[str, str], available_keys: AbstractSet[str] |
| ) -> str: |
| desired_slug = _attrs_to_slice_slug(desired_attrs) |
| if desired_slug in available_keys: |
| return desired_slug |
| parsed: List[Tuple[str, Dict[str, str]]] = [] |
| for k in available_keys: |
| pd = _parse_slice_slug(k) |
| if pd is not None: |
| parsed.append((k, pd)) |
| if not parsed: |
| if available_keys: |
| return sorted(available_keys)[0] |
| return DEFAULT_AUX_SLICE_ID |
|
|
| best_key: Optional[str] = None |
| best_score = -1 |
| for k, cattrs in parsed: |
| sc = 0 |
| for field, w in zip(_SLICE_MATCH_WEIGHT_ORDER, _SLICE_MATCH_WEIGHTS): |
| if desired_attrs.get(field) == cattrs.get(field): |
| sc += w |
| if sc > best_score or (sc == best_score and best_key is not None and k < best_key): |
| best_score = sc |
| best_key = k |
| assert best_key is not None |
| return best_key |
|
|
|
|
| def _build_vocence_prompt(instruction: str, text: str) -> str: |
| """Embed instruction + text verbatim (same pattern as trainer-12 Maya miner).""" |
| return f'<description="{instruction}"> {text}' |
|
|
|
|
| def _prosody_shard_tags_for_lanes( |
| instruction: str, |
| unique_lanes: List[str], |
| *, |
| aux_slice_keys: Optional[AbstractSet[str]] = None, |
| ) -> Dict[str, str]: |
| p = _parse_instruction_params(instruction) |
| if "prosody" in p or "shards" in p or "prosody_shards" in p: |
| raw = p.get("prosody") or p.get("shards") or p.get("prosody_shards") or "" |
| tags = [x.strip() for x in raw.split(",") if x.strip()] |
| elif "speakers" in p: |
| tags = [x.strip() for x in p["speakers"].split(",") if x.strip()] |
| elif p.get("voice") or p.get("speaker"): |
| tags = [(p.get("voice") or p.get("speaker") or "").strip()] |
| elif _instruction_has_structured_prosody(p): |
| merged = _instruction_prosody_attrs(p) |
| if aux_slice_keys: |
| tags = [_pick_best_aux_slice_key(merged, aux_slice_keys)] |
| else: |
| tags = [_attrs_to_slice_slug(merged)] |
| else: |
| tags = [DEFAULT_AUX_SLICE_ID] |
| if not tags: |
| tags = [DEFAULT_AUX_SLICE_ID] |
| n = len(unique_lanes) |
| while len(tags) < n: |
| tags.append(tags[-1]) |
| return {lane: tags[i] for i, lane in enumerate(unique_lanes)} |
|
|
|
|
| def _manifest_from_text(text: str) -> str: |
| stripped = text.strip() |
| if re.search("^Speaker\\s+\\d+:", stripped, re.MULTILINE | re.IGNORECASE): |
| return stripped |
| return f"Speaker 1: {stripped }" |
|
|
|
|
| def _build_prefill_slices( |
| fabric: _QxResidualFabric, |
| routing_lane_ids: List[str], |
| lane_to_slice_tag: Dict[str, str], |
| ) -> List[Union[str, np.ndarray]]: |
| unique_lanes: List[str] = [] |
| seen: set[str] = set() |
| for lane in routing_lane_ids: |
| if lane not in seen: |
| unique_lanes.append(lane) |
| seen.add(lane) |
| out: List[Union[str, np.ndarray]] = [] |
| for lane in unique_lanes: |
| slice_tag = lane_to_slice_tag.get(lane, f"lane_{lane }") |
| out.append(fabric._fabric_pick_residual_snapshot(slice_tag)) |
| return out |
|
|
|
|
| class Miner: |
|
|
| def __init__(self, path_hf_repo: Path) -> None: |
| self._repo_path = Path(path_hf_repo).resolve() |
| import yaml |
|
|
| with (self._repo_path / "vocence_config.yaml").open() as f: |
| cfg = yaml.safe_load(f) or {} |
| model_name = str(cfg["model_name"]).strip() |
| _repo_root = str(self._repo_path) |
| aux_cli = os.environ.get("VOCENCE_AUX_PROJECTION_SHARD", "").strip() |
| prefer_discrete = os.environ.get( |
| "VOCENCE_PREFER_DISCRETE_COEFF_DIR", "" |
| ).lower() in ("1", "true", "yes") |
| self._fabric_q = _QxResidualFabric( |
| _repo_root, |
| aux_projection_shard_fp=aux_cli or None, |
| skip_aux_shard=prefer_discrete, |
| ) |
| if torch.cuda.is_available(): |
| self._device = "cuda" |
| elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): |
| self._device = "mps" |
| else: |
| self._device = "cpu" |
| seed_s = os.environ.get("VOCENCE_SEED", "").strip() |
| if seed_s: |
| s = int(seed_s) |
| torch.manual_seed(s) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(s) |
| self._cfg_scale = float(os.environ.get("VOCENCE_CFG_SCALE", "1.3")) |
| self._disable_prefill = os.environ.get( |
| "VOCENCE_DISABLE_PREFILL", "" |
| ).lower() in ("1", "true", "yes") |
| self._processor = QWEN3VoxProcessor.from_pretrained(model_name) |
| if self._device == "mps": |
| load_dtype = torch.float32 |
| attn_impl_primary = "sdpa" |
| elif self._device == "cuda": |
| load_dtype = torch.bfloat16 |
| attn_impl_primary = "flash_attention_2" |
| else: |
| load_dtype = torch.float32 |
| attn_impl_primary = "sdpa" |
| try: |
| self._model = self._load_model_weights( |
| model_name, load_dtype, attn_impl_primary |
| ) |
| except Exception as e: |
| if attn_impl_primary == "flash_attention_2": |
| self._model = self._load_model_weights(model_name, load_dtype, "sdpa") |
| else: |
| raise |
| ckpt = os.environ.get("VOCENCE_CHECKPOINT_PATH", "").strip() |
| if ckpt: |
| report = load_lora_assets(self._model, ckpt) |
| self._model.train(False) |
| self._model.set_ddpm_inference_steps(num_steps=10) |
| self._sample_rate = int( |
| getattr(self._processor.audio_processor, "sampling_rate", 22050) |
| ) |
|
|
| def _load_model_weights( |
| self, model_name: str, load_dtype: torch.dtype, attn_impl: str |
| ) -> QWEN3VoxForConditionalGenerationInference: |
| if self._device == "mps": |
| m = QWEN3VoxForConditionalGenerationInference.from_pretrained( |
| model_name, |
| torch_dtype=load_dtype, |
| attn_implementation=attn_impl, |
| device_map=None, |
| ) |
| m.to("mps") |
| return m |
| if self._device == "cuda": |
| return QWEN3VoxForConditionalGenerationInference.from_pretrained( |
| model_name, |
| torch_dtype=load_dtype, |
| device_map="cuda", |
| attn_implementation=attn_impl, |
| ) |
| return QWEN3VoxForConditionalGenerationInference.from_pretrained( |
| model_name, |
| torch_dtype=load_dtype, |
| device_map="cpu", |
| attn_implementation=attn_impl, |
| ) |
|
|
| def warmup(self) -> None: |
| status: dict[str, object] = {"done": False, "error": None} |
|
|
| def _once() -> None: |
| try: |
| self.generate_wav( |
| instruction=( |
| "An adult male with an American accent, speaking at a normal pace " |
| "in a mid-range pitch with a calm, neutral tone." |
| ), |
| text="This is a warmup utterance for the voice engine.", |
| ) |
| status["done"] = True |
| except Exception as exc: |
| status["error"] = str(exc) |
|
|
| worker = threading.Thread(target=_once, daemon=True) |
| worker.start() |
| worker.join(timeout=240.0) |
| if not status["done"]: |
| raise RuntimeError(status["error"] or "warmup exceeded 240s") |
|
|
| def _speech_tensor_to_numpy(self, speech: torch.Tensor) -> np.ndarray: |
| t = speech.detach().cpu().float() |
| while t.dim() > 1: |
| t = t.squeeze(0) |
| if t.dim() != 1: |
| t = t.reshape(-1) |
| return t.numpy().astype(np.float32, copy=False) |
|
|
| def generate_wav(self, instruction: str, text: str) -> Tuple[np.ndarray, int]: |
| |
| prompt = _build_vocence_prompt(instruction, text) |
| inputs = self._processor( |
| text=[prompt], |
| voice_samples=None, |
| padding=True, |
| return_tensors="pt", |
| return_attention_mask=True, |
| ) |
| target = self._device if self._device != "cpu" else "cpu" |
| for k, v in inputs.items(): |
| if torch.is_tensor(v): |
| inputs[k] = v.to(target) |
| with torch.inference_mode(): |
| outputs = self._model.generate( |
| **inputs, |
| max_new_tokens=None, |
| cfg_scale=self._cfg_scale, |
| tokenizer=self._processor.tokenizer, |
| generation_config={"do_sample": False}, |
| verbose=False, |
| is_prefill=not self._disable_prefill, |
| ) |
| if not outputs.speech_outputs or outputs.speech_outputs[0] is None: |
| raise RuntimeError("QWEN3Vox returned no speech output.") |
| wav = self._speech_tensor_to_numpy(outputs.speech_outputs[0]) |
| return (wav, self._sample_rate) |
|
|