Spaces:
Running
Running
masterr-main
#2
by
ohmp
- opened
- Dockerfile +0 -38
- README.md +8 -12
- __pycache__/app.cpython-312.pyc +0 -0
- app.py +0 -471
- download_models.py +0 -37
- pre-requirements.txt +0 -1
- qwen_asr/__init__.py +0 -25
- qwen_asr/__main__.py +0 -25
- qwen_asr/cli/demo.py +0 -523
- qwen_asr/cli/serve.py +0 -46
- qwen_asr/core/transformers_backend/__init__.py +0 -18
- qwen_asr/core/transformers_backend/configuration_qwen3_asr.py +0 -425
- qwen_asr/core/transformers_backend/modeling_qwen3_asr.py +0 -1361
- qwen_asr/core/transformers_backend/processing_qwen3_asr.py +0 -209
- qwen_asr/core/vllm_backend/__init__.py +0 -16
- qwen_asr/core/vllm_backend/qwen3_asr.py +0 -997
- qwen_asr/inference/assets/korean_dict_jieba.dict +0 -0
- qwen_asr/inference/qwen3_asr.py +0 -519
- qwen_asr/inference/qwen3_forced_aligner.py +0 -484
- qwen_asr/inference/utils.py +0 -497
- qwen_tts/__init__.py +0 -25
- qwen_tts/__main__.py +0 -24
- qwen_tts/cli/demo.py +0 -633
- qwen_tts/core/__init__.py +0 -19
- qwen_tts/core/models/__init__.py +0 -18
- qwen_tts/core/models/configuration_qwen3_tts.py +0 -502
- qwen_tts/core/models/modeling_qwen3_tts.py +0 -2246
- qwen_tts/core/models/processing_qwen3_tts.py +0 -106
- qwen_tts/core/tokenizer_12hz/configuration_qwen3_tts_tokenizer_v2.py +0 -172
- qwen_tts/core/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +0 -1025
- qwen_tts/core/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py +0 -332
- qwen_tts/core/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py +0 -1528
- qwen_tts/core/tokenizer_25hz/vq/assets/mel_filters.npz +0 -3
- qwen_tts/core/tokenizer_25hz/vq/core_vq.py +0 -523
- qwen_tts/core/tokenizer_25hz/vq/speech_vq.py +0 -357
- qwen_tts/core/tokenizer_25hz/vq/whisper_encoder.py +0 -406
- qwen_tts/inference/qwen3_tts_model.py +0 -874
- qwen_tts/inference/qwen3_tts_tokenizer.py +0 -411
- requirements.txt +0 -17
Dockerfile
DELETED
|
@@ -1,38 +0,0 @@
|
|
| 1 |
-
# Use Python 3.10 as the base image
|
| 2 |
-
FROM python:3.10-slim
|
| 3 |
-
|
| 4 |
-
# Set up a new user named "user" with user ID 1000
|
| 5 |
-
RUN useradd -m -u 1000 user
|
| 6 |
-
|
| 7 |
-
# Switch to the "user" user
|
| 8 |
-
USER user
|
| 9 |
-
|
| 10 |
-
# Set home and path environment variables
|
| 11 |
-
ENV HOME=/home/user \
|
| 12 |
-
PATH=/home/user/.local/bin:$PATH
|
| 13 |
-
|
| 14 |
-
# Set the working directory
|
| 15 |
-
WORKDIR $HOME/app
|
| 16 |
-
|
| 17 |
-
# Copy the requirements file first to leverage Docker cache
|
| 18 |
-
COPY --chown=user requirements.txt $HOME/app/requirements.txt
|
| 19 |
-
|
| 20 |
-
# Install dependencies
|
| 21 |
-
RUN pip install --no-cache-dir --upgrade pip && \
|
| 22 |
-
pip install --no-cache-dir -r requirements.txt
|
| 23 |
-
|
| 24 |
-
# Copy the rest of the application code
|
| 25 |
-
# Copy the rest of the application code
|
| 26 |
-
#COPY --chown=user download_models.py $HOME/app/download_models.py
|
| 27 |
-
|
| 28 |
-
# Download models during build
|
| 29 |
-
#RUN python download_models.py
|
| 30 |
-
|
| 31 |
-
COPY --chown=user . $HOME/app
|
| 32 |
-
|
| 33 |
-
# Expose the default port for Hugging Face Spaces
|
| 34 |
-
EXPOSE 7860
|
| 35 |
-
|
| 36 |
-
# Run the application
|
| 37 |
-
# We use python app.py since it calls graph.launch()
|
| 38 |
-
CMD ["python", "app.py"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,16 +1,12 @@
|
|
| 1 |
---
|
| 2 |
title: Ohm Audio Studio
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
-
pinned: true
|
| 9 |
license: apache-2.0
|
| 10 |
-
short_description:
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
-
|
| 14 |
-
- asr
|
| 15 |
-
- qwen
|
| 16 |
-
---
|
|
|
|
| 1 |
---
|
| 2 |
title: Ohm Audio Studio
|
| 3 |
+
emoji: 🏆
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
+
pinned: false
|
|
|
|
| 8 |
license: apache-2.0
|
| 9 |
+
short_description: Daggr+Qwen3 ASR + Qwen3 TTS
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
__pycache__/app.cpython-312.pyc
DELETED
|
Binary file (20.5 kB)
|
|
|
app.py
DELETED
|
@@ -1,471 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Ohm Audio Studio
|
| 3 |
-
======================
|
| 4 |
-
|
| 5 |
-
A professional interface for Qwen2-Audio ASR and TTS models.
|
| 6 |
-
This application uses Daggr and Gradio to provide a seamless user experience
|
| 7 |
-
for Voice Design, Voice Cloning, Custom Voice Synthesis, and Automatic Speech Recognition.
|
| 8 |
-
|
| 9 |
-
Author: Ohm
|
| 10 |
-
Date: 2026
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
import os
|
| 14 |
-
import gc
|
| 15 |
-
import base64
|
| 16 |
-
import io
|
| 17 |
-
import logging
|
| 18 |
-
import numpy as np
|
| 19 |
-
import torch
|
| 20 |
-
import torchaudio
|
| 21 |
-
import soundfile as sf
|
| 22 |
-
import gradio as gr
|
| 23 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 24 |
-
from dataclasses import dataclass
|
| 25 |
-
|
| 26 |
-
from huggingface_hub import snapshot_download, login
|
| 27 |
-
from daggr import FnNode, Graph
|
| 28 |
-
|
| 29 |
-
# Configure Logging
|
| 30 |
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 31 |
-
logger = logging.getLogger(__name__)
|
| 32 |
-
|
| 33 |
-
# --- Configuration ---
|
| 34 |
-
|
| 35 |
-
@dataclass
|
| 36 |
-
class AppConfig:
|
| 37 |
-
HF_TOKEN: Optional[str] = os.environ.get('HF_TOKEN')
|
| 38 |
-
OMP_NUM_THREADS: str = "1"
|
| 39 |
-
|
| 40 |
-
MODEL_SIZES = ["0.6B", "1.7B"]
|
| 41 |
-
SPEAKERS = [
|
| 42 |
-
"Aiden", "Dylan", "Eric", "Ono_anna", "Ryan", "Serena", "Sohee", "Uncle_fu", "Vivian"
|
| 43 |
-
]
|
| 44 |
-
|
| 45 |
-
TTS_LANGUAGES = [
|
| 46 |
-
"Auto", "English", "Japanese", "Korean", "French", "German",
|
| 47 |
-
"Spanish", "Portuguese", "Russian"
|
| 48 |
-
]
|
| 49 |
-
|
| 50 |
-
ASR_SUPPORTED_LANGUAGES = [
|
| 51 |
-
"English", "Arabic", "German", "French", "Spanish", "Portuguese",
|
| 52 |
-
"Indonesian", "Italian", "Korean", "Russian", "Thai", "Vietnamese",
|
| 53 |
-
"Japanese", "Turkish", "Hindi", "Malay", "Dutch", "Swedish", "Danish",
|
| 54 |
-
"Finnish", "Polish", "Czech", "Filipino", "Persian", "Greek",
|
| 55 |
-
"Romanian", "Hungarian", "Macedonian"
|
| 56 |
-
]
|
| 57 |
-
|
| 58 |
-
# Apply Environment Variables
|
| 59 |
-
os.environ["OMP_NUM_THREADS"] = AppConfig.OMP_NUM_THREADS
|
| 60 |
-
if AppConfig.HF_TOKEN:
|
| 61 |
-
login(token=AppConfig.HF_TOKEN)
|
| 62 |
-
|
| 63 |
-
# --- Utilities ---
|
| 64 |
-
|
| 65 |
-
class AudioUtils:
|
| 66 |
-
"""Utilities for audio processing and normalization."""
|
| 67 |
-
|
| 68 |
-
@staticmethod
|
| 69 |
-
def title_case_display(s: str) -> str:
|
| 70 |
-
s = (s or "").strip()
|
| 71 |
-
s = s.replace("_", " ")
|
| 72 |
-
return " ".join([w[:1].upper() + w[1:] if w else "" for w in s.split()])
|
| 73 |
-
|
| 74 |
-
@staticmethod
|
| 75 |
-
def build_choices_and_map(items: Optional[List[str]]) -> Tuple[List[str], Dict[str, str]]:
|
| 76 |
-
if not items:
|
| 77 |
-
return [], {}
|
| 78 |
-
display = [AudioUtils.title_case_display(x) for x in items]
|
| 79 |
-
mapping = {d: r for d, r in zip(display, items)}
|
| 80 |
-
return display, mapping
|
| 81 |
-
|
| 82 |
-
@staticmethod
|
| 83 |
-
def normalize_audio(wav: np.ndarray, eps: float = 1e-12, clip: bool = True) -> np.float32:
|
| 84 |
-
"""Normalize audio to float32 in [-1, 1] range."""
|
| 85 |
-
x = np.asarray(wav)
|
| 86 |
-
|
| 87 |
-
if np.issubdtype(x.dtype, np.integer):
|
| 88 |
-
info = np.iinfo(x.dtype)
|
| 89 |
-
if info.min < 0:
|
| 90 |
-
y = x.astype(np.float32) / max(abs(info.min), info.max)
|
| 91 |
-
else:
|
| 92 |
-
mid = (info.max + 1) / 2.0
|
| 93 |
-
y = (x.astype(np.float32) - mid) / mid
|
| 94 |
-
elif np.issubdtype(x.dtype, np.floating):
|
| 95 |
-
y = x.astype(np.float32)
|
| 96 |
-
m = np.max(np.abs(y)) if y.size else 0.0
|
| 97 |
-
if m > 1.0 + 1e-6:
|
| 98 |
-
y = y / (m + eps)
|
| 99 |
-
else:
|
| 100 |
-
y = x.astype(np.float32)
|
| 101 |
-
|
| 102 |
-
if clip:
|
| 103 |
-
y = np.clip(y, -1.0, 1.0)
|
| 104 |
-
|
| 105 |
-
if y.ndim > 1:
|
| 106 |
-
y = np.mean(y, axis=-1).astype(np.float32)
|
| 107 |
-
|
| 108 |
-
return y
|
| 109 |
-
|
| 110 |
-
@staticmethod
|
| 111 |
-
def process_input(audio_input: Any) -> Optional[Tuple[np.float32, int]]:
|
| 112 |
-
"""
|
| 113 |
-
Handles Filepaths, Data URIs (base64), and Numpy arrays.
|
| 114 |
-
Returns (numpy_float32, sample_rate_int)
|
| 115 |
-
"""
|
| 116 |
-
if audio_input is None:
|
| 117 |
-
return None
|
| 118 |
-
|
| 119 |
-
try:
|
| 120 |
-
# Handle Path or Base64
|
| 121 |
-
if isinstance(audio_input, str):
|
| 122 |
-
if audio_input.startswith("data:"):
|
| 123 |
-
try:
|
| 124 |
-
header, encoded = audio_input.split(",", 1)
|
| 125 |
-
data = base64.b64decode(encoded)
|
| 126 |
-
wav, sr = sf.read(io.BytesIO(data))
|
| 127 |
-
return AudioUtils.normalize_audio(wav), int(sr)
|
| 128 |
-
except Exception as e:
|
| 129 |
-
logger.error(f"Failed to decode base64 audio: {e}")
|
| 130 |
-
return None
|
| 131 |
-
|
| 132 |
-
if os.path.exists(audio_input):
|
| 133 |
-
wav_tensor, sr = torchaudio.load(audio_input)
|
| 134 |
-
wav = wav_tensor.mean(dim=0).numpy()
|
| 135 |
-
return AudioUtils.normalize_audio(wav), int(sr)
|
| 136 |
-
else:
|
| 137 |
-
logger.error(f"Input string is not a file or valid data URI: {audio_input[:50]}...")
|
| 138 |
-
return None
|
| 139 |
-
|
| 140 |
-
# Handle Tuple (sample_rate, data) or (data, sample_rate)
|
| 141 |
-
if isinstance(audio_input, tuple) and len(audio_input) == 2:
|
| 142 |
-
a0, a1 = audio_input
|
| 143 |
-
if isinstance(a0, int):
|
| 144 |
-
return AudioUtils.normalize_audio(a1), int(a0)
|
| 145 |
-
else:
|
| 146 |
-
return AudioUtils.normalize_audio(a0), int(a1)
|
| 147 |
-
|
| 148 |
-
# Handle Dictionary
|
| 149 |
-
if isinstance(audio_input, dict):
|
| 150 |
-
if "name" in audio_input:
|
| 151 |
-
return AudioUtils.process_input(audio_input["name"])
|
| 152 |
-
if "path" in audio_input:
|
| 153 |
-
return AudioUtils.process_input(audio_input["path"])
|
| 154 |
-
if "sampling_rate" in audio_input and "data" in audio_input:
|
| 155 |
-
return AudioUtils.normalize_audio(audio_input["data"]), int(audio_input["sampling_rate"])
|
| 156 |
-
|
| 157 |
-
return None
|
| 158 |
-
|
| 159 |
-
except Exception as e:
|
| 160 |
-
logger.error(f"Audio Processing Error: {e}")
|
| 161 |
-
return None
|
| 162 |
-
|
| 163 |
-
# --- Model Management ---
|
| 164 |
-
|
| 165 |
-
class ModelManager:
|
| 166 |
-
"""Manages loading and unloading of AI models."""
|
| 167 |
-
|
| 168 |
-
def __init__(self):
|
| 169 |
-
self._loaded_models = {}
|
| 170 |
-
|
| 171 |
-
def _get_model_path(self, model_type: str, model_size: str) -> str:
|
| 172 |
-
"""Download/Get model path based on type and size."""
|
| 173 |
-
if model_type == "ASR":
|
| 174 |
-
return "Qwen/Qwen3-ASR-1.7B"
|
| 175 |
-
return snapshot_download(f"Qwen/Qwen3-TTS-12Hz-{model_size}-{model_type}")
|
| 176 |
-
|
| 177 |
-
def get_model(self, model_type: str, model_size: str):
|
| 178 |
-
"""
|
| 179 |
-
Lazy load models. Unloads previous models if VRAM is tight.
|
| 180 |
-
"""
|
| 181 |
-
key = (model_type, model_size)
|
| 182 |
-
|
| 183 |
-
if key not in self._loaded_models:
|
| 184 |
-
logger.info(f"Clearing Cache before loading {model_type}...")
|
| 185 |
-
self._loaded_models.clear()
|
| 186 |
-
gc.collect()
|
| 187 |
-
if torch.cuda.is_available():
|
| 188 |
-
torch.cuda.empty_cache()
|
| 189 |
-
|
| 190 |
-
logger.info(f"Loading Model: {model_type} {model_size}...")
|
| 191 |
-
|
| 192 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 193 |
-
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 194 |
-
|
| 195 |
-
if model_type == "ASR":
|
| 196 |
-
from qwen_asr import Qwen3ASRModel
|
| 197 |
-
self._loaded_models[key] = Qwen3ASRModel.from_pretrained(
|
| 198 |
-
"Qwen/Qwen3-ASR-1.7B",
|
| 199 |
-
dtype=dtype,
|
| 200 |
-
device_map=device,
|
| 201 |
-
forced_aligner="Qwen/Qwen3-ForcedAligner-0.6B",
|
| 202 |
-
forced_aligner_kwargs=dict(dtype=dtype, device_map=device),
|
| 203 |
-
max_inference_batch_size=4,
|
| 204 |
-
attn_implementation="sdpa",
|
| 205 |
-
)
|
| 206 |
-
else:
|
| 207 |
-
from qwen_tts import Qwen3TTSModel
|
| 208 |
-
model_path = self._get_model_path(model_type, model_size)
|
| 209 |
-
self._loaded_models[key] = Qwen3TTSModel.from_pretrained(
|
| 210 |
-
model_path,
|
| 211 |
-
device_map=device,
|
| 212 |
-
dtype=dtype,
|
| 213 |
-
token=AppConfig.HF_TOKEN,
|
| 214 |
-
)
|
| 215 |
-
|
| 216 |
-
return self._loaded_models[key]
|
| 217 |
-
|
| 218 |
-
# --- Core Service ---
|
| 219 |
-
|
| 220 |
-
class QwenService:
|
| 221 |
-
"""Core service logic connecting the ModelManager and AudioUtils."""
|
| 222 |
-
|
| 223 |
-
def __init__(self):
|
| 224 |
-
self.models = ModelManager()
|
| 225 |
-
|
| 226 |
-
def _cleanup_resources(self):
|
| 227 |
-
gc.collect()
|
| 228 |
-
if torch.cuda.is_available():
|
| 229 |
-
torch.cuda.empty_cache()
|
| 230 |
-
|
| 231 |
-
def voice_design(self, text, language, voice_description):
|
| 232 |
-
"""Voice Design (Prompt-to-Speech)"""
|
| 233 |
-
self._cleanup_resources()
|
| 234 |
-
|
| 235 |
-
if not text: return None, "Text required"
|
| 236 |
-
if not voice_description: return None, "Description required"
|
| 237 |
-
|
| 238 |
-
try:
|
| 239 |
-
tts = self.models.get_model("VoiceDesign", "1.7B")
|
| 240 |
-
|
| 241 |
-
wavs, sr = tts.generate_voice_design(
|
| 242 |
-
text=text.strip(),
|
| 243 |
-
language=language,
|
| 244 |
-
instruct=voice_description.strip(),
|
| 245 |
-
non_streaming_mode=True,
|
| 246 |
-
max_new_tokens=2048,
|
| 247 |
-
)
|
| 248 |
-
return (sr, wavs[0]), "Success"
|
| 249 |
-
except Exception as e:
|
| 250 |
-
logger.exception("Voice Design Error")
|
| 251 |
-
return None, f"Error: {str(e)}"
|
| 252 |
-
|
| 253 |
-
def voice_clone(self, ref_audio, ref_text, target_text, language, use_xvector_only, model_size):
|
| 254 |
-
"""Voice Cloning (Zero-Shot)"""
|
| 255 |
-
self._cleanup_resources()
|
| 256 |
-
|
| 257 |
-
if not target_text: return None, "Target text required"
|
| 258 |
-
|
| 259 |
-
audio_tuple = AudioUtils.process_input(ref_audio)
|
| 260 |
-
if audio_tuple is None:
|
| 261 |
-
return None, "Error: Could not process reference audio. Please upload a valid WAV/MP3."
|
| 262 |
-
|
| 263 |
-
if not use_xvector_only and not ref_text:
|
| 264 |
-
return None, "Error: Reference text required (or check 'Use x-vector only')"
|
| 265 |
-
|
| 266 |
-
try:
|
| 267 |
-
tts = self.models.get_model("Base", model_size)
|
| 268 |
-
|
| 269 |
-
wavs, sr = tts.generate_voice_clone(
|
| 270 |
-
text=target_text.strip(),
|
| 271 |
-
language=language,
|
| 272 |
-
ref_audio=audio_tuple,
|
| 273 |
-
ref_text=ref_text.strip() if ref_text else None,
|
| 274 |
-
x_vector_only_mode=use_xvector_only,
|
| 275 |
-
max_new_tokens=2048,
|
| 276 |
-
)
|
| 277 |
-
return (sr, wavs[0]), "Success"
|
| 278 |
-
except Exception as e:
|
| 279 |
-
logger.exception("Voice Clone Error")
|
| 280 |
-
return None, f"Error: {str(e)}"
|
| 281 |
-
|
| 282 |
-
def custom_voice(self, text, language, speaker, instruct, model_size):
|
| 283 |
-
"""Standard TTS"""
|
| 284 |
-
self._cleanup_resources()
|
| 285 |
-
|
| 286 |
-
if not text: return None, "Text required"
|
| 287 |
-
|
| 288 |
-
try:
|
| 289 |
-
tts = self.models.get_model("CustomVoice", model_size)
|
| 290 |
-
|
| 291 |
-
wavs, sr = tts.generate_custom_voice(
|
| 292 |
-
text=text.strip(),
|
| 293 |
-
language=language,
|
| 294 |
-
speaker=speaker.lower().replace(" ", "_"),
|
| 295 |
-
instruct=instruct.strip() if instruct else None,
|
| 296 |
-
non_streaming_mode=True,
|
| 297 |
-
max_new_tokens=2048,
|
| 298 |
-
)
|
| 299 |
-
return (sr, wavs[0]), "Success"
|
| 300 |
-
except Exception as e:
|
| 301 |
-
logger.exception("Custom Voice Error")
|
| 302 |
-
return None, f"Error: {str(e)}"
|
| 303 |
-
|
| 304 |
-
def asr(self, audio_upload, lang_disp):
|
| 305 |
-
"""Automatic Speech Recognition"""
|
| 306 |
-
self._cleanup_resources()
|
| 307 |
-
|
| 308 |
-
if audio_upload is None:
|
| 309 |
-
return "", "", "No Audio"
|
| 310 |
-
|
| 311 |
-
processed_audio = AudioUtils.process_input(audio_upload)
|
| 312 |
-
if processed_audio is None:
|
| 313 |
-
return "", "", "Error processing audio"
|
| 314 |
-
|
| 315 |
-
language = None
|
| 316 |
-
if lang_disp and lang_disp != "Auto":
|
| 317 |
-
# Assuming ASR_LANG_MAP is globally available or we rebuild it
|
| 318 |
-
# For efficiency let's reuse if possible, or rebuild locally
|
| 319 |
-
_, mapping = AudioUtils.build_choices_and_map(AppConfig.ASR_SUPPORTED_LANGUAGES)
|
| 320 |
-
language = mapping.get(lang_disp, lang_disp)
|
| 321 |
-
|
| 322 |
-
try:
|
| 323 |
-
asr_model = self.models.get_model("ASR", "1.7B")
|
| 324 |
-
|
| 325 |
-
results = asr_model.transcribe(
|
| 326 |
-
audio=processed_audio,
|
| 327 |
-
language=language,
|
| 328 |
-
return_time_stamps=False,
|
| 329 |
-
)
|
| 330 |
-
|
| 331 |
-
if not isinstance(results, list) or len(results) != 1:
|
| 332 |
-
return "", "", "Unexpected result format"
|
| 333 |
-
|
| 334 |
-
r = results[0]
|
| 335 |
-
detected_lang = getattr(r, "language", "") or ""
|
| 336 |
-
transcribed_text = getattr(r, "text", "") or ""
|
| 337 |
-
|
| 338 |
-
return detected_lang, transcribed_text, "Success"
|
| 339 |
-
|
| 340 |
-
except Exception as e:
|
| 341 |
-
logger.exception("ASR Error")
|
| 342 |
-
return "", "", f"Error: {str(e)}"
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
# --- Graph Construction ---
|
| 346 |
-
|
| 347 |
-
# Initialize Service
|
| 348 |
-
service = QwenService()
|
| 349 |
-
ASR_LANG_DISPLAY, _ = AudioUtils.build_choices_and_map(AppConfig.ASR_SUPPORTED_LANGUAGES)
|
| 350 |
-
ASR_LANG_CHOICES = ["Auto"] + ASR_LANG_DISPLAY
|
| 351 |
-
|
| 352 |
-
# Define Nodes
|
| 353 |
-
voice_design_node = FnNode(
|
| 354 |
-
fn=service.voice_design,
|
| 355 |
-
inputs={
|
| 356 |
-
"text": gr.Textbox(
|
| 357 |
-
label="Text to Synthesize (Voice Design)",
|
| 358 |
-
lines=4,
|
| 359 |
-
value="Welcome to Ohm Audio Studio. Experience the future of voice design."
|
| 360 |
-
),
|
| 361 |
-
"language": gr.Dropdown(
|
| 362 |
-
label="Language (Voice Design)",
|
| 363 |
-
choices=AppConfig.TTS_LANGUAGES,
|
| 364 |
-
value="Auto"
|
| 365 |
-
),
|
| 366 |
-
"voice_description": gr.Textbox(
|
| 367 |
-
label="Voice Description (Voice Design)",
|
| 368 |
-
lines=3,
|
| 369 |
-
value="A professional, warm and inviting voice with a clear, confident tone."
|
| 370 |
-
),
|
| 371 |
-
},
|
| 372 |
-
outputs={
|
| 373 |
-
"generated_audio": gr.Audio(label="Generated Audio", type="numpy"),
|
| 374 |
-
"status": gr.Textbox(label="Status", interactive=False),
|
| 375 |
-
},
|
| 376 |
-
name="Voice Design"
|
| 377 |
-
)
|
| 378 |
-
|
| 379 |
-
custom_voice_node = FnNode(
|
| 380 |
-
fn=service.custom_voice,
|
| 381 |
-
inputs={
|
| 382 |
-
"text": gr.Textbox(
|
| 383 |
-
label="Text to Synthesize (Custom Voice)",
|
| 384 |
-
lines=4,
|
| 385 |
-
value="Welcome to Ohm Audio Studio coverage of the latest in AI audio technology."
|
| 386 |
-
),
|
| 387 |
-
"language": gr.Dropdown(
|
| 388 |
-
label="Language (Custom Voice)",
|
| 389 |
-
choices=AppConfig.TTS_LANGUAGES,
|
| 390 |
-
value="English"
|
| 391 |
-
),
|
| 392 |
-
"speaker": gr.Dropdown(
|
| 393 |
-
label="Speaker (Custom Voice)",
|
| 394 |
-
choices=AppConfig.SPEAKERS,
|
| 395 |
-
value="Ryan"
|
| 396 |
-
),
|
| 397 |
-
"instruct": gr.Textbox(
|
| 398 |
-
label="Style Instruction (Custom Voice)",
|
| 399 |
-
lines=2,
|
| 400 |
-
placeholder="e.g. Happy, Sad",
|
| 401 |
-
value="Neutral"
|
| 402 |
-
),
|
| 403 |
-
"model_size": gr.Dropdown(
|
| 404 |
-
label="Model Size (Custom Voice)",
|
| 405 |
-
choices=AppConfig.MODEL_SIZES,
|
| 406 |
-
value="1.7B"
|
| 407 |
-
),
|
| 408 |
-
},
|
| 409 |
-
outputs={
|
| 410 |
-
"tts_audio": gr.Audio(label="Generated Audio", type="numpy"),
|
| 411 |
-
"status": gr.Textbox(label="Status", interactive=False),
|
| 412 |
-
},
|
| 413 |
-
name="Custom Voice"
|
| 414 |
-
)
|
| 415 |
-
|
| 416 |
-
voice_clone_node = FnNode(
|
| 417 |
-
fn=service.voice_clone,
|
| 418 |
-
inputs={
|
| 419 |
-
"ref_audio": gr.Audio(label="Reference Audio (Voice Clone)", type="filepath"),
|
| 420 |
-
"ref_text": gr.Textbox(label="Reference Transcript (Voice Clone)", lines=2),
|
| 421 |
-
"target_text": gr.Textbox(label="Target Text (Voice Clone)", lines=4),
|
| 422 |
-
"language": gr.Dropdown(
|
| 423 |
-
label="Language (Voice Clone)",
|
| 424 |
-
choices=AppConfig.TTS_LANGUAGES,
|
| 425 |
-
value="Auto"
|
| 426 |
-
),
|
| 427 |
-
"use_xvector_only": gr.Checkbox(label="Use x-vector only (Voice Clone)", value=False),
|
| 428 |
-
"model_size": gr.Dropdown(
|
| 429 |
-
label="Model Size (Voice Clone)",
|
| 430 |
-
choices=AppConfig.MODEL_SIZES,
|
| 431 |
-
value="1.7B"
|
| 432 |
-
),
|
| 433 |
-
},
|
| 434 |
-
outputs={
|
| 435 |
-
"cloned_audio": gr.Audio(label="Cloned Audio", type="numpy"),
|
| 436 |
-
"status": gr.Textbox(label="Status", interactive=False),
|
| 437 |
-
},
|
| 438 |
-
name="Voice Clone"
|
| 439 |
-
)
|
| 440 |
-
|
| 441 |
-
asr_node = FnNode(
|
| 442 |
-
fn=service.asr,
|
| 443 |
-
inputs={
|
| 444 |
-
"audio_upload": gr.Audio(
|
| 445 |
-
label="Upload Audio (Qwen3 ASR)",
|
| 446 |
-
type="numpy",
|
| 447 |
-
sources=["upload", "microphone"]
|
| 448 |
-
),
|
| 449 |
-
"lang_disp": gr.Dropdown(
|
| 450 |
-
label="Language (Qwen3 ASR)",
|
| 451 |
-
choices=ASR_LANG_CHOICES,
|
| 452 |
-
value="Auto"
|
| 453 |
-
),
|
| 454 |
-
},
|
| 455 |
-
outputs={
|
| 456 |
-
"detected_lang": gr.Textbox(label="Detected Language", interactive=False),
|
| 457 |
-
"transcription": gr.Textbox(label="Transcription Result", lines=6, interactive=True),
|
| 458 |
-
"status": gr.Textbox(label="Status", interactive=False),
|
| 459 |
-
},
|
| 460 |
-
name="Qwen3 ASR"
|
| 461 |
-
)
|
| 462 |
-
|
| 463 |
-
# Create and Launch Graph
|
| 464 |
-
graph = Graph(
|
| 465 |
-
name="Ohm-Audio-Studio",
|
| 466 |
-
nodes=[voice_design_node, custom_voice_node, voice_clone_node, asr_node]
|
| 467 |
-
)
|
| 468 |
-
|
| 469 |
-
if __name__ == "__main__":
|
| 470 |
-
port = int(os.environ.get("PORT", 7860))
|
| 471 |
-
graph.launch(host="0.0.0.0", port=port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
download_models.py
DELETED
|
@@ -1,37 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import os
|
| 3 |
-
from huggingface_hub import snapshot_download
|
| 4 |
-
|
| 5 |
-
# List of models used in the app
|
| 6 |
-
models_to_download = [
|
| 7 |
-
# ASR Models
|
| 8 |
-
"Qwen/Qwen3-ASR-1.7B",
|
| 9 |
-
"Qwen/Qwen3-ForcedAligner-0.6B",
|
| 10 |
-
|
| 11 |
-
# TTS Models (1.7B versions as they are the default/main ones used)
|
| 12 |
-
"Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
|
| 13 |
-
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
|
| 14 |
-
"Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
|
| 15 |
-
|
| 16 |
-
# Uncomment if you plan to use 0.6B TTS models
|
| 17 |
-
# "Qwen/Qwen3-TTS-12Hz-0.6B-VoiceDesign",
|
| 18 |
-
# "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
|
| 19 |
-
# "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
|
| 20 |
-
]
|
| 21 |
-
|
| 22 |
-
print(f"Starting download of {len(models_to_download)} models...")
|
| 23 |
-
|
| 24 |
-
for model_id in models_to_download:
|
| 25 |
-
print(f"\nDownloading: {model_id}")
|
| 26 |
-
try:
|
| 27 |
-
# If HF_TOKEN is empty/None, treat as anonymous (public models only)
|
| 28 |
-
token = os.environ.get('HF_TOKEN')
|
| 29 |
-
if token and not token.strip():
|
| 30 |
-
token = None
|
| 31 |
-
|
| 32 |
-
path = snapshot_download(repo_id=model_id, token=token)
|
| 33 |
-
print(f"Successfully downloaded to: {path}")
|
| 34 |
-
except Exception as e:
|
| 35 |
-
print(f"Failed to download {model_id}: {e}")
|
| 36 |
-
|
| 37 |
-
print("\nAll downloads complete!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pre-requirements.txt
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
pip>=23.0.0
|
|
|
|
|
|
qwen_asr/__init__.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
"""
|
| 17 |
-
qwen_asr: Qwen3-ASR package.
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
from .inference.qwen3_asr import Qwen3ASRModel
|
| 21 |
-
from .inference.qwen3_forced_aligner import Qwen3ForcedAligner
|
| 22 |
-
|
| 23 |
-
from .inference.utils import parse_asr_output
|
| 24 |
-
|
| 25 |
-
__all__ = ["__version__"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_asr/__main__.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
def main():
|
| 17 |
-
print(
|
| 18 |
-
"qwen_asr package.\n"
|
| 19 |
-
"Use CLI entrypoints:\n"
|
| 20 |
-
" - qwen-asr-demo\n"
|
| 21 |
-
" - qwen-asr-serve\n"
|
| 22 |
-
)
|
| 23 |
-
|
| 24 |
-
if __name__ == "__main__":
|
| 25 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_asr/cli/demo.py
DELETED
|
@@ -1,523 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
"""
|
| 5 |
-
A gradio demo for Qwen3 ASR models.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import argparse
|
| 9 |
-
import base64
|
| 10 |
-
import io
|
| 11 |
-
import json
|
| 12 |
-
import os
|
| 13 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 14 |
-
|
| 15 |
-
import gradio as gr
|
| 16 |
-
import numpy as np
|
| 17 |
-
import torch
|
| 18 |
-
from qwen_asr import Qwen3ASRModel
|
| 19 |
-
from scipy.io.wavfile import write as wav_write
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def _title_case_display(s: str) -> str:
|
| 23 |
-
s = (s or "").strip()
|
| 24 |
-
s = s.replace("_", " ")
|
| 25 |
-
return " ".join([w[:1].upper() + w[1:] if w else "" for w in s.split()])
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def _build_choices_and_map(items: Optional[List[str]]) -> Tuple[List[str], Dict[str, str]]:
|
| 29 |
-
if not items:
|
| 30 |
-
return [], {}
|
| 31 |
-
display = [_title_case_display(x) for x in items]
|
| 32 |
-
mapping = {d: r for d, r in zip(display, items)}
|
| 33 |
-
return display, mapping
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def _dtype_from_str(s: str) -> torch.dtype:
|
| 37 |
-
s = (s or "").strip().lower()
|
| 38 |
-
if s in ("bf16", "bfloat16"):
|
| 39 |
-
return torch.bfloat16
|
| 40 |
-
if s in ("fp16", "float16", "half"):
|
| 41 |
-
return torch.float16
|
| 42 |
-
if s in ("fp32", "float32"):
|
| 43 |
-
return torch.float32
|
| 44 |
-
raise ValueError(f"Unsupported torch dtype: {s}. Use bfloat16/float16/float32.")
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def _normalize_audio(wav, eps=1e-12, clip=True):
|
| 48 |
-
x = np.asarray(wav)
|
| 49 |
-
|
| 50 |
-
if np.issubdtype(x.dtype, np.integer):
|
| 51 |
-
info = np.iinfo(x.dtype)
|
| 52 |
-
if info.min < 0:
|
| 53 |
-
y = x.astype(np.float32) / max(abs(info.min), info.max)
|
| 54 |
-
else:
|
| 55 |
-
mid = (info.max + 1) / 2.0
|
| 56 |
-
y = (x.astype(np.float32) - mid) / mid
|
| 57 |
-
elif np.issubdtype(x.dtype, np.floating):
|
| 58 |
-
y = x.astype(np.float32)
|
| 59 |
-
m = np.max(np.abs(y)) if y.size else 0.0
|
| 60 |
-
if m > 1.0 + 1e-6:
|
| 61 |
-
y = y / (m + eps)
|
| 62 |
-
else:
|
| 63 |
-
raise TypeError(f"Unsupported dtype: {x.dtype}")
|
| 64 |
-
|
| 65 |
-
if clip:
|
| 66 |
-
y = np.clip(y, -1.0, 1.0)
|
| 67 |
-
|
| 68 |
-
if y.ndim > 1:
|
| 69 |
-
y = np.mean(y, axis=-1).astype(np.float32)
|
| 70 |
-
|
| 71 |
-
return y
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def _audio_to_tuple(audio: Any) -> Optional[Tuple[np.ndarray, int]]:
|
| 75 |
-
"""
|
| 76 |
-
Accept gradio audio:
|
| 77 |
-
- {"sampling_rate": int, "data": np.ndarray}
|
| 78 |
-
- (sr, np.ndarray) [some gradio versions]
|
| 79 |
-
Return: (wav_float32_mono, sr)
|
| 80 |
-
"""
|
| 81 |
-
if audio is None:
|
| 82 |
-
return None
|
| 83 |
-
|
| 84 |
-
if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio:
|
| 85 |
-
sr = int(audio["sampling_rate"])
|
| 86 |
-
wav = _normalize_audio(audio["data"])
|
| 87 |
-
return wav, sr
|
| 88 |
-
|
| 89 |
-
if isinstance(audio, tuple) and len(audio) == 2:
|
| 90 |
-
a0, a1 = audio
|
| 91 |
-
if isinstance(a0, int):
|
| 92 |
-
sr = int(a0)
|
| 93 |
-
wav = _normalize_audio(a1)
|
| 94 |
-
return wav, sr
|
| 95 |
-
if isinstance(a1, int):
|
| 96 |
-
wav = _normalize_audio(a0)
|
| 97 |
-
sr = int(a1)
|
| 98 |
-
return wav, sr
|
| 99 |
-
|
| 100 |
-
return None
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def _parse_audio_any(audio: Any) -> Union[str, Tuple[np.ndarray, int]]:
|
| 104 |
-
if audio is None:
|
| 105 |
-
raise ValueError("Audio is required.")
|
| 106 |
-
at = _audio_to_tuple(audio)
|
| 107 |
-
if at is not None:
|
| 108 |
-
return at
|
| 109 |
-
raise ValueError("Unsupported audio input format.")
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
def build_parser() -> argparse.ArgumentParser:
|
| 113 |
-
parser = argparse.ArgumentParser(
|
| 114 |
-
prog="qwen-asr-demo",
|
| 115 |
-
description=(
|
| 116 |
-
"Launch a Gradio demo for Qwen3 ASR models (Transformers / vLLM).\n\n"
|
| 117 |
-
"Examples:\n"
|
| 118 |
-
" qwen-asr-demo --asr-checkpoint Qwen/Qwen3-ASR-1.7B\n"
|
| 119 |
-
" qwen-asr-demo --asr-checkpoint Qwen/Qwen3-ASR-1.7B --aligner-checkpoint Qwen/Qwen3-ForcedAligner-0.6B\n"
|
| 120 |
-
" qwen-asr-demo --backend vllm --cuda-visible-devices 0\n"
|
| 121 |
-
" qwen-asr-demo --backend transformers --backend-kwargs '{\"device_map\":\"cuda:0\",\"dtype\":\"bfloat16\",\"attn_implementation\":\"flash_attention_2\"}'\n"
|
| 122 |
-
" qwen-asr-demo --backend vllm --backend-kwargs '{\"gpu_memory_utilization\":0.85}'\n"
|
| 123 |
-
),
|
| 124 |
-
formatter_class=argparse.RawTextHelpFormatter,
|
| 125 |
-
add_help=True,
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
parser.add_argument("--asr-checkpoint", required=True, help="Qwen3-ASR model checkpoint path or HF repo id.")
|
| 129 |
-
parser.add_argument(
|
| 130 |
-
"--aligner-checkpoint",
|
| 131 |
-
default=None,
|
| 132 |
-
help="Qwen3-ForcedAligner checkpoint path or HF repo id (optional; enables timestamps when provided).",
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
parser.add_argument(
|
| 136 |
-
"--backend",
|
| 137 |
-
default="transformers",
|
| 138 |
-
choices=["transformers", "vllm"],
|
| 139 |
-
help="Backend for ASR model loading (default: transformers).",
|
| 140 |
-
)
|
| 141 |
-
|
| 142 |
-
parser.add_argument(
|
| 143 |
-
"--cuda-visible-devices",
|
| 144 |
-
default="0",
|
| 145 |
-
help=(
|
| 146 |
-
"Set CUDA_VISIBLE_DEVICES for the demo process (default: 0). "
|
| 147 |
-
"Use e.g. '0' or '1'"
|
| 148 |
-
),
|
| 149 |
-
)
|
| 150 |
-
|
| 151 |
-
parser.add_argument(
|
| 152 |
-
"--backend-kwargs",
|
| 153 |
-
default=None,
|
| 154 |
-
help=(
|
| 155 |
-
"JSON dict for backend-specific kwargs excluding checkpoints.\n"
|
| 156 |
-
"Examples:\n"
|
| 157 |
-
" transformers: '{\"device_map\":\"cuda:0\",\"dtype\":\"bfloat16\",\"attn_implementation\":\"flash_attention_2\",\"max_inference_batch_size\":32}'\n"
|
| 158 |
-
" vllm : '{\"gpu_memory_utilization\":0.8,\"max_inference_batch_size\":32}'\n"
|
| 159 |
-
),
|
| 160 |
-
)
|
| 161 |
-
parser.add_argument(
|
| 162 |
-
"--aligner-kwargs",
|
| 163 |
-
default=None,
|
| 164 |
-
help=(
|
| 165 |
-
"JSON dict for forced aligner kwargs (only used when --aligner-checkpoint is set).\n"
|
| 166 |
-
"Example: '{\"dtype\":\"bfloat16\",\"device_map\":\"cuda:0\"}'\n"
|
| 167 |
-
),
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
# Gradio server args
|
| 171 |
-
parser.add_argument("--ip", default="0.0.0.0", help="Server bind IP for Gradio (default: 0.0.0.0).")
|
| 172 |
-
parser.add_argument("--port", type=int, default=8000, help="Server port for Gradio (default: 8000).")
|
| 173 |
-
parser.add_argument(
|
| 174 |
-
"--share/--no-share",
|
| 175 |
-
dest="share",
|
| 176 |
-
default=False,
|
| 177 |
-
action=argparse.BooleanOptionalAction,
|
| 178 |
-
help="Whether to create a public Gradio link (default: disabled).",
|
| 179 |
-
)
|
| 180 |
-
parser.add_argument("--concurrency", type=int, default=16, help="Gradio queue concurrency (default: 16).")
|
| 181 |
-
|
| 182 |
-
# HTTPS args
|
| 183 |
-
parser.add_argument("--ssl-certfile", default=None, help="Path to SSL certificate file for HTTPS (optional).")
|
| 184 |
-
parser.add_argument("--ssl-keyfile", default=None, help="Path to SSL key file for HTTPS (optional).")
|
| 185 |
-
parser.add_argument(
|
| 186 |
-
"--ssl-verify/--no-ssl-verify",
|
| 187 |
-
dest="ssl_verify",
|
| 188 |
-
default=True,
|
| 189 |
-
action=argparse.BooleanOptionalAction,
|
| 190 |
-
help="Whether to verify SSL certificate (default: enabled).",
|
| 191 |
-
)
|
| 192 |
-
|
| 193 |
-
return parser
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
def _parse_json_dict(s: Optional[str], *, name: str) -> Dict[str, Any]:
|
| 197 |
-
if s is None or not str(s).strip():
|
| 198 |
-
return {}
|
| 199 |
-
try:
|
| 200 |
-
obj = json.loads(s)
|
| 201 |
-
except Exception as e:
|
| 202 |
-
raise ValueError(f"Invalid JSON for {name}: {e}")
|
| 203 |
-
if not isinstance(obj, dict):
|
| 204 |
-
raise ValueError(f"{name} must be a JSON object (dict).")
|
| 205 |
-
return obj
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
def _apply_cuda_visible_devices(cuda_visible_devices: str) -> None:
|
| 209 |
-
v = (cuda_visible_devices or "").strip()
|
| 210 |
-
if not v:
|
| 211 |
-
return
|
| 212 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = v
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
def _default_backend_kwargs(backend: str) -> Dict[str, Any]:
|
| 216 |
-
if backend == "transformers":
|
| 217 |
-
return dict(
|
| 218 |
-
dtype=torch.bfloat16,
|
| 219 |
-
device_map="cuda:0",
|
| 220 |
-
attn_implementation="flash_attention_2",
|
| 221 |
-
max_inference_batch_size=32,
|
| 222 |
-
)
|
| 223 |
-
else:
|
| 224 |
-
return dict(
|
| 225 |
-
gpu_memory_utilization=0.8,
|
| 226 |
-
max_inference_batch_size=32,
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
def _default_aligner_kwargs() -> Dict[str, Any]:
|
| 231 |
-
return dict(
|
| 232 |
-
dtype=torch.bfloat16,
|
| 233 |
-
device_map="cuda:0",
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
def _merge_dicts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
| 238 |
-
out = dict(base)
|
| 239 |
-
out.update(override)
|
| 240 |
-
return out
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
def _coerce_special_types(d: Dict[str, Any]) -> Dict[str, Any]:
|
| 244 |
-
out: Dict[str, Any] = {}
|
| 245 |
-
for k, v in d.items():
|
| 246 |
-
if k == "dtype" and isinstance(v, str):
|
| 247 |
-
out[k] = _dtype_from_str(v)
|
| 248 |
-
else:
|
| 249 |
-
out[k] = v
|
| 250 |
-
return out
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
def _make_timestamp_html(audio_upload: Any, timestamps: Any) -> str:
|
| 254 |
-
"""
|
| 255 |
-
Build HTML with per-token audio slices, using base64 data URLs (no filesystem caching).
|
| 256 |
-
Expect timestamps as list[dict] with keys: text, start_time, end_time (ms).
|
| 257 |
-
"""
|
| 258 |
-
at = _audio_to_tuple(audio_upload)
|
| 259 |
-
if at is None:
|
| 260 |
-
raise ValueError("Audio input is required for visualization.")
|
| 261 |
-
audio, sr = at
|
| 262 |
-
|
| 263 |
-
if not timestamps:
|
| 264 |
-
return "<div style='color:#666'>No timestamps to visualize.</div>"
|
| 265 |
-
if not isinstance(timestamps, list):
|
| 266 |
-
raise ValueError("Timestamps must be a list (JSON array).")
|
| 267 |
-
|
| 268 |
-
html_content = """
|
| 269 |
-
<style>
|
| 270 |
-
.word-alignment-container { display: flex; flex-wrap: wrap; gap: 10px; }
|
| 271 |
-
.word-box {
|
| 272 |
-
border: 1px solid #ddd; border-radius: 8px; padding: 10px;
|
| 273 |
-
background-color: #f9f9f9; box-shadow: 0 2px 4px rgba(0,0,0,0.06);
|
| 274 |
-
text-align: center;
|
| 275 |
-
}
|
| 276 |
-
.word-text { font-size: 18px; font-weight: 700; margin-bottom: 5px; }
|
| 277 |
-
.word-time { font-size: 12px; color: #666; margin-bottom: 8px; }
|
| 278 |
-
.word-audio audio { width: 140px; height: 30px; }
|
| 279 |
-
details { border: 1px solid #ddd; border-radius: 6px; padding: 10px; background-color: #f7f7f7; }
|
| 280 |
-
summary { font-weight: 700; cursor: pointer; }
|
| 281 |
-
</style>
|
| 282 |
-
"""
|
| 283 |
-
|
| 284 |
-
html_content += """
|
| 285 |
-
<details open>
|
| 286 |
-
<summary>Timestamps Visualization (时间戳可视化结果)</summary>
|
| 287 |
-
<div class="word-alignment-container" style="margin-top: 14px;">
|
| 288 |
-
"""
|
| 289 |
-
|
| 290 |
-
for item in timestamps:
|
| 291 |
-
if not isinstance(item, dict):
|
| 292 |
-
continue
|
| 293 |
-
word = str(item.get("text", "") or "")
|
| 294 |
-
start = item.get("start_time", None)
|
| 295 |
-
end = item.get("end_time", None)
|
| 296 |
-
if start is None or end is None:
|
| 297 |
-
continue
|
| 298 |
-
|
| 299 |
-
start = float(start)
|
| 300 |
-
end = float(end)
|
| 301 |
-
if end <= start:
|
| 302 |
-
continue
|
| 303 |
-
|
| 304 |
-
start_sample = max(0, int(start * sr))
|
| 305 |
-
end_sample = min(len(audio), int(end * sr))
|
| 306 |
-
if end_sample <= start_sample:
|
| 307 |
-
continue
|
| 308 |
-
|
| 309 |
-
seg = audio[start_sample:end_sample]
|
| 310 |
-
seg_i16 = (np.clip(seg, -1.0, 1.0) * 32767.0).astype(np.int16)
|
| 311 |
-
|
| 312 |
-
mem = io.BytesIO()
|
| 313 |
-
wav_write(mem, sr, seg_i16)
|
| 314 |
-
mem.seek(0)
|
| 315 |
-
b64 = base64.b64encode(mem.read()).decode("utf-8")
|
| 316 |
-
audio_src = f"data:audio/wav;base64,{b64}"
|
| 317 |
-
|
| 318 |
-
html_content += f"""
|
| 319 |
-
<div class="word-box">
|
| 320 |
-
<div class="word-text">{word}</div>
|
| 321 |
-
<div class="word-time">{start} - {end} s</div>
|
| 322 |
-
<div class="word-audio">
|
| 323 |
-
<audio controls preload="none" src="{audio_src}"></audio>
|
| 324 |
-
</div>
|
| 325 |
-
</div>
|
| 326 |
-
"""
|
| 327 |
-
|
| 328 |
-
html_content += "</div></details>"
|
| 329 |
-
return html_content
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
def build_demo(
|
| 333 |
-
asr: Qwen3ASRModel,
|
| 334 |
-
asr_ckpt: str,
|
| 335 |
-
backend: str,
|
| 336 |
-
aligner_ckpt: Optional[str] = None,
|
| 337 |
-
) -> gr.Blocks:
|
| 338 |
-
supported_langs_raw = asr.get_supported_languages()
|
| 339 |
-
lang_choices_disp, lang_map = _build_choices_and_map([x for x in supported_langs_raw])
|
| 340 |
-
lang_choices = ["Auto"] + lang_choices_disp
|
| 341 |
-
|
| 342 |
-
has_aligner = bool(aligner_ckpt)
|
| 343 |
-
|
| 344 |
-
theme = gr.themes.Soft(
|
| 345 |
-
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
|
| 346 |
-
)
|
| 347 |
-
css = ".gradio-container {max-width: none !important;}"
|
| 348 |
-
|
| 349 |
-
with gr.Blocks(theme=theme, css=css) as demo:
|
| 350 |
-
gr.Markdown(
|
| 351 |
-
f"""
|
| 352 |
-
# Qwen3 ASR Demo
|
| 353 |
-
**Backend:** `{backend}`
|
| 354 |
-
**ASR Checkpoint:** `{asr_ckpt}`
|
| 355 |
-
**Forced Aligner:** `{aligner_ckpt if aligner_ckpt else "(none)"}`
|
| 356 |
-
"""
|
| 357 |
-
)
|
| 358 |
-
|
| 359 |
-
with gr.Row():
|
| 360 |
-
with gr.Column(scale=2):
|
| 361 |
-
audio_in = gr.Audio(label="Audio Input (上传音频)", type="numpy")
|
| 362 |
-
lang_in = gr.Dropdown(
|
| 363 |
-
label="Language (语种)",
|
| 364 |
-
choices=lang_choices,
|
| 365 |
-
value="Auto",
|
| 366 |
-
interactive=True,
|
| 367 |
-
)
|
| 368 |
-
if has_aligner:
|
| 369 |
-
ts_in = gr.Checkbox(
|
| 370 |
-
label="Return Timestamps (是否返回时间戳)",
|
| 371 |
-
value=True,
|
| 372 |
-
)
|
| 373 |
-
else:
|
| 374 |
-
ts_in = gr.State(False)
|
| 375 |
-
|
| 376 |
-
btn = gr.Button("Transcribe (识别)", variant="primary")
|
| 377 |
-
|
| 378 |
-
with gr.Column(scale=2):
|
| 379 |
-
out_lang = gr.Textbox(label="Detected Language", lines=1)
|
| 380 |
-
out_text = gr.Textbox(label="Result Text", lines=12)
|
| 381 |
-
|
| 382 |
-
if has_aligner:
|
| 383 |
-
with gr.Column(scale=3):
|
| 384 |
-
out_ts = gr.JSON(label="Timestamps(时间戳结果)")
|
| 385 |
-
viz_btn = gr.Button("Visualize Timestamps (可视化时间戳)", variant="secondary")
|
| 386 |
-
else:
|
| 387 |
-
with gr.Column(scale=3):
|
| 388 |
-
out_ts = gr.State(None)
|
| 389 |
-
viz_btn = gr.State(None)
|
| 390 |
-
|
| 391 |
-
# Put the visualization panel below the three columns
|
| 392 |
-
if has_aligner:
|
| 393 |
-
with gr.Row():
|
| 394 |
-
out_ts_html = gr.HTML(label="Timestamps Visualization (时间戳可视化结果)")
|
| 395 |
-
else:
|
| 396 |
-
out_ts_html = gr.State("")
|
| 397 |
-
|
| 398 |
-
def run(audio_upload: Any, lang_disp: str, return_ts: bool):
|
| 399 |
-
audio_obj = _parse_audio_any(audio_upload)
|
| 400 |
-
|
| 401 |
-
language = None
|
| 402 |
-
if lang_disp and lang_disp != "Auto":
|
| 403 |
-
language = lang_map.get(lang_disp, lang_disp)
|
| 404 |
-
|
| 405 |
-
return_ts = bool(return_ts) and has_aligner
|
| 406 |
-
|
| 407 |
-
results = asr.transcribe(
|
| 408 |
-
audio=audio_obj,
|
| 409 |
-
language=language,
|
| 410 |
-
return_time_stamps=return_ts,
|
| 411 |
-
)
|
| 412 |
-
if not isinstance(results, list) or len(results) != 1:
|
| 413 |
-
raise RuntimeError(
|
| 414 |
-
f"Unexpected result size: {type(results)} "
|
| 415 |
-
f"len={len(results) if isinstance(results, list) else 'N/A'}"
|
| 416 |
-
)
|
| 417 |
-
|
| 418 |
-
r = results[0]
|
| 419 |
-
|
| 420 |
-
if has_aligner:
|
| 421 |
-
ts_payload = None
|
| 422 |
-
if return_ts:
|
| 423 |
-
ts_payload = [
|
| 424 |
-
dict(
|
| 425 |
-
text=getattr(t, "text", None),
|
| 426 |
-
start_time=getattr(t, "start_time", None),
|
| 427 |
-
end_time=getattr(t, "end_time", None),
|
| 428 |
-
)
|
| 429 |
-
for t in (getattr(r, "time_stamps", None) or [])
|
| 430 |
-
]
|
| 431 |
-
return (
|
| 432 |
-
getattr(r, "language", "") or "",
|
| 433 |
-
getattr(r, "text", "") or "",
|
| 434 |
-
gr.update(value=ts_payload) if return_ts else gr.update(value=None),
|
| 435 |
-
gr.update(value=""), # clear html on each transcribe
|
| 436 |
-
)
|
| 437 |
-
else:
|
| 438 |
-
return (
|
| 439 |
-
getattr(r, "language", "") or "",
|
| 440 |
-
getattr(r, "text", "") or "",
|
| 441 |
-
)
|
| 442 |
-
|
| 443 |
-
def visualize(audio_upload: Any, timestamps_json: Any):
|
| 444 |
-
return _make_timestamp_html(audio_upload, timestamps_json)
|
| 445 |
-
|
| 446 |
-
if has_aligner:
|
| 447 |
-
btn.click(
|
| 448 |
-
run,
|
| 449 |
-
inputs=[audio_in, lang_in, ts_in],
|
| 450 |
-
outputs=[out_lang, out_text, out_ts, out_ts_html],
|
| 451 |
-
)
|
| 452 |
-
viz_btn.click(
|
| 453 |
-
visualize,
|
| 454 |
-
inputs=[audio_in, out_ts],
|
| 455 |
-
outputs=[out_ts_html],
|
| 456 |
-
)
|
| 457 |
-
else:
|
| 458 |
-
btn.click(
|
| 459 |
-
run,
|
| 460 |
-
inputs=[audio_in, lang_in, ts_in],
|
| 461 |
-
outputs=[out_lang, out_text],
|
| 462 |
-
)
|
| 463 |
-
|
| 464 |
-
return demo
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
def main(argv=None) -> int:
|
| 468 |
-
parser = build_parser()
|
| 469 |
-
args = parser.parse_args(argv)
|
| 470 |
-
|
| 471 |
-
_apply_cuda_visible_devices(args.cuda_visible_devices)
|
| 472 |
-
|
| 473 |
-
backend = args.backend
|
| 474 |
-
asr_ckpt = args.asr_checkpoint
|
| 475 |
-
aligner_ckpt = args.aligner_checkpoint
|
| 476 |
-
|
| 477 |
-
user_backend_kwargs = _parse_json_dict(args.backend_kwargs, name="--backend-kwargs")
|
| 478 |
-
user_aligner_kwargs = _parse_json_dict(args.aligner_kwargs, name="--aligner-kwargs")
|
| 479 |
-
|
| 480 |
-
backend_kwargs = _merge_dicts(_default_backend_kwargs(backend), user_backend_kwargs)
|
| 481 |
-
backend_kwargs = _coerce_special_types(backend_kwargs)
|
| 482 |
-
|
| 483 |
-
forced_aligner = None
|
| 484 |
-
forced_aligner_kwargs = None
|
| 485 |
-
if aligner_ckpt:
|
| 486 |
-
forced_aligner = aligner_ckpt
|
| 487 |
-
aligner_kwargs = _merge_dicts(_default_aligner_kwargs(), user_aligner_kwargs)
|
| 488 |
-
forced_aligner_kwargs = _coerce_special_types(aligner_kwargs)
|
| 489 |
-
|
| 490 |
-
if backend == "transformers":
|
| 491 |
-
asr = Qwen3ASRModel.from_pretrained(
|
| 492 |
-
asr_ckpt,
|
| 493 |
-
forced_aligner=forced_aligner,
|
| 494 |
-
forced_aligner_kwargs=forced_aligner_kwargs,
|
| 495 |
-
**backend_kwargs,
|
| 496 |
-
)
|
| 497 |
-
else:
|
| 498 |
-
asr = Qwen3ASRModel.LLM(
|
| 499 |
-
model=asr_ckpt,
|
| 500 |
-
forced_aligner=forced_aligner,
|
| 501 |
-
forced_aligner_kwargs=forced_aligner_kwargs,
|
| 502 |
-
**backend_kwargs,
|
| 503 |
-
)
|
| 504 |
-
|
| 505 |
-
demo = build_demo(asr, asr_ckpt, backend, aligner_ckpt=aligner_ckpt)
|
| 506 |
-
|
| 507 |
-
launch_kwargs: Dict[str, Any] = dict(
|
| 508 |
-
server_name=args.ip,
|
| 509 |
-
server_port=args.port,
|
| 510 |
-
share=args.share,
|
| 511 |
-
ssl_verify=True if args.ssl_verify else False,
|
| 512 |
-
)
|
| 513 |
-
if args.ssl_certfile is not None:
|
| 514 |
-
launch_kwargs["ssl_certfile"] = args.ssl_certfile
|
| 515 |
-
if args.ssl_keyfile is not None:
|
| 516 |
-
launch_kwargs["ssl_keyfile"] = args.ssl_keyfile
|
| 517 |
-
|
| 518 |
-
demo.queue(default_concurrency_limit=int(args.concurrency)).launch(**launch_kwargs)
|
| 519 |
-
return 0
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
if __name__ == "__main__":
|
| 523 |
-
raise SystemExit(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_asr/cli/serve.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
import sys
|
| 17 |
-
|
| 18 |
-
from qwen_asr.core.transformers_backend import (
|
| 19 |
-
Qwen3ASRConfig,
|
| 20 |
-
Qwen3ASRForConditionalGeneration,
|
| 21 |
-
Qwen3ASRProcessor,
|
| 22 |
-
)
|
| 23 |
-
from transformers import AutoConfig, AutoModel, AutoProcessor
|
| 24 |
-
|
| 25 |
-
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
|
| 26 |
-
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
|
| 27 |
-
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
|
| 28 |
-
|
| 29 |
-
try:
|
| 30 |
-
from qwen_asr.core.vllm_backend import Qwen3ASRForConditionalGeneration
|
| 31 |
-
from vllm import ModelRegistry
|
| 32 |
-
ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", Qwen3ASRForConditionalGeneration)
|
| 33 |
-
except Exception as e:
|
| 34 |
-
raise ImportError(
|
| 35 |
-
"vLLM is not available, to use qwen-asr-serve, please install with: pip install qwen-asr[vllm]"
|
| 36 |
-
) from e
|
| 37 |
-
|
| 38 |
-
from vllm.entrypoints.cli.main import main as vllm_main
|
| 39 |
-
|
| 40 |
-
def main():
|
| 41 |
-
sys.argv.insert(1, "serve")
|
| 42 |
-
vllm_main()
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
if __name__ == "__main__":
|
| 46 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_asr/core/transformers_backend/__init__.py
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
from .configuration_qwen3_asr import Qwen3ASRConfig
|
| 17 |
-
from .modeling_qwen3_asr import Qwen3ASRForConditionalGeneration
|
| 18 |
-
from .processing_qwen3_asr import Qwen3ASRProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_asr/core/transformers_backend/configuration_qwen3_asr.py
DELETED
|
@@ -1,425 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
from transformers.configuration_utils import PretrainedConfig
|
| 16 |
-
from transformers.utils import logging
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
logger = logging.get_logger(__name__)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class Qwen3ASRAudioEncoderConfig(PretrainedConfig):
|
| 23 |
-
r"""
|
| 24 |
-
This is the configuration class to store the configuration of a [`Qwen3ASRAudioEncoder`]. It is used to instantiate a
|
| 25 |
-
Qwen3-ASR audio encoder according to the specified arguments, defining the model architecture. Instantiating a
|
| 26 |
-
configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2-Audio
|
| 27 |
-
architecture.
|
| 28 |
-
|
| 29 |
-
e.g. [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
|
| 30 |
-
|
| 31 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
-
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
-
|
| 34 |
-
Args:
|
| 35 |
-
num_mel_bins (`int`, *optional*, defaults to 128):
|
| 36 |
-
Number of mel features used per input features. Should correspond to the value used in the
|
| 37 |
-
`Qwen3ASRProcessor` class.
|
| 38 |
-
encoder_layers (`int`, *optional*, defaults to 32):
|
| 39 |
-
Number of encoder layers.
|
| 40 |
-
encoder_attention_heads (`int`, *optional*, defaults to 20):
|
| 41 |
-
Number of attention heads for each attention layer in the Transformer encoder.
|
| 42 |
-
encoder_ffn_dim (`int`, *optional*, defaults to 5120):
|
| 43 |
-
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
|
| 44 |
-
d_model (`int`, *optional*, defaults to 1280):
|
| 45 |
-
Dimensionality of the layers.
|
| 46 |
-
dropout (`float`, *optional*, defaults to 0.0):
|
| 47 |
-
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 48 |
-
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 49 |
-
The dropout ratio for the attention probabilities.
|
| 50 |
-
activation_function (`str`, *optional*, defaults to `"gelu"`):
|
| 51 |
-
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 52 |
-
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 53 |
-
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 54 |
-
The dropout ratio for activations inside the fully connected layer.
|
| 55 |
-
scale_embedding (`bool`, *optional*, defaults to `False`):
|
| 56 |
-
Scale embeddings by diving by sqrt(d_model).
|
| 57 |
-
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 58 |
-
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 59 |
-
max_source_positions (`int`, *optional*, defaults to 1500):
|
| 60 |
-
The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
|
| 61 |
-
n_window (`int`, *optional*, defaults to 100):
|
| 62 |
-
The chunk for conv and flash attn in AudioEncoder.
|
| 63 |
-
output_dim (`int`, *optional*, defaults to 3584):
|
| 64 |
-
The output dimension of AudioEncoder.
|
| 65 |
-
|
| 66 |
-
Example:
|
| 67 |
-
|
| 68 |
-
```python
|
| 69 |
-
>>> from transformers import Qwen3ASRAudioEncoderConfig, Qwen3ASRAudioEncoder
|
| 70 |
-
|
| 71 |
-
>>> # Initializing a Qwen3ASRAudioEncoderConfig
|
| 72 |
-
>>> configuration = Qwen3ASRAudioEncoderConfig()
|
| 73 |
-
|
| 74 |
-
>>> # Initializing a Qwen3ASRAudioEncoder (with random weights)
|
| 75 |
-
>>> model = Qwen3ASRAudioEncoder(configuration)
|
| 76 |
-
|
| 77 |
-
>>> # Accessing the model configuration
|
| 78 |
-
>>> configuration = model.config
|
| 79 |
-
```"""
|
| 80 |
-
|
| 81 |
-
model_type = "qwen3_asr_audio_encoder"
|
| 82 |
-
|
| 83 |
-
def __init__(
|
| 84 |
-
self,
|
| 85 |
-
num_mel_bins=128,
|
| 86 |
-
encoder_layers=32,
|
| 87 |
-
encoder_attention_heads=20,
|
| 88 |
-
encoder_ffn_dim=5120,
|
| 89 |
-
d_model=1280,
|
| 90 |
-
dropout=0,
|
| 91 |
-
attention_dropout=0,
|
| 92 |
-
activation_function="gelu",
|
| 93 |
-
activation_dropout=0,
|
| 94 |
-
scale_embedding=False,
|
| 95 |
-
initializer_range=0.02,
|
| 96 |
-
max_source_positions=1500,
|
| 97 |
-
n_window=100,
|
| 98 |
-
output_dim=3584,
|
| 99 |
-
n_window_infer=400,
|
| 100 |
-
conv_chunksize=500,
|
| 101 |
-
downsample_hidden_size=480,
|
| 102 |
-
**kwargs,
|
| 103 |
-
):
|
| 104 |
-
super().__init__(**kwargs)
|
| 105 |
-
|
| 106 |
-
self.num_mel_bins = num_mel_bins
|
| 107 |
-
self.d_model = d_model
|
| 108 |
-
self.encoder_layers = encoder_layers
|
| 109 |
-
self.encoder_attention_heads = encoder_attention_heads
|
| 110 |
-
self.encoder_ffn_dim = encoder_ffn_dim
|
| 111 |
-
self.dropout = dropout
|
| 112 |
-
self.attention_dropout = attention_dropout
|
| 113 |
-
self.activation_function = activation_function
|
| 114 |
-
self.activation_dropout = activation_dropout
|
| 115 |
-
self.num_hidden_layers = encoder_layers
|
| 116 |
-
self.initializer_range = initializer_range
|
| 117 |
-
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
| 118 |
-
self.max_source_positions = max_source_positions
|
| 119 |
-
self.n_window = n_window
|
| 120 |
-
self.output_dim = output_dim
|
| 121 |
-
self.n_window_infer = n_window_infer
|
| 122 |
-
self.conv_chunksize = conv_chunksize
|
| 123 |
-
self.downsample_hidden_size = downsample_hidden_size
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
class Qwen3ASRTextConfig(PretrainedConfig):
|
| 127 |
-
r"""
|
| 128 |
-
This is the configuration class to store the configuration of a [`Qwen3ASRTextModel`]. It is used to instantiate a
|
| 129 |
-
Qwen3-ASR model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 130 |
-
with the defaults will yield a similar configuration to that of
|
| 131 |
-
Qwen3-ASR-1.7B [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
|
| 132 |
-
|
| 133 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 134 |
-
documentation from [`PretrainedConfig`] for more information.
|
| 135 |
-
|
| 136 |
-
Args:
|
| 137 |
-
vocab_size (`int`, *optional*, defaults to 151936):
|
| 138 |
-
Vocabulary size of the Qwen3ASR model. Defines the number of different tokens that can be represented by the
|
| 139 |
-
`inputs_ids` passed when calling [`Qwen3ASRModel`]
|
| 140 |
-
hidden_size (`int`, *optional*, defaults to 4096):
|
| 141 |
-
Dimension of the hidden representations.
|
| 142 |
-
intermediate_size (`int`, *optional*, defaults to 22016):
|
| 143 |
-
Dimension of the MLP representations.
|
| 144 |
-
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 145 |
-
Number of hidden layers in the Transformer encoder.
|
| 146 |
-
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 147 |
-
Number of attention heads for each attention layer in the Transformer encoder.
|
| 148 |
-
num_key_value_heads (`int`, *optional*, defaults to 32):
|
| 149 |
-
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 150 |
-
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 151 |
-
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 152 |
-
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 153 |
-
by meanpooling all the original heads within that group. For more details, check out [this
|
| 154 |
-
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
| 155 |
-
head_dim (`int`, *optional*, defaults to 128):
|
| 156 |
-
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
|
| 157 |
-
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 158 |
-
The non-linear activation function (function or string) in the decoder.
|
| 159 |
-
max_position_embeddings (`int`, *optional*, defaults to 128000):
|
| 160 |
-
The maximum sequence length that this model might ever be used with.
|
| 161 |
-
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 162 |
-
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 163 |
-
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 164 |
-
The epsilon used by the rms normalization layers.
|
| 165 |
-
use_cache (`bool`, *optional*, defaults to `True`):
|
| 166 |
-
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 167 |
-
relevant if `config.is_decoder=True`.
|
| 168 |
-
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 169 |
-
Whether the model's input and output word embeddings should be tied.
|
| 170 |
-
rope_theta (`float`, *optional*, defaults to 5000000.0):
|
| 171 |
-
The base period of the RoPE embeddings.
|
| 172 |
-
rope_scaling (`Dict`, *optional*):
|
| 173 |
-
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 174 |
-
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 175 |
-
accordingly.
|
| 176 |
-
Expected contents:
|
| 177 |
-
`rope_type` (`str`):
|
| 178 |
-
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 179 |
-
'llama3'], with 'default' being the original RoPE implementation.
|
| 180 |
-
`factor` (`float`, *optional*):
|
| 181 |
-
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 182 |
-
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 183 |
-
original maximum pre-trained length.
|
| 184 |
-
`original_max_position_embeddings` (`int`, *optional*):
|
| 185 |
-
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 186 |
-
pretraining.
|
| 187 |
-
`attention_factor` (`float`, *optional*):
|
| 188 |
-
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 189 |
-
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 190 |
-
`factor` field to infer the suggested value.
|
| 191 |
-
`beta_fast` (`float`, *optional*):
|
| 192 |
-
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 193 |
-
ramp function. If unspecified, it defaults to 32.
|
| 194 |
-
`beta_slow` (`float`, *optional*):
|
| 195 |
-
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 196 |
-
ramp function. If unspecified, it defaults to 1.
|
| 197 |
-
`short_factor` (`list[float]`, *optional*):
|
| 198 |
-
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 199 |
-
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 200 |
-
size divided by the number of attention heads divided by 2
|
| 201 |
-
`long_factor` (`list[float]`, *optional*):
|
| 202 |
-
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 203 |
-
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 204 |
-
size divided by the number of attention heads divided by 2
|
| 205 |
-
`low_freq_factor` (`float`, *optional*):
|
| 206 |
-
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 207 |
-
`high_freq_factor` (`float`, *optional*):
|
| 208 |
-
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 209 |
-
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 210 |
-
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 211 |
-
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 212 |
-
The dropout ratio for the attention probabilities.
|
| 213 |
-
|
| 214 |
-
```python
|
| 215 |
-
>>> from transformers import Qwen3ASRTextModel, Qwen3ASRTextConfig
|
| 216 |
-
|
| 217 |
-
>>> # Initializing a Qwen3ASR style configuration
|
| 218 |
-
>>> configuration = Qwen3ASRTextConfig()
|
| 219 |
-
|
| 220 |
-
>>> # Initializing a model from the Qwen3-VL-7B style configuration
|
| 221 |
-
>>> model = Qwen3ASRTextModel(configuration)
|
| 222 |
-
|
| 223 |
-
>>> # Accessing the model configuration
|
| 224 |
-
>>> configuration = model.config
|
| 225 |
-
```"""
|
| 226 |
-
|
| 227 |
-
model_type = "qwen3_asr_text"
|
| 228 |
-
base_config_key = "text_config"
|
| 229 |
-
|
| 230 |
-
def __init__(
|
| 231 |
-
self,
|
| 232 |
-
vocab_size=151936,
|
| 233 |
-
hidden_size=4096,
|
| 234 |
-
intermediate_size=22016,
|
| 235 |
-
num_hidden_layers=32,
|
| 236 |
-
num_attention_heads=32,
|
| 237 |
-
num_key_value_heads=32,
|
| 238 |
-
head_dim=128,
|
| 239 |
-
hidden_act="silu",
|
| 240 |
-
max_position_embeddings=128000,
|
| 241 |
-
initializer_range=0.02,
|
| 242 |
-
rms_norm_eps=1e-6,
|
| 243 |
-
use_cache=True,
|
| 244 |
-
tie_word_embeddings=False,
|
| 245 |
-
rope_theta=5000000.0,
|
| 246 |
-
rope_scaling=None,
|
| 247 |
-
attention_bias=False,
|
| 248 |
-
attention_dropout=0.0,
|
| 249 |
-
**kwargs,
|
| 250 |
-
):
|
| 251 |
-
self.vocab_size = vocab_size
|
| 252 |
-
self.max_position_embeddings = max_position_embeddings
|
| 253 |
-
self.hidden_size = hidden_size
|
| 254 |
-
self.intermediate_size = intermediate_size
|
| 255 |
-
self.num_hidden_layers = num_hidden_layers
|
| 256 |
-
self.num_attention_heads = num_attention_heads
|
| 257 |
-
|
| 258 |
-
# for backward compatibility
|
| 259 |
-
if num_key_value_heads is None:
|
| 260 |
-
num_key_value_heads = num_attention_heads
|
| 261 |
-
|
| 262 |
-
self.num_key_value_heads = num_key_value_heads
|
| 263 |
-
self.head_dim = head_dim
|
| 264 |
-
self.hidden_act = hidden_act
|
| 265 |
-
self.initializer_range = initializer_range
|
| 266 |
-
self.rms_norm_eps = rms_norm_eps
|
| 267 |
-
self.use_cache = use_cache
|
| 268 |
-
self.rope_theta = rope_theta
|
| 269 |
-
self.rope_scaling = rope_scaling
|
| 270 |
-
self.attention_bias = attention_bias
|
| 271 |
-
self.attention_dropout = attention_dropout
|
| 272 |
-
# Validate the correctness of rotary position embeddings parameters
|
| 273 |
-
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 274 |
-
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 275 |
-
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 276 |
-
|
| 277 |
-
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
class Qwen3ASRThinkerConfig(PretrainedConfig):
|
| 281 |
-
r"""
|
| 282 |
-
This is the configuration class to store the configuration of a [`Qwen3ASRThinker`]. It is used to instantiate a
|
| 283 |
-
Qwen3-ASR-Thinker model according to the specified arguments, defining the model architecture. Instantiating a
|
| 284 |
-
configuration with the defaults will yield a similar configuration to that of the thinker component of the Qwen3-Omni
|
| 285 |
-
architecture.
|
| 286 |
-
|
| 287 |
-
e.g. [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
|
| 288 |
-
|
| 289 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 290 |
-
documentation from [`PretrainedConfig`] for more information.
|
| 291 |
-
|
| 292 |
-
Args:
|
| 293 |
-
audio_config (`dict`, *optional*):
|
| 294 |
-
The config dictionary of the audio backbone.
|
| 295 |
-
text_config (`dict`, *optional*):
|
| 296 |
-
The config dictionary of the text backbone.
|
| 297 |
-
audio_token_id (`int`, *optional*, defaults to 151646):
|
| 298 |
-
The audio token id to encode the audio prompt.
|
| 299 |
-
audio_start_token_id (`int`, *optional*, defaults to 151647):
|
| 300 |
-
The audio start token id to encode the audio prompt.
|
| 301 |
-
user_token_id (`int`, *optional*, defaults to 872):
|
| 302 |
-
The user token id to encode the user token.
|
| 303 |
-
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 304 |
-
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 305 |
-
|
| 306 |
-
Example:
|
| 307 |
-
|
| 308 |
-
```python
|
| 309 |
-
>>> from transformers import Qwen3ASRThinkerModel, Qwen3ASRThinkerConfig
|
| 310 |
-
|
| 311 |
-
>>> # Initializing a default Qwen3ASRThinkerConfig
|
| 312 |
-
>>> configuration = Qwen3ASRThinkerConfig()
|
| 313 |
-
|
| 314 |
-
>>> # Initializing a model (with random weights) from the default configuration
|
| 315 |
-
>>> model = Qwen3ASRThinkerModel(configuration)
|
| 316 |
-
|
| 317 |
-
>>> # Accessing the model configuration
|
| 318 |
-
>>> configuration = model.config
|
| 319 |
-
```"""
|
| 320 |
-
|
| 321 |
-
model_type = "qwen3_asr_thinker"
|
| 322 |
-
|
| 323 |
-
attribute_map = {}
|
| 324 |
-
sub_configs = {
|
| 325 |
-
"audio_config": Qwen3ASRAudioEncoderConfig,
|
| 326 |
-
"text_config": Qwen3ASRTextConfig,
|
| 327 |
-
}
|
| 328 |
-
|
| 329 |
-
def __init__(
|
| 330 |
-
self,
|
| 331 |
-
audio_config=None,
|
| 332 |
-
text_config=None,
|
| 333 |
-
audio_token_id=151646,
|
| 334 |
-
audio_start_token_id=151647,
|
| 335 |
-
user_token_id=872,
|
| 336 |
-
initializer_range=0.02,
|
| 337 |
-
**kwargs,
|
| 338 |
-
):
|
| 339 |
-
super().__init__(**kwargs)
|
| 340 |
-
self.user_token_id = user_token_id
|
| 341 |
-
self.audio_start_token_id = audio_start_token_id
|
| 342 |
-
self.initializer_range = initializer_range
|
| 343 |
-
|
| 344 |
-
if isinstance(audio_config, dict):
|
| 345 |
-
audio_config = Qwen3ASRAudioEncoderConfig(**audio_config)
|
| 346 |
-
elif audio_config is None:
|
| 347 |
-
audio_config = Qwen3ASRAudioEncoderConfig()
|
| 348 |
-
self.audio_config = audio_config
|
| 349 |
-
|
| 350 |
-
if isinstance(text_config, dict):
|
| 351 |
-
text_config = Qwen3ASRTextConfig(**text_config)
|
| 352 |
-
elif text_config is None:
|
| 353 |
-
text_config = Qwen3ASRTextConfig()
|
| 354 |
-
self.text_config = text_config
|
| 355 |
-
self.audio_token_id = audio_token_id
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
class Qwen3ASRConfig(PretrainedConfig):
|
| 359 |
-
"""
|
| 360 |
-
This is the configuration class to store the configuration of a [`Qwen3ASRForConditionalGeneration`]. It is used to instantiate a Qwen3ASR
|
| 361 |
-
model according to the specified sub-models configurations, defining the model architecture.
|
| 362 |
-
|
| 363 |
-
Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
| 364 |
-
[Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B) architecture.
|
| 365 |
-
|
| 366 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 367 |
-
documentation from [`PretrainedConfig`] for more information.
|
| 368 |
-
|
| 369 |
-
Args:
|
| 370 |
-
thinker_config (`dict`, *optional*): Configuration of the underlying thinker sub-model.
|
| 371 |
-
support_languages (`List[str]`, *optional*): The languages supported by the model.
|
| 372 |
-
|
| 373 |
-
Example:
|
| 374 |
-
|
| 375 |
-
```python
|
| 376 |
-
>>> from transformers import (
|
| 377 |
-
... Qwen3ASRThinkerConfig,
|
| 378 |
-
... Qwen3ASRForConditionalGeneration,
|
| 379 |
-
... Qwen3ASRConfig,
|
| 380 |
-
... )
|
| 381 |
-
|
| 382 |
-
>>> # Initializing a Qwen3ASR style configuration
|
| 383 |
-
>>> configuration = Qwen3ASRConfig()
|
| 384 |
-
|
| 385 |
-
>>> # Initializing a model from the configuration
|
| 386 |
-
>>> model = Qwen3ASRForConditionalGeneration(configuration)
|
| 387 |
-
|
| 388 |
-
>>> # Accessing the model configuration
|
| 389 |
-
>>> configuration = model.config
|
| 390 |
-
```"""
|
| 391 |
-
|
| 392 |
-
model_type = "qwen3_asr"
|
| 393 |
-
sub_configs = {
|
| 394 |
-
"thinker_config": Qwen3ASRThinkerConfig,
|
| 395 |
-
}
|
| 396 |
-
|
| 397 |
-
def __init__(
|
| 398 |
-
self,
|
| 399 |
-
thinker_config=None,
|
| 400 |
-
support_languages=None,
|
| 401 |
-
**kwargs,
|
| 402 |
-
):
|
| 403 |
-
super().__init__(**kwargs)
|
| 404 |
-
if thinker_config is None:
|
| 405 |
-
thinker_config = {}
|
| 406 |
-
|
| 407 |
-
self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config)
|
| 408 |
-
self.support_languages = support_languages
|
| 409 |
-
|
| 410 |
-
def get_text_config(self, decoder=False) -> "PretrainedConfig":
|
| 411 |
-
"""
|
| 412 |
-
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
| 413 |
-
itself. On specific composite models, it is under a set of valid names.
|
| 414 |
-
|
| 415 |
-
Args:
|
| 416 |
-
decoder (`Optional[bool]`, *optional*, defaults to `False`):
|
| 417 |
-
If set to `True`, then only search for decoder config names.
|
| 418 |
-
"""
|
| 419 |
-
# Overridden for deeply nested config like Qwen2.5-Omni. We don't have any omni model
|
| 420 |
-
# except for Qwen yet. This has to be generalized if more deeply nested configs are
|
| 421 |
-
# added. NOTE: currently method used only by vLLM
|
| 422 |
-
return self.thinker_config.get_text_config()
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
__all__ = ["Qwen3ASRConfig", "Qwen3ASRThinkerConfig", "Qwen3ASRAudioEncoderConfig"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_asr/core/transformers_backend/modeling_qwen3_asr.py
DELETED
|
@@ -1,1361 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
import math
|
| 16 |
-
from dataclasses import dataclass
|
| 17 |
-
from typing import Callable, Optional, Union
|
| 18 |
-
|
| 19 |
-
import numpy as np
|
| 20 |
-
import torch
|
| 21 |
-
from torch import nn
|
| 22 |
-
from torch.nn import functional as F
|
| 23 |
-
|
| 24 |
-
from transformers.activations import ACT2FN
|
| 25 |
-
from transformers.cache_utils import Cache, DynamicCache
|
| 26 |
-
from transformers.generation import GenerationMixin
|
| 27 |
-
from transformers.integrations import use_kernel_forward_from_hub
|
| 28 |
-
from transformers.masking_utils import create_causal_mask
|
| 29 |
-
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 30 |
-
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 31 |
-
from transformers.modeling_outputs import (
|
| 32 |
-
BaseModelOutput,
|
| 33 |
-
BaseModelOutputWithPast,
|
| 34 |
-
MoeCausalLMOutputWithPast,
|
| 35 |
-
)
|
| 36 |
-
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 37 |
-
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 38 |
-
from transformers.processing_utils import Unpack
|
| 39 |
-
from transformers.utils import auto_docstring, can_return_tuple
|
| 40 |
-
from transformers.utils.deprecation import deprecate_kwarg
|
| 41 |
-
from transformers.utils.generic import TransformersKwargs, check_model_inputs
|
| 42 |
-
|
| 43 |
-
from .configuration_qwen3_asr import (
|
| 44 |
-
Qwen3ASRAudioEncoderConfig,
|
| 45 |
-
Qwen3ASRConfig,
|
| 46 |
-
Qwen3ASRThinkerConfig,
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
@use_kernel_forward_from_hub("RMSNorm")
|
| 51 |
-
class Qwen3ASRTextRMSNorm(nn.Module):
|
| 52 |
-
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
|
| 53 |
-
"""
|
| 54 |
-
Qwen3ASRTextRMSNorm is equivalent to T5LayerNorm
|
| 55 |
-
"""
|
| 56 |
-
super().__init__()
|
| 57 |
-
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 58 |
-
self.variance_epsilon = eps
|
| 59 |
-
|
| 60 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 61 |
-
input_dtype = hidden_states.dtype
|
| 62 |
-
hidden_states = hidden_states.to(torch.float32)
|
| 63 |
-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 64 |
-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 65 |
-
return self.weight * hidden_states.to(input_dtype)
|
| 66 |
-
|
| 67 |
-
def extra_repr(self):
|
| 68 |
-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def rotate_half(x):
|
| 72 |
-
"""Rotates half the hidden dims of the input."""
|
| 73 |
-
x1 = x[..., : x.shape[-1] // 2]
|
| 74 |
-
x2 = x[..., x.shape[-1] // 2 :]
|
| 75 |
-
return torch.cat((-x2, x1), dim=-1)
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 79 |
-
"""
|
| 80 |
-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 81 |
-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 82 |
-
"""
|
| 83 |
-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 84 |
-
if n_rep == 1:
|
| 85 |
-
return hidden_states
|
| 86 |
-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 87 |
-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def eager_attention_forward(
|
| 91 |
-
module: nn.Module,
|
| 92 |
-
query: torch.Tensor,
|
| 93 |
-
key: torch.Tensor,
|
| 94 |
-
value: torch.Tensor,
|
| 95 |
-
attention_mask: Optional[torch.Tensor],
|
| 96 |
-
scaling: float,
|
| 97 |
-
dropout: float = 0.0,
|
| 98 |
-
**kwargs: Unpack[TransformersKwargs],
|
| 99 |
-
):
|
| 100 |
-
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 101 |
-
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 102 |
-
|
| 103 |
-
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 104 |
-
if attention_mask is not None:
|
| 105 |
-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 106 |
-
attn_weights = attn_weights + causal_mask
|
| 107 |
-
|
| 108 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 109 |
-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 110 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
| 111 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 112 |
-
|
| 113 |
-
return attn_output, attn_weights
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 117 |
-
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 118 |
-
|
| 119 |
-
Args:
|
| 120 |
-
q (`torch.Tensor`): The query tensor.
|
| 121 |
-
k (`torch.Tensor`): The key tensor.
|
| 122 |
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 123 |
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 124 |
-
position_ids (`torch.Tensor`, *optional*):
|
| 125 |
-
Deprecated and unused.
|
| 126 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 127 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 128 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 129 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 130 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 131 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 132 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 133 |
-
Returns:
|
| 134 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 135 |
-
"""
|
| 136 |
-
cos = cos.unsqueeze(unsqueeze_dim)
|
| 137 |
-
sin = sin.unsqueeze(unsqueeze_dim)
|
| 138 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 139 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 140 |
-
return q_embed, k_embed
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
class Qwen3ASRTextAttention(nn.Module):
|
| 144 |
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 145 |
-
|
| 146 |
-
def __init__(self, config: Qwen3ASRConfig, layer_idx: int):
|
| 147 |
-
super().__init__()
|
| 148 |
-
self.config = config
|
| 149 |
-
self.layer_idx = layer_idx
|
| 150 |
-
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 151 |
-
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 152 |
-
self.scaling = self.head_dim**-0.5
|
| 153 |
-
self.attention_dropout = config.attention_dropout
|
| 154 |
-
self.is_causal = True
|
| 155 |
-
|
| 156 |
-
self.q_proj = nn.Linear(
|
| 157 |
-
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 158 |
-
)
|
| 159 |
-
self.k_proj = nn.Linear(
|
| 160 |
-
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 161 |
-
)
|
| 162 |
-
self.v_proj = nn.Linear(
|
| 163 |
-
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 164 |
-
)
|
| 165 |
-
self.o_proj = nn.Linear(
|
| 166 |
-
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 167 |
-
)
|
| 168 |
-
self.q_norm = Qwen3ASRTextRMSNorm(
|
| 169 |
-
self.head_dim, eps=config.rms_norm_eps
|
| 170 |
-
) # unlike olmo, only on the head dim!
|
| 171 |
-
self.k_norm = Qwen3ASRTextRMSNorm(
|
| 172 |
-
self.head_dim, eps=config.rms_norm_eps
|
| 173 |
-
) # thus post q_norm does not need reshape
|
| 174 |
-
|
| 175 |
-
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 176 |
-
def forward(
|
| 177 |
-
self,
|
| 178 |
-
hidden_states: torch.Tensor,
|
| 179 |
-
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 180 |
-
attention_mask: Optional[torch.Tensor],
|
| 181 |
-
past_key_values: Optional[Cache] = None,
|
| 182 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 183 |
-
**kwargs: Unpack[FlashAttentionKwargs],
|
| 184 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 185 |
-
input_shape = hidden_states.shape[:-1]
|
| 186 |
-
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 187 |
-
|
| 188 |
-
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 189 |
-
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 190 |
-
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 191 |
-
|
| 192 |
-
cos, sin = position_embeddings
|
| 193 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 194 |
-
|
| 195 |
-
if past_key_values is not None:
|
| 196 |
-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 197 |
-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 198 |
-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 199 |
-
|
| 200 |
-
attention_interface: Callable = eager_attention_forward
|
| 201 |
-
if self.config._attn_implementation != "eager":
|
| 202 |
-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 203 |
-
|
| 204 |
-
attn_output, attn_weights = attention_interface(
|
| 205 |
-
self,
|
| 206 |
-
query_states,
|
| 207 |
-
key_states,
|
| 208 |
-
value_states,
|
| 209 |
-
attention_mask,
|
| 210 |
-
dropout=0.0 if not self.training else self.attention_dropout,
|
| 211 |
-
scaling=self.scaling,
|
| 212 |
-
**kwargs,
|
| 213 |
-
)
|
| 214 |
-
|
| 215 |
-
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 216 |
-
attn_output = self.o_proj(attn_output)
|
| 217 |
-
return attn_output, attn_weights
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
class Qwen3ASRTextMLP(nn.Module):
|
| 221 |
-
def __init__(self, config):
|
| 222 |
-
super().__init__()
|
| 223 |
-
self.config = config
|
| 224 |
-
self.hidden_size = config.hidden_size
|
| 225 |
-
self.intermediate_size = config.intermediate_size
|
| 226 |
-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 227 |
-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 228 |
-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 229 |
-
self.act_fn = ACT2FN[config.hidden_act]
|
| 230 |
-
|
| 231 |
-
def forward(self, x):
|
| 232 |
-
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 233 |
-
return down_proj
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
class Qwen3ASRThinkerTextDecoderLayer(GradientCheckpointingLayer):
|
| 237 |
-
def __init__(self, config: Qwen3ASRConfig, layer_idx: int):
|
| 238 |
-
super().__init__()
|
| 239 |
-
self.hidden_size = config.hidden_size
|
| 240 |
-
|
| 241 |
-
self.self_attn = Qwen3ASRTextAttention(config=config, layer_idx=layer_idx)
|
| 242 |
-
|
| 243 |
-
self.mlp = Qwen3ASRTextMLP(config)
|
| 244 |
-
self.input_layernorm = Qwen3ASRTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 245 |
-
self.post_attention_layernorm = Qwen3ASRTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 246 |
-
|
| 247 |
-
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 248 |
-
def forward(
|
| 249 |
-
self,
|
| 250 |
-
hidden_states: torch.Tensor,
|
| 251 |
-
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 252 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 253 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 254 |
-
past_key_values: Optional[Cache] = None,
|
| 255 |
-
use_cache: Optional[bool] = False,
|
| 256 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 257 |
-
**kwargs: Unpack[TransformersKwargs],
|
| 258 |
-
) -> torch.Tensor:
|
| 259 |
-
residual = hidden_states
|
| 260 |
-
hidden_states = self.input_layernorm(hidden_states)
|
| 261 |
-
# Self Attention
|
| 262 |
-
hidden_states, _ = self.self_attn(
|
| 263 |
-
hidden_states=hidden_states,
|
| 264 |
-
attention_mask=attention_mask,
|
| 265 |
-
position_ids=position_ids,
|
| 266 |
-
past_key_values=past_key_values,
|
| 267 |
-
use_cache=use_cache,
|
| 268 |
-
cache_position=cache_position,
|
| 269 |
-
position_embeddings=position_embeddings,
|
| 270 |
-
**kwargs,
|
| 271 |
-
)
|
| 272 |
-
hidden_states = residual + hidden_states
|
| 273 |
-
|
| 274 |
-
# Fully Connected
|
| 275 |
-
residual = hidden_states
|
| 276 |
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 277 |
-
hidden_states = self.mlp(hidden_states)
|
| 278 |
-
hidden_states = residual + hidden_states
|
| 279 |
-
return hidden_states
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
@auto_docstring
|
| 283 |
-
class Qwen3ASRPreTrainedModel(PreTrainedModel):
|
| 284 |
-
config: Qwen3ASRConfig
|
| 285 |
-
base_model_prefix = "model"
|
| 286 |
-
supports_gradient_checkpointing = True
|
| 287 |
-
_skip_keys_device_placement = "past_key_values"
|
| 288 |
-
_supports_flash_attn = True
|
| 289 |
-
_supports_sdpa = True
|
| 290 |
-
|
| 291 |
-
_can_compile_fullgraph = True
|
| 292 |
-
_supports_attention_backend = True
|
| 293 |
-
_can_record_outputs = {
|
| 294 |
-
"attentions": Qwen3ASRTextAttention,
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
@dataclass
|
| 299 |
-
class Qwen3ASRThinkerCausalLMOutputWithPast(MoeCausalLMOutputWithPast):
|
| 300 |
-
r"""
|
| 301 |
-
Args:
|
| 302 |
-
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
| 303 |
-
The rope index difference between sequence length and multimodal rope.
|
| 304 |
-
"""
|
| 305 |
-
|
| 306 |
-
rope_deltas: Optional[torch.LongTensor] = None
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
def _get_feat_extract_output_lengths(input_lengths):
|
| 310 |
-
"""
|
| 311 |
-
Computes the output length of the convolutional layers and the output length of the audio encoder
|
| 312 |
-
"""
|
| 313 |
-
|
| 314 |
-
input_lengths_leave = input_lengths % 100
|
| 315 |
-
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
| 316 |
-
output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
| 317 |
-
return output_lengths
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
class Qwen3ASRPreTrainedModelForConditionalGeneration(Qwen3ASRPreTrainedModel):
|
| 321 |
-
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 322 |
-
self,
|
| 323 |
-
attention_mask: torch.Tensor,
|
| 324 |
-
sequence_length: int,
|
| 325 |
-
target_length: int,
|
| 326 |
-
dtype: torch.dtype,
|
| 327 |
-
device: torch.device,
|
| 328 |
-
min_dtype: float,
|
| 329 |
-
cache_position: torch.Tensor,
|
| 330 |
-
batch_size: int,
|
| 331 |
-
):
|
| 332 |
-
"""
|
| 333 |
-
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 334 |
-
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 335 |
-
|
| 336 |
-
Args:
|
| 337 |
-
attention_mask (`torch.Tensor`):
|
| 338 |
-
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
| 339 |
-
sequence_length (`int`):
|
| 340 |
-
The sequence length being processed.
|
| 341 |
-
target_length (`int`):
|
| 342 |
-
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
| 343 |
-
dtype (`torch.dtype`):
|
| 344 |
-
The dtype to use for the 4D attention mask.
|
| 345 |
-
device (`torch.device`):
|
| 346 |
-
The device to place the 4D attention mask on.
|
| 347 |
-
min_dtype (`float`):
|
| 348 |
-
The minimum value representable with the dtype `dtype`.
|
| 349 |
-
cache_position (`torch.Tensor`):
|
| 350 |
-
Indices depicting the position of the input sequence tokens in the sequence.
|
| 351 |
-
batch_size (`torch.Tensor`):
|
| 352 |
-
Batch size.
|
| 353 |
-
"""
|
| 354 |
-
if attention_mask is not None and attention_mask.dim() == 4:
|
| 355 |
-
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 356 |
-
causal_mask = attention_mask
|
| 357 |
-
else:
|
| 358 |
-
causal_mask = torch.full(
|
| 359 |
-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
| 360 |
-
)
|
| 361 |
-
if sequence_length != 1:
|
| 362 |
-
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 363 |
-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 364 |
-
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 365 |
-
if attention_mask is not None:
|
| 366 |
-
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 367 |
-
mask_length = attention_mask.shape[-1]
|
| 368 |
-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
| 369 |
-
padding_mask = padding_mask == 0
|
| 370 |
-
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 371 |
-
padding_mask, min_dtype
|
| 372 |
-
)
|
| 373 |
-
|
| 374 |
-
return causal_mask
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
def get_chunked_index(
|
| 378 |
-
self, token_indices: torch.Tensor, tokens_per_chunk: int, remove_index: int
|
| 379 |
-
) -> list[tuple[int, int]]:
|
| 380 |
-
"""
|
| 381 |
-
Splits token index list into chunks based on token value ranges.
|
| 382 |
-
|
| 383 |
-
Given a list of token indices, returns a list of (start, end) index tuples representing
|
| 384 |
-
slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`.
|
| 385 |
-
|
| 386 |
-
For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that:
|
| 387 |
-
- the first chunk contains token values < 1000,
|
| 388 |
-
- the second chunk contains values >= 1000 and < 2000, and so on.
|
| 389 |
-
|
| 390 |
-
Parameters:
|
| 391 |
-
token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of
|
| 392 |
-
token index values.
|
| 393 |
-
t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
|
| 394 |
-
remove_index (`int`) An index id to subtract from `token_indices` before chunking
|
| 395 |
-
|
| 396 |
-
Returns:
|
| 397 |
-
`list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive)
|
| 398 |
-
and end (exclusive) indices of a chunk in `token_indices`.
|
| 399 |
-
"""
|
| 400 |
-
|
| 401 |
-
def _iter():
|
| 402 |
-
i, start_idx = 0, 0 # skip bos token
|
| 403 |
-
current_chunk = 1
|
| 404 |
-
while i < len(token_indices): # skip eos token
|
| 405 |
-
if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk:
|
| 406 |
-
yield (start_idx, i)
|
| 407 |
-
start_idx = i
|
| 408 |
-
current_chunk += 1
|
| 409 |
-
i += 1
|
| 410 |
-
yield (start_idx, len(token_indices))
|
| 411 |
-
|
| 412 |
-
return list(_iter())
|
| 413 |
-
|
| 414 |
-
def get_rope_index(
|
| 415 |
-
self,
|
| 416 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 417 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 418 |
-
"""
|
| 419 |
-
Calculate the rope index in LLM.
|
| 420 |
-
|
| 421 |
-
Explanation:
|
| 422 |
-
Each embedding sequence contains text embedding.
|
| 423 |
-
|
| 424 |
-
Args:
|
| 425 |
-
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 426 |
-
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 427 |
-
it.
|
| 428 |
-
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 429 |
-
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 430 |
-
|
| 431 |
-
- 1 for tokens that are **not masked**,
|
| 432 |
-
- 0 for tokens that are **masked**.
|
| 433 |
-
audio_seqlens (`torch.LongTensor` of shape `(num_audios)`, *optional*):
|
| 434 |
-
The length of feature shape of each audio in LLM.
|
| 435 |
-
|
| 436 |
-
Returns:
|
| 437 |
-
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
|
| 438 |
-
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
| 439 |
-
"""
|
| 440 |
-
mrope_position_deltas = []
|
| 441 |
-
|
| 442 |
-
position_ids = attention_mask.float().cumsum(-1) - 1
|
| 443 |
-
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 444 |
-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
| 445 |
-
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
| 446 |
-
mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
|
| 447 |
-
|
| 448 |
-
return position_ids, mrope_position_deltas
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
class Qwen3ASRAudioAttention(nn.Module):
|
| 452 |
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 453 |
-
|
| 454 |
-
def __init__(self, config):
|
| 455 |
-
super().__init__()
|
| 456 |
-
self.embed_dim = config.d_model
|
| 457 |
-
self.num_heads = config.encoder_attention_heads
|
| 458 |
-
self.dropout = config.attention_dropout
|
| 459 |
-
self.head_dim = self.embed_dim // self.num_heads
|
| 460 |
-
self.num_key_value_groups = 1 # needed for eager attention
|
| 461 |
-
self.config = config
|
| 462 |
-
|
| 463 |
-
if (self.head_dim * self.num_heads) != self.embed_dim:
|
| 464 |
-
raise ValueError(
|
| 465 |
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
| 466 |
-
f" and `num_heads`: {self.num_heads})."
|
| 467 |
-
)
|
| 468 |
-
self.scaling = self.head_dim**-0.5
|
| 469 |
-
self.attention_dropout = 0.0
|
| 470 |
-
self.is_decoder = False
|
| 471 |
-
self.is_causal = False
|
| 472 |
-
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
|
| 473 |
-
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
|
| 474 |
-
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
|
| 475 |
-
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
|
| 476 |
-
|
| 477 |
-
def forward(
|
| 478 |
-
self,
|
| 479 |
-
hidden_states: torch.Tensor,
|
| 480 |
-
cu_seqlens: Optional[torch.Tensor] = None,
|
| 481 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 482 |
-
**kwargs,
|
| 483 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 484 |
-
"""Input shape: Batch x Time x Channel"""
|
| 485 |
-
|
| 486 |
-
seq_length, _ = hidden_states.size()
|
| 487 |
-
|
| 488 |
-
query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
| 489 |
-
key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
| 490 |
-
value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
| 491 |
-
|
| 492 |
-
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
| 493 |
-
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
| 494 |
-
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
| 495 |
-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
| 496 |
-
|
| 497 |
-
attention_interface: Callable = eager_attention_forward
|
| 498 |
-
if self.config._attn_implementation != "eager":
|
| 499 |
-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 500 |
-
|
| 501 |
-
attn_output, _ = attention_interface(
|
| 502 |
-
self,
|
| 503 |
-
query_states,
|
| 504 |
-
key_states,
|
| 505 |
-
value_states,
|
| 506 |
-
attention_mask=attention_mask,
|
| 507 |
-
dropout=0.0 if not self.training else self.attention_dropout,
|
| 508 |
-
scaling=self.scaling,
|
| 509 |
-
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
| 510 |
-
cu_seq_lens_k=cu_seqlens,
|
| 511 |
-
max_length_q=max_seqlen,
|
| 512 |
-
max_length_k=max_seqlen,
|
| 513 |
-
is_causal=False,
|
| 514 |
-
**kwargs,
|
| 515 |
-
)
|
| 516 |
-
|
| 517 |
-
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
| 518 |
-
attn_output = self.out_proj(attn_output)
|
| 519 |
-
|
| 520 |
-
return attn_output
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
class Qwen3ASRAudioEncoderLayer(GradientCheckpointingLayer):
|
| 524 |
-
def __init__(self, config: Qwen3ASRAudioEncoderConfig):
|
| 525 |
-
super().__init__()
|
| 526 |
-
self.embed_dim = config.d_model
|
| 527 |
-
self.self_attn = Qwen3ASRAudioAttention(config)
|
| 528 |
-
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 529 |
-
self.dropout = config.dropout
|
| 530 |
-
self.activation_fn = ACT2FN[config.activation_function]
|
| 531 |
-
self.activation_dropout = config.activation_dropout
|
| 532 |
-
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
| 533 |
-
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
| 534 |
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 535 |
-
|
| 536 |
-
def forward(
|
| 537 |
-
self,
|
| 538 |
-
hidden_states: torch.Tensor,
|
| 539 |
-
cu_seqlens: torch.Tensor,
|
| 540 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 541 |
-
**kwargs,
|
| 542 |
-
) -> torch.Tensor:
|
| 543 |
-
"""
|
| 544 |
-
Args:
|
| 545 |
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 546 |
-
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 547 |
-
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 548 |
-
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
| 549 |
-
`(encoder_attention_heads,)`.
|
| 550 |
-
output_attentions (`bool`, *optional*):
|
| 551 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 552 |
-
returned tensors for more detail.
|
| 553 |
-
"""
|
| 554 |
-
residual = hidden_states
|
| 555 |
-
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 556 |
-
hidden_states = self.self_attn(
|
| 557 |
-
hidden_states=hidden_states,
|
| 558 |
-
cu_seqlens=cu_seqlens,
|
| 559 |
-
attention_mask=attention_mask,
|
| 560 |
-
**kwargs,
|
| 561 |
-
)
|
| 562 |
-
hidden_states = residual + hidden_states
|
| 563 |
-
residual = hidden_states
|
| 564 |
-
hidden_states = self.final_layer_norm(hidden_states)
|
| 565 |
-
hidden_states = self.fc1(hidden_states)
|
| 566 |
-
hidden_states = self.activation_fn(hidden_states)
|
| 567 |
-
hidden_states = self.fc2(hidden_states)
|
| 568 |
-
hidden_states = residual + hidden_states
|
| 569 |
-
|
| 570 |
-
if hidden_states.dtype == torch.float16:
|
| 571 |
-
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 572 |
-
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
| 573 |
-
|
| 574 |
-
outputs = (hidden_states,)
|
| 575 |
-
|
| 576 |
-
return outputs
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
class SinusoidsPositionEmbedding(nn.Module):
|
| 580 |
-
def __init__(self, length, channels, max_timescale=10000):
|
| 581 |
-
super().__init__()
|
| 582 |
-
if channels % 2 != 0:
|
| 583 |
-
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
|
| 584 |
-
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
| 585 |
-
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
|
| 586 |
-
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
| 587 |
-
self.register_buffer(
|
| 588 |
-
"positional_embedding",
|
| 589 |
-
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
|
| 590 |
-
persistent=False,
|
| 591 |
-
)
|
| 592 |
-
|
| 593 |
-
def forward(self, seqlen: int):
|
| 594 |
-
return self.positional_embedding[:seqlen, :]
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
@auto_docstring(
|
| 598 |
-
custom_intro="""
|
| 599 |
-
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
| 600 |
-
[`Qwen3ASRAudioEncoderLayer`].
|
| 601 |
-
"""
|
| 602 |
-
)
|
| 603 |
-
class Qwen3ASRAudioEncoder(Qwen3ASRPreTrainedModel):
|
| 604 |
-
config: Qwen3ASRAudioEncoderConfig
|
| 605 |
-
main_input_name = "input_features"
|
| 606 |
-
_no_split_modules = ["Qwen3ASRAudioEncoderLayer"]
|
| 607 |
-
_supports_sdpa = True
|
| 608 |
-
|
| 609 |
-
def __init__(self, config: Qwen3ASRAudioEncoderConfig):
|
| 610 |
-
super().__init__(config)
|
| 611 |
-
self.dropout = config.dropout
|
| 612 |
-
|
| 613 |
-
embed_dim = config.d_model
|
| 614 |
-
self.num_mel_bins = config.num_mel_bins
|
| 615 |
-
self.max_source_positions = config.max_source_positions
|
| 616 |
-
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
| 617 |
-
self.n_window = config.n_window
|
| 618 |
-
self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim)
|
| 619 |
-
self.layers = nn.ModuleList([Qwen3ASRAudioEncoderLayer(config) for _ in range(config.encoder_layers)])
|
| 620 |
-
self.ln_post = nn.LayerNorm(config.d_model)
|
| 621 |
-
self.gradient_checkpointing = False
|
| 622 |
-
self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
|
| 623 |
-
self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1)
|
| 624 |
-
self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1)
|
| 625 |
-
self.conv_out = nn.Linear(
|
| 626 |
-
config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
|
| 627 |
-
config.d_model,
|
| 628 |
-
bias=False,
|
| 629 |
-
)
|
| 630 |
-
self.proj1 = nn.Linear(config.d_model, config.d_model)
|
| 631 |
-
self.act = ACT2FN[config.activation_function]
|
| 632 |
-
self.proj2 = nn.Linear(config.d_model, config.output_dim)
|
| 633 |
-
self.n_window_infer = self.config.n_window_infer
|
| 634 |
-
self.conv_chunksize = self.config.conv_chunksize
|
| 635 |
-
# Initialize weights and apply final processing
|
| 636 |
-
self.post_init()
|
| 637 |
-
|
| 638 |
-
def _freeze_parameters(self):
|
| 639 |
-
for param in self.parameters():
|
| 640 |
-
param.requires_grad = False
|
| 641 |
-
self._requires_grad = False
|
| 642 |
-
|
| 643 |
-
def get_input_embeddings(self) -> nn.Module:
|
| 644 |
-
return self.conv1
|
| 645 |
-
|
| 646 |
-
def set_input_embeddings(self, value: nn.Module):
|
| 647 |
-
self.conv1 = value
|
| 648 |
-
|
| 649 |
-
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
| 650 |
-
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
| 651 |
-
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
| 652 |
-
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
| 653 |
-
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
| 654 |
-
if self.config._attn_implementation == "flash_attention_2":
|
| 655 |
-
return None
|
| 656 |
-
|
| 657 |
-
seq_length = inputs_tensor.shape[0]
|
| 658 |
-
attention_mask = torch.full(
|
| 659 |
-
[1, 1, seq_length, seq_length],
|
| 660 |
-
torch.finfo(inputs_tensor.dtype).min,
|
| 661 |
-
device=inputs_tensor.device,
|
| 662 |
-
dtype=inputs_tensor.dtype,
|
| 663 |
-
)
|
| 664 |
-
for i in range(1, len(cu_seqlens)):
|
| 665 |
-
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
| 666 |
-
return attention_mask
|
| 667 |
-
|
| 668 |
-
@auto_docstring
|
| 669 |
-
def forward(
|
| 670 |
-
self,
|
| 671 |
-
input_features,
|
| 672 |
-
feature_lens=None,
|
| 673 |
-
aftercnn_lens=None,
|
| 674 |
-
):
|
| 675 |
-
r"""
|
| 676 |
-
feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
|
| 677 |
-
mel length
|
| 678 |
-
aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
|
| 679 |
-
mel length after cnn
|
| 680 |
-
"""
|
| 681 |
-
aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
|
| 682 |
-
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
|
| 683 |
-
|
| 684 |
-
chunk_lengths = torch.tensor(
|
| 685 |
-
[self.n_window * 2] * chunk_num.sum(),
|
| 686 |
-
dtype=torch.long,
|
| 687 |
-
device=feature_lens.device,
|
| 688 |
-
)
|
| 689 |
-
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
|
| 690 |
-
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
|
| 691 |
-
chunk_lengths[chunk_lengths == 0] = self.n_window * 2
|
| 692 |
-
|
| 693 |
-
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
|
| 694 |
-
padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2)
|
| 695 |
-
feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
|
| 696 |
-
padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
|
| 697 |
-
[torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn],
|
| 698 |
-
batch_first=True,
|
| 699 |
-
)
|
| 700 |
-
padded_feature = padded_feature.unsqueeze(1)
|
| 701 |
-
# Split to chunk to avoid OOM during convolution
|
| 702 |
-
padded_embeds = []
|
| 703 |
-
for chunk in padded_feature.split(self.conv_chunksize, dim=0):
|
| 704 |
-
padded_embed = F.gelu(self.conv2d1(chunk))
|
| 705 |
-
padded_embed = F.gelu(self.conv2d2(padded_embed))
|
| 706 |
-
padded_embed = F.gelu(self.conv2d3(padded_embed))
|
| 707 |
-
padded_embeds.append(padded_embed)
|
| 708 |
-
padded_embed = torch.cat(padded_embeds, dim=0)
|
| 709 |
-
b, c, f, t = padded_embed.size()
|
| 710 |
-
padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f))
|
| 711 |
-
|
| 712 |
-
positional_embedding = (
|
| 713 |
-
self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
|
| 714 |
-
.unsqueeze(0)
|
| 715 |
-
.to(padded_embed.dtype)
|
| 716 |
-
)
|
| 717 |
-
padded_embed = padded_embed + positional_embedding
|
| 718 |
-
hidden_states = padded_embed[padded_mask_after_cnn]
|
| 719 |
-
cu_chunk_lens = [0]
|
| 720 |
-
window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2))
|
| 721 |
-
for cnn_len in aftercnn_lens:
|
| 722 |
-
cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
|
| 723 |
-
remainder = cnn_len % window_aftercnn
|
| 724 |
-
if remainder != 0:
|
| 725 |
-
cu_chunk_lens += [remainder]
|
| 726 |
-
cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32)
|
| 727 |
-
|
| 728 |
-
for encoder_layer in self.layers:
|
| 729 |
-
layer_outputs = encoder_layer(
|
| 730 |
-
hidden_states,
|
| 731 |
-
cu_seqlens,
|
| 732 |
-
)
|
| 733 |
-
|
| 734 |
-
hidden_states = layer_outputs[0]
|
| 735 |
-
|
| 736 |
-
hidden_states = self.ln_post(hidden_states)
|
| 737 |
-
hidden_states = self.proj1(hidden_states)
|
| 738 |
-
hidden_states = self.act(hidden_states)
|
| 739 |
-
hidden_states = self.proj2(hidden_states)
|
| 740 |
-
return BaseModelOutput(last_hidden_state=hidden_states)
|
| 741 |
-
|
| 742 |
-
def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"):
|
| 743 |
-
"""
|
| 744 |
-
Pads a sequence of tensors to their maximum length on indicated `padding_side`.
|
| 745 |
-
Then prepares a mask so that pad tokens are not attended to.
|
| 746 |
-
"""
|
| 747 |
-
max_len = tensor_len.max()
|
| 748 |
-
dim = tensor_list[0].shape[0]
|
| 749 |
-
padded_tensor = torch.full(
|
| 750 |
-
size=(len(tensor_list), dim, max_len),
|
| 751 |
-
fill_value=padding_value,
|
| 752 |
-
dtype=self.dtype,
|
| 753 |
-
device=tensor_list[0].device,
|
| 754 |
-
)
|
| 755 |
-
|
| 756 |
-
batch_mask = torch.zeros(
|
| 757 |
-
(len(tensor_len), max_len),
|
| 758 |
-
dtype=torch.long,
|
| 759 |
-
device=padded_tensor.device,
|
| 760 |
-
)
|
| 761 |
-
for i, length in enumerate(tensor_len):
|
| 762 |
-
batch_mask[i, :length] = 1
|
| 763 |
-
padded_tensor[i, :, :length] = tensor_list[i]
|
| 764 |
-
|
| 765 |
-
feature_lens_after_cnn = (tensor_len - 1) // 2 + 1
|
| 766 |
-
max_len_after_cnn = feature_lens_after_cnn.max()
|
| 767 |
-
batch_mask_after_cnn = torch.zeros(
|
| 768 |
-
(len(tensor_len), max_len_after_cnn),
|
| 769 |
-
dtype=torch.long,
|
| 770 |
-
device=padded_tensor.device,
|
| 771 |
-
)
|
| 772 |
-
for i, length in enumerate(feature_lens_after_cnn):
|
| 773 |
-
batch_mask_after_cnn[i, :length] = 1
|
| 774 |
-
return (
|
| 775 |
-
padded_tensor,
|
| 776 |
-
batch_mask.unsqueeze(1),
|
| 777 |
-
batch_mask_after_cnn.bool(),
|
| 778 |
-
)
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
class Qwen3ASRThinkerTextRotaryEmbedding(nn.Module):
|
| 782 |
-
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 783 |
-
|
| 784 |
-
def __init__(self, config: Qwen3ASRConfig, device=None):
|
| 785 |
-
super().__init__()
|
| 786 |
-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 787 |
-
self.rope_type = config.rope_scaling.get("rope_type", "default")
|
| 788 |
-
else:
|
| 789 |
-
self.rope_type = "default"
|
| 790 |
-
self.max_seq_len_cached = config.max_position_embeddings
|
| 791 |
-
self.original_max_seq_len = config.max_position_embeddings
|
| 792 |
-
|
| 793 |
-
self.config = config
|
| 794 |
-
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 795 |
-
|
| 796 |
-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 797 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 798 |
-
self.original_inv_freq = self.inv_freq
|
| 799 |
-
|
| 800 |
-
self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
|
| 801 |
-
|
| 802 |
-
def apply_interleaved_mrope(self, freqs, mrope_section):
|
| 803 |
-
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
| 804 |
-
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
|
| 805 |
-
interleaved [THTHWHTHW...TT], preserving frequency continuity.
|
| 806 |
-
args:
|
| 807 |
-
x: (3, bs, seq_len, head_dim // 2)
|
| 808 |
-
mrope_section: (3,)
|
| 809 |
-
returns:
|
| 810 |
-
x_t: (bs, seq_len, head_dim // 2)
|
| 811 |
-
"""
|
| 812 |
-
freqs_t = freqs[0] # just overwrite the first dimension T
|
| 813 |
-
for dim, offset in enumerate((1, 2), start=1): # H, W
|
| 814 |
-
length = mrope_section[dim] * 3
|
| 815 |
-
idx = slice(offset, length, 3)
|
| 816 |
-
freqs_t[..., idx] = freqs[dim, ..., idx]
|
| 817 |
-
return freqs_t
|
| 818 |
-
|
| 819 |
-
@torch.no_grad()
|
| 820 |
-
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 821 |
-
def forward(self, x, position_ids):
|
| 822 |
-
# In contrast to other models, Qwen3ASRThinker has different position ids for the grids
|
| 823 |
-
# So we expand the inv_freq to shape (3, ...)
|
| 824 |
-
if position_ids.ndim == 2:
|
| 825 |
-
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
| 826 |
-
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
| 827 |
-
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
| 828 |
-
|
| 829 |
-
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 830 |
-
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 831 |
-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
| 832 |
-
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
|
| 833 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
| 834 |
-
cos = emb.cos() * self.attention_scaling
|
| 835 |
-
sin = emb.sin() * self.attention_scaling
|
| 836 |
-
|
| 837 |
-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
class Qwen3ASRThinkerTextMLP(nn.Module):
|
| 841 |
-
def __init__(self, config, intermediate_size=None):
|
| 842 |
-
super().__init__()
|
| 843 |
-
self.config = config
|
| 844 |
-
self.hidden_size = config.hidden_size
|
| 845 |
-
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
|
| 846 |
-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 847 |
-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 848 |
-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 849 |
-
self.act_fn = ACT2FN[config.hidden_act]
|
| 850 |
-
|
| 851 |
-
def forward(self, x):
|
| 852 |
-
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 853 |
-
return down_proj
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
@use_kernel_forward_from_hub("RMSNorm")
|
| 857 |
-
class Qwen3ASRThinkerTextRMSNorm(nn.Module):
|
| 858 |
-
def __init__(self, hidden_size, eps=1e-6):
|
| 859 |
-
"""
|
| 860 |
-
Qwen3ASRThinkerTextRMSNorm is equivalent to T5LayerNorm
|
| 861 |
-
"""
|
| 862 |
-
super().__init__()
|
| 863 |
-
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 864 |
-
self.variance_epsilon = eps
|
| 865 |
-
|
| 866 |
-
def forward(self, hidden_states):
|
| 867 |
-
input_dtype = hidden_states.dtype
|
| 868 |
-
hidden_states = hidden_states.to(torch.float32)
|
| 869 |
-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 870 |
-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 871 |
-
return self.weight * hidden_states.to(input_dtype)
|
| 872 |
-
|
| 873 |
-
def extra_repr(self):
|
| 874 |
-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
class Qwen3ASRThinkerTextAttention(nn.Module):
|
| 878 |
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 879 |
-
|
| 880 |
-
def __init__(self, config, layer_idx):
|
| 881 |
-
super().__init__()
|
| 882 |
-
self.config = config
|
| 883 |
-
self.layer_idx = layer_idx
|
| 884 |
-
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 885 |
-
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 886 |
-
self.scaling = self.head_dim**-0.5
|
| 887 |
-
self.attention_dropout = config.attention_dropout
|
| 888 |
-
self.is_causal = True
|
| 889 |
-
|
| 890 |
-
self.q_proj = nn.Linear(
|
| 891 |
-
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 892 |
-
)
|
| 893 |
-
self.k_proj = nn.Linear(
|
| 894 |
-
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 895 |
-
)
|
| 896 |
-
self.v_proj = nn.Linear(
|
| 897 |
-
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 898 |
-
)
|
| 899 |
-
self.o_proj = nn.Linear(
|
| 900 |
-
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 901 |
-
)
|
| 902 |
-
self.q_norm = Qwen3ASRThinkerTextRMSNorm(
|
| 903 |
-
self.head_dim, eps=config.rms_norm_eps
|
| 904 |
-
) # unlike olmo, only on the head dim!
|
| 905 |
-
self.k_norm = Qwen3ASRThinkerTextRMSNorm(
|
| 906 |
-
self.head_dim, eps=config.rms_norm_eps
|
| 907 |
-
) # thus post q_norm does not need reshape
|
| 908 |
-
self.sliding_window = None
|
| 909 |
-
|
| 910 |
-
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 911 |
-
def forward(
|
| 912 |
-
self,
|
| 913 |
-
hidden_states: torch.Tensor,
|
| 914 |
-
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 915 |
-
attention_mask: Optional[torch.Tensor],
|
| 916 |
-
past_key_values: Optional[Cache] = None,
|
| 917 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 918 |
-
**kwargs: Unpack[FlashAttentionKwargs],
|
| 919 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 920 |
-
input_shape = hidden_states.shape[:-1]
|
| 921 |
-
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 922 |
-
|
| 923 |
-
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 924 |
-
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 925 |
-
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 926 |
-
|
| 927 |
-
cos, sin = position_embeddings
|
| 928 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 929 |
-
|
| 930 |
-
if past_key_values is not None:
|
| 931 |
-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 932 |
-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 933 |
-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 934 |
-
|
| 935 |
-
attention_interface: Callable = eager_attention_forward
|
| 936 |
-
if self.config._attn_implementation != "eager":
|
| 937 |
-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 938 |
-
|
| 939 |
-
attn_output, attn_weights = attention_interface(
|
| 940 |
-
self,
|
| 941 |
-
query_states,
|
| 942 |
-
key_states,
|
| 943 |
-
value_states,
|
| 944 |
-
attention_mask,
|
| 945 |
-
dropout=0.0 if not self.training else self.attention_dropout,
|
| 946 |
-
scaling=self.scaling,
|
| 947 |
-
sliding_window=self.sliding_window, # diff with Llama
|
| 948 |
-
**kwargs,
|
| 949 |
-
)
|
| 950 |
-
|
| 951 |
-
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 952 |
-
attn_output = self.o_proj(attn_output)
|
| 953 |
-
return attn_output, attn_weights
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
@auto_docstring(
|
| 957 |
-
custom_intro=(
|
| 958 |
-
"Text part of Qwen3ASRThinker, "
|
| 959 |
-
)
|
| 960 |
-
)
|
| 961 |
-
class Qwen3ASRThinkerTextModel(Qwen3ASRPreTrainedModel):
|
| 962 |
-
config: Qwen3ASRConfig
|
| 963 |
-
_no_split_modules = ["Qwen3ASRThinkerTextDecoderLayer"]
|
| 964 |
-
config_class = Qwen3ASRConfig
|
| 965 |
-
_can_record_outputs = {
|
| 966 |
-
"hidden_states": Qwen3ASRThinkerTextDecoderLayer,
|
| 967 |
-
"attentions": Qwen3ASRThinkerTextAttention,
|
| 968 |
-
}
|
| 969 |
-
|
| 970 |
-
def __init__(self, config: Qwen3ASRConfig):
|
| 971 |
-
super().__init__(config)
|
| 972 |
-
self.padding_idx = config.pad_token_id
|
| 973 |
-
self.vocab_size = config.vocab_size
|
| 974 |
-
|
| 975 |
-
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 976 |
-
self.layers = nn.ModuleList(
|
| 977 |
-
[Qwen3ASRThinkerTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 978 |
-
)
|
| 979 |
-
self.norm = Qwen3ASRTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 980 |
-
self.rotary_emb = Qwen3ASRThinkerTextRotaryEmbedding(config)
|
| 981 |
-
self.gradient_checkpointing = False
|
| 982 |
-
|
| 983 |
-
# Initialize weights and apply final processing
|
| 984 |
-
self.post_init()
|
| 985 |
-
|
| 986 |
-
@check_model_inputs()
|
| 987 |
-
@auto_docstring
|
| 988 |
-
def forward(
|
| 989 |
-
self,
|
| 990 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 991 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 992 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 993 |
-
past_key_values: Optional[Cache] = None,
|
| 994 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 995 |
-
use_cache: Optional[bool] = None,
|
| 996 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 997 |
-
**kwargs: Unpack[FlashAttentionKwargs],
|
| 998 |
-
) -> Union[tuple, BaseModelOutputWithPast]:
|
| 999 |
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1000 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1001 |
-
|
| 1002 |
-
# torch.jit.trace() doesn't support cache objects in the output
|
| 1003 |
-
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
| 1004 |
-
past_key_values = DynamicCache(config=self.config)
|
| 1005 |
-
|
| 1006 |
-
if inputs_embeds is None:
|
| 1007 |
-
inputs_embeds = self.embed_tokens(input_ids)
|
| 1008 |
-
|
| 1009 |
-
if cache_position is None:
|
| 1010 |
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1011 |
-
cache_position = torch.arange(
|
| 1012 |
-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 1013 |
-
)
|
| 1014 |
-
|
| 1015 |
-
# the hard coded `3` is for temporal, height and width.
|
| 1016 |
-
if position_ids is None:
|
| 1017 |
-
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
| 1018 |
-
elif position_ids.ndim == 2:
|
| 1019 |
-
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
| 1020 |
-
|
| 1021 |
-
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
|
| 1022 |
-
text_position_ids = position_ids[0]
|
| 1023 |
-
position_ids = position_ids[1:]
|
| 1024 |
-
else:
|
| 1025 |
-
text_position_ids = position_ids[0]
|
| 1026 |
-
|
| 1027 |
-
attention_mask = create_causal_mask(
|
| 1028 |
-
config=self.config,
|
| 1029 |
-
input_embeds=inputs_embeds,
|
| 1030 |
-
attention_mask=attention_mask,
|
| 1031 |
-
cache_position=cache_position,
|
| 1032 |
-
past_key_values=past_key_values,
|
| 1033 |
-
position_ids=text_position_ids,
|
| 1034 |
-
)
|
| 1035 |
-
|
| 1036 |
-
hidden_states = inputs_embeds
|
| 1037 |
-
|
| 1038 |
-
# create position embeddings to be shared across the decoder layers
|
| 1039 |
-
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1040 |
-
|
| 1041 |
-
# decoder layers
|
| 1042 |
-
for layer_idx, decoder_layer in enumerate(self.layers):
|
| 1043 |
-
layer_outputs = decoder_layer(
|
| 1044 |
-
hidden_states,
|
| 1045 |
-
attention_mask=attention_mask,
|
| 1046 |
-
position_ids=text_position_ids,
|
| 1047 |
-
past_key_values=past_key_values,
|
| 1048 |
-
cache_position=cache_position,
|
| 1049 |
-
position_embeddings=position_embeddings,
|
| 1050 |
-
**kwargs,
|
| 1051 |
-
)
|
| 1052 |
-
hidden_states = layer_outputs
|
| 1053 |
-
|
| 1054 |
-
hidden_states = self.norm(hidden_states)
|
| 1055 |
-
|
| 1056 |
-
return BaseModelOutputWithPast(
|
| 1057 |
-
last_hidden_state=hidden_states,
|
| 1058 |
-
past_key_values=past_key_values,
|
| 1059 |
-
)
|
| 1060 |
-
|
| 1061 |
-
|
| 1062 |
-
@auto_docstring(
|
| 1063 |
-
custom_intro="""
|
| 1064 |
-
The Qwen3ASRThinker model which consists of a audio backbone and a language model.
|
| 1065 |
-
"""
|
| 1066 |
-
)
|
| 1067 |
-
class Qwen3ASRThinkerForConditionalGeneration(Qwen3ASRPreTrainedModelForConditionalGeneration, GenerationMixin):
|
| 1068 |
-
config: Qwen3ASRThinkerConfig
|
| 1069 |
-
base_model_prefix = "thinker"
|
| 1070 |
-
_tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
|
| 1071 |
-
_no_split_modules = [
|
| 1072 |
-
"Qwen3ASRAudioEncoderLayer",
|
| 1073 |
-
"Qwen3ASRThinkerTextDecoderLayer",
|
| 1074 |
-
]
|
| 1075 |
-
_can_record_outputs = {
|
| 1076 |
-
"hidden_states": Qwen3ASRThinkerTextDecoderLayer,
|
| 1077 |
-
"attentions": Qwen3ASRThinkerTextAttention,
|
| 1078 |
-
}
|
| 1079 |
-
|
| 1080 |
-
def __init__(self, config):
|
| 1081 |
-
super().__init__(config)
|
| 1082 |
-
self.audio_tower = Qwen3ASRAudioEncoder._from_config(config.audio_config)
|
| 1083 |
-
self.vocab_size = config.text_config.vocab_size
|
| 1084 |
-
self.model = Qwen3ASRThinkerTextModel._from_config(config.text_config)
|
| 1085 |
-
if "forced_aligner" in config.model_type:
|
| 1086 |
-
self.lm_head = nn.Linear(config.text_config.hidden_size, config.classify_num, bias=False)
|
| 1087 |
-
else:
|
| 1088 |
-
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
| 1089 |
-
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
| 1090 |
-
self.rope_deltas = None
|
| 1091 |
-
self.post_init()
|
| 1092 |
-
|
| 1093 |
-
def get_input_embeddings(self):
|
| 1094 |
-
return self.model.get_input_embeddings()
|
| 1095 |
-
|
| 1096 |
-
def set_input_embeddings(self, value):
|
| 1097 |
-
self.model.set_input_embeddings(value)
|
| 1098 |
-
|
| 1099 |
-
def get_audio_features(
|
| 1100 |
-
self,
|
| 1101 |
-
input_features: torch.FloatTensor,
|
| 1102 |
-
feature_attention_mask: Optional[torch.LongTensor] = None,
|
| 1103 |
-
audio_feature_lengths: Optional[torch.LongTensor] = None,
|
| 1104 |
-
):
|
| 1105 |
-
"""
|
| 1106 |
-
Encodes audios into continuous embeddings that can be forwarded to the language model.
|
| 1107 |
-
|
| 1108 |
-
Args:
|
| 1109 |
-
input_features (`torch.FloatTensor`):
|
| 1110 |
-
The tensors corresponding to the input audios.
|
| 1111 |
-
feature_attention_mask (`torch.LongTensor`, *optional*):
|
| 1112 |
-
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
|
| 1113 |
-
audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
|
| 1114 |
-
The length of feature shape of each audio in LLM.
|
| 1115 |
-
"""
|
| 1116 |
-
if feature_attention_mask is not None:
|
| 1117 |
-
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
| 1118 |
-
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
|
| 1119 |
-
else:
|
| 1120 |
-
audio_feature_lengths = None
|
| 1121 |
-
|
| 1122 |
-
feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
| 1123 |
-
audio_outputs = self.audio_tower(
|
| 1124 |
-
input_features,
|
| 1125 |
-
feature_lens=feature_lens,
|
| 1126 |
-
)
|
| 1127 |
-
audio_features = audio_outputs.last_hidden_state
|
| 1128 |
-
|
| 1129 |
-
return audio_features
|
| 1130 |
-
|
| 1131 |
-
def get_placeholder_mask(
|
| 1132 |
-
self,
|
| 1133 |
-
input_ids: torch.LongTensor,
|
| 1134 |
-
inputs_embeds: torch.FloatTensor,
|
| 1135 |
-
):
|
| 1136 |
-
"""
|
| 1137 |
-
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
|
| 1138 |
-
equal to the length of multimodal features. If the lengths are different, an error is raised.
|
| 1139 |
-
"""
|
| 1140 |
-
if input_ids is None:
|
| 1141 |
-
special_audio_mask = (
|
| 1142 |
-
inputs_embeds
|
| 1143 |
-
== self.get_input_embeddings()(
|
| 1144 |
-
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 1145 |
-
)
|
| 1146 |
-
).all(-1)
|
| 1147 |
-
else:
|
| 1148 |
-
special_audio_mask = input_ids == self.config.audio_token_id
|
| 1149 |
-
|
| 1150 |
-
special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 1151 |
-
return special_audio_mask
|
| 1152 |
-
|
| 1153 |
-
@can_return_tuple
|
| 1154 |
-
@auto_docstring
|
| 1155 |
-
def forward(
|
| 1156 |
-
self,
|
| 1157 |
-
input_ids=None,
|
| 1158 |
-
input_features=None,
|
| 1159 |
-
attention_mask=None,
|
| 1160 |
-
feature_attention_mask=None,
|
| 1161 |
-
audio_feature_lengths=None,
|
| 1162 |
-
position_ids=None,
|
| 1163 |
-
past_key_values=None,
|
| 1164 |
-
inputs_embeds=None,
|
| 1165 |
-
rope_deltas=None,
|
| 1166 |
-
labels=None,
|
| 1167 |
-
use_cache=None,
|
| 1168 |
-
cache_position=None,
|
| 1169 |
-
**kwargs,
|
| 1170 |
-
) -> Union[tuple, Qwen3ASRThinkerCausalLMOutputWithPast]:
|
| 1171 |
-
r"""
|
| 1172 |
-
feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
|
| 1173 |
-
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
|
| 1174 |
-
- 1 for tokens that are **not masked**,
|
| 1175 |
-
- 0 for tokens that are **masked**.
|
| 1176 |
-
audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
|
| 1177 |
-
The length of feature shape of each audio in LLM.
|
| 1178 |
-
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
| 1179 |
-
The rope index difference between sequence length and multimodal rope.
|
| 1180 |
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1181 |
-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1182 |
-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1183 |
-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1184 |
-
"""
|
| 1185 |
-
|
| 1186 |
-
if inputs_embeds is None:
|
| 1187 |
-
# 1. Extract the input embeddings
|
| 1188 |
-
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 1189 |
-
|
| 1190 |
-
# 2. Merge text, audios
|
| 1191 |
-
if input_features is not None:
|
| 1192 |
-
audio_features = self.get_audio_features(
|
| 1193 |
-
input_features,
|
| 1194 |
-
feature_attention_mask=feature_attention_mask,
|
| 1195 |
-
audio_feature_lengths=audio_feature_lengths,
|
| 1196 |
-
)
|
| 1197 |
-
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 1198 |
-
audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
|
| 1199 |
-
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
| 1200 |
-
|
| 1201 |
-
if feature_attention_mask is not None:
|
| 1202 |
-
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
| 1203 |
-
else:
|
| 1204 |
-
audio_feature_lengths = None
|
| 1205 |
-
|
| 1206 |
-
if attention_mask is not None and position_ids is None:
|
| 1207 |
-
if (
|
| 1208 |
-
cache_position is None
|
| 1209 |
-
or (cache_position is not None and cache_position[0] == 0)
|
| 1210 |
-
or self.rope_deltas is None
|
| 1211 |
-
):
|
| 1212 |
-
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
|
| 1213 |
-
position_ids, rope_deltas = self.get_rope_index(
|
| 1214 |
-
attention_mask,
|
| 1215 |
-
)
|
| 1216 |
-
rope_deltas = rope_deltas - delta0
|
| 1217 |
-
self.rope_deltas = rope_deltas
|
| 1218 |
-
else:
|
| 1219 |
-
batch_size, seq_length = input_ids.shape
|
| 1220 |
-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
| 1221 |
-
position_ids = torch.arange(seq_length, device=input_ids.device)
|
| 1222 |
-
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
| 1223 |
-
position_ids = position_ids.add(delta)
|
| 1224 |
-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
| 1225 |
-
|
| 1226 |
-
outputs = self.model(
|
| 1227 |
-
attention_mask=attention_mask,
|
| 1228 |
-
position_ids=position_ids,
|
| 1229 |
-
past_key_values=past_key_values,
|
| 1230 |
-
inputs_embeds=inputs_embeds,
|
| 1231 |
-
use_cache=use_cache,
|
| 1232 |
-
cache_position=cache_position,
|
| 1233 |
-
**kwargs,
|
| 1234 |
-
)
|
| 1235 |
-
|
| 1236 |
-
hidden_states = outputs[0]
|
| 1237 |
-
logits = self.lm_head(hidden_states)
|
| 1238 |
-
|
| 1239 |
-
loss = None
|
| 1240 |
-
if labels is not None:
|
| 1241 |
-
loss = self.loss_function(
|
| 1242 |
-
logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size
|
| 1243 |
-
)
|
| 1244 |
-
|
| 1245 |
-
return Qwen3ASRThinkerCausalLMOutputWithPast(
|
| 1246 |
-
loss=loss,
|
| 1247 |
-
logits=logits,
|
| 1248 |
-
hidden_states=outputs.hidden_states,
|
| 1249 |
-
attentions=outputs.attentions,
|
| 1250 |
-
past_key_values=outputs.past_key_values,
|
| 1251 |
-
rope_deltas=self.rope_deltas,
|
| 1252 |
-
)
|
| 1253 |
-
|
| 1254 |
-
def prepare_inputs_for_generation(
|
| 1255 |
-
self,
|
| 1256 |
-
input_ids,
|
| 1257 |
-
past_key_values=None,
|
| 1258 |
-
attention_mask=None,
|
| 1259 |
-
inputs_embeds=None,
|
| 1260 |
-
cache_position=None,
|
| 1261 |
-
position_ids=None,
|
| 1262 |
-
use_cache=True,
|
| 1263 |
-
input_features=None,
|
| 1264 |
-
feature_attention_mask=None,
|
| 1265 |
-
**kwargs,
|
| 1266 |
-
):
|
| 1267 |
-
model_inputs = super().prepare_inputs_for_generation(
|
| 1268 |
-
input_ids,
|
| 1269 |
-
past_key_values=past_key_values,
|
| 1270 |
-
attention_mask=attention_mask,
|
| 1271 |
-
inputs_embeds=inputs_embeds,
|
| 1272 |
-
cache_position=cache_position,
|
| 1273 |
-
position_ids=position_ids,
|
| 1274 |
-
use_cache=use_cache,
|
| 1275 |
-
input_features=input_features,
|
| 1276 |
-
feature_attention_mask=feature_attention_mask,
|
| 1277 |
-
**kwargs,
|
| 1278 |
-
)
|
| 1279 |
-
|
| 1280 |
-
model_inputs["position_ids"] = None
|
| 1281 |
-
|
| 1282 |
-
if cache_position[0] != 0:
|
| 1283 |
-
model_inputs["input_features"] = None
|
| 1284 |
-
|
| 1285 |
-
return model_inputs
|
| 1286 |
-
|
| 1287 |
-
|
| 1288 |
-
@auto_docstring
|
| 1289 |
-
class Qwen3ASRThinkerTextPreTrainedModel(PreTrainedModel):
|
| 1290 |
-
config = Qwen3ASRConfig
|
| 1291 |
-
base_model_prefix = "model"
|
| 1292 |
-
supports_gradient_checkpointing = True
|
| 1293 |
-
_no_split_modules = ["Qwen3ASRThinkerTextDecoderLayer"]
|
| 1294 |
-
_skip_keys_device_placement = ["past_key_values"]
|
| 1295 |
-
_supports_flash_attn = True
|
| 1296 |
-
_supports_sdpa = True
|
| 1297 |
-
_supports_flex_attn = True
|
| 1298 |
-
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
| 1299 |
-
_supports_attention_backend = True
|
| 1300 |
-
_can_record_outputs = {
|
| 1301 |
-
"hidden_states": Qwen3ASRThinkerTextDecoderLayer,
|
| 1302 |
-
"attentions": Qwen3ASRThinkerTextAttention,
|
| 1303 |
-
}
|
| 1304 |
-
config_class = Qwen3ASRConfig
|
| 1305 |
-
|
| 1306 |
-
|
| 1307 |
-
class Qwen3ASRForConditionalGeneration(Qwen3ASRPreTrainedModel, GenerationMixin):
|
| 1308 |
-
config_class = Qwen3ASRConfig
|
| 1309 |
-
|
| 1310 |
-
def __init__(self, config: Qwen3ASRConfig):
|
| 1311 |
-
super().__init__(config)
|
| 1312 |
-
self.config = config
|
| 1313 |
-
|
| 1314 |
-
self.thinker = Qwen3ASRThinkerForConditionalGeneration._from_config(config.thinker_config)
|
| 1315 |
-
self.post_init()
|
| 1316 |
-
|
| 1317 |
-
def get_support_languages(self):
|
| 1318 |
-
return self.config.support_languages
|
| 1319 |
-
|
| 1320 |
-
@torch.no_grad()
|
| 1321 |
-
def generate(
|
| 1322 |
-
self,
|
| 1323 |
-
input_ids: Optional[torch.Tensor] = None,
|
| 1324 |
-
max_new_tokens: int = 8192,
|
| 1325 |
-
eos_token_id: int | list[int] = [151645, 151643],
|
| 1326 |
-
**kwargs,
|
| 1327 |
-
):
|
| 1328 |
-
shared_kwargs = {}
|
| 1329 |
-
thinker_kwargs = {
|
| 1330 |
-
"max_new_tokens": max_new_tokens,
|
| 1331 |
-
"eos_token_id": eos_token_id,
|
| 1332 |
-
}
|
| 1333 |
-
|
| 1334 |
-
for key, value in kwargs.items():
|
| 1335 |
-
# Process special input values
|
| 1336 |
-
if key == "feature_attention_mask":
|
| 1337 |
-
thinker_kwargs[key] = value
|
| 1338 |
-
elif key in ("input_features", "attention_mask"):
|
| 1339 |
-
thinker_kwargs[key] = value
|
| 1340 |
-
# Put other key to shared kwargs
|
| 1341 |
-
else:
|
| 1342 |
-
shared_kwargs[key] = value
|
| 1343 |
-
|
| 1344 |
-
# Merge kwargs
|
| 1345 |
-
for key, value in shared_kwargs.items():
|
| 1346 |
-
if key not in thinker_kwargs:
|
| 1347 |
-
thinker_kwargs[key] = value
|
| 1348 |
-
|
| 1349 |
-
thinker_result = self.thinker.generate(input_ids=input_ids, return_dict_in_generate=True, **thinker_kwargs)
|
| 1350 |
-
|
| 1351 |
-
return thinker_result
|
| 1352 |
-
|
| 1353 |
-
|
| 1354 |
-
__all__ = [
|
| 1355 |
-
"Qwen3ASRForConditionalGeneration",
|
| 1356 |
-
"Qwen3ASRThinkerTextModel",
|
| 1357 |
-
"Qwen3ASRThinkerForConditionalGeneration",
|
| 1358 |
-
"Qwen3ASRPreTrainedModel",
|
| 1359 |
-
"Qwen3ASRPreTrainedModelForConditionalGeneration",
|
| 1360 |
-
"Qwen3ASRThinkerTextPreTrainedModel",
|
| 1361 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_asr/core/transformers_backend/processing_qwen3_asr.py
DELETED
|
@@ -1,209 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
import re
|
| 16 |
-
|
| 17 |
-
import numpy as np
|
| 18 |
-
|
| 19 |
-
from transformers.audio_utils import AudioInput
|
| 20 |
-
from transformers.feature_extraction_utils import BatchFeature
|
| 21 |
-
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
|
| 22 |
-
from transformers.tokenization_utils_base import TextInput
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class Qwen3ASRProcessorKwargs(ProcessingKwargs, total=False):
|
| 26 |
-
_defaults = {
|
| 27 |
-
"text_kwargs": {
|
| 28 |
-
"padding": False,
|
| 29 |
-
"padding_side": "left",
|
| 30 |
-
},
|
| 31 |
-
"audio_kwargs": {
|
| 32 |
-
"sampling_rate": 16000,
|
| 33 |
-
"padding": True,
|
| 34 |
-
"return_attention_mask": True,
|
| 35 |
-
},
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def _get_feat_extract_output_lengths(input_lengths):
|
| 40 |
-
"""
|
| 41 |
-
Computes the output length of the convolutional layers and the output length of the audio encoder
|
| 42 |
-
"""
|
| 43 |
-
|
| 44 |
-
input_lengths_leave = input_lengths % 100
|
| 45 |
-
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
| 46 |
-
output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
| 47 |
-
return output_lengths
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class Qwen3ASRProcessor(ProcessorMixin):
|
| 51 |
-
r"""
|
| 52 |
-
Constructs a Qwen3ASR processor.
|
| 53 |
-
[`Qwen3ASRProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and [`Qwen2TokenizerFast`]. See the
|
| 54 |
-
[`~Qwen3ASRProcessor.__call__`] and [`~Qwen3ASRProcessor.decode`] for more information.
|
| 55 |
-
|
| 56 |
-
Args:
|
| 57 |
-
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
|
| 58 |
-
The audio feature extractor.
|
| 59 |
-
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
| 60 |
-
The text tokenizer.
|
| 61 |
-
chat_template (`Optional[str]`, *optional*):
|
| 62 |
-
The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
|
| 63 |
-
"""
|
| 64 |
-
|
| 65 |
-
attributes = ["feature_extractor", "tokenizer"]
|
| 66 |
-
feature_extractor_class = "WhisperFeatureExtractor"
|
| 67 |
-
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
| 68 |
-
|
| 69 |
-
def __init__(
|
| 70 |
-
self, feature_extractor=None, tokenizer=None, chat_template=None
|
| 71 |
-
):
|
| 72 |
-
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
|
| 73 |
-
self.audio_token = self.tokenizer.audio_token
|
| 74 |
-
self.audio_bos_token = self.tokenizer.audio_bos_token
|
| 75 |
-
self.audio_eos_token = self.tokenizer.audio_eos_token
|
| 76 |
-
|
| 77 |
-
def __call__(
|
| 78 |
-
self,
|
| 79 |
-
text: TextInput = None,
|
| 80 |
-
audio: AudioInput = None,
|
| 81 |
-
**kwargs,
|
| 82 |
-
) -> BatchFeature:
|
| 83 |
-
"""
|
| 84 |
-
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
|
| 85 |
-
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
| 86 |
-
the text. To prepare the audio(s), this method forwards the `audio` and `kwargs` arguments to
|
| 87 |
-
WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] if `audio` is not `None`. Please refer to the doctsring
|
| 88 |
-
of the above two methods for more information.
|
| 89 |
-
|
| 90 |
-
Args:
|
| 91 |
-
text (`str`, `List[str]`, `List[List[str]]`):
|
| 92 |
-
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 93 |
-
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 94 |
-
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 95 |
-
audio (`np.ndarray`, `List[np.ndarray]`):
|
| 96 |
-
The audio or batch of audio to be prepared. Each audio can be a NumPy array.
|
| 97 |
-
"""
|
| 98 |
-
|
| 99 |
-
if text is None:
|
| 100 |
-
raise ValueError("You need to specify either a `text` input to process.")
|
| 101 |
-
|
| 102 |
-
output_kwargs = self._merge_kwargs(
|
| 103 |
-
Qwen3ASRProcessorKwargs,
|
| 104 |
-
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 105 |
-
**kwargs,
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
if audio is not None:
|
| 109 |
-
output_kwargs["audio_kwargs"]["padding"] = True
|
| 110 |
-
output_kwargs["audio_kwargs"]["truncation"] = False
|
| 111 |
-
audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
|
| 112 |
-
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
|
| 113 |
-
"attention_mask"
|
| 114 |
-
) # rename feature_attention_mask to prevent conflicts later on
|
| 115 |
-
audio_inputs["input_features"] = audio_inputs.pop(
|
| 116 |
-
"input_features"
|
| 117 |
-
) # rename input_features to prevent conflicts later on
|
| 118 |
-
audio_lengths = iter(_get_feat_extract_output_lengths(audio_inputs["feature_attention_mask"].sum(-1)))
|
| 119 |
-
else:
|
| 120 |
-
audio_inputs = {}
|
| 121 |
-
audio_lengths = iter([])
|
| 122 |
-
|
| 123 |
-
if not isinstance(text, list):
|
| 124 |
-
text = [text]
|
| 125 |
-
|
| 126 |
-
text = self.replace_multimodal_special_tokens(
|
| 127 |
-
text,
|
| 128 |
-
audio_lengths,
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
| 132 |
-
|
| 133 |
-
return BatchFeature(
|
| 134 |
-
data={**texts_inputs, **audio_inputs},
|
| 135 |
-
tensor_type=kwargs.get("return_tensors"),
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
def replace_multimodal_special_tokens(
|
| 139 |
-
self,
|
| 140 |
-
text,
|
| 141 |
-
audio_lengths,
|
| 142 |
-
):
|
| 143 |
-
|
| 144 |
-
processed_text = []
|
| 145 |
-
for sample in text:
|
| 146 |
-
positions = []
|
| 147 |
-
special_tokens = [re.escape(tok) for tok in [self.audio_token]]
|
| 148 |
-
pattern = "|".join(special_tokens)
|
| 149 |
-
positions = sorted([(match.start(), match.group()) for match in re.finditer(pattern, sample)])
|
| 150 |
-
positions.sort(key=lambda x: x[0])
|
| 151 |
-
|
| 152 |
-
for _, special_token in positions:
|
| 153 |
-
if special_token == self.audio_token:
|
| 154 |
-
sample = sample.replace(self.audio_token, "<|audio_placeholder|>" * next(audio_lengths), 1)
|
| 155 |
-
|
| 156 |
-
sample = sample.replace("<|audio_placeholder|>", self.audio_token)
|
| 157 |
-
processed_text.append(sample)
|
| 158 |
-
return processed_text
|
| 159 |
-
|
| 160 |
-
def get_chunked_index(self, token_indices: np.ndarray, tokens_per_chunk: int) -> list[tuple[int, int]]:
|
| 161 |
-
"""
|
| 162 |
-
Splits token index list into chunks based on token value ranges.
|
| 163 |
-
|
| 164 |
-
Given a list of token indices, returns a list of (start, end) index tuples representing
|
| 165 |
-
slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`.
|
| 166 |
-
|
| 167 |
-
For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that:
|
| 168 |
-
- the first chunk contains token values < 1000,
|
| 169 |
-
- the second chunk contains values >= 1000 and < 2000, and so on.
|
| 170 |
-
|
| 171 |
-
Parameters:
|
| 172 |
-
token_indices (`np.ndarray`): A monotonically increasing list of token index values.
|
| 173 |
-
t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
|
| 174 |
-
|
| 175 |
-
Returns:
|
| 176 |
-
`list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive)
|
| 177 |
-
and end (exclusive) indices of a chunk in `token_indices`.
|
| 178 |
-
"""
|
| 179 |
-
|
| 180 |
-
def _iter():
|
| 181 |
-
i, start_idx = 0, 0 # skip bos token
|
| 182 |
-
current_chunk = 1
|
| 183 |
-
while i < len(token_indices): # skip eos token
|
| 184 |
-
if token_indices[i] >= current_chunk * tokens_per_chunk:
|
| 185 |
-
yield (start_idx, i)
|
| 186 |
-
start_idx = i
|
| 187 |
-
current_chunk += 1
|
| 188 |
-
i += 1
|
| 189 |
-
yield (start_idx, len(token_indices))
|
| 190 |
-
|
| 191 |
-
return list(_iter())
|
| 192 |
-
|
| 193 |
-
def apply_chat_template(self, conversations, chat_template=None, **kwargs):
|
| 194 |
-
return super().apply_chat_template(conversations, chat_template, **kwargs)
|
| 195 |
-
|
| 196 |
-
@property
|
| 197 |
-
def model_input_names(self):
|
| 198 |
-
tokenizer_input_names = self.tokenizer.model_input_names
|
| 199 |
-
feature_extractor_input_names = self.feature_extractor.model_input_names
|
| 200 |
-
return list(
|
| 201 |
-
dict.fromkeys(
|
| 202 |
-
tokenizer_input_names
|
| 203 |
-
+ feature_extractor_input_names
|
| 204 |
-
+ ["feature_attention_mask"]
|
| 205 |
-
)
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
__all__ = ["Qwen3ASRProcessor"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_asr/core/vllm_backend/__init__.py
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
from .qwen3_asr import Qwen3ASRForConditionalGeneration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_asr/core/vllm_backend/qwen3_asr.py
DELETED
|
@@ -1,997 +0,0 @@
|
|
| 1 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
-
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
-
# Copyright 2026 The Qwen team.
|
| 4 |
-
# Copyright 2023 The vLLM team.
|
| 5 |
-
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 6 |
-
#
|
| 7 |
-
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 8 |
-
# and OPT implementations in this library. It has been modified from its
|
| 9 |
-
# original forms to accommodate minor architectural differences compared
|
| 10 |
-
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 11 |
-
#
|
| 12 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 13 |
-
# you may not use this file except in compliance with the License.
|
| 14 |
-
# You may obtain a copy of the License at
|
| 15 |
-
#
|
| 16 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 17 |
-
#
|
| 18 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 19 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 20 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 21 |
-
# See the License for the specific language governing permissions and
|
| 22 |
-
# limitations under the License.
|
| 23 |
-
"""Inference-only Qwen3-ASR model."""
|
| 24 |
-
|
| 25 |
-
from collections.abc import Iterable, Mapping, Sequence
|
| 26 |
-
from typing import Any, Literal, cast
|
| 27 |
-
|
| 28 |
-
import numpy as np
|
| 29 |
-
import torch
|
| 30 |
-
import torch.nn as nn
|
| 31 |
-
import torch.nn.functional as F
|
| 32 |
-
from transformers.feature_extraction_utils import BatchFeature
|
| 33 |
-
from transformers.models.whisper import WhisperFeatureExtractor
|
| 34 |
-
|
| 35 |
-
from vllm.config import MultiModalConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
| 36 |
-
from vllm.config.multimodal import BaseDummyOptions
|
| 37 |
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
| 38 |
-
from vllm.inputs.data import PromptType
|
| 39 |
-
from vllm.logger import init_logger
|
| 40 |
-
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
| 41 |
-
from vllm.model_executor.layers.attention.mm_encoder_attention import (
|
| 42 |
-
MMEncoderAttention,
|
| 43 |
-
)
|
| 44 |
-
from vllm.model_executor.layers.linear import (
|
| 45 |
-
ColumnParallelLinear,
|
| 46 |
-
QKVParallelLinear,
|
| 47 |
-
RowParallelLinear,
|
| 48 |
-
)
|
| 49 |
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 50 |
-
from vllm.model_executor.models.interfaces import (
|
| 51 |
-
MultiModalEmbeddings,
|
| 52 |
-
SupportsMRoPE,
|
| 53 |
-
SupportsMultiModal,
|
| 54 |
-
SupportsPP,
|
| 55 |
-
SupportsTranscription,
|
| 56 |
-
)
|
| 57 |
-
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
| 58 |
-
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
|
| 59 |
-
from vllm.model_executor.models.qwen3_omni_moe_thinker import (
|
| 60 |
-
Qwen2_5OmniAudioFeatureInputs,
|
| 61 |
-
Qwen3OmniMoeThinkerMultiModalProcessor,
|
| 62 |
-
)
|
| 63 |
-
from vllm.model_executor.models.utils import (
|
| 64 |
-
AutoWeightsLoader,
|
| 65 |
-
WeightsMapper,
|
| 66 |
-
_merge_multimodal_embeddings,
|
| 67 |
-
maybe_prefix,
|
| 68 |
-
)
|
| 69 |
-
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
|
| 70 |
-
from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 71 |
-
from vllm.multimodal.inputs import (
|
| 72 |
-
AudioItem,
|
| 73 |
-
ModalityData,
|
| 74 |
-
MultiModalDataDict,
|
| 75 |
-
MultiModalFeatureSpec,
|
| 76 |
-
MultiModalFieldConfig,
|
| 77 |
-
MultiModalKwargsItems,
|
| 78 |
-
)
|
| 79 |
-
from vllm.multimodal.parse import (
|
| 80 |
-
AudioProcessorItems,
|
| 81 |
-
DictEmbeddingItems,
|
| 82 |
-
ModalityDataItems,
|
| 83 |
-
MultiModalDataItems,
|
| 84 |
-
MultiModalDataParser,
|
| 85 |
-
)
|
| 86 |
-
from vllm.multimodal.processing import (
|
| 87 |
-
BaseProcessingInfo,
|
| 88 |
-
PromptReplacement,
|
| 89 |
-
PromptUpdate,
|
| 90 |
-
)
|
| 91 |
-
from vllm.sequence import IntermediateTensors
|
| 92 |
-
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
| 93 |
-
from vllm.tokenizers import cached_tokenizer_from_config
|
| 94 |
-
from vllm.transformers_utils.processor import cached_processor_from_config
|
| 95 |
-
from vllm.model_executor.models.vision import (
|
| 96 |
-
get_vit_attn_backend,
|
| 97 |
-
)
|
| 98 |
-
from ..transformers_backend.configuration_qwen3_asr import (
|
| 99 |
-
Qwen3ASRConfig,
|
| 100 |
-
Qwen3ASRThinkerConfig,
|
| 101 |
-
Qwen3ASRAudioEncoderConfig
|
| 102 |
-
)
|
| 103 |
-
from ..transformers_backend.processing_qwen3_asr import (
|
| 104 |
-
Qwen3ASRProcessor,
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
try:
|
| 108 |
-
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
| 109 |
-
except:
|
| 110 |
-
from vllm.multimodal.processing import BaseDummyInputsBuilder
|
| 111 |
-
|
| 112 |
-
logger = init_logger(__name__)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
| 116 |
-
input_lengths_leave = input_lengths % 100
|
| 117 |
-
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
| 118 |
-
output_lengths = (
|
| 119 |
-
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
| 120 |
-
)
|
| 121 |
-
return output_lengths
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
# ============= Audio Encoder Components =============
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
class SinusoidsPositionEmbedding(nn.Module):
|
| 128 |
-
"""Sinusoidal position embedding for audio encoder."""
|
| 129 |
-
|
| 130 |
-
def __init__(self, length: int, channels: int, max_timescale: int = 10000):
|
| 131 |
-
super().__init__()
|
| 132 |
-
self.length = length
|
| 133 |
-
self.channels = channels
|
| 134 |
-
self.max_timescale = max_timescale
|
| 135 |
-
|
| 136 |
-
if channels % 2 != 0:
|
| 137 |
-
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
|
| 138 |
-
|
| 139 |
-
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
| 140 |
-
inv_timescales = torch.exp(
|
| 141 |
-
-log_timescale_increment * torch.arange(channels // 2).float()
|
| 142 |
-
)
|
| 143 |
-
scaled_time = (
|
| 144 |
-
torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
| 145 |
-
)
|
| 146 |
-
positional_embedding = torch.cat(
|
| 147 |
-
[torch.sin(scaled_time), torch.cos(scaled_time)], dim=1
|
| 148 |
-
)
|
| 149 |
-
self.register_buffer(
|
| 150 |
-
"positional_embedding", positional_embedding, persistent=False
|
| 151 |
-
)
|
| 152 |
-
|
| 153 |
-
def forward(self, seqlen: int) -> torch.Tensor:
|
| 154 |
-
return self.positional_embedding[:seqlen, :]
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
class Qwen3ASRAudioAttention(nn.Module):
|
| 158 |
-
"""Multi-headed attention for Qwen3-Omni Audio Encoder using MMEncoderAttention."""
|
| 159 |
-
|
| 160 |
-
def __init__(
|
| 161 |
-
self,
|
| 162 |
-
config: Qwen3ASRAudioEncoderConfig,
|
| 163 |
-
multimodal_config: MultiModalConfig | None = None,
|
| 164 |
-
prefix: str = "",
|
| 165 |
-
):
|
| 166 |
-
super().__init__()
|
| 167 |
-
self.embed_dim = config.d_model
|
| 168 |
-
self.num_heads = config.encoder_attention_heads
|
| 169 |
-
self.head_dim = self.embed_dim // self.num_heads
|
| 170 |
-
tp_size = get_tensor_model_parallel_world_size()
|
| 171 |
-
self.num_local_heads = self.num_heads // tp_size
|
| 172 |
-
|
| 173 |
-
if (self.head_dim * self.num_heads) != self.embed_dim:
|
| 174 |
-
raise ValueError(
|
| 175 |
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: "
|
| 176 |
-
f"{self.embed_dim} and `num_heads`: {self.num_heads})."
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
self.scaling = self.head_dim**-0.5
|
| 180 |
-
|
| 181 |
-
self.qkv = QKVParallelLinear(
|
| 182 |
-
hidden_size=self.embed_dim,
|
| 183 |
-
head_size=self.head_dim,
|
| 184 |
-
total_num_heads=self.num_heads,
|
| 185 |
-
total_num_kv_heads=self.num_heads,
|
| 186 |
-
bias=True,
|
| 187 |
-
prefix=f"{prefix}.qkv",
|
| 188 |
-
)
|
| 189 |
-
|
| 190 |
-
self.out_proj = RowParallelLinear(
|
| 191 |
-
input_size=self.embed_dim,
|
| 192 |
-
output_size=self.embed_dim,
|
| 193 |
-
bias=True,
|
| 194 |
-
prefix=f"{prefix}.out_proj",
|
| 195 |
-
)
|
| 196 |
-
|
| 197 |
-
self.attn = MMEncoderAttention(
|
| 198 |
-
num_heads=self.num_local_heads,
|
| 199 |
-
head_size=self.head_dim,
|
| 200 |
-
scale=self.scaling,
|
| 201 |
-
multimodal_config=multimodal_config,
|
| 202 |
-
)
|
| 203 |
-
|
| 204 |
-
def forward(
|
| 205 |
-
self,
|
| 206 |
-
hidden_states: torch.Tensor,
|
| 207 |
-
cu_seqlens: torch.Tensor,
|
| 208 |
-
max_seqlen: torch.Tensor | None,
|
| 209 |
-
) -> torch.Tensor:
|
| 210 |
-
seq_length, _ = hidden_states.size()
|
| 211 |
-
qkv, _ = self.qkv(hidden_states)
|
| 212 |
-
q, k, v = qkv.chunk(3, dim=-1)
|
| 213 |
-
q = q.view(1, seq_length, -1, self.head_dim)
|
| 214 |
-
k = k.view(1, seq_length, -1, self.head_dim)
|
| 215 |
-
v = v.view(1, seq_length, -1, self.head_dim)
|
| 216 |
-
|
| 217 |
-
attn_output = self.attn(
|
| 218 |
-
query=q,
|
| 219 |
-
key=k,
|
| 220 |
-
value=v,
|
| 221 |
-
cu_seqlens=cu_seqlens,
|
| 222 |
-
max_seqlen=max_seqlen,
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
attn_output = attn_output.view(seq_length, -1)
|
| 226 |
-
output, _ = self.out_proj(attn_output)
|
| 227 |
-
return output
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
class Qwen3ASRAudioEncoderLayer(nn.Module):
|
| 231 |
-
"""Transformer encoder layer for Qwen3-Omni Audio Encoder."""
|
| 232 |
-
|
| 233 |
-
def __init__(
|
| 234 |
-
self,
|
| 235 |
-
config: Qwen3ASRAudioEncoderConfig,
|
| 236 |
-
multimodal_config: MultiModalConfig | None = None,
|
| 237 |
-
prefix: str = "",
|
| 238 |
-
):
|
| 239 |
-
super().__init__()
|
| 240 |
-
self.embed_dim = config.d_model
|
| 241 |
-
self.self_attn = Qwen3ASRAudioAttention(
|
| 242 |
-
config, multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn"
|
| 243 |
-
)
|
| 244 |
-
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 245 |
-
self.activation_fn = _ACTIVATION_REGISTRY[config.activation_function]
|
| 246 |
-
self.fc1 = ColumnParallelLinear(
|
| 247 |
-
self.embed_dim,
|
| 248 |
-
config.encoder_ffn_dim,
|
| 249 |
-
bias=True,
|
| 250 |
-
prefix=f"{prefix}.fc1",
|
| 251 |
-
)
|
| 252 |
-
self.fc2 = RowParallelLinear(
|
| 253 |
-
config.encoder_ffn_dim,
|
| 254 |
-
self.embed_dim,
|
| 255 |
-
bias=True,
|
| 256 |
-
prefix=f"{prefix}.fc2",
|
| 257 |
-
)
|
| 258 |
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 259 |
-
|
| 260 |
-
def forward(
|
| 261 |
-
self,
|
| 262 |
-
hidden_states: torch.Tensor,
|
| 263 |
-
cu_seqlens: torch.Tensor,
|
| 264 |
-
max_seqlen: torch.Tensor | None,
|
| 265 |
-
) -> torch.Tensor:
|
| 266 |
-
"""
|
| 267 |
-
Args:
|
| 268 |
-
hidden_states: Input tensor of shape (seq_len, hidden_size)
|
| 269 |
-
cu_seqlens: Cumulative sequence lengths
|
| 270 |
-
max_seqlen: Maximum sequence length in the batch
|
| 271 |
-
"""
|
| 272 |
-
residual = hidden_states
|
| 273 |
-
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 274 |
-
hidden_states = self.self_attn(
|
| 275 |
-
hidden_states=hidden_states,
|
| 276 |
-
cu_seqlens=cu_seqlens,
|
| 277 |
-
max_seqlen=max_seqlen,
|
| 278 |
-
)
|
| 279 |
-
hidden_states = residual + hidden_states
|
| 280 |
-
|
| 281 |
-
residual = hidden_states
|
| 282 |
-
hidden_states = self.final_layer_norm(hidden_states)
|
| 283 |
-
hidden_states, _ = self.fc1(hidden_states)
|
| 284 |
-
hidden_states = self.activation_fn(hidden_states)
|
| 285 |
-
hidden_states, _ = self.fc2(hidden_states)
|
| 286 |
-
hidden_states = residual + hidden_states
|
| 287 |
-
|
| 288 |
-
# Clamp for numerical stability with fp16
|
| 289 |
-
if hidden_states.dtype == torch.float16:
|
| 290 |
-
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 291 |
-
hidden_states = torch.clamp(
|
| 292 |
-
hidden_states, min=-clamp_value, max=clamp_value
|
| 293 |
-
)
|
| 294 |
-
|
| 295 |
-
return hidden_states
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
class Qwen3ASRAudioEncoder(nn.Module):
|
| 299 |
-
"""vLLM-native Qwen3-ASR Audio Encoder."""
|
| 300 |
-
|
| 301 |
-
def __init__(
|
| 302 |
-
self,
|
| 303 |
-
config: Qwen3ASRAudioEncoderConfig,
|
| 304 |
-
multimodal_config: MultiModalConfig | None = None,
|
| 305 |
-
prefix: str = "",
|
| 306 |
-
):
|
| 307 |
-
super().__init__()
|
| 308 |
-
|
| 309 |
-
embed_dim = config.d_model
|
| 310 |
-
self.num_mel_bins = config.num_mel_bins
|
| 311 |
-
self.max_source_positions = config.max_source_positions
|
| 312 |
-
self.n_window = config.n_window
|
| 313 |
-
self.n_window_infer = config.n_window_infer
|
| 314 |
-
self.conv_chunksize = config.conv_chunksize
|
| 315 |
-
|
| 316 |
-
# Position embedding
|
| 317 |
-
self.positional_embedding = SinusoidsPositionEmbedding(
|
| 318 |
-
self.max_source_positions, embed_dim
|
| 319 |
-
)
|
| 320 |
-
|
| 321 |
-
# Convolutional layers for mel-spectrogram processing
|
| 322 |
-
self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
|
| 323 |
-
self.conv2d2 = nn.Conv2d(
|
| 324 |
-
config.downsample_hidden_size,
|
| 325 |
-
config.downsample_hidden_size,
|
| 326 |
-
3,
|
| 327 |
-
2,
|
| 328 |
-
padding=1,
|
| 329 |
-
)
|
| 330 |
-
self.conv2d3 = nn.Conv2d(
|
| 331 |
-
config.downsample_hidden_size,
|
| 332 |
-
config.downsample_hidden_size,
|
| 333 |
-
3,
|
| 334 |
-
2,
|
| 335 |
-
padding=1,
|
| 336 |
-
)
|
| 337 |
-
|
| 338 |
-
conv_out_dim = config.downsample_hidden_size * (
|
| 339 |
-
(((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2
|
| 340 |
-
)
|
| 341 |
-
self.conv_out = nn.Linear(conv_out_dim, config.d_model, bias=False)
|
| 342 |
-
|
| 343 |
-
# Transformer encoder layers
|
| 344 |
-
self.layers = nn.ModuleList(
|
| 345 |
-
[
|
| 346 |
-
Qwen3ASRAudioEncoderLayer(
|
| 347 |
-
config,
|
| 348 |
-
multimodal_config=multimodal_config,
|
| 349 |
-
prefix=f"{prefix}.layers.{i}",
|
| 350 |
-
)
|
| 351 |
-
for i in range(config.encoder_layers)
|
| 352 |
-
]
|
| 353 |
-
)
|
| 354 |
-
|
| 355 |
-
# Output layers
|
| 356 |
-
self.ln_post = nn.LayerNorm(config.d_model)
|
| 357 |
-
self.proj1 = nn.Linear(config.d_model, config.d_model)
|
| 358 |
-
self.act = _ACTIVATION_REGISTRY[config.activation_function]
|
| 359 |
-
self.proj2 = nn.Linear(config.d_model, config.output_dim)
|
| 360 |
-
|
| 361 |
-
# Get attention backend
|
| 362 |
-
attn_backend_override = (
|
| 363 |
-
multimodal_config.mm_encoder_attn_backend
|
| 364 |
-
if multimodal_config is not None
|
| 365 |
-
else None
|
| 366 |
-
)
|
| 367 |
-
self.attn_backend = get_vit_attn_backend(
|
| 368 |
-
head_size=config.d_model // config.encoder_attention_heads,
|
| 369 |
-
dtype=torch.get_default_dtype(),
|
| 370 |
-
attn_backend_override=attn_backend_override,
|
| 371 |
-
)
|
| 372 |
-
|
| 373 |
-
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
|
| 374 |
-
"""Compute max_seqlen only for flash attention backends."""
|
| 375 |
-
max_seqlen = None
|
| 376 |
-
if self.attn_backend in {
|
| 377 |
-
AttentionBackendEnum.FLASH_ATTN,
|
| 378 |
-
AttentionBackendEnum.ROCM_AITER_FA,
|
| 379 |
-
}:
|
| 380 |
-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
| 381 |
-
return max_seqlen
|
| 382 |
-
|
| 383 |
-
@property
|
| 384 |
-
def dtype(self) -> torch.dtype:
|
| 385 |
-
return self.conv2d1.weight.dtype
|
| 386 |
-
|
| 387 |
-
@property
|
| 388 |
-
def device(self) -> torch.device:
|
| 389 |
-
return self.conv2d1.weight.device
|
| 390 |
-
|
| 391 |
-
def forward(
|
| 392 |
-
self,
|
| 393 |
-
input_features: torch.Tensor,
|
| 394 |
-
feature_lens: torch.Tensor,
|
| 395 |
-
aftercnn_lens: torch.Tensor,
|
| 396 |
-
):
|
| 397 |
-
# Compute chunk information
|
| 398 |
-
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
|
| 399 |
-
|
| 400 |
-
chunk_lengths = torch.tensor(
|
| 401 |
-
[self.n_window * 2] * chunk_num.sum(),
|
| 402 |
-
dtype=torch.long,
|
| 403 |
-
device=feature_lens.device,
|
| 404 |
-
)
|
| 405 |
-
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
|
| 406 |
-
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
|
| 407 |
-
chunk_lengths[chunk_lengths == 0] = self.n_window * 2
|
| 408 |
-
|
| 409 |
-
# Split input features into chunks and pad
|
| 410 |
-
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
|
| 411 |
-
padded_feature = nn.utils.rnn.pad_sequence(
|
| 412 |
-
chunk_list, batch_first=True
|
| 413 |
-
).transpose(1, 2)
|
| 414 |
-
|
| 415 |
-
# Compute feature lengths after CNN
|
| 416 |
-
feature_lens_after_cnn = self._get_cnn_output_lengths(chunk_lengths)
|
| 417 |
-
# Vectorized mask creation: avoid creating many small tensors
|
| 418 |
-
max_len_after_cnn = feature_lens_after_cnn.max().item()
|
| 419 |
-
indices = torch.arange(max_len_after_cnn, device=padded_feature.device)
|
| 420 |
-
padded_mask_after_cnn = indices.unsqueeze(0) < feature_lens_after_cnn.unsqueeze(
|
| 421 |
-
1
|
| 422 |
-
)
|
| 423 |
-
|
| 424 |
-
# Add channel dimension for conv2d
|
| 425 |
-
padded_feature = padded_feature.unsqueeze(1)
|
| 426 |
-
|
| 427 |
-
# Apply convolutional layers (chunk if needed to avoid OOM)
|
| 428 |
-
if padded_feature.size(0) <= self.conv_chunksize:
|
| 429 |
-
# Fast path: no chunking needed
|
| 430 |
-
padded_embed = F.gelu(self.conv2d1(padded_feature))
|
| 431 |
-
padded_embed = F.gelu(self.conv2d2(padded_embed))
|
| 432 |
-
padded_embed = F.gelu(self.conv2d3(padded_embed))
|
| 433 |
-
else:
|
| 434 |
-
# Chunked processing to avoid OOM
|
| 435 |
-
padded_embeds = []
|
| 436 |
-
for chunk in padded_feature.split(self.conv_chunksize, dim=0):
|
| 437 |
-
padded_embed = F.gelu(self.conv2d1(chunk))
|
| 438 |
-
padded_embed = F.gelu(self.conv2d2(padded_embed))
|
| 439 |
-
padded_embed = F.gelu(self.conv2d3(padded_embed))
|
| 440 |
-
padded_embeds.append(padded_embed)
|
| 441 |
-
padded_embed = torch.cat(padded_embeds, dim=0)
|
| 442 |
-
|
| 443 |
-
# (batch, channels, freq, time) -> (batch, time, channels*freq)
|
| 444 |
-
b, c, f, t = padded_embed.size()
|
| 445 |
-
padded_embed = self.conv_out(
|
| 446 |
-
padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)
|
| 447 |
-
)
|
| 448 |
-
|
| 449 |
-
# Add positional embedding
|
| 450 |
-
positional_embedding = (
|
| 451 |
-
self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
|
| 452 |
-
.unsqueeze(0)
|
| 453 |
-
.to(padded_embed.dtype)
|
| 454 |
-
)
|
| 455 |
-
padded_embed = padded_embed + positional_embedding
|
| 456 |
-
|
| 457 |
-
# Extract valid hidden states and compute cu_seqlens
|
| 458 |
-
hidden_states = padded_embed[padded_mask_after_cnn]
|
| 459 |
-
|
| 460 |
-
# Compute cumulative sequence lengths for chunked attention
|
| 461 |
-
cu_chunk_lens = [0]
|
| 462 |
-
window_aftercnn = padded_mask_after_cnn.shape[-1] * (
|
| 463 |
-
self.n_window_infer // (self.n_window * 2)
|
| 464 |
-
)
|
| 465 |
-
# Use tolist() for efficient batch conversion from tensor to Python
|
| 466 |
-
for cnn_len in aftercnn_lens.tolist():
|
| 467 |
-
num_full_chunks = cnn_len // window_aftercnn
|
| 468 |
-
remainder = cnn_len % window_aftercnn
|
| 469 |
-
cu_chunk_lens.extend([window_aftercnn] * num_full_chunks)
|
| 470 |
-
if remainder:
|
| 471 |
-
cu_chunk_lens.append(remainder)
|
| 472 |
-
cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(
|
| 473 |
-
-1, dtype=torch.int32
|
| 474 |
-
)
|
| 475 |
-
|
| 476 |
-
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
| 477 |
-
|
| 478 |
-
# Apply transformer layers
|
| 479 |
-
for encoder_layer in self.layers:
|
| 480 |
-
hidden_states = encoder_layer(
|
| 481 |
-
hidden_states,
|
| 482 |
-
cu_seqlens,
|
| 483 |
-
max_seqlen,
|
| 484 |
-
)
|
| 485 |
-
|
| 486 |
-
# Apply output layers
|
| 487 |
-
hidden_states = self.ln_post(hidden_states)
|
| 488 |
-
hidden_states = self.proj1(hidden_states)
|
| 489 |
-
hidden_states = self.act(hidden_states)
|
| 490 |
-
hidden_states = self.proj2(hidden_states)
|
| 491 |
-
|
| 492 |
-
return hidden_states
|
| 493 |
-
|
| 494 |
-
def _get_cnn_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
|
| 495 |
-
"""Compute output lengths after the three conv2d layers."""
|
| 496 |
-
lengths = input_lengths
|
| 497 |
-
for _ in range(3):
|
| 498 |
-
lengths = (lengths - 1) // 2 + 1
|
| 499 |
-
return lengths
|
| 500 |
-
|
| 501 |
-
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
| 502 |
-
"""Load weights with mapping from HuggingFace format."""
|
| 503 |
-
stacked_params_mapping = [
|
| 504 |
-
# (param_name, shard_name, shard_id)
|
| 505 |
-
("self_attn.qkv.", "self_attn.q_proj.", "q"),
|
| 506 |
-
("self_attn.qkv.", "self_attn.k_proj.", "k"),
|
| 507 |
-
("self_attn.qkv.", "self_attn.v_proj.", "v"),
|
| 508 |
-
]
|
| 509 |
-
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
| 510 |
-
loaded_params: set[str] = set()
|
| 511 |
-
|
| 512 |
-
for name, loaded_weight in weights:
|
| 513 |
-
for param_name, weight_name, shard_id in stacked_params_mapping:
|
| 514 |
-
if weight_name not in name:
|
| 515 |
-
continue
|
| 516 |
-
name = name.replace(weight_name, param_name)
|
| 517 |
-
|
| 518 |
-
param = params_dict[name]
|
| 519 |
-
weight_loader = param.weight_loader
|
| 520 |
-
weight_loader(param, loaded_weight, shard_id)
|
| 521 |
-
break
|
| 522 |
-
else:
|
| 523 |
-
param = params_dict.get(name)
|
| 524 |
-
if param is not None:
|
| 525 |
-
weight_loader = getattr(
|
| 526 |
-
param, "weight_loader", default_weight_loader
|
| 527 |
-
)
|
| 528 |
-
weight_loader(param, loaded_weight)
|
| 529 |
-
loaded_params.add(name)
|
| 530 |
-
return loaded_params
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
class Qwen3ASRProcessingInfo(BaseProcessingInfo):
|
| 534 |
-
def get_hf_config(self):
|
| 535 |
-
return self.ctx.get_hf_config(Qwen3ASRConfig).thinker_config
|
| 536 |
-
|
| 537 |
-
def get_hf_processor(self, **kwargs: object) -> Qwen3ASRProcessor:
|
| 538 |
-
processor = self.ctx.get_hf_processor(
|
| 539 |
-
Qwen3ASRProcessor,
|
| 540 |
-
use_fast=kwargs.pop("use_fast", True),
|
| 541 |
-
**kwargs,
|
| 542 |
-
)
|
| 543 |
-
if not hasattr(processor, "audio_token"):
|
| 544 |
-
processor.audio_token = "<|audio_pad|>"
|
| 545 |
-
return processor
|
| 546 |
-
|
| 547 |
-
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
|
| 548 |
-
hf_processor = self.get_hf_processor(**kwargs)
|
| 549 |
-
feature_extractor = hf_processor.feature_extractor
|
| 550 |
-
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
| 551 |
-
return feature_extractor
|
| 552 |
-
|
| 553 |
-
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
| 554 |
-
return {"audio": None}
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
class Qwen3ASRDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3ASRProcessingInfo]):
|
| 558 |
-
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
| 559 |
-
num_audios = mm_counts.get("audio", 0)
|
| 560 |
-
|
| 561 |
-
hf_processor = self.info.get_hf_processor()
|
| 562 |
-
audio_token = hf_processor.audio_token
|
| 563 |
-
|
| 564 |
-
return audio_token * num_audios
|
| 565 |
-
|
| 566 |
-
def get_dummy_mm_data(
|
| 567 |
-
self,
|
| 568 |
-
seq_len: int,
|
| 569 |
-
mm_counts: Mapping[str, int],
|
| 570 |
-
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
| 571 |
-
) -> MultiModalDataDict:
|
| 572 |
-
num_audios = mm_counts.get("audio", 0)
|
| 573 |
-
|
| 574 |
-
feature_extractor = self.info.get_feature_extractor()
|
| 575 |
-
|
| 576 |
-
target_audio_length = (
|
| 577 |
-
min(
|
| 578 |
-
feature_extractor.chunk_length,
|
| 579 |
-
30,
|
| 580 |
-
)
|
| 581 |
-
* feature_extractor.sampling_rate
|
| 582 |
-
)
|
| 583 |
-
|
| 584 |
-
audio_overrides = mm_options.get("audio") if mm_options else None
|
| 585 |
-
|
| 586 |
-
return {
|
| 587 |
-
"audio": self._get_dummy_audios(
|
| 588 |
-
length=target_audio_length,
|
| 589 |
-
num_audios=num_audios,
|
| 590 |
-
overrides=audio_overrides,
|
| 591 |
-
),
|
| 592 |
-
}
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
def _qwen3asr_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
| 596 |
-
audio_feature_lengths = hf_inputs.get("audio_feature_lengths", torch.empty((0,)))
|
| 597 |
-
return dict(
|
| 598 |
-
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
|
| 599 |
-
"audio", audio_feature_lengths, dim=1
|
| 600 |
-
),
|
| 601 |
-
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
| 602 |
-
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
|
| 603 |
-
)
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
class Qwen3ASRMultiModalDataParser(MultiModalDataParser):
|
| 607 |
-
def _parse_audio_data(
|
| 608 |
-
self,
|
| 609 |
-
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
|
| 610 |
-
) -> ModalityDataItems[Any, Any] | None:
|
| 611 |
-
if isinstance(data, dict):
|
| 612 |
-
return DictEmbeddingItems(
|
| 613 |
-
data,
|
| 614 |
-
modality="audio",
|
| 615 |
-
required_fields={"input_audio_features", "audio_feature_lengths"},
|
| 616 |
-
fields_factory=_qwen3asr_field_config,
|
| 617 |
-
)
|
| 618 |
-
|
| 619 |
-
return super()._parse_audio_data(data)
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
class Qwen3ASRMultiModalProcessor(
|
| 623 |
-
Qwen3OmniMoeThinkerMultiModalProcessor,
|
| 624 |
-
):
|
| 625 |
-
def _get_data_parser(self) -> MultiModalDataParser:
|
| 626 |
-
feature_extractor = self.info.get_feature_extractor()
|
| 627 |
-
return Qwen3ASRMultiModalDataParser(
|
| 628 |
-
target_sr=feature_extractor.sampling_rate,
|
| 629 |
-
)
|
| 630 |
-
|
| 631 |
-
def _get_mm_fields_config(
|
| 632 |
-
self,
|
| 633 |
-
hf_inputs: BatchFeature,
|
| 634 |
-
hf_processor_mm_kwargs: Mapping[str, object],
|
| 635 |
-
) -> Mapping[str, MultiModalFieldConfig]:
|
| 636 |
-
return _qwen3asr_field_config(hf_inputs)
|
| 637 |
-
|
| 638 |
-
def _get_prompt_updates(
|
| 639 |
-
self,
|
| 640 |
-
mm_items: MultiModalDataItems,
|
| 641 |
-
hf_processor_mm_kwargs: Mapping[str, Any],
|
| 642 |
-
out_mm_kwargs: MultiModalKwargsItems,
|
| 643 |
-
) -> Sequence[PromptUpdate]:
|
| 644 |
-
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
| 645 |
-
tokenizer = self.info.get_tokenizer()
|
| 646 |
-
vocab = tokenizer.get_vocab()
|
| 647 |
-
|
| 648 |
-
audio_token = processor.audio_token
|
| 649 |
-
audio_token_id = vocab[audio_token]
|
| 650 |
-
|
| 651 |
-
out_mm_data = out_mm_kwargs.get_data()
|
| 652 |
-
audio_feature_lengths = out_mm_data.get("audio_feature_lengths")
|
| 653 |
-
feature_attention_mask = out_mm_data.get("feature_attention_mask")
|
| 654 |
-
if audio_feature_lengths is None and feature_attention_mask is None:
|
| 655 |
-
audio_output_lengths = []
|
| 656 |
-
elif audio_feature_lengths is not None:
|
| 657 |
-
audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths)
|
| 658 |
-
audio_output_lengths = audio_output_lens.tolist()
|
| 659 |
-
elif feature_attention_mask is not None:
|
| 660 |
-
assert isinstance(feature_attention_mask, torch.Tensor)
|
| 661 |
-
audio_output_lens = _get_feat_extract_output_lengths(
|
| 662 |
-
feature_attention_mask.sum(-1)
|
| 663 |
-
)
|
| 664 |
-
audio_output_lengths = audio_output_lens.tolist()
|
| 665 |
-
|
| 666 |
-
def get_replacement_qwen2_audio(item_idx: int):
|
| 667 |
-
num_features = audio_output_lengths[item_idx]
|
| 668 |
-
if num_features == 0:
|
| 669 |
-
audios = mm_items.get_items("audio", AudioProcessorItems)
|
| 670 |
-
audio = audios.get(item_idx)
|
| 671 |
-
raise ValueError(
|
| 672 |
-
f"The audio {audio} (len={len(audio)}) is too short "
|
| 673 |
-
"to be represented inside the model"
|
| 674 |
-
)
|
| 675 |
-
|
| 676 |
-
return [audio_token_id] * num_features
|
| 677 |
-
|
| 678 |
-
return [
|
| 679 |
-
PromptReplacement(
|
| 680 |
-
modality="audio",
|
| 681 |
-
target=audio_token,
|
| 682 |
-
replacement=get_replacement_qwen2_audio,
|
| 683 |
-
),
|
| 684 |
-
]
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
@MULTIMODAL_REGISTRY.register_processor(
|
| 688 |
-
Qwen3ASRMultiModalProcessor,
|
| 689 |
-
info=Qwen3ASRProcessingInfo,
|
| 690 |
-
dummy_inputs=Qwen3ASRDummyInputsBuilder,
|
| 691 |
-
)
|
| 692 |
-
class Qwen3ASRForConditionalGeneration(
|
| 693 |
-
nn.Module,
|
| 694 |
-
SupportsMultiModal,
|
| 695 |
-
SupportsPP,
|
| 696 |
-
SupportsMRoPE,
|
| 697 |
-
SupportsTranscription,
|
| 698 |
-
):
|
| 699 |
-
supported_languages = ISO639_1_SUPPORTED_LANGS
|
| 700 |
-
|
| 701 |
-
hf_to_vllm_mapper = WeightsMapper(
|
| 702 |
-
orig_to_new_prefix={
|
| 703 |
-
"thinker.lm_head.": "language_model.lm_head.",
|
| 704 |
-
"thinker.model.": "language_model.model.",
|
| 705 |
-
"thinker.": "",
|
| 706 |
-
}
|
| 707 |
-
)
|
| 708 |
-
|
| 709 |
-
@classmethod
|
| 710 |
-
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
| 711 |
-
if modality.startswith("audio"):
|
| 712 |
-
return "<|audio_start|><|audio_pad|><|audio_end|>"
|
| 713 |
-
|
| 714 |
-
raise ValueError("Only audio modality is supported")
|
| 715 |
-
|
| 716 |
-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 717 |
-
super().__init__()
|
| 718 |
-
self.vllm_config = vllm_config # needed for torch compile forward context
|
| 719 |
-
thinker_config: Qwen3ASRThinkerConfig = (
|
| 720 |
-
vllm_config.model_config.hf_config.thinker_config
|
| 721 |
-
)
|
| 722 |
-
quant_config = vllm_config.quant_config
|
| 723 |
-
multimodal_config = vllm_config.model_config.multimodal_config
|
| 724 |
-
self.config = thinker_config
|
| 725 |
-
self.multimodal_config = multimodal_config
|
| 726 |
-
|
| 727 |
-
self.audio_tower = Qwen3ASRAudioEncoder(
|
| 728 |
-
thinker_config.audio_config,
|
| 729 |
-
multimodal_config=multimodal_config,
|
| 730 |
-
prefix=maybe_prefix(prefix, "audio_tower"),
|
| 731 |
-
)
|
| 732 |
-
self.quant_config = quant_config
|
| 733 |
-
|
| 734 |
-
self.language_model = Qwen3ForCausalLM(
|
| 735 |
-
vllm_config=vllm_config.with_hf_config(
|
| 736 |
-
thinker_config.text_config, architectures=["Qwen3ForCausalLM"]
|
| 737 |
-
),
|
| 738 |
-
prefix=maybe_prefix(prefix, "language_model"),
|
| 739 |
-
)
|
| 740 |
-
|
| 741 |
-
self.make_empty_intermediate_tensors = (
|
| 742 |
-
self.language_model.make_empty_intermediate_tensors
|
| 743 |
-
)
|
| 744 |
-
|
| 745 |
-
def _parse_and_validate_audio_input(
|
| 746 |
-
self, **kwargs: object
|
| 747 |
-
) -> Qwen2_5OmniAudioFeatureInputs | None:
|
| 748 |
-
input_audio_features = kwargs.pop("input_audio_features", None)
|
| 749 |
-
audio_feature_lengths = kwargs.pop("audio_feature_lengths", None)
|
| 750 |
-
feature_attention_mask = kwargs.pop("feature_attention_mask", None)
|
| 751 |
-
if input_audio_features is None:
|
| 752 |
-
return None
|
| 753 |
-
|
| 754 |
-
return Qwen2_5OmniAudioFeatureInputs(
|
| 755 |
-
type="audio_features",
|
| 756 |
-
input_features=input_audio_features,
|
| 757 |
-
audio_feature_lengths=audio_feature_lengths,
|
| 758 |
-
feature_attention_mask=feature_attention_mask,
|
| 759 |
-
)
|
| 760 |
-
|
| 761 |
-
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
| 762 |
-
mm_input_by_modality = {}
|
| 763 |
-
|
| 764 |
-
# Preserve the order of modalities if there are multiple of them
|
| 765 |
-
# from the order of kwargs.
|
| 766 |
-
for input_key in kwargs:
|
| 767 |
-
if (
|
| 768 |
-
input_key in ("input_audio_features")
|
| 769 |
-
and "audio" not in mm_input_by_modality
|
| 770 |
-
):
|
| 771 |
-
mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
|
| 772 |
-
**kwargs
|
| 773 |
-
)
|
| 774 |
-
return mm_input_by_modality
|
| 775 |
-
|
| 776 |
-
def _process_audio_input(
|
| 777 |
-
self,
|
| 778 |
-
audio_input: Qwen2_5OmniAudioFeatureInputs,
|
| 779 |
-
audio_hashes: list[str] | None = None,
|
| 780 |
-
cached_audio_features: torch.Tensor | None = None,
|
| 781 |
-
) -> torch.Tensor:
|
| 782 |
-
input_features = audio_input["input_features"]
|
| 783 |
-
audio_feature_lengths = audio_input["audio_feature_lengths"]
|
| 784 |
-
|
| 785 |
-
audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths)
|
| 786 |
-
|
| 787 |
-
audio_features = self.audio_tower(
|
| 788 |
-
input_features.to(self.audio_tower.dtype),
|
| 789 |
-
feature_lens=audio_feature_lengths,
|
| 790 |
-
aftercnn_lens=audio_output_lengths,
|
| 791 |
-
)
|
| 792 |
-
return audio_features.split(audio_output_lengths.tolist())
|
| 793 |
-
|
| 794 |
-
def get_language_model(self) -> torch.nn.Module:
|
| 795 |
-
return self.language_model
|
| 796 |
-
|
| 797 |
-
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
|
| 798 |
-
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
|
| 799 |
-
if not mm_input_by_modality:
|
| 800 |
-
return []
|
| 801 |
-
|
| 802 |
-
# The result multimodal_embeddings is tuple of tensors, with each
|
| 803 |
-
# tensor correspoending to a multimodal data item (image or video).
|
| 804 |
-
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
| 805 |
-
|
| 806 |
-
# NOTE: It is important to iterate over the keys in this dictionary
|
| 807 |
-
# to preserve the order of the modalities.
|
| 808 |
-
for modality in mm_input_by_modality:
|
| 809 |
-
multimodal_input = mm_input_by_modality[modality]
|
| 810 |
-
if modality == "audio":
|
| 811 |
-
audio_embeddings = self._process_audio_input(multimodal_input)
|
| 812 |
-
multimodal_embeddings += tuple(audio_embeddings)
|
| 813 |
-
return multimodal_embeddings
|
| 814 |
-
|
| 815 |
-
def embed_input_ids(
|
| 816 |
-
self,
|
| 817 |
-
input_ids: torch.Tensor,
|
| 818 |
-
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
| 819 |
-
*,
|
| 820 |
-
is_multimodal: torch.Tensor | None = None,
|
| 821 |
-
handle_oov_mm_token: bool = False,
|
| 822 |
-
) -> torch.Tensor:
|
| 823 |
-
inputs_embeds = self._embed_text_input_ids(
|
| 824 |
-
input_ids,
|
| 825 |
-
self.language_model.embed_input_ids,
|
| 826 |
-
is_multimodal=is_multimodal,
|
| 827 |
-
handle_oov_mm_token=handle_oov_mm_token,
|
| 828 |
-
)
|
| 829 |
-
|
| 830 |
-
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
| 831 |
-
return inputs_embeds
|
| 832 |
-
|
| 833 |
-
inputs_embeds = _merge_multimodal_embeddings(
|
| 834 |
-
inputs_embeds=inputs_embeds,
|
| 835 |
-
multimodal_embeddings=multimodal_embeddings,
|
| 836 |
-
is_multimodal=is_multimodal,
|
| 837 |
-
)
|
| 838 |
-
|
| 839 |
-
return inputs_embeds
|
| 840 |
-
|
| 841 |
-
def forward(
|
| 842 |
-
self,
|
| 843 |
-
input_ids: torch.Tensor,
|
| 844 |
-
positions: torch.Tensor,
|
| 845 |
-
intermediate_tensors: IntermediateTensors | None = None,
|
| 846 |
-
inputs_embeds: torch.Tensor | None = None,
|
| 847 |
-
**kwargs: object,
|
| 848 |
-
) -> torch.Tensor | IntermediateTensors:
|
| 849 |
-
if intermediate_tensors is not None:
|
| 850 |
-
inputs_embeds = None
|
| 851 |
-
|
| 852 |
-
hidden_states = self.language_model.model(
|
| 853 |
-
input_ids,
|
| 854 |
-
positions,
|
| 855 |
-
intermediate_tensors,
|
| 856 |
-
inputs_embeds=inputs_embeds,
|
| 857 |
-
)
|
| 858 |
-
|
| 859 |
-
return hidden_states
|
| 860 |
-
|
| 861 |
-
def compute_logits(
|
| 862 |
-
self,
|
| 863 |
-
hidden_states: torch.Tensor,
|
| 864 |
-
) -> torch.Tensor | None:
|
| 865 |
-
return self.language_model.compute_logits(hidden_states)
|
| 866 |
-
|
| 867 |
-
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
| 868 |
-
loader = AutoWeightsLoader(
|
| 869 |
-
self,
|
| 870 |
-
skip_prefixes=["talker.", "code2wav."],
|
| 871 |
-
)
|
| 872 |
-
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
| 873 |
-
|
| 874 |
-
return loaded_weights
|
| 875 |
-
|
| 876 |
-
def get_mrope_input_positions(
|
| 877 |
-
self,
|
| 878 |
-
input_tokens: list[int],
|
| 879 |
-
mm_features: list[MultiModalFeatureSpec],
|
| 880 |
-
) -> tuple[torch.Tensor, int]:
|
| 881 |
-
seq_len = len(input_tokens)
|
| 882 |
-
|
| 883 |
-
if not mm_features:
|
| 884 |
-
# No audio features, just return linear positions
|
| 885 |
-
llm_positions = (
|
| 886 |
-
torch.arange(seq_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
| 887 |
-
)
|
| 888 |
-
return llm_positions.clone(), 0
|
| 889 |
-
|
| 890 |
-
llm_pos_ids_list: list[torch.Tensor] = []
|
| 891 |
-
st = 0
|
| 892 |
-
|
| 893 |
-
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
|
| 894 |
-
offset = mm_feature.mm_position.offset
|
| 895 |
-
|
| 896 |
-
# Get audio feature length from mm_feature data
|
| 897 |
-
audio_feature_length = mm_feature.data["audio_feature_lengths"].data
|
| 898 |
-
if isinstance(audio_feature_length, torch.Tensor):
|
| 899 |
-
audio_feature_length = audio_feature_length.item()
|
| 900 |
-
audio_len = _get_feat_extract_output_lengths(
|
| 901 |
-
torch.tensor(audio_feature_length)
|
| 902 |
-
).item()
|
| 903 |
-
|
| 904 |
-
# Text segment before audio (includes audio_start token)
|
| 905 |
-
text_len = offset - st
|
| 906 |
-
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
| 907 |
-
text_positions = (
|
| 908 |
-
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
| 909 |
-
+ st_idx
|
| 910 |
-
)
|
| 911 |
-
llm_pos_ids_list.append(text_positions)
|
| 912 |
-
st_idx = st_idx + text_len
|
| 913 |
-
|
| 914 |
-
# Audio token segment
|
| 915 |
-
audio_positions = (
|
| 916 |
-
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
| 917 |
-
+ st_idx
|
| 918 |
-
)
|
| 919 |
-
llm_pos_ids_list.append(audio_positions)
|
| 920 |
-
|
| 921 |
-
st = offset + audio_len
|
| 922 |
-
|
| 923 |
-
# Handle remaining text (includes audio_end and any trailing text)
|
| 924 |
-
if st < seq_len:
|
| 925 |
-
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
| 926 |
-
text_len = seq_len - st
|
| 927 |
-
final_text_positions = (
|
| 928 |
-
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
| 929 |
-
+ st_idx
|
| 930 |
-
)
|
| 931 |
-
llm_pos_ids_list.append(final_text_positions)
|
| 932 |
-
|
| 933 |
-
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
| 934 |
-
if llm_positions.shape[1] != seq_len:
|
| 935 |
-
raise RuntimeError("Position ids length mismatch with input ids length")
|
| 936 |
-
|
| 937 |
-
mrope_position_delta = (llm_positions.max() + 1 - seq_len).item()
|
| 938 |
-
return llm_positions, mrope_position_delta
|
| 939 |
-
|
| 940 |
-
def get_mm_mapping(self) -> MultiModelKeys:
|
| 941 |
-
"""
|
| 942 |
-
Get the module prefix in multimodal models
|
| 943 |
-
"""
|
| 944 |
-
return MultiModelKeys.from_string_field(
|
| 945 |
-
language_model="language_model",
|
| 946 |
-
tower_model=["audio_tower."],
|
| 947 |
-
)
|
| 948 |
-
|
| 949 |
-
@classmethod
|
| 950 |
-
def get_speech_to_text_config(
|
| 951 |
-
cls, model_config: ModelConfig, task_type: str
|
| 952 |
-
) -> SpeechToTextConfig:
|
| 953 |
-
processor = cached_processor_from_config(model_config)
|
| 954 |
-
feature_extractor: WhisperFeatureExtractor = processor.feature_extractor
|
| 955 |
-
return SpeechToTextConfig(
|
| 956 |
-
max_audio_clip_s=feature_extractor.chunk_length,
|
| 957 |
-
sample_rate=feature_extractor.sampling_rate,
|
| 958 |
-
)
|
| 959 |
-
|
| 960 |
-
@classmethod
|
| 961 |
-
def get_generation_prompt(
|
| 962 |
-
cls,
|
| 963 |
-
audio: np.ndarray,
|
| 964 |
-
model_config: ModelConfig,
|
| 965 |
-
stt_config: SpeechToTextConfig,
|
| 966 |
-
language: str | None,
|
| 967 |
-
task_type: Literal["transcribe", "translate"],
|
| 968 |
-
request_prompt: str,
|
| 969 |
-
to_language: str | None,
|
| 970 |
-
) -> PromptType:
|
| 971 |
-
"""Get the generation prompt to be used for transcription requests."""
|
| 972 |
-
tokenizer = cached_tokenizer_from_config(model_config)
|
| 973 |
-
audio_placeholder = cls.get_placeholder_str("audio", 0)
|
| 974 |
-
|
| 975 |
-
if task_type not in ("transcribe", "translate"):
|
| 976 |
-
raise ValueError(
|
| 977 |
-
f"Unsupported task_type '{task_type}'. "
|
| 978 |
-
"Supported task types are 'transcribe' and 'translate'."
|
| 979 |
-
)
|
| 980 |
-
full_lang_name_to = cls.supported_languages.get(to_language, to_language)
|
| 981 |
-
if to_language is None:
|
| 982 |
-
prompt = (
|
| 983 |
-
f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n"
|
| 984 |
-
f"<|im_start|>assistant\n"
|
| 985 |
-
)
|
| 986 |
-
else:
|
| 987 |
-
prompt = (
|
| 988 |
-
f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n"
|
| 989 |
-
f"<|im_start|>assistant\nlanguage {full_lang_name_to}<asr_text>"
|
| 990 |
-
)
|
| 991 |
-
|
| 992 |
-
prompt_token_ids = tokenizer.encode(prompt)
|
| 993 |
-
prompt_dict = {
|
| 994 |
-
"prompt_token_ids": prompt_token_ids,
|
| 995 |
-
"multi_modal_data": {"audio": audio},
|
| 996 |
-
}
|
| 997 |
-
return cast(PromptType, prompt_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_asr/inference/assets/korean_dict_jieba.dict
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
qwen_asr/inference/qwen3_asr.py
DELETED
|
@@ -1,519 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
from dataclasses import dataclass
|
| 17 |
-
from typing import Any, Dict, List, Optional, Union
|
| 18 |
-
|
| 19 |
-
import numpy as np
|
| 20 |
-
import torch
|
| 21 |
-
from qwen_asr.core.transformers_backend import (
|
| 22 |
-
Qwen3ASRConfig,
|
| 23 |
-
Qwen3ASRForConditionalGeneration,
|
| 24 |
-
Qwen3ASRProcessor,
|
| 25 |
-
)
|
| 26 |
-
from transformers import AutoConfig, AutoModel, AutoProcessor
|
| 27 |
-
|
| 28 |
-
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
|
| 29 |
-
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
|
| 30 |
-
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
|
| 31 |
-
|
| 32 |
-
from .qwen3_forced_aligner import Qwen3ForcedAligner
|
| 33 |
-
from .utils import (
|
| 34 |
-
MAX_ASR_INPUT_SECONDS,
|
| 35 |
-
MAX_FORCE_ALIGN_INPUT_SECONDS,
|
| 36 |
-
SAMPLE_RATE,
|
| 37 |
-
SUPPORTED_LANGUAGES,
|
| 38 |
-
AudioChunk,
|
| 39 |
-
AudioLike,
|
| 40 |
-
chunk_list,
|
| 41 |
-
merge_languages,
|
| 42 |
-
normalize_audios,
|
| 43 |
-
normalize_language_name,
|
| 44 |
-
parse_asr_output,
|
| 45 |
-
split_audio_into_chunks,
|
| 46 |
-
validate_language,
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
try:
|
| 50 |
-
from qwen_asr.core.vllm_backend import Qwen3ASRForConditionalGeneration
|
| 51 |
-
from vllm import ModelRegistry
|
| 52 |
-
ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", Qwen3ASRForConditionalGeneration)
|
| 53 |
-
except:
|
| 54 |
-
pass
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
@dataclass
|
| 58 |
-
class ASRTranscription:
|
| 59 |
-
"""
|
| 60 |
-
One transcription result.
|
| 61 |
-
|
| 62 |
-
Attributes:
|
| 63 |
-
language (str):
|
| 64 |
-
Merged language string for the sample, e.g. "Chinese" or "Chinese,English".
|
| 65 |
-
Empty string if unknown or silent audio.
|
| 66 |
-
text (str):
|
| 67 |
-
Transcribed text.
|
| 68 |
-
time_stamps (Optional[Any]):
|
| 69 |
-
Forced aligner output (ForcedAlignResult).
|
| 70 |
-
Present only when return_time_stamps=True.
|
| 71 |
-
"""
|
| 72 |
-
language: str
|
| 73 |
-
text: str
|
| 74 |
-
time_stamps: Optional[Any] = None
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
class Qwen3ASRModel:
|
| 78 |
-
"""
|
| 79 |
-
Unified inference wrapper for Qwen3-ASR with two backends:
|
| 80 |
-
- Transformers backend
|
| 81 |
-
- vLLM backend
|
| 82 |
-
|
| 83 |
-
It optionally supports time stamp output via Qwen3-ForcedAligner.
|
| 84 |
-
|
| 85 |
-
Notes:
|
| 86 |
-
- Each request uses a context text and exactly one audio.
|
| 87 |
-
- If language is provided, the prompt will force the output to be text-only by appending
|
| 88 |
-
"language {Language}<asr_text>" to the assistant prompt.
|
| 89 |
-
"""
|
| 90 |
-
|
| 91 |
-
def __init__(
|
| 92 |
-
self,
|
| 93 |
-
backend: str,
|
| 94 |
-
model: Any,
|
| 95 |
-
processor: Any,
|
| 96 |
-
sampling_params: Optional[Any] = None,
|
| 97 |
-
forced_aligner: Optional[Qwen3ForcedAligner] = None,
|
| 98 |
-
max_inference_batch_size: int = -1,
|
| 99 |
-
):
|
| 100 |
-
self.backend = backend # "transformers" | "vllm"
|
| 101 |
-
self.model = model
|
| 102 |
-
self.processor = processor
|
| 103 |
-
self.sampling_params = sampling_params
|
| 104 |
-
self.forced_aligner = forced_aligner
|
| 105 |
-
self.max_inference_batch_size = int(max_inference_batch_size)
|
| 106 |
-
|
| 107 |
-
if backend == "transformers":
|
| 108 |
-
self.device = getattr(model, "device", None)
|
| 109 |
-
if self.device is None:
|
| 110 |
-
try:
|
| 111 |
-
self.device = next(model.parameters()).device
|
| 112 |
-
except StopIteration:
|
| 113 |
-
self.device = torch.device("cpu")
|
| 114 |
-
self.dtype = getattr(model, "dtype", torch.float32)
|
| 115 |
-
else:
|
| 116 |
-
self.device = None
|
| 117 |
-
self.dtype = None
|
| 118 |
-
|
| 119 |
-
@classmethod
|
| 120 |
-
def from_pretrained(
|
| 121 |
-
cls,
|
| 122 |
-
pretrained_model_name_or_path: str,
|
| 123 |
-
forced_aligner: Optional[str] = None,
|
| 124 |
-
forced_aligner_kwargs: Optional[Dict[str, Any]] = None,
|
| 125 |
-
max_inference_batch_size: int = -1,
|
| 126 |
-
**kwargs,
|
| 127 |
-
) -> "Qwen3ASRModel":
|
| 128 |
-
"""
|
| 129 |
-
Initialize using Transformers backend.
|
| 130 |
-
|
| 131 |
-
Args:
|
| 132 |
-
pretrained_model_name_or_path:
|
| 133 |
-
HuggingFace repo id or local directory.
|
| 134 |
-
forced_aligner:
|
| 135 |
-
Optional forced aligner model path/repo id.
|
| 136 |
-
forced_aligner_kwargs:
|
| 137 |
-
Optional kwargs forwarded to Qwen3ForcedAligner.from_pretrained(...).
|
| 138 |
-
max_inference_batch_size:
|
| 139 |
-
Batch size limit for inference. -1 means no chunking. Small values can avoid OOM.
|
| 140 |
-
**kwargs:
|
| 141 |
-
Forwarded to AutoModel.from_pretrained(...).
|
| 142 |
-
|
| 143 |
-
Returns:
|
| 144 |
-
Qwen3ASRModel
|
| 145 |
-
"""
|
| 146 |
-
|
| 147 |
-
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 148 |
-
|
| 149 |
-
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True)
|
| 150 |
-
|
| 151 |
-
if forced_aligner is not None:
|
| 152 |
-
forced_aligner_model = Qwen3ForcedAligner.from_pretrained(
|
| 153 |
-
forced_aligner, **(forced_aligner_kwargs or {})
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
return cls(
|
| 157 |
-
backend="transformers",
|
| 158 |
-
model=model,
|
| 159 |
-
processor=processor,
|
| 160 |
-
sampling_params=None,
|
| 161 |
-
forced_aligner=forced_aligner_model,
|
| 162 |
-
max_inference_batch_size=max_inference_batch_size,
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
@classmethod
|
| 166 |
-
def LLM(
|
| 167 |
-
cls,
|
| 168 |
-
model: str,
|
| 169 |
-
forced_aligner: Optional[str] = None,
|
| 170 |
-
forced_aligner_kwargs: Optional[Dict[str, Any]] = None,
|
| 171 |
-
max_inference_batch_size: int = -1,
|
| 172 |
-
max_new_tokens: Optional[int] = 8192,
|
| 173 |
-
**kwargs,
|
| 174 |
-
) -> "Qwen3ASRModel":
|
| 175 |
-
"""
|
| 176 |
-
Initialize using vLLM backend.
|
| 177 |
-
|
| 178 |
-
Import is isolated to keep vLLM optional.
|
| 179 |
-
|
| 180 |
-
Args:
|
| 181 |
-
model:
|
| 182 |
-
Model path/repo for vLLM.
|
| 183 |
-
forced_aligner:
|
| 184 |
-
Optional forced aligner model path/repo id.
|
| 185 |
-
forced_aligner_kwargs:
|
| 186 |
-
Optional kwargs forwarded to Qwen3ForcedAligner.from_pretrained(...).
|
| 187 |
-
max_inference_batch_size:
|
| 188 |
-
Batch size limit for inference. -1 means no chunking. Small values can avoid OOM.
|
| 189 |
-
max_new_tokens:
|
| 190 |
-
Maximum number of tokens to generate.
|
| 191 |
-
**kwargs:
|
| 192 |
-
Forwarded to vllm.LLM(...).
|
| 193 |
-
|
| 194 |
-
Returns:
|
| 195 |
-
Qwen3ASRModel
|
| 196 |
-
|
| 197 |
-
Raises:
|
| 198 |
-
ImportError: If vLLM is not installed.
|
| 199 |
-
"""
|
| 200 |
-
try:
|
| 201 |
-
from vllm import LLM as vLLM
|
| 202 |
-
from vllm import SamplingParams
|
| 203 |
-
except Exception as e:
|
| 204 |
-
raise ImportError(
|
| 205 |
-
"vLLM is not available. Install with: pip install qwen-asr[vllm]"
|
| 206 |
-
) from e
|
| 207 |
-
|
| 208 |
-
llm = vLLM(model=model, **kwargs)
|
| 209 |
-
|
| 210 |
-
processor = Qwen3ASRProcessor.from_pretrained(model, fix_mistral_regex=True)
|
| 211 |
-
sampling_params = SamplingParams(**({"temperature": 0.0, "max_tokens": max_new_tokens}))
|
| 212 |
-
|
| 213 |
-
if forced_aligner is not None:
|
| 214 |
-
forced_aligner_model = Qwen3ForcedAligner.from_pretrained(
|
| 215 |
-
forced_aligner, **(forced_aligner_kwargs or {})
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
-
return cls(
|
| 219 |
-
backend="vllm",
|
| 220 |
-
model=llm,
|
| 221 |
-
processor=processor,
|
| 222 |
-
sampling_params=sampling_params,
|
| 223 |
-
forced_aligner=forced_aligner_model,
|
| 224 |
-
max_inference_batch_size=max_inference_batch_size,
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
def get_supported_languages(self) -> List[str]:
|
| 228 |
-
"""
|
| 229 |
-
Returns the supported language list.
|
| 230 |
-
|
| 231 |
-
Returns:
|
| 232 |
-
List[str]: Canonical language names.
|
| 233 |
-
"""
|
| 234 |
-
return list(SUPPORTED_LANGUAGES)
|
| 235 |
-
|
| 236 |
-
@torch.no_grad()
|
| 237 |
-
def transcribe(
|
| 238 |
-
self,
|
| 239 |
-
audio: Union[AudioLike, List[AudioLike]],
|
| 240 |
-
context: Union[str, List[str]] = "",
|
| 241 |
-
language: Optional[Union[str, List[Optional[str]]]] = None,
|
| 242 |
-
return_time_stamps: bool = False,
|
| 243 |
-
) -> List[ASRTranscription]:
|
| 244 |
-
"""
|
| 245 |
-
Transcribe audio with optional context and optional forced alignment timestamps.
|
| 246 |
-
|
| 247 |
-
Args:
|
| 248 |
-
audio:
|
| 249 |
-
Audio input(s). Supported:
|
| 250 |
-
- str: local path / URL / base64 data url
|
| 251 |
-
- (np.ndarray, sr)
|
| 252 |
-
- list of above
|
| 253 |
-
context:
|
| 254 |
-
Context string(s). If scalar, it will be broadcast to batch size.
|
| 255 |
-
language:
|
| 256 |
-
Optional language(s). If provided, it must be in supported languages.
|
| 257 |
-
If scalar, it will be broadcast to batch size.
|
| 258 |
-
If provided, the prompt will force output to be transcription text only.
|
| 259 |
-
return_time_stamps:
|
| 260 |
-
If True, timestamps are produced via forced aligner and merged across chunks.
|
| 261 |
-
This requires forced_aligner initialized.
|
| 262 |
-
|
| 263 |
-
Returns:
|
| 264 |
-
List[ASRTranscription]: One result per input audio.
|
| 265 |
-
|
| 266 |
-
Raises:
|
| 267 |
-
ValueError:
|
| 268 |
-
- If return_time_stamps=True but forced_aligner is not provided.
|
| 269 |
-
- If language is unsupported.
|
| 270 |
-
- If batch sizes mismatch for context/language.
|
| 271 |
-
"""
|
| 272 |
-
if return_time_stamps and self.forced_aligner is None:
|
| 273 |
-
raise ValueError("return_time_stamps=True requires `forced_aligner` to be provided at initialization.")
|
| 274 |
-
|
| 275 |
-
wavs = normalize_audios(audio)
|
| 276 |
-
n = len(wavs)
|
| 277 |
-
|
| 278 |
-
ctxs = context if isinstance(context, list) else [context]
|
| 279 |
-
if len(ctxs) == 1 and n > 1:
|
| 280 |
-
ctxs = ctxs * n
|
| 281 |
-
if len(ctxs) != n:
|
| 282 |
-
raise ValueError(f"Batch size mismatch: audio={n}, context={len(ctxs)}")
|
| 283 |
-
|
| 284 |
-
langs_in: List[Optional[str]]
|
| 285 |
-
if language is None:
|
| 286 |
-
langs_in = [None] * n
|
| 287 |
-
else:
|
| 288 |
-
langs_in = language if isinstance(language, list) else [language]
|
| 289 |
-
if len(langs_in) == 1 and n > 1:
|
| 290 |
-
langs_in = langs_in * n
|
| 291 |
-
if len(langs_in) != n:
|
| 292 |
-
raise ValueError(f"Batch size mismatch: audio={n}, language={len(langs_in)}")
|
| 293 |
-
|
| 294 |
-
langs_norm: List[Optional[str]] = []
|
| 295 |
-
for l in langs_in:
|
| 296 |
-
if l is None or str(l).strip() == "":
|
| 297 |
-
langs_norm.append(None)
|
| 298 |
-
else:
|
| 299 |
-
ln = normalize_language_name(str(l))
|
| 300 |
-
validate_language(ln)
|
| 301 |
-
langs_norm.append(ln)
|
| 302 |
-
|
| 303 |
-
max_chunk_sec = MAX_FORCE_ALIGN_INPUT_SECONDS if return_time_stamps else MAX_ASR_INPUT_SECONDS
|
| 304 |
-
|
| 305 |
-
# chunk audios and record mapping
|
| 306 |
-
chunks: List[AudioChunk] = []
|
| 307 |
-
for i, wav in enumerate(wavs):
|
| 308 |
-
parts = split_audio_into_chunks(
|
| 309 |
-
wav=wav,
|
| 310 |
-
sr=SAMPLE_RATE,
|
| 311 |
-
max_chunk_sec=max_chunk_sec,
|
| 312 |
-
)
|
| 313 |
-
for j, (cwav, offset_sec) in enumerate(parts):
|
| 314 |
-
chunks.append(AudioChunk(orig_index=i, chunk_index=j, wav=cwav, sr=SAMPLE_RATE, offset_sec=offset_sec))
|
| 315 |
-
|
| 316 |
-
# run ASR on chunks
|
| 317 |
-
chunk_ctx: List[str] = [ctxs[c.orig_index] for c in chunks]
|
| 318 |
-
chunk_lang: List[Optional[str]] = [langs_norm[c.orig_index] for c in chunks]
|
| 319 |
-
chunk_wavs: List[np.ndarray] = [c.wav for c in chunks]
|
| 320 |
-
raw_outputs = self._infer_asr(chunk_ctx, chunk_wavs, chunk_lang)
|
| 321 |
-
|
| 322 |
-
# parse outputs, prepare for optional alignment
|
| 323 |
-
per_chunk_lang: List[str] = []
|
| 324 |
-
per_chunk_text: List[str] = []
|
| 325 |
-
for out, forced_lang in zip(raw_outputs, chunk_lang):
|
| 326 |
-
lang, txt = parse_asr_output(out, user_language=forced_lang)
|
| 327 |
-
per_chunk_lang.append(lang)
|
| 328 |
-
per_chunk_text.append(txt)
|
| 329 |
-
|
| 330 |
-
# forced alignment (optional)
|
| 331 |
-
per_chunk_align: List[Optional[Any]] = [None] * len(chunks)
|
| 332 |
-
if return_time_stamps:
|
| 333 |
-
to_align_audio = []
|
| 334 |
-
to_align_text = []
|
| 335 |
-
to_align_lang = []
|
| 336 |
-
to_align_idx = []
|
| 337 |
-
|
| 338 |
-
for idx, (c, txt, lang_pred) in enumerate(zip(chunks, per_chunk_text, per_chunk_lang)):
|
| 339 |
-
if txt.strip() == "":
|
| 340 |
-
continue
|
| 341 |
-
to_align_audio.append((c.wav, c.sr))
|
| 342 |
-
to_align_text.append(txt)
|
| 343 |
-
to_align_lang.append(lang_pred)
|
| 344 |
-
to_align_idx.append(idx)
|
| 345 |
-
|
| 346 |
-
# batch align with max_inference_batch_size
|
| 347 |
-
aligned_results: List[Any] = []
|
| 348 |
-
for a_chunk, t_chunk, l_chunk in zip(
|
| 349 |
-
chunk_list(to_align_audio, self.max_inference_batch_size),
|
| 350 |
-
chunk_list(to_align_text, self.max_inference_batch_size),
|
| 351 |
-
chunk_list(to_align_lang, self.max_inference_batch_size),
|
| 352 |
-
):
|
| 353 |
-
aligned_results.extend(
|
| 354 |
-
self.forced_aligner.align(audio=a_chunk, text=t_chunk, language=l_chunk)
|
| 355 |
-
)
|
| 356 |
-
|
| 357 |
-
# offset fix
|
| 358 |
-
for k, idx in enumerate(to_align_idx):
|
| 359 |
-
c = chunks[idx]
|
| 360 |
-
r = aligned_results[k]
|
| 361 |
-
per_chunk_align[idx] = self._offset_align_result(r, c.offset_sec)
|
| 362 |
-
|
| 363 |
-
# merge chunks back to original samples
|
| 364 |
-
out_langs: List[List[str]] = [[] for _ in range(n)]
|
| 365 |
-
out_texts: List[List[str]] = [[] for _ in range(n)]
|
| 366 |
-
out_aligns: List[List[Any]] = [[] for _ in range(n)]
|
| 367 |
-
|
| 368 |
-
for c, lang, txt, al in zip(chunks, per_chunk_lang, per_chunk_text, per_chunk_align):
|
| 369 |
-
out_langs[c.orig_index].append(lang)
|
| 370 |
-
out_texts[c.orig_index].append(txt)
|
| 371 |
-
if return_time_stamps and al is not None:
|
| 372 |
-
out_aligns[c.orig_index].append(al)
|
| 373 |
-
|
| 374 |
-
results: List[ASRTranscription] = []
|
| 375 |
-
for i in range(n):
|
| 376 |
-
merged_text = "".join([t for t in out_texts[i] if t is not None])
|
| 377 |
-
merged_language = merge_languages(out_langs[i])
|
| 378 |
-
merged_align = None
|
| 379 |
-
if return_time_stamps:
|
| 380 |
-
merged_align = self._merge_align_results(out_aligns[i])
|
| 381 |
-
results.append(ASRTranscription(language=merged_language, text=merged_text, time_stamps=merged_align))
|
| 382 |
-
|
| 383 |
-
return results
|
| 384 |
-
|
| 385 |
-
def _build_messages(self, context: str, audio_payload: Any) -> List[Dict[str, Any]]:
|
| 386 |
-
return [
|
| 387 |
-
{"role": "system", "content": context or ""},
|
| 388 |
-
{"role": "user", "content": [{"type": "audio", "audio": audio_payload}]},
|
| 389 |
-
]
|
| 390 |
-
|
| 391 |
-
def _build_text_prompt(self, context: str, force_language: Optional[str]) -> str:
|
| 392 |
-
"""
|
| 393 |
-
Build the string prompt for one request.
|
| 394 |
-
|
| 395 |
-
If force_language is provided, "language X<asr_text>" is appended after the generation prompt
|
| 396 |
-
to request text-only output.
|
| 397 |
-
"""
|
| 398 |
-
msgs = self._build_messages(context=context, audio_payload="")
|
| 399 |
-
base = self.processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
|
| 400 |
-
if force_language:
|
| 401 |
-
base = base + f"language {force_language}{'<asr_text>'}"
|
| 402 |
-
return base
|
| 403 |
-
|
| 404 |
-
def _infer_asr(
|
| 405 |
-
self,
|
| 406 |
-
contexts: List[str],
|
| 407 |
-
wavs: List[np.ndarray],
|
| 408 |
-
languages: List[Optional[str]],
|
| 409 |
-
) -> List[str]:
|
| 410 |
-
"""
|
| 411 |
-
Run backend inference for chunk-level items.
|
| 412 |
-
|
| 413 |
-
Args:
|
| 414 |
-
contexts: List of system context strings.
|
| 415 |
-
wavs: List of mono waveforms (np.ndarray).
|
| 416 |
-
languages: List of forced languages or None.
|
| 417 |
-
|
| 418 |
-
Returns:
|
| 419 |
-
List[str]: Raw decoded strings (one per chunk).
|
| 420 |
-
"""
|
| 421 |
-
if self.backend == "transformers":
|
| 422 |
-
return self._infer_asr_transformers(contexts, wavs, languages)
|
| 423 |
-
if self.backend == "vllm":
|
| 424 |
-
return self._infer_asr_vllm(contexts, wavs, languages)
|
| 425 |
-
raise RuntimeError(f"Unknown backend: {self.backend}")
|
| 426 |
-
|
| 427 |
-
def _infer_asr_transformers(
|
| 428 |
-
self,
|
| 429 |
-
contexts: List[str],
|
| 430 |
-
wavs: List[np.ndarray],
|
| 431 |
-
languages: List[Optional[str]],
|
| 432 |
-
) -> List[str]:
|
| 433 |
-
outs: List[str] = []
|
| 434 |
-
|
| 435 |
-
texts = [self._build_text_prompt(context=c, force_language=fl) for c, fl in zip(contexts, languages)]
|
| 436 |
-
|
| 437 |
-
batch_size = self.max_inference_batch_size
|
| 438 |
-
if batch_size is None or batch_size < 0:
|
| 439 |
-
batch_size = len(texts)
|
| 440 |
-
|
| 441 |
-
for i in range(0, len(texts), batch_size):
|
| 442 |
-
sub_text = texts[i : i + batch_size]
|
| 443 |
-
sub_wavs = wavs[i : i + batch_size]
|
| 444 |
-
inputs = self.processor(text=sub_text, audio=sub_wavs, return_tensors="pt", padding=True)
|
| 445 |
-
inputs = inputs.to(self.model.device).to(self.model.dtype)
|
| 446 |
-
|
| 447 |
-
text_ids = self.model.generate(**inputs)
|
| 448 |
-
|
| 449 |
-
decoded = self.processor.batch_decode(
|
| 450 |
-
text_ids.sequences[:, inputs["input_ids"].shape[1]:],
|
| 451 |
-
skip_special_tokens=True,
|
| 452 |
-
clean_up_tokenization_spaces=False,
|
| 453 |
-
)
|
| 454 |
-
outs.extend(list(decoded))
|
| 455 |
-
|
| 456 |
-
return outs
|
| 457 |
-
|
| 458 |
-
def _infer_asr_vllm(
|
| 459 |
-
self,
|
| 460 |
-
contexts: List[str],
|
| 461 |
-
wavs: List[np.ndarray],
|
| 462 |
-
languages: List[Optional[str]],
|
| 463 |
-
) -> List[str]:
|
| 464 |
-
inputs: List[Dict[str, Any]] = []
|
| 465 |
-
for c, w, fl in zip(contexts, wavs, languages):
|
| 466 |
-
prompt = self._build_text_prompt(context=c, force_language=fl)
|
| 467 |
-
inputs.append({"prompt": prompt, "multi_modal_data": {"audio": [w]}})
|
| 468 |
-
|
| 469 |
-
outs: List[str] = []
|
| 470 |
-
for batch in chunk_list(inputs, self.max_inference_batch_size):
|
| 471 |
-
outputs = self.model.generate(batch, sampling_params=self.sampling_params, use_tqdm=False)
|
| 472 |
-
for o in outputs:
|
| 473 |
-
outs.append(o.outputs[0].text)
|
| 474 |
-
return outs
|
| 475 |
-
|
| 476 |
-
def _offset_align_result(self, result: Any, offset_sec: float) -> Any:
|
| 477 |
-
"""
|
| 478 |
-
Apply time offset to a ForcedAlignResult-like object.
|
| 479 |
-
|
| 480 |
-
This function assumes:
|
| 481 |
-
- result has attribute `.items` which is a list of items with start_time/end_time in seconds.
|
| 482 |
-
- dataclasses are frozen in upstream implementation, so we reconstruct by type.
|
| 483 |
-
|
| 484 |
-
Args:
|
| 485 |
-
result: ForcedAlignResult
|
| 486 |
-
offset_sec: Offset in seconds
|
| 487 |
-
|
| 488 |
-
Returns:
|
| 489 |
-
ForcedAlignResult: New object with shifted timestamps.
|
| 490 |
-
"""
|
| 491 |
-
if result is None:
|
| 492 |
-
return None
|
| 493 |
-
items = []
|
| 494 |
-
for it in result.items:
|
| 495 |
-
items.append(type(it)(text=it.text,
|
| 496 |
-
start_time=round(it.start_time + offset_sec, 3),
|
| 497 |
-
end_time=round(it.end_time + offset_sec, 3)))
|
| 498 |
-
return type(result)(items=items)
|
| 499 |
-
|
| 500 |
-
def _merge_align_results(self, results: List[Any]) -> Optional[Any]:
|
| 501 |
-
"""
|
| 502 |
-
Merge multiple ForcedAlignResult objects into a single one by concatenating items.
|
| 503 |
-
|
| 504 |
-
Args:
|
| 505 |
-
results: List of ForcedAlignResult
|
| 506 |
-
|
| 507 |
-
Returns:
|
| 508 |
-
ForcedAlignResult or None
|
| 509 |
-
"""
|
| 510 |
-
if not results:
|
| 511 |
-
return None
|
| 512 |
-
all_items = []
|
| 513 |
-
for r in results:
|
| 514 |
-
if r is None:
|
| 515 |
-
continue
|
| 516 |
-
all_items.extend(list(r.items))
|
| 517 |
-
if not all_items:
|
| 518 |
-
return None
|
| 519 |
-
return type(results[0])(items=all_items)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_asr/inference/qwen3_forced_aligner.py
DELETED
|
@@ -1,484 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
import os
|
| 17 |
-
import unicodedata
|
| 18 |
-
from dataclasses import dataclass
|
| 19 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 20 |
-
|
| 21 |
-
import nagisa
|
| 22 |
-
import numpy as np
|
| 23 |
-
import torch
|
| 24 |
-
from qwen_asr.core.transformers_backend import (
|
| 25 |
-
Qwen3ASRConfig,
|
| 26 |
-
Qwen3ASRForConditionalGeneration,
|
| 27 |
-
Qwen3ASRProcessor,
|
| 28 |
-
)
|
| 29 |
-
from transformers import AutoConfig, AutoModel, AutoProcessor
|
| 30 |
-
|
| 31 |
-
from .utils import (
|
| 32 |
-
AudioLike,
|
| 33 |
-
ensure_list,
|
| 34 |
-
normalize_audios,
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class Qwen3ForceAlignProcessor():
|
| 39 |
-
def __init__(self):
|
| 40 |
-
ko_dict_path = os.path.join(os.path.dirname(__file__), "assets", "korean_dict_jieba.dict")
|
| 41 |
-
ko_scores = {}
|
| 42 |
-
with open(ko_dict_path, "r", encoding="utf-8") as f:
|
| 43 |
-
for line in f:
|
| 44 |
-
line = line.strip()
|
| 45 |
-
if not line:
|
| 46 |
-
continue
|
| 47 |
-
word = line.split()[0]
|
| 48 |
-
ko_scores[word] = 1.0
|
| 49 |
-
self.ko_score = ko_scores
|
| 50 |
-
self.ko_tokenizer = None
|
| 51 |
-
|
| 52 |
-
def is_kept_char(self, ch: str) -> bool:
|
| 53 |
-
if ch == "'":
|
| 54 |
-
return True
|
| 55 |
-
cat = unicodedata.category(ch)
|
| 56 |
-
if cat.startswith("L") or cat.startswith("N"):
|
| 57 |
-
return True
|
| 58 |
-
return False
|
| 59 |
-
|
| 60 |
-
def clean_token(self, token: str) -> str:
|
| 61 |
-
return "".join(ch for ch in token if self.is_kept_char(ch))
|
| 62 |
-
|
| 63 |
-
def is_cjk_char(self, ch: str) -> bool:
|
| 64 |
-
code = ord(ch)
|
| 65 |
-
return (
|
| 66 |
-
0x4E00 <= code <= 0x9FFF # CJK Unified Ideographs
|
| 67 |
-
or 0x3400 <= code <= 0x4DBF # Extension A
|
| 68 |
-
or 0x20000 <= code <= 0x2A6DF # Extension B
|
| 69 |
-
or 0x2A700 <= code <= 0x2B73F # Extension C
|
| 70 |
-
or 0x2B740 <= code <= 0x2B81F # Extension D
|
| 71 |
-
or 0x2B820 <= code <= 0x2CEAF # Extension E
|
| 72 |
-
or 0xF900 <= code <= 0xFAFF # Compatibility Ideographs
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
def tokenize_chinese_mixed(self, text: str) -> List[str]:
|
| 76 |
-
tokens: List[str] = []
|
| 77 |
-
current_latin: List[str] = []
|
| 78 |
-
|
| 79 |
-
def flush_latin():
|
| 80 |
-
nonlocal current_latin
|
| 81 |
-
if current_latin:
|
| 82 |
-
token = "".join(current_latin)
|
| 83 |
-
cleaned = self.clean_token(token)
|
| 84 |
-
if cleaned:
|
| 85 |
-
tokens.append(cleaned)
|
| 86 |
-
current_latin = []
|
| 87 |
-
|
| 88 |
-
for ch in text:
|
| 89 |
-
if self.is_cjk_char(ch):
|
| 90 |
-
flush_latin()
|
| 91 |
-
tokens.append(ch)
|
| 92 |
-
else:
|
| 93 |
-
if self.is_kept_char(ch):
|
| 94 |
-
current_latin.append(ch)
|
| 95 |
-
else:
|
| 96 |
-
flush_latin()
|
| 97 |
-
|
| 98 |
-
flush_latin()
|
| 99 |
-
|
| 100 |
-
return tokens
|
| 101 |
-
|
| 102 |
-
def tokenize_japanese(self, text: str) -> List[str]:
|
| 103 |
-
words = nagisa.tagging(text).words
|
| 104 |
-
tokens: List[str] = []
|
| 105 |
-
for w in words:
|
| 106 |
-
cleaned = self.clean_token(w)
|
| 107 |
-
if cleaned:
|
| 108 |
-
tokens.append(cleaned)
|
| 109 |
-
return tokens
|
| 110 |
-
|
| 111 |
-
def tokenize_korean(self, ko_tokenizer, text: str) -> List[str]:
|
| 112 |
-
raw_tokens = ko_tokenizer.tokenize(text)
|
| 113 |
-
tokens: List[str] = []
|
| 114 |
-
for w in raw_tokens:
|
| 115 |
-
w_clean = self.clean_token(w)
|
| 116 |
-
if w_clean:
|
| 117 |
-
tokens.append(w_clean)
|
| 118 |
-
return tokens
|
| 119 |
-
|
| 120 |
-
def split_segment_with_chinese(self, seg: str) -> List[str]:
|
| 121 |
-
tokens: List[str] = []
|
| 122 |
-
buf: List[str] = []
|
| 123 |
-
|
| 124 |
-
def flush_buf():
|
| 125 |
-
nonlocal buf
|
| 126 |
-
if buf:
|
| 127 |
-
tokens.append("".join(buf))
|
| 128 |
-
buf = []
|
| 129 |
-
|
| 130 |
-
for ch in seg:
|
| 131 |
-
if self.is_cjk_char(ch):
|
| 132 |
-
flush_buf()
|
| 133 |
-
tokens.append(ch)
|
| 134 |
-
else:
|
| 135 |
-
buf.append(ch)
|
| 136 |
-
|
| 137 |
-
flush_buf()
|
| 138 |
-
return tokens
|
| 139 |
-
|
| 140 |
-
def tokenize_space_lang(self, text: str) -> List[str]:
|
| 141 |
-
tokens: List[str] = []
|
| 142 |
-
for seg in text.split():
|
| 143 |
-
cleaned = self.clean_token(seg)
|
| 144 |
-
if cleaned:
|
| 145 |
-
tokens.extend(self.split_segment_with_chinese(cleaned))
|
| 146 |
-
return tokens
|
| 147 |
-
|
| 148 |
-
def fix_timestamp(self, data) -> List[int]:
|
| 149 |
-
data = data.tolist()
|
| 150 |
-
n = len(data)
|
| 151 |
-
|
| 152 |
-
dp = [1] * n
|
| 153 |
-
parent = [-1] * n
|
| 154 |
-
|
| 155 |
-
for i in range(1, n):
|
| 156 |
-
for j in range(i):
|
| 157 |
-
if data[j] <= data[i] and dp[j] + 1 > dp[i]:
|
| 158 |
-
dp[i] = dp[j] + 1
|
| 159 |
-
parent[i] = j
|
| 160 |
-
|
| 161 |
-
max_length = max(dp)
|
| 162 |
-
max_idx = dp.index(max_length)
|
| 163 |
-
|
| 164 |
-
lis_indices = []
|
| 165 |
-
idx = max_idx
|
| 166 |
-
while idx != -1:
|
| 167 |
-
lis_indices.append(idx)
|
| 168 |
-
idx = parent[idx]
|
| 169 |
-
lis_indices.reverse()
|
| 170 |
-
|
| 171 |
-
is_normal = [False] * n
|
| 172 |
-
for idx in lis_indices:
|
| 173 |
-
is_normal[idx] = True
|
| 174 |
-
|
| 175 |
-
result = data.copy()
|
| 176 |
-
i = 0
|
| 177 |
-
|
| 178 |
-
while i < n:
|
| 179 |
-
if not is_normal[i]:
|
| 180 |
-
j = i
|
| 181 |
-
while j < n and not is_normal[j]:
|
| 182 |
-
j += 1
|
| 183 |
-
|
| 184 |
-
anomaly_count = j - i
|
| 185 |
-
|
| 186 |
-
if anomaly_count <= 2:
|
| 187 |
-
left_val = None
|
| 188 |
-
for k in range(i - 1, -1, -1):
|
| 189 |
-
if is_normal[k]:
|
| 190 |
-
left_val = result[k]
|
| 191 |
-
break
|
| 192 |
-
|
| 193 |
-
right_val = None
|
| 194 |
-
for k in range(j, n):
|
| 195 |
-
if is_normal[k]:
|
| 196 |
-
right_val = result[k]
|
| 197 |
-
break
|
| 198 |
-
|
| 199 |
-
for k in range(i, j):
|
| 200 |
-
if left_val is None:
|
| 201 |
-
result[k] = right_val
|
| 202 |
-
elif right_val is None:
|
| 203 |
-
result[k] = left_val
|
| 204 |
-
else:
|
| 205 |
-
result[k] = left_val if (k - (i - 1)) <= ((j) - k) else right_val
|
| 206 |
-
|
| 207 |
-
else:
|
| 208 |
-
left_val = None
|
| 209 |
-
for k in range(i - 1, -1, -1):
|
| 210 |
-
if is_normal[k]:
|
| 211 |
-
left_val = result[k]
|
| 212 |
-
break
|
| 213 |
-
|
| 214 |
-
right_val = None
|
| 215 |
-
for k in range(j, n):
|
| 216 |
-
if is_normal[k]:
|
| 217 |
-
right_val = result[k]
|
| 218 |
-
break
|
| 219 |
-
|
| 220 |
-
if left_val is not None and right_val is not None:
|
| 221 |
-
step = (right_val - left_val) / (anomaly_count + 1)
|
| 222 |
-
for k in range(i, j):
|
| 223 |
-
result[k] = left_val + step * (k - i + 1)
|
| 224 |
-
elif left_val is not None:
|
| 225 |
-
for k in range(i, j):
|
| 226 |
-
result[k] = left_val
|
| 227 |
-
elif right_val is not None:
|
| 228 |
-
for k in range(i, j):
|
| 229 |
-
result[k] = right_val
|
| 230 |
-
|
| 231 |
-
i = j
|
| 232 |
-
else:
|
| 233 |
-
i += 1
|
| 234 |
-
|
| 235 |
-
return [int(res) for res in result]
|
| 236 |
-
|
| 237 |
-
def encode_timestamp(self, text: str, language: str) -> List[str]:
|
| 238 |
-
language = language.lower()
|
| 239 |
-
|
| 240 |
-
if language.lower() == "japanese":
|
| 241 |
-
word_list = self.tokenize_japanese(text)
|
| 242 |
-
elif language.lower() == "korean":
|
| 243 |
-
if self.ko_tokenizer is None:
|
| 244 |
-
from soynlp.tokenizer import LTokenizer
|
| 245 |
-
self.ko_tokenizer = LTokenizer(scores=self.ko_score)
|
| 246 |
-
word_list = self.tokenize_korean(self.ko_tokenizer, text)
|
| 247 |
-
else:
|
| 248 |
-
word_list = self.tokenize_space_lang(text)
|
| 249 |
-
|
| 250 |
-
input_text = "<timestamp><timestamp>".join(word_list) + "<timestamp><timestamp>"
|
| 251 |
-
input_text = "<|audio_start|><|audio_pad|><|audio_end|>" + input_text
|
| 252 |
-
|
| 253 |
-
return word_list, input_text
|
| 254 |
-
|
| 255 |
-
def parse_timestamp(self, word_list, timestamp):
|
| 256 |
-
timestamp_output = []
|
| 257 |
-
|
| 258 |
-
timestamp_fixed = self.fix_timestamp(timestamp)
|
| 259 |
-
for i, word in enumerate(word_list):
|
| 260 |
-
start_time = timestamp_fixed[i * 2]
|
| 261 |
-
end_time = timestamp_fixed[i * 2 + 1]
|
| 262 |
-
timestamp_output.append({
|
| 263 |
-
"text": word,
|
| 264 |
-
"start_time": start_time,
|
| 265 |
-
"end_time": end_time
|
| 266 |
-
})
|
| 267 |
-
|
| 268 |
-
return timestamp_output
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
@dataclass(frozen=True)
|
| 272 |
-
class ForcedAlignItem:
|
| 273 |
-
"""
|
| 274 |
-
One aligned item span.
|
| 275 |
-
|
| 276 |
-
Attributes:
|
| 277 |
-
text (str):
|
| 278 |
-
The aligned unit (cjk character or word) produced by the forced aligner processor.
|
| 279 |
-
start_time (float):
|
| 280 |
-
Start time in seconds.
|
| 281 |
-
end_time (float):
|
| 282 |
-
End time in seconds.
|
| 283 |
-
"""
|
| 284 |
-
text: str
|
| 285 |
-
start_time: int
|
| 286 |
-
end_time: int
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
@dataclass(frozen=True)
|
| 290 |
-
class ForcedAlignResult:
|
| 291 |
-
"""
|
| 292 |
-
Forced alignment output for one sample.
|
| 293 |
-
|
| 294 |
-
Attributes:
|
| 295 |
-
items (List[ForcedAlignItem]):
|
| 296 |
-
Aligned token spans.
|
| 297 |
-
"""
|
| 298 |
-
items: List[ForcedAlignItem]
|
| 299 |
-
|
| 300 |
-
def __iter__(self):
|
| 301 |
-
return iter(self.items)
|
| 302 |
-
|
| 303 |
-
def __len__(self):
|
| 304 |
-
return len(self.items)
|
| 305 |
-
|
| 306 |
-
def __getitem__(self, idx: int) -> ForcedAlignItem:
|
| 307 |
-
return self.items[idx]
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
class Qwen3ForcedAligner:
|
| 311 |
-
"""
|
| 312 |
-
A HuggingFace-style wrapper for Qwen3-ForcedAligner model inference.
|
| 313 |
-
|
| 314 |
-
This wrapper provides:
|
| 315 |
-
- `from_pretrained()` initialization via HuggingFace AutoModel/AutoProcessor
|
| 316 |
-
- audio input normalization (path/URL/base64/(np.ndarray, sr))
|
| 317 |
-
- batch and single-sample forced alignment
|
| 318 |
-
- structured output with attribute access (`.text`, `.start_time`, `.end_time`)
|
| 319 |
-
"""
|
| 320 |
-
|
| 321 |
-
def __init__(
|
| 322 |
-
self,
|
| 323 |
-
model: Qwen3ASRForConditionalGeneration,
|
| 324 |
-
processor: Qwen3ASRProcessor,
|
| 325 |
-
aligner_processor: Qwen3ForceAlignProcessor,
|
| 326 |
-
):
|
| 327 |
-
self.model = model
|
| 328 |
-
self.processor = processor
|
| 329 |
-
self.aligner_processor = aligner_processor
|
| 330 |
-
|
| 331 |
-
self.device = getattr(model, "device", None)
|
| 332 |
-
if self.device is None:
|
| 333 |
-
try:
|
| 334 |
-
self.device = next(model.parameters()).device
|
| 335 |
-
except StopIteration:
|
| 336 |
-
self.device = torch.device("cpu")
|
| 337 |
-
|
| 338 |
-
self.timestamp_token_id = int(model.config.timestamp_token_id)
|
| 339 |
-
self.timestamp_segment_time = float(model.config.timestamp_segment_time)
|
| 340 |
-
|
| 341 |
-
@classmethod
|
| 342 |
-
def from_pretrained(
|
| 343 |
-
cls,
|
| 344 |
-
pretrained_model_name_or_path: str,
|
| 345 |
-
**kwargs,
|
| 346 |
-
) -> "Qwen3ForcedAligner":
|
| 347 |
-
"""
|
| 348 |
-
Load Qwen3-ForcedAligner model and initialize processors.
|
| 349 |
-
|
| 350 |
-
This method:
|
| 351 |
-
1) Registers config/model/processor for HF auto classes.
|
| 352 |
-
2) Loads the model using `AutoModel.from_pretrained(...)`.
|
| 353 |
-
3) Initializes:
|
| 354 |
-
- HF processor (`AutoProcessor.from_pretrained(...)`)
|
| 355 |
-
- forced alignment text processor (`Qwen3ForceAlignProcessor()`)
|
| 356 |
-
|
| 357 |
-
Args:
|
| 358 |
-
pretrained_model_name_or_path (str):
|
| 359 |
-
HuggingFace repo id or local directory.
|
| 360 |
-
**kwargs:
|
| 361 |
-
Forwarded to `AutoModel.from_pretrained(...)`.
|
| 362 |
-
Typical examples: device_map="cuda:0", dtype=torch.bfloat16.
|
| 363 |
-
|
| 364 |
-
Returns:
|
| 365 |
-
Qwen3ForcedAligner:
|
| 366 |
-
Initialized wrapper instance.
|
| 367 |
-
"""
|
| 368 |
-
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
|
| 369 |
-
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
|
| 370 |
-
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
|
| 371 |
-
|
| 372 |
-
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 373 |
-
if not isinstance(model, Qwen3ASRForConditionalGeneration):
|
| 374 |
-
raise TypeError(
|
| 375 |
-
f"AutoModel returned {type(model)}, expected Qwen3ASRForConditionalGeneration."
|
| 376 |
-
)
|
| 377 |
-
|
| 378 |
-
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True)
|
| 379 |
-
aligner_processor = Qwen3ForceAlignProcessor()
|
| 380 |
-
|
| 381 |
-
return cls(model=model, processor=processor, aligner_processor=aligner_processor)
|
| 382 |
-
|
| 383 |
-
def _to_structured_items(self, timestamp_output: List[Dict[str, Any]]) -> ForcedAlignResult:
|
| 384 |
-
items: List[ForcedAlignItem] = []
|
| 385 |
-
for it in timestamp_output:
|
| 386 |
-
items.append(
|
| 387 |
-
ForcedAlignItem(
|
| 388 |
-
text=str(it.get("text", "")),
|
| 389 |
-
start_time=float(it.get("start_time", 0)),
|
| 390 |
-
end_time=float(it.get("end_time", 0)),
|
| 391 |
-
)
|
| 392 |
-
)
|
| 393 |
-
return ForcedAlignResult(items=items)
|
| 394 |
-
|
| 395 |
-
@torch.inference_mode()
|
| 396 |
-
def align(
|
| 397 |
-
self,
|
| 398 |
-
audio: Union[AudioLike, List[AudioLike]],
|
| 399 |
-
text: Union[str, List[str]],
|
| 400 |
-
language: Union[str, List[str]],
|
| 401 |
-
) -> List[ForcedAlignResult]:
|
| 402 |
-
"""
|
| 403 |
-
Run forced alignment for batch or single sample.
|
| 404 |
-
|
| 405 |
-
Args:
|
| 406 |
-
audio:
|
| 407 |
-
Audio input(s). Each item supports:
|
| 408 |
-
- local path / https URL / base64 string
|
| 409 |
-
- (np.ndarray, sr)
|
| 410 |
-
All audios will be converted into mono 16k float32 arrays in [-1, 1].
|
| 411 |
-
text:
|
| 412 |
-
Transcript(s) for alignment.
|
| 413 |
-
language:
|
| 414 |
-
Language(s) for each sample (e.g., "Chinese", "English").
|
| 415 |
-
|
| 416 |
-
Returns:
|
| 417 |
-
List[ForcedAlignResult]:
|
| 418 |
-
One result per sample. Each result contains `items`, and each token can be accessed via
|
| 419 |
-
`.text`, `.start_time`, `.end_time`.
|
| 420 |
-
"""
|
| 421 |
-
texts = ensure_list(text)
|
| 422 |
-
languages = ensure_list(language)
|
| 423 |
-
audios = normalize_audios(audio)
|
| 424 |
-
|
| 425 |
-
if len(languages) == 1 and len(audios) > 1:
|
| 426 |
-
languages = languages * len(audios)
|
| 427 |
-
|
| 428 |
-
if not (len(audios) == len(texts) == len(languages)):
|
| 429 |
-
raise ValueError(
|
| 430 |
-
f"Batch size mismatch: audio={len(audios)}, text={len(texts)}, language={len(languages)}"
|
| 431 |
-
)
|
| 432 |
-
|
| 433 |
-
word_lists = []
|
| 434 |
-
aligner_input_texts = []
|
| 435 |
-
for t, lang in zip(texts, languages):
|
| 436 |
-
word_list, aligner_input_text = self.aligner_processor.encode_timestamp(t, lang)
|
| 437 |
-
word_lists.append(word_list)
|
| 438 |
-
aligner_input_texts.append(aligner_input_text)
|
| 439 |
-
|
| 440 |
-
inputs = self.processor(
|
| 441 |
-
text=aligner_input_texts,
|
| 442 |
-
audio=audios,
|
| 443 |
-
return_tensors="pt",
|
| 444 |
-
padding=True,
|
| 445 |
-
)
|
| 446 |
-
inputs = inputs.to(self.model.device).to(self.model.dtype)
|
| 447 |
-
|
| 448 |
-
logits = self.model.thinker(**inputs).logits
|
| 449 |
-
output_ids = logits.argmax(dim=-1)
|
| 450 |
-
|
| 451 |
-
results: List[ForcedAlignResult] = []
|
| 452 |
-
for input_id, output_id, word_list in zip(inputs["input_ids"], output_ids, word_lists):
|
| 453 |
-
masked_output_id = output_id[input_id == self.timestamp_token_id]
|
| 454 |
-
timestamp_ms = (masked_output_id * self.timestamp_segment_time).to("cpu").numpy()
|
| 455 |
-
timestamp_output = self.aligner_processor.parse_timestamp(word_list, timestamp_ms)
|
| 456 |
-
for it in timestamp_output:
|
| 457 |
-
it['start_time'] = round(it['start_time'] / 1000.0, 3)
|
| 458 |
-
it['end_time'] = round(it['end_time'] / 1000.0, 3)
|
| 459 |
-
results.append(self._to_structured_items(timestamp_output))
|
| 460 |
-
|
| 461 |
-
return results
|
| 462 |
-
|
| 463 |
-
def get_supported_languages(self) -> Optional[List[str]]:
|
| 464 |
-
"""
|
| 465 |
-
List supported language names for the current model.
|
| 466 |
-
|
| 467 |
-
This is a thin wrapper around `self.model.get_support_languages()`.
|
| 468 |
-
If the underlying model does not expose language constraints (returns None),
|
| 469 |
-
this method also returns None.
|
| 470 |
-
|
| 471 |
-
Returns:
|
| 472 |
-
Optional[List[str]]:
|
| 473 |
-
- A sorted list of supported language names (lowercased), if available.
|
| 474 |
-
- None if the model does not provide supported languages.
|
| 475 |
-
"""
|
| 476 |
-
fn = getattr(self.model, "get_support_languages", None)
|
| 477 |
-
if not callable(fn):
|
| 478 |
-
return None
|
| 479 |
-
|
| 480 |
-
langs = fn()
|
| 481 |
-
if langs is None:
|
| 482 |
-
return None
|
| 483 |
-
|
| 484 |
-
return sorted({str(x).lower() for x in langs})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_asr/inference/utils.py
DELETED
|
@@ -1,497 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
import base64
|
| 17 |
-
import io
|
| 18 |
-
import urllib.request
|
| 19 |
-
from dataclasses import dataclass
|
| 20 |
-
from typing import Any, Iterable, List, Optional, Tuple, Union
|
| 21 |
-
from urllib.parse import urlparse
|
| 22 |
-
|
| 23 |
-
import librosa
|
| 24 |
-
import numpy as np
|
| 25 |
-
import soundfile as sf
|
| 26 |
-
|
| 27 |
-
AudioLike = Union[
|
| 28 |
-
str, # wav path / URL / base64
|
| 29 |
-
Tuple[np.ndarray, int], # (waveform, sr)
|
| 30 |
-
]
|
| 31 |
-
MaybeList = Union[Any, List[Any]]
|
| 32 |
-
|
| 33 |
-
SAMPLE_RATE = 16000
|
| 34 |
-
MAX_ASR_INPUT_SECONDS = 1200
|
| 35 |
-
MAX_FORCE_ALIGN_INPUT_SECONDS = 180
|
| 36 |
-
MIN_ASR_INPUT_SECONDS = 0.5
|
| 37 |
-
SUPPORTED_LANGUAGES: List[str] = [
|
| 38 |
-
"Chinese",
|
| 39 |
-
"English",
|
| 40 |
-
"Cantonese",
|
| 41 |
-
"Arabic",
|
| 42 |
-
"German",
|
| 43 |
-
"French",
|
| 44 |
-
"Spanish",
|
| 45 |
-
"Portuguese",
|
| 46 |
-
"Indonesian",
|
| 47 |
-
"Italian",
|
| 48 |
-
"Korean",
|
| 49 |
-
"Russian",
|
| 50 |
-
"Thai",
|
| 51 |
-
"Vietnamese",
|
| 52 |
-
"Japanese",
|
| 53 |
-
"Turkish",
|
| 54 |
-
"Hindi",
|
| 55 |
-
"Malay",
|
| 56 |
-
"Dutch",
|
| 57 |
-
"Swedish",
|
| 58 |
-
"Danish",
|
| 59 |
-
"Finnish",
|
| 60 |
-
"Polish",
|
| 61 |
-
"Czech",
|
| 62 |
-
"Filipino",
|
| 63 |
-
"Persian",
|
| 64 |
-
"Greek",
|
| 65 |
-
"Romanian",
|
| 66 |
-
"Hungarian",
|
| 67 |
-
"Macedonian"
|
| 68 |
-
]
|
| 69 |
-
_ASR_TEXT_TAG = "<asr_text>"
|
| 70 |
-
_LANG_PREFIX = "language "
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def normalize_language_name(language: str) -> str:
|
| 74 |
-
"""
|
| 75 |
-
Normalize language name to the canonical format used by Qwen3-ASR:
|
| 76 |
-
first letter uppercase, the rest lowercase (e.g., 'cHINese' -> 'Chinese').
|
| 77 |
-
|
| 78 |
-
Args:
|
| 79 |
-
language (str): Input language name.
|
| 80 |
-
|
| 81 |
-
Returns:
|
| 82 |
-
str: Normalized language name.
|
| 83 |
-
|
| 84 |
-
Raises:
|
| 85 |
-
ValueError: If language is empty.
|
| 86 |
-
"""
|
| 87 |
-
if language is None:
|
| 88 |
-
raise ValueError("language is None")
|
| 89 |
-
s = str(language).strip()
|
| 90 |
-
if not s:
|
| 91 |
-
raise ValueError("language is empty")
|
| 92 |
-
return s[:1].upper() + s[1:].lower()
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def validate_language(language: str) -> None:
|
| 96 |
-
"""
|
| 97 |
-
Validate the language is supported.
|
| 98 |
-
|
| 99 |
-
Args:
|
| 100 |
-
language (str): Canonical language name.
|
| 101 |
-
|
| 102 |
-
Raises:
|
| 103 |
-
ValueError: If unsupported.
|
| 104 |
-
"""
|
| 105 |
-
if language not in SUPPORTED_LANGUAGES:
|
| 106 |
-
raise ValueError(f"Unsupported language: {language}. Supported: {SUPPORTED_LANGUAGES}")
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def ensure_list(x: MaybeList) -> List[Any]:
|
| 110 |
-
return x if isinstance(x, list) else [x]
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def is_url(s: str) -> bool:
|
| 114 |
-
try:
|
| 115 |
-
u = urlparse(s)
|
| 116 |
-
return u.scheme in ("http", "https") and bool(u.netloc)
|
| 117 |
-
except Exception:
|
| 118 |
-
return False
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def is_probably_base64(s: str) -> bool:
|
| 122 |
-
if s.startswith("data:audio"):
|
| 123 |
-
return True
|
| 124 |
-
if ("/" not in s and "\\" not in s) and len(s) > 256:
|
| 125 |
-
return True
|
| 126 |
-
return False
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def decode_base64_bytes(b64: str) -> bytes:
|
| 130 |
-
if "," in b64 and b64.strip().startswith("data:"):
|
| 131 |
-
b64 = b64.split(",", 1)[1]
|
| 132 |
-
return base64.b64decode(b64)
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
def load_audio_any(x: str) -> Tuple[np.ndarray, int]:
|
| 136 |
-
if is_url(x):
|
| 137 |
-
with urllib.request.urlopen(x) as resp:
|
| 138 |
-
audio_bytes = resp.read()
|
| 139 |
-
with io.BytesIO(audio_bytes) as f:
|
| 140 |
-
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
| 141 |
-
elif is_probably_base64(x):
|
| 142 |
-
audio_bytes = decode_base64_bytes(x)
|
| 143 |
-
with io.BytesIO(audio_bytes) as f:
|
| 144 |
-
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
| 145 |
-
else:
|
| 146 |
-
audio, sr = librosa.load(x, sr=None, mono=False)
|
| 147 |
-
|
| 148 |
-
audio = np.asarray(audio, dtype=np.float32)
|
| 149 |
-
sr = int(sr)
|
| 150 |
-
return audio, sr
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
def to_mono(audio: np.ndarray) -> np.ndarray:
|
| 154 |
-
if audio.ndim == 1:
|
| 155 |
-
return audio
|
| 156 |
-
# soundfile can return shape (T, C); some pipelines use (C, T)
|
| 157 |
-
if audio.ndim == 2:
|
| 158 |
-
if audio.shape[0] <= 8 and audio.shape[1] > audio.shape[0]:
|
| 159 |
-
audio = audio.T
|
| 160 |
-
return np.mean(audio, axis=-1).astype(np.float32)
|
| 161 |
-
raise ValueError(f"Unsupported audio ndim={audio.ndim}")
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def float_range_normalize(audio: np.ndarray) -> np.ndarray:
|
| 165 |
-
audio = audio.astype(np.float32)
|
| 166 |
-
if audio.size == 0:
|
| 167 |
-
return audio
|
| 168 |
-
peak = float(np.max(np.abs(audio)))
|
| 169 |
-
if peak == 0.0:
|
| 170 |
-
return audio
|
| 171 |
-
# If decoded audio is int-like scaled or out-of-range, normalize conservatively.
|
| 172 |
-
if peak > 1.0:
|
| 173 |
-
audio = audio / peak
|
| 174 |
-
audio = np.clip(audio, -1.0, 1.0)
|
| 175 |
-
return audio
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def normalize_audio_input(a: AudioLike) -> np.ndarray:
|
| 179 |
-
"""
|
| 180 |
-
Normalize one audio input to mono 16k float32 waveform in [-1, 1].
|
| 181 |
-
|
| 182 |
-
Supported inputs:
|
| 183 |
-
- str: local file path / https URL / base64 audio string
|
| 184 |
-
- (np.ndarray, sr): waveform and sampling rate
|
| 185 |
-
|
| 186 |
-
Returns:
|
| 187 |
-
np.ndarray:
|
| 188 |
-
Mono 16k float32 waveform in [-1, 1].
|
| 189 |
-
"""
|
| 190 |
-
if isinstance(a, str):
|
| 191 |
-
audio, sr = load_audio_any(a)
|
| 192 |
-
elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
|
| 193 |
-
audio, sr = a[0], int(a[1])
|
| 194 |
-
else:
|
| 195 |
-
raise TypeError(f"Unsupported audio input type: {type(a)}")
|
| 196 |
-
|
| 197 |
-
audio = to_mono(np.asarray(audio))
|
| 198 |
-
if sr != SAMPLE_RATE:
|
| 199 |
-
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE).astype(np.float32)
|
| 200 |
-
audio = float_range_normalize(audio)
|
| 201 |
-
return audio
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
def normalize_audios(audios: Union[AudioLike, List[AudioLike]]) -> List[np.ndarray]:
|
| 205 |
-
items = ensure_list(audios)
|
| 206 |
-
return [normalize_audio_input(a) for a in items]
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
def chunk_list(xs: List[Any], chunk_size: int) -> Iterable[List[Any]]:
|
| 210 |
-
"""
|
| 211 |
-
Yield chunks of a list.
|
| 212 |
-
|
| 213 |
-
Args:
|
| 214 |
-
xs (List[Any]): Input list.
|
| 215 |
-
chunk_size (int): Chunk size.
|
| 216 |
-
|
| 217 |
-
Yields:
|
| 218 |
-
List[Any]: Slices of xs.
|
| 219 |
-
"""
|
| 220 |
-
if chunk_size <= 0:
|
| 221 |
-
yield xs
|
| 222 |
-
return
|
| 223 |
-
for i in range(0, len(xs), chunk_size):
|
| 224 |
-
yield xs[i : i + chunk_size]
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
@dataclass(frozen=True)
|
| 228 |
-
class AudioChunk:
|
| 229 |
-
"""
|
| 230 |
-
One chunk cut from an original audio.
|
| 231 |
-
|
| 232 |
-
Attributes:
|
| 233 |
-
orig_index: Index of the original sample in the input batch.
|
| 234 |
-
chunk_index: Index of this chunk within the original sample.
|
| 235 |
-
wav: Mono float32 waveform.
|
| 236 |
-
sr: Sampling rate.
|
| 237 |
-
offset_sec: Start offset of this chunk in the original audio, in seconds.
|
| 238 |
-
"""
|
| 239 |
-
orig_index: int
|
| 240 |
-
chunk_index: int
|
| 241 |
-
wav: np.ndarray
|
| 242 |
-
sr: int
|
| 243 |
-
offset_sec: float
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
def split_audio_into_chunks(
|
| 247 |
-
wav: np.ndarray,
|
| 248 |
-
sr: int,
|
| 249 |
-
max_chunk_sec: float,
|
| 250 |
-
search_expand_sec: float = 5.0,
|
| 251 |
-
min_window_ms: float = 100.0,
|
| 252 |
-
) -> List[Tuple[np.ndarray, float]]:
|
| 253 |
-
"""
|
| 254 |
-
Split a long audio into chunks close to max_chunk_sec, using a low-energy boundary.
|
| 255 |
-
|
| 256 |
-
This implementation guarantees:
|
| 257 |
-
- Concatenating all returned chunks reproduces the original audio exactly
|
| 258 |
-
(total number of samples is identical, no overlaps, no gaps).
|
| 259 |
-
|
| 260 |
-
Args:
|
| 261 |
-
wav: Mono waveform float32.
|
| 262 |
-
sr: Sampling rate.
|
| 263 |
-
max_chunk_sec: Target max chunk duration in seconds.
|
| 264 |
-
search_expand_sec: Boundary search half-window in seconds.
|
| 265 |
-
min_window_ms: Sliding window in milliseconds for energy estimation.
|
| 266 |
-
|
| 267 |
-
Returns:
|
| 268 |
-
List[Tuple[np.ndarray, float]]: List of (chunk_wav, offset_sec).
|
| 269 |
-
"""
|
| 270 |
-
wav = np.asarray(wav, dtype=np.float32)
|
| 271 |
-
if wav.ndim > 1:
|
| 272 |
-
wav = np.mean(wav, axis=-1).astype(np.float32)
|
| 273 |
-
|
| 274 |
-
total_len = int(wav.shape[0])
|
| 275 |
-
total_sec = total_len / float(sr)
|
| 276 |
-
if total_sec <= max_chunk_sec:
|
| 277 |
-
return [(wav, 0.0)]
|
| 278 |
-
|
| 279 |
-
max_len = int(max_chunk_sec * sr)
|
| 280 |
-
expand = int(search_expand_sec * sr)
|
| 281 |
-
win = max(4, int((min_window_ms / 1000.0) * sr))
|
| 282 |
-
|
| 283 |
-
chunks: List[Tuple[np.ndarray, float]] = []
|
| 284 |
-
|
| 285 |
-
start = 0
|
| 286 |
-
offset_sec = 0.0
|
| 287 |
-
|
| 288 |
-
while (total_len - start) > max_len:
|
| 289 |
-
cut = start + max_len
|
| 290 |
-
|
| 291 |
-
left = max(start, cut - expand)
|
| 292 |
-
right = min(total_len, cut + expand)
|
| 293 |
-
|
| 294 |
-
if right - left <= win:
|
| 295 |
-
boundary = cut
|
| 296 |
-
else:
|
| 297 |
-
seg = wav[left:right]
|
| 298 |
-
seg_abs = np.abs(seg)
|
| 299 |
-
|
| 300 |
-
window_sums = np.convolve(seg_abs, np.ones(win, dtype=np.float32), mode="valid")
|
| 301 |
-
|
| 302 |
-
min_pos = int(np.argmin(window_sums))
|
| 303 |
-
|
| 304 |
-
wstart = min_pos
|
| 305 |
-
wend = min_pos + win
|
| 306 |
-
local = seg_abs[wstart:wend]
|
| 307 |
-
inner = int(np.argmin(local))
|
| 308 |
-
boundary = left + wstart + inner
|
| 309 |
-
|
| 310 |
-
boundary = int(max(boundary, start + 1))
|
| 311 |
-
boundary = int(min(boundary, total_len))
|
| 312 |
-
|
| 313 |
-
chunk = wav[start:boundary]
|
| 314 |
-
chunks.append((chunk, offset_sec))
|
| 315 |
-
|
| 316 |
-
offset_sec += (boundary - start) / float(sr)
|
| 317 |
-
start = boundary
|
| 318 |
-
|
| 319 |
-
tail = wav[start:total_len]
|
| 320 |
-
chunks.append((tail, offset_sec))
|
| 321 |
-
|
| 322 |
-
# Pad too-short chunks to at least MIN_ASR_INPUT_SECONDS (zero-padding at tail)
|
| 323 |
-
min_len = int(MIN_ASR_INPUT_SECONDS * sr)
|
| 324 |
-
padded: List[Tuple[np.ndarray, float]] = []
|
| 325 |
-
for c, off in chunks:
|
| 326 |
-
if c.shape[0] < min_len:
|
| 327 |
-
pad = min_len - int(c.shape[0])
|
| 328 |
-
c = np.pad(c, (0, pad), mode="constant", constant_values=0.0).astype(np.float32)
|
| 329 |
-
padded.append((c, off))
|
| 330 |
-
chunks = padded
|
| 331 |
-
|
| 332 |
-
return chunks
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
def detect_and_fix_repetitions(text, threshold=20):
|
| 336 |
-
def fix_char_repeats(s, thresh):
|
| 337 |
-
res = []
|
| 338 |
-
i = 0
|
| 339 |
-
n = len(s)
|
| 340 |
-
while i < n:
|
| 341 |
-
count = 1
|
| 342 |
-
while i + count < n and s[i + count] == s[i]:
|
| 343 |
-
count += 1
|
| 344 |
-
|
| 345 |
-
if count > thresh:
|
| 346 |
-
res.append(s[i])
|
| 347 |
-
i += count
|
| 348 |
-
else:
|
| 349 |
-
res.append(s[i:i+count])
|
| 350 |
-
i += count
|
| 351 |
-
return ''.join(res)
|
| 352 |
-
|
| 353 |
-
def fix_pattern_repeats(s, thresh, max_len=20):
|
| 354 |
-
n = len(s)
|
| 355 |
-
min_repeat_chars = thresh * 2
|
| 356 |
-
if n < min_repeat_chars:
|
| 357 |
-
return s
|
| 358 |
-
|
| 359 |
-
i = 0
|
| 360 |
-
result = []
|
| 361 |
-
while i <= n - min_repeat_chars:
|
| 362 |
-
found = False
|
| 363 |
-
for k in range(1, max_len + 1):
|
| 364 |
-
if i + k * thresh > n:
|
| 365 |
-
break
|
| 366 |
-
|
| 367 |
-
pattern = s[i:i+k]
|
| 368 |
-
valid = True
|
| 369 |
-
for rep in range(1, thresh):
|
| 370 |
-
start_idx = i + rep * k
|
| 371 |
-
if s[start_idx:start_idx+k] != pattern:
|
| 372 |
-
valid = False
|
| 373 |
-
break
|
| 374 |
-
|
| 375 |
-
if valid:
|
| 376 |
-
total_rep = thresh
|
| 377 |
-
end_index = i + thresh * k
|
| 378 |
-
while end_index + k <= n and s[end_index:end_index+k] == pattern:
|
| 379 |
-
total_rep += 1
|
| 380 |
-
end_index += k
|
| 381 |
-
result.append(pattern)
|
| 382 |
-
result.append(fix_pattern_repeats(s[end_index:], thresh, max_len))
|
| 383 |
-
i = n
|
| 384 |
-
found = True
|
| 385 |
-
break
|
| 386 |
-
|
| 387 |
-
if found:
|
| 388 |
-
break
|
| 389 |
-
else:
|
| 390 |
-
result.append(s[i])
|
| 391 |
-
i += 1
|
| 392 |
-
|
| 393 |
-
if not found:
|
| 394 |
-
result.append(s[i:])
|
| 395 |
-
return ''.join(result)
|
| 396 |
-
|
| 397 |
-
text_raw = text
|
| 398 |
-
text = fix_char_repeats(text_raw, threshold)
|
| 399 |
-
text = fix_pattern_repeats(text, threshold)
|
| 400 |
-
return text
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
def parse_asr_output(
|
| 404 |
-
raw: str,
|
| 405 |
-
user_language: Optional[str] = None,
|
| 406 |
-
) -> Tuple[str, str]:
|
| 407 |
-
"""
|
| 408 |
-
Parse Qwen3-ASR raw output into (language, text).
|
| 409 |
-
|
| 410 |
-
Cases:
|
| 411 |
-
- With tag: "language Chinese<asr_text>...."
|
| 412 |
-
- With newlines: "language Chinese\\n...\\n<asr_text>...."
|
| 413 |
-
- No tag: treat whole string as text.
|
| 414 |
-
- "language None<asr_text>": treat as empty audio -> ("", "")
|
| 415 |
-
|
| 416 |
-
If user_language is provided, language is forced to user_language and raw is treated as text-only
|
| 417 |
-
(the model is expected to output plain transcription without metadata).
|
| 418 |
-
|
| 419 |
-
Args:
|
| 420 |
-
raw: Raw decoded string.
|
| 421 |
-
user_language: Canonical language name if user forced language.
|
| 422 |
-
|
| 423 |
-
Returns:
|
| 424 |
-
Tuple[str, str]: (language, text)
|
| 425 |
-
"""
|
| 426 |
-
if raw is None:
|
| 427 |
-
return "", ""
|
| 428 |
-
s = str(raw).strip()
|
| 429 |
-
if not s:
|
| 430 |
-
return "", ""
|
| 431 |
-
|
| 432 |
-
s = detect_and_fix_repetitions(s)
|
| 433 |
-
|
| 434 |
-
if user_language:
|
| 435 |
-
# user explicitly forced language => model output is treated as pure text
|
| 436 |
-
return user_language, s
|
| 437 |
-
|
| 438 |
-
meta_part = s
|
| 439 |
-
text_part = ""
|
| 440 |
-
has_tag = _ASR_TEXT_TAG in s
|
| 441 |
-
if has_tag:
|
| 442 |
-
meta_part, text_part = s.split(_ASR_TEXT_TAG, 1)
|
| 443 |
-
else:
|
| 444 |
-
# no tag => pure text
|
| 445 |
-
return "", s.strip()
|
| 446 |
-
|
| 447 |
-
meta_lower = meta_part.lower()
|
| 448 |
-
|
| 449 |
-
# empty audio heuristic
|
| 450 |
-
if "language none" in meta_lower:
|
| 451 |
-
t = text_part.strip()
|
| 452 |
-
if not t:
|
| 453 |
-
return "", ""
|
| 454 |
-
# if model still returned something, keep it but language unknown
|
| 455 |
-
return "", t
|
| 456 |
-
|
| 457 |
-
# extract "language xxx" from meta
|
| 458 |
-
lang = ""
|
| 459 |
-
for line in meta_part.splitlines():
|
| 460 |
-
line = line.strip()
|
| 461 |
-
if not line:
|
| 462 |
-
continue
|
| 463 |
-
low = line.lower()
|
| 464 |
-
if low.startswith(_LANG_PREFIX):
|
| 465 |
-
val = line[len(_LANG_PREFIX):].strip()
|
| 466 |
-
if val:
|
| 467 |
-
lang = normalize_language_name(val)
|
| 468 |
-
break
|
| 469 |
-
|
| 470 |
-
return lang, text_part.strip()
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
def merge_languages(langs: List[str]) -> str:
|
| 474 |
-
"""
|
| 475 |
-
Merge per-chunk languages into a compact comma-separated string,
|
| 476 |
-
keeping order and removing consecutive duplicates and empty entries.
|
| 477 |
-
|
| 478 |
-
Example:
|
| 479 |
-
["Chinese", "English", "English"] -> "Chinese,English"
|
| 480 |
-
|
| 481 |
-
Args:
|
| 482 |
-
langs: List of canonical language names.
|
| 483 |
-
|
| 484 |
-
Returns:
|
| 485 |
-
str: Merged language string.
|
| 486 |
-
"""
|
| 487 |
-
out: List[str] = []
|
| 488 |
-
prev = None
|
| 489 |
-
for x in langs:
|
| 490 |
-
x = (x or "").strip()
|
| 491 |
-
if not x:
|
| 492 |
-
continue
|
| 493 |
-
if x == prev:
|
| 494 |
-
continue
|
| 495 |
-
out.append(x)
|
| 496 |
-
prev = x
|
| 497 |
-
return ",".join(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/__init__.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
|
| 17 |
-
"""
|
| 18 |
-
qwen_tts: Qwen-TTS package.
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
from .inference.qwen3_tts_model import Qwen3TTSModel, VoiceClonePromptItem
|
| 22 |
-
from .inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
|
| 23 |
-
|
| 24 |
-
__all__ = ["__version__"]
|
| 25 |
-
__version__ = "0.0.1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/__main__.py
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
def main():
|
| 17 |
-
print(
|
| 18 |
-
"qwen_tts package.\n"
|
| 19 |
-
"Use CLI entrypoints:\n"
|
| 20 |
-
" - qwen-tts-demo\n"
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
if __name__ == "__main__":
|
| 24 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/cli/demo.py
DELETED
|
@@ -1,633 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
"""
|
| 17 |
-
A gradio demo for Qwen3 TTS models.
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
import argparse
|
| 21 |
-
import os
|
| 22 |
-
import tempfile
|
| 23 |
-
from dataclasses import asdict
|
| 24 |
-
from typing import Any, Dict, List, Optional, Tuple
|
| 25 |
-
|
| 26 |
-
import gradio as gr
|
| 27 |
-
import numpy as np
|
| 28 |
-
import torch
|
| 29 |
-
|
| 30 |
-
from .. import Qwen3TTSModel, VoiceClonePromptItem
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def _title_case_display(s: str) -> str:
|
| 34 |
-
s = (s or "").strip()
|
| 35 |
-
s = s.replace("_", " ")
|
| 36 |
-
return " ".join([w[:1].upper() + w[1:] if w else "" for w in s.split()])
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def _build_choices_and_map(items: Optional[List[str]]) -> Tuple[List[str], Dict[str, str]]:
|
| 40 |
-
if not items:
|
| 41 |
-
return [], {}
|
| 42 |
-
display = [_title_case_display(x) for x in items]
|
| 43 |
-
mapping = {d: r for d, r in zip(display, items)}
|
| 44 |
-
return display, mapping
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def _dtype_from_str(s: str) -> torch.dtype:
|
| 48 |
-
s = (s or "").strip().lower()
|
| 49 |
-
if s in ("bf16", "bfloat16"):
|
| 50 |
-
return torch.bfloat16
|
| 51 |
-
if s in ("fp16", "float16", "half"):
|
| 52 |
-
return torch.float16
|
| 53 |
-
if s in ("fp32", "float32"):
|
| 54 |
-
return torch.float32
|
| 55 |
-
raise ValueError(f"Unsupported torch dtype: {s}. Use bfloat16/float16/float32.")
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def _maybe(v):
|
| 59 |
-
return v if v is not None else gr.update()
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def build_parser() -> argparse.ArgumentParser:
|
| 63 |
-
parser = argparse.ArgumentParser(
|
| 64 |
-
prog="qwen-tts-demo",
|
| 65 |
-
description=(
|
| 66 |
-
"Launch a Gradio demo for Qwen3 TTS models (CustomVoice / VoiceDesign / Base).\n\n"
|
| 67 |
-
"Examples:\n"
|
| 68 |
-
" qwen-tts-demo Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice\n"
|
| 69 |
-
" qwen-tts-demo Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign --port 8000 --ip 127.0.0.01\n"
|
| 70 |
-
" qwen-tts-demo Qwen/Qwen3-TTS-12Hz-1.7B-Base --device cuda:0\n"
|
| 71 |
-
" qwen-tts-demo Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice --dtype bfloat16 --no-flash-attn\n"
|
| 72 |
-
),
|
| 73 |
-
formatter_class=argparse.RawTextHelpFormatter,
|
| 74 |
-
add_help=True,
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
# Positional checkpoint (also supports -c/--checkpoint)
|
| 78 |
-
parser.add_argument(
|
| 79 |
-
"checkpoint_pos",
|
| 80 |
-
nargs="?",
|
| 81 |
-
default=None,
|
| 82 |
-
help="Model checkpoint path or HuggingFace repo id (positional).",
|
| 83 |
-
)
|
| 84 |
-
parser.add_argument(
|
| 85 |
-
"-c",
|
| 86 |
-
"--checkpoint",
|
| 87 |
-
default=None,
|
| 88 |
-
help="Model checkpoint path or HuggingFace repo id (optional if positional is provided).",
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
# Model loading / from_pretrained args
|
| 92 |
-
parser.add_argument(
|
| 93 |
-
"--device",
|
| 94 |
-
default="cuda:0",
|
| 95 |
-
help="Device for device_map, e.g. cpu, cuda, cuda:0 (default: cuda:0).",
|
| 96 |
-
)
|
| 97 |
-
parser.add_argument(
|
| 98 |
-
"--dtype",
|
| 99 |
-
default="bfloat16",
|
| 100 |
-
choices=["bfloat16", "bf16", "float16", "fp16", "float32", "fp32"],
|
| 101 |
-
help="Torch dtype for loading the model (default: bfloat16).",
|
| 102 |
-
)
|
| 103 |
-
parser.add_argument(
|
| 104 |
-
"--flash-attn/--no-flash-attn",
|
| 105 |
-
dest="flash_attn",
|
| 106 |
-
default=True,
|
| 107 |
-
action=argparse.BooleanOptionalAction,
|
| 108 |
-
help="Enable FlashAttention-2 (default: enabled).",
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
# Gradio server args
|
| 112 |
-
parser.add_argument(
|
| 113 |
-
"--ip",
|
| 114 |
-
default="0.0.0.0",
|
| 115 |
-
help="Server bind IP for Gradio (default: 0.0.0.0).",
|
| 116 |
-
)
|
| 117 |
-
parser.add_argument(
|
| 118 |
-
"--port",
|
| 119 |
-
type=int,
|
| 120 |
-
default=8000,
|
| 121 |
-
help="Server port for Gradio (default: 8000).",
|
| 122 |
-
)
|
| 123 |
-
parser.add_argument(
|
| 124 |
-
"--share/--no-share",
|
| 125 |
-
dest="share",
|
| 126 |
-
default=False,
|
| 127 |
-
action=argparse.BooleanOptionalAction,
|
| 128 |
-
help="Whether to create a public Gradio link (default: disabled).",
|
| 129 |
-
)
|
| 130 |
-
parser.add_argument(
|
| 131 |
-
"--concurrency",
|
| 132 |
-
type=int,
|
| 133 |
-
default=16,
|
| 134 |
-
help="Gradio queue concurrency (default: 16).",
|
| 135 |
-
)
|
| 136 |
-
|
| 137 |
-
# HTTPS args
|
| 138 |
-
parser.add_argument(
|
| 139 |
-
"--ssl-certfile",
|
| 140 |
-
default=None,
|
| 141 |
-
help="Path to SSL certificate file for HTTPS (optional).",
|
| 142 |
-
)
|
| 143 |
-
parser.add_argument(
|
| 144 |
-
"--ssl-keyfile",
|
| 145 |
-
default=None,
|
| 146 |
-
help="Path to SSL key file for HTTPS (optional).",
|
| 147 |
-
)
|
| 148 |
-
parser.add_argument(
|
| 149 |
-
"--ssl-verify",
|
| 150 |
-
default=None,
|
| 151 |
-
help="SSL verify setting for Gradio (optional).",
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
# Optional generation args
|
| 155 |
-
parser.add_argument("--max-new-tokens", type=int, default=None, help="Max new tokens for generation (optional).")
|
| 156 |
-
parser.add_argument("--temperature", type=float, default=None, help="Sampling temperature (optional).")
|
| 157 |
-
parser.add_argument("--top-k", type=int, default=None, help="Top-k sampling (optional).")
|
| 158 |
-
parser.add_argument("--top-p", type=float, default=None, help="Top-p sampling (optional).")
|
| 159 |
-
parser.add_argument("--repetition-penalty", type=float, default=None, help="Repetition penalty (optional).")
|
| 160 |
-
parser.add_argument("--subtalker-top-k", type=int, default=None, help="Subtalker top-k (optional, only for tokenizer v2).")
|
| 161 |
-
parser.add_argument("--subtalker-top-p", type=float, default=None, help="Subtalker top-p (optional, only for tokenizer v2).")
|
| 162 |
-
parser.add_argument(
|
| 163 |
-
"--subtalker-temperature", type=float, default=None, help="Subtalker temperature (optional, only for tokenizer v2)."
|
| 164 |
-
)
|
| 165 |
-
|
| 166 |
-
return parser
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
def _resolve_checkpoint(args: argparse.Namespace) -> str:
|
| 170 |
-
ckpt = args.checkpoint or args.checkpoint_pos
|
| 171 |
-
if not ckpt:
|
| 172 |
-
raise SystemExit(0) # main() prints help
|
| 173 |
-
return ckpt
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def _collect_gen_kwargs(args: argparse.Namespace) -> Dict[str, Any]:
|
| 177 |
-
mapping = {
|
| 178 |
-
"max_new_tokens": args.max_new_tokens,
|
| 179 |
-
"temperature": args.temperature,
|
| 180 |
-
"top_k": args.top_k,
|
| 181 |
-
"top_p": args.top_p,
|
| 182 |
-
"repetition_penalty": args.repetition_penalty,
|
| 183 |
-
"subtalker_top_k": args.subtalker_top_k,
|
| 184 |
-
"subtalker_top_p": args.subtalker_top_p,
|
| 185 |
-
"subtalker_temperature": args.subtalker_temperature,
|
| 186 |
-
}
|
| 187 |
-
return {k: v for k, v in mapping.items() if v is not None}
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
def _normalize_audio(wav, eps=1e-12, clip=True):
|
| 191 |
-
x = np.asarray(wav)
|
| 192 |
-
|
| 193 |
-
if np.issubdtype(x.dtype, np.integer):
|
| 194 |
-
info = np.iinfo(x.dtype)
|
| 195 |
-
|
| 196 |
-
if info.min < 0:
|
| 197 |
-
y = x.astype(np.float32) / max(abs(info.min), info.max)
|
| 198 |
-
else:
|
| 199 |
-
mid = (info.max + 1) / 2.0
|
| 200 |
-
y = (x.astype(np.float32) - mid) / mid
|
| 201 |
-
|
| 202 |
-
elif np.issubdtype(x.dtype, np.floating):
|
| 203 |
-
y = x.astype(np.float32)
|
| 204 |
-
m = np.max(np.abs(y)) if y.size else 0.0
|
| 205 |
-
|
| 206 |
-
if m <= 1.0 + 1e-6:
|
| 207 |
-
pass
|
| 208 |
-
else:
|
| 209 |
-
y = y / (m + eps)
|
| 210 |
-
else:
|
| 211 |
-
raise TypeError(f"Unsupported dtype: {x.dtype}")
|
| 212 |
-
|
| 213 |
-
if clip:
|
| 214 |
-
y = np.clip(y, -1.0, 1.0)
|
| 215 |
-
|
| 216 |
-
if y.ndim > 1:
|
| 217 |
-
y = np.mean(y, axis=-1).astype(np.float32)
|
| 218 |
-
|
| 219 |
-
return y
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
def _audio_to_tuple(audio: Any) -> Optional[Tuple[np.ndarray, int]]:
|
| 223 |
-
if audio is None:
|
| 224 |
-
return None
|
| 225 |
-
|
| 226 |
-
if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[0], int):
|
| 227 |
-
sr, wav = audio
|
| 228 |
-
wav = _normalize_audio(wav)
|
| 229 |
-
return wav, int(sr)
|
| 230 |
-
|
| 231 |
-
if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio:
|
| 232 |
-
sr = int(audio["sampling_rate"])
|
| 233 |
-
wav = _normalize_audio(audio["data"])
|
| 234 |
-
return wav, sr
|
| 235 |
-
|
| 236 |
-
return None
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
def _wav_to_gradio_audio(wav: np.ndarray, sr: int) -> Tuple[int, np.ndarray]:
|
| 240 |
-
wav = np.asarray(wav, dtype=np.float32)
|
| 241 |
-
return sr, wav
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
def _detect_model_kind(ckpt: str, tts: Qwen3TTSModel) -> str:
|
| 245 |
-
mt = getattr(tts.model, "tts_model_type", None)
|
| 246 |
-
if mt in ("custom_voice", "voice_design", "base"):
|
| 247 |
-
return mt
|
| 248 |
-
else:
|
| 249 |
-
raise ValueError(f"Unknown Qwen-TTS model type: {mt}")
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
def build_demo(tts: Qwen3TTSModel, ckpt: str, gen_kwargs_default: Dict[str, Any]) -> gr.Blocks:
|
| 253 |
-
model_kind = _detect_model_kind(ckpt, tts)
|
| 254 |
-
|
| 255 |
-
supported_langs_raw = None
|
| 256 |
-
if callable(getattr(tts.model, "get_supported_languages", None)):
|
| 257 |
-
supported_langs_raw = tts.model.get_supported_languages()
|
| 258 |
-
|
| 259 |
-
supported_spks_raw = None
|
| 260 |
-
if callable(getattr(tts.model, "get_supported_speakers", None)):
|
| 261 |
-
supported_spks_raw = tts.model.get_supported_speakers()
|
| 262 |
-
|
| 263 |
-
lang_choices_disp, lang_map = _build_choices_and_map([x for x in (supported_langs_raw or [])])
|
| 264 |
-
spk_choices_disp, spk_map = _build_choices_and_map([x for x in (supported_spks_raw or [])])
|
| 265 |
-
|
| 266 |
-
def _gen_common_kwargs() -> Dict[str, Any]:
|
| 267 |
-
return dict(gen_kwargs_default)
|
| 268 |
-
|
| 269 |
-
theme = gr.themes.Soft(
|
| 270 |
-
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
|
| 271 |
-
)
|
| 272 |
-
|
| 273 |
-
css = ".gradio-container {max-width: none !important;}"
|
| 274 |
-
|
| 275 |
-
with gr.Blocks(theme=theme, css=css) as demo:
|
| 276 |
-
gr.Markdown(
|
| 277 |
-
f"""
|
| 278 |
-
# Qwen3 TTS Demo
|
| 279 |
-
**Checkpoint:** `{ckpt}`
|
| 280 |
-
**Model Type:** `{model_kind}`
|
| 281 |
-
"""
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
if model_kind == "custom_voice":
|
| 285 |
-
with gr.Row():
|
| 286 |
-
with gr.Column(scale=2):
|
| 287 |
-
text_in = gr.Textbox(
|
| 288 |
-
label="Text (待合成文本)",
|
| 289 |
-
lines=4,
|
| 290 |
-
placeholder="Enter text to synthesize (输入要合成的文本).",
|
| 291 |
-
)
|
| 292 |
-
with gr.Row():
|
| 293 |
-
lang_in = gr.Dropdown(
|
| 294 |
-
label="Language (语种)",
|
| 295 |
-
choices=lang_choices_disp,
|
| 296 |
-
value="Auto",
|
| 297 |
-
interactive=True,
|
| 298 |
-
)
|
| 299 |
-
spk_in = gr.Dropdown(
|
| 300 |
-
label="Speaker (说话人)",
|
| 301 |
-
choices=spk_choices_disp,
|
| 302 |
-
value="Vivian",
|
| 303 |
-
interactive=True,
|
| 304 |
-
)
|
| 305 |
-
instruct_in = gr.Textbox(
|
| 306 |
-
label="Instruction (Optional) (控制指令,可不输入)",
|
| 307 |
-
lines=2,
|
| 308 |
-
placeholder="e.g. Say it in a very angry tone (例如:用特别伤心的语气说).",
|
| 309 |
-
)
|
| 310 |
-
btn = gr.Button("Generate (生成)", variant="primary")
|
| 311 |
-
with gr.Column(scale=3):
|
| 312 |
-
audio_out = gr.Audio(label="Output Audio (合成结果)", type="numpy")
|
| 313 |
-
err = gr.Textbox(label="Status (状态)", lines=2)
|
| 314 |
-
|
| 315 |
-
def run_instruct(text: str, lang_disp: str, spk_disp: str, instruct: str):
|
| 316 |
-
try:
|
| 317 |
-
if not text or not text.strip():
|
| 318 |
-
return None, "Text is required (必须填写文本)."
|
| 319 |
-
if not spk_disp:
|
| 320 |
-
return None, "Speaker is required (必须选择说话人)."
|
| 321 |
-
language = lang_map.get(lang_disp, "Auto")
|
| 322 |
-
speaker = spk_map.get(spk_disp, spk_disp)
|
| 323 |
-
kwargs = _gen_common_kwargs()
|
| 324 |
-
wavs, sr = tts.generate_custom_voice(
|
| 325 |
-
text=text.strip(),
|
| 326 |
-
language=language,
|
| 327 |
-
speaker=speaker,
|
| 328 |
-
instruct=(instruct or "").strip() or None,
|
| 329 |
-
**kwargs,
|
| 330 |
-
)
|
| 331 |
-
return _wav_to_gradio_audio(wavs[0], sr), "Finished. (生成完成)"
|
| 332 |
-
except Exception as e:
|
| 333 |
-
return None, f"{type(e).__name__}: {e}"
|
| 334 |
-
|
| 335 |
-
btn.click(run_instruct, inputs=[text_in, lang_in, spk_in, instruct_in], outputs=[audio_out, err])
|
| 336 |
-
|
| 337 |
-
elif model_kind == "voice_design":
|
| 338 |
-
with gr.Row():
|
| 339 |
-
with gr.Column(scale=2):
|
| 340 |
-
text_in = gr.Textbox(
|
| 341 |
-
label="Text (待合成文本)",
|
| 342 |
-
lines=4,
|
| 343 |
-
value="It's in the top drawer... wait, it's empty? No way, that's impossible! I'm sure I put it there!"
|
| 344 |
-
)
|
| 345 |
-
with gr.Row():
|
| 346 |
-
lang_in = gr.Dropdown(
|
| 347 |
-
label="Language (语种)",
|
| 348 |
-
choices=lang_choices_disp,
|
| 349 |
-
value="Auto",
|
| 350 |
-
interactive=True,
|
| 351 |
-
)
|
| 352 |
-
design_in = gr.Textbox(
|
| 353 |
-
label="Voice Design Instruction (音色描述)",
|
| 354 |
-
lines=3,
|
| 355 |
-
value="Speak in an incredulous tone, but with a hint of panic beginning to creep into your voice."
|
| 356 |
-
)
|
| 357 |
-
btn = gr.Button("Generate (生成)", variant="primary")
|
| 358 |
-
with gr.Column(scale=3):
|
| 359 |
-
audio_out = gr.Audio(label="Output Audio (合成结果)", type="numpy")
|
| 360 |
-
err = gr.Textbox(label="Status (状态)", lines=2)
|
| 361 |
-
|
| 362 |
-
def run_voice_design(text: str, lang_disp: str, design: str):
|
| 363 |
-
try:
|
| 364 |
-
if not text or not text.strip():
|
| 365 |
-
return None, "Text is required (必须填写文本)."
|
| 366 |
-
if not design or not design.strip():
|
| 367 |
-
return None, "Voice design instruction is required (必须填写音色描述)."
|
| 368 |
-
language = lang_map.get(lang_disp, "Auto")
|
| 369 |
-
kwargs = _gen_common_kwargs()
|
| 370 |
-
wavs, sr = tts.generate_voice_design(
|
| 371 |
-
text=text.strip(),
|
| 372 |
-
language=language,
|
| 373 |
-
instruct=design.strip(),
|
| 374 |
-
**kwargs,
|
| 375 |
-
)
|
| 376 |
-
return _wav_to_gradio_audio(wavs[0], sr), "Finished. (生成完成)"
|
| 377 |
-
except Exception as e:
|
| 378 |
-
return None, f"{type(e).__name__}: {e}"
|
| 379 |
-
|
| 380 |
-
btn.click(run_voice_design, inputs=[text_in, lang_in, design_in], outputs=[audio_out, err])
|
| 381 |
-
|
| 382 |
-
else: # voice_clone for base
|
| 383 |
-
with gr.Tabs():
|
| 384 |
-
with gr.Tab("Clone & Generate (克隆并合成)"):
|
| 385 |
-
with gr.Row():
|
| 386 |
-
with gr.Column(scale=2):
|
| 387 |
-
ref_audio = gr.Audio(
|
| 388 |
-
label="Reference Audio (参考音频)",
|
| 389 |
-
)
|
| 390 |
-
ref_text = gr.Textbox(
|
| 391 |
-
label="Reference Text (参考音频文本)",
|
| 392 |
-
lines=2,
|
| 393 |
-
placeholder="Required if not set use x-vector only (不勾选use x-vector only时必填).",
|
| 394 |
-
)
|
| 395 |
-
xvec_only = gr.Checkbox(
|
| 396 |
-
label="Use x-vector only (仅用说话人向量,效果有限,但不用传入参考音频文本)",
|
| 397 |
-
value=False,
|
| 398 |
-
)
|
| 399 |
-
|
| 400 |
-
with gr.Column(scale=2):
|
| 401 |
-
text_in = gr.Textbox(
|
| 402 |
-
label="Target Text (待合成文本)",
|
| 403 |
-
lines=4,
|
| 404 |
-
placeholder="Enter text to synthesize (输入要合成的文本).",
|
| 405 |
-
)
|
| 406 |
-
lang_in = gr.Dropdown(
|
| 407 |
-
label="Language (语种)",
|
| 408 |
-
choices=lang_choices_disp,
|
| 409 |
-
value="Auto",
|
| 410 |
-
interactive=True,
|
| 411 |
-
)
|
| 412 |
-
btn = gr.Button("Generate (生成)", variant="primary")
|
| 413 |
-
|
| 414 |
-
with gr.Column(scale=3):
|
| 415 |
-
audio_out = gr.Audio(label="Output Audio (合成结果)", type="numpy")
|
| 416 |
-
err = gr.Textbox(label="Status (状态)", lines=2)
|
| 417 |
-
|
| 418 |
-
def run_voice_clone(ref_aud, ref_txt: str, use_xvec: bool, text: str, lang_disp: str):
|
| 419 |
-
try:
|
| 420 |
-
if not text or not text.strip():
|
| 421 |
-
return None, "Target text is required (必须填写待合成文本)."
|
| 422 |
-
at = _audio_to_tuple(ref_aud)
|
| 423 |
-
if at is None:
|
| 424 |
-
return None, "Reference audio is required (必须上传参考音频)."
|
| 425 |
-
if (not use_xvec) and (not ref_txt or not ref_txt.strip()):
|
| 426 |
-
return None, (
|
| 427 |
-
"Reference text is required when use x-vector only is NOT enabled.\n"
|
| 428 |
-
"(未勾选 use x-vector only 时,必须提供参考音频文本;否则请勾选 use x-vector only,但效果会变差.)"
|
| 429 |
-
)
|
| 430 |
-
language = lang_map.get(lang_disp, "Auto")
|
| 431 |
-
kwargs = _gen_common_kwargs()
|
| 432 |
-
wavs, sr = tts.generate_voice_clone(
|
| 433 |
-
text=text.strip(),
|
| 434 |
-
language=language,
|
| 435 |
-
ref_audio=at,
|
| 436 |
-
ref_text=(ref_txt.strip() if ref_txt else None),
|
| 437 |
-
x_vector_only_mode=bool(use_xvec),
|
| 438 |
-
**kwargs,
|
| 439 |
-
)
|
| 440 |
-
return _wav_to_gradio_audio(wavs[0], sr), "Finished. (生成完成)"
|
| 441 |
-
except Exception as e:
|
| 442 |
-
return None, f"{type(e).__name__}: {e}"
|
| 443 |
-
|
| 444 |
-
btn.click(
|
| 445 |
-
run_voice_clone,
|
| 446 |
-
inputs=[ref_audio, ref_text, xvec_only, text_in, lang_in],
|
| 447 |
-
outputs=[audio_out, err],
|
| 448 |
-
)
|
| 449 |
-
|
| 450 |
-
with gr.Tab("Save / Load Voice (保存/加载克隆音色)"):
|
| 451 |
-
with gr.Row():
|
| 452 |
-
with gr.Column(scale=2):
|
| 453 |
-
gr.Markdown(
|
| 454 |
-
"""
|
| 455 |
-
### Save Voice (保存音色)
|
| 456 |
-
Upload reference audio and text, choose use x-vector only or not, then save a reusable voice prompt file.
|
| 457 |
-
(上传参考音频和参考文本,选择是否使用 use x-vector only 模式后保存为可复用的音色文件)
|
| 458 |
-
"""
|
| 459 |
-
)
|
| 460 |
-
ref_audio_s = gr.Audio(label="Reference Audio (参考音频)", type="numpy")
|
| 461 |
-
ref_text_s = gr.Textbox(
|
| 462 |
-
label="Reference Text (参考音频文本)",
|
| 463 |
-
lines=2,
|
| 464 |
-
placeholder="Required if not set use x-vector only (不勾选use x-vector only时必填).",
|
| 465 |
-
)
|
| 466 |
-
xvec_only_s = gr.Checkbox(
|
| 467 |
-
label="Use x-vector only (仅用说话人向量,效果有限,但不用传入参考音频文本)",
|
| 468 |
-
value=False,
|
| 469 |
-
)
|
| 470 |
-
save_btn = gr.Button("Save Voice File (保存音色文件)", variant="primary")
|
| 471 |
-
prompt_file_out = gr.File(label="Voice File (音色文件)")
|
| 472 |
-
|
| 473 |
-
with gr.Column(scale=2):
|
| 474 |
-
gr.Markdown(
|
| 475 |
-
"""
|
| 476 |
-
### Load Voice & Generate (加载音色并合成)
|
| 477 |
-
Upload a previously saved voice file, then synthesize new text.
|
| 478 |
-
(上传已保存提示文件后,输入新文本进行合成)
|
| 479 |
-
"""
|
| 480 |
-
)
|
| 481 |
-
prompt_file_in = gr.File(label="Upload Prompt File (上传提示文件)")
|
| 482 |
-
text_in2 = gr.Textbox(
|
| 483 |
-
label="Target Text (待合成文本)",
|
| 484 |
-
lines=4,
|
| 485 |
-
placeholder="Enter text to synthesize (输入要合成的文本).",
|
| 486 |
-
)
|
| 487 |
-
lang_in2 = gr.Dropdown(
|
| 488 |
-
label="Language (语种)",
|
| 489 |
-
choices=lang_choices_disp,
|
| 490 |
-
value="Auto",
|
| 491 |
-
interactive=True,
|
| 492 |
-
)
|
| 493 |
-
gen_btn2 = gr.Button("Generate (生成)", variant="primary")
|
| 494 |
-
|
| 495 |
-
with gr.Column(scale=3):
|
| 496 |
-
audio_out2 = gr.Audio(label="Output Audio (合成结果)", type="numpy")
|
| 497 |
-
err2 = gr.Textbox(label="Status (状态)", lines=2)
|
| 498 |
-
|
| 499 |
-
def save_prompt(ref_aud, ref_txt: str, use_xvec: bool):
|
| 500 |
-
try:
|
| 501 |
-
at = _audio_to_tuple(ref_aud)
|
| 502 |
-
if at is None:
|
| 503 |
-
return None, "Reference audio is required (必须上传参考音频)."
|
| 504 |
-
if (not use_xvec) and (not ref_txt or not ref_txt.strip()):
|
| 505 |
-
return None, (
|
| 506 |
-
"Reference text is required when use x-vector only is NOT enabled.\n"
|
| 507 |
-
"(未勾选 use x-vector only 时,必须提供参考音频文本;否则请勾选 use x-vector only,但效果会变差.)"
|
| 508 |
-
)
|
| 509 |
-
items = tts.create_voice_clone_prompt(
|
| 510 |
-
ref_audio=at,
|
| 511 |
-
ref_text=(ref_txt.strip() if ref_txt else None),
|
| 512 |
-
x_vector_only_mode=bool(use_xvec),
|
| 513 |
-
)
|
| 514 |
-
payload = {
|
| 515 |
-
"items": [asdict(it) for it in items],
|
| 516 |
-
}
|
| 517 |
-
fd, out_path = tempfile.mkstemp(prefix="voice_clone_prompt_", suffix=".pt")
|
| 518 |
-
os.close(fd)
|
| 519 |
-
torch.save(payload, out_path)
|
| 520 |
-
return out_path, "Finished. (生成完成)"
|
| 521 |
-
except Exception as e:
|
| 522 |
-
return None, f"{type(e).__name__}: {e}"
|
| 523 |
-
|
| 524 |
-
def load_prompt_and_gen(file_obj, text: str, lang_disp: str):
|
| 525 |
-
try:
|
| 526 |
-
if file_obj is None:
|
| 527 |
-
return None, "Voice file is required (必须上传音色文件)."
|
| 528 |
-
if not text or not text.strip():
|
| 529 |
-
return None, "Target text is required (必须填写待合成文本)."
|
| 530 |
-
|
| 531 |
-
path = getattr(file_obj, "name", None) or getattr(file_obj, "path", None) or str(file_obj)
|
| 532 |
-
payload = torch.load(path, map_location="cpu", weights_only=True)
|
| 533 |
-
if not isinstance(payload, dict) or "items" not in payload:
|
| 534 |
-
return None, "Invalid file format (文件格式不正确)."
|
| 535 |
-
|
| 536 |
-
items_raw = payload["items"]
|
| 537 |
-
if not isinstance(items_raw, list) or len(items_raw) == 0:
|
| 538 |
-
return None, "Empty voice items (音色为空)."
|
| 539 |
-
|
| 540 |
-
items: List[VoiceClonePromptItem] = []
|
| 541 |
-
for d in items_raw:
|
| 542 |
-
if not isinstance(d, dict):
|
| 543 |
-
return None, "Invalid item format in file (文件内部格式错误)."
|
| 544 |
-
ref_code = d.get("ref_code", None)
|
| 545 |
-
if ref_code is not None and not torch.is_tensor(ref_code):
|
| 546 |
-
ref_code = torch.tensor(ref_code)
|
| 547 |
-
ref_spk = d.get("ref_spk_embedding", None)
|
| 548 |
-
if ref_spk is None:
|
| 549 |
-
return None, "Missing ref_spk_embedding (缺少说话人向量)."
|
| 550 |
-
if not torch.is_tensor(ref_spk):
|
| 551 |
-
ref_spk = torch.tensor(ref_spk)
|
| 552 |
-
|
| 553 |
-
items.append(
|
| 554 |
-
VoiceClonePromptItem(
|
| 555 |
-
ref_code=ref_code,
|
| 556 |
-
ref_spk_embedding=ref_spk,
|
| 557 |
-
x_vector_only_mode=bool(d.get("x_vector_only_mode", False)),
|
| 558 |
-
icl_mode=bool(d.get("icl_mode", not bool(d.get("x_vector_only_mode", False)))),
|
| 559 |
-
ref_text=d.get("ref_text", None),
|
| 560 |
-
)
|
| 561 |
-
)
|
| 562 |
-
|
| 563 |
-
language = lang_map.get(lang_disp, "Auto")
|
| 564 |
-
kwargs = _gen_common_kwargs()
|
| 565 |
-
wavs, sr = tts.generate_voice_clone(
|
| 566 |
-
text=text.strip(),
|
| 567 |
-
language=language,
|
| 568 |
-
voice_clone_prompt=items,
|
| 569 |
-
**kwargs,
|
| 570 |
-
)
|
| 571 |
-
return _wav_to_gradio_audio(wavs[0], sr), "Finished. (生成完成)"
|
| 572 |
-
except Exception as e:
|
| 573 |
-
return None, (
|
| 574 |
-
f"Failed to read or use voice file. Check file format/content.\n"
|
| 575 |
-
f"(读取或使用音色文件失败,请检查文件格式或内容)\n"
|
| 576 |
-
f"{type(e).__name__}: {e}"
|
| 577 |
-
)
|
| 578 |
-
|
| 579 |
-
save_btn.click(save_prompt, inputs=[ref_audio_s, ref_text_s, xvec_only_s], outputs=[prompt_file_out, err2])
|
| 580 |
-
gen_btn2.click(load_prompt_and_gen, inputs=[prompt_file_in, text_in2, lang_in2], outputs=[audio_out2, err2])
|
| 581 |
-
|
| 582 |
-
gr.Markdown(
|
| 583 |
-
"""
|
| 584 |
-
**Disclaimer (免责声明)**
|
| 585 |
-
- The audio is automatically generated/synthesized by an AI model solely to demonstrate the model’s capabilities; it may be inaccurate or inappropriate, does not represent the views of the developer/operator, and does not constitute professional advice. You are solely responsible for evaluating, using, distributing, or relying on this audio; to the maximum extent permitted by applicable law, the developer/operator disclaims liability for any direct, indirect, incidental, or consequential damages arising from the use of or inability to use the audio, except where liability cannot be excluded by law. Do not use this service to intentionally generate or replicate unlawful, harmful, defamatory, fraudulent, deepfake, or privacy/publicity/copyright/trademark‑infringing content; if a user prompts, supplies materials, or otherwise facilitates any illegal or infringing conduct, the user bears all legal consequences and the developer/operator is not responsible.
|
| 586 |
-
- 音频由人工智能模型自动生成/合成,仅用于体验与展示模型效果,可能存在不准确或不当之处;其内容不代表开发者/运营方立场,亦不构成任何专业建议。用户应自行评估并承担使用、传播或依赖该音频所产生的一切风险与责任;在适用法律允许的最大范围内,开发者/运营方不对因使用或无法使用本音频造成的任何直接、间接、附带或后果性损失承担责任(法律另有强制规定的除外)。严禁利用本服务故意引导生成或复制违法、有害、诽谤、欺诈、深度伪造、侵犯隐私/肖像/著作权/商标等内容;如用户通过提示词、素材或其他方式实施或促成任何违法或侵权行为,相关法律后果由用户自行承担,与开发者/运营方无关。
|
| 587 |
-
"""
|
| 588 |
-
)
|
| 589 |
-
|
| 590 |
-
return demo
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
def main(argv=None) -> int:
|
| 594 |
-
parser = build_parser()
|
| 595 |
-
args = parser.parse_args(argv)
|
| 596 |
-
|
| 597 |
-
if not args.checkpoint and not args.checkpoint_pos:
|
| 598 |
-
parser.print_help()
|
| 599 |
-
return 0
|
| 600 |
-
|
| 601 |
-
ckpt = _resolve_checkpoint(args)
|
| 602 |
-
|
| 603 |
-
dtype = _dtype_from_str(args.dtype)
|
| 604 |
-
attn_impl = "flash_attention_2" if args.flash_attn else None
|
| 605 |
-
|
| 606 |
-
tts = Qwen3TTSModel.from_pretrained(
|
| 607 |
-
ckpt,
|
| 608 |
-
device_map=args.device,
|
| 609 |
-
dtype=dtype,
|
| 610 |
-
attn_implementation=attn_impl,
|
| 611 |
-
)
|
| 612 |
-
|
| 613 |
-
gen_kwargs_default = _collect_gen_kwargs(args)
|
| 614 |
-
demo = build_demo(tts, ckpt, gen_kwargs_default)
|
| 615 |
-
|
| 616 |
-
launch_kwargs: Dict[str, Any] = dict(
|
| 617 |
-
server_name=args.ip,
|
| 618 |
-
server_port=args.port,
|
| 619 |
-
share=args.share,
|
| 620 |
-
)
|
| 621 |
-
if args.ssl_certfile is not None:
|
| 622 |
-
launch_kwargs["ssl_certfile"] = args.ssl_certfile
|
| 623 |
-
if args.ssl_keyfile is not None:
|
| 624 |
-
launch_kwargs["ssl_keyfile"] = args.ssl_keyfile
|
| 625 |
-
if args.ssl_verify is not None:
|
| 626 |
-
launch_kwargs["ssl_verify"] = args.ssl_verify
|
| 627 |
-
|
| 628 |
-
demo.queue(default_concurrency_limit=int(args.concurrency)).launch(**launch_kwargs)
|
| 629 |
-
return 0
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
if __name__ == "__main__":
|
| 633 |
-
raise SystemExit(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/__init__.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
from .tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Config
|
| 17 |
-
from .tokenizer_25hz.modeling_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Model
|
| 18 |
-
from .tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Config
|
| 19 |
-
from .tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/models/__init__.py
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
from .configuration_qwen3_tts import Qwen3TTSConfig
|
| 17 |
-
from .modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
|
| 18 |
-
from .processing_qwen3_tts import Qwen3TTSProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/models/configuration_qwen3_tts.py
DELETED
|
@@ -1,502 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
|
| 16 |
-
from transformers.modeling_rope_utils import rope_config_validation
|
| 17 |
-
from transformers.utils import logging
|
| 18 |
-
|
| 19 |
-
logger = logging.get_logger(__name__)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class Qwen3TTSSpeakerEncoderConfig(PretrainedConfig):
|
| 23 |
-
r"""
|
| 24 |
-
This is the configuration class to store the configuration of a [`Qwen3TTSSpeakerEncoder`].
|
| 25 |
-
It is used to instantiate a Qwen3TTS speaker encoder model according to the specified arguments, defining the model
|
| 26 |
-
architecture. The architecture is based on the ECAPA-TDNN model.
|
| 27 |
-
|
| 28 |
-
Args:
|
| 29 |
-
mel_dim (`int`, *optional*, defaults to 128):
|
| 30 |
-
The dimension of the input mel-spectrogram.
|
| 31 |
-
enc_dim (`int`, *optional*, defaults to 192):
|
| 32 |
-
The dimension of the final speaker embedding.
|
| 33 |
-
enc_channels (`list[int]`, *optional*, defaults to `[512, 512, 512, 512, 1536]`):
|
| 34 |
-
A list of output channels for each TDNN/SERes2Net layer in the encoder. The first channel size is for the initial TDNN layer,
|
| 35 |
-
the intermediate ones for the `SqueezeExcitationRes2NetBlock` layers, and the last one for the multi-layer feature aggregation.
|
| 36 |
-
enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
|
| 37 |
-
A list of kernel sizes for each layer in the encoder, corresponding to `enc_channels`.
|
| 38 |
-
enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
|
| 39 |
-
A list of dilations for each layer in the encoder, corresponding to `enc_channels`.
|
| 40 |
-
enc_attention_channels (`int`, *optional*, defaults to 128):
|
| 41 |
-
The number of attention channels in the `AttentiveStatisticsPooling` layer.
|
| 42 |
-
enc_res2net_scale (`int`, *optional*,defaults to 8):
|
| 43 |
-
The scale of the `Res2NetBlock` in the encoder.
|
| 44 |
-
enc_se_channels (`int`, *optional*, defaults to 128):
|
| 45 |
-
The number of channels in the squeeze part of the `SqueezeExcitationBlock`.
|
| 46 |
-
"""
|
| 47 |
-
def __init__(
|
| 48 |
-
self,
|
| 49 |
-
mel_dim=128,
|
| 50 |
-
enc_dim=1024,
|
| 51 |
-
enc_channels=[512, 512, 512, 512, 1536],
|
| 52 |
-
enc_kernel_sizes=[5, 3, 3, 3, 1],
|
| 53 |
-
enc_dilations=[1, 2, 3, 4, 1],
|
| 54 |
-
enc_attention_channels=128,
|
| 55 |
-
enc_res2net_scale=8,
|
| 56 |
-
enc_se_channels=128,
|
| 57 |
-
sample_rate=24000,
|
| 58 |
-
):
|
| 59 |
-
self.mel_dim = mel_dim
|
| 60 |
-
self.enc_dim = enc_dim
|
| 61 |
-
self.enc_channels = enc_channels
|
| 62 |
-
self.enc_kernel_sizes = enc_kernel_sizes
|
| 63 |
-
self.enc_dilations = enc_dilations
|
| 64 |
-
self.enc_attention_channels = enc_attention_channels
|
| 65 |
-
self.enc_res2net_scale = enc_res2net_scale
|
| 66 |
-
self.enc_se_channels = enc_se_channels
|
| 67 |
-
self.sample_rate = sample_rate
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
class Qwen3TTSTalkerCodePredictorConfig(PretrainedConfig):
|
| 71 |
-
r"""
|
| 72 |
-
This is the configuration class to store the configuration of a [`Qwen3TTSTalkerCodePredictorModel`]. It is used to instantiate a
|
| 73 |
-
Qwen3TTSTalkerCodePredictor model according to the specified arguments, defining the model architecture.
|
| 74 |
-
|
| 75 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 76 |
-
documentation from [`PretrainedConfig`] for more information.
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
Args:
|
| 80 |
-
vocab_size (`int`, *optional*, defaults to 151936):
|
| 81 |
-
Vocabulary size of the Qwen3TTSTalkerCodePredictor model. Defines the number of different tokens that can be represented by the
|
| 82 |
-
`inputs_ids` passed when calling [`Qwen3TTSTalkerCodePredictorModel`]
|
| 83 |
-
hidden_size (`int`, *optional*, defaults to 4096):
|
| 84 |
-
Dimension of the hidden representations.
|
| 85 |
-
intermediate_size (`int`, *optional*, defaults to 22016):
|
| 86 |
-
Dimension of the MLP representations.
|
| 87 |
-
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 88 |
-
Number of hidden layers in the Transformer encoder.
|
| 89 |
-
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 90 |
-
Number of attention heads for each attention layer in the Transformer encoder.
|
| 91 |
-
num_key_value_heads (`int`, *optional*, defaults to 32):
|
| 92 |
-
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 93 |
-
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 94 |
-
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 95 |
-
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 96 |
-
by meanpooling all the original heads within that group. For more details, check out [this
|
| 97 |
-
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
| 98 |
-
head_dim (`int`, *optional*, defaults to 128):
|
| 99 |
-
The attention head dimension.
|
| 100 |
-
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 101 |
-
The non-linear activation function (function or string) in the decoder.
|
| 102 |
-
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
| 103 |
-
The maximum sequence length that this model might ever be used with.
|
| 104 |
-
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 105 |
-
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 106 |
-
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 107 |
-
The epsilon used by the rms normalization layers.
|
| 108 |
-
use_cache (`bool`, *optional*, defaults to `True`):
|
| 109 |
-
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 110 |
-
relevant if `config.is_decoder=True`.
|
| 111 |
-
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 112 |
-
Whether the model's input and output word embeddings should be tied.
|
| 113 |
-
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 114 |
-
The base period of the RoPE embeddings.
|
| 115 |
-
rope_scaling (`Dict`, *optional*):
|
| 116 |
-
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 117 |
-
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 118 |
-
accordingly.
|
| 119 |
-
Expected contents:
|
| 120 |
-
`rope_type` (`str`):
|
| 121 |
-
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 122 |
-
'llama3'], with 'default' being the original RoPE implementation.
|
| 123 |
-
`factor` (`float`, *optional*):
|
| 124 |
-
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 125 |
-
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 126 |
-
original maximum pre-trained length.
|
| 127 |
-
`original_max_position_embeddings` (`int`, *optional*):
|
| 128 |
-
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 129 |
-
pretraining.
|
| 130 |
-
`attention_factor` (`float`, *optional*):
|
| 131 |
-
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 132 |
-
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 133 |
-
`factor` field to infer the suggested value.
|
| 134 |
-
`beta_fast` (`float`, *optional*):
|
| 135 |
-
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 136 |
-
ramp function. If unspecified, it defaults to 32.
|
| 137 |
-
`beta_slow` (`float`, *optional*):
|
| 138 |
-
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 139 |
-
ramp function. If unspecified, it defaults to 1.
|
| 140 |
-
`short_factor` (`list[float]`, *optional*):
|
| 141 |
-
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 142 |
-
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 143 |
-
size divided by the number of attention heads divided by 2
|
| 144 |
-
`long_factor` (`list[float]`, *optional*):
|
| 145 |
-
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 146 |
-
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 147 |
-
size divided by the number of attention heads divided by 2
|
| 148 |
-
`low_freq_factor` (`float`, *optional*):
|
| 149 |
-
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 150 |
-
`high_freq_factor` (`float`, *optional*):
|
| 151 |
-
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 152 |
-
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 153 |
-
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 154 |
-
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
| 155 |
-
Whether to use sliding window attention.
|
| 156 |
-
sliding_window (`int`, *optional*, defaults to 4096):
|
| 157 |
-
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
| 158 |
-
max_window_layers (`int`, *optional*, defaults to 28):
|
| 159 |
-
The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
|
| 160 |
-
additional layer afterwards will use SWA (Sliding Window Attention).
|
| 161 |
-
layer_types (`list`, *optional*):
|
| 162 |
-
Attention pattern for each layer.
|
| 163 |
-
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 164 |
-
The dropout ratio for the attention probabilities.
|
| 165 |
-
|
| 166 |
-
"""
|
| 167 |
-
|
| 168 |
-
model_type = "qwen3_tts_talker_code_predictor"
|
| 169 |
-
keys_to_ignore_at_inference = ["past_key_values"]
|
| 170 |
-
|
| 171 |
-
# Default tensor parallel plan for base model `Qwen3TTSTalkerCodePredictor`
|
| 172 |
-
base_model_tp_plan = {
|
| 173 |
-
"layers.*.self_attn.q_proj": "colwise",
|
| 174 |
-
"layers.*.self_attn.k_proj": "colwise",
|
| 175 |
-
"layers.*.self_attn.v_proj": "colwise",
|
| 176 |
-
"layers.*.self_attn.o_proj": "rowwise",
|
| 177 |
-
"layers.*.mlp.gate_proj": "colwise",
|
| 178 |
-
"layers.*.mlp.up_proj": "colwise",
|
| 179 |
-
"layers.*.mlp.down_proj": "rowwise",
|
| 180 |
-
}
|
| 181 |
-
base_model_pp_plan = {
|
| 182 |
-
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 183 |
-
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 184 |
-
"norm": (["hidden_states"], ["hidden_states"]),
|
| 185 |
-
}
|
| 186 |
-
|
| 187 |
-
def __init__(
|
| 188 |
-
self,
|
| 189 |
-
vocab_size=2048,
|
| 190 |
-
hidden_size=1024,
|
| 191 |
-
intermediate_size=3072,
|
| 192 |
-
num_hidden_layers=5,
|
| 193 |
-
num_attention_heads=16,
|
| 194 |
-
num_key_value_heads=8,
|
| 195 |
-
head_dim=128,
|
| 196 |
-
hidden_act="silu",
|
| 197 |
-
max_position_embeddings=32768,
|
| 198 |
-
initializer_range=0.02,
|
| 199 |
-
rms_norm_eps=0.000001,
|
| 200 |
-
use_cache=True,
|
| 201 |
-
tie_word_embeddings=False,
|
| 202 |
-
rope_theta=10000,
|
| 203 |
-
rope_scaling=None,
|
| 204 |
-
attention_bias=False,
|
| 205 |
-
use_sliding_window=False,
|
| 206 |
-
sliding_window=4096,
|
| 207 |
-
max_window_layers=28,
|
| 208 |
-
layer_types=None,
|
| 209 |
-
attention_dropout=0,
|
| 210 |
-
num_code_groups=32,
|
| 211 |
-
**kwargs,
|
| 212 |
-
):
|
| 213 |
-
super().__init__(
|
| 214 |
-
tie_word_embeddings=tie_word_embeddings,
|
| 215 |
-
**kwargs,
|
| 216 |
-
)
|
| 217 |
-
self.vocab_size = vocab_size
|
| 218 |
-
self.max_position_embeddings = max_position_embeddings
|
| 219 |
-
self.hidden_size = hidden_size
|
| 220 |
-
self.intermediate_size = intermediate_size
|
| 221 |
-
self.num_hidden_layers = num_hidden_layers
|
| 222 |
-
self.num_attention_heads = num_attention_heads
|
| 223 |
-
self.use_sliding_window = use_sliding_window
|
| 224 |
-
self.sliding_window = sliding_window if self.use_sliding_window else None
|
| 225 |
-
self.max_window_layers = max_window_layers
|
| 226 |
-
|
| 227 |
-
# for backward compatibility
|
| 228 |
-
if num_key_value_heads is None:
|
| 229 |
-
num_key_value_heads = num_attention_heads
|
| 230 |
-
|
| 231 |
-
self.num_key_value_heads = num_key_value_heads
|
| 232 |
-
self.head_dim = head_dim
|
| 233 |
-
self.hidden_act = hidden_act
|
| 234 |
-
self.initializer_range = initializer_range
|
| 235 |
-
self.rms_norm_eps = rms_norm_eps
|
| 236 |
-
self.use_cache = use_cache
|
| 237 |
-
self.rope_theta = rope_theta
|
| 238 |
-
self.rope_scaling = rope_scaling
|
| 239 |
-
self.attention_bias = attention_bias
|
| 240 |
-
self.attention_dropout = attention_dropout
|
| 241 |
-
# Validate the correctness of rotary position embeddings parameters
|
| 242 |
-
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 243 |
-
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 244 |
-
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 245 |
-
rope_config_validation(self)
|
| 246 |
-
|
| 247 |
-
self.layer_types = layer_types
|
| 248 |
-
if self.layer_types is None:
|
| 249 |
-
self.layer_types = [
|
| 250 |
-
"sliding_attention"
|
| 251 |
-
if self.sliding_window is not None and i >= self.max_window_layers
|
| 252 |
-
else "full_attention"
|
| 253 |
-
for i in range(self.num_hidden_layers)
|
| 254 |
-
]
|
| 255 |
-
layer_type_validation(self.layer_types)
|
| 256 |
-
self.num_code_groups = num_code_groups
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
class Qwen3TTSTalkerConfig(PretrainedConfig):
|
| 260 |
-
r"""
|
| 261 |
-
This is the configuration class to store the configuration of a [`Qwen3TTSTalkerModel`]. It is used to instantiate a
|
| 262 |
-
Qwen3TTSTalker model according to the specified arguments, defining the model architecture.
|
| 263 |
-
|
| 264 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 265 |
-
documentation from [`PretrainedConfig`] for more information.
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
Args:
|
| 269 |
-
vocab_size (`int`, *optional*, defaults to 151936):
|
| 270 |
-
Vocabulary size of the Qwen3TTSTalker model. Defines the number of different tokens that can be represented by the
|
| 271 |
-
`inputs_ids` passed when calling [`Qwen3TTSTalkerModel`]
|
| 272 |
-
hidden_size (`int`, *optional*, defaults to 2048):
|
| 273 |
-
Dimension of the hidden representations.
|
| 274 |
-
intermediate_size (`int`, *optional*, defaults to 6144):
|
| 275 |
-
Dimension of the MLP representations.
|
| 276 |
-
num_hidden_layers (`int`, *optional*, defaults to 24):
|
| 277 |
-
Number of hidden layers in the Transformer encoder.
|
| 278 |
-
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 279 |
-
Number of attention heads for each attention layer in the Transformer encoder.
|
| 280 |
-
num_key_value_heads (`int`, *optional*, defaults to 4):
|
| 281 |
-
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 282 |
-
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 283 |
-
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 284 |
-
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 285 |
-
by meanpooling all the original heads within that group. For more details, check out [this
|
| 286 |
-
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
| 287 |
-
|
| 288 |
-
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 289 |
-
The non-linear activation function (function or string) in the decoder.
|
| 290 |
-
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
| 291 |
-
The maximum sequence length that this model might ever be used with.
|
| 292 |
-
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 293 |
-
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 294 |
-
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 295 |
-
The epsilon used by the rms normalization layers.
|
| 296 |
-
use_cache (`bool`, *optional*, defaults to `True`):
|
| 297 |
-
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 298 |
-
relevant if `config.is_decoder=True`.
|
| 299 |
-
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 300 |
-
Whether the model's input and output word embeddings should be tied.
|
| 301 |
-
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 302 |
-
The base period of the RoPE embeddings.
|
| 303 |
-
rope_scaling (`Dict`, *optional*):
|
| 304 |
-
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 305 |
-
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 306 |
-
accordingly.
|
| 307 |
-
Expected contents:
|
| 308 |
-
`rope_type` (`str`):
|
| 309 |
-
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 310 |
-
'llama3'], with 'default' being the original RoPE implementation.
|
| 311 |
-
`factor` (`float`, *optional*):
|
| 312 |
-
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 313 |
-
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 314 |
-
original maximum pre-trained length.
|
| 315 |
-
`original_max_position_embeddings` (`int`, *optional*):
|
| 316 |
-
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 317 |
-
pretraining.
|
| 318 |
-
`attention_factor` (`float`, *optional*):
|
| 319 |
-
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 320 |
-
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 321 |
-
`factor` field to infer the suggested value.
|
| 322 |
-
`beta_fast` (`float`, *optional*):
|
| 323 |
-
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 324 |
-
ramp function. If unspecified, it defaults to 32.
|
| 325 |
-
`beta_slow` (`float`, *optional*):
|
| 326 |
-
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 327 |
-
ramp function. If unspecified, it defaults to 1.
|
| 328 |
-
`short_factor` (`list[float]`, *optional*):
|
| 329 |
-
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 330 |
-
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 331 |
-
size divided by the number of attention heads divided by 2
|
| 332 |
-
`long_factor` (`list[float]`, *optional*):
|
| 333 |
-
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 334 |
-
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 335 |
-
size divided by the number of attention heads divided by 2
|
| 336 |
-
`low_freq_factor` (`float`, *optional*):
|
| 337 |
-
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 338 |
-
`high_freq_factor` (`float`, *optional*):
|
| 339 |
-
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 340 |
-
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 341 |
-
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 342 |
-
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
| 343 |
-
Whether to use sliding window attention.
|
| 344 |
-
sliding_window (`int`, *optional*, defaults to 4096):
|
| 345 |
-
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
| 346 |
-
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 347 |
-
The dropout ratio for the attention probabilities.
|
| 348 |
-
"""
|
| 349 |
-
|
| 350 |
-
model_type = "qwen3_tts_talker"
|
| 351 |
-
keys_to_ignore_at_inference = ["past_key_values"]
|
| 352 |
-
|
| 353 |
-
# Default tensor parallel plan for base model `Qwen3TTSTalker`
|
| 354 |
-
base_model_tp_plan = {
|
| 355 |
-
"layers.*.self_attn.q_proj": "colwise",
|
| 356 |
-
"layers.*.self_attn.k_proj": "colwise",
|
| 357 |
-
"layers.*.self_attn.v_proj": "colwise",
|
| 358 |
-
"layers.*.self_attn.o_proj": "rowwise",
|
| 359 |
-
"layers.*.mlp.gate_proj": "colwise",
|
| 360 |
-
"layers.*.mlp.up_proj": "colwise",
|
| 361 |
-
"layers.*.mlp.down_proj": "rowwise",
|
| 362 |
-
}
|
| 363 |
-
base_model_pp_plan = {
|
| 364 |
-
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 365 |
-
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 366 |
-
"norm": (["hidden_states"], ["hidden_states"]),
|
| 367 |
-
}
|
| 368 |
-
sub_configs = {"code_predictor_config": Qwen3TTSTalkerCodePredictorConfig}
|
| 369 |
-
|
| 370 |
-
def __init__(
|
| 371 |
-
self,
|
| 372 |
-
code_predictor_config=None,
|
| 373 |
-
vocab_size=3072,
|
| 374 |
-
hidden_size=1024,
|
| 375 |
-
intermediate_size=2048,
|
| 376 |
-
num_hidden_layers=20,
|
| 377 |
-
num_attention_heads=16,
|
| 378 |
-
num_key_value_heads=2,
|
| 379 |
-
hidden_act="silu",
|
| 380 |
-
max_position_embeddings=32768,
|
| 381 |
-
initializer_range=0.02,
|
| 382 |
-
rms_norm_eps=0.000001,
|
| 383 |
-
use_cache=True,
|
| 384 |
-
tie_word_embeddings=False,
|
| 385 |
-
rope_theta=10000,
|
| 386 |
-
rope_scaling=None,
|
| 387 |
-
attention_bias=False,
|
| 388 |
-
use_sliding_window=False,
|
| 389 |
-
sliding_window=4096,
|
| 390 |
-
attention_dropout=0,
|
| 391 |
-
num_code_groups=32,
|
| 392 |
-
text_hidden_size=2048,
|
| 393 |
-
codec_eos_token_id=4198,
|
| 394 |
-
codec_think_id=4202,
|
| 395 |
-
codec_nothink_id=4203,
|
| 396 |
-
codec_think_bos_id=4204,
|
| 397 |
-
codec_think_eos_id=4205,
|
| 398 |
-
codec_pad_id=4196,
|
| 399 |
-
codec_bos_id=4197,
|
| 400 |
-
spk_id=None,
|
| 401 |
-
spk_is_dialect=None,
|
| 402 |
-
codec_language_id=None,
|
| 403 |
-
**kwargs,
|
| 404 |
-
):
|
| 405 |
-
super().__init__(
|
| 406 |
-
tie_word_embeddings=tie_word_embeddings,
|
| 407 |
-
**kwargs,
|
| 408 |
-
)
|
| 409 |
-
self.vocab_size = vocab_size
|
| 410 |
-
self.max_position_embeddings = max_position_embeddings
|
| 411 |
-
self.hidden_size = hidden_size
|
| 412 |
-
self.intermediate_size = intermediate_size
|
| 413 |
-
self.num_hidden_layers = num_hidden_layers
|
| 414 |
-
self.num_attention_heads = num_attention_heads
|
| 415 |
-
self.use_sliding_window = use_sliding_window
|
| 416 |
-
self.sliding_window = sliding_window if use_sliding_window else None
|
| 417 |
-
|
| 418 |
-
self.num_key_value_heads = num_key_value_heads
|
| 419 |
-
self.hidden_act = hidden_act
|
| 420 |
-
self.initializer_range = initializer_range
|
| 421 |
-
self.rms_norm_eps = rms_norm_eps
|
| 422 |
-
self.use_cache = use_cache
|
| 423 |
-
self.rope_theta = rope_theta
|
| 424 |
-
self.rope_scaling = rope_scaling
|
| 425 |
-
self.attention_bias = attention_bias
|
| 426 |
-
self.attention_dropout = attention_dropout
|
| 427 |
-
# Validate the correctness of rotary position embeddings parameters
|
| 428 |
-
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 429 |
-
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 430 |
-
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 431 |
-
|
| 432 |
-
if code_predictor_config is None:
|
| 433 |
-
code_predictor_config = {}
|
| 434 |
-
self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig()
|
| 435 |
-
logger.info("code_predictor_config is None. Initializing code_predictor model with default values")
|
| 436 |
-
elif isinstance(code_predictor_config, Qwen3TTSTalkerCodePredictorConfig):
|
| 437 |
-
self.code_predictor_config = code_predictor_config
|
| 438 |
-
else:
|
| 439 |
-
self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig(**code_predictor_config)
|
| 440 |
-
self.num_code_groups = num_code_groups
|
| 441 |
-
self.text_hidden_size = text_hidden_size
|
| 442 |
-
self.codec_eos_token_id = codec_eos_token_id
|
| 443 |
-
self.codec_think_id = codec_think_id
|
| 444 |
-
self.codec_language_id = codec_language_id
|
| 445 |
-
self.codec_nothink_id = codec_nothink_id
|
| 446 |
-
self.codec_think_bos_id = codec_think_bos_id
|
| 447 |
-
self.codec_think_eos_id = codec_think_eos_id
|
| 448 |
-
self.codec_pad_id = codec_pad_id
|
| 449 |
-
self.codec_bos_id = codec_bos_id
|
| 450 |
-
self.spk_id = spk_id
|
| 451 |
-
self.spk_is_dialect = spk_is_dialect
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
class Qwen3TTSConfig(PretrainedConfig):
|
| 455 |
-
"""
|
| 456 |
-
This is the configuration class to store the configuration of a [`Qwen3TTSForConditionalGeneration`].
|
| 457 |
-
"""
|
| 458 |
-
|
| 459 |
-
model_type = "qwen3_tts"
|
| 460 |
-
sub_configs = {
|
| 461 |
-
"talker_config": Qwen3TTSTalkerConfig,
|
| 462 |
-
"speaker_encoder_config": Qwen3TTSSpeakerEncoderConfig,
|
| 463 |
-
}
|
| 464 |
-
|
| 465 |
-
def __init__(
|
| 466 |
-
self,
|
| 467 |
-
talker_config=None,
|
| 468 |
-
speaker_encoder_config=None,
|
| 469 |
-
tokenizer_type=None,
|
| 470 |
-
tts_model_size=None,
|
| 471 |
-
tts_model_type=None,
|
| 472 |
-
im_start_token_id=151644,
|
| 473 |
-
im_end_token_id=151645,
|
| 474 |
-
tts_pad_token_id=151671,
|
| 475 |
-
tts_bos_token_id=151672,
|
| 476 |
-
tts_eos_token_id=151673,
|
| 477 |
-
**kwargs,
|
| 478 |
-
):
|
| 479 |
-
super().__init__(**kwargs)
|
| 480 |
-
|
| 481 |
-
if talker_config is None:
|
| 482 |
-
talker_config = {}
|
| 483 |
-
logger.info("talker_config is None. Initializing talker model with default values")
|
| 484 |
-
if speaker_encoder_config is None:
|
| 485 |
-
speaker_encoder_config = {}
|
| 486 |
-
logger.info("speaker_encoder_config is None. Initializing talker model with default values")
|
| 487 |
-
|
| 488 |
-
self.talker_config = Qwen3TTSTalkerConfig(**talker_config)
|
| 489 |
-
self.speaker_encoder_config = Qwen3TTSSpeakerEncoderConfig(**speaker_encoder_config)
|
| 490 |
-
|
| 491 |
-
self.tokenizer_type = tokenizer_type
|
| 492 |
-
self.tts_model_size = tts_model_size
|
| 493 |
-
self.tts_model_type = tts_model_type
|
| 494 |
-
|
| 495 |
-
self.im_start_token_id = im_start_token_id
|
| 496 |
-
self.im_end_token_id = im_end_token_id
|
| 497 |
-
self.tts_pad_token_id = tts_pad_token_id
|
| 498 |
-
self.tts_bos_token_id = tts_bos_token_id
|
| 499 |
-
self.tts_eos_token_id = tts_eos_token_id
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
__all__ = ["Qwen3TTSConfig", "Qwen3TTSTalkerConfig", "Qwen3TTSSpeakerEncoderConfig"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/models/modeling_qwen3_tts.py
DELETED
|
@@ -1,2246 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
"""PyTorch Qwen3TTS model."""
|
| 16 |
-
|
| 17 |
-
import json
|
| 18 |
-
import os
|
| 19 |
-
from dataclasses import dataclass
|
| 20 |
-
from typing import Callable, Optional
|
| 21 |
-
|
| 22 |
-
import torch
|
| 23 |
-
from librosa.filters import mel as librosa_mel_fn
|
| 24 |
-
from torch import nn
|
| 25 |
-
from torch.nn import functional as F
|
| 26 |
-
from transformers.activations import ACT2FN
|
| 27 |
-
from transformers.cache_utils import Cache, DynamicCache
|
| 28 |
-
from transformers.generation import GenerationMixin
|
| 29 |
-
from transformers.integrations import use_kernel_forward_from_hub
|
| 30 |
-
from transformers.masking_utils import (
|
| 31 |
-
create_causal_mask,
|
| 32 |
-
create_sliding_window_causal_mask,
|
| 33 |
-
)
|
| 34 |
-
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 35 |
-
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 36 |
-
from transformers.modeling_outputs import (
|
| 37 |
-
BaseModelOutputWithPast,
|
| 38 |
-
CausalLMOutputWithPast,
|
| 39 |
-
ModelOutput,
|
| 40 |
-
)
|
| 41 |
-
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 42 |
-
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 43 |
-
from transformers.processing_utils import Unpack
|
| 44 |
-
from transformers.utils import can_return_tuple, logging
|
| 45 |
-
from transformers.utils.hub import cached_file
|
| 46 |
-
|
| 47 |
-
from ...inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
|
| 48 |
-
from .configuration_qwen3_tts import (
|
| 49 |
-
Qwen3TTSConfig,
|
| 50 |
-
Qwen3TTSSpeakerEncoderConfig,
|
| 51 |
-
Qwen3TTSTalkerCodePredictorConfig,
|
| 52 |
-
Qwen3TTSTalkerConfig,
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
logger = logging.get_logger(__name__)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
class Res2NetBlock(torch.nn.Module):
|
| 59 |
-
def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
|
| 60 |
-
super().__init__()
|
| 61 |
-
|
| 62 |
-
in_channel = in_channels // scale
|
| 63 |
-
hidden_channel = out_channels // scale
|
| 64 |
-
|
| 65 |
-
self.blocks = nn.ModuleList(
|
| 66 |
-
[
|
| 67 |
-
TimeDelayNetBlock(
|
| 68 |
-
in_channel,
|
| 69 |
-
hidden_channel,
|
| 70 |
-
kernel_size=kernel_size,
|
| 71 |
-
dilation=dilation,
|
| 72 |
-
)
|
| 73 |
-
for i in range(scale - 1)
|
| 74 |
-
]
|
| 75 |
-
)
|
| 76 |
-
self.scale = scale
|
| 77 |
-
|
| 78 |
-
def forward(self, hidden_states):
|
| 79 |
-
outputs = []
|
| 80 |
-
for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)):
|
| 81 |
-
if i == 0:
|
| 82 |
-
output_part = hidden_part
|
| 83 |
-
elif i == 1:
|
| 84 |
-
output_part = self.blocks[i - 1](hidden_part)
|
| 85 |
-
else:
|
| 86 |
-
output_part = self.blocks[i - 1](hidden_part + output_part)
|
| 87 |
-
outputs.append(output_part)
|
| 88 |
-
output = torch.cat(outputs, dim=1)
|
| 89 |
-
return output
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
class SqueezeExcitationBlock(nn.Module):
|
| 93 |
-
def __init__(self, in_channels, se_channels, out_channels):
|
| 94 |
-
super().__init__()
|
| 95 |
-
|
| 96 |
-
self.conv1 = nn.Conv1d(
|
| 97 |
-
in_channels=in_channels,
|
| 98 |
-
out_channels=se_channels,
|
| 99 |
-
kernel_size=1,
|
| 100 |
-
padding="same",
|
| 101 |
-
padding_mode="reflect",
|
| 102 |
-
)
|
| 103 |
-
self.relu = nn.ReLU(inplace=True)
|
| 104 |
-
self.conv2 = nn.Conv1d(
|
| 105 |
-
in_channels=se_channels,
|
| 106 |
-
out_channels=out_channels,
|
| 107 |
-
kernel_size=1,
|
| 108 |
-
padding="same",
|
| 109 |
-
padding_mode="reflect",
|
| 110 |
-
)
|
| 111 |
-
self.sigmoid = nn.Sigmoid()
|
| 112 |
-
|
| 113 |
-
def forward(self, hidden_states):
|
| 114 |
-
hidden_states_mean = hidden_states.mean(dim=2, keepdim=True)
|
| 115 |
-
|
| 116 |
-
hidden_states_mean = self.relu(self.conv1(hidden_states_mean))
|
| 117 |
-
hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean))
|
| 118 |
-
|
| 119 |
-
return hidden_states * hidden_states_mean
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
class AttentiveStatisticsPooling(nn.Module):
|
| 123 |
-
"""This class implements an attentive statistic pooling layer for each channel.
|
| 124 |
-
It returns the concatenated mean and std of the input tensor.
|
| 125 |
-
"""
|
| 126 |
-
|
| 127 |
-
def __init__(self, channels, attention_channels=128):
|
| 128 |
-
super().__init__()
|
| 129 |
-
|
| 130 |
-
self.eps = 1e-12
|
| 131 |
-
self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1)
|
| 132 |
-
self.tanh = nn.Tanh()
|
| 133 |
-
self.conv = nn.Conv1d(
|
| 134 |
-
in_channels=attention_channels,
|
| 135 |
-
out_channels=channels,
|
| 136 |
-
kernel_size=1,
|
| 137 |
-
padding="same",
|
| 138 |
-
padding_mode="reflect",
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
def _length_to_mask(self, length, max_len=None, dtype=None, device=None):
|
| 142 |
-
"""Creates a binary mask for each sequence.
|
| 143 |
-
|
| 144 |
-
Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
|
| 145 |
-
|
| 146 |
-
Arguments
|
| 147 |
-
---------
|
| 148 |
-
length : torch.LongTensor
|
| 149 |
-
Containing the length of each sequence in the batch. Must be 1D.
|
| 150 |
-
max_len : int
|
| 151 |
-
Max length for the mask, also the size of the second dimension.
|
| 152 |
-
dtype : torch.dtype, default: None
|
| 153 |
-
The dtype of the generated mask.
|
| 154 |
-
device: torch.device, default: None
|
| 155 |
-
The device to put the mask variable.
|
| 156 |
-
|
| 157 |
-
Returns
|
| 158 |
-
-------
|
| 159 |
-
mask : tensor
|
| 160 |
-
The binary mask.
|
| 161 |
-
"""
|
| 162 |
-
|
| 163 |
-
if max_len is None:
|
| 164 |
-
max_len = length.max().long().item() # using arange to generate mask
|
| 165 |
-
mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
|
| 166 |
-
len(length), max_len
|
| 167 |
-
) < length.unsqueeze(1)
|
| 168 |
-
|
| 169 |
-
mask = torch.as_tensor(mask, dtype=dtype, device=device)
|
| 170 |
-
return mask
|
| 171 |
-
|
| 172 |
-
def _compute_statistics(self, x, m, dim=2):
|
| 173 |
-
mean = (m * x).sum(dim)
|
| 174 |
-
std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps))
|
| 175 |
-
return mean, std
|
| 176 |
-
|
| 177 |
-
def forward(self, hidden_states):
|
| 178 |
-
seq_length = hidden_states.shape[-1]
|
| 179 |
-
lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device)
|
| 180 |
-
|
| 181 |
-
# Make binary mask of shape [N, 1, L]
|
| 182 |
-
mask = self._length_to_mask(
|
| 183 |
-
lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device
|
| 184 |
-
)
|
| 185 |
-
mask = mask.unsqueeze(1)
|
| 186 |
-
|
| 187 |
-
# Expand the temporal context of the pooling layer by allowing the
|
| 188 |
-
# self-attention to look at global properties of the utterance.
|
| 189 |
-
total = mask.sum(dim=2, keepdim=True)
|
| 190 |
-
|
| 191 |
-
mean, std = self._compute_statistics(hidden_states, mask / total)
|
| 192 |
-
mean = mean.unsqueeze(2).repeat(1, 1, seq_length)
|
| 193 |
-
std = std.unsqueeze(2).repeat(1, 1, seq_length)
|
| 194 |
-
attention = torch.cat([hidden_states, mean, std], dim=1)
|
| 195 |
-
|
| 196 |
-
# Apply layers
|
| 197 |
-
attention = self.conv(self.tanh(self.tdnn(attention)))
|
| 198 |
-
|
| 199 |
-
# Filter out zero-paddings
|
| 200 |
-
attention = attention.masked_fill(mask == 0, float("-inf"))
|
| 201 |
-
|
| 202 |
-
attention = F.softmax(attention, dim=2)
|
| 203 |
-
mean, std = self._compute_statistics(hidden_states, attention)
|
| 204 |
-
# Append mean and std of the batch
|
| 205 |
-
pooled_stats = torch.cat((mean, std), dim=1)
|
| 206 |
-
pooled_stats = pooled_stats.unsqueeze(2)
|
| 207 |
-
|
| 208 |
-
return pooled_stats
|
| 209 |
-
|
| 210 |
-
class TimeDelayNetBlock(nn.Module):
|
| 211 |
-
def __init__(
|
| 212 |
-
self,
|
| 213 |
-
in_channels,
|
| 214 |
-
out_channels,
|
| 215 |
-
kernel_size,
|
| 216 |
-
dilation,
|
| 217 |
-
):
|
| 218 |
-
super().__init__()
|
| 219 |
-
self.conv = nn.Conv1d(
|
| 220 |
-
in_channels=in_channels,
|
| 221 |
-
out_channels=out_channels,
|
| 222 |
-
kernel_size=kernel_size,
|
| 223 |
-
dilation=dilation,
|
| 224 |
-
padding="same",
|
| 225 |
-
padding_mode="reflect",
|
| 226 |
-
)
|
| 227 |
-
self.activation = nn.ReLU()
|
| 228 |
-
|
| 229 |
-
def forward(self, hidden_states: torch.Tensor):
|
| 230 |
-
return self.activation(self.conv(hidden_states))
|
| 231 |
-
|
| 232 |
-
class SqueezeExcitationRes2NetBlock(nn.Module):
|
| 233 |
-
"""An implementation of building block in ECAPA-TDNN, i.e.,
|
| 234 |
-
TDNN-Res2Net-TDNN-SqueezeExcitationBlock.
|
| 235 |
-
"""
|
| 236 |
-
|
| 237 |
-
def __init__(
|
| 238 |
-
self,
|
| 239 |
-
in_channels,
|
| 240 |
-
out_channels,
|
| 241 |
-
res2net_scale=8,
|
| 242 |
-
se_channels=128,
|
| 243 |
-
kernel_size=1,
|
| 244 |
-
dilation=1,
|
| 245 |
-
):
|
| 246 |
-
super().__init__()
|
| 247 |
-
self.out_channels = out_channels
|
| 248 |
-
self.tdnn1 = TimeDelayNetBlock(
|
| 249 |
-
in_channels,
|
| 250 |
-
out_channels,
|
| 251 |
-
kernel_size=1,
|
| 252 |
-
dilation=1,
|
| 253 |
-
)
|
| 254 |
-
self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
|
| 255 |
-
self.tdnn2 = TimeDelayNetBlock(
|
| 256 |
-
out_channels,
|
| 257 |
-
out_channels,
|
| 258 |
-
kernel_size=1,
|
| 259 |
-
dilation=1,
|
| 260 |
-
)
|
| 261 |
-
self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels)
|
| 262 |
-
|
| 263 |
-
def forward(self, hidden_state):
|
| 264 |
-
residual = hidden_state
|
| 265 |
-
|
| 266 |
-
hidden_state = self.tdnn1(hidden_state)
|
| 267 |
-
hidden_state = self.res2net_block(hidden_state)
|
| 268 |
-
hidden_state = self.tdnn2(hidden_state)
|
| 269 |
-
hidden_state = self.se_block(hidden_state)
|
| 270 |
-
|
| 271 |
-
return hidden_state + residual
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
class Qwen3TTSSpeakerEncoder(torch.nn.Module):
|
| 275 |
-
"""An implementation of the speaker embedding model in a paper.
|
| 276 |
-
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
|
| 277 |
-
TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143).
|
| 278 |
-
Use for Qwen3TTS extract speaker embedding.
|
| 279 |
-
"""
|
| 280 |
-
|
| 281 |
-
def __init__(self, config: Qwen3TTSSpeakerEncoderConfig):
|
| 282 |
-
super().__init__()
|
| 283 |
-
if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len(
|
| 284 |
-
config.enc_dilations
|
| 285 |
-
):
|
| 286 |
-
raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length")
|
| 287 |
-
self.channels = config.enc_channels
|
| 288 |
-
self.blocks = nn.ModuleList()
|
| 289 |
-
|
| 290 |
-
# The initial TDNN layer
|
| 291 |
-
self.blocks.append(
|
| 292 |
-
TimeDelayNetBlock(
|
| 293 |
-
config.mel_dim,
|
| 294 |
-
config.enc_channels[0],
|
| 295 |
-
config.enc_kernel_sizes[0],
|
| 296 |
-
config.enc_dilations[0],
|
| 297 |
-
)
|
| 298 |
-
)
|
| 299 |
-
|
| 300 |
-
# SE-Res2Net layers
|
| 301 |
-
for i in range(1, len(config.enc_channels) - 1):
|
| 302 |
-
self.blocks.append(
|
| 303 |
-
SqueezeExcitationRes2NetBlock(
|
| 304 |
-
config.enc_channels[i - 1],
|
| 305 |
-
config.enc_channels[i],
|
| 306 |
-
res2net_scale=config.enc_res2net_scale,
|
| 307 |
-
se_channels=config.enc_se_channels,
|
| 308 |
-
kernel_size=config.enc_kernel_sizes[i],
|
| 309 |
-
dilation=config.enc_dilations[i],
|
| 310 |
-
)
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
# Multi-layer feature aggregation
|
| 314 |
-
self.mfa = TimeDelayNetBlock(
|
| 315 |
-
config.enc_channels[-1],
|
| 316 |
-
config.enc_channels[-1],
|
| 317 |
-
config.enc_kernel_sizes[-1],
|
| 318 |
-
config.enc_dilations[-1],
|
| 319 |
-
)
|
| 320 |
-
|
| 321 |
-
# Attentive Statistical Pooling
|
| 322 |
-
self.asp = AttentiveStatisticsPooling(
|
| 323 |
-
config.enc_channels[-1],
|
| 324 |
-
attention_channels=config.enc_attention_channels,
|
| 325 |
-
)
|
| 326 |
-
|
| 327 |
-
# Final linear transformation
|
| 328 |
-
self.fc = nn.Conv1d(
|
| 329 |
-
in_channels=config.enc_channels[-1] * 2,
|
| 330 |
-
out_channels=config.enc_dim,
|
| 331 |
-
kernel_size=1,
|
| 332 |
-
padding="same",
|
| 333 |
-
padding_mode="reflect",
|
| 334 |
-
)
|
| 335 |
-
|
| 336 |
-
def forward(self, hidden_states):
|
| 337 |
-
# Minimize transpose for efficiency
|
| 338 |
-
hidden_states = hidden_states.transpose(1, 2)
|
| 339 |
-
|
| 340 |
-
hidden_states_list = []
|
| 341 |
-
for layer in self.blocks:
|
| 342 |
-
hidden_states = layer(hidden_states)
|
| 343 |
-
hidden_states_list.append(hidden_states)
|
| 344 |
-
|
| 345 |
-
# Multi-layer feature aggregation
|
| 346 |
-
hidden_states = torch.cat(hidden_states_list[1:], dim=1)
|
| 347 |
-
hidden_states = self.mfa(hidden_states)
|
| 348 |
-
|
| 349 |
-
# Attentive Statistical Pooling
|
| 350 |
-
hidden_states = self.asp(hidden_states)
|
| 351 |
-
|
| 352 |
-
# Final linear transformation
|
| 353 |
-
hidden_states = self.fc(hidden_states)
|
| 354 |
-
|
| 355 |
-
hidden_states = hidden_states.squeeze(-1)
|
| 356 |
-
return hidden_states
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 360 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 361 |
-
|
| 362 |
-
def mel_spectrogram(
|
| 363 |
-
y: torch.Tensor,
|
| 364 |
-
n_fft: int,
|
| 365 |
-
num_mels: int,
|
| 366 |
-
sampling_rate: int,
|
| 367 |
-
hop_size: int,
|
| 368 |
-
win_size: int,
|
| 369 |
-
fmin: int,
|
| 370 |
-
fmax: int = None,
|
| 371 |
-
center: bool = False,
|
| 372 |
-
) -> torch.Tensor:
|
| 373 |
-
"""
|
| 374 |
-
Calculate the mel spectrogram of an input signal.
|
| 375 |
-
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
|
| 376 |
-
|
| 377 |
-
Args:
|
| 378 |
-
y (torch.Tensor): Input signal.
|
| 379 |
-
n_fft (int): FFT size.
|
| 380 |
-
num_mels (int): Number of mel bins.
|
| 381 |
-
sampling_rate (int): Sampling rate of the input signal.
|
| 382 |
-
hop_size (int): Hop size for STFT.
|
| 383 |
-
win_size (int): Window size for STFT.
|
| 384 |
-
fmin (int): Minimum frequency for mel filterbank.
|
| 385 |
-
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
|
| 386 |
-
center (bool): Whether to pad the input to center the frames. Default is False.
|
| 387 |
-
|
| 388 |
-
Returns:
|
| 389 |
-
torch.Tensor: Mel spectrogram.
|
| 390 |
-
"""
|
| 391 |
-
if torch.min(y) < -1.0:
|
| 392 |
-
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
|
| 393 |
-
if torch.max(y) > 1.0:
|
| 394 |
-
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
|
| 395 |
-
|
| 396 |
-
device = y.device
|
| 397 |
-
|
| 398 |
-
mel = librosa_mel_fn(
|
| 399 |
-
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
| 400 |
-
)
|
| 401 |
-
|
| 402 |
-
mel_basis = torch.from_numpy(mel).float().to(device)
|
| 403 |
-
hann_window = torch.hann_window(win_size).to(device)
|
| 404 |
-
|
| 405 |
-
padding = (n_fft - hop_size) // 2
|
| 406 |
-
y = torch.nn.functional.pad(
|
| 407 |
-
y.unsqueeze(1), (padding, padding), mode="reflect"
|
| 408 |
-
).squeeze(1)
|
| 409 |
-
|
| 410 |
-
spec = torch.stft(
|
| 411 |
-
y,
|
| 412 |
-
n_fft,
|
| 413 |
-
hop_length=hop_size,
|
| 414 |
-
win_length=win_size,
|
| 415 |
-
window=hann_window,
|
| 416 |
-
center=center,
|
| 417 |
-
pad_mode="reflect",
|
| 418 |
-
normalized=False,
|
| 419 |
-
onesided=True,
|
| 420 |
-
return_complex=True,
|
| 421 |
-
)
|
| 422 |
-
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
| 423 |
-
|
| 424 |
-
mel_spec = torch.matmul(mel_basis, spec)
|
| 425 |
-
mel_spec = dynamic_range_compression_torch(mel_spec)
|
| 426 |
-
|
| 427 |
-
return mel_spec
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
class Qwen3TTSPreTrainedModel(PreTrainedModel):
|
| 431 |
-
config_class = Qwen3TTSConfig
|
| 432 |
-
base_model_prefix = "model"
|
| 433 |
-
supports_gradient_checkpointing = True
|
| 434 |
-
_no_split_modules = ["Qwen3TTSDecoderLayer"]
|
| 435 |
-
_skip_keys_device_placement = "past_key_values"
|
| 436 |
-
_supports_flash_attn_2 = True
|
| 437 |
-
_supports_sdpa = True
|
| 438 |
-
_supports_cache_class = True
|
| 439 |
-
_supports_static_cache = False
|
| 440 |
-
_supports_attention_backend = True
|
| 441 |
-
|
| 442 |
-
def _init_weights(self, module):
|
| 443 |
-
# important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only
|
| 444 |
-
# inference and fine-tuning - so the proper init weights code has been removed
|
| 445 |
-
std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02
|
| 446 |
-
|
| 447 |
-
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)):
|
| 448 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 449 |
-
if module.bias is not None:
|
| 450 |
-
module.bias.data.zero_()
|
| 451 |
-
elif isinstance(module, nn.Embedding):
|
| 452 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 453 |
-
if module.padding_idx is not None:
|
| 454 |
-
module.weight.data[module.padding_idx].zero_()
|
| 455 |
-
elif isinstance(module, nn.LayerNorm):
|
| 456 |
-
if module.weight is not None:
|
| 457 |
-
module.weight.data.fill_(1.0)
|
| 458 |
-
if module.bias is not None:
|
| 459 |
-
module.bias.data.zero_()
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
class Qwen3TTSTalkerTextPreTrainedModel(PreTrainedModel):
|
| 463 |
-
base_model_prefix = "model"
|
| 464 |
-
supports_gradient_checkpointing = True
|
| 465 |
-
_no_split_modules = []
|
| 466 |
-
_skip_keys_device_placement = ["past_key_values"]
|
| 467 |
-
_supports_flash_attn_3 = True
|
| 468 |
-
_supports_flash_attn_2 = True
|
| 469 |
-
_supports_sdpa = True
|
| 470 |
-
_supports_flex_attn = True
|
| 471 |
-
_supports_cache_class = True
|
| 472 |
-
_supports_quantized_cache = True
|
| 473 |
-
_supports_static_cache = False
|
| 474 |
-
_supports_attention_backend = True
|
| 475 |
-
|
| 476 |
-
def _init_weights(self, module):
|
| 477 |
-
std = self.config.initializer_range
|
| 478 |
-
if isinstance(module, nn.Linear):
|
| 479 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 480 |
-
if module.bias is not None:
|
| 481 |
-
module.bias.data.zero_()
|
| 482 |
-
elif isinstance(module, nn.Embedding):
|
| 483 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 484 |
-
if module.padding_idx is not None:
|
| 485 |
-
module.weight.data[module.padding_idx].zero_()
|
| 486 |
-
elif isinstance(module, Qwen3TTSRMSNorm):
|
| 487 |
-
module.weight.data.fill_(1.0)
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
class Qwen3TTSTalkerRotaryEmbedding(nn.Module):
|
| 491 |
-
def __init__(self, config: Qwen3TTSTalkerConfig, device=None):
|
| 492 |
-
super().__init__()
|
| 493 |
-
# BC: "rope_type" was originally "type"
|
| 494 |
-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 495 |
-
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 496 |
-
else:
|
| 497 |
-
self.rope_type = "default"
|
| 498 |
-
self.max_seq_len_cached = config.max_position_embeddings
|
| 499 |
-
self.original_max_seq_len = config.max_position_embeddings
|
| 500 |
-
|
| 501 |
-
self.config = config
|
| 502 |
-
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 503 |
-
|
| 504 |
-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 505 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 506 |
-
self.original_inv_freq = self.inv_freq
|
| 507 |
-
|
| 508 |
-
@torch.no_grad()
|
| 509 |
-
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 510 |
-
def forward(self, x, position_ids):
|
| 511 |
-
# In contrast to other models, Qwen3TTSThinkerText has different position ids for the grids
|
| 512 |
-
# So we expand the inv_freq to shape (3, ...)
|
| 513 |
-
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
| 514 |
-
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
| 515 |
-
|
| 516 |
-
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 517 |
-
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 518 |
-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
| 519 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
| 520 |
-
cos = emb.cos() * self.attention_scaling
|
| 521 |
-
sin = emb.sin() * self.attention_scaling
|
| 522 |
-
|
| 523 |
-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 524 |
-
|
| 525 |
-
class Qwen3TTSRotaryEmbedding(nn.Module):
|
| 526 |
-
def __init__(self, config: Qwen3TTSConfig, device=None):
|
| 527 |
-
super().__init__()
|
| 528 |
-
# BC: "rope_type" was originally "type"
|
| 529 |
-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 530 |
-
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 531 |
-
else:
|
| 532 |
-
self.rope_type = "default"
|
| 533 |
-
self.max_seq_len_cached = config.max_position_embeddings
|
| 534 |
-
self.original_max_seq_len = config.max_position_embeddings
|
| 535 |
-
|
| 536 |
-
self.config = config
|
| 537 |
-
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 538 |
-
|
| 539 |
-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 540 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 541 |
-
self.original_inv_freq = self.inv_freq
|
| 542 |
-
|
| 543 |
-
@torch.no_grad()
|
| 544 |
-
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 545 |
-
def forward(self, x, position_ids):
|
| 546 |
-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 547 |
-
position_ids_expanded = position_ids[:, None, :].float()
|
| 548 |
-
|
| 549 |
-
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 550 |
-
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 551 |
-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 552 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
| 553 |
-
cos = emb.cos() * self.attention_scaling
|
| 554 |
-
sin = emb.sin() * self.attention_scaling
|
| 555 |
-
|
| 556 |
-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
@use_kernel_forward_from_hub("RMSNorm")
|
| 560 |
-
class Qwen3TTSRMSNorm(nn.Module):
|
| 561 |
-
def __init__(self, hidden_size, eps=1e-6):
|
| 562 |
-
"""
|
| 563 |
-
Qwen3TTSRMSNorm is equivalent to T5LayerNorm
|
| 564 |
-
"""
|
| 565 |
-
super().__init__()
|
| 566 |
-
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 567 |
-
self.variance_epsilon = eps
|
| 568 |
-
|
| 569 |
-
def forward(self, hidden_states):
|
| 570 |
-
input_dtype = hidden_states.dtype
|
| 571 |
-
hidden_states = hidden_states.to(torch.float32)
|
| 572 |
-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 573 |
-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 574 |
-
return self.weight * hidden_states.to(input_dtype)
|
| 575 |
-
|
| 576 |
-
def extra_repr(self):
|
| 577 |
-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 578 |
-
|
| 579 |
-
def rotate_half(x):
|
| 580 |
-
"""Rotates half the hidden dims of the input."""
|
| 581 |
-
x1 = x[..., : x.shape[-1] // 2]
|
| 582 |
-
x2 = x[..., x.shape[-1] // 2 :]
|
| 583 |
-
return torch.cat((-x2, x1), dim=-1)
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 587 |
-
"""
|
| 588 |
-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 589 |
-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 590 |
-
"""
|
| 591 |
-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 592 |
-
if n_rep == 1:
|
| 593 |
-
return hidden_states
|
| 594 |
-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 595 |
-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
def eager_attention_forward(
|
| 599 |
-
module: nn.Module,
|
| 600 |
-
query: torch.Tensor,
|
| 601 |
-
key: torch.Tensor,
|
| 602 |
-
value: torch.Tensor,
|
| 603 |
-
attention_mask: Optional[torch.Tensor],
|
| 604 |
-
scaling: float,
|
| 605 |
-
dropout: float = 0.0,
|
| 606 |
-
**kwargs,
|
| 607 |
-
):
|
| 608 |
-
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 609 |
-
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 610 |
-
|
| 611 |
-
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 612 |
-
if attention_mask is not None:
|
| 613 |
-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 614 |
-
attn_weights = attn_weights + causal_mask
|
| 615 |
-
|
| 616 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 617 |
-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 618 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
| 619 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 620 |
-
|
| 621 |
-
return attn_output, attn_weights
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, mrope_interleaved=False, unsqueeze_dim=1):
|
| 625 |
-
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
|
| 626 |
-
|
| 627 |
-
Explanation:
|
| 628 |
-
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
| 629 |
-
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
| 630 |
-
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
|
| 631 |
-
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
| 632 |
-
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
| 633 |
-
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
| 634 |
-
difference with modern LLMs.
|
| 635 |
-
|
| 636 |
-
Args:
|
| 637 |
-
q (`torch.Tensor`): The query tensor.
|
| 638 |
-
k (`torch.Tensor`): The key tensor.
|
| 639 |
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 640 |
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 641 |
-
position_ids (`torch.Tensor`):
|
| 642 |
-
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 643 |
-
used to pass offsetted position ids when working with a KV-cache.
|
| 644 |
-
mrope_section(`List(int)`):
|
| 645 |
-
Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
|
| 646 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 647 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 648 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 649 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 650 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 651 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 652 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 653 |
-
Returns:
|
| 654 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 655 |
-
"""
|
| 656 |
-
if mrope_interleaved:
|
| 657 |
-
|
| 658 |
-
def apply_interleaved_rope(x, modality_num):
|
| 659 |
-
x_t = x[0].clone()
|
| 660 |
-
index_ranges = []
|
| 661 |
-
for i, n in enumerate(mrope_section[1:], 1):
|
| 662 |
-
beg_idx = i
|
| 663 |
-
end_idx = n * modality_num
|
| 664 |
-
index_ranges.append((beg_idx, end_idx))
|
| 665 |
-
for beg_idx, end_idx in index_ranges:
|
| 666 |
-
x_t[..., beg_idx:end_idx:modality_num] = x[beg_idx, ..., beg_idx:end_idx:modality_num]
|
| 667 |
-
return x_t
|
| 668 |
-
|
| 669 |
-
dim = cos.shape[-1]
|
| 670 |
-
modality_num = len(mrope_section)
|
| 671 |
-
cos = torch.cat([apply_interleaved_rope(cos[..., : dim // 2], modality_num)] * 2, dim=-1).unsqueeze(
|
| 672 |
-
unsqueeze_dim
|
| 673 |
-
)
|
| 674 |
-
sin = torch.cat([apply_interleaved_rope(sin[..., : dim // 2], modality_num)] * 2, dim=-1).unsqueeze(
|
| 675 |
-
unsqueeze_dim
|
| 676 |
-
)
|
| 677 |
-
else:
|
| 678 |
-
mrope_section = mrope_section * 2
|
| 679 |
-
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
| 680 |
-
unsqueeze_dim
|
| 681 |
-
)
|
| 682 |
-
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
| 683 |
-
unsqueeze_dim
|
| 684 |
-
)
|
| 685 |
-
|
| 686 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 687 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 688 |
-
return q_embed, k_embed
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
class Qwen3TTSTalkerAttention(nn.Module):
|
| 692 |
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 693 |
-
|
| 694 |
-
def __init__(self, config, layer_idx):
|
| 695 |
-
super().__init__()
|
| 696 |
-
self.config = config
|
| 697 |
-
self.layer_idx = layer_idx
|
| 698 |
-
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 699 |
-
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 700 |
-
self.scaling = self.head_dim**-0.5
|
| 701 |
-
self.attention_dropout = config.attention_dropout
|
| 702 |
-
self.is_causal = True
|
| 703 |
-
|
| 704 |
-
self.q_proj = nn.Linear(
|
| 705 |
-
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 706 |
-
)
|
| 707 |
-
self.k_proj = nn.Linear(
|
| 708 |
-
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 709 |
-
)
|
| 710 |
-
self.v_proj = nn.Linear(
|
| 711 |
-
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 712 |
-
)
|
| 713 |
-
self.o_proj = nn.Linear(
|
| 714 |
-
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 715 |
-
)
|
| 716 |
-
self.q_norm = Qwen3TTSRMSNorm(
|
| 717 |
-
self.head_dim, eps=config.rms_norm_eps
|
| 718 |
-
) # unlike olmo, only on the head dim!
|
| 719 |
-
self.k_norm = Qwen3TTSRMSNorm(
|
| 720 |
-
self.head_dim, eps=config.rms_norm_eps
|
| 721 |
-
) # thus post q_norm does not need reshape
|
| 722 |
-
self.sliding_window = getattr(config, "sliding_window", None)
|
| 723 |
-
self.rope_scaling = config.rope_scaling
|
| 724 |
-
|
| 725 |
-
def forward(
|
| 726 |
-
self,
|
| 727 |
-
hidden_states: torch.Tensor,
|
| 728 |
-
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 729 |
-
attention_mask: Optional[torch.Tensor],
|
| 730 |
-
past_key_values: Optional[Cache] = None,
|
| 731 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 732 |
-
**kwargs: Unpack[FlashAttentionKwargs],
|
| 733 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 734 |
-
input_shape = hidden_states.shape[:-1]
|
| 735 |
-
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 736 |
-
|
| 737 |
-
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 738 |
-
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 739 |
-
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 740 |
-
|
| 741 |
-
cos, sin = position_embeddings
|
| 742 |
-
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
| 743 |
-
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"], self.rope_scaling["interleaved"]
|
| 744 |
-
)
|
| 745 |
-
|
| 746 |
-
if past_key_values is not None:
|
| 747 |
-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 748 |
-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 749 |
-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 750 |
-
|
| 751 |
-
attention_interface: Callable = eager_attention_forward
|
| 752 |
-
if self.config._attn_implementation != "eager":
|
| 753 |
-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 754 |
-
|
| 755 |
-
attn_output, attn_weights = attention_interface(
|
| 756 |
-
self,
|
| 757 |
-
query_states,
|
| 758 |
-
key_states,
|
| 759 |
-
value_states,
|
| 760 |
-
attention_mask,
|
| 761 |
-
dropout=0.0 if not self.training else self.attention_dropout,
|
| 762 |
-
scaling=self.scaling,
|
| 763 |
-
sliding_window=self.sliding_window, # diff with Llama
|
| 764 |
-
**kwargs,
|
| 765 |
-
)
|
| 766 |
-
|
| 767 |
-
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 768 |
-
attn_output = self.o_proj(attn_output)
|
| 769 |
-
return attn_output, attn_weights
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
class Qwen3TTSTalkerResizeMLP(nn.Module):
|
| 773 |
-
def __init__(self, input_size: int, intermediate_size: int, output_size: int, act: str, bias=False):
|
| 774 |
-
super().__init__()
|
| 775 |
-
self.linear_fc1 = nn.Linear(input_size, intermediate_size, bias=bias)
|
| 776 |
-
self.linear_fc2 = nn.Linear(intermediate_size, output_size, bias=bias)
|
| 777 |
-
self.act_fn = ACT2FN[act]
|
| 778 |
-
|
| 779 |
-
def forward(self, hidden_state):
|
| 780 |
-
return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
@dataclass
|
| 784 |
-
class Qwen3TTSTalkerCodePredictorOutputWithPast(ModelOutput):
|
| 785 |
-
r"""
|
| 786 |
-
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 787 |
-
Language modeling loss (for next-token prediction).
|
| 788 |
-
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 789 |
-
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 790 |
-
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 791 |
-
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 792 |
-
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 793 |
-
|
| 794 |
-
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 795 |
-
`past_key_values` input) to speed up sequential decoding.
|
| 796 |
-
"""
|
| 797 |
-
|
| 798 |
-
loss: Optional[torch.FloatTensor] = None
|
| 799 |
-
logits: torch.FloatTensor = None
|
| 800 |
-
past_key_values: Optional[list[torch.FloatTensor]] = None
|
| 801 |
-
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 802 |
-
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 803 |
-
generation_steps: Optional[int] = None
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
class Qwen3TTSTalkerTextMLP(nn.Module):
|
| 807 |
-
def __init__(self, config, intermediate_size=None):
|
| 808 |
-
super().__init__()
|
| 809 |
-
self.config = config
|
| 810 |
-
self.hidden_size = config.hidden_size
|
| 811 |
-
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
|
| 812 |
-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 813 |
-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 814 |
-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 815 |
-
self.act_fn = ACT2FN[config.hidden_act]
|
| 816 |
-
|
| 817 |
-
def forward(self, x):
|
| 818 |
-
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 819 |
-
return down_proj
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 823 |
-
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 824 |
-
|
| 825 |
-
Args:
|
| 826 |
-
q (`torch.Tensor`): The query tensor.
|
| 827 |
-
k (`torch.Tensor`): The key tensor.
|
| 828 |
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 829 |
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 830 |
-
position_ids (`torch.Tensor`, *optional*):
|
| 831 |
-
Deprecated and unused.
|
| 832 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 833 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 834 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 835 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 836 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 837 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 838 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 839 |
-
Returns:
|
| 840 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 841 |
-
"""
|
| 842 |
-
cos = cos.unsqueeze(unsqueeze_dim)
|
| 843 |
-
sin = sin.unsqueeze(unsqueeze_dim)
|
| 844 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 845 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 846 |
-
return q_embed, k_embed
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
class Qwen3TTSAttention(nn.Module):
|
| 850 |
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 851 |
-
|
| 852 |
-
def __init__(self, config: Qwen3TTSConfig, layer_idx: int):
|
| 853 |
-
super().__init__()
|
| 854 |
-
self.config = config
|
| 855 |
-
self.layer_idx = layer_idx
|
| 856 |
-
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 857 |
-
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 858 |
-
self.scaling = self.head_dim**-0.5
|
| 859 |
-
self.attention_dropout = config.attention_dropout
|
| 860 |
-
self.is_causal = True
|
| 861 |
-
|
| 862 |
-
self.q_proj = nn.Linear(
|
| 863 |
-
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 864 |
-
)
|
| 865 |
-
self.k_proj = nn.Linear(
|
| 866 |
-
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 867 |
-
)
|
| 868 |
-
self.v_proj = nn.Linear(
|
| 869 |
-
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 870 |
-
)
|
| 871 |
-
self.o_proj = nn.Linear(
|
| 872 |
-
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 873 |
-
)
|
| 874 |
-
self.q_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
| 875 |
-
self.k_norm = Qwen3TTSRMSNorm(
|
| 876 |
-
self.head_dim, eps=config.rms_norm_eps
|
| 877 |
-
) # thus post q_norm does not need reshape
|
| 878 |
-
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
| 879 |
-
|
| 880 |
-
def forward(
|
| 881 |
-
self,
|
| 882 |
-
hidden_states: torch.Tensor,
|
| 883 |
-
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 884 |
-
attention_mask: Optional[torch.Tensor],
|
| 885 |
-
past_key_values: Optional[Cache] = None,
|
| 886 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 887 |
-
**kwargs: Unpack[FlashAttentionKwargs],
|
| 888 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 889 |
-
input_shape = hidden_states.shape[:-1]
|
| 890 |
-
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 891 |
-
|
| 892 |
-
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 893 |
-
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 894 |
-
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 895 |
-
|
| 896 |
-
cos, sin = position_embeddings
|
| 897 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 898 |
-
|
| 899 |
-
if past_key_values is not None:
|
| 900 |
-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 901 |
-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 902 |
-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 903 |
-
|
| 904 |
-
attention_interface: Callable = eager_attention_forward
|
| 905 |
-
if self.config._attn_implementation != "eager":
|
| 906 |
-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 907 |
-
|
| 908 |
-
attn_output, attn_weights = attention_interface(
|
| 909 |
-
self,
|
| 910 |
-
query_states,
|
| 911 |
-
key_states,
|
| 912 |
-
value_states,
|
| 913 |
-
attention_mask,
|
| 914 |
-
dropout=0.0 if not self.training else self.attention_dropout,
|
| 915 |
-
scaling=self.scaling,
|
| 916 |
-
sliding_window=self.sliding_window, # diff with Llama
|
| 917 |
-
**kwargs,
|
| 918 |
-
)
|
| 919 |
-
|
| 920 |
-
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 921 |
-
attn_output = self.o_proj(attn_output)
|
| 922 |
-
return attn_output, attn_weights
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
class Qwen3TTSDecoderLayer(GradientCheckpointingLayer):
|
| 926 |
-
def __init__(self, config: Qwen3TTSConfig, layer_idx: int):
|
| 927 |
-
super().__init__()
|
| 928 |
-
self.hidden_size = config.hidden_size
|
| 929 |
-
|
| 930 |
-
self.self_attn = Qwen3TTSAttention(config=config, layer_idx=layer_idx)
|
| 931 |
-
|
| 932 |
-
self.mlp = Qwen3TTSTalkerTextMLP(config)
|
| 933 |
-
self.input_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 934 |
-
self.post_attention_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 935 |
-
self.attention_type = config.layer_types[layer_idx]
|
| 936 |
-
|
| 937 |
-
def forward(
|
| 938 |
-
self,
|
| 939 |
-
hidden_states: torch.Tensor,
|
| 940 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 941 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 942 |
-
past_key_values: Optional[Cache] = None,
|
| 943 |
-
output_attentions: Optional[bool] = False,
|
| 944 |
-
use_cache: Optional[bool] = False,
|
| 945 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 946 |
-
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 947 |
-
**kwargs: Unpack[FlashAttentionKwargs],
|
| 948 |
-
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 949 |
-
residual = hidden_states
|
| 950 |
-
hidden_states = self.input_layernorm(hidden_states)
|
| 951 |
-
|
| 952 |
-
# Self Attention
|
| 953 |
-
hidden_states, self_attn_weights = self.self_attn(
|
| 954 |
-
hidden_states=hidden_states,
|
| 955 |
-
attention_mask=attention_mask,
|
| 956 |
-
position_ids=position_ids,
|
| 957 |
-
past_key_values=past_key_values,
|
| 958 |
-
output_attentions=output_attentions,
|
| 959 |
-
use_cache=use_cache,
|
| 960 |
-
cache_position=cache_position,
|
| 961 |
-
position_embeddings=position_embeddings,
|
| 962 |
-
**kwargs,
|
| 963 |
-
)
|
| 964 |
-
hidden_states = residual + hidden_states
|
| 965 |
-
|
| 966 |
-
# Fully Connected
|
| 967 |
-
residual = hidden_states
|
| 968 |
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 969 |
-
hidden_states = self.mlp(hidden_states)
|
| 970 |
-
hidden_states = residual + hidden_states
|
| 971 |
-
|
| 972 |
-
outputs = (hidden_states,)
|
| 973 |
-
if output_attentions:
|
| 974 |
-
outputs += (self_attn_weights,)
|
| 975 |
-
|
| 976 |
-
return outputs
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
class Qwen3TTSTalkerCodePredictorModel(Qwen3TTSPreTrainedModel):
|
| 980 |
-
config_class = Qwen3TTSTalkerCodePredictorConfig
|
| 981 |
-
base_model_prefix = "talker.code_predictor.model"
|
| 982 |
-
|
| 983 |
-
def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig, embedding_dim: int):
|
| 984 |
-
super().__init__(config)
|
| 985 |
-
self.padding_idx = config.pad_token_id
|
| 986 |
-
self.vocab_size = config.vocab_size
|
| 987 |
-
self.layers = nn.ModuleList(
|
| 988 |
-
[Qwen3TTSDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 989 |
-
)
|
| 990 |
-
self.norm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 991 |
-
self.rotary_emb = Qwen3TTSRotaryEmbedding(config=config)
|
| 992 |
-
self.gradient_checkpointing = False
|
| 993 |
-
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
| 994 |
-
self.codec_embedding = nn.ModuleList(
|
| 995 |
-
[nn.Embedding(config.vocab_size, embedding_dim) for _ in range(config.num_code_groups - 1)]
|
| 996 |
-
)
|
| 997 |
-
|
| 998 |
-
# Initialize weights and apply final processing
|
| 999 |
-
self.post_init()
|
| 1000 |
-
|
| 1001 |
-
def get_input_embeddings(self):
|
| 1002 |
-
return self.codec_embedding
|
| 1003 |
-
|
| 1004 |
-
def set_input_embeddings(self, value):
|
| 1005 |
-
self.embed_tokens = value
|
| 1006 |
-
|
| 1007 |
-
@can_return_tuple
|
| 1008 |
-
def forward(
|
| 1009 |
-
self,
|
| 1010 |
-
input_ids=None,
|
| 1011 |
-
attention_mask=None,
|
| 1012 |
-
position_ids=None,
|
| 1013 |
-
past_key_values=None,
|
| 1014 |
-
inputs_embeds=None,
|
| 1015 |
-
use_cache=None,
|
| 1016 |
-
output_attentions=None,
|
| 1017 |
-
output_hidden_states=None,
|
| 1018 |
-
cache_position=None,
|
| 1019 |
-
generation_steps=None,
|
| 1020 |
-
**flash_attn_kwargs,
|
| 1021 |
-
) -> BaseModelOutputWithPast:
|
| 1022 |
-
if input_ids is not None:
|
| 1023 |
-
raise ValueError("`input_ids` is expected to be `None`")
|
| 1024 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1025 |
-
output_hidden_states = (
|
| 1026 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1027 |
-
)
|
| 1028 |
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1029 |
-
|
| 1030 |
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1031 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1032 |
-
|
| 1033 |
-
if self.gradient_checkpointing and self.training and use_cache:
|
| 1034 |
-
logger.warning_once(
|
| 1035 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 1036 |
-
)
|
| 1037 |
-
use_cache = False
|
| 1038 |
-
|
| 1039 |
-
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
| 1040 |
-
if not isinstance(past_key_values, (type(None), Cache)):
|
| 1041 |
-
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
| 1042 |
-
|
| 1043 |
-
if inputs_embeds is None:
|
| 1044 |
-
inputs_embeds = self.embed_tokens(input_ids)
|
| 1045 |
-
|
| 1046 |
-
if use_cache and past_key_values is None:
|
| 1047 |
-
past_key_values = DynamicCache()
|
| 1048 |
-
|
| 1049 |
-
if cache_position is None:
|
| 1050 |
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1051 |
-
cache_position = torch.arange(
|
| 1052 |
-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 1053 |
-
)
|
| 1054 |
-
|
| 1055 |
-
if position_ids is None:
|
| 1056 |
-
position_ids = cache_position.unsqueeze(0)
|
| 1057 |
-
|
| 1058 |
-
# It may already have been prepared by e.g. `generate`
|
| 1059 |
-
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 1060 |
-
# Prepare mask arguments
|
| 1061 |
-
mask_kwargs = {
|
| 1062 |
-
"config": self.config,
|
| 1063 |
-
"input_embeds": inputs_embeds,
|
| 1064 |
-
"attention_mask": attention_mask,
|
| 1065 |
-
"cache_position": cache_position,
|
| 1066 |
-
"past_key_values": past_key_values,
|
| 1067 |
-
}
|
| 1068 |
-
# Create the masks
|
| 1069 |
-
causal_mask_mapping = {
|
| 1070 |
-
"full_attention": create_causal_mask(**mask_kwargs),
|
| 1071 |
-
}
|
| 1072 |
-
# The sliding window alternating layers are not always activated depending on the config
|
| 1073 |
-
if self.has_sliding_layers:
|
| 1074 |
-
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
| 1075 |
-
|
| 1076 |
-
hidden_states = inputs_embeds
|
| 1077 |
-
|
| 1078 |
-
# create position embeddings to be shared across the decoder layers
|
| 1079 |
-
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1080 |
-
|
| 1081 |
-
# decoder layers
|
| 1082 |
-
all_hidden_states = () if output_hidden_states else None
|
| 1083 |
-
all_self_attns = () if output_attentions else None
|
| 1084 |
-
|
| 1085 |
-
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 1086 |
-
if output_hidden_states:
|
| 1087 |
-
all_hidden_states += (hidden_states,)
|
| 1088 |
-
|
| 1089 |
-
layer_outputs = decoder_layer(
|
| 1090 |
-
hidden_states,
|
| 1091 |
-
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
| 1092 |
-
position_ids=position_ids,
|
| 1093 |
-
past_key_values=past_key_values,
|
| 1094 |
-
output_attentions=output_attentions,
|
| 1095 |
-
use_cache=use_cache,
|
| 1096 |
-
cache_position=cache_position,
|
| 1097 |
-
position_embeddings=position_embeddings,
|
| 1098 |
-
**flash_attn_kwargs,
|
| 1099 |
-
)
|
| 1100 |
-
|
| 1101 |
-
hidden_states = layer_outputs[0]
|
| 1102 |
-
|
| 1103 |
-
if output_attentions:
|
| 1104 |
-
all_self_attns += (layer_outputs[1],)
|
| 1105 |
-
|
| 1106 |
-
hidden_states = self.norm(hidden_states)
|
| 1107 |
-
|
| 1108 |
-
# add hidden states from the last decoder layer
|
| 1109 |
-
if output_hidden_states:
|
| 1110 |
-
all_hidden_states += (hidden_states,)
|
| 1111 |
-
|
| 1112 |
-
return BaseModelOutputWithPast(
|
| 1113 |
-
last_hidden_state=hidden_states,
|
| 1114 |
-
past_key_values=past_key_values if use_cache else None,
|
| 1115 |
-
hidden_states=all_hidden_states,
|
| 1116 |
-
attentions=all_self_attns,
|
| 1117 |
-
)
|
| 1118 |
-
|
| 1119 |
-
|
| 1120 |
-
class Qwen3TTSTalkerCodePredictorModelForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin):
|
| 1121 |
-
_tied_weights_keys = ["lm_head.weight"]
|
| 1122 |
-
_tp_plan = {"lm_head": "colwise_rep"}
|
| 1123 |
-
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 1124 |
-
config_class = Qwen3TTSTalkerCodePredictorConfig
|
| 1125 |
-
base_model_prefix = "talker.code_predictor"
|
| 1126 |
-
|
| 1127 |
-
def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig, talker_config: Qwen3TTSTalkerConfig):
|
| 1128 |
-
super().__init__(config)
|
| 1129 |
-
self.model = Qwen3TTSTalkerCodePredictorModel(config, talker_config.hidden_size)
|
| 1130 |
-
self.vocab_size = config.vocab_size
|
| 1131 |
-
self.lm_head = nn.ModuleList(
|
| 1132 |
-
[nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_code_groups - 1)]
|
| 1133 |
-
)
|
| 1134 |
-
|
| 1135 |
-
if config.hidden_size != talker_config.hidden_size:
|
| 1136 |
-
self.small_to_mtp_projection = torch.nn.Linear(talker_config.hidden_size, config.hidden_size, bias=True)
|
| 1137 |
-
else:
|
| 1138 |
-
self.small_to_mtp_projection = torch.nn.Identity()
|
| 1139 |
-
|
| 1140 |
-
# Initialize weights and apply final processing
|
| 1141 |
-
self.post_init()
|
| 1142 |
-
|
| 1143 |
-
def get_input_embeddings(self):
|
| 1144 |
-
return self.model.get_input_embeddings()
|
| 1145 |
-
|
| 1146 |
-
def set_input_embeddings(self, value):
|
| 1147 |
-
self.model.embed_tokens = value
|
| 1148 |
-
|
| 1149 |
-
def get_output_embeddings(self):
|
| 1150 |
-
return self.lm_head
|
| 1151 |
-
|
| 1152 |
-
def set_output_embeddings(self, new_embeddings):
|
| 1153 |
-
self.lm_head = new_embeddings
|
| 1154 |
-
|
| 1155 |
-
def set_decoder(self, decoder):
|
| 1156 |
-
self.model = decoder
|
| 1157 |
-
|
| 1158 |
-
def get_decoder(self):
|
| 1159 |
-
return self.model
|
| 1160 |
-
|
| 1161 |
-
def forward_finetune(
|
| 1162 |
-
self,
|
| 1163 |
-
input_ids=None,
|
| 1164 |
-
attention_mask=None,
|
| 1165 |
-
position_ids=None,
|
| 1166 |
-
past_key_values=None,
|
| 1167 |
-
inputs_embeds=None,
|
| 1168 |
-
labels=None,
|
| 1169 |
-
use_cache=None,
|
| 1170 |
-
output_attentions=None,
|
| 1171 |
-
output_hidden_states=None,
|
| 1172 |
-
cache_position=None,
|
| 1173 |
-
generation_steps=None,
|
| 1174 |
-
**kwargs,
|
| 1175 |
-
) -> CausalLMOutputWithPast:
|
| 1176 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1177 |
-
output_hidden_states = (
|
| 1178 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1179 |
-
)
|
| 1180 |
-
|
| 1181 |
-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1182 |
-
outputs: BaseModelOutputWithPast = self.model(
|
| 1183 |
-
input_ids=None,
|
| 1184 |
-
attention_mask=attention_mask,
|
| 1185 |
-
position_ids=position_ids,
|
| 1186 |
-
past_key_values=past_key_values,
|
| 1187 |
-
inputs_embeds=inputs_embeds,
|
| 1188 |
-
use_cache=use_cache,
|
| 1189 |
-
output_attentions=output_attentions,
|
| 1190 |
-
output_hidden_states=output_hidden_states,
|
| 1191 |
-
cache_position=cache_position,
|
| 1192 |
-
**kwargs,
|
| 1193 |
-
)
|
| 1194 |
-
|
| 1195 |
-
hidden_states = outputs.last_hidden_state
|
| 1196 |
-
|
| 1197 |
-
logits = []
|
| 1198 |
-
for i in range(1, self.config.num_code_groups):
|
| 1199 |
-
logits.append(self.lm_head[i-1](hidden_states[:, i]))
|
| 1200 |
-
logits = torch.stack(logits, dim=1)
|
| 1201 |
-
|
| 1202 |
-
loss = None
|
| 1203 |
-
if labels is not None:
|
| 1204 |
-
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 1205 |
-
|
| 1206 |
-
return Qwen3TTSTalkerCodePredictorOutputWithPast(
|
| 1207 |
-
loss=loss,
|
| 1208 |
-
logits=logits
|
| 1209 |
-
)
|
| 1210 |
-
|
| 1211 |
-
@can_return_tuple
|
| 1212 |
-
def forward(
|
| 1213 |
-
self,
|
| 1214 |
-
input_ids=None,
|
| 1215 |
-
attention_mask=None,
|
| 1216 |
-
position_ids=None,
|
| 1217 |
-
past_key_values=None,
|
| 1218 |
-
inputs_embeds=None,
|
| 1219 |
-
labels=None,
|
| 1220 |
-
use_cache=None,
|
| 1221 |
-
output_attentions=None,
|
| 1222 |
-
output_hidden_states=None,
|
| 1223 |
-
cache_position=None,
|
| 1224 |
-
generation_steps=None,
|
| 1225 |
-
**kwargs,
|
| 1226 |
-
) -> CausalLMOutputWithPast:
|
| 1227 |
-
r"""
|
| 1228 |
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1229 |
-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1230 |
-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1231 |
-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1232 |
-
"""
|
| 1233 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1234 |
-
output_hidden_states = (
|
| 1235 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1236 |
-
)
|
| 1237 |
-
|
| 1238 |
-
# Prefill stage
|
| 1239 |
-
if inputs_embeds is not None and inputs_embeds.shape[1] > 1:
|
| 1240 |
-
generation_steps = inputs_embeds.shape[1] - 2 # hidden & layer 0
|
| 1241 |
-
# Generation stage
|
| 1242 |
-
else:
|
| 1243 |
-
inputs_embeds = self.model.get_input_embeddings()[generation_steps - 1](input_ids)
|
| 1244 |
-
inputs_embeds = self.small_to_mtp_projection(inputs_embeds)
|
| 1245 |
-
|
| 1246 |
-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1247 |
-
outputs: BaseModelOutputWithPast = self.model(
|
| 1248 |
-
input_ids=None,
|
| 1249 |
-
attention_mask=attention_mask,
|
| 1250 |
-
position_ids=position_ids,
|
| 1251 |
-
past_key_values=past_key_values,
|
| 1252 |
-
inputs_embeds=inputs_embeds,
|
| 1253 |
-
use_cache=use_cache,
|
| 1254 |
-
output_attentions=output_attentions,
|
| 1255 |
-
output_hidden_states=output_hidden_states,
|
| 1256 |
-
cache_position=cache_position,
|
| 1257 |
-
**kwargs,
|
| 1258 |
-
)
|
| 1259 |
-
|
| 1260 |
-
hidden_states = outputs.last_hidden_state
|
| 1261 |
-
logits = self.lm_head[generation_steps](hidden_states)
|
| 1262 |
-
|
| 1263 |
-
loss = None
|
| 1264 |
-
if labels is not None:
|
| 1265 |
-
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 1266 |
-
|
| 1267 |
-
return Qwen3TTSTalkerCodePredictorOutputWithPast(
|
| 1268 |
-
loss=loss,
|
| 1269 |
-
logits=logits,
|
| 1270 |
-
past_key_values=outputs.past_key_values,
|
| 1271 |
-
hidden_states=outputs.hidden_states,
|
| 1272 |
-
attentions=outputs.attentions,
|
| 1273 |
-
generation_steps=generation_steps + 1,
|
| 1274 |
-
)
|
| 1275 |
-
|
| 1276 |
-
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1):
|
| 1277 |
-
model_kwargs = super()._update_model_kwargs_for_generation(
|
| 1278 |
-
outputs, model_kwargs, is_encoder_decoder, num_new_tokens
|
| 1279 |
-
)
|
| 1280 |
-
model_kwargs["generation_steps"] = outputs.generation_steps
|
| 1281 |
-
return model_kwargs
|
| 1282 |
-
|
| 1283 |
-
|
| 1284 |
-
@dataclass
|
| 1285 |
-
class Qwen3TTSTalkerOutputWithPast(ModelOutput):
|
| 1286 |
-
r"""
|
| 1287 |
-
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 1288 |
-
Language modeling loss (for next-token prediction).
|
| 1289 |
-
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 1290 |
-
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 1291 |
-
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 1292 |
-
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 1293 |
-
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 1294 |
-
|
| 1295 |
-
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 1296 |
-
`past_key_values` input) to speed up sequential decoding.
|
| 1297 |
-
"""
|
| 1298 |
-
|
| 1299 |
-
loss: Optional[torch.FloatTensor] = None
|
| 1300 |
-
logits: Optional[torch.FloatTensor] = None
|
| 1301 |
-
past_key_values: Optional[list[torch.FloatTensor]] = None
|
| 1302 |
-
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 1303 |
-
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 1304 |
-
past_hidden: Optional[torch.FloatTensor] = None
|
| 1305 |
-
generation_step: Optional[int] = None
|
| 1306 |
-
trailing_text_hidden: Optional[torch.FloatTensor] = None
|
| 1307 |
-
tts_pad_embed: Optional[torch.FloatTensor] = None
|
| 1308 |
-
|
| 1309 |
-
|
| 1310 |
-
class Qwen3TTSTalkerDecoderLayer(GradientCheckpointingLayer):
|
| 1311 |
-
def __init__(self, config, layer_idx):
|
| 1312 |
-
super().__init__()
|
| 1313 |
-
self.hidden_size = config.hidden_size
|
| 1314 |
-
self.self_attn = Qwen3TTSTalkerAttention(config, layer_idx)
|
| 1315 |
-
|
| 1316 |
-
self.mlp = Qwen3TTSTalkerTextMLP(config, intermediate_size=config.intermediate_size)
|
| 1317 |
-
|
| 1318 |
-
self.input_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1319 |
-
self.post_attention_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1320 |
-
|
| 1321 |
-
def forward(
|
| 1322 |
-
self,
|
| 1323 |
-
hidden_states: torch.Tensor,
|
| 1324 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1325 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 1326 |
-
past_key_values: Optional[tuple[torch.Tensor]] = None,
|
| 1327 |
-
output_attentions: Optional[bool] = False,
|
| 1328 |
-
use_cache: Optional[bool] = False,
|
| 1329 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 1330 |
-
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 1331 |
-
**kwargs: Unpack[FlashAttentionKwargs],
|
| 1332 |
-
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 1333 |
-
"""
|
| 1334 |
-
Args:
|
| 1335 |
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 1336 |
-
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 1337 |
-
`(batch, sequence_length)` where padding elements are indicated by 0.
|
| 1338 |
-
output_attentions (`bool`, *optional*):
|
| 1339 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1340 |
-
returned tensors for more detail.
|
| 1341 |
-
use_cache (`bool`, *optional*):
|
| 1342 |
-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 1343 |
-
(see `past_key_values`).
|
| 1344 |
-
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 1345 |
-
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 1346 |
-
Indices depicting the position of the input sequence tokens in the sequence.
|
| 1347 |
-
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
| 1348 |
-
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
| 1349 |
-
with `head_dim` being the embedding dimension of each attention head.
|
| 1350 |
-
kwargs (`dict`, *optional*):
|
| 1351 |
-
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
| 1352 |
-
into the model
|
| 1353 |
-
"""
|
| 1354 |
-
|
| 1355 |
-
residual = hidden_states
|
| 1356 |
-
|
| 1357 |
-
hidden_states = self.input_layernorm(hidden_states)
|
| 1358 |
-
|
| 1359 |
-
# Self Attention
|
| 1360 |
-
hidden_states, self_attn_weights = self.self_attn(
|
| 1361 |
-
hidden_states=hidden_states,
|
| 1362 |
-
attention_mask=attention_mask,
|
| 1363 |
-
position_ids=position_ids,
|
| 1364 |
-
past_key_values=past_key_values,
|
| 1365 |
-
output_attentions=output_attentions,
|
| 1366 |
-
use_cache=use_cache,
|
| 1367 |
-
cache_position=cache_position,
|
| 1368 |
-
position_embeddings=position_embeddings,
|
| 1369 |
-
**kwargs,
|
| 1370 |
-
)
|
| 1371 |
-
hidden_states = residual + hidden_states
|
| 1372 |
-
|
| 1373 |
-
# Fully Connected
|
| 1374 |
-
residual = hidden_states
|
| 1375 |
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 1376 |
-
|
| 1377 |
-
hidden_states = self.mlp(hidden_states)
|
| 1378 |
-
|
| 1379 |
-
hidden_states = residual + hidden_states
|
| 1380 |
-
|
| 1381 |
-
outputs = (hidden_states,)
|
| 1382 |
-
|
| 1383 |
-
if output_attentions:
|
| 1384 |
-
outputs += (self_attn_weights,)
|
| 1385 |
-
|
| 1386 |
-
return outputs
|
| 1387 |
-
|
| 1388 |
-
|
| 1389 |
-
class Qwen3TTSTalkerModel(Qwen3TTSTalkerTextPreTrainedModel):
|
| 1390 |
-
config_class = Qwen3TTSTalkerConfig
|
| 1391 |
-
base_model_prefix = "talker.model"
|
| 1392 |
-
|
| 1393 |
-
def __init__(self, config):
|
| 1394 |
-
super().__init__(config)
|
| 1395 |
-
self.padding_idx = config.pad_token_id
|
| 1396 |
-
self.vocab_size = config.vocab_size
|
| 1397 |
-
self.layers = nn.ModuleList(
|
| 1398 |
-
[Qwen3TTSTalkerDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 1399 |
-
)
|
| 1400 |
-
self.norm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1401 |
-
self.rotary_emb = Qwen3TTSTalkerRotaryEmbedding(config)
|
| 1402 |
-
self.gradient_checkpointing = False
|
| 1403 |
-
self.codec_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 1404 |
-
self.text_embedding = nn.Embedding(config.text_vocab_size, config.text_hidden_size)
|
| 1405 |
-
|
| 1406 |
-
# Initialize weights and apply final processing
|
| 1407 |
-
self.post_init()
|
| 1408 |
-
|
| 1409 |
-
def get_input_embeddings(self):
|
| 1410 |
-
return self.codec_embedding
|
| 1411 |
-
|
| 1412 |
-
def get_text_embeddings(self):
|
| 1413 |
-
return self.text_embedding
|
| 1414 |
-
|
| 1415 |
-
def set_input_embeddings(self, value):
|
| 1416 |
-
self.embed_tokens = value
|
| 1417 |
-
|
| 1418 |
-
@can_return_tuple
|
| 1419 |
-
def forward(
|
| 1420 |
-
self,
|
| 1421 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 1422 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1423 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 1424 |
-
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
| 1425 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1426 |
-
use_cache: Optional[bool] = None,
|
| 1427 |
-
output_attentions: Optional[bool] = None,
|
| 1428 |
-
output_hidden_states: Optional[bool] = None,
|
| 1429 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 1430 |
-
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 1431 |
-
) -> BaseModelOutputWithPast:
|
| 1432 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1433 |
-
output_hidden_states = (
|
| 1434 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1435 |
-
)
|
| 1436 |
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1437 |
-
|
| 1438 |
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1439 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1440 |
-
|
| 1441 |
-
if self.gradient_checkpointing and self.training:
|
| 1442 |
-
if use_cache:
|
| 1443 |
-
logger.warning_once(
|
| 1444 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 1445 |
-
)
|
| 1446 |
-
use_cache = False
|
| 1447 |
-
|
| 1448 |
-
if use_cache and past_key_values is None:
|
| 1449 |
-
past_key_values = DynamicCache()
|
| 1450 |
-
|
| 1451 |
-
if inputs_embeds is None:
|
| 1452 |
-
inputs_embeds = self.embed_tokens(input_ids)
|
| 1453 |
-
|
| 1454 |
-
if cache_position is None:
|
| 1455 |
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1456 |
-
cache_position = torch.arange(
|
| 1457 |
-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 1458 |
-
)
|
| 1459 |
-
|
| 1460 |
-
# the hard coded `3` is for temporal, height and width.
|
| 1461 |
-
if position_ids is None:
|
| 1462 |
-
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
| 1463 |
-
elif position_ids.ndim == 2:
|
| 1464 |
-
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
| 1465 |
-
|
| 1466 |
-
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
|
| 1467 |
-
text_position_ids = position_ids[0]
|
| 1468 |
-
position_ids = position_ids[1:]
|
| 1469 |
-
else:
|
| 1470 |
-
text_position_ids = position_ids[0]
|
| 1471 |
-
|
| 1472 |
-
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
| 1473 |
-
causal_mask = mask_function(
|
| 1474 |
-
config=self.config,
|
| 1475 |
-
input_embeds=inputs_embeds,
|
| 1476 |
-
attention_mask=attention_mask,
|
| 1477 |
-
cache_position=cache_position,
|
| 1478 |
-
past_key_values=past_key_values,
|
| 1479 |
-
position_ids=text_position_ids,
|
| 1480 |
-
)
|
| 1481 |
-
|
| 1482 |
-
hidden_states = inputs_embeds
|
| 1483 |
-
|
| 1484 |
-
# create position embeddings to be shared across the decoder layers
|
| 1485 |
-
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1486 |
-
|
| 1487 |
-
# decoder layers
|
| 1488 |
-
all_hidden_states = () if output_hidden_states else None
|
| 1489 |
-
all_self_attns = () if output_attentions else None
|
| 1490 |
-
|
| 1491 |
-
for decoder_layer in self.layers:
|
| 1492 |
-
if output_hidden_states:
|
| 1493 |
-
all_hidden_states += (hidden_states,)
|
| 1494 |
-
|
| 1495 |
-
layer_outputs = decoder_layer(
|
| 1496 |
-
hidden_states,
|
| 1497 |
-
attention_mask=causal_mask,
|
| 1498 |
-
position_ids=text_position_ids,
|
| 1499 |
-
past_key_values=past_key_values,
|
| 1500 |
-
output_attentions=output_attentions,
|
| 1501 |
-
use_cache=use_cache,
|
| 1502 |
-
cache_position=cache_position,
|
| 1503 |
-
position_embeddings=position_embeddings,
|
| 1504 |
-
**flash_attn_kwargs,
|
| 1505 |
-
)
|
| 1506 |
-
|
| 1507 |
-
hidden_states = layer_outputs[0]
|
| 1508 |
-
|
| 1509 |
-
if output_attentions:
|
| 1510 |
-
all_self_attns += (layer_outputs[1],)
|
| 1511 |
-
|
| 1512 |
-
hidden_states = self.norm(hidden_states)
|
| 1513 |
-
|
| 1514 |
-
# add hidden states from the last decoder layer
|
| 1515 |
-
if output_hidden_states:
|
| 1516 |
-
all_hidden_states += (hidden_states,)
|
| 1517 |
-
|
| 1518 |
-
return BaseModelOutputWithPast(
|
| 1519 |
-
last_hidden_state=hidden_states,
|
| 1520 |
-
past_key_values=past_key_values,
|
| 1521 |
-
hidden_states=all_hidden_states,
|
| 1522 |
-
attentions=all_self_attns,
|
| 1523 |
-
)
|
| 1524 |
-
|
| 1525 |
-
|
| 1526 |
-
class Qwen3TTSTalkerForConditionalGeneration(Qwen3TTSTalkerTextPreTrainedModel, GenerationMixin):
|
| 1527 |
-
_tied_weights_keys = ["lm_head.weight"]
|
| 1528 |
-
_tp_plan = {"lm_head": "colwise_rep"}
|
| 1529 |
-
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 1530 |
-
config_class = Qwen3TTSTalkerConfig
|
| 1531 |
-
base_model_prefix = "talker"
|
| 1532 |
-
|
| 1533 |
-
def __init__(self, config: Qwen3TTSTalkerConfig):
|
| 1534 |
-
super().__init__(config)
|
| 1535 |
-
self.model = Qwen3TTSTalkerModel(config)
|
| 1536 |
-
self.vocab_size = config.vocab_size
|
| 1537 |
-
self.text_projection = Qwen3TTSTalkerResizeMLP(
|
| 1538 |
-
config.text_hidden_size, config.text_hidden_size, config.hidden_size, config.hidden_act, bias=True
|
| 1539 |
-
)
|
| 1540 |
-
|
| 1541 |
-
self.codec_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1542 |
-
self.code_predictor = Qwen3TTSTalkerCodePredictorModelForConditionalGeneration(
|
| 1543 |
-
config=config.code_predictor_config,
|
| 1544 |
-
talker_config=config
|
| 1545 |
-
)
|
| 1546 |
-
self.rope_deltas = None
|
| 1547 |
-
|
| 1548 |
-
# Initialize weights and apply final processing
|
| 1549 |
-
self.post_init()
|
| 1550 |
-
|
| 1551 |
-
# TODO: hack, modular cannot inherit multiple classes
|
| 1552 |
-
|
| 1553 |
-
def get_input_embeddings(self):
|
| 1554 |
-
return self.model.get_input_embeddings()
|
| 1555 |
-
|
| 1556 |
-
def get_text_embeddings(self):
|
| 1557 |
-
return self.model.get_text_embeddings()
|
| 1558 |
-
|
| 1559 |
-
def set_input_embeddings(self, value):
|
| 1560 |
-
self.model.embed_tokens = value
|
| 1561 |
-
|
| 1562 |
-
def get_output_embeddings(self):
|
| 1563 |
-
return self.lm_head
|
| 1564 |
-
|
| 1565 |
-
def set_output_embeddings(self, new_embeddings):
|
| 1566 |
-
self.lm_head = new_embeddings
|
| 1567 |
-
|
| 1568 |
-
def set_decoder(self, decoder):
|
| 1569 |
-
self.model = decoder
|
| 1570 |
-
|
| 1571 |
-
def get_decoder(self):
|
| 1572 |
-
return self.model
|
| 1573 |
-
|
| 1574 |
-
def forward_sub_talker_finetune(self, codec_ids, talker_hidden_states):
|
| 1575 |
-
assert len(codec_ids.shape) == 2
|
| 1576 |
-
assert len(talker_hidden_states.shape) == 2
|
| 1577 |
-
assert codec_ids.shape[0] == talker_hidden_states.shape[0]
|
| 1578 |
-
assert talker_hidden_states.shape[1] == self.config.hidden_size
|
| 1579 |
-
assert codec_ids.shape[1] == self.config.num_code_groups
|
| 1580 |
-
|
| 1581 |
-
sub_talker_inputs_embeds = [talker_hidden_states.unsqueeze(1)]
|
| 1582 |
-
|
| 1583 |
-
for i in range(self.config.num_code_groups - 1):
|
| 1584 |
-
if i == 0:
|
| 1585 |
-
sub_talker_inputs_embeds.append(self.get_input_embeddings()(codec_ids[:, :1]))
|
| 1586 |
-
else:
|
| 1587 |
-
sub_talker_inputs_embeds.append(self.code_predictor.get_input_embeddings()[i-1](codec_ids[:, i:i+1]))
|
| 1588 |
-
sub_talker_inputs_embeds = torch.cat(sub_talker_inputs_embeds, dim=1)
|
| 1589 |
-
|
| 1590 |
-
sub_talker_outputs = self.code_predictor.forward_finetune(inputs_embeds=sub_talker_inputs_embeds,
|
| 1591 |
-
labels=codec_ids[:, 1:])
|
| 1592 |
-
|
| 1593 |
-
sub_talker_logits = sub_talker_outputs.logits
|
| 1594 |
-
sub_talker_loss = sub_talker_outputs.loss
|
| 1595 |
-
return sub_talker_logits, sub_talker_loss
|
| 1596 |
-
|
| 1597 |
-
@can_return_tuple
|
| 1598 |
-
def forward(
|
| 1599 |
-
self,
|
| 1600 |
-
input_ids=None,
|
| 1601 |
-
attention_mask=None,
|
| 1602 |
-
position_ids=None,
|
| 1603 |
-
past_key_values=None,
|
| 1604 |
-
inputs_embeds=None,
|
| 1605 |
-
labels=None,
|
| 1606 |
-
use_cache=None,
|
| 1607 |
-
output_attentions=None,
|
| 1608 |
-
output_hidden_states=None,
|
| 1609 |
-
cache_position=None,
|
| 1610 |
-
past_hidden=None,
|
| 1611 |
-
trailing_text_hidden=None,
|
| 1612 |
-
tts_pad_embed=None,
|
| 1613 |
-
generation_step=None,
|
| 1614 |
-
subtalker_dosample=None,
|
| 1615 |
-
subtalker_top_p=None,
|
| 1616 |
-
subtalker_top_k=None,
|
| 1617 |
-
subtalker_temperature=None,
|
| 1618 |
-
**kwargs,
|
| 1619 |
-
) -> CausalLMOutputWithPast:
|
| 1620 |
-
r"""
|
| 1621 |
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1622 |
-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1623 |
-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1624 |
-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1625 |
-
```"""
|
| 1626 |
-
# Prefill
|
| 1627 |
-
if inputs_embeds is not None and inputs_embeds.shape[1] > 1:
|
| 1628 |
-
generation_step = -1
|
| 1629 |
-
codec_ids = None
|
| 1630 |
-
# Generate
|
| 1631 |
-
else:
|
| 1632 |
-
last_id_hidden = self.get_input_embeddings()(input_ids)
|
| 1633 |
-
predictor_result = self.code_predictor.generate(
|
| 1634 |
-
inputs_embeds=torch.cat((past_hidden, last_id_hidden), dim=1),
|
| 1635 |
-
max_new_tokens=self.config.num_code_groups - 1,
|
| 1636 |
-
do_sample=subtalker_dosample,
|
| 1637 |
-
top_p=subtalker_top_p,
|
| 1638 |
-
top_k=subtalker_top_k,
|
| 1639 |
-
temperature=subtalker_temperature,
|
| 1640 |
-
output_hidden_states=True,
|
| 1641 |
-
return_dict_in_generate=True,
|
| 1642 |
-
)
|
| 1643 |
-
codec_ids = torch.cat((input_ids, predictor_result.sequences), dim=-1)
|
| 1644 |
-
codec_hiddens = torch.cat(
|
| 1645 |
-
[last_id_hidden]
|
| 1646 |
-
+ [self.code_predictor.get_input_embeddings()[i](predictor_result.sequences[..., i:i+1]) for i in range(self.config.num_code_groups - 1)],
|
| 1647 |
-
dim=1,
|
| 1648 |
-
)
|
| 1649 |
-
inputs_embeds = codec_hiddens.sum(1, keepdim=True)
|
| 1650 |
-
|
| 1651 |
-
if generation_step < trailing_text_hidden.shape[1]:
|
| 1652 |
-
inputs_embeds = inputs_embeds + trailing_text_hidden[:, generation_step].unsqueeze(1)
|
| 1653 |
-
else:
|
| 1654 |
-
inputs_embeds = inputs_embeds + tts_pad_embed
|
| 1655 |
-
if attention_mask is not None:
|
| 1656 |
-
if (
|
| 1657 |
-
cache_position is None
|
| 1658 |
-
or (cache_position is not None and cache_position[0] == 0)
|
| 1659 |
-
or self.rope_deltas is None
|
| 1660 |
-
):
|
| 1661 |
-
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
|
| 1662 |
-
position_ids, rope_deltas = self.get_rope_index(
|
| 1663 |
-
attention_mask,
|
| 1664 |
-
)
|
| 1665 |
-
rope_deltas = rope_deltas - delta0
|
| 1666 |
-
self.rope_deltas = rope_deltas
|
| 1667 |
-
else:
|
| 1668 |
-
batch_size, seq_length = input_ids.shape
|
| 1669 |
-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
| 1670 |
-
position_ids = torch.arange(seq_length, device=input_ids.device)
|
| 1671 |
-
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
| 1672 |
-
position_ids = position_ids.add(delta)
|
| 1673 |
-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
| 1674 |
-
|
| 1675 |
-
outputs: BaseModelOutputWithPast = self.model(
|
| 1676 |
-
input_ids=None,
|
| 1677 |
-
attention_mask=attention_mask,
|
| 1678 |
-
position_ids=position_ids,
|
| 1679 |
-
past_key_values=past_key_values,
|
| 1680 |
-
inputs_embeds=inputs_embeds,
|
| 1681 |
-
use_cache=use_cache,
|
| 1682 |
-
output_attentions=output_attentions,
|
| 1683 |
-
output_hidden_states=output_hidden_states,
|
| 1684 |
-
cache_position=cache_position,
|
| 1685 |
-
**kwargs,
|
| 1686 |
-
)
|
| 1687 |
-
|
| 1688 |
-
hidden_states = outputs.last_hidden_state
|
| 1689 |
-
logits = self.codec_head(hidden_states)
|
| 1690 |
-
|
| 1691 |
-
loss = None
|
| 1692 |
-
if labels is not None:
|
| 1693 |
-
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 1694 |
-
|
| 1695 |
-
|
| 1696 |
-
return Qwen3TTSTalkerOutputWithPast(
|
| 1697 |
-
loss=loss,
|
| 1698 |
-
logits=logits,
|
| 1699 |
-
past_key_values=outputs.past_key_values,
|
| 1700 |
-
hidden_states=(outputs.hidden_states, codec_ids),
|
| 1701 |
-
attentions=outputs.attentions,
|
| 1702 |
-
past_hidden=hidden_states[:, -1:, :],
|
| 1703 |
-
generation_step=generation_step + 1,
|
| 1704 |
-
trailing_text_hidden=trailing_text_hidden,
|
| 1705 |
-
tts_pad_embed=tts_pad_embed,
|
| 1706 |
-
)
|
| 1707 |
-
|
| 1708 |
-
def get_rope_index(
|
| 1709 |
-
self,
|
| 1710 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1711 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 1712 |
-
"""
|
| 1713 |
-
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
|
| 1714 |
-
|
| 1715 |
-
Explanation:
|
| 1716 |
-
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
| 1717 |
-
|
| 1718 |
-
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
|
| 1719 |
-
Examples:
|
| 1720 |
-
input_ids: [T T T T T], here T is for text.
|
| 1721 |
-
temporal position_ids: [0, 1, 2, 3, 4]
|
| 1722 |
-
height position_ids: [0, 1, 2, 3, 4]
|
| 1723 |
-
width position_ids: [0, 1, 2, 3, 4]
|
| 1724 |
-
|
| 1725 |
-
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
| 1726 |
-
and 1D rotary position embedding for text part.
|
| 1727 |
-
Examples:
|
| 1728 |
-
Temporal (Time): 3 patches, representing different segments of the video in time.
|
| 1729 |
-
Height: 2 patches, dividing each frame vertically.
|
| 1730 |
-
Width: 2 patches, dividing each frame horizontally.
|
| 1731 |
-
We also have some important parameters:
|
| 1732 |
-
fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
|
| 1733 |
-
interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
|
| 1734 |
-
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
| 1735 |
-
text temporal position_ids: [101, 102, 103, 104, 105]
|
| 1736 |
-
text height position_ids: [101, 102, 103, 104, 105]
|
| 1737 |
-
text width position_ids: [101, 102, 103, 104, 105]
|
| 1738 |
-
Here we calculate the text start position_ids as the max vision position_ids plus 1.
|
| 1739 |
-
|
| 1740 |
-
Args:
|
| 1741 |
-
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1742 |
-
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 1743 |
-
it.
|
| 1744 |
-
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1745 |
-
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1746 |
-
|
| 1747 |
-
- 1 for tokens that are **not masked**,
|
| 1748 |
-
- 0 for tokens that are **masked**.
|
| 1749 |
-
|
| 1750 |
-
Returns:
|
| 1751 |
-
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
|
| 1752 |
-
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
| 1753 |
-
"""
|
| 1754 |
-
mrope_position_deltas = []
|
| 1755 |
-
|
| 1756 |
-
position_ids = attention_mask.float().cumsum(-1) - 1
|
| 1757 |
-
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1758 |
-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
| 1759 |
-
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
| 1760 |
-
mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
|
| 1761 |
-
|
| 1762 |
-
return position_ids, mrope_position_deltas
|
| 1763 |
-
|
| 1764 |
-
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1):
|
| 1765 |
-
model_kwargs = super()._update_model_kwargs_for_generation(
|
| 1766 |
-
outputs, model_kwargs, is_encoder_decoder, num_new_tokens
|
| 1767 |
-
)
|
| 1768 |
-
model_kwargs["past_hidden"] = outputs.past_hidden
|
| 1769 |
-
model_kwargs["generation_step"] = outputs.generation_step
|
| 1770 |
-
model_kwargs["trailing_text_hidden"] = outputs.trailing_text_hidden
|
| 1771 |
-
model_kwargs["tts_pad_embed"] = outputs.tts_pad_embed
|
| 1772 |
-
return model_kwargs
|
| 1773 |
-
|
| 1774 |
-
|
| 1775 |
-
class Qwen3TTSForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin):
|
| 1776 |
-
config_class = Qwen3TTSConfig
|
| 1777 |
-
|
| 1778 |
-
def __init__(self, config: Qwen3TTSConfig):
|
| 1779 |
-
super().__init__(config)
|
| 1780 |
-
self.config = config
|
| 1781 |
-
|
| 1782 |
-
self.talker = Qwen3TTSTalkerForConditionalGeneration(self.config.talker_config)
|
| 1783 |
-
|
| 1784 |
-
if config.tts_model_type == "base":
|
| 1785 |
-
self.speaker_encoder = Qwen3TTSSpeakerEncoder(self.config.speaker_encoder_config)
|
| 1786 |
-
else:
|
| 1787 |
-
self.speaker_encoder = None
|
| 1788 |
-
|
| 1789 |
-
self.speech_tokenizer = None
|
| 1790 |
-
self.generate_config = None
|
| 1791 |
-
|
| 1792 |
-
self.supported_speakers = self.config.talker_config.spk_id.keys()
|
| 1793 |
-
self.supported_languages = ["auto"]
|
| 1794 |
-
for language_id in self.config.talker_config.codec_language_id.keys():
|
| 1795 |
-
if "dialect" not in language_id:
|
| 1796 |
-
self.supported_languages.append(language_id)
|
| 1797 |
-
|
| 1798 |
-
self.speaker_encoder_sample_rate = self.config.speaker_encoder_config.sample_rate
|
| 1799 |
-
self.tokenizer_type = self.config.tokenizer_type
|
| 1800 |
-
self.tts_model_size = self.config.tts_model_size
|
| 1801 |
-
self.tts_model_type = self.config.tts_model_type
|
| 1802 |
-
|
| 1803 |
-
self.post_init()
|
| 1804 |
-
|
| 1805 |
-
def load_speech_tokenizer(self, speech_tokenizer):
|
| 1806 |
-
self.speech_tokenizer = speech_tokenizer
|
| 1807 |
-
|
| 1808 |
-
def load_generate_config(self, generate_config):
|
| 1809 |
-
self.generate_config = generate_config
|
| 1810 |
-
|
| 1811 |
-
def get_supported_speakers(self):
|
| 1812 |
-
return self.supported_speakers
|
| 1813 |
-
|
| 1814 |
-
def get_supported_languages(self):
|
| 1815 |
-
return self.supported_languages
|
| 1816 |
-
|
| 1817 |
-
@classmethod
|
| 1818 |
-
def from_pretrained(
|
| 1819 |
-
cls,
|
| 1820 |
-
pretrained_model_name_or_path,
|
| 1821 |
-
*model_args,
|
| 1822 |
-
config=None,
|
| 1823 |
-
cache_dir=None,
|
| 1824 |
-
ignore_mismatched_sizes=False,
|
| 1825 |
-
force_download=False,
|
| 1826 |
-
local_files_only=False,
|
| 1827 |
-
token=None,
|
| 1828 |
-
revision="main",
|
| 1829 |
-
use_safetensors=None,
|
| 1830 |
-
weights_only=True,
|
| 1831 |
-
**kwargs,
|
| 1832 |
-
):
|
| 1833 |
-
model = super().from_pretrained(
|
| 1834 |
-
pretrained_model_name_or_path,
|
| 1835 |
-
*model_args,
|
| 1836 |
-
config=config,
|
| 1837 |
-
cache_dir=cache_dir,
|
| 1838 |
-
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
| 1839 |
-
force_download=force_download,
|
| 1840 |
-
local_files_only=local_files_only,
|
| 1841 |
-
token=token,
|
| 1842 |
-
revision=revision,
|
| 1843 |
-
use_safetensors=use_safetensors,
|
| 1844 |
-
weights_only=weights_only,
|
| 1845 |
-
**kwargs,
|
| 1846 |
-
)
|
| 1847 |
-
speech_tokenizer_path = cached_file(
|
| 1848 |
-
pretrained_model_name_or_path,
|
| 1849 |
-
"speech_tokenizer/config.json",
|
| 1850 |
-
subfolder=kwargs.pop("subfolder", None),
|
| 1851 |
-
cache_dir=kwargs.pop("cache_dir", None),
|
| 1852 |
-
force_download=kwargs.pop("force_download", False),
|
| 1853 |
-
proxies=kwargs.pop("proxies", None),
|
| 1854 |
-
resume_download=kwargs.pop("resume_download", None),
|
| 1855 |
-
local_files_only=kwargs.pop("local_files_only", False),
|
| 1856 |
-
token=kwargs.pop("use_auth_token", None),
|
| 1857 |
-
revision=kwargs.pop("revision", None),
|
| 1858 |
-
)
|
| 1859 |
-
if speech_tokenizer_path is None:
|
| 1860 |
-
raise ValueError(f"""{pretrained_model_name_or_path}/{speech_tokenizer_path} not exists""")
|
| 1861 |
-
speech_tokenizer_dir = os.path.dirname(speech_tokenizer_path)
|
| 1862 |
-
speech_tokenizer = Qwen3TTSTokenizer.from_pretrained(
|
| 1863 |
-
speech_tokenizer_dir,
|
| 1864 |
-
*model_args,
|
| 1865 |
-
**kwargs,
|
| 1866 |
-
)
|
| 1867 |
-
model.load_speech_tokenizer(speech_tokenizer)
|
| 1868 |
-
|
| 1869 |
-
generate_config_path = cached_file(
|
| 1870 |
-
pretrained_model_name_or_path,
|
| 1871 |
-
"generation_config.json",
|
| 1872 |
-
subfolder=kwargs.pop("subfolder", None),
|
| 1873 |
-
cache_dir=kwargs.pop("cache_dir", None),
|
| 1874 |
-
force_download=kwargs.pop("force_download", False),
|
| 1875 |
-
proxies=kwargs.pop("proxies", None),
|
| 1876 |
-
resume_download=kwargs.pop("resume_download", None),
|
| 1877 |
-
local_files_only=kwargs.pop("local_files_only", False),
|
| 1878 |
-
token=kwargs.pop("use_auth_token", None),
|
| 1879 |
-
revision=kwargs.pop("revision", None),
|
| 1880 |
-
)
|
| 1881 |
-
with open(generate_config_path, "r", encoding="utf-8") as f:
|
| 1882 |
-
generate_config = json.load(f)
|
| 1883 |
-
model.load_generate_config(generate_config)
|
| 1884 |
-
|
| 1885 |
-
return model
|
| 1886 |
-
|
| 1887 |
-
@torch.inference_mode()
|
| 1888 |
-
def extract_speaker_embedding(self, audio, sr):
|
| 1889 |
-
assert sr == 24000, "Only support 24kHz audio"
|
| 1890 |
-
mels = mel_spectrogram(
|
| 1891 |
-
torch.from_numpy(audio).unsqueeze(0),
|
| 1892 |
-
n_fft=1024,
|
| 1893 |
-
num_mels=128,
|
| 1894 |
-
sampling_rate=24000,
|
| 1895 |
-
hop_size=256,
|
| 1896 |
-
win_size=1024,
|
| 1897 |
-
fmin=0,
|
| 1898 |
-
fmax=12000
|
| 1899 |
-
).transpose(1, 2)
|
| 1900 |
-
speaker_embedding = self.speaker_encoder(mels.to(self.device).to(self.dtype))[0]
|
| 1901 |
-
return speaker_embedding
|
| 1902 |
-
|
| 1903 |
-
@torch.inference_mode()
|
| 1904 |
-
def generate_speaker_prompt(
|
| 1905 |
-
self,
|
| 1906 |
-
voice_clone_prompt: list[dict]
|
| 1907 |
-
):
|
| 1908 |
-
voice_clone_spk_embeds = []
|
| 1909 |
-
for index in range(len(voice_clone_prompt['ref_spk_embedding'])):
|
| 1910 |
-
ref_spk_embedding = voice_clone_prompt["ref_spk_embedding"][index].to(self.talker.device).to(self.talker.dtype)
|
| 1911 |
-
voice_clone_spk_embeds.append(ref_spk_embedding)
|
| 1912 |
-
|
| 1913 |
-
return voice_clone_spk_embeds
|
| 1914 |
-
|
| 1915 |
-
def generate_icl_prompt(
|
| 1916 |
-
self,
|
| 1917 |
-
text_id: torch.Tensor,
|
| 1918 |
-
ref_id: torch.Tensor,
|
| 1919 |
-
ref_code: torch.Tensor,
|
| 1920 |
-
tts_pad_embed: torch.Tensor,
|
| 1921 |
-
tts_eos_embed: torch.Tensor,
|
| 1922 |
-
non_streaming_mode: bool,
|
| 1923 |
-
):
|
| 1924 |
-
# text embed (ref id + text id + eos) 1 T1 D
|
| 1925 |
-
text_embed = self.talker.text_projection(
|
| 1926 |
-
self.talker.get_text_embeddings()(torch.cat([ref_id, text_id],
|
| 1927 |
-
dim=-1)))
|
| 1928 |
-
text_embed = torch.cat([text_embed, tts_eos_embed], dim=1)
|
| 1929 |
-
# codec embed (codec bos + codec) 1 T2 D
|
| 1930 |
-
codec_embed = []
|
| 1931 |
-
for i in range(self.talker.config.num_code_groups):
|
| 1932 |
-
if i == 0:
|
| 1933 |
-
codec_embed.append(self.talker.get_input_embeddings()(ref_code[:, :1]))
|
| 1934 |
-
else:
|
| 1935 |
-
codec_embed.append(self.talker.code_predictor.get_input_embeddings()[i-1](ref_code[:, i:i+1]))
|
| 1936 |
-
codec_embed = torch.cat(codec_embed, dim=1).sum(1).unsqueeze(0)
|
| 1937 |
-
codec_embed = torch.cat([self.talker.get_input_embeddings()(
|
| 1938 |
-
torch.tensor(
|
| 1939 |
-
[[
|
| 1940 |
-
self.config.talker_config.codec_bos_id,
|
| 1941 |
-
]],
|
| 1942 |
-
device=self.talker.device,
|
| 1943 |
-
dtype=text_id.dtype,
|
| 1944 |
-
)
|
| 1945 |
-
), codec_embed], dim=1)
|
| 1946 |
-
# compute lens
|
| 1947 |
-
text_lens = text_embed.shape[1]
|
| 1948 |
-
codec_lens = codec_embed.shape[1]
|
| 1949 |
-
if non_streaming_mode:
|
| 1950 |
-
icl_input_embed = text_embed + self.talker.get_input_embeddings()(
|
| 1951 |
-
torch.tensor(
|
| 1952 |
-
[[
|
| 1953 |
-
self.config.talker_config.codec_pad_id,
|
| 1954 |
-
] * text_lens],
|
| 1955 |
-
device=self.talker.device,
|
| 1956 |
-
dtype=text_id.dtype,
|
| 1957 |
-
)
|
| 1958 |
-
)
|
| 1959 |
-
icl_input_embed = torch.cat([icl_input_embed, codec_embed + tts_pad_embed], dim=1)
|
| 1960 |
-
return icl_input_embed, tts_pad_embed
|
| 1961 |
-
else:
|
| 1962 |
-
if text_lens > codec_lens:
|
| 1963 |
-
return text_embed[:, :codec_lens] + codec_embed, text_embed[:, codec_lens:]
|
| 1964 |
-
else:
|
| 1965 |
-
text_embed = torch.cat([text_embed] + [tts_pad_embed] * (codec_lens - text_lens), dim=1)
|
| 1966 |
-
return text_embed + codec_embed, tts_pad_embed
|
| 1967 |
-
|
| 1968 |
-
@torch.no_grad()
|
| 1969 |
-
def generate(
|
| 1970 |
-
self,
|
| 1971 |
-
input_ids: Optional[list[torch.Tensor]] = None,
|
| 1972 |
-
instruct_ids: Optional[list[torch.Tensor]] = None,
|
| 1973 |
-
ref_ids: Optional[list[torch.Tensor]] = None,
|
| 1974 |
-
voice_clone_prompt: list[dict] = None,
|
| 1975 |
-
languages: list[str] = None,
|
| 1976 |
-
speakers: list[str] = None,
|
| 1977 |
-
non_streaming_mode = False,
|
| 1978 |
-
max_new_tokens: int = 4096,
|
| 1979 |
-
do_sample: bool = True,
|
| 1980 |
-
top_k: int = 50,
|
| 1981 |
-
top_p: float = 1.0,
|
| 1982 |
-
temperature: float = 0.9,
|
| 1983 |
-
subtalker_dosample: bool = True,
|
| 1984 |
-
subtalker_top_k: int = 50,
|
| 1985 |
-
subtalker_top_p: float = 1.0,
|
| 1986 |
-
subtalker_temperature: float = 0.9,
|
| 1987 |
-
eos_token_id: Optional[int] = None,
|
| 1988 |
-
repetition_penalty: float = 1.05,
|
| 1989 |
-
**kwargs,
|
| 1990 |
-
):
|
| 1991 |
-
talker_kwargs = {
|
| 1992 |
-
"max_new_tokens": max_new_tokens,
|
| 1993 |
-
"min_new_tokens": 2,
|
| 1994 |
-
"do_sample": do_sample,
|
| 1995 |
-
"top_k": top_k,
|
| 1996 |
-
"top_p": top_p,
|
| 1997 |
-
"temperature": temperature,
|
| 1998 |
-
"subtalker_dosample": subtalker_dosample,
|
| 1999 |
-
"subtalker_top_k": subtalker_top_k,
|
| 2000 |
-
"subtalker_top_p": subtalker_top_p,
|
| 2001 |
-
"subtalker_temperature": subtalker_temperature,
|
| 2002 |
-
"eos_token_id": eos_token_id
|
| 2003 |
-
if eos_token_id is not None
|
| 2004 |
-
else self.config.talker_config.codec_eos_token_id,
|
| 2005 |
-
"repetition_penalty": repetition_penalty,
|
| 2006 |
-
"suppress_tokens": [
|
| 2007 |
-
i
|
| 2008 |
-
for i in range(self.config.talker_config.vocab_size - 1024, self.config.talker_config.vocab_size)
|
| 2009 |
-
if i not in (self.config.talker_config.codec_eos_token_id,)
|
| 2010 |
-
],
|
| 2011 |
-
"output_hidden_states": getattr(kwargs, "output_hidden_states", True),
|
| 2012 |
-
"return_dict_in_generate": getattr(kwargs, "return_dict_in_generate", True)
|
| 2013 |
-
}
|
| 2014 |
-
|
| 2015 |
-
talker_input_embeds = [[] for _ in range(len(input_ids))]
|
| 2016 |
-
|
| 2017 |
-
voice_clone_spk_embeds = None
|
| 2018 |
-
# voice clone speaker prompt generate
|
| 2019 |
-
if voice_clone_prompt is not None:
|
| 2020 |
-
voice_clone_spk_embeds = self.generate_speaker_prompt(voice_clone_prompt)
|
| 2021 |
-
|
| 2022 |
-
# instruct text prompt generate
|
| 2023 |
-
if instruct_ids is not None:
|
| 2024 |
-
for index, instruct_id in enumerate(instruct_ids):
|
| 2025 |
-
if instruct_id is not None:
|
| 2026 |
-
talker_input_embeds[index].append(self.talker.text_projection(
|
| 2027 |
-
self.talker.get_text_embeddings()(instruct_id)))
|
| 2028 |
-
|
| 2029 |
-
# tts text prompt generate
|
| 2030 |
-
trailing_text_hiddens = []
|
| 2031 |
-
if speakers is None:
|
| 2032 |
-
speakers = [None] * len(input_ids)
|
| 2033 |
-
for index, (input_id, language, speaker) in enumerate(zip(input_ids, languages, speakers)):
|
| 2034 |
-
if voice_clone_spk_embeds is None:
|
| 2035 |
-
if speaker == "" or speaker == None: # Instruct create speaker
|
| 2036 |
-
speaker_embed = None
|
| 2037 |
-
else:
|
| 2038 |
-
if speaker.lower() not in self.config.talker_config.spk_id:
|
| 2039 |
-
raise NotImplementedError(f"Speaker {speaker} not implemented")
|
| 2040 |
-
else:
|
| 2041 |
-
spk_id = self.config.talker_config.spk_id[speaker.lower()]
|
| 2042 |
-
speaker_embed = self.talker.get_input_embeddings()(
|
| 2043 |
-
torch.tensor(
|
| 2044 |
-
spk_id,
|
| 2045 |
-
device=self.talker.device,
|
| 2046 |
-
dtype=input_id.dtype,
|
| 2047 |
-
)
|
| 2048 |
-
)
|
| 2049 |
-
else:
|
| 2050 |
-
if voice_clone_prompt["x_vector_only_mode"][index] or voice_clone_prompt["icl_mode"][index]:
|
| 2051 |
-
speaker_embed = voice_clone_spk_embeds[index]
|
| 2052 |
-
else:
|
| 2053 |
-
speaker_embed = None
|
| 2054 |
-
|
| 2055 |
-
assert language is not None
|
| 2056 |
-
|
| 2057 |
-
if language.lower() == "auto":
|
| 2058 |
-
language_id = None
|
| 2059 |
-
else:
|
| 2060 |
-
if language.lower() not in self.config.talker_config.codec_language_id:
|
| 2061 |
-
raise NotImplementedError(f"Language {language} not implemented")
|
| 2062 |
-
else:
|
| 2063 |
-
language_id = self.config.talker_config.codec_language_id[language.lower()]
|
| 2064 |
-
|
| 2065 |
-
if (language.lower() in ["chinese", "auto"] and \
|
| 2066 |
-
speaker != "" and speaker is not None and \
|
| 2067 |
-
self.config.talker_config.spk_is_dialect[speaker.lower()] != False):
|
| 2068 |
-
dialect = self.config.talker_config.spk_is_dialect[speaker.lower()]
|
| 2069 |
-
language_id = self.config.talker_config.codec_language_id[dialect]
|
| 2070 |
-
|
| 2071 |
-
tts_bos_embed, tts_eos_embed, tts_pad_embed = self.talker.text_projection(
|
| 2072 |
-
self.talker.get_text_embeddings()(
|
| 2073 |
-
torch.tensor(
|
| 2074 |
-
[[self.config.tts_bos_token_id, self.config.tts_eos_token_id, self.config.tts_pad_token_id]],
|
| 2075 |
-
device=self.talker.device,
|
| 2076 |
-
dtype=input_id.dtype,
|
| 2077 |
-
)
|
| 2078 |
-
)
|
| 2079 |
-
).chunk(3, dim=1) # 3 * [1 1 d]
|
| 2080 |
-
|
| 2081 |
-
# codec: tag and speaker
|
| 2082 |
-
if language_id is None:
|
| 2083 |
-
codec_prefill_list = [[
|
| 2084 |
-
self.config.talker_config.codec_nothink_id,
|
| 2085 |
-
self.config.talker_config.codec_think_bos_id,
|
| 2086 |
-
self.config.talker_config.codec_think_eos_id,
|
| 2087 |
-
]]
|
| 2088 |
-
else:
|
| 2089 |
-
codec_prefill_list = [[
|
| 2090 |
-
self.config.talker_config.codec_think_id,
|
| 2091 |
-
self.config.talker_config.codec_think_bos_id,
|
| 2092 |
-
language_id,
|
| 2093 |
-
self.config.talker_config.codec_think_eos_id,
|
| 2094 |
-
]]
|
| 2095 |
-
|
| 2096 |
-
codec_input_emebdding_0 = self.talker.get_input_embeddings()(
|
| 2097 |
-
torch.tensor(
|
| 2098 |
-
codec_prefill_list,
|
| 2099 |
-
device=self.talker.device,
|
| 2100 |
-
dtype=input_id.dtype,
|
| 2101 |
-
)
|
| 2102 |
-
)
|
| 2103 |
-
codec_input_emebdding_1 = self.talker.get_input_embeddings()(
|
| 2104 |
-
torch.tensor(
|
| 2105 |
-
[[
|
| 2106 |
-
self.config.talker_config.codec_pad_id,
|
| 2107 |
-
self.config.talker_config.codec_bos_id,
|
| 2108 |
-
]],
|
| 2109 |
-
device=self.talker.device,
|
| 2110 |
-
dtype=input_id.dtype,
|
| 2111 |
-
)
|
| 2112 |
-
)
|
| 2113 |
-
if speaker_embed is None:
|
| 2114 |
-
codec_input_emebdding = torch.cat([codec_input_emebdding_0,
|
| 2115 |
-
codec_input_emebdding_1], dim=1)
|
| 2116 |
-
else:
|
| 2117 |
-
codec_input_emebdding = torch.cat([codec_input_emebdding_0,
|
| 2118 |
-
speaker_embed.view(1, 1, -1),
|
| 2119 |
-
codec_input_emebdding_1], dim=1)
|
| 2120 |
-
|
| 2121 |
-
# '<|im_start|>assistant\n我叫通义千问,是阿里云的开源大模型。<|im_end|>\n<|im_start|>assistant\n'
|
| 2122 |
-
|
| 2123 |
-
# <|im_start|>assistant\n
|
| 2124 |
-
_talker_input_embed_role = self.talker.text_projection(
|
| 2125 |
-
self.talker.get_text_embeddings()(input_id[:, :3])
|
| 2126 |
-
)
|
| 2127 |
-
|
| 2128 |
-
# tts_pad * 4 + tts_bos
|
| 2129 |
-
_talker_input_embed = torch.cat((tts_pad_embed.expand(-1, codec_input_emebdding.shape[1] - 2, -1),
|
| 2130 |
-
tts_bos_embed,
|
| 2131 |
-
), dim=1) + codec_input_emebdding[:, :-1]
|
| 2132 |
-
|
| 2133 |
-
talker_input_embed = torch.cat((_talker_input_embed_role, _talker_input_embed), dim=1)
|
| 2134 |
-
|
| 2135 |
-
if voice_clone_prompt is not None and voice_clone_prompt["ref_code"] is not None and voice_clone_prompt["icl_mode"][index]:
|
| 2136 |
-
icl_input_embed, trailing_text_hidden = self.generate_icl_prompt(
|
| 2137 |
-
text_id=input_id[:, 3:-5],
|
| 2138 |
-
ref_id=ref_ids[index][:, 3:-2],
|
| 2139 |
-
ref_code=voice_clone_prompt["ref_code"][index].to(self.talker.device),
|
| 2140 |
-
tts_pad_embed=tts_pad_embed,
|
| 2141 |
-
tts_eos_embed=tts_eos_embed,
|
| 2142 |
-
non_streaming_mode=non_streaming_mode,
|
| 2143 |
-
)
|
| 2144 |
-
talker_input_embed = torch.cat([talker_input_embed, icl_input_embed], dim=1)
|
| 2145 |
-
else:
|
| 2146 |
-
# tts_text_first_token
|
| 2147 |
-
talker_input_embed = torch.cat([talker_input_embed,
|
| 2148 |
-
self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 3:4])) + codec_input_emebdding[:, -1:]],
|
| 2149 |
-
dim=1)
|
| 2150 |
-
if non_streaming_mode:
|
| 2151 |
-
talker_input_embed = talker_input_embed[:, :-1] # 去掉原本放进去的text
|
| 2152 |
-
talker_input_embed = torch.cat([talker_input_embed,
|
| 2153 |
-
torch.cat((self.talker.text_projection(
|
| 2154 |
-
self.talker.get_text_embeddings()(input_id[:, 3:-5])
|
| 2155 |
-
), tts_eos_embed), dim=1) + self.talker.get_input_embeddings()(
|
| 2156 |
-
torch.tensor(
|
| 2157 |
-
[[
|
| 2158 |
-
self.config.talker_config.codec_pad_id,
|
| 2159 |
-
] * (input_id[:, 3:-5].shape[1] + 1)],
|
| 2160 |
-
device=self.talker.device,
|
| 2161 |
-
dtype=input_id.dtype,
|
| 2162 |
-
)
|
| 2163 |
-
),
|
| 2164 |
-
tts_pad_embed + self.talker.get_input_embeddings()(
|
| 2165 |
-
torch.tensor(
|
| 2166 |
-
[[
|
| 2167 |
-
self.config.talker_config.codec_bos_id,
|
| 2168 |
-
]],
|
| 2169 |
-
device=self.talker.device,
|
| 2170 |
-
dtype=input_id.dtype,
|
| 2171 |
-
)
|
| 2172 |
-
)
|
| 2173 |
-
], dim=1)
|
| 2174 |
-
trailing_text_hidden = tts_pad_embed
|
| 2175 |
-
else:
|
| 2176 |
-
# 叫通义千问,是阿里云的开源大模型。
|
| 2177 |
-
trailing_text_hidden = torch.cat((self.talker.text_projection(
|
| 2178 |
-
self.talker.get_text_embeddings()(input_id[:, 4:-5])
|
| 2179 |
-
), tts_eos_embed), dim=1)
|
| 2180 |
-
talker_input_embeds[index].append(talker_input_embed)
|
| 2181 |
-
trailing_text_hiddens.append(trailing_text_hidden)
|
| 2182 |
-
|
| 2183 |
-
for index, talker_input_embed in enumerate(talker_input_embeds):
|
| 2184 |
-
talker_input_embeds[index] = torch.cat([item for item in talker_input_embed if item is not None], dim=1)
|
| 2185 |
-
|
| 2186 |
-
# for batch inferquence
|
| 2187 |
-
original_lengths = torch.tensor([t.shape[1] for t in talker_input_embeds])
|
| 2188 |
-
# left padding for talker input embeds
|
| 2189 |
-
sequences = [t.squeeze(0) for t in talker_input_embeds]
|
| 2190 |
-
sequences_reversed = [t.flip(dims=[0]) for t in sequences]
|
| 2191 |
-
padded_reversed = torch.nn.utils.rnn.pad_sequence(
|
| 2192 |
-
sequences_reversed,
|
| 2193 |
-
batch_first=True,
|
| 2194 |
-
padding_value=0.0
|
| 2195 |
-
)
|
| 2196 |
-
talker_input_embeds = padded_reversed.flip(dims=[1])
|
| 2197 |
-
# generate mask
|
| 2198 |
-
batch_size, max_len = talker_input_embeds.shape[0], talker_input_embeds.shape[1]
|
| 2199 |
-
indices = torch.arange(max_len).expand(batch_size, -1)
|
| 2200 |
-
num_pads = max_len - original_lengths
|
| 2201 |
-
talker_attention_mask = (indices >= num_pads.unsqueeze(1)).long().to(talker_input_embeds.device)
|
| 2202 |
-
# padding trailing text hiddens
|
| 2203 |
-
pad_embedding_vector = tts_pad_embed.squeeze()
|
| 2204 |
-
sequences_to_pad = [t.squeeze(0) for t in trailing_text_hiddens]
|
| 2205 |
-
trailing_text_original_lengths = [s.shape[0] for s in sequences_to_pad]
|
| 2206 |
-
padded_hiddens = torch.nn.utils.rnn.pad_sequence(
|
| 2207 |
-
sequences_to_pad,
|
| 2208 |
-
batch_first=True,
|
| 2209 |
-
padding_value=0.0
|
| 2210 |
-
)
|
| 2211 |
-
arange_tensor = torch.arange(max(trailing_text_original_lengths),
|
| 2212 |
-
device=padded_hiddens.device).expand(len(trailing_text_original_lengths), -1)
|
| 2213 |
-
lengths_tensor = torch.tensor(trailing_text_original_lengths, device=padded_hiddens.device).unsqueeze(1)
|
| 2214 |
-
padding_mask = arange_tensor >= lengths_tensor
|
| 2215 |
-
padded_hiddens[padding_mask] = pad_embedding_vector
|
| 2216 |
-
trailing_text_hiddens = padded_hiddens
|
| 2217 |
-
|
| 2218 |
-
# forward
|
| 2219 |
-
talker_result = self.talker.generate(
|
| 2220 |
-
inputs_embeds=talker_input_embeds,
|
| 2221 |
-
attention_mask=talker_attention_mask,
|
| 2222 |
-
trailing_text_hidden=trailing_text_hiddens,
|
| 2223 |
-
tts_pad_embed=tts_pad_embed,
|
| 2224 |
-
**talker_kwargs,
|
| 2225 |
-
)
|
| 2226 |
-
|
| 2227 |
-
talker_codes = torch.stack([hid[-1] for hid in talker_result.hidden_states if hid[-1] is not None], dim=1)
|
| 2228 |
-
talker_hidden_states = torch.cat([hid[0][-1][:, -1:] for hid in talker_result.hidden_states], dim=1)[:, :-1]
|
| 2229 |
-
|
| 2230 |
-
first_codebook = talker_codes[:, :, 0]
|
| 2231 |
-
is_stop_token = (first_codebook == self.config.talker_config.codec_eos_token_id)
|
| 2232 |
-
stop_indices = torch.argmax(is_stop_token.int(), dim=1)
|
| 2233 |
-
has_stop_token = is_stop_token.any(dim=1)
|
| 2234 |
-
effective_lengths = torch.where(has_stop_token, stop_indices, talker_codes.shape[1])
|
| 2235 |
-
|
| 2236 |
-
talker_codes_list = [talker_codes[i, :length, ] for i, length in enumerate(effective_lengths)]
|
| 2237 |
-
talker_hidden_states_list = [talker_hidden_states[i, :length, :] for i, length in enumerate(effective_lengths)]
|
| 2238 |
-
|
| 2239 |
-
return talker_codes_list, talker_hidden_states_list
|
| 2240 |
-
|
| 2241 |
-
__all__ = [
|
| 2242 |
-
"Qwen3TTSForConditionalGeneration",
|
| 2243 |
-
"Qwen3TTSTalkerForConditionalGeneration",
|
| 2244 |
-
"Qwen3TTSPreTrainedModel",
|
| 2245 |
-
"Qwen3TTSTalkerModel",
|
| 2246 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/models/processing_qwen3_tts.py
DELETED
|
@@ -1,106 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
from transformers.feature_extraction_utils import BatchFeature
|
| 16 |
-
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class Qwen3TTSProcessorKwargs(ProcessingKwargs, total=False):
|
| 20 |
-
_defaults = {
|
| 21 |
-
"text_kwargs": {
|
| 22 |
-
"padding": False,
|
| 23 |
-
"padding_side": "left",
|
| 24 |
-
}
|
| 25 |
-
}
|
| 26 |
-
|
| 27 |
-
class Qwen3TTSProcessor(ProcessorMixin):
|
| 28 |
-
r"""
|
| 29 |
-
Constructs a Qwen3TTS processor.
|
| 30 |
-
|
| 31 |
-
Args:
|
| 32 |
-
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
| 33 |
-
The text tokenizer.
|
| 34 |
-
chat_template (`Optional[str]`, *optional*):
|
| 35 |
-
The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
|
| 36 |
-
"""
|
| 37 |
-
|
| 38 |
-
attributes = ["tokenizer"]
|
| 39 |
-
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
| 40 |
-
|
| 41 |
-
def __init__(
|
| 42 |
-
self, tokenizer=None, chat_template=None
|
| 43 |
-
):
|
| 44 |
-
super().__init__(tokenizer, chat_template=chat_template)
|
| 45 |
-
|
| 46 |
-
def __call__(self, text=None, **kwargs) -> BatchFeature:
|
| 47 |
-
"""
|
| 48 |
-
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
|
| 49 |
-
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
| 50 |
-
the text.
|
| 51 |
-
|
| 52 |
-
Args:
|
| 53 |
-
text (`str`, `List[str]`, `List[List[str]]`):
|
| 54 |
-
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 55 |
-
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 56 |
-
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 57 |
-
"""
|
| 58 |
-
|
| 59 |
-
if text is None:
|
| 60 |
-
raise ValueError("You need to specify either a `text` input to process.")
|
| 61 |
-
|
| 62 |
-
output_kwargs = self._merge_kwargs(
|
| 63 |
-
Qwen3TTSProcessorKwargs,
|
| 64 |
-
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 65 |
-
**kwargs,
|
| 66 |
-
)
|
| 67 |
-
if not isinstance(text, list):
|
| 68 |
-
text = [text]
|
| 69 |
-
|
| 70 |
-
texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
| 71 |
-
|
| 72 |
-
return BatchFeature(
|
| 73 |
-
data={**texts_inputs},
|
| 74 |
-
tensor_type=kwargs.get("return_tensors"),
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
def batch_decode(self, *args, **kwargs):
|
| 78 |
-
"""
|
| 79 |
-
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 80 |
-
refer to the docstring of this method for more information.
|
| 81 |
-
"""
|
| 82 |
-
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 83 |
-
|
| 84 |
-
def decode(self, *args, **kwargs):
|
| 85 |
-
"""
|
| 86 |
-
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 87 |
-
the docstring of this method for more information.
|
| 88 |
-
"""
|
| 89 |
-
return self.tokenizer.decode(*args, **kwargs)
|
| 90 |
-
|
| 91 |
-
def apply_chat_template(self, conversations, chat_template=None, **kwargs):
|
| 92 |
-
if isinstance(conversations[0], dict):
|
| 93 |
-
conversations = [conversations]
|
| 94 |
-
return super().apply_chat_template(conversations, chat_template, **kwargs)
|
| 95 |
-
|
| 96 |
-
@property
|
| 97 |
-
def model_input_names(self):
|
| 98 |
-
tokenizer_input_names = self.tokenizer.model_input_names
|
| 99 |
-
return list(
|
| 100 |
-
dict.fromkeys(
|
| 101 |
-
tokenizer_input_names
|
| 102 |
-
)
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
__all__ = ["Qwen3TTSProcessor"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/tokenizer_12hz/configuration_qwen3_tts_tokenizer_v2.py
DELETED
|
@@ -1,172 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
"""Qwen3TTSTokenizerV2 model configuration"""
|
| 16 |
-
|
| 17 |
-
from transformers.configuration_utils import PretrainedConfig
|
| 18 |
-
from transformers.utils import logging
|
| 19 |
-
|
| 20 |
-
from transformers import MimiConfig
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
logger = logging.get_logger(__name__)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class Qwen3TTSTokenizerV2DecoderConfig(PretrainedConfig):
|
| 27 |
-
r"""
|
| 28 |
-
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2DecoderConfig`].
|
| 29 |
-
|
| 30 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 31 |
-
documentation from [`PretrainedConfig`] for more information.
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
codebook_size (`int`, *optional*, defaults to 2048):
|
| 35 |
-
Number of entries in each residual codebook used for acoustic token quantization.
|
| 36 |
-
hidden_size (`int`, *optional*, defaults to 1024):
|
| 37 |
-
Dimensionality of the hidden states and embeddings in the autoregressive transformer decoder.
|
| 38 |
-
max_position_embeddings (`int`, *optional*, defaults to 8000):
|
| 39 |
-
Maximum sequence length that the autoregressive decoder can handle. Determines positional embedding size.
|
| 40 |
-
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 41 |
-
The base period for rotary position embeddings (RoPE) applied to attention layers.
|
| 42 |
-
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 43 |
-
Number of attention heads for each attention layer in the decoder.
|
| 44 |
-
num_key_value_heads (`int`, *optional*, defaults to 16):
|
| 45 |
-
Number of key and value attention heads used in grouped-query attention (if applicable).
|
| 46 |
-
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 47 |
-
Whether to use bias in the attention projection layers.
|
| 48 |
-
sliding_window (`int`, *optional*, defaults to 72):
|
| 49 |
-
Window size for local attention mechanism, limiting attention context to improve efficiency.
|
| 50 |
-
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 51 |
-
Dimensionality of the feed-forward (intermediate) layer in each transformer block.
|
| 52 |
-
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 53 |
-
The non-linear activation function used in the feed-forward layers. Supports `"silu"`, `"relu"`, `"gelu"`, etc.
|
| 54 |
-
layer_scale_initial_scale (`float`, *optional*, defaults to 0.01):
|
| 55 |
-
Initial value for LayerScale applied in transformer blocks, helping stabilize training.
|
| 56 |
-
rms_norm_eps (`float`, *optional*, defaults to 1e-5):
|
| 57 |
-
Epsilon value for RMS normalization layers to prevent division by zero.
|
| 58 |
-
num_hidden_layers (`int`, *optional*, defaults to 8):
|
| 59 |
-
Number of transformer blocks in the autoregressive decoder.
|
| 60 |
-
num_quantizers (`int`, *optional*, defaults to 16):
|
| 61 |
-
Number of residual vector quantizers used in the vocoder for fine-grained audio reconstruction.
|
| 62 |
-
upsample_rates (`Tuple[int]`, *optional*, defaults to `(8, 5, 4, 3)`):
|
| 63 |
-
Rate at which features are upsampled in the final waveform synthesis stage.
|
| 64 |
-
upsampling_ratios (`Tuple[int]`, *optional*, defaults to `(2, 2)`):
|
| 65 |
-
Ratios used in transposed convolutional layers to progressively upsample feature maps to waveform.
|
| 66 |
-
decoder_dim (`int`, *optional*, defaults to 1536):
|
| 67 |
-
Final dimensionality of the decoder's output before waveform generation.
|
| 68 |
-
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 69 |
-
Dropout probability applied to attention weights in the decoder.
|
| 70 |
-
"""
|
| 71 |
-
|
| 72 |
-
def __init__(
|
| 73 |
-
self,
|
| 74 |
-
codebook_size=2048,
|
| 75 |
-
hidden_size=1024,
|
| 76 |
-
latent_dim=1024,
|
| 77 |
-
max_position_embeddings=8000,
|
| 78 |
-
rope_theta=10000,
|
| 79 |
-
num_attention_heads=16,
|
| 80 |
-
num_key_value_heads=16,
|
| 81 |
-
attention_bias=False,
|
| 82 |
-
sliding_window=72,
|
| 83 |
-
intermediate_size=3072,
|
| 84 |
-
hidden_act="silu",
|
| 85 |
-
layer_scale_initial_scale=0.01,
|
| 86 |
-
rms_norm_eps=1e-5,
|
| 87 |
-
num_hidden_layers=8,
|
| 88 |
-
num_quantizers=16,
|
| 89 |
-
upsample_rates=(8, 5, 4, 3),
|
| 90 |
-
upsampling_ratios=(2, 2),
|
| 91 |
-
decoder_dim=1536,
|
| 92 |
-
attention_dropout=0.0,
|
| 93 |
-
**kwargs,
|
| 94 |
-
):
|
| 95 |
-
super().__init__(**kwargs)
|
| 96 |
-
self.codebook_size = codebook_size
|
| 97 |
-
self.hidden_size = hidden_size
|
| 98 |
-
self.latent_dim = latent_dim
|
| 99 |
-
self.max_position_embeddings = max_position_embeddings
|
| 100 |
-
self.rope_theta = rope_theta
|
| 101 |
-
self.num_attention_heads = num_attention_heads
|
| 102 |
-
self.num_key_value_heads = num_key_value_heads
|
| 103 |
-
self.attention_bias = attention_bias
|
| 104 |
-
self.sliding_window = sliding_window
|
| 105 |
-
self.intermediate_size = intermediate_size
|
| 106 |
-
self.hidden_act = hidden_act
|
| 107 |
-
self.layer_scale_initial_scale = layer_scale_initial_scale
|
| 108 |
-
self.rms_norm_eps = rms_norm_eps
|
| 109 |
-
self.num_hidden_layers = num_hidden_layers
|
| 110 |
-
self.num_quantizers = num_quantizers
|
| 111 |
-
self.upsample_rates = upsample_rates
|
| 112 |
-
self.upsampling_ratios = upsampling_ratios
|
| 113 |
-
self.decoder_dim = decoder_dim
|
| 114 |
-
self.attention_dropout = attention_dropout
|
| 115 |
-
|
| 116 |
-
@property
|
| 117 |
-
def layer_types(self):
|
| 118 |
-
"""
|
| 119 |
-
All layer in code2wav should be sliding attention
|
| 120 |
-
"""
|
| 121 |
-
return ["sliding_attention"] * self.num_hidden_layers
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
class Qwen3TTSTokenizerV2Config(PretrainedConfig):
|
| 125 |
-
"""
|
| 126 |
-
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2Config`]. It is used to instantiate a Qwen3TTSTokenizerV2Model
|
| 127 |
-
model according to the specified sub-models configurations, defining the model architecture.
|
| 128 |
-
|
| 129 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 130 |
-
documentation from [`PretrainedConfig`] for more information.
|
| 131 |
-
|
| 132 |
-
Args:
|
| 133 |
-
encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model.
|
| 134 |
-
decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model.
|
| 135 |
-
"""
|
| 136 |
-
|
| 137 |
-
model_type = "qwen3_tts_tokenizer_12hz"
|
| 138 |
-
sub_configs = {
|
| 139 |
-
"encoder_config": MimiConfig,
|
| 140 |
-
"decoder_config": Qwen3TTSTokenizerV2DecoderConfig,
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
def __init__(
|
| 144 |
-
self,
|
| 145 |
-
encoder_config=None,
|
| 146 |
-
decoder_config=None,
|
| 147 |
-
encoder_valid_num_quantizers=16,
|
| 148 |
-
input_sample_rate=24000,
|
| 149 |
-
output_sample_rate=24000,
|
| 150 |
-
decode_upsample_rate=1920,
|
| 151 |
-
encode_downsample_rate=1920,
|
| 152 |
-
**kwargs,
|
| 153 |
-
):
|
| 154 |
-
super().__init__(**kwargs)
|
| 155 |
-
if encoder_config is None:
|
| 156 |
-
encoder_config = {}
|
| 157 |
-
logger.info("encoder_config is None. Initializing encoder with default values")
|
| 158 |
-
if decoder_config is None:
|
| 159 |
-
decoder_config = {}
|
| 160 |
-
logger.info("decoder_config is None. Initializing decoder with default values")
|
| 161 |
-
|
| 162 |
-
self.encoder_config = MimiConfig(**encoder_config)
|
| 163 |
-
self.decoder_config = Qwen3TTSTokenizerV2DecoderConfig(**decoder_config)
|
| 164 |
-
|
| 165 |
-
self.encoder_valid_num_quantizers = encoder_valid_num_quantizers
|
| 166 |
-
self.input_sample_rate = input_sample_rate
|
| 167 |
-
self.output_sample_rate = output_sample_rate
|
| 168 |
-
self.decode_upsample_rate = decode_upsample_rate
|
| 169 |
-
self.encode_downsample_rate = encode_downsample_rate
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
__all__ = ["Qwen3TTSTokenizerV2Config", "Qwen3TTSTokenizerV2DecoderConfig"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py
DELETED
|
@@ -1,1025 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
"""PyTorch Qwen3TTSTokenizerV2 model."""
|
| 16 |
-
|
| 17 |
-
import math
|
| 18 |
-
from dataclasses import dataclass
|
| 19 |
-
from typing import Callable, Optional, Union, List
|
| 20 |
-
|
| 21 |
-
import numpy as np
|
| 22 |
-
import torch
|
| 23 |
-
from torch import nn
|
| 24 |
-
from torch.nn import Parameter
|
| 25 |
-
from torch.nn import functional as F
|
| 26 |
-
from transformers import MimiConfig, MimiModel
|
| 27 |
-
from transformers.activations import ACT2FN
|
| 28 |
-
from transformers.cache_utils import Cache, DynamicCache
|
| 29 |
-
from transformers.integrations import use_kernel_forward_from_hub
|
| 30 |
-
from transformers.masking_utils import (
|
| 31 |
-
create_causal_mask,
|
| 32 |
-
create_sliding_window_causal_mask,
|
| 33 |
-
)
|
| 34 |
-
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 35 |
-
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 36 |
-
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 37 |
-
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 38 |
-
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 39 |
-
from transformers.processing_utils import Unpack
|
| 40 |
-
from transformers.utils import ModelOutput, auto_docstring, logging
|
| 41 |
-
from transformers.utils.deprecation import deprecate_kwarg
|
| 42 |
-
from transformers.utils.generic import check_model_inputs
|
| 43 |
-
|
| 44 |
-
from .configuration_qwen3_tts_tokenizer_v2 import (
|
| 45 |
-
Qwen3TTSTokenizerV2Config,
|
| 46 |
-
Qwen3TTSTokenizerV2DecoderConfig,
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
logger = logging.get_logger(__name__)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
@dataclass
|
| 53 |
-
@auto_docstring
|
| 54 |
-
class Qwen3TTSTokenizerV2EncoderOutput(ModelOutput):
|
| 55 |
-
r"""
|
| 56 |
-
audio_codes (`List[torch.LongTensor]`):
|
| 57 |
-
Discret code embeddings computed using `model.encode`, each tensor has shape (codes_length_i, num_quantizers).
|
| 58 |
-
"""
|
| 59 |
-
|
| 60 |
-
audio_codes: List[torch.LongTensor] = None
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
@dataclass
|
| 64 |
-
@auto_docstring
|
| 65 |
-
class Qwen3TTSTokenizerV2DecoderOutput(ModelOutput):
|
| 66 |
-
r"""
|
| 67 |
-
audio_values (`List[torch.FloatTensor]`):
|
| 68 |
-
Decoded audio values, obtained using the decoder part of Qwen3TTSTokenizerV1.
|
| 69 |
-
Each tensor has shape (segment_length_i).
|
| 70 |
-
"""
|
| 71 |
-
|
| 72 |
-
audio_values: List[torch.FloatTensor] = None
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def rotate_half(x):
|
| 76 |
-
"""Rotates half the hidden dims of the input."""
|
| 77 |
-
x1 = x[..., : x.shape[-1] // 2]
|
| 78 |
-
x2 = x[..., x.shape[-1] // 2 :]
|
| 79 |
-
return torch.cat((-x2, x1), dim=-1)
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 83 |
-
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 84 |
-
|
| 85 |
-
Args:
|
| 86 |
-
q (`torch.Tensor`): The query tensor.
|
| 87 |
-
k (`torch.Tensor`): The key tensor.
|
| 88 |
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 89 |
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 90 |
-
position_ids (`torch.Tensor`, *optional*):
|
| 91 |
-
Deprecated and unused.
|
| 92 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 93 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 94 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 95 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 96 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 97 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 98 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 99 |
-
Returns:
|
| 100 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 101 |
-
"""
|
| 102 |
-
cos = cos.unsqueeze(unsqueeze_dim)
|
| 103 |
-
sin = sin.unsqueeze(unsqueeze_dim)
|
| 104 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 105 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 106 |
-
return q_embed, k_embed
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 110 |
-
"""
|
| 111 |
-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 112 |
-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 113 |
-
"""
|
| 114 |
-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 115 |
-
if n_rep == 1:
|
| 116 |
-
return hidden_states
|
| 117 |
-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 118 |
-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def eager_attention_forward(
|
| 122 |
-
module: nn.Module,
|
| 123 |
-
query: torch.Tensor,
|
| 124 |
-
key: torch.Tensor,
|
| 125 |
-
value: torch.Tensor,
|
| 126 |
-
attention_mask: Optional[torch.Tensor],
|
| 127 |
-
scaling: float,
|
| 128 |
-
dropout: float = 0.0,
|
| 129 |
-
**kwargs,
|
| 130 |
-
):
|
| 131 |
-
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 132 |
-
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 133 |
-
|
| 134 |
-
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 135 |
-
if attention_mask is not None:
|
| 136 |
-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 137 |
-
attn_weights = attn_weights + causal_mask
|
| 138 |
-
|
| 139 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 140 |
-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 141 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
| 142 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 143 |
-
|
| 144 |
-
return attn_output, attn_weights
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
@auto_docstring
|
| 148 |
-
class Qwen3TTSTokenizerV2DecoderPreTrainedModel(PreTrainedModel):
|
| 149 |
-
config: Qwen3TTSTokenizerV2DecoderConfig
|
| 150 |
-
base_model_prefix = "model"
|
| 151 |
-
supports_gradient_checkpointing = True
|
| 152 |
-
_skip_keys_device_placement = "past_key_values"
|
| 153 |
-
_supports_flash_attn = True
|
| 154 |
-
_supports_sdpa = True
|
| 155 |
-
_can_compile_fullgraph = False
|
| 156 |
-
_supports_attention_backend = True
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
class Qwen3TTSTokenizerV2CausalConvNet(nn.Module):
|
| 160 |
-
def __init__(
|
| 161 |
-
self,
|
| 162 |
-
in_channels,
|
| 163 |
-
out_channels,
|
| 164 |
-
kernel_size,
|
| 165 |
-
dilation=1,
|
| 166 |
-
stride=1,
|
| 167 |
-
groups=1,
|
| 168 |
-
):
|
| 169 |
-
super().__init__()
|
| 170 |
-
self.conv = nn.Conv1d(
|
| 171 |
-
in_channels,
|
| 172 |
-
out_channels,
|
| 173 |
-
kernel_size,
|
| 174 |
-
stride=stride,
|
| 175 |
-
dilation=dilation,
|
| 176 |
-
groups=groups,
|
| 177 |
-
)
|
| 178 |
-
self.stride = stride
|
| 179 |
-
self.kernel_size = (kernel_size - 1) * dilation + 1
|
| 180 |
-
self.dilation = dilation
|
| 181 |
-
self.padding = self.kernel_size - self.stride
|
| 182 |
-
|
| 183 |
-
def _get_extra_padding_for_conv1d(self, hidden_state: torch.Tensor) -> int:
|
| 184 |
-
length = hidden_state.shape[-1]
|
| 185 |
-
n_frames = (length - self.kernel_size + self.padding) / self.stride + 1
|
| 186 |
-
ideal_length = (math.ceil(n_frames) - 1) * self.stride + (self.kernel_size - self.padding)
|
| 187 |
-
return ideal_length - length
|
| 188 |
-
|
| 189 |
-
def forward(self, hidden_state):
|
| 190 |
-
extra_padding = self._get_extra_padding_for_conv1d(hidden_state)
|
| 191 |
-
hidden_state = F.pad(hidden_state, (self.padding, extra_padding), mode="constant", value=0)
|
| 192 |
-
return self.conv(hidden_state).contiguous()
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
class Qwen3TTSTokenizerV2CausalTransConvNet(nn.Module):
|
| 196 |
-
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
|
| 197 |
-
super().__init__()
|
| 198 |
-
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride)
|
| 199 |
-
|
| 200 |
-
pad = kernel_size - stride
|
| 201 |
-
self.left_pad = math.ceil(pad)
|
| 202 |
-
self.right_pad = pad = self.left_pad
|
| 203 |
-
|
| 204 |
-
def forward(self, hidden_state):
|
| 205 |
-
hidden_state = self.conv(hidden_state)
|
| 206 |
-
hidden_state = hidden_state[..., self.left_pad : hidden_state.shape[-1] - self.right_pad]
|
| 207 |
-
return hidden_state.contiguous()
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
class Qwen3TTSTokenizerV2ConvNeXtBlock(nn.Module):
|
| 211 |
-
def __init__(self, dim: int):
|
| 212 |
-
super().__init__()
|
| 213 |
-
self.dwconv = Qwen3TTSTokenizerV2CausalConvNet(
|
| 214 |
-
dim,
|
| 215 |
-
dim,
|
| 216 |
-
kernel_size=7,
|
| 217 |
-
groups=dim,
|
| 218 |
-
dilation=1,
|
| 219 |
-
)
|
| 220 |
-
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 221 |
-
self.pwconv1 = nn.Linear(dim, 4 * dim)
|
| 222 |
-
self.act = nn.GELU()
|
| 223 |
-
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 224 |
-
self.gamma = nn.Parameter(1e-6 * torch.ones(dim))
|
| 225 |
-
|
| 226 |
-
def forward(self, hidden_states):
|
| 227 |
-
input = hidden_states
|
| 228 |
-
|
| 229 |
-
hidden_states = self.dwconv(hidden_states)
|
| 230 |
-
hidden_states = hidden_states.permute(0, 2, 1)
|
| 231 |
-
hidden_states = self.norm(hidden_states)
|
| 232 |
-
hidden_states = self.pwconv1(hidden_states)
|
| 233 |
-
hidden_states = self.act(hidden_states)
|
| 234 |
-
hidden_states = self.pwconv2(hidden_states)
|
| 235 |
-
|
| 236 |
-
hidden_states = self.gamma * hidden_states
|
| 237 |
-
|
| 238 |
-
hidden_states = hidden_states.permute(0, 2, 1)
|
| 239 |
-
|
| 240 |
-
hidden_states = input + hidden_states
|
| 241 |
-
|
| 242 |
-
return hidden_states
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
class Qwen3TTSTokenizerV2DecoderRotatoryEmbedding(nn.Module):
|
| 246 |
-
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 247 |
-
|
| 248 |
-
def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, device=None):
|
| 249 |
-
super().__init__()
|
| 250 |
-
# BC: "rope_type" was originally "type"
|
| 251 |
-
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 252 |
-
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 253 |
-
else:
|
| 254 |
-
self.rope_type = "default"
|
| 255 |
-
self.max_seq_len_cached = config.max_position_embeddings
|
| 256 |
-
self.original_max_seq_len = config.max_position_embeddings
|
| 257 |
-
|
| 258 |
-
self.config = config
|
| 259 |
-
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 260 |
-
|
| 261 |
-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 262 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 263 |
-
self.original_inv_freq = self.inv_freq
|
| 264 |
-
|
| 265 |
-
@torch.no_grad()
|
| 266 |
-
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 267 |
-
def forward(self, x, position_ids):
|
| 268 |
-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 269 |
-
position_ids_expanded = position_ids[:, None, :].float()
|
| 270 |
-
|
| 271 |
-
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 272 |
-
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 273 |
-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 274 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
| 275 |
-
cos = emb.cos() * self.attention_scaling
|
| 276 |
-
sin = emb.sin() * self.attention_scaling
|
| 277 |
-
|
| 278 |
-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
class Qwen3TTSTokenizerV2DecoderAttention(nn.Module):
|
| 282 |
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 283 |
-
|
| 284 |
-
def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx):
|
| 285 |
-
super().__init__()
|
| 286 |
-
self.config = config
|
| 287 |
-
self.layer_idx = layer_idx
|
| 288 |
-
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 289 |
-
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 290 |
-
self.scaling = self.head_dim**-0.5
|
| 291 |
-
self.attention_dropout = config.attention_dropout
|
| 292 |
-
self.is_causal = True
|
| 293 |
-
|
| 294 |
-
self.q_proj = nn.Linear(
|
| 295 |
-
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 296 |
-
)
|
| 297 |
-
self.k_proj = nn.Linear(
|
| 298 |
-
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 299 |
-
)
|
| 300 |
-
self.v_proj = nn.Linear(
|
| 301 |
-
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 302 |
-
)
|
| 303 |
-
self.o_proj = nn.Linear(
|
| 304 |
-
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 305 |
-
)
|
| 306 |
-
self.q_norm = nn.Identity()
|
| 307 |
-
self.k_norm = nn.Identity()
|
| 308 |
-
self.sliding_window = config.sliding_window
|
| 309 |
-
|
| 310 |
-
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 311 |
-
def forward(
|
| 312 |
-
self,
|
| 313 |
-
hidden_states: torch.Tensor,
|
| 314 |
-
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 315 |
-
attention_mask: Optional[torch.Tensor],
|
| 316 |
-
past_key_values: Optional[Cache] = None,
|
| 317 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 318 |
-
**kwargs: Unpack[FlashAttentionKwargs],
|
| 319 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 320 |
-
input_shape = hidden_states.shape[:-1]
|
| 321 |
-
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 322 |
-
|
| 323 |
-
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 324 |
-
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 325 |
-
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 326 |
-
|
| 327 |
-
cos, sin = position_embeddings
|
| 328 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 329 |
-
|
| 330 |
-
if past_key_values is not None:
|
| 331 |
-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 332 |
-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 333 |
-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 334 |
-
|
| 335 |
-
attention_interface: Callable = eager_attention_forward
|
| 336 |
-
if self.config._attn_implementation != "eager":
|
| 337 |
-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 338 |
-
|
| 339 |
-
attn_output, attn_weights = attention_interface(
|
| 340 |
-
self,
|
| 341 |
-
query_states,
|
| 342 |
-
key_states,
|
| 343 |
-
value_states,
|
| 344 |
-
attention_mask,
|
| 345 |
-
dropout=0.0 if not self.training else self.attention_dropout,
|
| 346 |
-
scaling=self.scaling,
|
| 347 |
-
sliding_window=self.sliding_window, # diff with Llama
|
| 348 |
-
**kwargs,
|
| 349 |
-
)
|
| 350 |
-
|
| 351 |
-
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 352 |
-
attn_output = self.o_proj(attn_output)
|
| 353 |
-
return attn_output, attn_weights
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
class Qwen3TTSTokenizerV2DecoderMlp(nn.Module):
|
| 357 |
-
def __init__(self, config):
|
| 358 |
-
super().__init__()
|
| 359 |
-
self.config = config
|
| 360 |
-
self.hidden_size = config.hidden_size
|
| 361 |
-
self.intermediate_size = config.intermediate_size
|
| 362 |
-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 363 |
-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 364 |
-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 365 |
-
self.act_fn = ACT2FN[config.hidden_act]
|
| 366 |
-
|
| 367 |
-
def forward(self, x):
|
| 368 |
-
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 369 |
-
return down_proj
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
@use_kernel_forward_from_hub("RMSNorm")
|
| 373 |
-
class Qwen3TTSTokenizerV2DecoderRMSNorm(nn.Module):
|
| 374 |
-
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
|
| 375 |
-
"""
|
| 376 |
-
Qwen3TTSTokenizerV2DecoderRMSNorm is equivalent to T5LayerNorm
|
| 377 |
-
"""
|
| 378 |
-
super().__init__()
|
| 379 |
-
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 380 |
-
self.variance_epsilon = eps
|
| 381 |
-
|
| 382 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 383 |
-
input_dtype = hidden_states.dtype
|
| 384 |
-
hidden_states = hidden_states.to(torch.float32)
|
| 385 |
-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 386 |
-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 387 |
-
return self.weight * hidden_states.to(input_dtype)
|
| 388 |
-
|
| 389 |
-
def extra_repr(self):
|
| 390 |
-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
class Qwen3TTSTokenizerV2DecoderLayerScale(nn.Module):
|
| 394 |
-
"""Layer scale from [Touvron et al 2021] (https://huggingface.co/papers/2103.17239).
|
| 395 |
-
This rescales diagonally the residual outputs close to 0, with a learnt scale.
|
| 396 |
-
"""
|
| 397 |
-
|
| 398 |
-
def __init__(self, config):
|
| 399 |
-
super().__init__()
|
| 400 |
-
channels = config.hidden_size
|
| 401 |
-
initial_scale = config.layer_scale_initial_scale
|
| 402 |
-
self.scale = nn.Parameter(torch.full((channels,), initial_scale, requires_grad=True))
|
| 403 |
-
|
| 404 |
-
def forward(self, x: torch.Tensor):
|
| 405 |
-
return self.scale * x
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
class Qwen3TTSTokenizerV2DecoderTransformerLayer(GradientCheckpointingLayer):
|
| 409 |
-
def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx):
|
| 410 |
-
super().__init__()
|
| 411 |
-
self.hidden_size = config.hidden_size
|
| 412 |
-
self.self_attn = Qwen3TTSTokenizerV2DecoderAttention(config, layer_idx)
|
| 413 |
-
self.mlp = Qwen3TTSTokenizerV2DecoderMlp(config)
|
| 414 |
-
self.input_layernorm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 415 |
-
self.post_attention_layernorm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 416 |
-
self.self_attn_layer_scale = Qwen3TTSTokenizerV2DecoderLayerScale(config)
|
| 417 |
-
self.mlp_layer_scale = Qwen3TTSTokenizerV2DecoderLayerScale(config)
|
| 418 |
-
self.attention_type = "sliding_attention"
|
| 419 |
-
|
| 420 |
-
def forward(
|
| 421 |
-
self,
|
| 422 |
-
hidden_states: torch.Tensor,
|
| 423 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 424 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 425 |
-
past_key_values: Optional[Cache] = None,
|
| 426 |
-
use_cache: Optional[bool] = False,
|
| 427 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 428 |
-
**kwargs,
|
| 429 |
-
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 430 |
-
"""
|
| 431 |
-
Args:
|
| 432 |
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 433 |
-
attention_mask (`torch.FloatTensor`, *optional*):
|
| 434 |
-
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
| 435 |
-
query_sequence_length, key_sequence_length)` if default attention is used.
|
| 436 |
-
output_attentions (`bool`, *optional*):
|
| 437 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 438 |
-
returned tensors for more detail.
|
| 439 |
-
use_cache (`bool`, *optional*):
|
| 440 |
-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 441 |
-
(see `past_key_values`).
|
| 442 |
-
past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 443 |
-
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 444 |
-
Indices depicting the position of the input sequence tokens in the sequence
|
| 445 |
-
kwargs (`dict`, *optional*):
|
| 446 |
-
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
| 447 |
-
into the model
|
| 448 |
-
"""
|
| 449 |
-
residual = hidden_states
|
| 450 |
-
|
| 451 |
-
hidden_states = self.input_layernorm(hidden_states)
|
| 452 |
-
|
| 453 |
-
# Self Attention
|
| 454 |
-
hidden_states, _ = self.self_attn(
|
| 455 |
-
hidden_states=hidden_states,
|
| 456 |
-
attention_mask=attention_mask,
|
| 457 |
-
position_ids=position_ids,
|
| 458 |
-
past_key_values=past_key_values,
|
| 459 |
-
use_cache=use_cache,
|
| 460 |
-
cache_position=cache_position,
|
| 461 |
-
**kwargs,
|
| 462 |
-
)
|
| 463 |
-
hidden_states = residual + self.self_attn_layer_scale(hidden_states)
|
| 464 |
-
|
| 465 |
-
# Fully Connected
|
| 466 |
-
residual = hidden_states
|
| 467 |
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 468 |
-
hidden_states = self.mlp(hidden_states)
|
| 469 |
-
hidden_states = residual + self.mlp_layer_scale(hidden_states)
|
| 470 |
-
|
| 471 |
-
return hidden_states
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
@auto_docstring
|
| 475 |
-
class Qwen3TTSTokenizerV2DecoderTransformerModel(Qwen3TTSTokenizerV2DecoderPreTrainedModel):
|
| 476 |
-
_can_record_outputs = {
|
| 477 |
-
"hidden_states": Qwen3TTSTokenizerV2DecoderTransformerLayer,
|
| 478 |
-
"attentions": Qwen3TTSTokenizerV2DecoderAttention,
|
| 479 |
-
}
|
| 480 |
-
|
| 481 |
-
def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig):
|
| 482 |
-
super().__init__(config)
|
| 483 |
-
self.layers = nn.ModuleList(
|
| 484 |
-
[Qwen3TTSTokenizerV2DecoderTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 485 |
-
)
|
| 486 |
-
self.norm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 487 |
-
self.rotary_emb = Qwen3TTSTokenizerV2DecoderRotatoryEmbedding(config=config)
|
| 488 |
-
self.gradient_checkpointing = False
|
| 489 |
-
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
| 490 |
-
self.window_size = config.sliding_window
|
| 491 |
-
|
| 492 |
-
self.input_proj = nn.Linear(config.latent_dim, config.hidden_size)
|
| 493 |
-
self.output_proj = nn.Linear(config.hidden_size, config.latent_dim)
|
| 494 |
-
|
| 495 |
-
# Initialize weights and apply final processing
|
| 496 |
-
self.post_init()
|
| 497 |
-
|
| 498 |
-
@check_model_inputs()
|
| 499 |
-
@auto_docstring
|
| 500 |
-
def forward(
|
| 501 |
-
self,
|
| 502 |
-
input_ids=None,
|
| 503 |
-
attention_mask=None,
|
| 504 |
-
position_ids=None,
|
| 505 |
-
past_key_values=None,
|
| 506 |
-
inputs_embeds=None,
|
| 507 |
-
use_cache=None,
|
| 508 |
-
cache_position=None,
|
| 509 |
-
**kwargs,
|
| 510 |
-
) -> BaseModelOutputWithPast:
|
| 511 |
-
if input_ids is not None:
|
| 512 |
-
raise ValueError("input_ids is not expected")
|
| 513 |
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 514 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 515 |
-
|
| 516 |
-
if inputs_embeds is None:
|
| 517 |
-
inputs_embeds = self.embed_tokens(input_ids)
|
| 518 |
-
|
| 519 |
-
inputs_embeds = self.input_proj(inputs_embeds)
|
| 520 |
-
|
| 521 |
-
if use_cache and past_key_values is None:
|
| 522 |
-
past_key_values = DynamicCache(config=self.config)
|
| 523 |
-
|
| 524 |
-
if cache_position is None:
|
| 525 |
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 526 |
-
cache_position = torch.arange(
|
| 527 |
-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 528 |
-
)
|
| 529 |
-
|
| 530 |
-
if position_ids is None:
|
| 531 |
-
position_ids = cache_position.unsqueeze(0)
|
| 532 |
-
|
| 533 |
-
# It may already have been prepared by e.g. `generate`
|
| 534 |
-
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 535 |
-
# Prepare mask arguments
|
| 536 |
-
mask_kwargs = {
|
| 537 |
-
"config": self.config,
|
| 538 |
-
"input_embeds": inputs_embeds,
|
| 539 |
-
"attention_mask": attention_mask,
|
| 540 |
-
"cache_position": cache_position,
|
| 541 |
-
"past_key_values": past_key_values,
|
| 542 |
-
"position_ids": position_ids,
|
| 543 |
-
}
|
| 544 |
-
# Create the masks
|
| 545 |
-
causal_mask_mapping = {
|
| 546 |
-
"full_attention": create_causal_mask(**mask_kwargs),
|
| 547 |
-
}
|
| 548 |
-
# The sliding window alternating layers are not always activated depending on the config
|
| 549 |
-
if self.has_sliding_layers:
|
| 550 |
-
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
| 551 |
-
|
| 552 |
-
hidden_states = inputs_embeds
|
| 553 |
-
|
| 554 |
-
# create position embeddings to be shared across the decoder layers
|
| 555 |
-
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 556 |
-
|
| 557 |
-
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 558 |
-
hidden_states = decoder_layer(
|
| 559 |
-
hidden_states,
|
| 560 |
-
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
| 561 |
-
position_ids=position_ids,
|
| 562 |
-
past_key_values=past_key_values,
|
| 563 |
-
use_cache=use_cache,
|
| 564 |
-
cache_position=cache_position,
|
| 565 |
-
position_embeddings=position_embeddings,
|
| 566 |
-
**kwargs,
|
| 567 |
-
)
|
| 568 |
-
|
| 569 |
-
hidden_states = self.norm(hidden_states)
|
| 570 |
-
hidden_states = self.output_proj(hidden_states)
|
| 571 |
-
return BaseModelOutputWithPast(
|
| 572 |
-
last_hidden_state=hidden_states,
|
| 573 |
-
past_key_values=past_key_values if use_cache else None,
|
| 574 |
-
)
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
class SnakeBeta(nn.Module):
|
| 578 |
-
"""
|
| 579 |
-
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 580 |
-
Shape:
|
| 581 |
-
- Input: (B, C, T)
|
| 582 |
-
- Output: (B, C, T), same shape as the input
|
| 583 |
-
Parameters:
|
| 584 |
-
- alpha - trainable parameter that controls frequency
|
| 585 |
-
- beta - trainable parameter that controls magnitude
|
| 586 |
-
References:
|
| 587 |
-
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 588 |
-
https://huggingface.co/papers/2006.08195
|
| 589 |
-
"""
|
| 590 |
-
|
| 591 |
-
def __init__(self, in_features, alpha=1.0):
|
| 592 |
-
super().__init__()
|
| 593 |
-
self.in_features = in_features
|
| 594 |
-
|
| 595 |
-
# initialize alpha
|
| 596 |
-
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 597 |
-
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
| 598 |
-
|
| 599 |
-
self.no_div_by_zero = 0.000000001
|
| 600 |
-
|
| 601 |
-
def forward(self, hidden_states):
|
| 602 |
-
"""
|
| 603 |
-
Forward pass of the function.
|
| 604 |
-
Applies the function to the input elementwise.
|
| 605 |
-
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
| 606 |
-
"""
|
| 607 |
-
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 608 |
-
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 609 |
-
alpha = torch.exp(alpha)
|
| 610 |
-
beta = torch.exp(beta)
|
| 611 |
-
hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
|
| 612 |
-
torch.sin(hidden_states * alpha), 2
|
| 613 |
-
)
|
| 614 |
-
|
| 615 |
-
return hidden_states
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
class Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(nn.Module):
|
| 619 |
-
def __init__(self, dim: int = 16, dilation: int = 1):
|
| 620 |
-
super().__init__()
|
| 621 |
-
|
| 622 |
-
self.act1 = SnakeBeta(dim)
|
| 623 |
-
self.conv1 = Qwen3TTSTokenizerV2CausalConvNet(dim, dim, kernel_size=7, dilation=dilation)
|
| 624 |
-
self.act2 = SnakeBeta(dim)
|
| 625 |
-
self.conv2 = Qwen3TTSTokenizerV2CausalConvNet(dim, dim, kernel_size=1)
|
| 626 |
-
|
| 627 |
-
def forward(self, hidden_state):
|
| 628 |
-
residual = hidden_state
|
| 629 |
-
|
| 630 |
-
hidden_state = self.act1(hidden_state)
|
| 631 |
-
hidden_state = self.conv1(hidden_state)
|
| 632 |
-
hidden_state = self.act2(hidden_state)
|
| 633 |
-
hidden_state = self.conv2(hidden_state)
|
| 634 |
-
return hidden_state + residual
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
class Qwen3TTSTokenizerV2DecoderDecoderBlock(Qwen3TTSTokenizerV2DecoderPreTrainedModel):
|
| 638 |
-
def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx):
|
| 639 |
-
super().__init__(config)
|
| 640 |
-
in_dim = config.decoder_dim // 2**layer_idx
|
| 641 |
-
out_dim = config.decoder_dim // 2 ** (layer_idx + 1)
|
| 642 |
-
upsample_rate = config.upsample_rates[layer_idx]
|
| 643 |
-
|
| 644 |
-
block = [
|
| 645 |
-
SnakeBeta(in_dim),
|
| 646 |
-
Qwen3TTSTokenizerV2CausalTransConvNet(in_dim, out_dim, 2 * upsample_rate, upsample_rate),
|
| 647 |
-
]
|
| 648 |
-
|
| 649 |
-
for dilation in (1, 3, 9):
|
| 650 |
-
block.append(Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(out_dim, dilation))
|
| 651 |
-
|
| 652 |
-
self.block = nn.ModuleList(block)
|
| 653 |
-
|
| 654 |
-
def forward(self, hidden):
|
| 655 |
-
for block in self.block:
|
| 656 |
-
hidden = block(hidden)
|
| 657 |
-
return hidden
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
class EuclideanCodebook(nn.Module):
|
| 661 |
-
def __init__(
|
| 662 |
-
self,
|
| 663 |
-
dim: int,
|
| 664 |
-
codebook_size: int,
|
| 665 |
-
epsilon: float = 1e-5,
|
| 666 |
-
):
|
| 667 |
-
super().__init__()
|
| 668 |
-
self.dim = dim
|
| 669 |
-
self.codebook_size = codebook_size
|
| 670 |
-
self.epsilon = epsilon
|
| 671 |
-
|
| 672 |
-
self.cluster_usage = nn.Parameter(torch.ones(codebook_size))
|
| 673 |
-
self.embedding_sum = nn.Parameter(torch.zeros(codebook_size, dim))
|
| 674 |
-
|
| 675 |
-
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 676 |
-
embedding = self.embedding_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None]
|
| 677 |
-
quantized = F.embedding(codes, embedding)
|
| 678 |
-
return quantized
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
class VectorQuantization(nn.Module):
|
| 682 |
-
def __init__(
|
| 683 |
-
self,
|
| 684 |
-
dim: int,
|
| 685 |
-
codebook_size: int,
|
| 686 |
-
codebook_dim: Optional[int] = None,
|
| 687 |
-
epsilon: float = 1e-5,
|
| 688 |
-
):
|
| 689 |
-
super().__init__()
|
| 690 |
-
if codebook_dim is None:
|
| 691 |
-
codebook_dim = dim
|
| 692 |
-
|
| 693 |
-
requires_projection = codebook_dim != dim
|
| 694 |
-
|
| 695 |
-
self.project_out = (
|
| 696 |
-
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
| 697 |
-
)
|
| 698 |
-
self.epsilon = epsilon
|
| 699 |
-
self._codebook = EuclideanCodebook(
|
| 700 |
-
dim=codebook_dim,
|
| 701 |
-
codebook_size=codebook_size,
|
| 702 |
-
epsilon=epsilon
|
| 703 |
-
)
|
| 704 |
-
self.codebook_size = codebook_size
|
| 705 |
-
|
| 706 |
-
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 707 |
-
quantized = self._codebook.decode(codes)
|
| 708 |
-
quantized = self.project_out(quantized)
|
| 709 |
-
quantized = quantized.transpose(1, 2)
|
| 710 |
-
return quantized
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
class ResidualVectorQuantization(nn.Module):
|
| 714 |
-
def __init__(self, *, num_quantizers: int, **kwargs):
|
| 715 |
-
super().__init__()
|
| 716 |
-
self.layers = nn.ModuleList(
|
| 717 |
-
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
| 718 |
-
)
|
| 719 |
-
|
| 720 |
-
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 721 |
-
quantized = torch.zeros([1], device=codes.device)[0]
|
| 722 |
-
for idx, layer_codes in enumerate(codes):
|
| 723 |
-
layer = self.layers[idx]
|
| 724 |
-
assert isinstance(layer, VectorQuantization)
|
| 725 |
-
quantized = quantized + layer.decode(layer_codes)
|
| 726 |
-
return quantized
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
class ResidualVectorQuantizer(nn.Module):
|
| 730 |
-
def __init__(
|
| 731 |
-
self,
|
| 732 |
-
dimension: int = 128,
|
| 733 |
-
input_dimension: Optional[int] = None,
|
| 734 |
-
output_dimension: Optional[int] = None,
|
| 735 |
-
n_q: int = 8,
|
| 736 |
-
q_dropout: bool = False,
|
| 737 |
-
no_quantization_rate: float = 0.0,
|
| 738 |
-
bins: int = 1024,
|
| 739 |
-
decay: float = 0.99,
|
| 740 |
-
force_projection: bool = False,
|
| 741 |
-
):
|
| 742 |
-
super().__init__()
|
| 743 |
-
self.max_n_q = n_q
|
| 744 |
-
self.n_q = n_q
|
| 745 |
-
self.q_dropout = q_dropout
|
| 746 |
-
self.no_quantization_rate = no_quantization_rate
|
| 747 |
-
self.dimension = dimension
|
| 748 |
-
self.input_dimension = input_dimension or dimension
|
| 749 |
-
self.output_dimension = output_dimension or dimension
|
| 750 |
-
self.bins = bins
|
| 751 |
-
self.decay = decay
|
| 752 |
-
self.input_proj: torch.nn.Module
|
| 753 |
-
self.output_proj: torch.nn.Module
|
| 754 |
-
if self.input_dimension == self.dimension and not force_projection:
|
| 755 |
-
self.input_proj = torch.nn.Identity()
|
| 756 |
-
else:
|
| 757 |
-
self.input_proj = torch.nn.Conv1d(
|
| 758 |
-
self.input_dimension, self.dimension, 1, bias=False
|
| 759 |
-
)
|
| 760 |
-
if self.output_dimension == self.dimension and not force_projection:
|
| 761 |
-
self.output_proj = torch.nn.Identity()
|
| 762 |
-
else:
|
| 763 |
-
self.output_proj = torch.nn.Conv1d(
|
| 764 |
-
self.dimension, self.output_dimension, 1, bias=False
|
| 765 |
-
)
|
| 766 |
-
self.vq = ResidualVectorQuantization(
|
| 767 |
-
dim=self.dimension,
|
| 768 |
-
codebook_size=self.bins,
|
| 769 |
-
num_quantizers=self.n_q
|
| 770 |
-
)
|
| 771 |
-
|
| 772 |
-
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 773 |
-
codes = codes.transpose(0, 1)
|
| 774 |
-
quantized = self.vq.decode(codes)
|
| 775 |
-
quantized = self.output_proj(quantized)
|
| 776 |
-
return quantized
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
class SplitResidualVectorQuantizer(nn.Module):
|
| 780 |
-
"""Residual Vector Quantizer with separate projections for the first quantizer and the rest.
|
| 781 |
-
|
| 782 |
-
Args:
|
| 783 |
-
n_q (int): Number of residual vector quantizers used.
|
| 784 |
-
n_semantic_q (int): Number of residual vector quantizers used for the semantic quantizer.
|
| 785 |
-
**kwargs: Arguments to the constructor of `ResidualVectorQuantizer` that are shared between both.
|
| 786 |
-
"""
|
| 787 |
-
|
| 788 |
-
def __init__(
|
| 789 |
-
self,
|
| 790 |
-
*,
|
| 791 |
-
n_q: int = 8,
|
| 792 |
-
n_q_semantic: int = 1,
|
| 793 |
-
**kwargs,
|
| 794 |
-
):
|
| 795 |
-
super().__init__()
|
| 796 |
-
assert n_q > n_q_semantic, (
|
| 797 |
-
f"Number of quantizers {n_q} must be larger "
|
| 798 |
-
f"than the number of semantic quantizers {n_q_semantic}."
|
| 799 |
-
)
|
| 800 |
-
self.max_n_q = n_q
|
| 801 |
-
self.n_q_semantic = n_q_semantic
|
| 802 |
-
self.n_q_acoustic = n_q - n_q_semantic
|
| 803 |
-
q_dropout = kwargs.pop("q_dropout", False)
|
| 804 |
-
self.rvq_first = ResidualVectorQuantizer(
|
| 805 |
-
n_q=n_q_semantic, force_projection=True, q_dropout=False, **kwargs
|
| 806 |
-
)
|
| 807 |
-
self.rvq_rest = ResidualVectorQuantizer(
|
| 808 |
-
n_q=n_q - n_q_semantic,
|
| 809 |
-
force_projection=True,
|
| 810 |
-
q_dropout=q_dropout,
|
| 811 |
-
**kwargs,
|
| 812 |
-
)
|
| 813 |
-
|
| 814 |
-
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 815 |
-
"""Decode the given codes to the quantized representation."""
|
| 816 |
-
# codes is [B, K, T], with T frames, K nb of codebooks.
|
| 817 |
-
quantized = self.rvq_first.decode(codes[:, : self.n_q_semantic])
|
| 818 |
-
if codes.shape[1] > self.n_q_semantic:
|
| 819 |
-
quantized += self.rvq_rest.decode(codes[:, self.n_q_semantic :])
|
| 820 |
-
return quantized
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
class Qwen3TTSTokenizerV2Decoder(Qwen3TTSTokenizerV2DecoderPreTrainedModel):
|
| 824 |
-
def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig):
|
| 825 |
-
super().__init__(config)
|
| 826 |
-
self.total_upsample = np.prod(config.upsample_rates + config.upsampling_ratios)
|
| 827 |
-
self.pre_transformer = Qwen3TTSTokenizerV2DecoderTransformerModel._from_config(config)
|
| 828 |
-
|
| 829 |
-
self.quantizer = SplitResidualVectorQuantizer(
|
| 830 |
-
dimension=config.codebook_dim // 2,
|
| 831 |
-
n_q=config.num_quantizers,
|
| 832 |
-
n_q_semantic=1,
|
| 833 |
-
bins=config.codebook_size,
|
| 834 |
-
input_dimension=config.codebook_dim,
|
| 835 |
-
output_dimension=config.codebook_dim,
|
| 836 |
-
)
|
| 837 |
-
|
| 838 |
-
self.pre_conv = Qwen3TTSTokenizerV2CausalConvNet(
|
| 839 |
-
config.codebook_dim,
|
| 840 |
-
config.latent_dim,
|
| 841 |
-
kernel_size=3,
|
| 842 |
-
)
|
| 843 |
-
|
| 844 |
-
upsample = []
|
| 845 |
-
for factor in config.upsampling_ratios:
|
| 846 |
-
upsample.append(
|
| 847 |
-
nn.ModuleList(
|
| 848 |
-
[
|
| 849 |
-
Qwen3TTSTokenizerV2CausalTransConvNet(config.latent_dim, config.latent_dim, factor, factor),
|
| 850 |
-
Qwen3TTSTokenizerV2ConvNeXtBlock(config.latent_dim),
|
| 851 |
-
]
|
| 852 |
-
)
|
| 853 |
-
)
|
| 854 |
-
self.upsample = nn.ModuleList(upsample)
|
| 855 |
-
|
| 856 |
-
decoder = [Qwen3TTSTokenizerV2CausalConvNet(config.latent_dim, config.decoder_dim, 7)]
|
| 857 |
-
for i in range(len(config.upsample_rates)):
|
| 858 |
-
decoder.append(Qwen3TTSTokenizerV2DecoderDecoderBlock(config, i))
|
| 859 |
-
output_dim = config.decoder_dim // 2 ** len(config.upsample_rates)
|
| 860 |
-
decoder += [
|
| 861 |
-
SnakeBeta(output_dim),
|
| 862 |
-
Qwen3TTSTokenizerV2CausalConvNet(output_dim, 1, 7),
|
| 863 |
-
]
|
| 864 |
-
self.decoder = nn.ModuleList(decoder)
|
| 865 |
-
|
| 866 |
-
self.post_init()
|
| 867 |
-
|
| 868 |
-
def forward(self, codes):
|
| 869 |
-
if codes.shape[1] != self.config.num_quantizers:
|
| 870 |
-
raise ValueError(f"Expected {self.config.num_quantizers} layer of codes, got {codes.shape[1]}")
|
| 871 |
-
|
| 872 |
-
hidden = self.quantizer.decode(codes)
|
| 873 |
-
hidden = self.pre_conv(hidden).transpose(1, 2)
|
| 874 |
-
|
| 875 |
-
hidden = self.pre_transformer(inputs_embeds=hidden).last_hidden_state
|
| 876 |
-
hidden = hidden.permute(0, 2, 1)
|
| 877 |
-
for blocks in self.upsample:
|
| 878 |
-
for block in blocks:
|
| 879 |
-
hidden = block(hidden)
|
| 880 |
-
wav = hidden
|
| 881 |
-
for block in self.decoder:
|
| 882 |
-
wav = block(wav)
|
| 883 |
-
return wav.clamp(min=-1, max=1)
|
| 884 |
-
|
| 885 |
-
def chunked_decode(self, codes, chunk_size=300, left_context_size=25):
|
| 886 |
-
wavs = []
|
| 887 |
-
start_index = 0
|
| 888 |
-
while start_index < codes.shape[-1]:
|
| 889 |
-
end_index = min(start_index + chunk_size, codes.shape[-1])
|
| 890 |
-
context_size = left_context_size if start_index - left_context_size > 0 else start_index
|
| 891 |
-
codes_chunk = codes[..., start_index - context_size : end_index]
|
| 892 |
-
wav_chunk = self(codes_chunk)
|
| 893 |
-
wavs.append(wav_chunk[..., context_size * self.total_upsample :])
|
| 894 |
-
start_index = end_index
|
| 895 |
-
return torch.cat(wavs, dim=-1)
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
class Qwen3TTSTokenizerV2Encoder(MimiModel):
|
| 899 |
-
def __init__(self, config: MimiConfig):
|
| 900 |
-
super().__init__(config)
|
| 901 |
-
self.config = config
|
| 902 |
-
|
| 903 |
-
self.upsample = None
|
| 904 |
-
self.decoder_transformer = None
|
| 905 |
-
self.decoder = None
|
| 906 |
-
|
| 907 |
-
self.post_init()
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
@auto_docstring
|
| 911 |
-
class Qwen3TTSTokenizerV2PreTrainedModel(PreTrainedModel):
|
| 912 |
-
config: Qwen3TTSTokenizerV2Config
|
| 913 |
-
base_model_prefix = "model"
|
| 914 |
-
supports_gradient_checkpointing = True
|
| 915 |
-
_skip_keys_device_placement = "past_key_values"
|
| 916 |
-
_supports_flash_attn = True
|
| 917 |
-
_supports_sdpa = True
|
| 918 |
-
_can_compile_fullgraph = False
|
| 919 |
-
_supports_attention_backend = True
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
@auto_docstring(
|
| 923 |
-
custom_intro="""
|
| 924 |
-
The Qwen3TTSTokenizerV2 model.
|
| 925 |
-
"""
|
| 926 |
-
)
|
| 927 |
-
class Qwen3TTSTokenizerV2Model(Qwen3TTSTokenizerV2PreTrainedModel):
|
| 928 |
-
def __init__(self, config: Qwen3TTSTokenizerV2Config):
|
| 929 |
-
super().__init__(config)
|
| 930 |
-
self.config = config
|
| 931 |
-
|
| 932 |
-
self.encoder_valid_num_quantizers = config.encoder_valid_num_quantizers
|
| 933 |
-
|
| 934 |
-
self.input_sample_rate = config.input_sample_rate
|
| 935 |
-
self.output_sample_rate = config.output_sample_rate
|
| 936 |
-
|
| 937 |
-
self.decode_upsample_rate = config.decode_upsample_rate
|
| 938 |
-
self.encode_downsample_rate = config.encode_downsample_rate
|
| 939 |
-
|
| 940 |
-
self.encoder = Qwen3TTSTokenizerV2Encoder._from_config(self.config.encoder_config)
|
| 941 |
-
self.decoder = Qwen3TTSTokenizerV2Decoder._from_config(self.config.decoder_config)
|
| 942 |
-
|
| 943 |
-
self.post_init()
|
| 944 |
-
|
| 945 |
-
def get_model_type(self):
|
| 946 |
-
return self.config.model_type
|
| 947 |
-
|
| 948 |
-
def get_input_sample_rate(self):
|
| 949 |
-
return self.input_sample_rate
|
| 950 |
-
|
| 951 |
-
def get_output_sample_rate(self):
|
| 952 |
-
return self.output_sample_rate
|
| 953 |
-
|
| 954 |
-
def get_encode_downsample_rate(self):
|
| 955 |
-
return self.encode_downsample_rate
|
| 956 |
-
|
| 957 |
-
def get_decode_upsample_rate(self):
|
| 958 |
-
return self.decode_upsample_rate
|
| 959 |
-
|
| 960 |
-
def encode(
|
| 961 |
-
self,
|
| 962 |
-
input_values: torch.Tensor,
|
| 963 |
-
padding_mask: Optional[torch.Tensor] = None,
|
| 964 |
-
return_dict: Optional[bool] = None,
|
| 965 |
-
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], Qwen3TTSTokenizerV2EncoderOutput]:
|
| 966 |
-
"""
|
| 967 |
-
Encodes the input audio waveform into discrete codes.
|
| 968 |
-
|
| 969 |
-
Args:
|
| 970 |
-
input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
| 971 |
-
Float values of the input audio waveform.
|
| 972 |
-
padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
| 973 |
-
Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
|
| 974 |
-
for *masked*.
|
| 975 |
-
return_dict (`bool`, *optional*):
|
| 976 |
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 977 |
-
"""
|
| 978 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 979 |
-
|
| 980 |
-
encoded_frames = self.encoder.encode(input_values=input_values.unsqueeze(1),
|
| 981 |
-
return_dict=True)
|
| 982 |
-
audio_codes = encoded_frames.audio_codes[:, :self.encoder_valid_num_quantizers]
|
| 983 |
-
audio_codes = [code[..., :-(-mask.sum() // self.encode_downsample_rate)].transpose(0, 1) for code, mask in zip(audio_codes, padding_mask)]
|
| 984 |
-
|
| 985 |
-
if not return_dict:
|
| 986 |
-
return (
|
| 987 |
-
audio_codes,
|
| 988 |
-
)
|
| 989 |
-
|
| 990 |
-
return Qwen3TTSTokenizerV2EncoderOutput(audio_codes)
|
| 991 |
-
|
| 992 |
-
def decode(
|
| 993 |
-
self,
|
| 994 |
-
audio_codes: torch.Tensor,
|
| 995 |
-
return_dict: Optional[bool] = None,
|
| 996 |
-
) -> Union[tuple[torch.Tensor, torch.Tensor], Qwen3TTSTokenizerV2DecoderOutput]:
|
| 997 |
-
"""
|
| 998 |
-
Decodes the given frames into an output audio waveform.
|
| 999 |
-
|
| 1000 |
-
Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
|
| 1001 |
-
trimmed.
|
| 1002 |
-
|
| 1003 |
-
Args:
|
| 1004 |
-
audio_codes (`torch.LongTensor` of shape `(batch_size, codes_length, num_quantizers)`, *optional*):
|
| 1005 |
-
Discret code embeddings computed using `model.encode`.
|
| 1006 |
-
return_dict (`bool`, *optional*):
|
| 1007 |
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1008 |
-
|
| 1009 |
-
"""
|
| 1010 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1011 |
-
|
| 1012 |
-
audio_values = self.decoder.chunked_decode(audio_codes.transpose(1, 2)).squeeze(1)
|
| 1013 |
-
|
| 1014 |
-
audio_lengths = (audio_codes[..., 0] > 0).sum(1) * self.decode_upsample_rate
|
| 1015 |
-
audio_values = [a[:l] for a, l in zip(audio_values, audio_lengths)]
|
| 1016 |
-
|
| 1017 |
-
if not return_dict:
|
| 1018 |
-
return (
|
| 1019 |
-
audio_values,
|
| 1020 |
-
)
|
| 1021 |
-
|
| 1022 |
-
return Qwen3TTSTokenizerV2DecoderOutput(audio_values)
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
__all__ = ["Qwen3TTSTokenizerV2Model", "Qwen3TTSTokenizerV2PreTrainedModel"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py
DELETED
|
@@ -1,332 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
"""Qwen3TTSTokenizerV1 model configuration"""
|
| 16 |
-
|
| 17 |
-
from transformers.configuration_utils import PretrainedConfig
|
| 18 |
-
from transformers.utils import logging
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
logger = logging.get_logger(__name__)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class Qwen3TTSTokenizerV1DecoderDiTConfig(PretrainedConfig):
|
| 25 |
-
r"""
|
| 26 |
-
This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavDiT.
|
| 27 |
-
It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens.
|
| 28 |
-
|
| 29 |
-
Args:
|
| 30 |
-
hidden_size (`int`, *optional*, defaults to 1024):
|
| 31 |
-
The dimension of the model.
|
| 32 |
-
num_hidden_layers (`int`, *optional*, defaults to 22):
|
| 33 |
-
The number of transformer blocks in the DiT model.
|
| 34 |
-
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 35 |
-
The number of attention heads in each transformer block.
|
| 36 |
-
ff_mult (`int`, *optional*, defaults to 2):
|
| 37 |
-
The multiplier for the feedforward layer in each transformer block.
|
| 38 |
-
emb_dim (`int`, *optional*, defaults to 512):
|
| 39 |
-
The dimension of the embedding layer.
|
| 40 |
-
head_dim (`int`, *optional*, defaults to 64):
|
| 41 |
-
The dimension of each attention head.
|
| 42 |
-
repeats (`int`, *optional*, defaults to 2):
|
| 43 |
-
The number of times the codec embeddings are repeated.
|
| 44 |
-
num_embeds (`int`, *optional*, defaults to 8193):
|
| 45 |
-
The number of unique embeddings in the codec.
|
| 46 |
-
mel_dim (`int`, *optional*, defaults to 80):
|
| 47 |
-
The dimension of the mel-spectrogram.
|
| 48 |
-
dropout (`float`, *optional*, defaults to 0.1):
|
| 49 |
-
The dropout rate for the transformer blocks.
|
| 50 |
-
|
| 51 |
-
enc_emb_dim (`int`, *optional*, defaults to 192):
|
| 52 |
-
The dimension of the pre-trained speaker embedding.
|
| 53 |
-
enc_dim (`int`, *optional*, defaults to 128):
|
| 54 |
-
The dimension of the encoder output.
|
| 55 |
-
enc_channels (`list[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`):
|
| 56 |
-
A list of output channels for each TDNN/SERes2Net layer in the encoder.
|
| 57 |
-
enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
|
| 58 |
-
A list of kernel sizes for each layer in the encoder.
|
| 59 |
-
enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
|
| 60 |
-
A list of dilations for each layer in the encoder.
|
| 61 |
-
enc_attention_channels (`int`, *optional*, defaults to 64):
|
| 62 |
-
The number of attention channels in the SqueezeExcitationBlock.
|
| 63 |
-
enc_res2net_scale (`int`, *optional*, defaults to 2):
|
| 64 |
-
The scale of the Res2Net block in the encoder.
|
| 65 |
-
enc_se_channels (`int`, *optional*, defaults to 64):
|
| 66 |
-
The number of output channels after squeeze in the SqueezeExcitationBlock.
|
| 67 |
-
"""
|
| 68 |
-
|
| 69 |
-
model_type = "qwen3_tts_tokenizer_v1_decoder_dit"
|
| 70 |
-
|
| 71 |
-
def __init__(
|
| 72 |
-
self,
|
| 73 |
-
hidden_size=1024,
|
| 74 |
-
num_hidden_layers=22,
|
| 75 |
-
num_attention_heads=16,
|
| 76 |
-
ff_mult=2,
|
| 77 |
-
emb_dim=512,
|
| 78 |
-
head_dim=64,
|
| 79 |
-
rope_theta=10000.0,
|
| 80 |
-
max_position_embeddings=32768,
|
| 81 |
-
block_size=24,
|
| 82 |
-
look_ahead_layers=[10],
|
| 83 |
-
look_backward_layers=[0, 20],
|
| 84 |
-
repeats=2,
|
| 85 |
-
num_embeds=8193,
|
| 86 |
-
mel_dim=80,
|
| 87 |
-
dropout=0.1,
|
| 88 |
-
enc_emb_dim=192,
|
| 89 |
-
enc_dim=128,
|
| 90 |
-
enc_channels=[256, 256, 256, 256, 768],
|
| 91 |
-
enc_kernel_sizes=[5, 3, 3, 3, 1],
|
| 92 |
-
enc_dilations=[1, 2, 3, 4, 1],
|
| 93 |
-
enc_attention_channels=64,
|
| 94 |
-
enc_res2net_scale=2,
|
| 95 |
-
enc_se_channels=64,
|
| 96 |
-
**kwargs,
|
| 97 |
-
):
|
| 98 |
-
self.hidden_size = hidden_size
|
| 99 |
-
self.num_hidden_layers = num_hidden_layers
|
| 100 |
-
self.num_attention_heads = num_attention_heads
|
| 101 |
-
self.ff_mult = ff_mult
|
| 102 |
-
self.emb_dim = emb_dim
|
| 103 |
-
self.head_dim = head_dim
|
| 104 |
-
self.rope_theta = rope_theta
|
| 105 |
-
self.max_position_embeddings = max_position_embeddings
|
| 106 |
-
self.block_size = block_size
|
| 107 |
-
self.look_ahead_layers = look_ahead_layers
|
| 108 |
-
self.look_backward_layers = look_backward_layers
|
| 109 |
-
self.repeats = repeats
|
| 110 |
-
self.num_embeds = num_embeds
|
| 111 |
-
self.mel_dim = mel_dim
|
| 112 |
-
self.dropout = dropout
|
| 113 |
-
self.enc_emb_dim = enc_emb_dim
|
| 114 |
-
self.enc_dim = enc_dim
|
| 115 |
-
self.enc_channels = enc_channels
|
| 116 |
-
self.enc_kernel_sizes = enc_kernel_sizes
|
| 117 |
-
self.enc_dilations = enc_dilations
|
| 118 |
-
self.enc_attention_channels = enc_attention_channels
|
| 119 |
-
self.enc_res2net_scale = enc_res2net_scale
|
| 120 |
-
self.enc_se_channels = enc_se_channels
|
| 121 |
-
super().__init__(**kwargs)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
class Qwen3TTSTokenizerV1DecoderBigVGANConfig(PretrainedConfig):
|
| 125 |
-
r"""
|
| 126 |
-
This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavBigVGAN module.
|
| 127 |
-
It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms.
|
| 128 |
-
|
| 129 |
-
Args:
|
| 130 |
-
mel_dim (`int`, *optional*, defaults to 80):
|
| 131 |
-
The dimension of the mel-spectrogram.
|
| 132 |
-
upsample_initial_channel (`int`, *optional*, defaults to 1536):
|
| 133 |
-
The number of channels in the initial upsampling layer.
|
| 134 |
-
resblock_kernel_sizes (`list[int]`, *optional*, defaults to `[3, 7, 11]`):
|
| 135 |
-
A list of kernel sizes for each residual block.
|
| 136 |
-
resblock_dilation_sizes (`list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
|
| 137 |
-
A list of dilation sizes for each residual block.
|
| 138 |
-
upsample_rates (`list[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`):
|
| 139 |
-
A list of upsampling rates for each upsampling layer.
|
| 140 |
-
upsample_kernel_sizes (`list[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`):
|
| 141 |
-
A list of kernel sizes for each upsampling layer.
|
| 142 |
-
"""
|
| 143 |
-
|
| 144 |
-
model_type = "qwen3_tts_tokenizer_v1_decoder_bigvgan"
|
| 145 |
-
|
| 146 |
-
def __init__(
|
| 147 |
-
self,
|
| 148 |
-
mel_dim=80,
|
| 149 |
-
upsample_initial_channel=1536,
|
| 150 |
-
resblock_kernel_sizes=[3, 7, 11],
|
| 151 |
-
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 152 |
-
upsample_rates=[5, 3, 2, 2, 2, 2],
|
| 153 |
-
upsample_kernel_sizes=[11, 7, 4, 4, 4, 4],
|
| 154 |
-
**kwargs,
|
| 155 |
-
):
|
| 156 |
-
self.mel_dim = mel_dim
|
| 157 |
-
self.upsample_initial_channel = upsample_initial_channel
|
| 158 |
-
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 159 |
-
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 160 |
-
self.upsample_rates = upsample_rates
|
| 161 |
-
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 162 |
-
super().__init__(**kwargs)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
class Qwen3TTSTokenizerV1DecoderConfig(PretrainedConfig):
|
| 166 |
-
r"""
|
| 167 |
-
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1DecoderConfig`].
|
| 168 |
-
|
| 169 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 170 |
-
documentation from [`PretrainedConfig`] for more information.
|
| 171 |
-
|
| 172 |
-
Args:
|
| 173 |
-
dit_config ([`DiT_Args`], *optional*):
|
| 174 |
-
Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms.
|
| 175 |
-
bigvgan_config ([`BigVGAN_Args`], *optional*):
|
| 176 |
-
Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms.
|
| 177 |
-
"""
|
| 178 |
-
|
| 179 |
-
model_type = "qwen3_tts_tokenizer_v1_decoder"
|
| 180 |
-
sub_configs = {
|
| 181 |
-
"dit_config": Qwen3TTSTokenizerV1DecoderDiTConfig,
|
| 182 |
-
"bigvgan_config": Qwen3TTSTokenizerV1DecoderBigVGANConfig,
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
def __init__(self, dit_config=None, bigvgan_config=None, **kwargs):
|
| 186 |
-
if dit_config is None:
|
| 187 |
-
dit_config = {}
|
| 188 |
-
if bigvgan_config is None:
|
| 189 |
-
bigvgan_config = {}
|
| 190 |
-
self.dit_config = Qwen3TTSTokenizerV1DecoderDiTConfig(**dit_config)
|
| 191 |
-
self.bigvgan_config = Qwen3TTSTokenizerV1DecoderBigVGANConfig(**bigvgan_config)
|
| 192 |
-
super().__init__(**kwargs)
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
class Qwen3TTSTokenizerV1EncoderConfig(PretrainedConfig):
|
| 196 |
-
r"""
|
| 197 |
-
This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1 Encoder.
|
| 198 |
-
|
| 199 |
-
The encoder typically takes mel-spectrogram features and produces high-level audio representations, then (optionally)
|
| 200 |
-
applies an Audio-VQ module (e.g., GRVQ) to discretize continuous representations into codes.
|
| 201 |
-
|
| 202 |
-
Args:
|
| 203 |
-
n_mels (`int`, *optional*, defaults to 128):
|
| 204 |
-
Number of mel bins in the input mel-spectrogram.
|
| 205 |
-
n_ctx (`int`, *optional*, defaults to 1500):
|
| 206 |
-
Maximum input sequence length (in frames/tokens) for the encoder.
|
| 207 |
-
n_state (`int`, *optional*, defaults to 1280):
|
| 208 |
-
Hidden size (model dimension) of the encoder transformer.
|
| 209 |
-
n_head (`int`, *optional*, defaults to 20):
|
| 210 |
-
Number of attention heads in each transformer layer.
|
| 211 |
-
n_layer (`int`, *optional*, defaults to 32):
|
| 212 |
-
Number of transformer layers.
|
| 213 |
-
n_window (`int`, *optional*, defaults to 100):
|
| 214 |
-
Window size used by the model for local attention / chunking (implementation-dependent).
|
| 215 |
-
output_dim (`int`, *optional*, defaults to 3584):
|
| 216 |
-
Output feature dimension produced by the encoder head (before/after projection, implementation-dependent).
|
| 217 |
-
|
| 218 |
-
grad_checkpointing (`bool`, *optional*, defaults to `False`):
|
| 219 |
-
Whether to enable gradient checkpointing to reduce memory usage during training.
|
| 220 |
-
enable_mp (`bool`, *optional*, defaults to `False`):
|
| 221 |
-
Whether to enable model parallel features (implementation-dependent).
|
| 222 |
-
audio_sequence_parallel (`bool`, *optional*, defaults to `False`):
|
| 223 |
-
Whether to enable sequence parallelism for audio branch (implementation-dependent).
|
| 224 |
-
|
| 225 |
-
audio_vq_type (`str`, *optional*, defaults to `"GRVQ"`):
|
| 226 |
-
Type of audio vector-quantization module. Common choices: `"GRVQ"`, `"RVQ"`, etc.
|
| 227 |
-
audio_vq_layers (`int`, *optional*, defaults to 6):
|
| 228 |
-
Number of VQ layers / quantizers (e.g., number of residual quantizers for RVQ/GRVQ-like designs).
|
| 229 |
-
audio_vq_codebook_size (`int`, *optional*, defaults to 32768):
|
| 230 |
-
Size of each codebook (number of entries).
|
| 231 |
-
audio_vq_codebook_dim (`int`, *optional*, defaults to 1280):
|
| 232 |
-
Dimension of codebook vectors (often equals encoder hidden size).
|
| 233 |
-
audio_vq_pe (`bool`, *optional*, defaults to `True`):
|
| 234 |
-
Whether to use positional encoding (or position embeddings) inside the VQ module.
|
| 235 |
-
audio_vq_ds_rate (`int`, *optional*, defaults to 2):
|
| 236 |
-
Downsampling rate applied before VQ (e.g., temporal downsample factor).
|
| 237 |
-
"""
|
| 238 |
-
|
| 239 |
-
model_type = "qwen3_tts_tokenizer_v1_encoder"
|
| 240 |
-
|
| 241 |
-
def __init__(
|
| 242 |
-
self,
|
| 243 |
-
n_mels=128,
|
| 244 |
-
n_ctx=1500,
|
| 245 |
-
n_state=1280,
|
| 246 |
-
n_head=20,
|
| 247 |
-
n_layer=32,
|
| 248 |
-
n_window=100,
|
| 249 |
-
output_dim=3584,
|
| 250 |
-
grad_checkpointing=False,
|
| 251 |
-
enable_mp=False,
|
| 252 |
-
audio_sequence_parallel=False,
|
| 253 |
-
audio_vq_type="GRVQ",
|
| 254 |
-
audio_vq_layers=6,
|
| 255 |
-
audio_vq_codebook_size=32768,
|
| 256 |
-
audio_vq_codebook_dim=1280,
|
| 257 |
-
audio_vq_pe=True,
|
| 258 |
-
audio_vq_ds_rate=2,
|
| 259 |
-
**kwargs,
|
| 260 |
-
):
|
| 261 |
-
super().__init__(**kwargs)
|
| 262 |
-
self.n_mels = n_mels
|
| 263 |
-
self.n_ctx = n_ctx
|
| 264 |
-
self.n_state = n_state
|
| 265 |
-
self.n_head = n_head
|
| 266 |
-
self.n_layer = n_layer
|
| 267 |
-
self.n_window = n_window
|
| 268 |
-
self.output_dim = output_dim
|
| 269 |
-
self.grad_checkpointing = grad_checkpointing
|
| 270 |
-
self.enable_mp = enable_mp
|
| 271 |
-
self.audio_sequence_parallel = audio_sequence_parallel
|
| 272 |
-
self.audio_vq_type = audio_vq_type
|
| 273 |
-
self.audio_vq_layers = audio_vq_layers
|
| 274 |
-
self.audio_vq_codebook_size = audio_vq_codebook_size
|
| 275 |
-
self.audio_vq_codebook_dim = audio_vq_codebook_dim
|
| 276 |
-
self.audio_vq_pe = audio_vq_pe
|
| 277 |
-
self.audio_vq_ds_rate = audio_vq_ds_rate
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
class Qwen3TTSTokenizerV1Config(PretrainedConfig):
|
| 281 |
-
"""
|
| 282 |
-
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1Config`]. It is used to instantiate a Qwen3TTSTokenizerV1Model
|
| 283 |
-
model according to the specified sub-models configurations, defining the model architecture.
|
| 284 |
-
|
| 285 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 286 |
-
documentation from [`PretrainedConfig`] for more information.
|
| 287 |
-
|
| 288 |
-
Args:
|
| 289 |
-
encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model.
|
| 290 |
-
decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model.
|
| 291 |
-
"""
|
| 292 |
-
|
| 293 |
-
model_type = "qwen3_tts_tokenizer_25hz"
|
| 294 |
-
sub_configs = {
|
| 295 |
-
"encoder_config": Qwen3TTSTokenizerV1EncoderConfig,
|
| 296 |
-
"decoder_config": Qwen3TTSTokenizerV1DecoderConfig,
|
| 297 |
-
}
|
| 298 |
-
|
| 299 |
-
def __init__(
|
| 300 |
-
self,
|
| 301 |
-
encoder_config=None,
|
| 302 |
-
decoder_config=None,
|
| 303 |
-
input_sample_rate=24000,
|
| 304 |
-
output_sample_rate=24000,
|
| 305 |
-
decode_upsample_rate=1920,
|
| 306 |
-
encode_downsample_rate=1920,
|
| 307 |
-
**kwargs,
|
| 308 |
-
):
|
| 309 |
-
super().__init__(**kwargs)
|
| 310 |
-
if encoder_config is None:
|
| 311 |
-
encoder_config = {}
|
| 312 |
-
logger.info("encoder_config is None. Initializing encoder with default values")
|
| 313 |
-
if decoder_config is None:
|
| 314 |
-
decoder_config = {}
|
| 315 |
-
logger.info("decoder_config is None. Initializing decoder with default values")
|
| 316 |
-
|
| 317 |
-
self.encoder_config = Qwen3TTSTokenizerV1EncoderConfig(**encoder_config)
|
| 318 |
-
self.decoder_config = Qwen3TTSTokenizerV1DecoderConfig(**decoder_config)
|
| 319 |
-
|
| 320 |
-
self.input_sample_rate = input_sample_rate
|
| 321 |
-
self.output_sample_rate = output_sample_rate
|
| 322 |
-
self.decode_upsample_rate = decode_upsample_rate
|
| 323 |
-
self.encode_downsample_rate = encode_downsample_rate
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
__all__ = [
|
| 327 |
-
"Qwen3TTSTokenizerV1Config",
|
| 328 |
-
"Qwen3TTSTokenizerV1EncoderConfig",
|
| 329 |
-
"Qwen3TTSTokenizerV1DecoderConfig",
|
| 330 |
-
"Qwen3TTSTokenizerV1DecoderBigVGANConfig",
|
| 331 |
-
"Qwen3TTSTokenizerV1DecoderDiTConfig"
|
| 332 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py
DELETED
|
@@ -1,1528 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
"""PyTorch Qwen3TTSTokenizerV1 model."""
|
| 16 |
-
|
| 17 |
-
import math
|
| 18 |
-
from dataclasses import dataclass
|
| 19 |
-
from typing import Optional, Union, List
|
| 20 |
-
|
| 21 |
-
import numpy as np
|
| 22 |
-
import torch
|
| 23 |
-
from torch import nn
|
| 24 |
-
from torch.nn import Parameter
|
| 25 |
-
from torch.nn import functional as F
|
| 26 |
-
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 27 |
-
from transformers.utils import ModelOutput, auto_docstring, logging
|
| 28 |
-
from transformers.utils.hub import cached_file
|
| 29 |
-
|
| 30 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 31 |
-
|
| 32 |
-
from .vq.whisper_encoder import get_mel_audio, get_T_after_cnn
|
| 33 |
-
from .vq.speech_vq import WhisperEncoderVQ, XVectorExtractor
|
| 34 |
-
|
| 35 |
-
from .configuration_qwen3_tts_tokenizer_v1 import (
|
| 36 |
-
Qwen3TTSTokenizerV1Config,
|
| 37 |
-
Qwen3TTSTokenizerV1EncoderConfig,
|
| 38 |
-
Qwen3TTSTokenizerV1DecoderConfig,
|
| 39 |
-
Qwen3TTSTokenizerV1DecoderBigVGANConfig,
|
| 40 |
-
Qwen3TTSTokenizerV1DecoderDiTConfig
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
logger = logging.get_logger(__name__)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
@dataclass
|
| 47 |
-
@auto_docstring
|
| 48 |
-
class Qwen3TTSTokenizerV1EncoderOutput(ModelOutput):
|
| 49 |
-
r"""
|
| 50 |
-
audio_codes (`List[torch.LongTensor]`):
|
| 51 |
-
Discret code embeddings computed using `model.encode`, each tensor has shape (codes_length_i,).
|
| 52 |
-
xvectors (`List[torch.FloatTensor]`):
|
| 53 |
-
X-vector embeddings computed using `model.encode`, each tensor has shape (xvector_dim,).
|
| 54 |
-
ref_mels (`List[torch.FloatTensor]`):
|
| 55 |
-
Reference mel spectrogram computed using `model.encode`, each tensor has shape (mel_length_i, mel_dim,).
|
| 56 |
-
"""
|
| 57 |
-
|
| 58 |
-
audio_codes: List[torch.LongTensor] = None
|
| 59 |
-
xvectors: List[torch.FloatTensor] = None
|
| 60 |
-
ref_mels: List[torch.FloatTensor] = None
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
@dataclass
|
| 64 |
-
@auto_docstring
|
| 65 |
-
class Qwen3TTSTokenizerV1DecoderOutput(ModelOutput):
|
| 66 |
-
r"""
|
| 67 |
-
audio_values (`List[torch.FloatTensor]`):
|
| 68 |
-
Decoded audio values, obtained using the decoder part of Qwen3TTSTokenizerV1.
|
| 69 |
-
Each tensor has shape (segment_length_i).
|
| 70 |
-
"""
|
| 71 |
-
|
| 72 |
-
audio_values: List[torch.FloatTensor] = None
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
@auto_docstring
|
| 76 |
-
class Qwen3TTSTokenizerV1DecoderPreTrainedModel(PreTrainedModel):
|
| 77 |
-
config: Qwen3TTSTokenizerV1DecoderConfig
|
| 78 |
-
base_model_prefix = "model"
|
| 79 |
-
supports_gradient_checkpointing = True
|
| 80 |
-
_skip_keys_device_placement = "past_key_values"
|
| 81 |
-
_supports_flash_attn = True
|
| 82 |
-
_supports_sdpa = True
|
| 83 |
-
_can_compile_fullgraph = False
|
| 84 |
-
_supports_attention_backend = True
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
@auto_docstring
|
| 88 |
-
class Qwen3TTSTokenizerV1EncoderPreTrainedModel(PreTrainedModel):
|
| 89 |
-
config: Qwen3TTSTokenizerV1EncoderConfig
|
| 90 |
-
base_model_prefix = "model"
|
| 91 |
-
supports_gradient_checkpointing = True
|
| 92 |
-
_skip_keys_device_placement = "past_key_values"
|
| 93 |
-
_supports_flash_attn = True
|
| 94 |
-
_supports_sdpa = True
|
| 95 |
-
_can_compile_fullgraph = False
|
| 96 |
-
_supports_attention_backend = True
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
class Qwen3TTSTokenizerV1DecoderDiTRotaryEmbedding(nn.Module):
|
| 100 |
-
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 101 |
-
|
| 102 |
-
def __init__(self, dim, base=10000):
|
| 103 |
-
super().__init__()
|
| 104 |
-
|
| 105 |
-
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 106 |
-
self.register_buffer("inv_freq", inv_freq)
|
| 107 |
-
|
| 108 |
-
def forward(self, x):
|
| 109 |
-
batch_size, seq_len = x.shape[0], x.shape[1]
|
| 110 |
-
t = torch.arange(seq_len, device=x.device)
|
| 111 |
-
device_type = x.device.type
|
| 112 |
-
device_type = device_type if device_type != "mps" else "cpu"
|
| 113 |
-
with torch.autocast(device_type=device_type, enabled=False):
|
| 114 |
-
freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float()
|
| 115 |
-
freqs = torch.stack((freqs, freqs), dim=-1)
|
| 116 |
-
freqs = freqs.reshape(*freqs.shape[:-2], -1)
|
| 117 |
-
freqs = freqs.repeat(batch_size, *([1] * freqs.dim()))
|
| 118 |
-
cos = freqs.cos()
|
| 119 |
-
sin = freqs.sin()
|
| 120 |
-
|
| 121 |
-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
class TimeDelayNetBlock(nn.Module):
|
| 125 |
-
def __init__(
|
| 126 |
-
self,
|
| 127 |
-
in_channels,
|
| 128 |
-
out_channels,
|
| 129 |
-
kernel_size,
|
| 130 |
-
dilation,
|
| 131 |
-
):
|
| 132 |
-
super().__init__()
|
| 133 |
-
self.conv = nn.Conv1d(
|
| 134 |
-
in_channels=in_channels,
|
| 135 |
-
out_channels=out_channels,
|
| 136 |
-
kernel_size=kernel_size,
|
| 137 |
-
dilation=dilation,
|
| 138 |
-
padding="same",
|
| 139 |
-
padding_mode="reflect",
|
| 140 |
-
)
|
| 141 |
-
self.activation = nn.ReLU()
|
| 142 |
-
|
| 143 |
-
def forward(self, hidden_states: torch.Tensor):
|
| 144 |
-
return self.activation(self.conv(hidden_states))
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
class Res2NetBlock(torch.nn.Module):
|
| 148 |
-
def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
|
| 149 |
-
super().__init__()
|
| 150 |
-
|
| 151 |
-
in_channel = in_channels // scale
|
| 152 |
-
hidden_channel = out_channels // scale
|
| 153 |
-
|
| 154 |
-
self.blocks = nn.ModuleList(
|
| 155 |
-
[
|
| 156 |
-
TimeDelayNetBlock(
|
| 157 |
-
in_channel,
|
| 158 |
-
hidden_channel,
|
| 159 |
-
kernel_size=kernel_size,
|
| 160 |
-
dilation=dilation,
|
| 161 |
-
)
|
| 162 |
-
for i in range(scale - 1)
|
| 163 |
-
]
|
| 164 |
-
)
|
| 165 |
-
self.scale = scale
|
| 166 |
-
|
| 167 |
-
def forward(self, hidden_states):
|
| 168 |
-
outputs = []
|
| 169 |
-
for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)):
|
| 170 |
-
if i == 0:
|
| 171 |
-
output_part = hidden_part
|
| 172 |
-
elif i == 1:
|
| 173 |
-
output_part = self.blocks[i - 1](hidden_part)
|
| 174 |
-
else:
|
| 175 |
-
output_part = self.blocks[i - 1](hidden_part + output_part)
|
| 176 |
-
outputs.append(output_part)
|
| 177 |
-
output = torch.cat(outputs, dim=1)
|
| 178 |
-
return output
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
class SqueezeExcitationBlock(nn.Module):
|
| 182 |
-
def __init__(self, in_channels, se_channels, out_channels):
|
| 183 |
-
super().__init__()
|
| 184 |
-
|
| 185 |
-
self.conv1 = nn.Conv1d(
|
| 186 |
-
in_channels=in_channels,
|
| 187 |
-
out_channels=se_channels,
|
| 188 |
-
kernel_size=1,
|
| 189 |
-
padding="same",
|
| 190 |
-
padding_mode="reflect",
|
| 191 |
-
)
|
| 192 |
-
self.relu = nn.ReLU(inplace=True)
|
| 193 |
-
self.conv2 = nn.Conv1d(
|
| 194 |
-
in_channels=se_channels,
|
| 195 |
-
out_channels=out_channels,
|
| 196 |
-
kernel_size=1,
|
| 197 |
-
padding="same",
|
| 198 |
-
padding_mode="reflect",
|
| 199 |
-
)
|
| 200 |
-
self.sigmoid = nn.Sigmoid()
|
| 201 |
-
|
| 202 |
-
def forward(self, hidden_states):
|
| 203 |
-
hidden_states_mean = hidden_states.mean(dim=2, keepdim=True)
|
| 204 |
-
|
| 205 |
-
hidden_states_mean = self.relu(self.conv1(hidden_states_mean))
|
| 206 |
-
hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean))
|
| 207 |
-
|
| 208 |
-
return hidden_states * hidden_states_mean
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
class AttentiveStatisticsPooling(nn.Module):
|
| 212 |
-
"""This class implements an attentive statistic pooling layer for each channel.
|
| 213 |
-
It returns the concatenated mean and std of the input tensor.
|
| 214 |
-
"""
|
| 215 |
-
|
| 216 |
-
def __init__(self, channels, attention_channels=128):
|
| 217 |
-
super().__init__()
|
| 218 |
-
|
| 219 |
-
self.eps = 1e-12
|
| 220 |
-
self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1)
|
| 221 |
-
self.tanh = nn.Tanh()
|
| 222 |
-
self.conv = nn.Conv1d(
|
| 223 |
-
in_channels=attention_channels,
|
| 224 |
-
out_channels=channels,
|
| 225 |
-
kernel_size=1,
|
| 226 |
-
padding="same",
|
| 227 |
-
padding_mode="reflect",
|
| 228 |
-
)
|
| 229 |
-
|
| 230 |
-
def _length_to_mask(self, length, max_len=None, dtype=None, device=None):
|
| 231 |
-
"""Creates a binary mask for each sequence.
|
| 232 |
-
|
| 233 |
-
Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
|
| 234 |
-
|
| 235 |
-
Arguments
|
| 236 |
-
---------
|
| 237 |
-
length : torch.LongTensor
|
| 238 |
-
Containing the length of each sequence in the batch. Must be 1D.
|
| 239 |
-
max_len : int
|
| 240 |
-
Max length for the mask, also the size of the second dimension.
|
| 241 |
-
dtype : torch.dtype, default: None
|
| 242 |
-
The dtype of the generated mask.
|
| 243 |
-
device: torch.device, default: None
|
| 244 |
-
The device to put the mask variable.
|
| 245 |
-
|
| 246 |
-
Returns
|
| 247 |
-
-------
|
| 248 |
-
mask : tensor
|
| 249 |
-
The binary mask.
|
| 250 |
-
"""
|
| 251 |
-
|
| 252 |
-
if max_len is None:
|
| 253 |
-
max_len = length.max().long().item() # using arange to generate mask
|
| 254 |
-
mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
|
| 255 |
-
len(length), max_len
|
| 256 |
-
) < length.unsqueeze(1)
|
| 257 |
-
|
| 258 |
-
mask = torch.as_tensor(mask, dtype=dtype, device=device)
|
| 259 |
-
return mask
|
| 260 |
-
|
| 261 |
-
def _compute_statistics(self, x, m, dim=2):
|
| 262 |
-
mean = (m * x).sum(dim)
|
| 263 |
-
std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps))
|
| 264 |
-
return mean, std
|
| 265 |
-
|
| 266 |
-
def forward(self, hidden_states):
|
| 267 |
-
seq_length = hidden_states.shape[-1]
|
| 268 |
-
lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device)
|
| 269 |
-
|
| 270 |
-
# Make binary mask of shape [N, 1, L]
|
| 271 |
-
mask = self._length_to_mask(
|
| 272 |
-
lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device
|
| 273 |
-
)
|
| 274 |
-
mask = mask.unsqueeze(1)
|
| 275 |
-
|
| 276 |
-
# Expand the temporal context of the pooling layer by allowing the
|
| 277 |
-
# self-attention to look at global properties of the utterance.
|
| 278 |
-
total = mask.sum(dim=2, keepdim=True)
|
| 279 |
-
|
| 280 |
-
mean, std = self._compute_statistics(hidden_states, mask / total)
|
| 281 |
-
mean = mean.unsqueeze(2).repeat(1, 1, seq_length)
|
| 282 |
-
std = std.unsqueeze(2).repeat(1, 1, seq_length)
|
| 283 |
-
attention = torch.cat([hidden_states, mean, std], dim=1)
|
| 284 |
-
|
| 285 |
-
# Apply layers
|
| 286 |
-
attention = self.conv(self.tanh(self.tdnn(attention)))
|
| 287 |
-
|
| 288 |
-
# Filter out zero-paddings
|
| 289 |
-
attention = attention.masked_fill(mask == 0, float("-inf"))
|
| 290 |
-
|
| 291 |
-
attention = F.softmax(attention, dim=2)
|
| 292 |
-
mean, std = self._compute_statistics(hidden_states, attention)
|
| 293 |
-
# Append mean and std of the batch
|
| 294 |
-
pooled_stats = torch.cat((mean, std), dim=1)
|
| 295 |
-
pooled_stats = pooled_stats.unsqueeze(2)
|
| 296 |
-
|
| 297 |
-
return pooled_stats
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
class SqueezeExcitationRes2NetBlock(nn.Module):
|
| 301 |
-
"""An implementation of building block in ECAPA-TDNN, i.e.,
|
| 302 |
-
TDNN-Res2Net-TDNN-SqueezeExcitationBlock.
|
| 303 |
-
"""
|
| 304 |
-
|
| 305 |
-
def __init__(
|
| 306 |
-
self,
|
| 307 |
-
in_channels,
|
| 308 |
-
out_channels,
|
| 309 |
-
res2net_scale=8,
|
| 310 |
-
se_channels=128,
|
| 311 |
-
kernel_size=1,
|
| 312 |
-
dilation=1,
|
| 313 |
-
):
|
| 314 |
-
super().__init__()
|
| 315 |
-
self.out_channels = out_channels
|
| 316 |
-
self.tdnn1 = TimeDelayNetBlock(
|
| 317 |
-
in_channels,
|
| 318 |
-
out_channels,
|
| 319 |
-
kernel_size=1,
|
| 320 |
-
dilation=1,
|
| 321 |
-
)
|
| 322 |
-
self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
|
| 323 |
-
self.tdnn2 = TimeDelayNetBlock(
|
| 324 |
-
out_channels,
|
| 325 |
-
out_channels,
|
| 326 |
-
kernel_size=1,
|
| 327 |
-
dilation=1,
|
| 328 |
-
)
|
| 329 |
-
self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels)
|
| 330 |
-
|
| 331 |
-
def forward(self, hidden_state):
|
| 332 |
-
residual = hidden_state
|
| 333 |
-
|
| 334 |
-
hidden_state = self.tdnn1(hidden_state)
|
| 335 |
-
hidden_state = self.res2net_block(hidden_state)
|
| 336 |
-
hidden_state = self.tdnn2(hidden_state)
|
| 337 |
-
hidden_state = self.se_block(hidden_state)
|
| 338 |
-
|
| 339 |
-
return hidden_state + residual
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
class ECAPA_TimeDelayNet(torch.nn.Module):
|
| 343 |
-
"""An implementation of the speaker embedding model in a paper.
|
| 344 |
-
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
|
| 345 |
-
TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143).
|
| 346 |
-
"""
|
| 347 |
-
|
| 348 |
-
def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
|
| 349 |
-
super().__init__()
|
| 350 |
-
if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len(
|
| 351 |
-
config.enc_dilations
|
| 352 |
-
):
|
| 353 |
-
raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length")
|
| 354 |
-
self.channels = config.enc_channels
|
| 355 |
-
self.blocks = nn.ModuleList()
|
| 356 |
-
|
| 357 |
-
# The initial TDNN layer
|
| 358 |
-
self.blocks.append(
|
| 359 |
-
TimeDelayNetBlock(
|
| 360 |
-
config.mel_dim,
|
| 361 |
-
config.enc_channels[0],
|
| 362 |
-
config.enc_kernel_sizes[0],
|
| 363 |
-
config.enc_dilations[0],
|
| 364 |
-
)
|
| 365 |
-
)
|
| 366 |
-
|
| 367 |
-
# SE-Res2Net layers
|
| 368 |
-
for i in range(1, len(config.enc_channels) - 1):
|
| 369 |
-
self.blocks.append(
|
| 370 |
-
SqueezeExcitationRes2NetBlock(
|
| 371 |
-
config.enc_channels[i - 1],
|
| 372 |
-
config.enc_channels[i],
|
| 373 |
-
res2net_scale=config.enc_res2net_scale,
|
| 374 |
-
se_channels=config.enc_se_channels,
|
| 375 |
-
kernel_size=config.enc_kernel_sizes[i],
|
| 376 |
-
dilation=config.enc_dilations[i],
|
| 377 |
-
)
|
| 378 |
-
)
|
| 379 |
-
|
| 380 |
-
# Multi-layer feature aggregation
|
| 381 |
-
self.mfa = TimeDelayNetBlock(
|
| 382 |
-
config.enc_channels[-1],
|
| 383 |
-
config.enc_channels[-1],
|
| 384 |
-
config.enc_kernel_sizes[-1],
|
| 385 |
-
config.enc_dilations[-1],
|
| 386 |
-
)
|
| 387 |
-
|
| 388 |
-
# Attentive Statistical Pooling
|
| 389 |
-
self.asp = AttentiveStatisticsPooling(
|
| 390 |
-
config.enc_channels[-1],
|
| 391 |
-
attention_channels=config.enc_attention_channels,
|
| 392 |
-
)
|
| 393 |
-
|
| 394 |
-
# Final linear transformation
|
| 395 |
-
self.fc = nn.Conv1d(
|
| 396 |
-
in_channels=config.enc_channels[-1] * 2,
|
| 397 |
-
out_channels=config.enc_dim,
|
| 398 |
-
kernel_size=1,
|
| 399 |
-
padding="same",
|
| 400 |
-
padding_mode="reflect",
|
| 401 |
-
)
|
| 402 |
-
|
| 403 |
-
def forward(self, hidden_states):
|
| 404 |
-
# Minimize transpose for efficiency
|
| 405 |
-
hidden_states = hidden_states.transpose(1, 2)
|
| 406 |
-
|
| 407 |
-
hidden_states_list = []
|
| 408 |
-
for layer in self.blocks:
|
| 409 |
-
hidden_states = layer(hidden_states)
|
| 410 |
-
hidden_states_list.append(hidden_states)
|
| 411 |
-
|
| 412 |
-
# Multi-layer feature aggregation
|
| 413 |
-
hidden_states = torch.cat(hidden_states_list[1:], dim=1)
|
| 414 |
-
hidden_states = self.mfa(hidden_states)
|
| 415 |
-
|
| 416 |
-
# Attentive Statistical Pooling
|
| 417 |
-
hidden_states = self.asp(hidden_states)
|
| 418 |
-
|
| 419 |
-
# Final linear transformation
|
| 420 |
-
hidden_states = self.fc(hidden_states)
|
| 421 |
-
|
| 422 |
-
hidden_states = hidden_states.squeeze(-1)
|
| 423 |
-
return hidden_states
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
class DiTInputEmbedding(nn.Module):
|
| 427 |
-
def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
|
| 428 |
-
super().__init__()
|
| 429 |
-
self.proj = nn.Linear(
|
| 430 |
-
config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim,
|
| 431 |
-
config.hidden_size,
|
| 432 |
-
)
|
| 433 |
-
self.spk_encoder = ECAPA_TimeDelayNet(config)
|
| 434 |
-
|
| 435 |
-
def forward(
|
| 436 |
-
self,
|
| 437 |
-
hidden_states: torch.Tensor,
|
| 438 |
-
speaker_embedding: torch.Tensor,
|
| 439 |
-
condition_vector: torch.Tensor,
|
| 440 |
-
code_embed: torch.Tensor,
|
| 441 |
-
drop_audio_cond: Optional[bool] = False,
|
| 442 |
-
code_embed_uncond: Optional[bool] = None,
|
| 443 |
-
apply_cfg: Optional[bool] = True,
|
| 444 |
-
):
|
| 445 |
-
if apply_cfg:
|
| 446 |
-
hidden_states = torch.cat([hidden_states, hidden_states], dim=0)
|
| 447 |
-
speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0)
|
| 448 |
-
condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0)
|
| 449 |
-
code_embed = torch.cat([code_embed, code_embed_uncond], dim=0)
|
| 450 |
-
elif drop_audio_cond: # cfg for cond audio
|
| 451 |
-
condition_vector = torch.zeros_like(condition_vector)
|
| 452 |
-
speaker_embedding = torch.zeros_like(speaker_embedding)
|
| 453 |
-
condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1)
|
| 454 |
-
hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1))
|
| 455 |
-
|
| 456 |
-
return hidden_states
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
# Transformer backbone using DiT blocks
|
| 460 |
-
class DiTCodecEmbedding(nn.Module):
|
| 461 |
-
def __init__(self, codec_num_embeds, codec_dim, repeats):
|
| 462 |
-
super().__init__()
|
| 463 |
-
self.repeats = repeats
|
| 464 |
-
self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim)
|
| 465 |
-
|
| 466 |
-
def forward(self, code, drop_code=False):
|
| 467 |
-
if drop_code:
|
| 468 |
-
code = torch.zeros_like(code)
|
| 469 |
-
code_embed = self.codec_embed(code)
|
| 470 |
-
|
| 471 |
-
code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1)
|
| 472 |
-
return code_embed
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
# AdaLayerNormZero
|
| 476 |
-
# return with modulated x for attn input, and params for later mlp modulation
|
| 477 |
-
class AdaLayerNormZero(nn.Module):
|
| 478 |
-
def __init__(self, dim):
|
| 479 |
-
super().__init__()
|
| 480 |
-
|
| 481 |
-
self.silu = nn.SiLU()
|
| 482 |
-
self.linear = nn.Linear(dim, dim * 6)
|
| 483 |
-
|
| 484 |
-
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 485 |
-
|
| 486 |
-
def forward(self, hidden_states, emb=None):
|
| 487 |
-
emb = self.linear(self.silu(emb))
|
| 488 |
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
| 489 |
-
|
| 490 |
-
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 491 |
-
return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
# AdaLayerNormZero for final layer
|
| 495 |
-
# return only with modulated x for attn input, cuz no more mlp modulation
|
| 496 |
-
class AdaLayerNormZero_Final(nn.Module):
|
| 497 |
-
def __init__(self, dim):
|
| 498 |
-
super().__init__()
|
| 499 |
-
|
| 500 |
-
self.silu = nn.SiLU()
|
| 501 |
-
self.linear = nn.Linear(dim, dim * 2)
|
| 502 |
-
|
| 503 |
-
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 504 |
-
|
| 505 |
-
def forward(self, hidden_states, emb):
|
| 506 |
-
emb = self.linear(self.silu(emb))
|
| 507 |
-
scale, shift = torch.chunk(emb, 2, dim=1)
|
| 508 |
-
|
| 509 |
-
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
| 510 |
-
return hidden_states
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
# FeedForward
|
| 514 |
-
class DiTMLP(nn.Module):
|
| 515 |
-
def __init__(self, dim, mult=4, dropout=0.0):
|
| 516 |
-
super().__init__()
|
| 517 |
-
inner_dim = int(dim * mult)
|
| 518 |
-
|
| 519 |
-
self.ff = nn.ModuleList(
|
| 520 |
-
[
|
| 521 |
-
nn.Linear(dim, inner_dim),
|
| 522 |
-
nn.GELU(approximate="tanh"),
|
| 523 |
-
nn.Dropout(dropout),
|
| 524 |
-
nn.Linear(inner_dim, dim),
|
| 525 |
-
]
|
| 526 |
-
)
|
| 527 |
-
|
| 528 |
-
def forward(self, hidden_states):
|
| 529 |
-
for layer in self.ff:
|
| 530 |
-
hidden_states = layer(hidden_states)
|
| 531 |
-
return hidden_states
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
# Modified from Llama with a different rotate function, will fixed in next release
|
| 535 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 536 |
-
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 537 |
-
|
| 538 |
-
Args:
|
| 539 |
-
q (`torch.Tensor`): The query tensor.
|
| 540 |
-
k (`torch.Tensor`): The key tensor.
|
| 541 |
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 542 |
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 543 |
-
position_ids (`torch.Tensor`, *optional*):
|
| 544 |
-
Deprecated and unused.
|
| 545 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 546 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 547 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 548 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 549 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 550 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 551 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 552 |
-
Returns:
|
| 553 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 554 |
-
"""
|
| 555 |
-
|
| 556 |
-
def rotate_half_codec(x):
|
| 557 |
-
# x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 558 |
-
x = x.reshape(*x.shape[:-1], -1, 2)
|
| 559 |
-
x1, x2 = x.unbind(dim=-1)
|
| 560 |
-
x = torch.stack((-x2, x1), dim=-1)
|
| 561 |
-
return x.reshape(*x.shape[:-2], -1)
|
| 562 |
-
|
| 563 |
-
cos = cos.unsqueeze(unsqueeze_dim)
|
| 564 |
-
sin = sin.unsqueeze(unsqueeze_dim)
|
| 565 |
-
q_embed = (q * cos) + (rotate_half_codec(q) * sin)
|
| 566 |
-
k_embed = (k * cos) + (rotate_half_codec(k) * sin)
|
| 567 |
-
return q_embed, k_embed
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
class DiTAttention(nn.Module):
|
| 571 |
-
def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
|
| 572 |
-
super().__init__()
|
| 573 |
-
|
| 574 |
-
self.config = config
|
| 575 |
-
self.dim = config.hidden_size
|
| 576 |
-
self.heads = config.num_attention_heads
|
| 577 |
-
self.inner_dim = config.head_dim * config.num_attention_heads
|
| 578 |
-
self.dropout = config.dropout
|
| 579 |
-
self.is_causal = False
|
| 580 |
-
|
| 581 |
-
self.to_q = nn.Linear(config.hidden_size, self.inner_dim)
|
| 582 |
-
self.to_k = nn.Linear(config.hidden_size, self.inner_dim)
|
| 583 |
-
self.to_v = nn.Linear(config.hidden_size, self.inner_dim)
|
| 584 |
-
|
| 585 |
-
self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)])
|
| 586 |
-
|
| 587 |
-
def forward(
|
| 588 |
-
self,
|
| 589 |
-
hidden_states, # noised input x
|
| 590 |
-
position_embeddings=None, # rotary position embedding for x
|
| 591 |
-
attention_mask=None,
|
| 592 |
-
) -> torch.Tensor:
|
| 593 |
-
batch_size = hidden_states.shape[0]
|
| 594 |
-
|
| 595 |
-
# `sample` projections.
|
| 596 |
-
query = self.to_q(hidden_states)
|
| 597 |
-
key = self.to_k(hidden_states)
|
| 598 |
-
value = self.to_v(hidden_states)
|
| 599 |
-
|
| 600 |
-
# attention
|
| 601 |
-
inner_dim = key.shape[-1]
|
| 602 |
-
head_dim = inner_dim // self.heads
|
| 603 |
-
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
| 604 |
-
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
| 605 |
-
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
| 606 |
-
|
| 607 |
-
# apply rotary position embedding
|
| 608 |
-
# Due to training process, only first head is applied with RoPE, will be fixed at next release
|
| 609 |
-
cos, sin = position_embeddings
|
| 610 |
-
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
| 611 |
-
|
| 612 |
-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 613 |
-
attention_weights, _ = attention_interface(
|
| 614 |
-
self,
|
| 615 |
-
query,
|
| 616 |
-
key,
|
| 617 |
-
value,
|
| 618 |
-
attention_mask=attention_mask,
|
| 619 |
-
is_causal=False,
|
| 620 |
-
)
|
| 621 |
-
|
| 622 |
-
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
| 623 |
-
attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim)
|
| 624 |
-
attention_weights = attention_weights.to(query.dtype)
|
| 625 |
-
|
| 626 |
-
# linear proj
|
| 627 |
-
attention_output = self.to_out[0](attention_weights)
|
| 628 |
-
attention_output = self.to_out[1](attention_output)
|
| 629 |
-
|
| 630 |
-
return attention_output
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
# time step conditioning embedding
|
| 634 |
-
class SinusPositionEmbedding(nn.Module):
|
| 635 |
-
def __init__(self, dim):
|
| 636 |
-
super().__init__()
|
| 637 |
-
self.dim = dim
|
| 638 |
-
|
| 639 |
-
def forward(self, hidden_states, scale=1000):
|
| 640 |
-
device = hidden_states.device
|
| 641 |
-
half_dim = self.dim // 2
|
| 642 |
-
emb = math.log(10000) / (half_dim - 1)
|
| 643 |
-
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
| 644 |
-
emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0)
|
| 645 |
-
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 646 |
-
return emb.type_as(hidden_states)
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
class DiTTimestepEmbedding(nn.Module):
|
| 650 |
-
def __init__(self, dim, freq_embed_dim=256):
|
| 651 |
-
super().__init__()
|
| 652 |
-
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
| 653 |
-
self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)])
|
| 654 |
-
|
| 655 |
-
def forward(self, timestep):
|
| 656 |
-
time_hidden = self.time_embed(timestep)
|
| 657 |
-
time_hidden = time_hidden.to(timestep.dtype)
|
| 658 |
-
for layer in self.time_mlp:
|
| 659 |
-
time_hidden = layer(time_hidden) # b d
|
| 660 |
-
return time_hidden
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
class DiTDecoderLayer(nn.Module):
|
| 664 |
-
def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig, look_ahead_block=0, look_backward_block=0):
|
| 665 |
-
super().__init__()
|
| 666 |
-
self.attn_norm = AdaLayerNormZero(config.hidden_size)
|
| 667 |
-
|
| 668 |
-
self.attn = DiTAttention(config)
|
| 669 |
-
self.look_ahead_block = look_ahead_block
|
| 670 |
-
self.look_backward_block = look_backward_block
|
| 671 |
-
self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6)
|
| 672 |
-
self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout)
|
| 673 |
-
|
| 674 |
-
def forward(
|
| 675 |
-
self, hidden_states, timestep, position_embeddings=None, block_diff=None
|
| 676 |
-
): # x: noised input, t: time embedding
|
| 677 |
-
# pre-norm & modulation for attention input
|
| 678 |
-
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep)
|
| 679 |
-
|
| 680 |
-
# attention
|
| 681 |
-
attn_output = self.attn(
|
| 682 |
-
hidden_states=norm,
|
| 683 |
-
position_embeddings=position_embeddings,
|
| 684 |
-
attention_mask=(block_diff >= -float(self.look_backward_block))
|
| 685 |
-
& (block_diff <= float(self.look_ahead_block)),
|
| 686 |
-
)
|
| 687 |
-
|
| 688 |
-
# process attention output for input x
|
| 689 |
-
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output
|
| 690 |
-
|
| 691 |
-
norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 692 |
-
ff_output = self.ff(norm)
|
| 693 |
-
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
| 694 |
-
|
| 695 |
-
return hidden_states
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
class SnakeBeta(nn.Module):
|
| 699 |
-
"""
|
| 700 |
-
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 701 |
-
Shape:
|
| 702 |
-
- Input: (B, C, T)
|
| 703 |
-
- Output: (B, C, T), same shape as the input
|
| 704 |
-
Parameters:
|
| 705 |
-
- alpha - trainable parameter that controls frequency
|
| 706 |
-
- beta - trainable parameter that controls magnitude
|
| 707 |
-
References:
|
| 708 |
-
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 709 |
-
https://huggingface.co/papers/2006.08195
|
| 710 |
-
"""
|
| 711 |
-
|
| 712 |
-
def __init__(self, in_features, alpha=1.0):
|
| 713 |
-
super().__init__()
|
| 714 |
-
self.in_features = in_features
|
| 715 |
-
|
| 716 |
-
# initialize alpha
|
| 717 |
-
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 718 |
-
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
| 719 |
-
|
| 720 |
-
self.no_div_by_zero = 0.000000001
|
| 721 |
-
|
| 722 |
-
def forward(self, hidden_states):
|
| 723 |
-
"""
|
| 724 |
-
Forward pass of the function.
|
| 725 |
-
Applies the function to the input elementwise.
|
| 726 |
-
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
| 727 |
-
"""
|
| 728 |
-
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 729 |
-
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 730 |
-
alpha = torch.exp(alpha)
|
| 731 |
-
beta = torch.exp(beta)
|
| 732 |
-
hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
|
| 733 |
-
torch.sin(hidden_states * alpha), 2
|
| 734 |
-
)
|
| 735 |
-
|
| 736 |
-
return hidden_states
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
|
| 740 |
-
"""Generates a 1D Kaiser-windowed sinc filter.
|
| 741 |
-
|
| 742 |
-
Args:
|
| 743 |
-
cutoff (float): Normalized cutoff frequency (0 to 0.5).
|
| 744 |
-
half_width (float): Transition bandwidth.
|
| 745 |
-
kernel_size (int): Number of filter taps.
|
| 746 |
-
|
| 747 |
-
Returns:
|
| 748 |
-
torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter.
|
| 749 |
-
"""
|
| 750 |
-
is_even = kernel_size % 2 == 0
|
| 751 |
-
half_size = kernel_size // 2
|
| 752 |
-
|
| 753 |
-
# Compute Kaiser window parameters
|
| 754 |
-
delta_f = 4 * half_width
|
| 755 |
-
attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
| 756 |
-
|
| 757 |
-
if attenuation > 50.0:
|
| 758 |
-
beta = 0.1102 * (attenuation - 8.7)
|
| 759 |
-
elif attenuation >= 21.0:
|
| 760 |
-
beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0)
|
| 761 |
-
else:
|
| 762 |
-
beta = 0.0
|
| 763 |
-
|
| 764 |
-
kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32)
|
| 765 |
-
|
| 766 |
-
# Compute time indices
|
| 767 |
-
if is_even:
|
| 768 |
-
time_indices = torch.arange(-half_size, half_size) + 0.5
|
| 769 |
-
else:
|
| 770 |
-
time_indices = torch.arange(kernel_size) - half_size
|
| 771 |
-
|
| 772 |
-
# Compute sinc filter
|
| 773 |
-
if cutoff == 0:
|
| 774 |
-
return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape
|
| 775 |
-
|
| 776 |
-
sinc_filter = torch.sinc(2 * cutoff * time_indices)
|
| 777 |
-
normalized_filter = 2 * cutoff * kaiser_window * sinc_filter
|
| 778 |
-
|
| 779 |
-
# Normalize to ensure sum = 1 (avoid leakage of constant component)
|
| 780 |
-
normalized_filter /= normalized_filter.sum()
|
| 781 |
-
|
| 782 |
-
return normalized_filter.view(1, 1, kernel_size)
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
class UpSample1d(nn.Module):
|
| 786 |
-
def __init__(self, ratio=2, kernel_size=None):
|
| 787 |
-
super().__init__()
|
| 788 |
-
self.ratio = ratio
|
| 789 |
-
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 790 |
-
self.stride = ratio
|
| 791 |
-
self.pad = self.kernel_size // ratio - 1
|
| 792 |
-
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
| 793 |
-
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
| 794 |
-
|
| 795 |
-
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
|
| 796 |
-
self.register_buffer("filter", filter, persistent=False)
|
| 797 |
-
|
| 798 |
-
def forward(self, hidden_states):
|
| 799 |
-
channels = hidden_states.shape[1]
|
| 800 |
-
|
| 801 |
-
hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate")
|
| 802 |
-
hidden_states = self.ratio * F.conv_transpose1d(
|
| 803 |
-
hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels
|
| 804 |
-
)
|
| 805 |
-
hidden_states = hidden_states[..., self.pad_left : -self.pad_right]
|
| 806 |
-
|
| 807 |
-
return hidden_states
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
class DownSample1d(nn.Module):
|
| 811 |
-
def __init__(self, ratio=2, kernel_size=None):
|
| 812 |
-
super().__init__()
|
| 813 |
-
cutoff = 0.5 / ratio
|
| 814 |
-
half_width = 0.6 / ratio
|
| 815 |
-
|
| 816 |
-
if cutoff < 0.0:
|
| 817 |
-
raise ValueError("Minimum cutoff must be larger than zero.")
|
| 818 |
-
if cutoff > 0.5:
|
| 819 |
-
raise ValueError("A cutoff above 0.5 does not make sense.")
|
| 820 |
-
|
| 821 |
-
self.even = kernel_size % 2 == 0
|
| 822 |
-
self.pad_left = kernel_size // 2 - int(self.even)
|
| 823 |
-
self.pad_right = kernel_size // 2
|
| 824 |
-
self.stride = ratio
|
| 825 |
-
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
| 826 |
-
self.register_buffer("filter", filter, persistent=False)
|
| 827 |
-
|
| 828 |
-
def forward(self, hidden_states):
|
| 829 |
-
channels = hidden_states.shape[1]
|
| 830 |
-
hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate")
|
| 831 |
-
out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels)
|
| 832 |
-
return out
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
class TorchActivation1d(nn.Module):
|
| 836 |
-
def __init__(
|
| 837 |
-
self,
|
| 838 |
-
activation,
|
| 839 |
-
up_ratio: int = 2,
|
| 840 |
-
down_ratio: int = 2,
|
| 841 |
-
up_kernel_size: int = 12,
|
| 842 |
-
down_kernel_size: int = 12,
|
| 843 |
-
):
|
| 844 |
-
super().__init__()
|
| 845 |
-
if not callable(activation):
|
| 846 |
-
raise TypeError("Activation function must be callable")
|
| 847 |
-
self.act = activation
|
| 848 |
-
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
| 849 |
-
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
| 850 |
-
|
| 851 |
-
def forward(self, hidden_states):
|
| 852 |
-
hidden_states = self.upsample(hidden_states)
|
| 853 |
-
hidden_states = self.act(hidden_states)
|
| 854 |
-
hidden_states = self.downsample(hidden_states)
|
| 855 |
-
|
| 856 |
-
return hidden_states
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
class CausalConv1d(nn.Conv1d):
|
| 860 |
-
def __init__(self, *args, **kwargs):
|
| 861 |
-
super().__init__(*args, **kwargs)
|
| 862 |
-
self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1)
|
| 863 |
-
|
| 864 |
-
def forward(self, x):
|
| 865 |
-
return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias)
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
class AMPBlock(torch.nn.Module):
|
| 869 |
-
def __init__(
|
| 870 |
-
self,
|
| 871 |
-
channels,
|
| 872 |
-
kernel_size=3,
|
| 873 |
-
dilation=(1, 3, 5),
|
| 874 |
-
causal_type='1',
|
| 875 |
-
):
|
| 876 |
-
super().__init__()
|
| 877 |
-
|
| 878 |
-
self.convs1 = nn.ModuleList(
|
| 879 |
-
[
|
| 880 |
-
CausalConv1d(
|
| 881 |
-
channels,
|
| 882 |
-
channels,
|
| 883 |
-
kernel_size,
|
| 884 |
-
1,
|
| 885 |
-
dilation=dilation[0],
|
| 886 |
-
),
|
| 887 |
-
CausalConv1d(
|
| 888 |
-
channels,
|
| 889 |
-
channels,
|
| 890 |
-
kernel_size,
|
| 891 |
-
1,
|
| 892 |
-
dilation=dilation[1],
|
| 893 |
-
),
|
| 894 |
-
CausalConv1d(
|
| 895 |
-
channels,
|
| 896 |
-
channels,
|
| 897 |
-
kernel_size,
|
| 898 |
-
1,
|
| 899 |
-
dilation=dilation[2],
|
| 900 |
-
),
|
| 901 |
-
]
|
| 902 |
-
)
|
| 903 |
-
|
| 904 |
-
if causal_type == '1':
|
| 905 |
-
self.convs2 = nn.ModuleList(
|
| 906 |
-
[
|
| 907 |
-
nn.Conv1d(
|
| 908 |
-
channels,
|
| 909 |
-
channels,
|
| 910 |
-
kernel_size,
|
| 911 |
-
1,
|
| 912 |
-
dilation=1,
|
| 913 |
-
padding=self._get_padding(kernel_size, 1),
|
| 914 |
-
),
|
| 915 |
-
nn.Conv1d(
|
| 916 |
-
channels,
|
| 917 |
-
channels,
|
| 918 |
-
kernel_size,
|
| 919 |
-
1,
|
| 920 |
-
dilation=1,
|
| 921 |
-
padding=self._get_padding(kernel_size, 1),
|
| 922 |
-
),
|
| 923 |
-
nn.Conv1d(
|
| 924 |
-
channels,
|
| 925 |
-
channels,
|
| 926 |
-
kernel_size,
|
| 927 |
-
1,
|
| 928 |
-
dilation=1,
|
| 929 |
-
padding=self._get_padding(kernel_size, 1),
|
| 930 |
-
),
|
| 931 |
-
]
|
| 932 |
-
)
|
| 933 |
-
else:
|
| 934 |
-
self.convs2 = nn.ModuleList(
|
| 935 |
-
[
|
| 936 |
-
CausalConv1d(
|
| 937 |
-
channels,
|
| 938 |
-
channels,
|
| 939 |
-
kernel_size,
|
| 940 |
-
1,
|
| 941 |
-
dilation=1,
|
| 942 |
-
),
|
| 943 |
-
CausalConv1d(
|
| 944 |
-
channels,
|
| 945 |
-
channels,
|
| 946 |
-
kernel_size,
|
| 947 |
-
1,
|
| 948 |
-
dilation=1,
|
| 949 |
-
),
|
| 950 |
-
CausalConv1d(
|
| 951 |
-
channels,
|
| 952 |
-
channels,
|
| 953 |
-
kernel_size,
|
| 954 |
-
1,
|
| 955 |
-
dilation=1,
|
| 956 |
-
),
|
| 957 |
-
]
|
| 958 |
-
)
|
| 959 |
-
|
| 960 |
-
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
| 961 |
-
|
| 962 |
-
self.activations = nn.ModuleList(
|
| 963 |
-
[TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)]
|
| 964 |
-
)
|
| 965 |
-
|
| 966 |
-
if causal_type == '2':
|
| 967 |
-
self.pre_conv = nn.Conv1d(
|
| 968 |
-
channels,
|
| 969 |
-
channels,
|
| 970 |
-
kernel_size,
|
| 971 |
-
stride=1,
|
| 972 |
-
padding=self._get_padding(kernel_size, 1),
|
| 973 |
-
)
|
| 974 |
-
self.pre_act = TorchActivation1d(activation=SnakeBeta(channels))
|
| 975 |
-
else:
|
| 976 |
-
self.pre_conv = nn.Identity()
|
| 977 |
-
self.pre_act = nn.Identity()
|
| 978 |
-
|
| 979 |
-
def _get_padding(self, kernel_size, dilation=1):
|
| 980 |
-
return int((kernel_size * dilation - dilation) / 2)
|
| 981 |
-
|
| 982 |
-
def forward(self, x):
|
| 983 |
-
hidden_states = self.pre_conv(x)
|
| 984 |
-
hidden_states = self.pre_act(hidden_states)
|
| 985 |
-
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
| 986 |
-
for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2):
|
| 987 |
-
hidden_states = act1(hidden_states)
|
| 988 |
-
hidden_states = conv1(hidden_states)
|
| 989 |
-
hidden_states = act2(hidden_states)
|
| 990 |
-
hidden_states = conv2(hidden_states)
|
| 991 |
-
x = x + hidden_states
|
| 992 |
-
return x
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
@auto_docstring
|
| 996 |
-
class Qwen3TTSTokenizerV1DecoderBigVGANModel(Qwen3TTSTokenizerV1DecoderPreTrainedModel):
|
| 997 |
-
config: Qwen3TTSTokenizerV1DecoderBigVGANConfig
|
| 998 |
-
|
| 999 |
-
def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
|
| 1000 |
-
super().__init__(config)
|
| 1001 |
-
self.num_residual_blocks = len(config.resblock_kernel_sizes)
|
| 1002 |
-
self.num_upsample_layers = len(config.upsample_rates)
|
| 1003 |
-
|
| 1004 |
-
self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 5, 1, padding=2)
|
| 1005 |
-
|
| 1006 |
-
# Removing extra ModuleList breaks official state dict
|
| 1007 |
-
ups = [
|
| 1008 |
-
nn.ModuleList(
|
| 1009 |
-
[
|
| 1010 |
-
nn.ConvTranspose1d(
|
| 1011 |
-
config.upsample_initial_channel // (2**layer_idx),
|
| 1012 |
-
config.upsample_initial_channel // (2 ** (layer_idx + 1)),
|
| 1013 |
-
kernel_size,
|
| 1014 |
-
stride,
|
| 1015 |
-
padding=(kernel_size - stride) // 2,
|
| 1016 |
-
)
|
| 1017 |
-
]
|
| 1018 |
-
)
|
| 1019 |
-
for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes))
|
| 1020 |
-
]
|
| 1021 |
-
self.ups = nn.ModuleList(ups)
|
| 1022 |
-
|
| 1023 |
-
self.resblocks = nn.ModuleList(
|
| 1024 |
-
[
|
| 1025 |
-
AMPBlock(config.upsample_initial_channel // (2 ** (layer_idx + 1)), kernel_size, dilation, '1' if layer_idx > 1 else '2')
|
| 1026 |
-
for layer_idx in range(self.num_upsample_layers)
|
| 1027 |
-
for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes)
|
| 1028 |
-
]
|
| 1029 |
-
)
|
| 1030 |
-
|
| 1031 |
-
self.activation_post = TorchActivation1d(
|
| 1032 |
-
activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers))
|
| 1033 |
-
)
|
| 1034 |
-
self.conv_post = nn.Conv1d(
|
| 1035 |
-
config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False
|
| 1036 |
-
)
|
| 1037 |
-
|
| 1038 |
-
def normalize_spectrogram(self, spectrogram, max_value, min_db):
|
| 1039 |
-
return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value)
|
| 1040 |
-
|
| 1041 |
-
def amplitude_to_db(self, amplitude, min_db_level):
|
| 1042 |
-
min_level = torch.exp(
|
| 1043 |
-
torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype)
|
| 1044 |
-
)
|
| 1045 |
-
return 20 * torch.log10(torch.clamp(amplitude, min=min_level))
|
| 1046 |
-
|
| 1047 |
-
def process_mel_spectrogram(self, mel_spectrogram):
|
| 1048 |
-
amplitude_spectrum = torch.exp(mel_spectrogram)
|
| 1049 |
-
decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20
|
| 1050 |
-
return self.normalize_spectrogram(decibel_spectrum, 1, -115)
|
| 1051 |
-
|
| 1052 |
-
def forward(self, mel_spectrogram):
|
| 1053 |
-
processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram)
|
| 1054 |
-
hidden_representation = self.conv_pre(processed_spectrogram)
|
| 1055 |
-
|
| 1056 |
-
for layer_index in range(self.num_upsample_layers):
|
| 1057 |
-
hidden_representation = self.ups[layer_index][0](hidden_representation)
|
| 1058 |
-
residual_output = sum(
|
| 1059 |
-
self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation)
|
| 1060 |
-
for block_index in range(self.num_residual_blocks)
|
| 1061 |
-
)
|
| 1062 |
-
residual_output = residual_output / self.num_residual_blocks
|
| 1063 |
-
hidden_representation = residual_output
|
| 1064 |
-
|
| 1065 |
-
hidden_representation = self.activation_post(hidden_representation)
|
| 1066 |
-
output_waveform = self.conv_post(hidden_representation)
|
| 1067 |
-
return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze(1)
|
| 1068 |
-
|
| 1069 |
-
|
| 1070 |
-
@auto_docstring
|
| 1071 |
-
class Qwen3TTSTokenizerV1DecoderDiTModel(Qwen3TTSTokenizerV1DecoderPreTrainedModel):
|
| 1072 |
-
config: Qwen3TTSTokenizerV1DecoderDiTConfig
|
| 1073 |
-
_no_split_modules = ["DiTDecoderLayer"]
|
| 1074 |
-
|
| 1075 |
-
def __init__(self, config: Qwen3TTSTokenizerV1DecoderDiTConfig):
|
| 1076 |
-
super().__init__(config)
|
| 1077 |
-
self.mel_dim = config.mel_dim
|
| 1078 |
-
self.repeats = config.repeats
|
| 1079 |
-
self.time_embed = DiTTimestepEmbedding(config.hidden_size)
|
| 1080 |
-
|
| 1081 |
-
self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats)
|
| 1082 |
-
self.input_embed = DiTInputEmbedding(config)
|
| 1083 |
-
|
| 1084 |
-
self.rotary_embed = Qwen3TTSTokenizerV1DecoderDiTRotaryEmbedding(config.head_dim)
|
| 1085 |
-
|
| 1086 |
-
self.hidden_size = config.hidden_size
|
| 1087 |
-
self.layers = config.num_hidden_layers
|
| 1088 |
-
self.block_size = config.block_size
|
| 1089 |
-
self.num_attention_heads = config.num_attention_heads
|
| 1090 |
-
|
| 1091 |
-
self.transformer_blocks = nn.ModuleList()
|
| 1092 |
-
for i in range(config.num_hidden_layers):
|
| 1093 |
-
self.transformer_blocks.append(
|
| 1094 |
-
DiTDecoderLayer(
|
| 1095 |
-
config,
|
| 1096 |
-
look_ahead_block=1 if i in config.look_ahead_layers else 0,
|
| 1097 |
-
look_backward_block=1 if i in config.look_backward_layers else 0,
|
| 1098 |
-
)
|
| 1099 |
-
)
|
| 1100 |
-
|
| 1101 |
-
self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
|
| 1102 |
-
self.proj_out = nn.Linear(config.hidden_size, config.mel_dim)
|
| 1103 |
-
|
| 1104 |
-
def _create_block_diff(self, hidden_states):
|
| 1105 |
-
batch, seq_len = hidden_states.shape[0], hidden_states.shape[1]
|
| 1106 |
-
block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length]
|
| 1107 |
-
|
| 1108 |
-
block_i = block_indices.unsqueeze(1) # [seq_length, 1]
|
| 1109 |
-
block_j = block_indices.unsqueeze(0) # [1, seq_length]
|
| 1110 |
-
block_diff = block_j - block_i # (n, n)
|
| 1111 |
-
|
| 1112 |
-
return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len)
|
| 1113 |
-
|
| 1114 |
-
def forward(
|
| 1115 |
-
self,
|
| 1116 |
-
hidden_states,
|
| 1117 |
-
condition_vector,
|
| 1118 |
-
speaker_embedding,
|
| 1119 |
-
quantized_code,
|
| 1120 |
-
time_step,
|
| 1121 |
-
drop_audio_conditioning=False,
|
| 1122 |
-
drop_code=False,
|
| 1123 |
-
apply_cfg=True,
|
| 1124 |
-
):
|
| 1125 |
-
batch_size = hidden_states.shape[0] * 2
|
| 1126 |
-
if time_step.ndim == 0:
|
| 1127 |
-
time_step = time_step.repeat(batch_size)
|
| 1128 |
-
|
| 1129 |
-
# Compute embeddings
|
| 1130 |
-
time_embedding = self.time_embed(time_step)
|
| 1131 |
-
text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code)
|
| 1132 |
-
text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None
|
| 1133 |
-
|
| 1134 |
-
hidden_states = self.input_embed(
|
| 1135 |
-
hidden_states,
|
| 1136 |
-
speaker_embedding,
|
| 1137 |
-
condition_vector,
|
| 1138 |
-
text_embedding,
|
| 1139 |
-
drop_audio_cond=drop_audio_conditioning,
|
| 1140 |
-
code_embed_uncond=text_embedding_unconditioned,
|
| 1141 |
-
apply_cfg=apply_cfg,
|
| 1142 |
-
)
|
| 1143 |
-
|
| 1144 |
-
# Compute positional encodings
|
| 1145 |
-
position_embeddings = self.rotary_embed(hidden_states)
|
| 1146 |
-
blockwise_difference = self._create_block_diff(hidden_states)
|
| 1147 |
-
|
| 1148 |
-
# Transformer blocks
|
| 1149 |
-
for transformer_block in self.transformer_blocks:
|
| 1150 |
-
hidden_states = transformer_block(
|
| 1151 |
-
hidden_states,
|
| 1152 |
-
time_embedding,
|
| 1153 |
-
position_embeddings=position_embeddings,
|
| 1154 |
-
block_diff=blockwise_difference,
|
| 1155 |
-
)
|
| 1156 |
-
|
| 1157 |
-
hidden_states = self.norm_out(hidden_states, time_embedding)
|
| 1158 |
-
output = self.proj_out(hidden_states)
|
| 1159 |
-
|
| 1160 |
-
return output
|
| 1161 |
-
|
| 1162 |
-
def optimized_scale(self, positive_flat, negative_flat):
|
| 1163 |
-
# Calculate dot production
|
| 1164 |
-
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
| 1165 |
-
# Squared norm of uncondition
|
| 1166 |
-
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
| 1167 |
-
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
| 1168 |
-
st_star = dot_product / squared_norm
|
| 1169 |
-
return st_star
|
| 1170 |
-
|
| 1171 |
-
@torch.no_grad()
|
| 1172 |
-
def sample(
|
| 1173 |
-
self,
|
| 1174 |
-
conditioning_vector,
|
| 1175 |
-
reference_mel_spectrogram,
|
| 1176 |
-
quantized_code,
|
| 1177 |
-
num_steps=10,
|
| 1178 |
-
guidance_scale=0.5,
|
| 1179 |
-
sway_coefficient=-1.0,
|
| 1180 |
-
):
|
| 1181 |
-
noise_initialization = torch.randn([quantized_code.shape[0], 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype)
|
| 1182 |
-
maximum_duration = quantized_code.shape[1] * self.repeats
|
| 1183 |
-
initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device)
|
| 1184 |
-
conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1)
|
| 1185 |
-
|
| 1186 |
-
def ode_function(time_step, hidden_states):
|
| 1187 |
-
if guidance_scale < 1e-5:
|
| 1188 |
-
prediction = self(
|
| 1189 |
-
hidden_states=hidden_states,
|
| 1190 |
-
speaker_embedding=conditioning_vector,
|
| 1191 |
-
condition_vector=reference_mel_spectrogram,
|
| 1192 |
-
quantized_code=quantized_code,
|
| 1193 |
-
time_step=time_step,
|
| 1194 |
-
drop_audio_conditioning=False,
|
| 1195 |
-
drop_code=False,
|
| 1196 |
-
)
|
| 1197 |
-
return prediction
|
| 1198 |
-
|
| 1199 |
-
model_output = self(
|
| 1200 |
-
hidden_states=hidden_states,
|
| 1201 |
-
quantized_code=quantized_code,
|
| 1202 |
-
speaker_embedding=conditioning_vector,
|
| 1203 |
-
condition_vector=reference_mel_spectrogram,
|
| 1204 |
-
time_step=time_step,
|
| 1205 |
-
apply_cfg=True,
|
| 1206 |
-
)
|
| 1207 |
-
guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0)
|
| 1208 |
-
|
| 1209 |
-
return guided_prediction + (guided_prediction - null_prediction) * guidance_scale
|
| 1210 |
-
|
| 1211 |
-
initial_time = 0
|
| 1212 |
-
time_embedding = torch.linspace(
|
| 1213 |
-
initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype
|
| 1214 |
-
)
|
| 1215 |
-
|
| 1216 |
-
if sway_coefficient is not None:
|
| 1217 |
-
time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding)
|
| 1218 |
-
|
| 1219 |
-
values = initial_state.clone()
|
| 1220 |
-
for t0, t1 in zip(time_embedding[:-1], time_embedding[1:]):
|
| 1221 |
-
dt = t1 - t0
|
| 1222 |
-
vt = ode_function(t0, values)
|
| 1223 |
-
values = values + vt * dt
|
| 1224 |
-
|
| 1225 |
-
generated_mel_spectrogram = values.permute(0, 2, 1)
|
| 1226 |
-
return generated_mel_spectrogram
|
| 1227 |
-
|
| 1228 |
-
|
| 1229 |
-
@auto_docstring
|
| 1230 |
-
class Qwen3TTSTokenizerV1Decoder(Qwen3TTSTokenizerV1DecoderPreTrainedModel):
|
| 1231 |
-
config: Qwen3TTSTokenizerV1DecoderConfig
|
| 1232 |
-
base_model_prefix = "model"
|
| 1233 |
-
_no_split_modules = ["Qwen3TTSTokenizerV1DecoderDiTModel", "Qwen3TTSTokenizerV1DecoderBigVGANModel"]
|
| 1234 |
-
|
| 1235 |
-
def __init__(self, config: Qwen3TTSTokenizerV1DecoderConfig):
|
| 1236 |
-
super().__init__(config)
|
| 1237 |
-
attn_impl = config._attn_implementation
|
| 1238 |
-
if config._attn_implementation == "flash_attention_2":
|
| 1239 |
-
logger.warning_once(
|
| 1240 |
-
"Qwen3TTSTokenizerV1Decoder must inference with fp32, but flash_attention_2 only supports fp16 and bf16, "
|
| 1241 |
-
"attention implementation of Qwen3TTSTokenizerV1Decoder will fallback to sdpa."
|
| 1242 |
-
)
|
| 1243 |
-
attn_impl = "sdpa"
|
| 1244 |
-
elif config._attn_implementation == "eager":
|
| 1245 |
-
logger.warning_once(
|
| 1246 |
-
"Qwen3TTSTokenizerV1Decoder does not support eager attention implementation, fall back to sdpa"
|
| 1247 |
-
)
|
| 1248 |
-
attn_impl = "sdpa"
|
| 1249 |
-
self.dit = Qwen3TTSTokenizerV1DecoderDiTModel._from_config(
|
| 1250 |
-
config.dit_config, attn_implementation=attn_impl
|
| 1251 |
-
)
|
| 1252 |
-
self.bigvgan = Qwen3TTSTokenizerV1DecoderBigVGANModel._from_config(
|
| 1253 |
-
config.bigvgan_config, attn_implementation=attn_impl
|
| 1254 |
-
)
|
| 1255 |
-
|
| 1256 |
-
def forward(
|
| 1257 |
-
self,
|
| 1258 |
-
code,
|
| 1259 |
-
conditioning,
|
| 1260 |
-
reference_mel,
|
| 1261 |
-
num_steps=10,
|
| 1262 |
-
guidance_scale=0.5,
|
| 1263 |
-
sway_coefficient=-1.0,
|
| 1264 |
-
**kwargs,
|
| 1265 |
-
):
|
| 1266 |
-
"""Generates a waveform from input code and conditioning parameters."""
|
| 1267 |
-
|
| 1268 |
-
mel_spectrogram = self.dit.sample(
|
| 1269 |
-
conditioning,
|
| 1270 |
-
reference_mel,
|
| 1271 |
-
code,
|
| 1272 |
-
num_steps=num_steps,
|
| 1273 |
-
guidance_scale=guidance_scale,
|
| 1274 |
-
sway_coefficient=sway_coefficient,
|
| 1275 |
-
)
|
| 1276 |
-
|
| 1277 |
-
waveform = self.bigvgan(mel_spectrogram)
|
| 1278 |
-
|
| 1279 |
-
return waveform
|
| 1280 |
-
|
| 1281 |
-
|
| 1282 |
-
class Qwen3TTSTokenizerV1Encoder(Qwen3TTSTokenizerV1EncoderPreTrainedModel):
|
| 1283 |
-
config: Qwen3TTSTokenizerV1EncoderConfig
|
| 1284 |
-
def __init__(self, config: Qwen3TTSTokenizerV1EncoderConfig):
|
| 1285 |
-
super().__init__(config)
|
| 1286 |
-
|
| 1287 |
-
self.tokenizer = WhisperEncoderVQ(
|
| 1288 |
-
n_mels=config.n_mels,
|
| 1289 |
-
n_ctx=config.n_ctx,
|
| 1290 |
-
n_state=config.n_state,
|
| 1291 |
-
n_head=config.n_head,
|
| 1292 |
-
n_layer=config.n_layer,
|
| 1293 |
-
n_window=config.n_window,
|
| 1294 |
-
output_dim=config.output_dim,
|
| 1295 |
-
grad_checkpointing=config.grad_checkpointing,
|
| 1296 |
-
enable_mp=config.enable_mp,
|
| 1297 |
-
audio_sequence_parallel=config.audio_sequence_parallel,
|
| 1298 |
-
audio_vq_type=config.audio_vq_type,
|
| 1299 |
-
audio_vq_layers=config.audio_vq_layers,
|
| 1300 |
-
audio_vq_codebook_size=config.audio_vq_codebook_size,
|
| 1301 |
-
audio_vq_codebook_dim=config.audio_vq_codebook_dim,
|
| 1302 |
-
audio_vq_pe=config.audio_vq_pe,
|
| 1303 |
-
audio_vq_ds_rate=config.audio_vq_ds_rate,
|
| 1304 |
-
)
|
| 1305 |
-
|
| 1306 |
-
self.padding = True
|
| 1307 |
-
self.audio_vq_ds_rate = self.tokenizer.audio_vq_ds_rate
|
| 1308 |
-
|
| 1309 |
-
def speech2mel(self, speechs):
|
| 1310 |
-
mels = [
|
| 1311 |
-
get_mel_audio(
|
| 1312 |
-
speech, padding = self.padding, audio_vq_ds_rate = self.audio_vq_ds_rate
|
| 1313 |
-
).to(speech.dtype).to(self.tokenizer.conv1.weight.device)
|
| 1314 |
-
for speech in speechs
|
| 1315 |
-
]
|
| 1316 |
-
return mels
|
| 1317 |
-
|
| 1318 |
-
def mel2code(self, mels):
|
| 1319 |
-
audio_mellens = [mel.size(-1) for mel in mels]
|
| 1320 |
-
audio_aftercnnlens = [get_T_after_cnn(T) for T in audio_mellens]
|
| 1321 |
-
audio_seqlens = [T + 2 for T in audio_aftercnnlens]
|
| 1322 |
-
|
| 1323 |
-
with torch.no_grad():
|
| 1324 |
-
_, indices = self.tokenizer(
|
| 1325 |
-
x_list = mels,
|
| 1326 |
-
audio_mellens = audio_mellens,
|
| 1327 |
-
audio_aftercnnlens = audio_aftercnnlens,
|
| 1328 |
-
audio_seqlens = audio_seqlens,
|
| 1329 |
-
return_indices=True,
|
| 1330 |
-
)
|
| 1331 |
-
|
| 1332 |
-
indice_lens = [T // self.tokenizer.audio_vq_ds_rate for T in audio_aftercnnlens]
|
| 1333 |
-
indices = pad_sequence(torch.split(indices, indice_lens), batch_first=True, padding_value=0)
|
| 1334 |
-
|
| 1335 |
-
return indices, indice_lens
|
| 1336 |
-
|
| 1337 |
-
def quantize_speech(self, speechs):
|
| 1338 |
-
mels = self.speech2mel(speechs)
|
| 1339 |
-
indices, indice_lens = self.mel2code(mels)
|
| 1340 |
-
return indices, indice_lens
|
| 1341 |
-
|
| 1342 |
-
|
| 1343 |
-
@auto_docstring
|
| 1344 |
-
class Qwen3TTSTokenizerV1PreTrainedModel(PreTrainedModel):
|
| 1345 |
-
config: Qwen3TTSTokenizerV1Config
|
| 1346 |
-
base_model_prefix = "model"
|
| 1347 |
-
supports_gradient_checkpointing = True
|
| 1348 |
-
_skip_keys_device_placement = "past_key_values"
|
| 1349 |
-
_supports_flash_attn = True
|
| 1350 |
-
_supports_sdpa = True
|
| 1351 |
-
_can_compile_fullgraph = False
|
| 1352 |
-
_supports_attention_backend = True
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
@auto_docstring(
|
| 1356 |
-
custom_intro="""
|
| 1357 |
-
The Qwen3TTSTokenizerV1 model.
|
| 1358 |
-
"""
|
| 1359 |
-
)
|
| 1360 |
-
class Qwen3TTSTokenizerV1Model(Qwen3TTSTokenizerV1PreTrainedModel):
|
| 1361 |
-
def __init__(self, config: Qwen3TTSTokenizerV1Config):
|
| 1362 |
-
super().__init__(config)
|
| 1363 |
-
self.config = config
|
| 1364 |
-
|
| 1365 |
-
self.input_sample_rate = config.input_sample_rate
|
| 1366 |
-
self.output_sample_rate = config.output_sample_rate
|
| 1367 |
-
|
| 1368 |
-
self.decode_upsample_rate = config.decode_upsample_rate
|
| 1369 |
-
self.encode_downsample_rate = config.encode_downsample_rate
|
| 1370 |
-
|
| 1371 |
-
self.encoder = Qwen3TTSTokenizerV1Encoder._from_config(self.config.encoder_config)
|
| 1372 |
-
self.decoder = Qwen3TTSTokenizerV1Decoder._from_config(self.config.decoder_config)
|
| 1373 |
-
|
| 1374 |
-
self.encoder_xvector_extractor = None
|
| 1375 |
-
|
| 1376 |
-
self.post_init()
|
| 1377 |
-
|
| 1378 |
-
def load_encoder_xvector_extractor(self, model_path):
|
| 1379 |
-
self.encoder_xvector_extractor = XVectorExtractor(model_path)
|
| 1380 |
-
|
| 1381 |
-
def get_model_type(self):
|
| 1382 |
-
return self.config.model_type
|
| 1383 |
-
|
| 1384 |
-
def get_input_sample_rate(self):
|
| 1385 |
-
return self.input_sample_rate
|
| 1386 |
-
|
| 1387 |
-
def get_output_sample_rate(self):
|
| 1388 |
-
return self.output_sample_rate
|
| 1389 |
-
|
| 1390 |
-
def get_encode_downsample_rate(self):
|
| 1391 |
-
return self.encode_downsample_rate
|
| 1392 |
-
|
| 1393 |
-
def get_decode_upsample_rate(self):
|
| 1394 |
-
return self.decode_upsample_rate
|
| 1395 |
-
|
| 1396 |
-
@classmethod
|
| 1397 |
-
def from_pretrained(
|
| 1398 |
-
cls,
|
| 1399 |
-
pretrained_model_name_or_path,
|
| 1400 |
-
*model_args,
|
| 1401 |
-
config=None,
|
| 1402 |
-
cache_dir=None,
|
| 1403 |
-
ignore_mismatched_sizes=False,
|
| 1404 |
-
force_download=False,
|
| 1405 |
-
local_files_only=False,
|
| 1406 |
-
token=None,
|
| 1407 |
-
revision="main",
|
| 1408 |
-
use_safetensors=None,
|
| 1409 |
-
weights_only=True,
|
| 1410 |
-
**kwargs,
|
| 1411 |
-
):
|
| 1412 |
-
model = super().from_pretrained(
|
| 1413 |
-
pretrained_model_name_or_path,
|
| 1414 |
-
*model_args,
|
| 1415 |
-
config=config,
|
| 1416 |
-
cache_dir=cache_dir,
|
| 1417 |
-
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
| 1418 |
-
force_download=force_download,
|
| 1419 |
-
local_files_only=local_files_only,
|
| 1420 |
-
token=token,
|
| 1421 |
-
revision=revision,
|
| 1422 |
-
use_safetensors=use_safetensors,
|
| 1423 |
-
weights_only=weights_only,
|
| 1424 |
-
**kwargs,
|
| 1425 |
-
)
|
| 1426 |
-
encoder_xvector_extractor_path = cached_file(
|
| 1427 |
-
pretrained_model_name_or_path,
|
| 1428 |
-
"campplus.onnx",
|
| 1429 |
-
subfolder=kwargs.pop("subfolder", None),
|
| 1430 |
-
cache_dir=kwargs.pop("cache_dir", None),
|
| 1431 |
-
force_download=kwargs.pop("force_download", False),
|
| 1432 |
-
proxies=kwargs.pop("proxies", None),
|
| 1433 |
-
resume_download=kwargs.pop("resume_download", None),
|
| 1434 |
-
local_files_only=kwargs.pop("local_files_only", False),
|
| 1435 |
-
token=kwargs.pop("use_auth_token", None),
|
| 1436 |
-
revision=kwargs.pop("revision", None),
|
| 1437 |
-
)
|
| 1438 |
-
if encoder_xvector_extractor_path is None:
|
| 1439 |
-
raise ValueError(f"""{pretrained_model_name_or_path}/{encoder_xvector_extractor_path} not exists""")
|
| 1440 |
-
model.load_encoder_xvector_extractor(encoder_xvector_extractor_path)
|
| 1441 |
-
|
| 1442 |
-
return model
|
| 1443 |
-
|
| 1444 |
-
def encode(
|
| 1445 |
-
self,
|
| 1446 |
-
input_values: torch.Tensor,
|
| 1447 |
-
padding_mask: Optional[torch.Tensor] = None,
|
| 1448 |
-
return_dict: Optional[bool] = None,
|
| 1449 |
-
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], Qwen3TTSTokenizerV1EncoderOutput]:
|
| 1450 |
-
"""
|
| 1451 |
-
Encodes the input audio waveform into discrete codes.
|
| 1452 |
-
|
| 1453 |
-
Args:
|
| 1454 |
-
input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
| 1455 |
-
Float values of the input audio waveform.
|
| 1456 |
-
padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
| 1457 |
-
Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
|
| 1458 |
-
for *masked*.
|
| 1459 |
-
return_dict (`bool`, *optional*):
|
| 1460 |
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1461 |
-
"""
|
| 1462 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1463 |
-
|
| 1464 |
-
wavs = [value[:mask.sum()] for value, mask in zip(input_values, padding_mask)]
|
| 1465 |
-
|
| 1466 |
-
codes, codes_lens = self.encoder.quantize_speech(wavs)
|
| 1467 |
-
codes = [c[:l] for c, l in zip(codes, codes_lens)]
|
| 1468 |
-
|
| 1469 |
-
xvectors = []
|
| 1470 |
-
ref_mels = []
|
| 1471 |
-
for wav in wavs:
|
| 1472 |
-
xvector, ref_mel = self.encoder_xvector_extractor.extract_code(wav.cpu().numpy())
|
| 1473 |
-
xvector = torch.tensor(xvector).to(wav.dtype).to(wav.device)
|
| 1474 |
-
ref_mel = torch.tensor(ref_mel).to(wav.dtype).to(wav.device)
|
| 1475 |
-
xvectors.append(xvector)
|
| 1476 |
-
ref_mels.append(ref_mel)
|
| 1477 |
-
|
| 1478 |
-
if not return_dict:
|
| 1479 |
-
return (
|
| 1480 |
-
codes,
|
| 1481 |
-
xvectors,
|
| 1482 |
-
ref_mels
|
| 1483 |
-
)
|
| 1484 |
-
|
| 1485 |
-
return Qwen3TTSTokenizerV1EncoderOutput(codes, xvectors, ref_mels)
|
| 1486 |
-
|
| 1487 |
-
def decode(
|
| 1488 |
-
self,
|
| 1489 |
-
audio_codes: torch.Tensor,
|
| 1490 |
-
xvectors: torch.Tensor,
|
| 1491 |
-
ref_mels: torch.Tensor,
|
| 1492 |
-
return_dict: Optional[bool] = None,
|
| 1493 |
-
) -> Union[tuple[torch.Tensor, torch.Tensor], Qwen3TTSTokenizerV1DecoderOutput]:
|
| 1494 |
-
"""
|
| 1495 |
-
Decodes the given frames into an output audio waveform.
|
| 1496 |
-
|
| 1497 |
-
Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
|
| 1498 |
-
trimmed.
|
| 1499 |
-
|
| 1500 |
-
Args:
|
| 1501 |
-
audio_codes (`torch.LongTensor` of shape `(batch_size, codes_length)`, *optional*):
|
| 1502 |
-
Discret code embeddings computed using `model.encode`.
|
| 1503 |
-
xvectors (`torch.FloatTensor` of shape `(batch_size, xvector_dim)`, *optional*):
|
| 1504 |
-
X-vector embeddings computed using `model.encode`.
|
| 1505 |
-
ref_mels (`torch.FloatTensor` of shape `(batch_size, mel_length, mel_dim)`, *optional*):
|
| 1506 |
-
Reference mel spectrogram computed using `model.encode`.
|
| 1507 |
-
return_dict (`bool`, *optional*):
|
| 1508 |
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1509 |
-
|
| 1510 |
-
"""
|
| 1511 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1512 |
-
|
| 1513 |
-
audio_values = self.decoder(code=audio_codes,
|
| 1514 |
-
reference_mel=ref_mels,
|
| 1515 |
-
conditioning=xvectors)
|
| 1516 |
-
|
| 1517 |
-
audio_lengths = (audio_codes > 0).sum(1) * self.decode_upsample_rate
|
| 1518 |
-
audio_values = [a[:l] for a, l in zip(audio_values, audio_lengths)]
|
| 1519 |
-
|
| 1520 |
-
if not return_dict:
|
| 1521 |
-
return (
|
| 1522 |
-
audio_values,
|
| 1523 |
-
)
|
| 1524 |
-
|
| 1525 |
-
return Qwen3TTSTokenizerV1DecoderOutput(audio_values)
|
| 1526 |
-
|
| 1527 |
-
|
| 1528 |
-
__all__ = ["Qwen3TTSTokenizerV1Model", "Qwen3TTSTokenizerV1PreTrainedModel"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/tokenizer_25hz/vq/assets/mel_filters.npz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:7450ae70723a5ef9d341e3cee628c7cb0177f36ce42c44b7ed2bf3325f0f6d4c
|
| 3 |
-
size 4271
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/tokenizer_25hz/vq/core_vq.py
DELETED
|
@@ -1,523 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
#
|
| 7 |
-
# This implementation is inspired from
|
| 8 |
-
# https://github.com/lucidrains/vector-quantize-pytorch
|
| 9 |
-
# which is released under MIT License. Hereafter, the original license:
|
| 10 |
-
# MIT License
|
| 11 |
-
#
|
| 12 |
-
# Copyright (c) 2020 Phil Wang
|
| 13 |
-
#
|
| 14 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 15 |
-
# of this software and associated documentation files (the "Software"), to deal
|
| 16 |
-
# in the Software without restriction, including without limitation the rights
|
| 17 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 18 |
-
# copies of the Software, and to permit persons to whom the Software is
|
| 19 |
-
# furnished to do so, subject to the following conditions:
|
| 20 |
-
#
|
| 21 |
-
# The above copyright notice and this permission notice shall be included in all
|
| 22 |
-
# copies or substantial portions of the Software.
|
| 23 |
-
#
|
| 24 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 25 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 26 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 27 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 28 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 29 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 30 |
-
# SOFTWARE.
|
| 31 |
-
|
| 32 |
-
"""Core vector quantization implementation."""
|
| 33 |
-
import random
|
| 34 |
-
import typing as tp
|
| 35 |
-
from random import randrange
|
| 36 |
-
|
| 37 |
-
import numpy as np
|
| 38 |
-
from einops import rearrange, repeat
|
| 39 |
-
from math import ceil
|
| 40 |
-
import torch
|
| 41 |
-
from torch import nn
|
| 42 |
-
import torch.nn.functional as F
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def round_up_multiple(num, mult):
|
| 46 |
-
return ceil(num / mult) * mult
|
| 47 |
-
|
| 48 |
-
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
| 49 |
-
return val if val is not None else d
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def ema_inplace(moving_avg, new, decay: float):
|
| 53 |
-
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
| 57 |
-
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def uniform_init(*shape: int):
|
| 61 |
-
t = torch.empty(shape)
|
| 62 |
-
nn.init.kaiming_uniform_(t)
|
| 63 |
-
return t
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def sample_vectors(samples, num: int):
|
| 67 |
-
num_samples, device = samples.shape[0], samples.device
|
| 68 |
-
|
| 69 |
-
if num_samples >= num:
|
| 70 |
-
indices = torch.randperm(num_samples, device=device)[:num]
|
| 71 |
-
else:
|
| 72 |
-
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 73 |
-
|
| 74 |
-
return samples[indices]
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
@torch.no_grad()
|
| 78 |
-
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
| 79 |
-
dim, dtype = samples.shape[-1], samples.dtype
|
| 80 |
-
|
| 81 |
-
means = sample_vectors(samples, num_clusters)
|
| 82 |
-
|
| 83 |
-
for _ in range(num_iters):
|
| 84 |
-
dists = -(
|
| 85 |
-
samples.pow(2).sum(1, keepdim=True)
|
| 86 |
-
- 2 * torch.matmul(samples, means.t())
|
| 87 |
-
+ means.t().pow(2).sum(0, keepdim=True)
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
buckets = dists.max(dim=-1).indices
|
| 91 |
-
del dists
|
| 92 |
-
bins = torch.bincount(buckets, minlength=num_clusters)
|
| 93 |
-
zero_mask = bins == 0
|
| 94 |
-
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
| 95 |
-
|
| 96 |
-
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
| 97 |
-
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
| 98 |
-
new_means = new_means / bins_min_clamped[..., None]
|
| 99 |
-
|
| 100 |
-
means = torch.where(zero_mask[..., None], means, new_means)
|
| 101 |
-
return means, bins
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def preprocess(x):
|
| 105 |
-
x = rearrange(x, "... d -> (...) d")
|
| 106 |
-
return x
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def postprocess_emb(embed_ind, shape):
|
| 110 |
-
return embed_ind.view(*shape[:-1])
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
class EuclideanCodebook(nn.Module):
|
| 114 |
-
"""Codebook with Euclidean distance.
|
| 115 |
-
Args:
|
| 116 |
-
dim (int): Dimension.
|
| 117 |
-
codebook_size (int): Codebook size.
|
| 118 |
-
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
| 119 |
-
If set to true, run the k-means algorithm on the first training batch and use
|
| 120 |
-
the learned centroids as initialization.
|
| 121 |
-
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
| 122 |
-
decay (float): Decay for exponential moving average over the codebooks.
|
| 123 |
-
epsilon (float): Epsilon value for numerical stability.
|
| 124 |
-
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 125 |
-
that have an exponential moving average cluster size less than the specified threshold with
|
| 126 |
-
randomly selected vector from the current batch.
|
| 127 |
-
"""
|
| 128 |
-
|
| 129 |
-
def __init__(
|
| 130 |
-
self,
|
| 131 |
-
dim: int,
|
| 132 |
-
codebook_size: int,
|
| 133 |
-
kmeans_init: int = False,
|
| 134 |
-
kmeans_iters: int = 10,
|
| 135 |
-
decay: float = 0.99,
|
| 136 |
-
epsilon: float = 1e-5,
|
| 137 |
-
threshold_ema_dead_code: float = 2.0,
|
| 138 |
-
):
|
| 139 |
-
super().__init__()
|
| 140 |
-
self.decay = decay
|
| 141 |
-
self.codebook_size = codebook_size
|
| 142 |
-
self.kmeans_iters = kmeans_iters
|
| 143 |
-
self.epsilon = epsilon
|
| 144 |
-
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 145 |
-
|
| 146 |
-
self.inited = None
|
| 147 |
-
self.cluster_size = None
|
| 148 |
-
self.embed = None
|
| 149 |
-
self.embed_avg = None
|
| 150 |
-
self.training = True
|
| 151 |
-
|
| 152 |
-
def init_embed_(self, data):
|
| 153 |
-
if self.inited:
|
| 154 |
-
return
|
| 155 |
-
|
| 156 |
-
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
| 157 |
-
self.embed.data.copy_(embed)
|
| 158 |
-
self.embed_avg.data.copy_(embed.clone())
|
| 159 |
-
self.cluster_size.data.copy_(cluster_size)
|
| 160 |
-
self.inited.data.copy_(torch.Tensor([True]))
|
| 161 |
-
# Make sure all buffers across workers are in sync after initialization
|
| 162 |
-
# distrib.broadcast_tensors([self.embed, self.embed_avg, self.cluster_size, self.inited])
|
| 163 |
-
|
| 164 |
-
def replace_(self, samples, mask):
|
| 165 |
-
modified_codebook = torch.where(
|
| 166 |
-
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
| 167 |
-
)
|
| 168 |
-
self.embed.data.copy_(modified_codebook)
|
| 169 |
-
|
| 170 |
-
def expire_codes_(self, batch_samples):
|
| 171 |
-
if self.threshold_ema_dead_code == 0:
|
| 172 |
-
return
|
| 173 |
-
|
| 174 |
-
cluster_size = self.cluster_size / sum(self.cluster_size) * self.codebook_size
|
| 175 |
-
expired_codes = cluster_size < self.threshold_ema_dead_code
|
| 176 |
-
if not torch.any(expired_codes):
|
| 177 |
-
return
|
| 178 |
-
else:
|
| 179 |
-
print(f"VQ expire infos: num_expire={sum(expired_codes)}, cluster_size[:5]={cluster_size[:5]}")
|
| 180 |
-
|
| 181 |
-
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
| 182 |
-
self.replace_(batch_samples, mask=expired_codes)
|
| 183 |
-
# sync buffers outside for efficiency
|
| 184 |
-
# distrib.broadcast_tensors(self.buffers())
|
| 185 |
-
|
| 186 |
-
def quantize(self, x):
|
| 187 |
-
embed = self.embed.t()
|
| 188 |
-
dist = -(
|
| 189 |
-
x.pow(2).sum(1, keepdim=True)
|
| 190 |
-
- 2 * x @ embed
|
| 191 |
-
+ embed.pow(2).sum(0, keepdim=True)
|
| 192 |
-
)
|
| 193 |
-
embed_ind = dist.max(dim=-1).indices
|
| 194 |
-
return embed_ind
|
| 195 |
-
|
| 196 |
-
def dequantize(self, embed_ind):
|
| 197 |
-
quantize = F.embedding(embed_ind, self.embed)
|
| 198 |
-
return quantize
|
| 199 |
-
|
| 200 |
-
def encode(self, x, buffers):
|
| 201 |
-
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
| 202 |
-
|
| 203 |
-
shape = x.shape
|
| 204 |
-
# pre-process
|
| 205 |
-
x = preprocess(x)
|
| 206 |
-
# quantize
|
| 207 |
-
embed_ind = self.quantize(x)
|
| 208 |
-
# post-process
|
| 209 |
-
embed_ind = postprocess_emb(embed_ind, shape)
|
| 210 |
-
return embed_ind
|
| 211 |
-
|
| 212 |
-
def decode(self, embed_ind, buffers):
|
| 213 |
-
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
| 214 |
-
|
| 215 |
-
quantize = self.dequantize(embed_ind)
|
| 216 |
-
return quantize
|
| 217 |
-
|
| 218 |
-
def forward(self, x, buffers):
|
| 219 |
-
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
| 220 |
-
|
| 221 |
-
shape, dtype = x.shape, x.dtype
|
| 222 |
-
x = preprocess(x)
|
| 223 |
-
|
| 224 |
-
self.init_embed_(x)
|
| 225 |
-
if self.training:
|
| 226 |
-
# We do the expiry of code at that point as buffers are in sync
|
| 227 |
-
# and all the workers will take the same decision.
|
| 228 |
-
self.expire_codes_(x)
|
| 229 |
-
|
| 230 |
-
embed_ind = self.quantize(x)
|
| 231 |
-
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
| 232 |
-
embed_ind = postprocess_emb(embed_ind, shape)
|
| 233 |
-
quantize = self.dequantize(embed_ind)
|
| 234 |
-
|
| 235 |
-
if self.training:
|
| 236 |
-
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
| 237 |
-
embed_sum = x.t() @ embed_onehot
|
| 238 |
-
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
| 239 |
-
cluster_size = (
|
| 240 |
-
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
| 241 |
-
* self.cluster_size.sum()
|
| 242 |
-
)
|
| 243 |
-
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
| 244 |
-
self.embed.data.copy_(embed_normalized)
|
| 245 |
-
# Note: after ema update, there is a very small difference between codebooks on GPUs.
|
| 246 |
-
# The impact can be very small, ignore it.
|
| 247 |
-
|
| 248 |
-
return quantize, embed_ind
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
class VectorQuantization(nn.Module):
|
| 252 |
-
"""Vector quantization implementation.
|
| 253 |
-
Currently, supports only euclidean distance.
|
| 254 |
-
Args:
|
| 255 |
-
dim (int): Dimension
|
| 256 |
-
codebook_size (int): Codebook size
|
| 257 |
-
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
| 258 |
-
decay (float): Decay for exponential moving average over the codebooks.
|
| 259 |
-
epsilon (float): Epsilon value for numerical stability.
|
| 260 |
-
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 261 |
-
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 262 |
-
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 263 |
-
that have an exponential moving average cluster size less than the specified threshold with
|
| 264 |
-
randomly selected vector from the current batch.
|
| 265 |
-
commitment_weight (float): Weight for commitment loss.
|
| 266 |
-
"""
|
| 267 |
-
def __init__(
|
| 268 |
-
self,
|
| 269 |
-
dim: int,
|
| 270 |
-
codebook_size: int,
|
| 271 |
-
codebook_dim: tp.Optional[int] = None,
|
| 272 |
-
decay: float = 0.99,
|
| 273 |
-
epsilon: float = 1e-5,
|
| 274 |
-
kmeans_init: bool = True,
|
| 275 |
-
kmeans_iters: int = 50,
|
| 276 |
-
threshold_ema_dead_code: float = 2.0,
|
| 277 |
-
commitment_weight: float = 1.,
|
| 278 |
-
):
|
| 279 |
-
super().__init__()
|
| 280 |
-
_codebook_dim: int = default(codebook_dim, dim)
|
| 281 |
-
|
| 282 |
-
requires_projection = _codebook_dim != dim
|
| 283 |
-
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
|
| 284 |
-
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
|
| 285 |
-
|
| 286 |
-
self.epsilon = epsilon
|
| 287 |
-
self.commitment_weight = commitment_weight
|
| 288 |
-
|
| 289 |
-
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
| 290 |
-
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
| 291 |
-
decay=decay, epsilon=epsilon,
|
| 292 |
-
threshold_ema_dead_code=threshold_ema_dead_code)
|
| 293 |
-
self.codebook_size = codebook_size
|
| 294 |
-
self.training = True
|
| 295 |
-
|
| 296 |
-
@property
|
| 297 |
-
def codebook(self):
|
| 298 |
-
return self._codebook.embed
|
| 299 |
-
|
| 300 |
-
def encode(self, x, buffers):
|
| 301 |
-
# x = rearrange(x, "b d n -> b n d")
|
| 302 |
-
x = self.project_in(x)
|
| 303 |
-
embed_in = self._codebook.encode(x, buffers)
|
| 304 |
-
return embed_in
|
| 305 |
-
|
| 306 |
-
def decode(self, embed_ind, buffers):
|
| 307 |
-
quantize = self._codebook.decode(embed_ind, buffers)
|
| 308 |
-
quantize = self.project_out(quantize)
|
| 309 |
-
# quantize = rearrange(quantize, "b n d -> b d n")
|
| 310 |
-
return quantize
|
| 311 |
-
|
| 312 |
-
def forward(self, x, buffers):
|
| 313 |
-
device = x.device
|
| 314 |
-
# x = rearrange(x, "b d n -> b n d")
|
| 315 |
-
x = self.project_in(x)
|
| 316 |
-
|
| 317 |
-
quantize, embed_ind = self._codebook(x, buffers)
|
| 318 |
-
|
| 319 |
-
if self.training:
|
| 320 |
-
quantize = x + (quantize - x).detach()
|
| 321 |
-
|
| 322 |
-
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
| 323 |
-
|
| 324 |
-
if self.training:
|
| 325 |
-
if self.commitment_weight > 0:
|
| 326 |
-
commit_loss = F.mse_loss(quantize.detach(), x)
|
| 327 |
-
loss = loss + commit_loss * self.commitment_weight
|
| 328 |
-
|
| 329 |
-
quantize = self.project_out(quantize)
|
| 330 |
-
# quantize = rearrange(quantize, "b n d -> b d n")
|
| 331 |
-
return quantize, embed_ind, loss
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
class DistributedResidualVectorQuantization(nn.Module):
|
| 335 |
-
"""Efficient distributed residual vector quantization implementation.
|
| 336 |
-
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
| 337 |
-
"""
|
| 338 |
-
def __init__(self, *,
|
| 339 |
-
num_quantizers,
|
| 340 |
-
quantize_dropout: bool = False,
|
| 341 |
-
rand_num_quant: tp.Optional[tp.List] = None,
|
| 342 |
-
**kwargs):
|
| 343 |
-
super().__init__()
|
| 344 |
-
"""
|
| 345 |
-
dim: int,
|
| 346 |
-
codebook_size: int,
|
| 347 |
-
codebook_dim: tp.Optional[int] = None,
|
| 348 |
-
"""
|
| 349 |
-
codebook_size, codebook_dim = kwargs["codebook_size"], kwargs["codebook_dim"] if kwargs["codebook_dim"] else kwargs["dim"]
|
| 350 |
-
kmeans_init = kwargs["kmeans_init"]
|
| 351 |
-
if isinstance(kmeans_init, bool):
|
| 352 |
-
if not kwargs["kmeans_init"]:
|
| 353 |
-
# use uniform init
|
| 354 |
-
embed = uniform_init(num_quantizers, codebook_size, codebook_dim)
|
| 355 |
-
inited = True
|
| 356 |
-
else:
|
| 357 |
-
# to perform kmeans init on first batch
|
| 358 |
-
embed = torch.zeros(num_quantizers, codebook_size, codebook_dim)
|
| 359 |
-
inited = False
|
| 360 |
-
elif isinstance(kmeans_init, str):
|
| 361 |
-
# use prepared kmeans init
|
| 362 |
-
embed = np.load(kmeans_init)
|
| 363 |
-
embed = torch.from_numpy(embed)
|
| 364 |
-
if embed.dim() == 2:
|
| 365 |
-
embed = embed.unsqueeze(0)
|
| 366 |
-
inited = True
|
| 367 |
-
else:
|
| 368 |
-
raise TypeError("kmeans_init should be either a bool or string path to init weights.")
|
| 369 |
-
|
| 370 |
-
self.register_buffer("inited", torch.Tensor([[inited] for _ in range(num_quantizers)]))
|
| 371 |
-
self.register_buffer("cluster_size", torch.zeros(num_quantizers, codebook_size))
|
| 372 |
-
self.register_buffer("embed", embed)
|
| 373 |
-
self.register_buffer("embed_avg", embed.clone())
|
| 374 |
-
|
| 375 |
-
self.q0_ds_ratio = 1
|
| 376 |
-
if "q0_ds_ratio" in kwargs:
|
| 377 |
-
self.q0_ds_ratio = kwargs.pop("q0_ds_ratio")
|
| 378 |
-
|
| 379 |
-
self.layers = nn.ModuleList()
|
| 380 |
-
for i in range(num_quantizers):
|
| 381 |
-
vq_args = dict(**kwargs)
|
| 382 |
-
vq = VectorQuantization(**vq_args)
|
| 383 |
-
self.layers.append(vq)
|
| 384 |
-
|
| 385 |
-
self.quantize_dropout = quantize_dropout
|
| 386 |
-
self.rand_num_quant = rand_num_quant
|
| 387 |
-
|
| 388 |
-
def forward(self, x, n_q: tp.Optional[int] = None):
|
| 389 |
-
quantized_out = torch.zeros_like(x)
|
| 390 |
-
residual = x
|
| 391 |
-
bb, cc, tt = x.shape
|
| 392 |
-
device = x.device
|
| 393 |
-
|
| 394 |
-
all_losses = []
|
| 395 |
-
all_indices = []
|
| 396 |
-
all_sub_quants = []
|
| 397 |
-
n_q = n_q or len(self.layers)
|
| 398 |
-
|
| 399 |
-
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
|
| 400 |
-
if should_quantize_dropout:
|
| 401 |
-
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
|
| 402 |
-
|
| 403 |
-
null_indices_shape = (x.shape[0], x.shape[2])
|
| 404 |
-
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
|
| 405 |
-
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
|
| 406 |
-
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
|
| 407 |
-
|
| 408 |
-
for quantizer_index, layer in enumerate(self.layers[:n_q]):
|
| 409 |
-
# dropout except the first quantizer
|
| 410 |
-
if should_quantize_dropout and quantizer_index >= rand_quantize_dropout_index:
|
| 411 |
-
all_indices.append(null_indices)
|
| 412 |
-
all_losses.append(null_loss)
|
| 413 |
-
all_sub_quants.append(null_sub_quant)
|
| 414 |
-
continue
|
| 415 |
-
|
| 416 |
-
quant_in = residual
|
| 417 |
-
if self.q0_ds_ratio > 1 and quantizer_index == 0:
|
| 418 |
-
quant_in = F.interpolate(quant_in, size=[tt//2])
|
| 419 |
-
quantized, indices, loss = layer(quant_in, [
|
| 420 |
-
self.inited[quantizer_index],
|
| 421 |
-
self.cluster_size[quantizer_index],
|
| 422 |
-
self.embed[quantizer_index],
|
| 423 |
-
self.embed_avg[quantizer_index]
|
| 424 |
-
])
|
| 425 |
-
if self.q0_ds_ratio > 1 and quantizer_index == 0:
|
| 426 |
-
quantized = F.interpolate(quantized, size=[tt])
|
| 427 |
-
indices = F.interpolate(indices.unsqueeze(1).float(), size=[tt]).squeeze(1).long()
|
| 428 |
-
residual = residual - quantized
|
| 429 |
-
quantized_out = quantized_out + quantized
|
| 430 |
-
|
| 431 |
-
all_indices.append(indices)
|
| 432 |
-
all_losses.append(loss)
|
| 433 |
-
all_sub_quants.append(quantized)
|
| 434 |
-
|
| 435 |
-
# sync buffers after one forward step
|
| 436 |
-
# distrib.broadcast_tensors(self.buffers())
|
| 437 |
-
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
|
| 438 |
-
|
| 439 |
-
return quantized_out, out_indices, out_losses
|
| 440 |
-
|
| 441 |
-
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
| 442 |
-
residual = x
|
| 443 |
-
all_indices = []
|
| 444 |
-
n_q = n_q or len(self.layers)
|
| 445 |
-
for i, layer in enumerate(self.layers[:n_q]):
|
| 446 |
-
indices = layer.encode(residual, [
|
| 447 |
-
self.inited[i],
|
| 448 |
-
self.cluster_size[i],
|
| 449 |
-
self.embed[i],
|
| 450 |
-
self.embed_avg[i]
|
| 451 |
-
])
|
| 452 |
-
quantized = layer.decode(indices, [
|
| 453 |
-
self.inited[i],
|
| 454 |
-
self.cluster_size[i],
|
| 455 |
-
self.embed[i],
|
| 456 |
-
self.embed_avg[i]
|
| 457 |
-
])
|
| 458 |
-
residual = residual - quantized
|
| 459 |
-
all_indices.append(indices)
|
| 460 |
-
out_indices = torch.stack(all_indices)
|
| 461 |
-
return out_indices
|
| 462 |
-
|
| 463 |
-
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
| 464 |
-
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
| 465 |
-
for i, indices in enumerate(q_indices):
|
| 466 |
-
layer = self.layers[i]
|
| 467 |
-
quantized = layer.decode(indices, [
|
| 468 |
-
self.inited[i],
|
| 469 |
-
self.cluster_size[i],
|
| 470 |
-
self.embed[i],
|
| 471 |
-
self.embed_avg[i]
|
| 472 |
-
])
|
| 473 |
-
quantized_out = quantized_out + quantized
|
| 474 |
-
return quantized_out
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
class DistributedGroupResidualVectorQuantization(nn.Module):
|
| 478 |
-
"""Efficient distributed group residual vector quantization implementation.
|
| 479 |
-
Follows Algorithm 1. in https://arxiv.org/abs/2305.02765
|
| 480 |
-
Group Then rvq
|
| 481 |
-
"""
|
| 482 |
-
def __init__(self, *,
|
| 483 |
-
num_groups,
|
| 484 |
-
num_quantizers,
|
| 485 |
-
quantize_dropout: bool = False,
|
| 486 |
-
rand_num_quant: tp.Optional[tp.List] = None,
|
| 487 |
-
**kwargs):
|
| 488 |
-
super().__init__()
|
| 489 |
-
self.rvqs = nn.ModuleList(
|
| 490 |
-
[
|
| 491 |
-
DistributedResidualVectorQuantization(
|
| 492 |
-
num_quantizers=num_quantizers,
|
| 493 |
-
quantize_dropout=quantize_dropout,
|
| 494 |
-
rand_num_quant=rand_num_quant,
|
| 495 |
-
**kwargs
|
| 496 |
-
)
|
| 497 |
-
for _ in range(num_groups)
|
| 498 |
-
]
|
| 499 |
-
)
|
| 500 |
-
self.num_groups = num_groups
|
| 501 |
-
|
| 502 |
-
def forward(self, x, n_q: tp.Optional[int] = None):
|
| 503 |
-
x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
|
| 504 |
-
all_quantized_out = []
|
| 505 |
-
all_indices = []
|
| 506 |
-
all_losses = []
|
| 507 |
-
for mod, item in zip(self.rvqs, x_lst):
|
| 508 |
-
quantized_out, out_indices, out_losses = mod(item, n_q)
|
| 509 |
-
all_quantized_out.append(quantized_out)
|
| 510 |
-
all_indices.append(out_indices)
|
| 511 |
-
all_losses.append(out_losses)
|
| 512 |
-
|
| 513 |
-
out_losses = torch.stack(all_losses, dim=1).mean(dim=1)
|
| 514 |
-
|
| 515 |
-
return torch.cat(all_quantized_out, dim=1), torch.stack(all_indices, dim=1), out_losses
|
| 516 |
-
|
| 517 |
-
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
| 518 |
-
x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
|
| 519 |
-
return torch.stack([mod.encode(item, n_q) for mod, item in zip(self.rvqs, x_lst)], dim=1)
|
| 520 |
-
|
| 521 |
-
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
| 522 |
-
q_indices_lst = torch.chunk(q_indices, chunks=self.num_groups, dim=1)
|
| 523 |
-
return torch.cat([mod.decode(item.squeeze(1)) for mod, item in zip(self.rvqs, q_indices_lst)], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/tokenizer_25hz/vq/speech_vq.py
DELETED
|
@@ -1,357 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
import sox
|
| 17 |
-
import copy
|
| 18 |
-
import torch
|
| 19 |
-
import operator
|
| 20 |
-
import onnxruntime
|
| 21 |
-
|
| 22 |
-
import torch.nn as nn
|
| 23 |
-
import torch.nn.functional as F
|
| 24 |
-
import torchaudio.compliance.kaldi as kaldi
|
| 25 |
-
|
| 26 |
-
from librosa.filters import mel as librosa_mel_fn
|
| 27 |
-
from itertools import accumulate
|
| 28 |
-
from typing import List
|
| 29 |
-
from torch import Tensor
|
| 30 |
-
|
| 31 |
-
from .core_vq import DistributedGroupResidualVectorQuantization
|
| 32 |
-
from .whisper_encoder import WhisperEncoder, Conv1d, ConvTranspose1d
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 36 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 37 |
-
|
| 38 |
-
def spectral_normalize_torch(magnitudes):
|
| 39 |
-
output = dynamic_range_compression_torch(magnitudes)
|
| 40 |
-
return output
|
| 41 |
-
|
| 42 |
-
class MelSpectrogramFeatures(nn.Module):
|
| 43 |
-
"""
|
| 44 |
-
Calculate the BigVGAN style mel spectrogram of an input signal.
|
| 45 |
-
Args:
|
| 46 |
-
filter_length (int): The number of samples in the filter window, used for the Fourier Transform. Default is 1024.
|
| 47 |
-
hop_length (int): The number of samples between successive frames (stride of the STFT). Default is 160.
|
| 48 |
-
win_length (int): The length of the window function applied to each frame, usually less than or equal to the filter length. Default is 640.
|
| 49 |
-
n_mel_channels (int): The number of Mel-frequency channels to output from the Mel-scale spectrogram. Default is 80.
|
| 50 |
-
mel_fmin (int): The minimum frequency (in Hz) of the Mel-scale spectrogram. Default is 0.
|
| 51 |
-
mel_fmax (int): The maximum frequency (in Hz) of the Mel-scale spectrogram. Default is 8000.
|
| 52 |
-
sampling_rate (int): The sampling rate of the audio data (in Hz). Default is 16000.
|
| 53 |
-
sampling_rate_org (int, optional): The original sampling rate of the audio data before any resampling (in Hz), if applicable. Default is None.
|
| 54 |
-
padding (str): The padding mode for the input signal. 'center' pads the signal symmetrically around its center. Default is 'center'.
|
| 55 |
-
|
| 56 |
-
Returns:
|
| 57 |
-
torch.Tensor: Mel spectrogram.
|
| 58 |
-
"""
|
| 59 |
-
def __init__(self,
|
| 60 |
-
filter_length=1024,
|
| 61 |
-
hop_length=160,
|
| 62 |
-
win_length=640,
|
| 63 |
-
n_mel_channels=80,
|
| 64 |
-
mel_fmin=0,
|
| 65 |
-
mel_fmax=8000,
|
| 66 |
-
sampling_rate=16000,
|
| 67 |
-
sampling_rate_org=None,
|
| 68 |
-
padding='center',
|
| 69 |
-
use_db = False,
|
| 70 |
-
):
|
| 71 |
-
super().__init__()
|
| 72 |
-
if padding not in ["center", "same"]:
|
| 73 |
-
raise ValueError("Padding must be 'center' or 'same'.")
|
| 74 |
-
self.padding = padding
|
| 75 |
-
|
| 76 |
-
self.filter_length = filter_length
|
| 77 |
-
self.hop_length = hop_length
|
| 78 |
-
self.win_length = win_length
|
| 79 |
-
self.n_mel_channels = n_mel_channels
|
| 80 |
-
self.mel_fmin = mel_fmin
|
| 81 |
-
self.mel_fmax = mel_fmax
|
| 82 |
-
self.sampling_rate = sampling_rate
|
| 83 |
-
self.sampling_rate_org = sampling_rate_org if sampling_rate_org is not None else sampling_rate
|
| 84 |
-
self.mel_basis = {}
|
| 85 |
-
self.hann_window = {}
|
| 86 |
-
|
| 87 |
-
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 88 |
-
with torch.no_grad():
|
| 89 |
-
feats = self.extract(audio, **kwargs)
|
| 90 |
-
return feats
|
| 91 |
-
|
| 92 |
-
def extract(self, audio, **kwargs):
|
| 93 |
-
|
| 94 |
-
if len(audio.shape) == 3:
|
| 95 |
-
audio = audio.squeeze(1) if audio.shape[1] == 1 else audio.squeeze(2)
|
| 96 |
-
assert len(audio.shape) == 2
|
| 97 |
-
|
| 98 |
-
y = audio
|
| 99 |
-
if len(list(self.mel_basis.keys())) == 0:
|
| 100 |
-
mel = librosa_mel_fn(sr=self.sampling_rate, n_fft=self.filter_length, n_mels=self.n_mel_channels, fmin=self.mel_fmin, fmax=self.mel_fmax)
|
| 101 |
-
self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
| 102 |
-
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(y.device)
|
| 103 |
-
|
| 104 |
-
y = torch.nn.functional.pad(y.unsqueeze(1), (int((self.filter_length-self.hop_length)/2), int((self.filter_length-self.hop_length)/2)), mode='reflect')
|
| 105 |
-
y = y.squeeze(1)
|
| 106 |
-
|
| 107 |
-
spec = torch.stft(y, self.filter_length, hop_length=self.hop_length, win_length=self.win_length, window=self.hann_window[str(y.device)],
|
| 108 |
-
center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
| 109 |
-
spec = torch.view_as_real(spec)
|
| 110 |
-
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
| 111 |
-
|
| 112 |
-
spec = torch.matmul(self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)], spec)
|
| 113 |
-
spec = spectral_normalize_torch(spec)
|
| 114 |
-
|
| 115 |
-
return spec
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
class XVectorExtractor(nn.Module):
|
| 119 |
-
def __init__(self, audio_codec_with_xvector):
|
| 120 |
-
super().__init__()
|
| 121 |
-
option = onnxruntime.SessionOptions()
|
| 122 |
-
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 123 |
-
option.intra_op_num_threads = 1
|
| 124 |
-
providers = ["CPUExecutionProvider"]
|
| 125 |
-
self.ort_session = onnxruntime.InferenceSession(audio_codec_with_xvector, sess_options=option, providers=providers)
|
| 126 |
-
|
| 127 |
-
self.tfm = sox.Transformer()
|
| 128 |
-
self.tfm.norm(db_level=-6)
|
| 129 |
-
|
| 130 |
-
self.mel_ext = MelSpectrogramFeatures(
|
| 131 |
-
filter_length=1024,
|
| 132 |
-
hop_length=160,
|
| 133 |
-
win_length=640,
|
| 134 |
-
n_mel_channels=80,
|
| 135 |
-
mel_fmin=0,
|
| 136 |
-
mel_fmax=8000,
|
| 137 |
-
sampling_rate=16000
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
def extract_code(self, audio):
|
| 141 |
-
with torch.no_grad():
|
| 142 |
-
norm_audio = self.sox_norm(audio)
|
| 143 |
-
|
| 144 |
-
norm_audio = torch.from_numpy(copy.deepcopy(norm_audio)).unsqueeze(0)
|
| 145 |
-
feat = kaldi.fbank(norm_audio,
|
| 146 |
-
num_mel_bins=80,
|
| 147 |
-
dither=0,
|
| 148 |
-
sample_frequency=16000)
|
| 149 |
-
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 150 |
-
norm_embedding = self.ort_session.run(None, {self.ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten()
|
| 151 |
-
norm_embedding = F.normalize(torch.from_numpy(norm_embedding), dim=0)
|
| 152 |
-
|
| 153 |
-
ref_mel = self.mel_ext.extract(audio=norm_audio)
|
| 154 |
-
|
| 155 |
-
return norm_embedding.numpy(), ref_mel.permute(0,2,1).squeeze(0).numpy()
|
| 156 |
-
|
| 157 |
-
def sox_norm(self, audio):
|
| 158 |
-
wav_norm = self.tfm.build_array(input_array=audio, sample_rate_in=16000)
|
| 159 |
-
return wav_norm
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
class WhisperEncoderVQ(WhisperEncoder):
|
| 163 |
-
def __init__(
|
| 164 |
-
self,
|
| 165 |
-
n_mels: int,
|
| 166 |
-
n_ctx: int,
|
| 167 |
-
n_state: int,
|
| 168 |
-
n_head: int,
|
| 169 |
-
n_layer: int,
|
| 170 |
-
n_window: int = 1500,
|
| 171 |
-
output_dim: int = 512,
|
| 172 |
-
grad_checkpointing: bool = False,
|
| 173 |
-
enable_mp: bool = False,
|
| 174 |
-
audio_sequence_parallel: bool = False,
|
| 175 |
-
audio_vq_layers: int = -1,
|
| 176 |
-
audio_vq_type: str = "NULL",
|
| 177 |
-
audio_vq_codebook_size: int = 4096,
|
| 178 |
-
audio_vq_pe: bool = False,
|
| 179 |
-
audio_vq_commit_loss: float = 0.0,
|
| 180 |
-
audio_vq_out_commit_loss: float = 0.0,
|
| 181 |
-
audio_vq_no_quantize: bool = False,
|
| 182 |
-
audio_vq_ff_layer: int = 0,
|
| 183 |
-
audio_vq_threshold_ema_dead_code: float = 0.1,
|
| 184 |
-
audio_vq_codebook_dim: int = None,
|
| 185 |
-
audio_vq_ds_rate: int = None,
|
| 186 |
-
):
|
| 187 |
-
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer, n_window, output_dim, grad_checkpointing, enable_mp, audio_sequence_parallel)
|
| 188 |
-
|
| 189 |
-
self.audio_vq_layers = audio_vq_layers
|
| 190 |
-
self.audio_vq_type = audio_vq_type
|
| 191 |
-
self.audio_vq_codebook_size = audio_vq_codebook_size
|
| 192 |
-
self.audio_vq_pe = audio_vq_pe
|
| 193 |
-
self.audio_vq_commit_loss = audio_vq_commit_loss
|
| 194 |
-
self.audio_vq_out_commit_loss = audio_vq_out_commit_loss
|
| 195 |
-
self.audio_vq_no_quantize = audio_vq_no_quantize
|
| 196 |
-
self.audio_vq_ff_layer = audio_vq_ff_layer
|
| 197 |
-
|
| 198 |
-
if audio_vq_layers > 0:
|
| 199 |
-
self.vq_feature_dim = self.n_state
|
| 200 |
-
self.audio_vq_ds_rate = 1
|
| 201 |
-
else:
|
| 202 |
-
raise NotImplementedError(f"Unsupported audio_vq_layers: {audio_vq_layers}")
|
| 203 |
-
|
| 204 |
-
if self.audio_vq_ds_rate == audio_vq_ds_rate:
|
| 205 |
-
self.audio_vq_downsample = nn.Identity()
|
| 206 |
-
self.audio_vq_upsample = nn.Identity()
|
| 207 |
-
else:
|
| 208 |
-
assert audio_vq_ds_rate % self.audio_vq_ds_rate == 0
|
| 209 |
-
stride = audio_vq_ds_rate // self.audio_vq_ds_rate
|
| 210 |
-
self.audio_vq_downsample = Conv1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
|
| 211 |
-
self.audio_vq_upsample = ConvTranspose1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
|
| 212 |
-
self.audio_vq_ds_rate = audio_vq_ds_rate
|
| 213 |
-
|
| 214 |
-
if audio_vq_type == "GRVQ":
|
| 215 |
-
self.audio_quantizer = DistributedGroupResidualVectorQuantization(
|
| 216 |
-
codebook_size = audio_vq_codebook_size,
|
| 217 |
-
dim = self.vq_feature_dim,
|
| 218 |
-
codebook_dim = self.vq_codebook_dim if audio_vq_codebook_dim is None else audio_vq_codebook_dim,
|
| 219 |
-
num_groups=1,
|
| 220 |
-
num_quantizers=1,
|
| 221 |
-
kmeans_init=False,
|
| 222 |
-
threshold_ema_dead_code = audio_vq_threshold_ema_dead_code
|
| 223 |
-
)
|
| 224 |
-
else:
|
| 225 |
-
raise NotImplementedError(f"Unsupported audio_vq_type: {audio_vq_type}")
|
| 226 |
-
|
| 227 |
-
if self.audio_vq_pe:
|
| 228 |
-
self.project_after_vq_pe = nn.Linear(self.n_state, self.n_state)
|
| 229 |
-
|
| 230 |
-
def _calc_quantize_activities(self, indices):
|
| 231 |
-
indices_onehot = F.one_hot(indices.long().flatten(), self.audio_vq_codebook_size).sum(dim=0)
|
| 232 |
-
vq_num_activities = sum(indices_onehot>0)
|
| 233 |
-
vq_num_tokens = sum(indices_onehot)
|
| 234 |
-
return {
|
| 235 |
-
"vq_num_activities": vq_num_activities,
|
| 236 |
-
"vq_num_tokens": vq_num_tokens,
|
| 237 |
-
}
|
| 238 |
-
|
| 239 |
-
def _do_quantize(self, x, pe=None, y=None):
|
| 240 |
-
"""
|
| 241 |
-
x: torch.Tensor, shape = (T, D)
|
| 242 |
-
q: torch.Tensor, shape = (T, D)
|
| 243 |
-
i: torch.Tensor, shape = (T)
|
| 244 |
-
"""
|
| 245 |
-
if self.audio_vq_out_commit_loss > 0:
|
| 246 |
-
x_teacher = x.clone()
|
| 247 |
-
x = x.unsqueeze(0)
|
| 248 |
-
|
| 249 |
-
x = self.audio_vq_downsample(x.transpose(1, 2))
|
| 250 |
-
x = x.transpose(1, 2)
|
| 251 |
-
|
| 252 |
-
vq_stats = {}
|
| 253 |
-
|
| 254 |
-
if self.audio_vq_type == "GRVQ":
|
| 255 |
-
if self.training:
|
| 256 |
-
raise NotImplementedError
|
| 257 |
-
else:
|
| 258 |
-
indices = self.audio_quantizer.encode(x)
|
| 259 |
-
x = self.audio_quantizer.decode(indices)
|
| 260 |
-
indices = indices.squeeze(2).squeeze(1)
|
| 261 |
-
|
| 262 |
-
vq_stats.update(self._calc_quantize_activities(indices))
|
| 263 |
-
|
| 264 |
-
x, indices = x.squeeze(0), indices.squeeze(0)
|
| 265 |
-
if self.audio_vq_pe:
|
| 266 |
-
x = x + pe
|
| 267 |
-
x = self.project_after_vq_pe(x)
|
| 268 |
-
|
| 269 |
-
x = self.audio_vq_upsample(x.unsqueeze(0).transpose(1, 2))
|
| 270 |
-
x = x.transpose(1, 2).squeeze(0)
|
| 271 |
-
|
| 272 |
-
if self.audio_vq_out_commit_loss > 0:
|
| 273 |
-
vq_out_commit_loss = F.mse_loss(x_teacher.detach(), x)
|
| 274 |
-
vq_stats["vq_out_commit_loss"] = vq_out_commit_loss * self.audio_vq_out_commit_loss
|
| 275 |
-
|
| 276 |
-
return x, indices, vq_stats
|
| 277 |
-
|
| 278 |
-
def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int], return_indices=False, audio_pitchs=None):
|
| 279 |
-
"""
|
| 280 |
-
x : torch.Tensor, shape = (n_mels, n_ctx)
|
| 281 |
-
the mel spectrogram of the audio
|
| 282 |
-
"""
|
| 283 |
-
|
| 284 |
-
aftercnn_x_list = []
|
| 285 |
-
pe_for_vq_list = []
|
| 286 |
-
for each_x in x_list:
|
| 287 |
-
each_x_split_list = each_x.split(self.n_window * 2, dim=1)
|
| 288 |
-
for each_x_split in each_x_split_list:
|
| 289 |
-
each_x_split = F.gelu(self.conv1(each_x_split))
|
| 290 |
-
each_x_split = F.gelu(self.conv2(each_x_split))
|
| 291 |
-
each_x_split = each_x_split.permute(1, 0) # L,D
|
| 292 |
-
|
| 293 |
-
each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
|
| 294 |
-
aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
|
| 295 |
-
|
| 296 |
-
pe_for_vq_split = self.positional_embedding[:each_x_split.shape[0] // self.audio_vq_ds_rate]
|
| 297 |
-
pe_for_vq_list.append(pe_for_vq_split.to(each_x_split.dtype))
|
| 298 |
-
|
| 299 |
-
pe_for_vq = torch.cat(pe_for_vq_list, dim=0)
|
| 300 |
-
x = torch.cat(aftercnn_x_list, dim=0)
|
| 301 |
-
src_len = x.size(0)
|
| 302 |
-
|
| 303 |
-
output_list = []
|
| 304 |
-
for item in audio_aftercnnlens:
|
| 305 |
-
while item > self.n_window:
|
| 306 |
-
output_list.append(self.n_window)
|
| 307 |
-
item -= self.n_window
|
| 308 |
-
output_list.append(item)
|
| 309 |
-
|
| 310 |
-
cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
|
| 311 |
-
cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
|
| 312 |
-
|
| 313 |
-
layer_id = 0
|
| 314 |
-
|
| 315 |
-
for block in self.blocks:
|
| 316 |
-
layer_id+=1
|
| 317 |
-
|
| 318 |
-
x = block(x, cu_seqlens=cu_seqlens)
|
| 319 |
-
|
| 320 |
-
if self.audio_vq_layers == layer_id: # vq inside encoder
|
| 321 |
-
x, indices, vq_stats = self._do_quantize(x, pe_for_vq)
|
| 322 |
-
if return_indices:
|
| 323 |
-
return x, indices
|
| 324 |
-
|
| 325 |
-
if self.avg_pooler:
|
| 326 |
-
x_list = x.split(audio_aftercnnlens, dim=0)
|
| 327 |
-
token_x_list = []
|
| 328 |
-
for x in x_list:
|
| 329 |
-
x = x.permute(1, 0)
|
| 330 |
-
x = self.avg_pooler(x)
|
| 331 |
-
x = x.permute(1, 0)
|
| 332 |
-
token_x_list.append(x)
|
| 333 |
-
x = torch.cat(token_x_list, dim=0)
|
| 334 |
-
|
| 335 |
-
x = self.ln_post(x)
|
| 336 |
-
|
| 337 |
-
x = self.proj(x)
|
| 338 |
-
|
| 339 |
-
output = torch.zeros(
|
| 340 |
-
(x.size(0) + len(audio_seqlens) * 2, x.size(1)),
|
| 341 |
-
device=x.device, dtype=x.dtype
|
| 342 |
-
)
|
| 343 |
-
|
| 344 |
-
audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
|
| 345 |
-
start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
|
| 346 |
-
end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
|
| 347 |
-
|
| 348 |
-
audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
|
| 349 |
-
audio_tokens_mask[start_ids] = False
|
| 350 |
-
audio_tokens_mask[end_ids] = False
|
| 351 |
-
output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
|
| 352 |
-
output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
|
| 353 |
-
output[audio_tokens_mask] = x
|
| 354 |
-
|
| 355 |
-
if self.audio_vq_type != "NULL":
|
| 356 |
-
return output, vq_stats
|
| 357 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/core/tokenizer_25hz/vq/whisper_encoder.py
DELETED
|
@@ -1,406 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
import os
|
| 17 |
-
import math
|
| 18 |
-
import torch
|
| 19 |
-
import operator
|
| 20 |
-
|
| 21 |
-
import numpy as np
|
| 22 |
-
import torch.nn.functional as F
|
| 23 |
-
|
| 24 |
-
from functools import lru_cache
|
| 25 |
-
from typing import Optional, Union, List
|
| 26 |
-
from torch import nn, Tensor
|
| 27 |
-
from itertools import accumulate
|
| 28 |
-
|
| 29 |
-
try:
|
| 30 |
-
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
|
| 31 |
-
except ImportError:
|
| 32 |
-
try:
|
| 33 |
-
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_varlen_func
|
| 34 |
-
except ImportError:
|
| 35 |
-
print("\n********\nWarning: flash-attn is not installed. Will only run the manual PyTorch version. Please install flash-attn for faster inference.\n********\n ")
|
| 36 |
-
flash_attn_varlen_func = None
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
N_FFT = 400
|
| 40 |
-
HOP_LENGTH = 160
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
@lru_cache(maxsize=None)
|
| 44 |
-
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
| 45 |
-
"""
|
| 46 |
-
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
| 47 |
-
Allows decoupling librosa dependency; saved using:
|
| 48 |
-
|
| 49 |
-
np.savez_compressed(
|
| 50 |
-
"mel_filters.npz",
|
| 51 |
-
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
| 52 |
-
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
| 53 |
-
)
|
| 54 |
-
"""
|
| 55 |
-
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
| 56 |
-
|
| 57 |
-
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
| 58 |
-
with np.load(filters_path, allow_pickle=False) as f:
|
| 59 |
-
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def log_mel_spectrogram(
|
| 63 |
-
audio: Union[str, np.ndarray, torch.Tensor],
|
| 64 |
-
n_mels: int = 80,
|
| 65 |
-
padding: int = 0,
|
| 66 |
-
device: Optional[Union[str, torch.device]] = None,
|
| 67 |
-
):
|
| 68 |
-
"""
|
| 69 |
-
Compute the log-Mel spectrogram of
|
| 70 |
-
|
| 71 |
-
Parameters
|
| 72 |
-
----------
|
| 73 |
-
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
| 74 |
-
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
| 75 |
-
|
| 76 |
-
n_mels: int
|
| 77 |
-
The number of Mel-frequency filters, only 80 is supported
|
| 78 |
-
|
| 79 |
-
padding: int
|
| 80 |
-
Number of zero samples to pad to the right
|
| 81 |
-
|
| 82 |
-
device: Optional[Union[str, torch.device]]
|
| 83 |
-
If given, the audio tensor is moved to this device before STFT
|
| 84 |
-
|
| 85 |
-
Returns
|
| 86 |
-
-------
|
| 87 |
-
torch.Tensor, shape = (80, n_frames)
|
| 88 |
-
A Tensor that contains the Mel spectrogram
|
| 89 |
-
"""
|
| 90 |
-
if not torch.is_tensor(audio):
|
| 91 |
-
audio = torch.from_numpy(audio)
|
| 92 |
-
|
| 93 |
-
if device is not None:
|
| 94 |
-
audio = audio.to(device)
|
| 95 |
-
if padding > 0:
|
| 96 |
-
audio = F.pad(audio, (0, padding))
|
| 97 |
-
window = torch.hann_window(N_FFT).to(audio.device)
|
| 98 |
-
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
| 99 |
-
magnitudes = stft[..., :-1].abs() ** 2
|
| 100 |
-
|
| 101 |
-
filters = mel_filters(audio.device, n_mels)
|
| 102 |
-
mel_spec = filters @ magnitudes
|
| 103 |
-
|
| 104 |
-
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 105 |
-
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 106 |
-
log_spec = (log_spec + 4.0) / 4.0
|
| 107 |
-
return log_spec
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def get_T_after_cnn(L_in, dilation=1):
|
| 111 |
-
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
|
| 112 |
-
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
|
| 113 |
-
L_out = 1 + L_out // stride
|
| 114 |
-
L_in = L_out
|
| 115 |
-
return L_out
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def get_mel_audio(audio, padding=False, audio_vq_ds_rate = 1, n_mels = 128):
|
| 119 |
-
audio_len = len(audio)
|
| 120 |
-
if padding:
|
| 121 |
-
reduction = 160 * 2 * audio_vq_ds_rate
|
| 122 |
-
audio_pad = math.ceil(audio_len / reduction) * reduction - audio_len
|
| 123 |
-
mel = log_mel_spectrogram(audio, n_mels=n_mels, padding=audio_pad)
|
| 124 |
-
else:
|
| 125 |
-
mel = log_mel_spectrogram(audio, n_mels=n_mels) # [F,T]
|
| 126 |
-
return mel
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def sinusoids(length, channels, max_timescale=10000):
|
| 130 |
-
"""Returns sinusoids for positional embedding"""
|
| 131 |
-
assert channels % 2 == 0
|
| 132 |
-
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
| 133 |
-
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
| 134 |
-
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
| 135 |
-
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
class Conv1d(nn.Conv1d):
|
| 139 |
-
def _conv_forward(
|
| 140 |
-
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
| 141 |
-
) -> Tensor:
|
| 142 |
-
return super()._conv_forward(
|
| 143 |
-
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
| 144 |
-
)
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
class ConvTranspose1d(nn.ConvTranspose1d):
|
| 148 |
-
def _conv_forward(
|
| 149 |
-
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
| 150 |
-
) -> Tensor:
|
| 151 |
-
return super()._conv_forward(
|
| 152 |
-
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
class Linear(nn.Linear):
|
| 157 |
-
def forward(self, x: Tensor) -> Tensor:
|
| 158 |
-
return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) )
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
class MultiHeadAttention(nn.Module):
|
| 162 |
-
def __init__(self, n_state: int, n_head: int):
|
| 163 |
-
super().__init__()
|
| 164 |
-
self.n_head = n_head
|
| 165 |
-
self.query = Linear(n_state, n_state)
|
| 166 |
-
self.key = Linear(n_state, n_state, bias=False)
|
| 167 |
-
self.value = Linear(n_state, n_state)
|
| 168 |
-
self.out = Linear(n_state, n_state)
|
| 169 |
-
|
| 170 |
-
self.use_flash_attention = True
|
| 171 |
-
|
| 172 |
-
def forward(
|
| 173 |
-
self,
|
| 174 |
-
x: Tensor,
|
| 175 |
-
cu_seqlens = None,
|
| 176 |
-
):
|
| 177 |
-
q = self.query(x)
|
| 178 |
-
k = self.key(x)
|
| 179 |
-
v = self.value(x)
|
| 180 |
-
|
| 181 |
-
if self.use_flash_attention:
|
| 182 |
-
if flash_attn_varlen_func is None:
|
| 183 |
-
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
|
| 184 |
-
else:
|
| 185 |
-
if q.dtype not in [torch.float16, torch.bfloat16]:
|
| 186 |
-
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
|
| 187 |
-
self.use_flash_attention = False
|
| 188 |
-
else:
|
| 189 |
-
x = self.qkv_flash_attention(q, k, v, cu_seqlens=cu_seqlens)
|
| 190 |
-
else:
|
| 191 |
-
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
|
| 192 |
-
|
| 193 |
-
output = self.out(x)
|
| 194 |
-
return output
|
| 195 |
-
|
| 196 |
-
def qkv_flash_attention(
|
| 197 |
-
self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens=None
|
| 198 |
-
):
|
| 199 |
-
n_ctx, n_state = q.shape
|
| 200 |
-
# scale = (n_state // self.n_head) ** -0.25
|
| 201 |
-
q = q.view(n_ctx, self.n_head, -1)# (batch_size, seqlen, nheads, headdim)
|
| 202 |
-
k = k.view(n_ctx, self.n_head, -1)
|
| 203 |
-
v = v.view(n_ctx, self.n_head, -1)
|
| 204 |
-
|
| 205 |
-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
x = flash_attn_varlen_func(
|
| 209 |
-
q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
|
| 210 |
-
)
|
| 211 |
-
x = x.reshape(n_ctx, n_state)
|
| 212 |
-
return x
|
| 213 |
-
|
| 214 |
-
def qkv_attention_manual(
|
| 215 |
-
self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
|
| 216 |
-
):
|
| 217 |
-
n_ctx, n_state = q.shape
|
| 218 |
-
head_dim = n_state // self.n_head
|
| 219 |
-
scale = head_dim ** -0.5
|
| 220 |
-
|
| 221 |
-
q = q.view(n_ctx, self.n_head, head_dim)
|
| 222 |
-
k = k.view(n_ctx, self.n_head, head_dim)
|
| 223 |
-
v = v.view(n_ctx, self.n_head, head_dim)
|
| 224 |
-
|
| 225 |
-
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
| 226 |
-
batch_size = len(seqlens)
|
| 227 |
-
max_seqlen = max(seqlens)
|
| 228 |
-
|
| 229 |
-
q_padded = torch.zeros(batch_size, max_seqlen, self.n_head, head_dim, dtype=q.dtype, device=q.device)
|
| 230 |
-
k_padded = torch.zeros_like(q_padded)
|
| 231 |
-
v_padded = torch.zeros_like(q_padded)
|
| 232 |
-
|
| 233 |
-
for i in range(batch_size):
|
| 234 |
-
start_idx = cu_seqlens[i]
|
| 235 |
-
end_idx = cu_seqlens[i+1]
|
| 236 |
-
seq_len = seqlens[i]
|
| 237 |
-
q_padded[i, :seq_len] = q[start_idx:end_idx]
|
| 238 |
-
k_padded[i, :seq_len] = k[start_idx:end_idx]
|
| 239 |
-
v_padded[i, :seq_len] = v[start_idx:end_idx]
|
| 240 |
-
|
| 241 |
-
q_padded = q_padded.transpose(1, 2)
|
| 242 |
-
k_padded = k_padded.transpose(1, 2)
|
| 243 |
-
v_padded = v_padded.transpose(1, 2)
|
| 244 |
-
|
| 245 |
-
attn_mask = torch.arange(max_seqlen, device=q.device)[None, :] < torch.tensor(seqlens, device=q.device)[:, None]
|
| 246 |
-
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
|
| 247 |
-
|
| 248 |
-
attn_mask = attn_mask.masked_fill(attn_mask == 0, -torch.finfo(q.dtype).max)
|
| 249 |
-
|
| 250 |
-
attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale
|
| 251 |
-
attn_scores = attn_scores + attn_mask
|
| 252 |
-
attn_weights = F.softmax(attn_scores, dim=-1)
|
| 253 |
-
|
| 254 |
-
context = torch.matmul(attn_weights, v_padded)
|
| 255 |
-
|
| 256 |
-
context = context.transpose(1, 2).contiguous().view(batch_size, max_seqlen, n_state)
|
| 257 |
-
|
| 258 |
-
output_packed = torch.cat([context[i, :seqlens[i]] for i in range(batch_size)], dim=0)
|
| 259 |
-
|
| 260 |
-
assert output_packed.shape == (n_ctx, n_state)
|
| 261 |
-
|
| 262 |
-
return output_packed
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
class ResidualAttentionBlock(nn.Module):
|
| 266 |
-
def __init__(self, n_state: int, n_head: int,
|
| 267 |
-
enable_mp: bool = False, sequence_parallel: bool = False):
|
| 268 |
-
super().__init__()
|
| 269 |
-
n_mlp = n_state * 4
|
| 270 |
-
self.attn_ln = nn.LayerNorm(n_state)
|
| 271 |
-
self.mlp_ln = nn.LayerNorm(n_state)
|
| 272 |
-
|
| 273 |
-
self.attn = MultiHeadAttention(n_state, n_head)
|
| 274 |
-
self.mlp = nn.Sequential(
|
| 275 |
-
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
| 276 |
-
)
|
| 277 |
-
|
| 278 |
-
def forward(
|
| 279 |
-
self,
|
| 280 |
-
x: Tensor,
|
| 281 |
-
cu_seqlens = None
|
| 282 |
-
):
|
| 283 |
-
x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens)
|
| 284 |
-
x = x + self.mlp(self.mlp_ln(x))
|
| 285 |
-
return x
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
class WhisperEncoder(nn.Module):
|
| 289 |
-
def __init__(
|
| 290 |
-
self,
|
| 291 |
-
n_mels: int,
|
| 292 |
-
n_ctx: int,
|
| 293 |
-
n_state: int,
|
| 294 |
-
n_head: int,
|
| 295 |
-
n_layer: int,
|
| 296 |
-
n_window: int = 1500,
|
| 297 |
-
output_dim: int = 512,
|
| 298 |
-
grad_checkpointing: bool = False,
|
| 299 |
-
enable_mp: bool = False,
|
| 300 |
-
audio_sequence_parallel: bool = False,
|
| 301 |
-
):
|
| 302 |
-
super().__init__()
|
| 303 |
-
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
| 304 |
-
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
| 305 |
-
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
| 306 |
-
self.n_layer = n_layer
|
| 307 |
-
self.n_mels = n_mels
|
| 308 |
-
|
| 309 |
-
self.blocks = nn.ModuleList(
|
| 310 |
-
[ResidualAttentionBlock(n_state, n_head, enable_mp=enable_mp, sequence_parallel=audio_sequence_parallel)
|
| 311 |
-
for _ in range(n_layer)]
|
| 312 |
-
)
|
| 313 |
-
self.ln_post = nn.LayerNorm(n_state)
|
| 314 |
-
self.avg_pooler = nn.AvgPool1d(2, stride=2)
|
| 315 |
-
|
| 316 |
-
self.proj = torch.nn.Linear(n_state, output_dim)
|
| 317 |
-
|
| 318 |
-
self.audio_bos_eos_token = nn.Embedding(2, output_dim)
|
| 319 |
-
|
| 320 |
-
self.output_dim = output_dim
|
| 321 |
-
self.grad_checkpointing = grad_checkpointing
|
| 322 |
-
self.enable_mp = enable_mp
|
| 323 |
-
self.n_head = n_head
|
| 324 |
-
self.n_state = n_state
|
| 325 |
-
self.n_window = n_window
|
| 326 |
-
|
| 327 |
-
self.audio_sequence_parallel = audio_sequence_parallel
|
| 328 |
-
|
| 329 |
-
self.tp_world_size = 1
|
| 330 |
-
|
| 331 |
-
self.set_audio_sync()
|
| 332 |
-
|
| 333 |
-
def set_audio_sync(self):
|
| 334 |
-
for name, param in self.named_parameters():
|
| 335 |
-
if not name.startswith("blocks"):
|
| 336 |
-
setattr(param, "audio_sync", True)
|
| 337 |
-
|
| 338 |
-
def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int]):
|
| 339 |
-
"""
|
| 340 |
-
x : torch.Tensor, shape = (n_mels, n_ctx)
|
| 341 |
-
the mel spectrogram of the audio
|
| 342 |
-
"""
|
| 343 |
-
|
| 344 |
-
aftercnn_x_list = []
|
| 345 |
-
for each_x in x_list:
|
| 346 |
-
each_x_split_list = each_x.split(self.n_window * 2, dim=1)
|
| 347 |
-
for each_x_split in each_x_split_list:
|
| 348 |
-
each_x_split = F.gelu(self.conv1(each_x_split))
|
| 349 |
-
each_x_split = F.gelu(self.conv2(each_x_split))
|
| 350 |
-
each_x_split = each_x_split.permute(1, 0) # L,D
|
| 351 |
-
each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
|
| 352 |
-
aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
|
| 353 |
-
|
| 354 |
-
x = torch.cat(aftercnn_x_list, dim=0)
|
| 355 |
-
src_len = x.size(0)
|
| 356 |
-
|
| 357 |
-
output_list = []
|
| 358 |
-
for item in audio_aftercnnlens:
|
| 359 |
-
while item > self.n_window:
|
| 360 |
-
output_list.append(self.n_window)
|
| 361 |
-
item -= self.n_window
|
| 362 |
-
output_list.append(item)
|
| 363 |
-
|
| 364 |
-
cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
|
| 365 |
-
cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
|
| 366 |
-
|
| 367 |
-
layer_id = 0
|
| 368 |
-
for block in self.blocks:
|
| 369 |
-
layer_id+=1
|
| 370 |
-
x = block(x, cu_seqlens=cu_seqlens)
|
| 371 |
-
|
| 372 |
-
if self.avg_pooler:
|
| 373 |
-
x_list = x.split(audio_aftercnnlens, dim=0)
|
| 374 |
-
token_x_list = []
|
| 375 |
-
for x in x_list:
|
| 376 |
-
x = x.permute(1, 0)
|
| 377 |
-
x = self.avg_pooler(x)
|
| 378 |
-
x = x.permute(1, 0)
|
| 379 |
-
token_x_list.append(x)
|
| 380 |
-
x = torch.cat(token_x_list, dim=0)
|
| 381 |
-
|
| 382 |
-
x = self.ln_post(x)
|
| 383 |
-
x = self.proj(x)
|
| 384 |
-
|
| 385 |
-
output = torch.zeros(
|
| 386 |
-
(x.size(0) + len(audio_seqlens) * 2, x.size(1)),
|
| 387 |
-
device=x.device, dtype=x.dtype
|
| 388 |
-
)
|
| 389 |
-
|
| 390 |
-
audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
|
| 391 |
-
start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
|
| 392 |
-
end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
|
| 393 |
-
|
| 394 |
-
audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
|
| 395 |
-
audio_tokens_mask[start_ids] = False
|
| 396 |
-
audio_tokens_mask[end_ids] = False
|
| 397 |
-
output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
|
| 398 |
-
output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
|
| 399 |
-
output[audio_tokens_mask] = x
|
| 400 |
-
return output
|
| 401 |
-
|
| 402 |
-
def lock(self, layers: int):
|
| 403 |
-
self.conv1.requires_grad_(False)
|
| 404 |
-
self.conv2.requires_grad_(False)
|
| 405 |
-
for i in range(min(layers, len(self.blocks))):
|
| 406 |
-
self.blocks[i].requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/inference/qwen3_tts_model.py
DELETED
|
@@ -1,874 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
import base64
|
| 17 |
-
import io
|
| 18 |
-
import urllib.request
|
| 19 |
-
from dataclasses import dataclass
|
| 20 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 21 |
-
from urllib.parse import urlparse
|
| 22 |
-
|
| 23 |
-
import librosa
|
| 24 |
-
import numpy as np
|
| 25 |
-
import soundfile as sf
|
| 26 |
-
import torch
|
| 27 |
-
from transformers import AutoConfig, AutoModel, AutoProcessor
|
| 28 |
-
|
| 29 |
-
from ..core.models import Qwen3TTSConfig, Qwen3TTSForConditionalGeneration, Qwen3TTSProcessor
|
| 30 |
-
|
| 31 |
-
AudioLike = Union[
|
| 32 |
-
str, # wav path, URL, base64
|
| 33 |
-
np.ndarray, # waveform (requires sr)
|
| 34 |
-
Tuple[np.ndarray, int], # (waveform, sr)
|
| 35 |
-
]
|
| 36 |
-
|
| 37 |
-
MaybeList = Union[Any, List[Any]]
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
@dataclass
|
| 41 |
-
class VoiceClonePromptItem:
|
| 42 |
-
"""
|
| 43 |
-
Container for one sample's voice-clone prompt information that can be fed to the model.
|
| 44 |
-
|
| 45 |
-
Fields are aligned with `Qwen3TTSForConditionalGeneration.generate(..., voice_clone_prompt=...)`.
|
| 46 |
-
"""
|
| 47 |
-
ref_code: Optional[torch.Tensor] # (T, Q) or (T,) depending on tokenizer 25Hz/12Hz
|
| 48 |
-
ref_spk_embedding: torch.Tensor # (D,)
|
| 49 |
-
x_vector_only_mode: bool
|
| 50 |
-
icl_mode: bool
|
| 51 |
-
ref_text: Optional[str] = None
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
class Qwen3TTSModel:
|
| 55 |
-
"""
|
| 56 |
-
A HuggingFace-style wrapper for Qwen3 TTS models (CustomVoice/VoiceDesign/Base) that provides:
|
| 57 |
-
- from_pretrained() initialization via AutoModel/AutoProcessor
|
| 58 |
-
- generation APIs for:
|
| 59 |
-
* CustomVoice: generate_custom_voice()
|
| 60 |
-
* VoiceDesign: generate_voice_design()
|
| 61 |
-
* Base: generate_voice_clone() + create_voice_clone_prompt()
|
| 62 |
-
- consistent output: (wavs: List[np.ndarray], sample_rate: int)
|
| 63 |
-
|
| 64 |
-
Notes:
|
| 65 |
-
- This wrapper expects the underlying model class to be `Qwen3TTSForConditionalGeneration`
|
| 66 |
-
- Language / speaker validation is done via model methods:
|
| 67 |
-
model.get_supported_languages(), model.get_supported_speakers()
|
| 68 |
-
"""
|
| 69 |
-
|
| 70 |
-
def __init__(self, model: Qwen3TTSForConditionalGeneration, processor, generate_defaults: Optional[Dict[str, Any]] = None):
|
| 71 |
-
self.model = model
|
| 72 |
-
self.processor = processor
|
| 73 |
-
self.generate_defaults = generate_defaults or {}
|
| 74 |
-
|
| 75 |
-
self.device = getattr(model, "device", None)
|
| 76 |
-
if self.device is None:
|
| 77 |
-
try:
|
| 78 |
-
self.device = next(model.parameters()).device
|
| 79 |
-
except StopIteration:
|
| 80 |
-
self.device = torch.device("cpu")
|
| 81 |
-
|
| 82 |
-
@classmethod
|
| 83 |
-
def from_pretrained(
|
| 84 |
-
cls,
|
| 85 |
-
pretrained_model_name_or_path: str,
|
| 86 |
-
**kwargs,
|
| 87 |
-
) -> "Qwen3TTSModel":
|
| 88 |
-
"""
|
| 89 |
-
Load a Qwen3 TTS model and its processor in HuggingFace `from_pretrained` style.
|
| 90 |
-
|
| 91 |
-
This method:
|
| 92 |
-
1) Loads config via AutoConfig (so your side can register model_type -> config/model).
|
| 93 |
-
2) Loads the model via AutoModel.from_pretrained(...), forwarding `kwargs` unchanged.
|
| 94 |
-
3) Loads the processor via AutoProcessor.from_pretrained(model_path).
|
| 95 |
-
4) Loads optional `generate_config.json` from the model directory/repo snapshot if present.
|
| 96 |
-
|
| 97 |
-
Args:
|
| 98 |
-
pretrained_model_name_or_path (str):
|
| 99 |
-
HuggingFace repo id or local directory of the model.
|
| 100 |
-
**kwargs:
|
| 101 |
-
Forwarded as-is into `AutoModel.from_pretrained(...)`.
|
| 102 |
-
Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="flash_attention_2".
|
| 103 |
-
|
| 104 |
-
Returns:
|
| 105 |
-
Qwen3TTSModel:
|
| 106 |
-
Wrapper instance containing `model`, `processor`, and generation defaults.
|
| 107 |
-
"""
|
| 108 |
-
AutoConfig.register("qwen3_tts", Qwen3TTSConfig)
|
| 109 |
-
AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration)
|
| 110 |
-
AutoProcessor.register(Qwen3TTSConfig, Qwen3TTSProcessor)
|
| 111 |
-
|
| 112 |
-
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 113 |
-
if not isinstance(model, Qwen3TTSForConditionalGeneration):
|
| 114 |
-
raise TypeError(
|
| 115 |
-
f"AutoModel returned {type(model)}, expected Qwen3TTSForConditionalGeneration. "
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True,)
|
| 119 |
-
|
| 120 |
-
generate_defaults = model.generate_config
|
| 121 |
-
return cls(model=model, processor=processor, generate_defaults=generate_defaults)
|
| 122 |
-
|
| 123 |
-
def _supported_languages_set(self) -> Optional[set]:
|
| 124 |
-
langs = getattr(self.model, "get_supported_languages", None)
|
| 125 |
-
if callable(langs):
|
| 126 |
-
v = langs()
|
| 127 |
-
if v is None:
|
| 128 |
-
return None
|
| 129 |
-
return set([str(x).lower() for x in v])
|
| 130 |
-
return None
|
| 131 |
-
|
| 132 |
-
def _supported_speakers_set(self) -> Optional[set]:
|
| 133 |
-
spks = getattr(self.model, "get_supported_speakers", None)
|
| 134 |
-
if callable(spks):
|
| 135 |
-
v = spks()
|
| 136 |
-
if v is None:
|
| 137 |
-
return None
|
| 138 |
-
return set([str(x).lower() for x in v])
|
| 139 |
-
return None
|
| 140 |
-
|
| 141 |
-
def _validate_languages(self, languages: List[str]) -> None:
|
| 142 |
-
"""
|
| 143 |
-
Validate that requested languages are supported by the model.
|
| 144 |
-
|
| 145 |
-
Args:
|
| 146 |
-
languages (List[str]): Language names for each sample.
|
| 147 |
-
|
| 148 |
-
Raises:
|
| 149 |
-
ValueError: If any language is not supported.
|
| 150 |
-
"""
|
| 151 |
-
supported = self._supported_languages_set()
|
| 152 |
-
if supported is None:
|
| 153 |
-
return
|
| 154 |
-
|
| 155 |
-
bad = []
|
| 156 |
-
for lang in languages:
|
| 157 |
-
if lang is None:
|
| 158 |
-
bad.append(lang)
|
| 159 |
-
continue
|
| 160 |
-
if str(lang).lower() not in supported:
|
| 161 |
-
bad.append(lang)
|
| 162 |
-
if bad:
|
| 163 |
-
raise ValueError(f"Unsupported languages: {bad}. Supported: {sorted(supported)}")
|
| 164 |
-
|
| 165 |
-
def _validate_speakers(self, speakers: List[Optional[str]]) -> None:
|
| 166 |
-
"""
|
| 167 |
-
Validate that requested speakers are supported by the Instruct model.
|
| 168 |
-
|
| 169 |
-
Args:
|
| 170 |
-
speakers (List[Optional[str]]): Speaker names for each sample.
|
| 171 |
-
|
| 172 |
-
Raises:
|
| 173 |
-
ValueError: If any speaker is not supported.
|
| 174 |
-
"""
|
| 175 |
-
supported = self._supported_speakers_set()
|
| 176 |
-
if supported is None:
|
| 177 |
-
return
|
| 178 |
-
|
| 179 |
-
bad = []
|
| 180 |
-
for spk in speakers:
|
| 181 |
-
if spk is None or spk == "":
|
| 182 |
-
continue
|
| 183 |
-
if str(spk).lower() not in supported:
|
| 184 |
-
bad.append(spk)
|
| 185 |
-
if bad:
|
| 186 |
-
raise ValueError(f"Unsupported speakers: {bad}. Supported: {sorted(supported)}")
|
| 187 |
-
|
| 188 |
-
def _is_probably_base64(self, s: str) -> bool:
|
| 189 |
-
if s.startswith("data:audio"):
|
| 190 |
-
return True
|
| 191 |
-
if ("/" not in s and "\\" not in s) and len(s) > 256:
|
| 192 |
-
return True
|
| 193 |
-
return False
|
| 194 |
-
|
| 195 |
-
def _is_url(self, s: str) -> bool:
|
| 196 |
-
try:
|
| 197 |
-
u = urlparse(s)
|
| 198 |
-
return u.scheme in ("http", "https") and bool(u.netloc)
|
| 199 |
-
except Exception:
|
| 200 |
-
return False
|
| 201 |
-
|
| 202 |
-
def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
|
| 203 |
-
if "," in b64 and b64.strip().startswith("data:"):
|
| 204 |
-
b64 = b64.split(",", 1)[1]
|
| 205 |
-
return base64.b64decode(b64)
|
| 206 |
-
|
| 207 |
-
def _load_audio_to_np(self, x: str) -> Tuple[np.ndarray, int]:
|
| 208 |
-
if self._is_url(x):
|
| 209 |
-
with urllib.request.urlopen(x) as resp:
|
| 210 |
-
audio_bytes = resp.read()
|
| 211 |
-
with io.BytesIO(audio_bytes) as f:
|
| 212 |
-
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
| 213 |
-
elif self._is_probably_base64(x):
|
| 214 |
-
wav_bytes = self._decode_base64_to_wav_bytes(x)
|
| 215 |
-
with io.BytesIO(wav_bytes) as f:
|
| 216 |
-
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
| 217 |
-
else:
|
| 218 |
-
audio, sr = librosa.load(x, sr=None, mono=True)
|
| 219 |
-
|
| 220 |
-
if audio.ndim > 1:
|
| 221 |
-
audio = np.mean(audio, axis=-1)
|
| 222 |
-
|
| 223 |
-
return audio.astype(np.float32), int(sr)
|
| 224 |
-
|
| 225 |
-
def _normalize_audio_inputs(self, audios: Union[AudioLike, List[AudioLike]]) -> List[Tuple[np.ndarray, int]]:
|
| 226 |
-
"""
|
| 227 |
-
Normalize audio inputs into a list of (waveform, sr).
|
| 228 |
-
|
| 229 |
-
Supported forms:
|
| 230 |
-
- str: wav path / URL / base64 audio string
|
| 231 |
-
- (np.ndarray, sr): waveform + sampling rate
|
| 232 |
-
- list of the above
|
| 233 |
-
|
| 234 |
-
Args:
|
| 235 |
-
audios:
|
| 236 |
-
Audio input(s).
|
| 237 |
-
|
| 238 |
-
Returns:
|
| 239 |
-
List[Tuple[np.ndarray, int]]:
|
| 240 |
-
List of (float32 waveform, original sr).
|
| 241 |
-
|
| 242 |
-
Raises:
|
| 243 |
-
ValueError: If a numpy waveform is provided without sr.
|
| 244 |
-
"""
|
| 245 |
-
if isinstance(audios, list):
|
| 246 |
-
items = audios
|
| 247 |
-
else:
|
| 248 |
-
items = [audios]
|
| 249 |
-
|
| 250 |
-
out: List[Tuple[np.ndarray, int]] = []
|
| 251 |
-
for a in items:
|
| 252 |
-
if isinstance(a, str):
|
| 253 |
-
out.append(self._load_audio_to_np(a))
|
| 254 |
-
elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
|
| 255 |
-
out.append((a[0].astype(np.float32), int(a[1])))
|
| 256 |
-
elif isinstance(a, np.ndarray):
|
| 257 |
-
raise ValueError("For numpy waveform input, pass a tuple (audio, sr).")
|
| 258 |
-
else:
|
| 259 |
-
raise TypeError(f"Unsupported audio input type: {type(a)}")
|
| 260 |
-
for i, a in enumerate(out):
|
| 261 |
-
if a[0].ndim > 1:
|
| 262 |
-
a[0] = np.mean(a[0], axis=-1).astype(np.float32)
|
| 263 |
-
out[i] = (a[0], a[1])
|
| 264 |
-
return out
|
| 265 |
-
|
| 266 |
-
def _ensure_list(self, x: MaybeList) -> List[Any]:
|
| 267 |
-
return x if isinstance(x, list) else [x]
|
| 268 |
-
|
| 269 |
-
def _build_assistant_text(self, text: str) -> str:
|
| 270 |
-
return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
|
| 271 |
-
|
| 272 |
-
def _build_ref_text(self, text: str) -> str:
|
| 273 |
-
return f"<|im_start|>assistant\n{text}<|im_end|>\n"
|
| 274 |
-
|
| 275 |
-
def _build_instruct_text(self, instruct: str) -> str:
|
| 276 |
-
return f"<|im_start|>user\n{instruct}<|im_end|>\n"
|
| 277 |
-
|
| 278 |
-
def _tokenize_texts(self, texts: List[str]) -> List[torch.Tensor]:
|
| 279 |
-
input_ids = []
|
| 280 |
-
for text in texts:
|
| 281 |
-
input = self.processor(text=text, return_tensors="pt", padding=True)
|
| 282 |
-
input_id = input["input_ids"].to(self.device)
|
| 283 |
-
input_id = input_id.unsqueeze(0) if input_id.dim() == 1 else input_id
|
| 284 |
-
input_ids.append(input_id)
|
| 285 |
-
return input_ids
|
| 286 |
-
|
| 287 |
-
def _merge_generate_kwargs(
|
| 288 |
-
self,
|
| 289 |
-
non_streaming_mode: Optional[bool] = None,
|
| 290 |
-
do_sample: Optional[bool] = None,
|
| 291 |
-
top_k: Optional[int] = None,
|
| 292 |
-
top_p: Optional[float] = None,
|
| 293 |
-
temperature: Optional[float] = None,
|
| 294 |
-
repetition_penalty: Optional[float] = None,
|
| 295 |
-
subtalker_dosample: Optional[bool] = None,
|
| 296 |
-
subtalker_top_k: Optional[int] = None,
|
| 297 |
-
subtalker_top_p: Optional[float] = None,
|
| 298 |
-
subtalker_temperature: Optional[float] = None,
|
| 299 |
-
max_new_tokens: Optional[int] = None,
|
| 300 |
-
**kwargs,
|
| 301 |
-
) -> Dict[str, Any]:
|
| 302 |
-
"""
|
| 303 |
-
Merge user-provided generation arguments with defaults from `generate_config.json`.
|
| 304 |
-
|
| 305 |
-
Rule:
|
| 306 |
-
- If the user explicitly passes a value (not None), use it.
|
| 307 |
-
- Otherwise, use the value from generate_config.json if present.
|
| 308 |
-
- Otherwise, fall back to the hard defaults.
|
| 309 |
-
|
| 310 |
-
Args:
|
| 311 |
-
non_streaming_mode, do_sample, top_k, top_p, temperature, repetition_penalty,
|
| 312 |
-
subtalker_dosample, subtalker_top_k, subtalker_top_p, subtalker_temperature, max_new_tokens:
|
| 313 |
-
Common generation parameters.
|
| 314 |
-
**kwargs:
|
| 315 |
-
Other arguments forwarded to model.generate().
|
| 316 |
-
|
| 317 |
-
Returns:
|
| 318 |
-
Dict[str, Any]: Final kwargs to pass into model.generate().
|
| 319 |
-
"""
|
| 320 |
-
hard_defaults = dict(
|
| 321 |
-
non_streaming_mode=False,
|
| 322 |
-
do_sample=True,
|
| 323 |
-
top_k=50,
|
| 324 |
-
top_p=1.0,
|
| 325 |
-
temperature=0.9,
|
| 326 |
-
repetition_penalty=1.05,
|
| 327 |
-
subtalker_dosample=True,
|
| 328 |
-
subtalker_top_k=50,
|
| 329 |
-
subtalker_top_p=1.0,
|
| 330 |
-
subtalker_temperature=0.9,
|
| 331 |
-
max_new_tokens=2048,
|
| 332 |
-
)
|
| 333 |
-
|
| 334 |
-
def pick(name: str, user_val: Any) -> Any:
|
| 335 |
-
if user_val is not None:
|
| 336 |
-
return user_val
|
| 337 |
-
if name in self.generate_defaults:
|
| 338 |
-
return self.generate_defaults[name]
|
| 339 |
-
return hard_defaults[name]
|
| 340 |
-
|
| 341 |
-
merged = dict(kwargs)
|
| 342 |
-
merged.update(
|
| 343 |
-
non_streaming_mode=pick("non_streaming_mode", non_streaming_mode),
|
| 344 |
-
do_sample=pick("do_sample", do_sample),
|
| 345 |
-
top_k=pick("top_k", top_k),
|
| 346 |
-
top_p=pick("top_p", top_p),
|
| 347 |
-
temperature=pick("temperature", temperature),
|
| 348 |
-
repetition_penalty=pick("repetition_penalty", repetition_penalty),
|
| 349 |
-
subtalker_dosample=pick("subtalker_dosample", subtalker_dosample),
|
| 350 |
-
subtalker_top_k=pick("subtalker_top_k", subtalker_top_k),
|
| 351 |
-
subtalker_top_p=pick("subtalker_top_p", subtalker_top_p),
|
| 352 |
-
subtalker_temperature=pick("subtalker_temperature", subtalker_temperature),
|
| 353 |
-
max_new_tokens=pick("max_new_tokens", max_new_tokens),
|
| 354 |
-
)
|
| 355 |
-
return merged
|
| 356 |
-
|
| 357 |
-
# voice clone model
|
| 358 |
-
@torch.inference_mode()
|
| 359 |
-
def create_voice_clone_prompt(
|
| 360 |
-
self,
|
| 361 |
-
ref_audio: Union[AudioLike, List[AudioLike]],
|
| 362 |
-
ref_text: Optional[Union[str, List[Optional[str]]]] = None,
|
| 363 |
-
x_vector_only_mode: Union[bool, List[bool]] = False,
|
| 364 |
-
) -> List[VoiceClonePromptItem]:
|
| 365 |
-
"""
|
| 366 |
-
Build voice-clone prompt items from reference audio (and optionally reference text) using Base model.
|
| 367 |
-
|
| 368 |
-
Modes:
|
| 369 |
-
- x_vector_only_mode=True:
|
| 370 |
-
Only speaker embedding is used to clone voice; ref_text/ref_code are ignored.
|
| 371 |
-
This is mutually exclusive with ICL.
|
| 372 |
-
- x_vector_only_mode=False:
|
| 373 |
-
ICL mode is enabled automatically (icl_mode=True). In this case ref_text is required,
|
| 374 |
-
because the model continues/conditions on the reference text + reference speech codes.
|
| 375 |
-
|
| 376 |
-
Batch behavior:
|
| 377 |
-
- ref_audio can be a single item or a list.
|
| 378 |
-
- ref_text and x_vector_only_mode can be scalars or lists.
|
| 379 |
-
- If any of them are lists with length > 1, lengths must match.
|
| 380 |
-
|
| 381 |
-
Audio input:
|
| 382 |
-
- str: local wav path / URL / base64
|
| 383 |
-
- (np.ndarray, sr): waveform + sampling rate
|
| 384 |
-
|
| 385 |
-
Args:
|
| 386 |
-
ref_audio:
|
| 387 |
-
Reference audio(s) used to extract:
|
| 388 |
-
- ref_code via `model.speech_tokenizer.encode(...)`
|
| 389 |
-
- ref_spk_embedding via `model.extract_speaker_embedding(...)` (resampled to 24k)
|
| 390 |
-
ref_text:
|
| 391 |
-
Reference transcript(s). Required when x_vector_only_mode=False (ICL mode).
|
| 392 |
-
x_vector_only_mode:
|
| 393 |
-
Whether to use speaker embedding only. If False, ICL mode will be used.
|
| 394 |
-
|
| 395 |
-
Returns:
|
| 396 |
-
List[VoiceClonePromptItem]:
|
| 397 |
-
List of prompt items that can be converted into `voice_clone_prompt` dict.
|
| 398 |
-
|
| 399 |
-
Raises:
|
| 400 |
-
ValueError:
|
| 401 |
-
- If x_vector_only_mode=False but ref_text is missing.
|
| 402 |
-
- If batch lengths mismatch.
|
| 403 |
-
"""
|
| 404 |
-
if self.model.tts_model_type != "base":
|
| 405 |
-
raise ValueError(
|
| 406 |
-
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
|
| 407 |
-
f"tts_model_size: {self.model.tts_model_size}\n"
|
| 408 |
-
f"tts_model_type: {self.model.tts_model_type}\n"
|
| 409 |
-
"does not support create_voice_clone_prompt, Please check Model Card or Readme for more details."
|
| 410 |
-
)
|
| 411 |
-
|
| 412 |
-
ref_audio_list = self._ensure_list(ref_audio)
|
| 413 |
-
ref_text_list = self._ensure_list(ref_text) if isinstance(ref_text, list) else ([ref_text] * len(ref_audio_list))
|
| 414 |
-
xvec_list = self._ensure_list(x_vector_only_mode) if isinstance(x_vector_only_mode, list) else ([x_vector_only_mode] * len(ref_audio_list))
|
| 415 |
-
|
| 416 |
-
if len(ref_text_list) != len(ref_audio_list) or len(xvec_list) != len(ref_audio_list):
|
| 417 |
-
raise ValueError(
|
| 418 |
-
f"Batch size mismatch: ref_audio={len(ref_audio_list)}, ref_text={len(ref_text_list)}, x_vector_only_mode={len(xvec_list)}"
|
| 419 |
-
)
|
| 420 |
-
|
| 421 |
-
normalized = self._normalize_audio_inputs(ref_audio_list)
|
| 422 |
-
|
| 423 |
-
ref_wavs_for_code: List[np.ndarray] = []
|
| 424 |
-
ref_sr_for_code: List[int] = []
|
| 425 |
-
for wav, sr in normalized:
|
| 426 |
-
ref_wavs_for_code.append(wav)
|
| 427 |
-
ref_sr_for_code.append(sr)
|
| 428 |
-
|
| 429 |
-
if len(set(ref_sr_for_code)) == 1:
|
| 430 |
-
enc = self.model.speech_tokenizer.encode(ref_wavs_for_code, sr=ref_sr_for_code[0])
|
| 431 |
-
ref_codes = enc.audio_codes
|
| 432 |
-
else:
|
| 433 |
-
ref_codes = []
|
| 434 |
-
for wav, sr in normalized:
|
| 435 |
-
ref_codes.append(self.model.speech_tokenizer.encode(wav, sr=sr).audio_codes[0])
|
| 436 |
-
|
| 437 |
-
items: List[VoiceClonePromptItem] = []
|
| 438 |
-
for i, ((wav, sr), code, rtext, xvec_only) in enumerate(zip(normalized, ref_codes, ref_text_list, xvec_list)):
|
| 439 |
-
if not xvec_only:
|
| 440 |
-
if rtext is None or rtext == "":
|
| 441 |
-
raise ValueError(f"ref_text is required when x_vector_only_mode=False (ICL mode). Bad index={i}")
|
| 442 |
-
|
| 443 |
-
wav_resample = wav
|
| 444 |
-
if sr != self.model.speaker_encoder_sample_rate:
|
| 445 |
-
wav_resample = librosa.resample(y=wav_resample.astype(np.float32),
|
| 446 |
-
orig_sr=int(sr),
|
| 447 |
-
target_sr=self.model.speaker_encoder_sample_rate)
|
| 448 |
-
|
| 449 |
-
spk_emb = self.model.extract_speaker_embedding(audio=wav_resample,
|
| 450 |
-
sr=self.model.speaker_encoder_sample_rate)
|
| 451 |
-
|
| 452 |
-
items.append(
|
| 453 |
-
VoiceClonePromptItem(
|
| 454 |
-
ref_code=None if xvec_only else code,
|
| 455 |
-
ref_spk_embedding=spk_emb,
|
| 456 |
-
x_vector_only_mode=bool(xvec_only),
|
| 457 |
-
icl_mode=bool(not xvec_only),
|
| 458 |
-
ref_text=rtext,
|
| 459 |
-
)
|
| 460 |
-
)
|
| 461 |
-
return items
|
| 462 |
-
|
| 463 |
-
def _prompt_items_to_voice_clone_prompt(self, items: List[VoiceClonePromptItem]) -> Dict[str, Any]:
|
| 464 |
-
return dict(
|
| 465 |
-
ref_code=[it.ref_code for it in items],
|
| 466 |
-
ref_spk_embedding=[it.ref_spk_embedding for it in items],
|
| 467 |
-
x_vector_only_mode=[it.x_vector_only_mode for it in items],
|
| 468 |
-
icl_mode=[it.icl_mode for it in items],
|
| 469 |
-
)
|
| 470 |
-
|
| 471 |
-
# voice clone model
|
| 472 |
-
@torch.no_grad()
|
| 473 |
-
def generate_voice_clone(
|
| 474 |
-
self,
|
| 475 |
-
text: Union[str, List[str]],
|
| 476 |
-
language: Union[str, List[str]] = None,
|
| 477 |
-
ref_audio: Optional[Union[AudioLike, List[AudioLike]]] = None,
|
| 478 |
-
ref_text: Optional[Union[str, List[Optional[str]]]] = None,
|
| 479 |
-
x_vector_only_mode: Union[bool, List[bool]] = False,
|
| 480 |
-
voice_clone_prompt: Optional[Union[Dict[str, Any], List[VoiceClonePromptItem]]] = None,
|
| 481 |
-
**kwargs,
|
| 482 |
-
) -> Tuple[List[np.ndarray], int]:
|
| 483 |
-
"""
|
| 484 |
-
Voice clone speech using the Base model.
|
| 485 |
-
|
| 486 |
-
You can provide either:
|
| 487 |
-
- (ref_audio, ref_text, x_vector_only_mode) and let this method build the prompt, OR
|
| 488 |
-
- `VoiceClonePromptItem` returned by `create_voice_clone_prompt`, OR
|
| 489 |
-
- a list of `VoiceClonePromptItem` returned by `create_voice_clone_prompt`.
|
| 490 |
-
|
| 491 |
-
`ref_audio` Supported forms:
|
| 492 |
-
- str: wav path / URL / base64 audio string
|
| 493 |
-
- (np.ndarray, sr): waveform + sampling rate
|
| 494 |
-
- list of the above
|
| 495 |
-
|
| 496 |
-
Input flexibility:
|
| 497 |
-
- text/language can be scalar or list.
|
| 498 |
-
- prompt can be single or batch.
|
| 499 |
-
- If batch mode (len(text)>1), lengths must match.
|
| 500 |
-
|
| 501 |
-
Args:
|
| 502 |
-
text:
|
| 503 |
-
Text(s) to synthesize.
|
| 504 |
-
language:
|
| 505 |
-
Language(s) for each sample.
|
| 506 |
-
ref_audio:
|
| 507 |
-
Reference audio(s) for prompt building. Required if voice_clone_prompt is not provided.
|
| 508 |
-
ref_text:
|
| 509 |
-
Reference text(s) used for ICL mode (required when x_vector_only_mode=False).
|
| 510 |
-
x_vector_only_mode:
|
| 511 |
-
If True, only speaker embedding is used (ignores ref_text/ref_code).
|
| 512 |
-
If False, ICL mode is used automatically.
|
| 513 |
-
voice_clone_prompt:
|
| 514 |
-
list[VoiceClonePromptItem] from `create_voice_clone_prompt`.
|
| 515 |
-
non_streaming_mode:
|
| 516 |
-
Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
|
| 517 |
-
rather than enabling true streaming input or streaming generation.
|
| 518 |
-
do_sample:
|
| 519 |
-
Whether to use sampling, recommended to be set to `true` for most use cases.
|
| 520 |
-
top_k:
|
| 521 |
-
Top-k sampling parameter.
|
| 522 |
-
top_p:
|
| 523 |
-
Top-p sampling parameter.
|
| 524 |
-
temperature:
|
| 525 |
-
Sampling temperature; higher => more random.
|
| 526 |
-
repetition_penalty:
|
| 527 |
-
Penalty to reduce repeated tokens/codes.
|
| 528 |
-
subtalker_dosample:
|
| 529 |
-
Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
|
| 530 |
-
subtalker_top_k:
|
| 531 |
-
Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 532 |
-
subtalker_top_p:
|
| 533 |
-
Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 534 |
-
subtalker_temperature:
|
| 535 |
-
Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 536 |
-
max_new_tokens:
|
| 537 |
-
Maximum number of new codec tokens to generate.
|
| 538 |
-
**kwargs:
|
| 539 |
-
Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
|
| 540 |
-
They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
|
| 541 |
-
|
| 542 |
-
Returns:
|
| 543 |
-
Tuple[List[np.ndarray], int]:
|
| 544 |
-
(wavs, sample_rate)
|
| 545 |
-
|
| 546 |
-
Raises:
|
| 547 |
-
ValueError:
|
| 548 |
-
If batch sizes mismatch or required prompt inputs are missing.
|
| 549 |
-
"""
|
| 550 |
-
if self.model.tts_model_type != "base":
|
| 551 |
-
raise ValueError(
|
| 552 |
-
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
|
| 553 |
-
f"tts_model_size: {self.model.tts_model_size}\n"
|
| 554 |
-
f"tts_model_type: {self.model.tts_model_type}\n"
|
| 555 |
-
"does not support generate_voice_clone, Please check Model Card or Readme for more details."
|
| 556 |
-
)
|
| 557 |
-
|
| 558 |
-
texts = self._ensure_list(text)
|
| 559 |
-
languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
|
| 560 |
-
if len(languages) == 1 and len(texts) > 1:
|
| 561 |
-
languages = languages * len(texts)
|
| 562 |
-
if len(texts) != len(languages):
|
| 563 |
-
raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}")
|
| 564 |
-
|
| 565 |
-
self._validate_languages(languages)
|
| 566 |
-
|
| 567 |
-
if voice_clone_prompt is None:
|
| 568 |
-
if ref_audio is None:
|
| 569 |
-
raise ValueError("Either `voice_clone_prompt` or `ref_audio` must be provided.")
|
| 570 |
-
prompt_items = self.create_voice_clone_prompt(ref_audio=ref_audio, ref_text=ref_text, x_vector_only_mode=x_vector_only_mode)
|
| 571 |
-
if len(prompt_items) == 1 and len(texts) > 1:
|
| 572 |
-
prompt_items = prompt_items * len(texts)
|
| 573 |
-
if len(prompt_items) != len(texts):
|
| 574 |
-
raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
|
| 575 |
-
voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
|
| 576 |
-
ref_texts_for_ids = [it.ref_text for it in prompt_items]
|
| 577 |
-
else:
|
| 578 |
-
if isinstance(voice_clone_prompt, list):
|
| 579 |
-
prompt_items = voice_clone_prompt
|
| 580 |
-
if len(prompt_items) == 1 and len(texts) > 1:
|
| 581 |
-
prompt_items = prompt_items * len(texts)
|
| 582 |
-
if len(prompt_items) != len(texts):
|
| 583 |
-
raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
|
| 584 |
-
voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
|
| 585 |
-
ref_texts_for_ids = [it.ref_text for it in prompt_items]
|
| 586 |
-
else:
|
| 587 |
-
voice_clone_prompt_dict = voice_clone_prompt
|
| 588 |
-
ref_texts_for_ids = None
|
| 589 |
-
|
| 590 |
-
input_texts = [self._build_assistant_text(t) for t in texts]
|
| 591 |
-
input_ids = self._tokenize_texts(input_texts)
|
| 592 |
-
|
| 593 |
-
ref_ids = None
|
| 594 |
-
if ref_texts_for_ids is not None:
|
| 595 |
-
ref_ids = []
|
| 596 |
-
for i, rt in enumerate(ref_texts_for_ids):
|
| 597 |
-
if rt is None or rt == "":
|
| 598 |
-
ref_ids.append(None)
|
| 599 |
-
else:
|
| 600 |
-
ref_tok = self._tokenize_texts([self._build_ref_text(rt)])[0]
|
| 601 |
-
ref_ids.append(ref_tok)
|
| 602 |
-
|
| 603 |
-
gen_kwargs = self._merge_generate_kwargs(**kwargs)
|
| 604 |
-
|
| 605 |
-
talker_codes_list, _ = self.model.generate(
|
| 606 |
-
input_ids=input_ids,
|
| 607 |
-
ref_ids=ref_ids,
|
| 608 |
-
voice_clone_prompt=voice_clone_prompt_dict,
|
| 609 |
-
languages=languages,
|
| 610 |
-
**gen_kwargs,
|
| 611 |
-
)
|
| 612 |
-
|
| 613 |
-
codes_for_decode = []
|
| 614 |
-
for i, codes in enumerate(talker_codes_list):
|
| 615 |
-
ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
|
| 616 |
-
if ref_code_list is not None and ref_code_list[i] is not None:
|
| 617 |
-
codes_for_decode.append(torch.cat([ref_code_list[i].to(codes.device), codes], dim=0))
|
| 618 |
-
else:
|
| 619 |
-
codes_for_decode.append(codes)
|
| 620 |
-
|
| 621 |
-
wavs_all, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in codes_for_decode])
|
| 622 |
-
|
| 623 |
-
wavs_out: List[np.ndarray] = []
|
| 624 |
-
for i, wav in enumerate(wavs_all):
|
| 625 |
-
ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
|
| 626 |
-
if ref_code_list is not None and ref_code_list[i] is not None:
|
| 627 |
-
ref_len = int(ref_code_list[i].shape[0])
|
| 628 |
-
total_len = int(codes_for_decode[i].shape[0])
|
| 629 |
-
cut = int(ref_len / max(total_len, 1) * wav.shape[0])
|
| 630 |
-
wavs_out.append(wav[cut:])
|
| 631 |
-
else:
|
| 632 |
-
wavs_out.append(wav)
|
| 633 |
-
|
| 634 |
-
return wavs_out, fs
|
| 635 |
-
|
| 636 |
-
# voice design model
|
| 637 |
-
@torch.no_grad()
|
| 638 |
-
def generate_voice_design(
|
| 639 |
-
self,
|
| 640 |
-
text: Union[str, List[str]],
|
| 641 |
-
instruct: Union[str, List[str]],
|
| 642 |
-
language: Union[str, List[str]] = None,
|
| 643 |
-
**kwargs,
|
| 644 |
-
) -> Tuple[List[np.ndarray], int]:
|
| 645 |
-
"""
|
| 646 |
-
Generate speech with the VoiceDesign model using natural-language style instructions.
|
| 647 |
-
|
| 648 |
-
Args:
|
| 649 |
-
text:
|
| 650 |
-
Text(s) to synthesize.
|
| 651 |
-
language:
|
| 652 |
-
Language(s) for each sample.
|
| 653 |
-
instruct:
|
| 654 |
-
Instruction(s) describing desired voice/style. Empty string is allowed (treated as no instruction).
|
| 655 |
-
non_streaming_mode:
|
| 656 |
-
Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
|
| 657 |
-
rather than enabling true streaming input or streaming generation.
|
| 658 |
-
do_sample:
|
| 659 |
-
Whether to use sampling, recommended to be set to `true` for most use cases.
|
| 660 |
-
top_k:
|
| 661 |
-
Top-k sampling parameter.
|
| 662 |
-
top_p:
|
| 663 |
-
Top-p sampling parameter.
|
| 664 |
-
temperature:
|
| 665 |
-
Sampling temperature; higher => more random.
|
| 666 |
-
repetition_penalty:
|
| 667 |
-
Penalty to reduce repeated tokens/codes.
|
| 668 |
-
subtalker_dosample:
|
| 669 |
-
Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
|
| 670 |
-
subtalker_top_k:
|
| 671 |
-
Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 672 |
-
subtalker_top_p:
|
| 673 |
-
Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 674 |
-
subtalker_temperature:
|
| 675 |
-
Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 676 |
-
max_new_tokens:
|
| 677 |
-
Maximum number of new codec tokens to generate.
|
| 678 |
-
**kwargs:
|
| 679 |
-
Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
|
| 680 |
-
They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
|
| 681 |
-
|
| 682 |
-
Returns:
|
| 683 |
-
Tuple[List[np.ndarray], int]:
|
| 684 |
-
(wavs, sample_rate)
|
| 685 |
-
"""
|
| 686 |
-
if self.model.tts_model_type != "voice_design":
|
| 687 |
-
raise ValueError(
|
| 688 |
-
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
|
| 689 |
-
f"tts_model_size: {self.model.tts_model_size}\n"
|
| 690 |
-
f"tts_model_type: {self.model.tts_model_type}\n"
|
| 691 |
-
"does not support generate_voice_design, Please check Model Card or Readme for more details."
|
| 692 |
-
)
|
| 693 |
-
|
| 694 |
-
texts = self._ensure_list(text)
|
| 695 |
-
languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
|
| 696 |
-
instructs = self._ensure_list(instruct)
|
| 697 |
-
|
| 698 |
-
if len(languages) == 1 and len(texts) > 1:
|
| 699 |
-
languages = languages * len(texts)
|
| 700 |
-
if len(instructs) == 1 and len(texts) > 1:
|
| 701 |
-
instructs = instructs * len(texts)
|
| 702 |
-
|
| 703 |
-
if not (len(texts) == len(languages) == len(instructs)):
|
| 704 |
-
raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}, instruct={len(instructs)}")
|
| 705 |
-
|
| 706 |
-
self._validate_languages(languages)
|
| 707 |
-
|
| 708 |
-
input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])
|
| 709 |
-
|
| 710 |
-
instruct_ids: List[Optional[torch.Tensor]] = []
|
| 711 |
-
for ins in instructs:
|
| 712 |
-
if ins is None or ins == "":
|
| 713 |
-
instruct_ids.append(None)
|
| 714 |
-
else:
|
| 715 |
-
instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])
|
| 716 |
-
|
| 717 |
-
gen_kwargs = self._merge_generate_kwargs(**kwargs)
|
| 718 |
-
|
| 719 |
-
talker_codes_list, _ = self.model.generate(
|
| 720 |
-
input_ids=input_ids,
|
| 721 |
-
instruct_ids=instruct_ids,
|
| 722 |
-
languages=languages,
|
| 723 |
-
**gen_kwargs,
|
| 724 |
-
)
|
| 725 |
-
|
| 726 |
-
wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
|
| 727 |
-
return wavs, fs
|
| 728 |
-
|
| 729 |
-
# custom voice model
|
| 730 |
-
@torch.no_grad()
|
| 731 |
-
def generate_custom_voice(
|
| 732 |
-
self,
|
| 733 |
-
text: Union[str, List[str]],
|
| 734 |
-
speaker: Union[str, List[str]],
|
| 735 |
-
language: Union[str, List[str]] = None,
|
| 736 |
-
instruct: Optional[Union[str, List[str]]] = None,
|
| 737 |
-
**kwargs,
|
| 738 |
-
) -> Tuple[List[np.ndarray], int]:
|
| 739 |
-
"""
|
| 740 |
-
Generate speech with the CustomVoice model using a predefined speaker id, optionally controlled by instruction text.
|
| 741 |
-
|
| 742 |
-
Args:
|
| 743 |
-
text:
|
| 744 |
-
Text(s) to synthesize.
|
| 745 |
-
language:
|
| 746 |
-
Language(s) for each sample.
|
| 747 |
-
speaker:
|
| 748 |
-
Speaker name(s). Will be validated against `model.get_supported_speakers()` (case-insensitive).
|
| 749 |
-
instruct:
|
| 750 |
-
Optional instruction(s). If None, treated as empty (no instruction).
|
| 751 |
-
non_streaming_mode:
|
| 752 |
-
Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
|
| 753 |
-
rather than enabling true streaming input or streaming generation.
|
| 754 |
-
do_sample:
|
| 755 |
-
Whether to use sampling, recommended to be set to `true` for most use cases.
|
| 756 |
-
top_k:
|
| 757 |
-
Top-k sampling parameter.
|
| 758 |
-
top_p:
|
| 759 |
-
Top-p sampling parameter.
|
| 760 |
-
temperature:
|
| 761 |
-
Sampling temperature; higher => more random.
|
| 762 |
-
repetition_penalty:
|
| 763 |
-
Penalty to reduce repeated tokens/codes.
|
| 764 |
-
subtalker_dosample:
|
| 765 |
-
Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
|
| 766 |
-
subtalker_top_k:
|
| 767 |
-
Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 768 |
-
subtalker_top_p:
|
| 769 |
-
Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 770 |
-
subtalker_temperature:
|
| 771 |
-
Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 772 |
-
max_new_tokens:
|
| 773 |
-
Maximum number of new codec tokens to generate.
|
| 774 |
-
**kwargs:
|
| 775 |
-
Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
|
| 776 |
-
They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
|
| 777 |
-
|
| 778 |
-
Returns:
|
| 779 |
-
Tuple[List[np.ndarray], int]:
|
| 780 |
-
(wavs, sample_rate)
|
| 781 |
-
|
| 782 |
-
Raises:
|
| 783 |
-
ValueError:
|
| 784 |
-
If any speaker/language is unsupported or batch sizes mismatch.
|
| 785 |
-
"""
|
| 786 |
-
if self.model.tts_model_type != "custom_voice":
|
| 787 |
-
raise ValueError(
|
| 788 |
-
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
|
| 789 |
-
f"tts_model_size: {self.model.tts_model_size}\n"
|
| 790 |
-
f"tts_model_type: {self.model.tts_model_type}\n"
|
| 791 |
-
"does not support generate_custom_voice, Please check Model Card or Readme for more details."
|
| 792 |
-
)
|
| 793 |
-
|
| 794 |
-
texts = self._ensure_list(text)
|
| 795 |
-
languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
|
| 796 |
-
speakers = self._ensure_list(speaker)
|
| 797 |
-
if self.model.tts_model_size in "0b6": # for 0b6 model, instruct is not supported
|
| 798 |
-
instruct = None
|
| 799 |
-
instructs = self._ensure_list(instruct) if isinstance(instruct, list) else ([instruct] * len(texts) if instruct is not None else [""] * len(texts))
|
| 800 |
-
|
| 801 |
-
if len(languages) == 1 and len(texts) > 1:
|
| 802 |
-
languages = languages * len(texts)
|
| 803 |
-
if len(speakers) == 1 and len(texts) > 1:
|
| 804 |
-
speakers = speakers * len(texts)
|
| 805 |
-
if len(instructs) == 1 and len(texts) > 1:
|
| 806 |
-
instructs = instructs * len(texts)
|
| 807 |
-
|
| 808 |
-
if not (len(texts) == len(languages) == len(speakers) == len(instructs)):
|
| 809 |
-
raise ValueError(
|
| 810 |
-
f"Batch size mismatch: text={len(texts)}, language={len(languages)}, speaker={len(speakers)}, instruct={len(instructs)}"
|
| 811 |
-
)
|
| 812 |
-
|
| 813 |
-
self._validate_languages(languages)
|
| 814 |
-
self._validate_speakers(speakers)
|
| 815 |
-
|
| 816 |
-
input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])
|
| 817 |
-
|
| 818 |
-
instruct_ids: List[Optional[torch.Tensor]] = []
|
| 819 |
-
for ins in instructs:
|
| 820 |
-
if ins is None or ins == "":
|
| 821 |
-
instruct_ids.append(None)
|
| 822 |
-
else:
|
| 823 |
-
instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])
|
| 824 |
-
|
| 825 |
-
gen_kwargs = self._merge_generate_kwargs(**kwargs)
|
| 826 |
-
|
| 827 |
-
talker_codes_list, _ = self.model.generate(
|
| 828 |
-
input_ids=input_ids,
|
| 829 |
-
instruct_ids=instruct_ids,
|
| 830 |
-
languages=languages,
|
| 831 |
-
speakers=speakers,
|
| 832 |
-
**gen_kwargs,
|
| 833 |
-
)
|
| 834 |
-
|
| 835 |
-
wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
|
| 836 |
-
return wavs, fs
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
def get_supported_speakers(self) -> Optional[List[str]]:
|
| 840 |
-
"""
|
| 841 |
-
List supported speaker names for the current model.
|
| 842 |
-
|
| 843 |
-
This is a convenience wrapper around `model.get_supported_speakers()`.
|
| 844 |
-
If the underlying model does not expose speaker constraints (returns None),
|
| 845 |
-
this method also returns None.
|
| 846 |
-
|
| 847 |
-
Returns:
|
| 848 |
-
Optional[List[str]]:
|
| 849 |
-
- A sorted list of supported speaker names (lowercased), if available.
|
| 850 |
-
- None if the model does not provide supported speakers.
|
| 851 |
-
"""
|
| 852 |
-
supported = self._supported_speakers_set()
|
| 853 |
-
if supported is None:
|
| 854 |
-
return None
|
| 855 |
-
return sorted(supported)
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
def get_supported_languages(self) -> Optional[List[str]]:
|
| 859 |
-
"""
|
| 860 |
-
List supported language names for the current model.
|
| 861 |
-
|
| 862 |
-
This is a convenience wrapper around `model.get_supported_languages()`.
|
| 863 |
-
If the underlying model does not expose language constraints (returns None),
|
| 864 |
-
this method also returns None.
|
| 865 |
-
|
| 866 |
-
Returns:
|
| 867 |
-
Optional[List[str]]:
|
| 868 |
-
- A sorted list of supported language names (lowercased), if available.
|
| 869 |
-
- None if the model does not provide supported languages.
|
| 870 |
-
"""
|
| 871 |
-
supported = self._supported_languages_set()
|
| 872 |
-
if supported is None:
|
| 873 |
-
return None
|
| 874 |
-
return sorted(supported)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwen_tts/inference/qwen3_tts_tokenizer.py
DELETED
|
@@ -1,411 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
import base64
|
| 17 |
-
import io
|
| 18 |
-
import urllib.request
|
| 19 |
-
from typing import List, Optional, Tuple, Union
|
| 20 |
-
from urllib.parse import urlparse
|
| 21 |
-
|
| 22 |
-
import librosa
|
| 23 |
-
import numpy as np
|
| 24 |
-
import soundfile as sf
|
| 25 |
-
import torch
|
| 26 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 27 |
-
from transformers import AutoConfig, AutoFeatureExtractor, AutoModel
|
| 28 |
-
|
| 29 |
-
from ..core import (
|
| 30 |
-
Qwen3TTSTokenizerV1Config,
|
| 31 |
-
Qwen3TTSTokenizerV1Model,
|
| 32 |
-
Qwen3TTSTokenizerV2Config,
|
| 33 |
-
Qwen3TTSTokenizerV2Model,
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
AudioInput = Union[
|
| 37 |
-
str, # wav path, or base64 string
|
| 38 |
-
np.ndarray, # 1-D float array
|
| 39 |
-
List[str],
|
| 40 |
-
List[np.ndarray],
|
| 41 |
-
]
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
class Qwen3TTSTokenizer:
|
| 45 |
-
"""
|
| 46 |
-
A wrapper for Qwen3 TTS Tokenizer 25Hz/12Hz with HuggingFace-style loading.
|
| 47 |
-
|
| 48 |
-
- from_pretrained(): loads speech tokenizer model via AutoModel and feature_extractor via AutoFeatureExtractor.
|
| 49 |
-
- encode(): supports wav path(s), base64 audio string(s), numpy array(s).
|
| 50 |
-
- decode(): accepts either the raw model encode output, or a minimal dict/list-of-dicts.
|
| 51 |
-
|
| 52 |
-
Notes:
|
| 53 |
-
- For numpy array input, you must pass `sr` so the audio can be resampled to model sample rate.
|
| 54 |
-
- Returned audio is float32 numpy arrays and the output sample rate.
|
| 55 |
-
"""
|
| 56 |
-
|
| 57 |
-
def __init__(self):
|
| 58 |
-
self.model = None
|
| 59 |
-
self.feature_extractor = None
|
| 60 |
-
self.config = None
|
| 61 |
-
self.device = None
|
| 62 |
-
|
| 63 |
-
@classmethod
|
| 64 |
-
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3TTSTokenizer":
|
| 65 |
-
"""
|
| 66 |
-
Initialize tokenizer with HuggingFace `from_pretrained` style.
|
| 67 |
-
|
| 68 |
-
Args:
|
| 69 |
-
pretrained_model_name_or_path (str):
|
| 70 |
-
HuggingFace repo id or local directory.
|
| 71 |
-
**kwargs (Any):
|
| 72 |
-
Forwarded to `AutoModel.from_pretrained(...)` directly.
|
| 73 |
-
Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="eager".
|
| 74 |
-
|
| 75 |
-
Returns:
|
| 76 |
-
Qwen3TTSTokenizer:
|
| 77 |
-
Initialized instance with `model`, `feature_extractor`, `config`.
|
| 78 |
-
"""
|
| 79 |
-
inst = cls()
|
| 80 |
-
|
| 81 |
-
AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config)
|
| 82 |
-
AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model)
|
| 83 |
-
|
| 84 |
-
AutoConfig.register("qwen3_tts_tokenizer_12hz", Qwen3TTSTokenizerV2Config)
|
| 85 |
-
AutoModel.register(Qwen3TTSTokenizerV2Config, Qwen3TTSTokenizerV2Model)
|
| 86 |
-
|
| 87 |
-
inst.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
|
| 88 |
-
inst.model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 89 |
-
inst.config = inst.model.config
|
| 90 |
-
|
| 91 |
-
inst.device = getattr(inst.model, "device", None)
|
| 92 |
-
if inst.device is None:
|
| 93 |
-
# fallback: infer from first parameter device
|
| 94 |
-
try:
|
| 95 |
-
inst.device = next(inst.model.parameters()).device
|
| 96 |
-
except StopIteration:
|
| 97 |
-
inst.device = torch.device("cpu")
|
| 98 |
-
|
| 99 |
-
return inst
|
| 100 |
-
|
| 101 |
-
def _is_probably_base64(self, s: str) -> bool:
|
| 102 |
-
if s.startswith("data:audio"):
|
| 103 |
-
return True
|
| 104 |
-
# Heuristic: no filesystem path separators and long enough.
|
| 105 |
-
if ("/" not in s and "\\" not in s) and len(s) > 256:
|
| 106 |
-
return True
|
| 107 |
-
return False
|
| 108 |
-
|
| 109 |
-
def _is_url(self, s: str) -> bool:
|
| 110 |
-
try:
|
| 111 |
-
u = urlparse(s)
|
| 112 |
-
return u.scheme in ("http", "https") and bool(u.netloc)
|
| 113 |
-
except Exception:
|
| 114 |
-
return False
|
| 115 |
-
|
| 116 |
-
def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
|
| 117 |
-
# Accept both "data:audio/wav;base64,...." and raw base64
|
| 118 |
-
if "," in b64 and b64.strip().startswith("data:"):
|
| 119 |
-
b64 = b64.split(",", 1)[1]
|
| 120 |
-
return base64.b64decode(b64)
|
| 121 |
-
|
| 122 |
-
def load_audio(
|
| 123 |
-
self,
|
| 124 |
-
x: str,
|
| 125 |
-
target_sr: int,
|
| 126 |
-
) -> np.ndarray:
|
| 127 |
-
"""
|
| 128 |
-
Load audio from wav path or base64 string, then resample to target_sr.
|
| 129 |
-
|
| 130 |
-
Args:
|
| 131 |
-
x (str):
|
| 132 |
-
A wav file path, or a base64 audio string (raw or data URL).
|
| 133 |
-
target_sr (int):
|
| 134 |
-
Target sampling rate.
|
| 135 |
-
|
| 136 |
-
Returns:
|
| 137 |
-
np.ndarray:
|
| 138 |
-
1-D float32 waveform at target_sr.
|
| 139 |
-
"""
|
| 140 |
-
if self._is_url(x):
|
| 141 |
-
with urllib.request.urlopen(x) as resp:
|
| 142 |
-
audio_bytes = resp.read()
|
| 143 |
-
with io.BytesIO(audio_bytes) as f:
|
| 144 |
-
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
| 145 |
-
elif self._is_probably_base64(x):
|
| 146 |
-
wav_bytes = self._decode_base64_to_wav_bytes(x)
|
| 147 |
-
with io.BytesIO(wav_bytes) as f:
|
| 148 |
-
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
| 149 |
-
else:
|
| 150 |
-
audio, sr = librosa.load(x, sr=None, mono=True)
|
| 151 |
-
|
| 152 |
-
if audio.ndim > 1:
|
| 153 |
-
audio = np.mean(audio, axis=-1)
|
| 154 |
-
|
| 155 |
-
if sr != target_sr:
|
| 156 |
-
audio = librosa.resample(y=audio, orig_sr=sr, target_sr=target_sr)
|
| 157 |
-
|
| 158 |
-
return audio.astype(np.float32)
|
| 159 |
-
|
| 160 |
-
def _normalize_audio_inputs(
|
| 161 |
-
self,
|
| 162 |
-
audios: AudioInput,
|
| 163 |
-
sr: Optional[int],
|
| 164 |
-
) -> List[np.ndarray]:
|
| 165 |
-
"""
|
| 166 |
-
Normalize all supported input types into a list of 1-D numpy float32 waveforms
|
| 167 |
-
at `self.feature_extractor.sampling_rate`.
|
| 168 |
-
|
| 169 |
-
Args:
|
| 170 |
-
audios (AudioInput):
|
| 171 |
-
- str: wav path OR base64 audio string
|
| 172 |
-
- np.ndarray: raw waveform (sr must be provided)
|
| 173 |
-
- list[str] / list[np.ndarray]
|
| 174 |
-
sr (Optional[int]):
|
| 175 |
-
Sampling rate for raw numpy input. Required if input is np.ndarray or list[np.ndarray].
|
| 176 |
-
|
| 177 |
-
Returns:
|
| 178 |
-
List[np.ndarray]:
|
| 179 |
-
List of float32 waveforms resampled to model input SR.
|
| 180 |
-
"""
|
| 181 |
-
target_sr = int(self.feature_extractor.sampling_rate)
|
| 182 |
-
|
| 183 |
-
if isinstance(audios, (str, np.ndarray)):
|
| 184 |
-
audios = [audios]
|
| 185 |
-
|
| 186 |
-
if len(audios) == 0:
|
| 187 |
-
return []
|
| 188 |
-
|
| 189 |
-
if isinstance(audios[0], str):
|
| 190 |
-
# wav path list or base64 list
|
| 191 |
-
return [self.load_audio(x, target_sr=target_sr) for x in audios] # type: ignore[arg-type]
|
| 192 |
-
|
| 193 |
-
# numpy list
|
| 194 |
-
if sr is None:
|
| 195 |
-
raise ValueError("For numpy waveform input, you must provide `sr` (original sampling rate).")
|
| 196 |
-
|
| 197 |
-
out: List[np.ndarray] = []
|
| 198 |
-
for a in audios: # type: ignore[assignment]
|
| 199 |
-
if not isinstance(a, np.ndarray):
|
| 200 |
-
raise TypeError("Mixed input types are not supported. Use all paths/base64 or all numpy arrays.")
|
| 201 |
-
if a.ndim > 1:
|
| 202 |
-
a = np.mean(a, axis=-1)
|
| 203 |
-
if int(sr) != target_sr:
|
| 204 |
-
a = librosa.resample(y=a.astype(np.float32), orig_sr=int(sr), target_sr=target_sr)
|
| 205 |
-
out.append(a.astype(np.float32))
|
| 206 |
-
return out
|
| 207 |
-
|
| 208 |
-
def encode(
|
| 209 |
-
self,
|
| 210 |
-
audios: AudioInput,
|
| 211 |
-
sr: Optional[int] = None,
|
| 212 |
-
return_dict: bool = True,
|
| 213 |
-
):
|
| 214 |
-
"""
|
| 215 |
-
Batch-encode audio into discrete codes (and optional conditioning, depending on 25Hz/12Hz).
|
| 216 |
-
|
| 217 |
-
Args:
|
| 218 |
-
audios (AudioInput):
|
| 219 |
-
Supported forms:
|
| 220 |
-
- np.ndarray: waveform (requires sr)
|
| 221 |
-
- list[np.ndarray]: waveforms (requires sr)
|
| 222 |
-
- str: wav path OR base64 audio string
|
| 223 |
-
- list[str]: wav paths and/or base64 strings
|
| 224 |
-
sr (Optional[int], default=None):
|
| 225 |
-
Original sampling rate for numpy waveform input.
|
| 226 |
-
return_dict (bool, default=True):
|
| 227 |
-
Forwarded to model.encode(...). If True, returns ModelOutput.
|
| 228 |
-
|
| 229 |
-
Returns:
|
| 230 |
-
25Hz:
|
| 231 |
-
Qwen3TTSTokenizerV1EncoderOutput (if return_dict=True) with fields:
|
| 232 |
-
- audio_codes: List[torch.LongTensor] each (codes_len,)
|
| 233 |
-
- xvectors: List[torch.FloatTensor] each (xvector_dim,)
|
| 234 |
-
- ref_mels: List[torch.FloatTensor] each (mel_len, mel_dim)
|
| 235 |
-
12Hz:
|
| 236 |
-
Qwen3TTSTokenizerV2EncoderOutput (if return_dict=True) with fields:
|
| 237 |
-
- audio_codes: List[torch.LongTensor] each (codes_len, num_quantizers)
|
| 238 |
-
|
| 239 |
-
If return_dict=False, returns the raw tuple from model.encode.
|
| 240 |
-
"""
|
| 241 |
-
wavs = self._normalize_audio_inputs(audios, sr=sr)
|
| 242 |
-
|
| 243 |
-
inputs = self.feature_extractor(
|
| 244 |
-
raw_audio=wavs,
|
| 245 |
-
sampling_rate=int(self.feature_extractor.sampling_rate),
|
| 246 |
-
return_tensors="pt",
|
| 247 |
-
)
|
| 248 |
-
inputs = inputs.to(self.device).to(self.model.dtype)
|
| 249 |
-
|
| 250 |
-
with torch.inference_mode():
|
| 251 |
-
# model.encode expects (B, T) and (B, T)
|
| 252 |
-
enc = self.model.encode(
|
| 253 |
-
inputs["input_values"].squeeze(1),
|
| 254 |
-
inputs["padding_mask"].squeeze(1),
|
| 255 |
-
return_dict=return_dict,
|
| 256 |
-
)
|
| 257 |
-
return enc
|
| 258 |
-
|
| 259 |
-
def decode(
|
| 260 |
-
self,
|
| 261 |
-
encoded,
|
| 262 |
-
) -> Tuple[List[np.ndarray], int]:
|
| 263 |
-
"""
|
| 264 |
-
Decode back to waveform.
|
| 265 |
-
|
| 266 |
-
Usage:
|
| 267 |
-
1) Pass the raw output of `encode(...)` directly (recommended).
|
| 268 |
-
- 25Hz: expects fields audio_codes, xvectors, ref_mels
|
| 269 |
-
- 12Hz: expects field audio_codes
|
| 270 |
-
2) Pass a dict or list[dict] (minimal form) for custom pipelines:
|
| 271 |
-
- 25Hz dict keys: {"audio_codes", "xvectors", "ref_mels"}
|
| 272 |
-
- 12Hz dict keys: {"audio_codes"}
|
| 273 |
-
Values can be torch tensors or numpy arrays.
|
| 274 |
-
|
| 275 |
-
Args:
|
| 276 |
-
encoded (Any):
|
| 277 |
-
- ModelOutput returned by `encode()`, OR
|
| 278 |
-
- dict, OR
|
| 279 |
-
- list[dict]
|
| 280 |
-
|
| 281 |
-
Returns:
|
| 282 |
-
Tuple[List[np.ndarray], int]:
|
| 283 |
-
- wavs: list of 1-D float32 numpy arrays
|
| 284 |
-
- sample_rate: int, model output sampling rate
|
| 285 |
-
"""
|
| 286 |
-
model_type = self.model.get_model_type()
|
| 287 |
-
|
| 288 |
-
def _to_tensor(x, dtype=None):
|
| 289 |
-
if isinstance(x, torch.Tensor):
|
| 290 |
-
return x
|
| 291 |
-
x = np.asarray(x)
|
| 292 |
-
t = torch.from_numpy(x)
|
| 293 |
-
if dtype is not None:
|
| 294 |
-
t = t.to(dtype)
|
| 295 |
-
return t
|
| 296 |
-
|
| 297 |
-
# Normalize `encoded` into the same shapes as the official demo uses.
|
| 298 |
-
if hasattr(encoded, "audio_codes"):
|
| 299 |
-
# ModelOutput from encode()
|
| 300 |
-
audio_codes_list = encoded.audio_codes
|
| 301 |
-
xvectors_list = getattr(encoded, "xvectors", None)
|
| 302 |
-
ref_mels_list = getattr(encoded, "ref_mels", None)
|
| 303 |
-
elif isinstance(encoded, dict):
|
| 304 |
-
audio_codes_list = encoded["audio_codes"]
|
| 305 |
-
xvectors_list = encoded.get("xvectors", None)
|
| 306 |
-
ref_mels_list = encoded.get("ref_mels", None)
|
| 307 |
-
elif isinstance(encoded, list):
|
| 308 |
-
# list of dicts
|
| 309 |
-
audio_codes_list = [e["audio_codes"] for e in encoded]
|
| 310 |
-
xvectors_list = [e["xvectors"] for e in encoded] if ("xvectors" in encoded[0]) else None
|
| 311 |
-
ref_mels_list = [e["ref_mels"] for e in encoded] if ("ref_mels" in encoded[0]) else None
|
| 312 |
-
else:
|
| 313 |
-
raise TypeError("`encoded` must be an encode output, a dict, or a list of dicts.")
|
| 314 |
-
|
| 315 |
-
# Ensure list form for per-sample tensors
|
| 316 |
-
if isinstance(audio_codes_list, torch.Tensor):
|
| 317 |
-
# Could be a single sample tensor or an already padded batch tensor.
|
| 318 |
-
t = audio_codes_list
|
| 319 |
-
if t.dim() == 1:
|
| 320 |
-
# 25Hz single sample: (C,) -> (1, C)
|
| 321 |
-
t = t.unsqueeze(0)
|
| 322 |
-
elif t.dim() == 2:
|
| 323 |
-
# 12Hz single sample: (C, Q) -> (1, C, Q)
|
| 324 |
-
t = t.unsqueeze(0)
|
| 325 |
-
audio_codes_padded = t.to(self.device)
|
| 326 |
-
else:
|
| 327 |
-
# List[Tensor/np]
|
| 328 |
-
audio_codes_list = [_to_tensor(c, dtype=torch.long) for c in audio_codes_list]
|
| 329 |
-
audio_codes_padded = pad_sequence(audio_codes_list, batch_first=True, padding_value=0).to(self.device)
|
| 330 |
-
|
| 331 |
-
with torch.inference_mode():
|
| 332 |
-
if model_type == "qwen3_tts_tokenizer_25hz":
|
| 333 |
-
if xvectors_list is None or ref_mels_list is None:
|
| 334 |
-
raise ValueError("25Hz decode requires `xvectors` and `ref_mels`.")
|
| 335 |
-
|
| 336 |
-
if isinstance(xvectors_list, torch.Tensor):
|
| 337 |
-
xvectors_batch = xvectors_list
|
| 338 |
-
if xvectors_batch.dim() == 1: # (D,) -> (1, D)
|
| 339 |
-
xvectors_batch = xvectors_batch.unsqueeze(0)
|
| 340 |
-
xvectors_batch = xvectors_batch.to(self.device).to(self.model.dtype)
|
| 341 |
-
else:
|
| 342 |
-
xvectors_list = [_to_tensor(x, dtype=torch.float32) for x in xvectors_list]
|
| 343 |
-
xvectors_batch = torch.stack(xvectors_list, dim=0).to(self.device).to(self.model.dtype)
|
| 344 |
-
|
| 345 |
-
if isinstance(ref_mels_list, torch.Tensor):
|
| 346 |
-
ref_mels_padded = ref_mels_list
|
| 347 |
-
if ref_mels_padded.dim() == 2: # (T, M) -> (1, T, M)
|
| 348 |
-
ref_mels_padded = ref_mels_padded.unsqueeze(0)
|
| 349 |
-
ref_mels_padded = ref_mels_padded.to(self.device).to(self.model.dtype)
|
| 350 |
-
else:
|
| 351 |
-
ref_mels_list = [_to_tensor(m, dtype=torch.float32) for m in ref_mels_list]
|
| 352 |
-
ref_mels_padded = pad_sequence(ref_mels_list, batch_first=True, padding_value=0).to(self.device).to(self.model.dtype)
|
| 353 |
-
|
| 354 |
-
dec = self.model.decode(audio_codes_padded, xvectors_batch, ref_mels_padded, return_dict=True)
|
| 355 |
-
wav_tensors = dec.audio_values
|
| 356 |
-
|
| 357 |
-
elif model_type == "qwen3_tts_tokenizer_12hz":
|
| 358 |
-
dec = self.model.decode(audio_codes_padded, return_dict=True)
|
| 359 |
-
wav_tensors = dec.audio_values
|
| 360 |
-
|
| 361 |
-
else:
|
| 362 |
-
raise ValueError(f"Unknown model type: {model_type}")
|
| 363 |
-
|
| 364 |
-
wavs = [w.to(torch.float32).detach().cpu().numpy() for w in wav_tensors]
|
| 365 |
-
return wavs, int(self.model.get_output_sample_rate())
|
| 366 |
-
|
| 367 |
-
def get_model_type(self) -> str:
|
| 368 |
-
"""
|
| 369 |
-
Get the underlying tokenizer model type.
|
| 370 |
-
|
| 371 |
-
Returns:
|
| 372 |
-
str: Model type string from `self.model.config.model_type`
|
| 373 |
-
(e.g. "qwen3_tts_tokenizer_25hz" / "qwen3_tts_tokenizer_12hz").
|
| 374 |
-
"""
|
| 375 |
-
return self.model.get_model_type()
|
| 376 |
-
|
| 377 |
-
def get_input_sample_rate(self) -> int:
|
| 378 |
-
"""
|
| 379 |
-
Get the expected input sample rate for encoding.
|
| 380 |
-
|
| 381 |
-
Returns:
|
| 382 |
-
int: Input sample rate (Hz).
|
| 383 |
-
"""
|
| 384 |
-
return int(self.model.get_input_sample_rate())
|
| 385 |
-
|
| 386 |
-
def get_output_sample_rate(self) -> int:
|
| 387 |
-
"""
|
| 388 |
-
Get the output sample rate for decoded waveforms.
|
| 389 |
-
|
| 390 |
-
Returns:
|
| 391 |
-
int: Output sample rate (Hz).
|
| 392 |
-
"""
|
| 393 |
-
return int(self.model.get_output_sample_rate())
|
| 394 |
-
|
| 395 |
-
def get_encode_downsample_rate(self) -> int:
|
| 396 |
-
"""
|
| 397 |
-
Get the encoder downsample rate (waveform samples per code step).
|
| 398 |
-
|
| 399 |
-
Returns:
|
| 400 |
-
int: Encode downsample rate.
|
| 401 |
-
"""
|
| 402 |
-
return int(self.model.get_encode_downsample_rate())
|
| 403 |
-
|
| 404 |
-
def get_decode_upsample_rate(self) -> int:
|
| 405 |
-
"""
|
| 406 |
-
Get the decoder upsample rate (waveform samples per code step).
|
| 407 |
-
|
| 408 |
-
Returns:
|
| 409 |
-
int: Decode upsample rate.
|
| 410 |
-
"""
|
| 411 |
-
return int(self.model.get_decode_upsample_rate())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
daggr==0.5.2
|
| 2 |
-
torch
|
| 3 |
-
transformers==4.57.3
|
| 4 |
-
accelerate==1.12.0
|
| 5 |
-
einops
|
| 6 |
-
gradio
|
| 7 |
-
librosa
|
| 8 |
-
torchaudio
|
| 9 |
-
soundfile
|
| 10 |
-
sox
|
| 11 |
-
nagisa==0.2.11
|
| 12 |
-
soynlp==0.0.493
|
| 13 |
-
onnxruntime
|
| 14 |
-
scipy
|
| 15 |
-
torch
|
| 16 |
-
numpy
|
| 17 |
-
huggingface_hub
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|