Spaces:
Build error
Build error
Michael Hu commited on
Commit ·
d77f8ff
1
Parent(s): dda48cd
remove vibevoice
Browse files- README.md +0 -1
- app.py +2 -300
- src/vibevoice/__init__.py +0 -0
- src/vibevoice/configs/qwen2.5_1.5b_64k.json +0 -112
- src/vibevoice/configs/qwen2.5_7b_32k.json +0 -113
- src/vibevoice/modular/__init__.py +0 -0
- src/vibevoice/modular/configuration_vibevoice.py +0 -248
- src/vibevoice/modular/modeling_vibevoice.py +0 -487
- src/vibevoice/modular/modeling_vibevoice_inference.py +0 -716
- src/vibevoice/modular/modular_vibevoice_diffusion_head.py +0 -287
- src/vibevoice/modular/modular_vibevoice_text_tokenizer.py +0 -214
- src/vibevoice/modular/modular_vibevoice_tokenizer.py +0 -1195
- src/vibevoice/modular/streamer.py +0 -264
- src/vibevoice/processor/__init__.py +0 -0
- src/vibevoice/processor/vibevoice_processor.py +0 -701
- src/vibevoice/processor/vibevoice_tokenizer_processor.py +0 -483
- src/vibevoice/schedule/__init__.py +0 -0
- src/vibevoice/schedule/dpm_solver.py +0 -1065
- src/vibevoice/schedule/timestep_sampler.py +0 -19
- src/vibevoice/scripts/__init__.py +0 -0
- src/vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py +0 -166
- src/voices/vibe_voices/en-Alice_woman.wav +0 -3
- src/voices/vibe_voices/en-Carter_man.wav +0 -3
- src/voices/vibe_voices/en-Frank_man.wav +0 -3
- src/voices/vibe_voices/en-Mary_woman_bgm.wav +0 -3
- src/voices/vibe_voices/en-Maya_woman.wav +0 -3
- src/voices/vibe_voices/in-Samuel_man.wav +0 -3
- src/voices/vibe_voices/zh-Anchen_man_bgm.wav +0 -3
- src/voices/vibe_voices/zh-Bowen_man.wav +0 -3
- src/voices/vibe_voices/zh-Xinran_woman.wav +0 -3
README.md
CHANGED
|
@@ -47,7 +47,6 @@ This demo showcases the multilingual capabilities of multiple TTS models, suppor
|
|
| 47 |
- **Chatterbox**: Industrial-grade multilingual TTS solution
|
| 48 |
- **KittenTTS**: High-quality TTS with voice cloning capabilities
|
| 49 |
- **Piper**: Local on-device TTS with multiple voice options
|
| 50 |
-
- **VibeVoice 1.5B**: Microsoft's advanced seq2seq TTS model
|
| 51 |
|
| 52 |
## Examples
|
| 53 |
|
|
|
|
| 47 |
- **Chatterbox**: Industrial-grade multilingual TTS solution
|
| 48 |
- **KittenTTS**: High-quality TTS with voice cloning capabilities
|
| 49 |
- **Piper**: Local on-device TTS with multiple voice options
|
|
|
|
| 50 |
|
| 51 |
## Examples
|
| 52 |
|
app.py
CHANGED
|
@@ -20,7 +20,6 @@ MODEL_DESCRIPTIONS = {
|
|
| 20 |
"ResembleAI/chatterbox": "Industrial-grade TTS solution with multilingual support",
|
| 21 |
"KittenML/KittenTTS": "High-quality TTS with voice cloning capabilities using reference audio",
|
| 22 |
"piper-tts": "Local on-device TTS with dynamic English and Chinese voice selection from Piper models",
|
| 23 |
-
"microsoft/VibeVoice-1.5B": "Microsoft's advanced seq2seq TTS model with high-quality speech synthesis",
|
| 24 |
}
|
| 25 |
|
| 26 |
# Models dictionary
|
|
@@ -28,7 +27,6 @@ MODELS = {
|
|
| 28 |
"ResembleAI/chatterbox": "Chatterbox",
|
| 29 |
"KittenML/KittenTTS": "KittenTTS",
|
| 30 |
"piper-tts": "Piper (no voice cloning)",
|
| 31 |
-
"microsoft/VibeVoice-1.5B": "VibeVoice 1.5B",
|
| 32 |
}
|
| 33 |
|
| 34 |
original_torch_load = torch.load
|
|
@@ -53,130 +51,6 @@ except RuntimeError as e:
|
|
| 53 |
# Initialize KittenTTS model
|
| 54 |
kittentts_model = KittenTTS("KittenML/kitten-tts-nano-0.2")
|
| 55 |
|
| 56 |
-
# Initialize VibeVoice model
|
| 57 |
-
vibevoice_model = None
|
| 58 |
-
vibevoice_processor = None
|
| 59 |
-
vibevoice_voices = {}
|
| 60 |
-
|
| 61 |
-
def initialize_vibevoice():
|
| 62 |
-
"""Initialize VibeVoice model using the proper VibeVoice classes"""
|
| 63 |
-
global vibevoice_model, vibevoice_processor, vibevoice_voices
|
| 64 |
-
try:
|
| 65 |
-
# Add the src directory to Python path to make vibe-voice importable
|
| 66 |
-
src_path = os.path.join(os.path.dirname(__file__), 'src')
|
| 67 |
-
if src_path not in sys.path:
|
| 68 |
-
sys.path.insert(0, src_path)
|
| 69 |
-
|
| 70 |
-
# Import VibeVoice specific classes from src/vibe-voice directory
|
| 71 |
-
# Use underscore import since hyphens aren't valid in Python module names
|
| 72 |
-
vibe_voice_path = os.path.join(src_path, 'vibevoice')
|
| 73 |
-
if vibe_voice_path not in sys.path:
|
| 74 |
-
sys.path.insert(0, vibe_voice_path)
|
| 75 |
-
|
| 76 |
-
# Now import using the actual module structure
|
| 77 |
-
from modular.configuration_vibevoice import VibeVoiceConfig
|
| 78 |
-
from modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
| 79 |
-
from processor.vibevoice_processor import VibeVoiceProcessor
|
| 80 |
-
|
| 81 |
-
# Determine device
|
| 82 |
-
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
| 83 |
-
|
| 84 |
-
# Load processor
|
| 85 |
-
print("Loading VibeVoice processor...")
|
| 86 |
-
vibevoice_processor = VibeVoiceProcessor.from_pretrained("microsoft/VibeVoice-1.5B")
|
| 87 |
-
|
| 88 |
-
# Determine dtype and attention implementation based on device
|
| 89 |
-
if device == "mps":
|
| 90 |
-
load_dtype = torch.float32
|
| 91 |
-
attn_impl_primary = "sdpa"
|
| 92 |
-
elif device == "cuda":
|
| 93 |
-
load_dtype = torch.bfloat16
|
| 94 |
-
attn_impl_primary = "flash_attention_2"
|
| 95 |
-
else:
|
| 96 |
-
load_dtype = torch.float32
|
| 97 |
-
attn_impl_primary = "sdpa"
|
| 98 |
-
|
| 99 |
-
print(f"Using device: {device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
|
| 100 |
-
|
| 101 |
-
# Load model
|
| 102 |
-
print("Loading VibeVoice model...")
|
| 103 |
-
try:
|
| 104 |
-
if device == "mps":
|
| 105 |
-
vibevoice_model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
| 106 |
-
"microsoft/VibeVoice-1.5B",
|
| 107 |
-
torch_dtype=load_dtype,
|
| 108 |
-
attn_implementation=attn_impl_primary,
|
| 109 |
-
device_map=None,
|
| 110 |
-
)
|
| 111 |
-
vibevoice_model.to("mps")
|
| 112 |
-
elif device == "cuda":
|
| 113 |
-
vibevoice_model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
| 114 |
-
"microsoft/VibeVoice-1.5B",
|
| 115 |
-
torch_dtype=load_dtype,
|
| 116 |
-
device_map="cuda",
|
| 117 |
-
attn_implementation=attn_impl_primary,
|
| 118 |
-
)
|
| 119 |
-
else:
|
| 120 |
-
vibevoice_model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
| 121 |
-
"microsoft/VibeVoice-1.5B",
|
| 122 |
-
torch_dtype=load_dtype,
|
| 123 |
-
device_map="cpu",
|
| 124 |
-
attn_implementation=attn_impl_primary,
|
| 125 |
-
)
|
| 126 |
-
except Exception as e:
|
| 127 |
-
if attn_impl_primary == 'flash_attention_2':
|
| 128 |
-
print(f"[ERROR] : {type(e).__name__}: {e}")
|
| 129 |
-
print("Falling back to attention implementation: sdpa")
|
| 130 |
-
vibevoice_model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
| 131 |
-
"microsoft/VibeVoice-1.5B",
|
| 132 |
-
torch_dtype=load_dtype,
|
| 133 |
-
device_map=(device if device in ("cuda", "cpu") else None),
|
| 134 |
-
attn_implementation="sdpa",
|
| 135 |
-
)
|
| 136 |
-
if device == "mps":
|
| 137 |
-
vibevoice_model.to("mps")
|
| 138 |
-
else:
|
| 139 |
-
raise e
|
| 140 |
-
|
| 141 |
-
vibevoice_model.eval()
|
| 142 |
-
|
| 143 |
-
# Setup noise scheduler for SDE solver
|
| 144 |
-
vibevoice_model.model.noise_scheduler = vibevoice_model.model.noise_scheduler.from_config(
|
| 145 |
-
vibevoice_model.model.noise_scheduler.config,
|
| 146 |
-
algorithm_type='sde-dpmsolver++',
|
| 147 |
-
beta_schedule='squaredcos_cap_v2'
|
| 148 |
-
)
|
| 149 |
-
vibevoice_model.set_ddpm_inference_steps(num_steps=10)
|
| 150 |
-
|
| 151 |
-
# Load voice presets
|
| 152 |
-
voices_dir = "src/voices/vibe_voices"
|
| 153 |
-
if os.path.exists(voices_dir):
|
| 154 |
-
wav_files = [f for f in os.listdir(voices_dir)
|
| 155 |
-
if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')) and os.path.isfile(os.path.join(voices_dir, f))]
|
| 156 |
-
|
| 157 |
-
for wav_file in wav_files:
|
| 158 |
-
name = os.path.splitext(wav_file)[0]
|
| 159 |
-
full_path = os.path.join(voices_dir, wav_file)
|
| 160 |
-
vibevoice_voices[name] = full_path
|
| 161 |
-
|
| 162 |
-
vibevoice_voices = dict(sorted(vibevoice_voices.items()))
|
| 163 |
-
print(f"Found {len(vibevoice_voices)} voice files in {voices_dir}")
|
| 164 |
-
print(f"Available voices: {', '.join(vibevoice_voices.keys())}")
|
| 165 |
-
else:
|
| 166 |
-
print(f"Warning: Voices directory not found at {voices_dir}")
|
| 167 |
-
vibevoice_voices = {}
|
| 168 |
-
|
| 169 |
-
print("VibeVoice model loaded successfully")
|
| 170 |
-
|
| 171 |
-
except Exception as e:
|
| 172 |
-
print(f"Error loading VibeVoice: {str(e)}")
|
| 173 |
-
import traceback
|
| 174 |
-
traceback.print_exc()
|
| 175 |
-
raise e
|
| 176 |
-
|
| 177 |
-
# Initialize VibeVoice on startup
|
| 178 |
-
initialize_vibevoice()
|
| 179 |
-
|
| 180 |
# Scan Piper voices
|
| 181 |
def scan_piper_voices():
|
| 182 |
voices_dir = "src/voices/piper_voices"
|
|
@@ -306,147 +180,6 @@ def generate_piper_speech(text, lang, voice):
|
|
| 306 |
except Exception as e:
|
| 307 |
return None, f"Error synthesizing speech: {str(e)}"
|
| 308 |
|
| 309 |
-
def generate_vibevoice_speech(text, voice_name=None):
|
| 310 |
-
"""
|
| 311 |
-
Generate speech from text using VibeVoice 1.5B with proper API
|
| 312 |
-
|
| 313 |
-
Args:
|
| 314 |
-
text (str): Text to convert to speech
|
| 315 |
-
voice_name (str, optional): Name of voice preset to use
|
| 316 |
-
|
| 317 |
-
Returns:
|
| 318 |
-
str: Path to the generated audio file
|
| 319 |
-
"""
|
| 320 |
-
if not vibevoice_model or not vibevoice_processor:
|
| 321 |
-
raise RuntimeError("VibeVoice model not initialized")
|
| 322 |
-
|
| 323 |
-
if not text.strip():
|
| 324 |
-
raise ValueError("Please enter text to synthesize")
|
| 325 |
-
|
| 326 |
-
try:
|
| 327 |
-
# Select voice preset
|
| 328 |
-
if voice_name and voice_name in vibevoice_voices:
|
| 329 |
-
voice_path = vibevoice_voices[voice_name]
|
| 330 |
-
print(f"Using voice preset: {voice_name}")
|
| 331 |
-
else:
|
| 332 |
-
# Use first available voice or default
|
| 333 |
-
if vibevoice_voices:
|
| 334 |
-
voice_name = list(vibevoice_voices.keys())[0]
|
| 335 |
-
voice_path = vibevoice_voices[voice_name]
|
| 336 |
-
print(f"Using default voice preset: {voice_name}")
|
| 337 |
-
else:
|
| 338 |
-
# Generate without voice preset (may not work well)
|
| 339 |
-
voice_path = None
|
| 340 |
-
print("No voice presets available, generating without voice reference")
|
| 341 |
-
|
| 342 |
-
# Read voice sample if available
|
| 343 |
-
voice_samples = []
|
| 344 |
-
if voice_path:
|
| 345 |
-
try:
|
| 346 |
-
wav, sr = sf.read(voice_path)
|
| 347 |
-
if len(wav.shape) > 1:
|
| 348 |
-
wav = np.mean(wav, axis=1)
|
| 349 |
-
if sr != 24000:
|
| 350 |
-
wav = librosa.resample(wav, orig_sr=sr, target_sr=24000)
|
| 351 |
-
voice_samples.append(wav)
|
| 352 |
-
print(f"Loaded voice sample: {voice_path}, duration: {len(wav)/24000:.2f}s")
|
| 353 |
-
except Exception as e:
|
| 354 |
-
print(f"Error loading voice sample {voice_path}: {e}")
|
| 355 |
-
voice_samples = []
|
| 356 |
-
|
| 357 |
-
# Prepare input for VibeVoice - format text as single-speaker script
|
| 358 |
-
formatted_script = f"Speaker 1: {text}"
|
| 359 |
-
|
| 360 |
-
voice_samples_input = [voice_samples] if voice_samples else None
|
| 361 |
-
|
| 362 |
-
inputs = vibevoice_processor(
|
| 363 |
-
text=[formatted_script],
|
| 364 |
-
voice_samples=voice_samples_input,
|
| 365 |
-
padding=True,
|
| 366 |
-
return_tensors="pt",
|
| 367 |
-
return_attention_mask=True,
|
| 368 |
-
)
|
| 369 |
-
|
| 370 |
-
# Ensure voice samples are properly typed before processor
|
| 371 |
-
if voice_samples_input and voice_samples_input[0]:
|
| 372 |
-
voice_samples_input[0] = torch.tensor(voice_samples_input[0], dtype=torch.float32)
|
| 373 |
-
|
| 374 |
-
# Move tensors to device and match model's data type
|
| 375 |
-
device = next(vibevoice_model.parameters()).device
|
| 376 |
-
model_dtype = next(vibevoice_model.parameters()).dtype
|
| 377 |
-
|
| 378 |
-
for k, v in inputs.items():
|
| 379 |
-
if torch.is_tensor(v):
|
| 380 |
-
# Convert to model's data type before moving to device
|
| 381 |
-
inputs[k] = v.to(dtype=model_dtype).to(device)
|
| 382 |
-
|
| 383 |
-
# Generate speech using VibeVoice
|
| 384 |
-
with torch.no_grad():
|
| 385 |
-
outputs = vibevoice_model.generate(
|
| 386 |
-
**inputs,
|
| 387 |
-
cfg_scale=1.3,
|
| 388 |
-
tokenizer=vibevoice_processor.tokenizer,
|
| 389 |
-
generation_config={
|
| 390 |
-
'do_sample': False,
|
| 391 |
-
},
|
| 392 |
-
verbose=False,
|
| 393 |
-
refresh_negative=True,
|
| 394 |
-
)
|
| 395 |
-
|
| 396 |
-
# Extract audio from outputs
|
| 397 |
-
if hasattr(outputs, 'waveform'):
|
| 398 |
-
audio = outputs.waveform
|
| 399 |
-
elif hasattr(outputs, 'audio'):
|
| 400 |
-
audio = outputs.audio
|
| 401 |
-
elif isinstance(outputs, dict) and 'audio' in outputs:
|
| 402 |
-
audio = outputs['audio']
|
| 403 |
-
elif isinstance(outputs, torch.Tensor):
|
| 404 |
-
audio = outputs
|
| 405 |
-
else:
|
| 406 |
-
# Try to get audio from the model output
|
| 407 |
-
audio = vibevoice_model.model.generate_audio(outputs)
|
| 408 |
-
|
| 409 |
-
# Ensure audio is in correct format
|
| 410 |
-
if torch.is_tensor(audio):
|
| 411 |
-
audio = audio.cpu().numpy()
|
| 412 |
-
|
| 413 |
-
# Ensure audio is 1D and properly normalized
|
| 414 |
-
if len(audio.shape) > 1:
|
| 415 |
-
audio = np.mean(audio, axis=1) if audio.shape[0] < audio.shape[1] else np.mean(audio, axis=0)
|
| 416 |
-
|
| 417 |
-
# Normalize to [-1, 1] range
|
| 418 |
-
if np.max(np.abs(audio)) > 1.0:
|
| 419 |
-
audio = audio / np.max(np.abs(audio))
|
| 420 |
-
|
| 421 |
-
# Convert to 16-bit for saving
|
| 422 |
-
audio_16bit = (audio * 32767).astype(np.int16)
|
| 423 |
-
|
| 424 |
-
# Save to temporary file
|
| 425 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
| 426 |
-
sf.write(tmp_file.name, audio_16bit, 24000)
|
| 427 |
-
print(f"Generated audio saved to: {tmp_file.name}")
|
| 428 |
-
return tmp_file.name
|
| 429 |
-
|
| 430 |
-
except Exception as e:
|
| 431 |
-
print(f"Error in VibeVoice generation: {str(e)}")
|
| 432 |
-
import traceback
|
| 433 |
-
traceback.print_exc()
|
| 434 |
-
# Fallback to simple audio generation if model inference fails
|
| 435 |
-
try:
|
| 436 |
-
sample_rate = 22050
|
| 437 |
-
duration = 2.0
|
| 438 |
-
t = torch.linspace(0, duration, int(sample_rate * duration))
|
| 439 |
-
frequency = 440 # A4 note
|
| 440 |
-
audio = torch.sin(2 * torch.pi * frequency * t).unsqueeze(0) * 0.3
|
| 441 |
-
|
| 442 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
| 443 |
-
# Convert tensor to numpy array before saving
|
| 444 |
-
audio_np = audio.squeeze().numpy() # Remove extra dimensions
|
| 445 |
-
sf.write(tmp_file.name, audio_np, sample_rate)
|
| 446 |
-
return tmp_file.name
|
| 447 |
-
except Exception as fallback_error:
|
| 448 |
-
raise RuntimeError(f"Error generating speech with VibeVoice: {str(e)} - Fallback also failed: {str(fallback_error)}")
|
| 449 |
-
|
| 450 |
def update_piper_voices(lang):
|
| 451 |
choices = list(voices_by_lang.get(lang, {}).keys())
|
| 452 |
value = choices[0] if choices else None
|
|
@@ -550,33 +283,7 @@ with gr.Blocks(css=custom_css, title="🎙️ TTS Model Gallery", theme=gr.theme
|
|
| 550 |
piper_audio_output = gr.Audio(label="Generated Speech", type="filepath")
|
| 551 |
piper_status = gr.Textbox(label="Status", interactive=False)
|
| 552 |
|
| 553 |
-
# VibeVoice
|
| 554 |
-
vibevoice_model_info = gr.HTML(create_model_card("microsoft/VibeVoice-1.5B"))
|
| 555 |
-
|
| 556 |
-
with gr.Row():
|
| 557 |
-
with gr.Column():
|
| 558 |
-
vibevoice_voice_selection = gr.Dropdown(
|
| 559 |
-
choices=list(vibevoice_voices.keys()) if vibevoice_voices else [],
|
| 560 |
-
value=list(vibevoice_voices.keys())[0] if vibevoice_voices else None,
|
| 561 |
-
label="Voice Preset"
|
| 562 |
-
)
|
| 563 |
-
vibevoice_generate_btn = gr.Button("Generate Speech")
|
| 564 |
-
|
| 565 |
-
with gr.Column():
|
| 566 |
-
vibevoice_audio_output = gr.Audio(label="Generated Speech", type="filepath")
|
| 567 |
-
|
| 568 |
-
# Examples for VibeVoice
|
| 569 |
-
gr.Examples(
|
| 570 |
-
examples=[
|
| 571 |
-
["Hello, this is a test of VibeVoice 1.5B from Microsoft.", list(vibevoice_voices.keys())[0] if vibevoice_voices else None],
|
| 572 |
-
["The quick brown fox jumps over the lazy dog.", list(vibevoice_voices.keys())[0] if vibevoice_voices else None],
|
| 573 |
-
["Artificial intelligence is transforming the world.", list(vibevoice_voices.keys())[0] if vibevoice_voices else None]
|
| 574 |
-
],
|
| 575 |
-
inputs=[text_input, vibevoice_voice_selection],
|
| 576 |
-
outputs=vibevoice_audio_output,
|
| 577 |
-
fn=generate_vibevoice_speech,
|
| 578 |
-
cache_examples=False
|
| 579 |
-
)
|
| 580 |
|
| 581 |
# Examples for Chatterbox
|
| 582 |
gr.Examples(
|
|
@@ -597,12 +304,7 @@ with gr.Blocks(css=custom_css, title="🎙️ TTS Model Gallery", theme=gr.theme
|
|
| 597 |
outputs=audio_output
|
| 598 |
)
|
| 599 |
|
| 600 |
-
#
|
| 601 |
-
vibevoice_generate_btn.click(
|
| 602 |
-
fn=generate_vibevoice_speech,
|
| 603 |
-
inputs=[text_input, vibevoice_voice_selection],
|
| 604 |
-
outputs=vibevoice_audio_output
|
| 605 |
-
)
|
| 606 |
|
| 607 |
# Connect the KittenTTS generate button to the function
|
| 608 |
kittentts_generate_btn.click(
|
|
|
|
| 20 |
"ResembleAI/chatterbox": "Industrial-grade TTS solution with multilingual support",
|
| 21 |
"KittenML/KittenTTS": "High-quality TTS with voice cloning capabilities using reference audio",
|
| 22 |
"piper-tts": "Local on-device TTS with dynamic English and Chinese voice selection from Piper models",
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
# Models dictionary
|
|
|
|
| 27 |
"ResembleAI/chatterbox": "Chatterbox",
|
| 28 |
"KittenML/KittenTTS": "KittenTTS",
|
| 29 |
"piper-tts": "Piper (no voice cloning)",
|
|
|
|
| 30 |
}
|
| 31 |
|
| 32 |
original_torch_load = torch.load
|
|
|
|
| 51 |
# Initialize KittenTTS model
|
| 52 |
kittentts_model = KittenTTS("KittenML/kitten-tts-nano-0.2")
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
# Scan Piper voices
|
| 55 |
def scan_piper_voices():
|
| 56 |
voices_dir = "src/voices/piper_voices"
|
|
|
|
| 180 |
except Exception as e:
|
| 181 |
return None, f"Error synthesizing speech: {str(e)}"
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
def update_piper_voices(lang):
|
| 184 |
choices = list(voices_by_lang.get(lang, {}).keys())
|
| 185 |
value = choices[0] if choices else None
|
|
|
|
| 283 |
piper_audio_output = gr.Audio(label="Generated Speech", type="filepath")
|
| 284 |
piper_status = gr.Textbox(label="Status", interactive=False)
|
| 285 |
|
| 286 |
+
# VibeVoice section removed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
# Examples for Chatterbox
|
| 289 |
gr.Examples(
|
|
|
|
| 304 |
outputs=audio_output
|
| 305 |
)
|
| 306 |
|
| 307 |
+
# VibeVoice button connection removed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
# Connect the KittenTTS generate button to the function
|
| 310 |
kittentts_generate_btn.click(
|
src/vibevoice/__init__.py
DELETED
|
File without changes
|
src/vibevoice/configs/qwen2.5_1.5b_64k.json
DELETED
|
@@ -1,112 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_attn_implementation_autoset": true,
|
| 3 |
-
"acoustic_vae_dim": 64,
|
| 4 |
-
"acoustic_tokenizer_config": {
|
| 5 |
-
"causal": true,
|
| 6 |
-
"channels": 1,
|
| 7 |
-
"conv_bias": true,
|
| 8 |
-
"conv_norm": "none",
|
| 9 |
-
"corpus_normalize": 0.0,
|
| 10 |
-
"decoder_depths": null,
|
| 11 |
-
"decoder_n_filters": 32,
|
| 12 |
-
"decoder_ratios": [
|
| 13 |
-
8,
|
| 14 |
-
5,
|
| 15 |
-
5,
|
| 16 |
-
4,
|
| 17 |
-
2,
|
| 18 |
-
2
|
| 19 |
-
],
|
| 20 |
-
"disable_last_norm": true,
|
| 21 |
-
"encoder_depths": "3-3-3-3-3-3-8",
|
| 22 |
-
"encoder_n_filters": 32,
|
| 23 |
-
"encoder_ratios": [
|
| 24 |
-
8,
|
| 25 |
-
5,
|
| 26 |
-
5,
|
| 27 |
-
4,
|
| 28 |
-
2,
|
| 29 |
-
2
|
| 30 |
-
],
|
| 31 |
-
"fix_std": 0.5,
|
| 32 |
-
"layer_scale_init_value": 1e-06,
|
| 33 |
-
"layernorm": "RMSNorm",
|
| 34 |
-
"layernorm_elementwise_affine": true,
|
| 35 |
-
"layernorm_eps": 1e-05,
|
| 36 |
-
"mixer_layer": "depthwise_conv",
|
| 37 |
-
"model_type": "vibepod_acoustic_tokenizer",
|
| 38 |
-
"pad_mode": "constant",
|
| 39 |
-
"std_dist_type": "gaussian",
|
| 40 |
-
"vae_dim": 64,
|
| 41 |
-
"weight_init_value": 0.01
|
| 42 |
-
},
|
| 43 |
-
"decoder_config": {
|
| 44 |
-
"attention_dropout": 0.0,
|
| 45 |
-
"hidden_act": "silu",
|
| 46 |
-
"hidden_size": 1536,
|
| 47 |
-
"initializer_range": 0.02,
|
| 48 |
-
"intermediate_size": 8960,
|
| 49 |
-
"max_position_embeddings": 65536,
|
| 50 |
-
"max_window_layers": 28,
|
| 51 |
-
"model_type": "qwen2",
|
| 52 |
-
"num_attention_heads": 12,
|
| 53 |
-
"num_hidden_layers": 28,
|
| 54 |
-
"num_key_value_heads": 2,
|
| 55 |
-
"rms_norm_eps": 1e-06,
|
| 56 |
-
"rope_scaling": null,
|
| 57 |
-
"rope_theta": 1000000.0,
|
| 58 |
-
"sliding_window": null,
|
| 59 |
-
"tie_word_embeddings": true,
|
| 60 |
-
"torch_dtype": "bfloat16",
|
| 61 |
-
"use_cache": true,
|
| 62 |
-
"use_sliding_window": false,
|
| 63 |
-
"vocab_size": 151936
|
| 64 |
-
},
|
| 65 |
-
"diffusion_head_config": {
|
| 66 |
-
"ddpm_batch_mul": 4,
|
| 67 |
-
"ddpm_beta_schedule": "cosine",
|
| 68 |
-
"ddpm_num_inference_steps": 20,
|
| 69 |
-
"ddpm_num_steps": 1000,
|
| 70 |
-
"diffusion_type": "ddpm",
|
| 71 |
-
"head_ffn_ratio": 3.0,
|
| 72 |
-
"head_layers": 4,
|
| 73 |
-
"hidden_size": 1536,
|
| 74 |
-
"latent_size": 64,
|
| 75 |
-
"model_type": "vibepod_diffusion_head",
|
| 76 |
-
"prediction_type": "v_prediction",
|
| 77 |
-
"rms_norm_eps": 1e-05,
|
| 78 |
-
"speech_vae_dim": 64
|
| 79 |
-
},
|
| 80 |
-
"model_type": "vibepod",
|
| 81 |
-
"semantic_tokenizer_config": {
|
| 82 |
-
"causal": true,
|
| 83 |
-
"channels": 1,
|
| 84 |
-
"conv_bias": true,
|
| 85 |
-
"conv_norm": "none",
|
| 86 |
-
"corpus_normalize": 0.0,
|
| 87 |
-
"disable_last_norm": true,
|
| 88 |
-
"encoder_depths": "3-3-3-3-3-3-8",
|
| 89 |
-
"encoder_n_filters": 32,
|
| 90 |
-
"encoder_ratios": [
|
| 91 |
-
8,
|
| 92 |
-
5,
|
| 93 |
-
5,
|
| 94 |
-
4,
|
| 95 |
-
2,
|
| 96 |
-
2
|
| 97 |
-
],
|
| 98 |
-
"fix_std": 0,
|
| 99 |
-
"layer_scale_init_value": 1e-06,
|
| 100 |
-
"layernorm": "RMSNorm",
|
| 101 |
-
"layernorm_elementwise_affine": true,
|
| 102 |
-
"layernorm_eps": 1e-05,
|
| 103 |
-
"mixer_layer": "depthwise_conv",
|
| 104 |
-
"model_type": "vibepod_semantic_tokenizer",
|
| 105 |
-
"pad_mode": "constant",
|
| 106 |
-
"std_dist_type": "none",
|
| 107 |
-
"vae_dim": 128,
|
| 108 |
-
"weight_init_value": 0.01
|
| 109 |
-
},
|
| 110 |
-
"semantic_vae_dim": 128,
|
| 111 |
-
"torch_dtype": "bfloat16"
|
| 112 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/configs/qwen2.5_7b_32k.json
DELETED
|
@@ -1,113 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_attn_implementation_autoset": true,
|
| 3 |
-
"acoustic_vae_dim": 64,
|
| 4 |
-
"acoustic_tokenizer_config": {
|
| 5 |
-
"causal": true,
|
| 6 |
-
"channels": 1,
|
| 7 |
-
"conv_bias": true,
|
| 8 |
-
"conv_norm": "none",
|
| 9 |
-
"corpus_normalize": 0.0,
|
| 10 |
-
"decoder_depths": null,
|
| 11 |
-
"decoder_n_filters": 32,
|
| 12 |
-
"decoder_ratios": [
|
| 13 |
-
8,
|
| 14 |
-
5,
|
| 15 |
-
5,
|
| 16 |
-
4,
|
| 17 |
-
2,
|
| 18 |
-
2
|
| 19 |
-
],
|
| 20 |
-
"disable_last_norm": true,
|
| 21 |
-
"encoder_depths": "3-3-3-3-3-3-8",
|
| 22 |
-
"encoder_n_filters": 32,
|
| 23 |
-
"encoder_ratios": [
|
| 24 |
-
8,
|
| 25 |
-
5,
|
| 26 |
-
5,
|
| 27 |
-
4,
|
| 28 |
-
2,
|
| 29 |
-
2
|
| 30 |
-
],
|
| 31 |
-
"fix_std": 0.5,
|
| 32 |
-
"layer_scale_init_value": 1e-06,
|
| 33 |
-
"layernorm": "RMSNorm",
|
| 34 |
-
"layernorm_elementwise_affine": true,
|
| 35 |
-
"layernorm_eps": 1e-05,
|
| 36 |
-
"mixer_layer": "depthwise_conv",
|
| 37 |
-
"model_type": "vibepod_acoustic_tokenizer",
|
| 38 |
-
"pad_mode": "constant",
|
| 39 |
-
"std_dist_type": "gaussian",
|
| 40 |
-
"vae_dim": 64,
|
| 41 |
-
"weight_init_value": 0.01
|
| 42 |
-
},
|
| 43 |
-
"decoder_config": {
|
| 44 |
-
"attention_dropout": 0.0,
|
| 45 |
-
"hidden_act": "silu",
|
| 46 |
-
"hidden_size": 3584,
|
| 47 |
-
"initializer_range": 0.02,
|
| 48 |
-
"intermediate_size": 18944,
|
| 49 |
-
"max_position_embeddings": 32768,
|
| 50 |
-
"max_window_layers": 28,
|
| 51 |
-
"model_type": "qwen2",
|
| 52 |
-
"num_attention_heads": 28,
|
| 53 |
-
"num_hidden_layers": 28,
|
| 54 |
-
"num_key_value_heads": 4,
|
| 55 |
-
"rms_norm_eps": 1e-06,
|
| 56 |
-
"rope_theta": 1000000.0,
|
| 57 |
-
"sliding_window": null,
|
| 58 |
-
"tie_word_embeddings": false,
|
| 59 |
-
"torch_dtype": "bfloat16",
|
| 60 |
-
"transformers_version": "4.40.1",
|
| 61 |
-
"use_cache": true,
|
| 62 |
-
"use_mrope": false,
|
| 63 |
-
"use_sliding_window": false,
|
| 64 |
-
"vocab_size": 152064
|
| 65 |
-
},
|
| 66 |
-
"diffusion_head_config": {
|
| 67 |
-
"ddpm_batch_mul": 4,
|
| 68 |
-
"ddpm_beta_schedule": "cosine",
|
| 69 |
-
"ddpm_num_inference_steps": 20,
|
| 70 |
-
"ddpm_num_steps": 1000,
|
| 71 |
-
"diffusion_type": "ddpm",
|
| 72 |
-
"head_ffn_ratio": 3.0,
|
| 73 |
-
"head_layers": 4,
|
| 74 |
-
"hidden_size": 3584,
|
| 75 |
-
"latent_size": 64,
|
| 76 |
-
"model_type": "vibepod_diffusion_head",
|
| 77 |
-
"prediction_type": "v_prediction",
|
| 78 |
-
"rms_norm_eps": 1e-05,
|
| 79 |
-
"speech_vae_dim": 64
|
| 80 |
-
},
|
| 81 |
-
"model_type": "vibepod",
|
| 82 |
-
"semantic_tokenizer_config": {
|
| 83 |
-
"causal": true,
|
| 84 |
-
"channels": 1,
|
| 85 |
-
"conv_bias": true,
|
| 86 |
-
"conv_norm": "none",
|
| 87 |
-
"corpus_normalize": 0.0,
|
| 88 |
-
"disable_last_norm": true,
|
| 89 |
-
"encoder_depths": "3-3-3-3-3-3-8",
|
| 90 |
-
"encoder_n_filters": 32,
|
| 91 |
-
"encoder_ratios": [
|
| 92 |
-
8,
|
| 93 |
-
5,
|
| 94 |
-
5,
|
| 95 |
-
4,
|
| 96 |
-
2,
|
| 97 |
-
2
|
| 98 |
-
],
|
| 99 |
-
"fix_std": 0,
|
| 100 |
-
"layer_scale_init_value": 1e-06,
|
| 101 |
-
"layernorm": "RMSNorm",
|
| 102 |
-
"layernorm_elementwise_affine": true,
|
| 103 |
-
"layernorm_eps": 1e-05,
|
| 104 |
-
"mixer_layer": "depthwise_conv",
|
| 105 |
-
"model_type": "vibepod_semantic_tokenizer",
|
| 106 |
-
"pad_mode": "constant",
|
| 107 |
-
"std_dist_type": "none",
|
| 108 |
-
"vae_dim": 128,
|
| 109 |
-
"weight_init_value": 0.01
|
| 110 |
-
},
|
| 111 |
-
"semantic_vae_dim": 128,
|
| 112 |
-
"torch_dtype": "bfloat16"
|
| 113 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/modular/__init__.py
DELETED
|
File without changes
|
src/vibevoice/modular/configuration_vibevoice.py
DELETED
|
@@ -1,248 +0,0 @@
|
|
| 1 |
-
""" VibeVoice_AcousticTokenizer model configuration"""
|
| 2 |
-
|
| 3 |
-
from typing import Dict, List, Optional, Tuple
|
| 4 |
-
|
| 5 |
-
from transformers.configuration_utils import PretrainedConfig
|
| 6 |
-
from transformers.utils import logging
|
| 7 |
-
|
| 8 |
-
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
| 9 |
-
|
| 10 |
-
logger = logging.get_logger(__name__)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class VibeVoiceAcousticTokenizerConfig(PretrainedConfig):
|
| 14 |
-
model_type = "vibevoice_acoustic_tokenizer"
|
| 15 |
-
|
| 16 |
-
def __init__(
|
| 17 |
-
self,
|
| 18 |
-
channels: int = 1,
|
| 19 |
-
corpus_normalize: float = 0.0,
|
| 20 |
-
causal: bool = True,
|
| 21 |
-
vae_dim: int = 64,
|
| 22 |
-
fix_std: float = 0.5,
|
| 23 |
-
std_dist_type: str = 'gaussian',
|
| 24 |
-
# common
|
| 25 |
-
mixer_layer: str = 'depthwise_conv',
|
| 26 |
-
conv_norm: str = 'none',
|
| 27 |
-
pad_mode: str = 'constant',
|
| 28 |
-
disable_last_norm: bool = True,
|
| 29 |
-
layernorm: str = 'RMSNorm',
|
| 30 |
-
layernorm_eps: float = 1e-5,
|
| 31 |
-
layernorm_elementwise_affine: bool = True,
|
| 32 |
-
conv_bias: bool = True,
|
| 33 |
-
layer_scale_init_value: float = 1e-6,
|
| 34 |
-
weight_init_value: float = 1e-2,
|
| 35 |
-
# encoder specific
|
| 36 |
-
encoder_n_filters: int = 32,
|
| 37 |
-
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
|
| 38 |
-
encoder_depths: str = "3-3-3-3-3-3-8",
|
| 39 |
-
# decoder specific
|
| 40 |
-
decoder_n_filters: int = 32,
|
| 41 |
-
decoder_ratios: Optional[List[int]] = None, # if None, same as encoder
|
| 42 |
-
decoder_depths: Optional[str] = None,
|
| 43 |
-
**kwargs
|
| 44 |
-
):
|
| 45 |
-
super().__init__(**kwargs)
|
| 46 |
-
self.channels = channels
|
| 47 |
-
self.corpus_normalize = corpus_normalize
|
| 48 |
-
self.causal = causal
|
| 49 |
-
self.vae_dim = vae_dim
|
| 50 |
-
self.fix_std = fix_std
|
| 51 |
-
self.std_dist_type = std_dist_type
|
| 52 |
-
|
| 53 |
-
# common parameters
|
| 54 |
-
self.conv_norm = conv_norm
|
| 55 |
-
self.pad_mode = pad_mode
|
| 56 |
-
self.layernorm_eps = layernorm_eps
|
| 57 |
-
self.disable_last_norm = disable_last_norm
|
| 58 |
-
self.layernorm = layernorm
|
| 59 |
-
self.layernorm_elementwise_affine = layernorm_elementwise_affine
|
| 60 |
-
self.conv_bias = conv_bias
|
| 61 |
-
self.layer_scale_init_value = layer_scale_init_value
|
| 62 |
-
self.weight_init_value = weight_init_value
|
| 63 |
-
self.mixer_layer = mixer_layer
|
| 64 |
-
|
| 65 |
-
# encoder specific parameters
|
| 66 |
-
self.encoder_n_filters = encoder_n_filters
|
| 67 |
-
self.encoder_ratios = encoder_ratios
|
| 68 |
-
self.encoder_depths = encoder_depths
|
| 69 |
-
|
| 70 |
-
# decoder specific parameters
|
| 71 |
-
self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios
|
| 72 |
-
self.decoder_n_filters = decoder_n_filters
|
| 73 |
-
self.decoder_depths = decoder_depths
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
class VibeVoiceSemanticTokenizerConfig(PretrainedConfig):
|
| 77 |
-
model_type = "vibevoice_semantic_tokenizer"
|
| 78 |
-
|
| 79 |
-
def __init__(
|
| 80 |
-
self,
|
| 81 |
-
channels: int = 1,
|
| 82 |
-
corpus_normalize: float = 0.0,
|
| 83 |
-
causal: bool = True,
|
| 84 |
-
vae_dim: int = 64,
|
| 85 |
-
fix_std: float = 0,
|
| 86 |
-
std_dist_type: str = 'none',
|
| 87 |
-
# common
|
| 88 |
-
mixer_layer: str = 'depthwise_conv',
|
| 89 |
-
conv_norm: str = 'none',
|
| 90 |
-
pad_mode: str = 'constant',
|
| 91 |
-
disable_last_norm: bool = True,
|
| 92 |
-
layernorm: str = 'RMSNorm',
|
| 93 |
-
layernorm_eps: float = 1e-5,
|
| 94 |
-
layernorm_elementwise_affine: bool = True,
|
| 95 |
-
conv_bias: bool = True,
|
| 96 |
-
layer_scale_init_value: float = 1e-6,
|
| 97 |
-
weight_init_value: float = 1e-2,
|
| 98 |
-
# encoder specific
|
| 99 |
-
encoder_n_filters: int = 32,
|
| 100 |
-
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
|
| 101 |
-
encoder_depths: str = "3-3-3-3-3-3-8",
|
| 102 |
-
**kwargs
|
| 103 |
-
):
|
| 104 |
-
super().__init__(**kwargs)
|
| 105 |
-
self.channels = channels
|
| 106 |
-
self.corpus_normalize = corpus_normalize
|
| 107 |
-
self.causal = causal
|
| 108 |
-
self.vae_dim = vae_dim
|
| 109 |
-
self.fix_std = fix_std
|
| 110 |
-
self.std_dist_type = std_dist_type
|
| 111 |
-
|
| 112 |
-
# common parameters
|
| 113 |
-
self.conv_norm = conv_norm
|
| 114 |
-
self.pad_mode = pad_mode
|
| 115 |
-
self.layernorm_eps = layernorm_eps
|
| 116 |
-
self.disable_last_norm = disable_last_norm
|
| 117 |
-
self.layernorm = layernorm
|
| 118 |
-
self.layernorm_elementwise_affine = layernorm_elementwise_affine
|
| 119 |
-
self.conv_bias = conv_bias
|
| 120 |
-
self.layer_scale_init_value = layer_scale_init_value
|
| 121 |
-
self.weight_init_value = weight_init_value
|
| 122 |
-
self.mixer_layer = mixer_layer
|
| 123 |
-
|
| 124 |
-
# encoder specific parameters
|
| 125 |
-
self.encoder_n_filters = encoder_n_filters
|
| 126 |
-
self.encoder_ratios = encoder_ratios
|
| 127 |
-
self.encoder_depths = encoder_depths
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
class VibeVoiceDiffusionHeadConfig(PretrainedConfig):
|
| 131 |
-
model_type = "vibevoice_diffusion_head"
|
| 132 |
-
|
| 133 |
-
def __init__(
|
| 134 |
-
self,
|
| 135 |
-
hidden_size=768,
|
| 136 |
-
head_layers=4,
|
| 137 |
-
head_ffn_ratio=3.0,
|
| 138 |
-
rms_norm_eps=1e-5,
|
| 139 |
-
latent_size=64,
|
| 140 |
-
speech_vae_dim=None,
|
| 141 |
-
prediction_type="v_prediction",
|
| 142 |
-
diffusion_type="ddpm",
|
| 143 |
-
ddpm_num_steps=1000,
|
| 144 |
-
ddpm_num_inference_steps=20,
|
| 145 |
-
ddpm_beta_schedule="cosine",
|
| 146 |
-
ddpm_batch_mul=4,
|
| 147 |
-
**kwargs
|
| 148 |
-
):
|
| 149 |
-
self.hidden_size = hidden_size
|
| 150 |
-
self.head_layers = head_layers
|
| 151 |
-
self.head_ffn_ratio = head_ffn_ratio
|
| 152 |
-
self.rms_norm_eps = rms_norm_eps
|
| 153 |
-
self.latent_size = latent_size
|
| 154 |
-
self.speech_vae_dim = speech_vae_dim
|
| 155 |
-
self.prediction_type = prediction_type
|
| 156 |
-
self.diffusion_type = diffusion_type
|
| 157 |
-
self.ddpm_num_steps = ddpm_num_steps
|
| 158 |
-
self.ddpm_num_inference_steps = ddpm_num_inference_steps
|
| 159 |
-
self.ddpm_beta_schedule = ddpm_beta_schedule
|
| 160 |
-
self.ddpm_batch_mul = ddpm_batch_mul
|
| 161 |
-
|
| 162 |
-
super().__init__(**kwargs)
|
| 163 |
-
|
| 164 |
-
class VibeVoiceConfig(PretrainedConfig):
|
| 165 |
-
model_type = "vibevoice"
|
| 166 |
-
is_composition = True
|
| 167 |
-
sub_configs = {
|
| 168 |
-
"acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
|
| 169 |
-
"semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
|
| 170 |
-
"decoder_config": Qwen2Config,
|
| 171 |
-
"diffusion_head_config": VibeVoiceDiffusionHeadConfig,
|
| 172 |
-
}
|
| 173 |
-
# keys_to_ignore_at_inference = ["past_key_values"]
|
| 174 |
-
# Default tensor parallel plan for base model `Qwen2`
|
| 175 |
-
base_model_tp_plan = {
|
| 176 |
-
"layers.*.self_attn.q_proj": "colwise",
|
| 177 |
-
"layers.*.self_attn.k_proj": "colwise",
|
| 178 |
-
"layers.*.self_attn.v_proj": "colwise",
|
| 179 |
-
"layers.*.self_attn.o_proj": "rowwise",
|
| 180 |
-
"layers.*.mlp.gate_proj": "colwise",
|
| 181 |
-
"layers.*.mlp.up_proj": "colwise",
|
| 182 |
-
"layers.*.mlp.down_proj": "rowwise",
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
def __init__(
|
| 186 |
-
self,
|
| 187 |
-
acoustic_tokenizer_config=None,
|
| 188 |
-
semantic_tokenizer_config=None,
|
| 189 |
-
decoder_config=None,
|
| 190 |
-
diffusion_head_config=None,
|
| 191 |
-
**kwargs
|
| 192 |
-
):
|
| 193 |
-
|
| 194 |
-
# kwargs["_attn_implementation"] = "flash_attention_2"
|
| 195 |
-
kwargs["_attn_implementation_autoset"] = False
|
| 196 |
-
|
| 197 |
-
if acoustic_tokenizer_config is None:
|
| 198 |
-
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
|
| 199 |
-
elif isinstance(acoustic_tokenizer_config, dict):
|
| 200 |
-
acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
|
| 201 |
-
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
|
| 202 |
-
elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
|
| 203 |
-
# If an instance of the config class is provided
|
| 204 |
-
self.acoustic_tokenizer_config = acoustic_tokenizer_config
|
| 205 |
-
|
| 206 |
-
if semantic_tokenizer_config is None:
|
| 207 |
-
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
|
| 208 |
-
elif isinstance(semantic_tokenizer_config, dict):
|
| 209 |
-
semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
|
| 210 |
-
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
|
| 211 |
-
elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
|
| 212 |
-
# If an instance of the config class is provided
|
| 213 |
-
self.semantic_tokenizer_config = semantic_tokenizer_config
|
| 214 |
-
|
| 215 |
-
if decoder_config is None:
|
| 216 |
-
self.decoder_config = self.sub_configs["decoder_config"]()
|
| 217 |
-
elif isinstance(decoder_config, dict):
|
| 218 |
-
# If a dictionary is provided, instantiate the config class with it
|
| 219 |
-
# self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
|
| 220 |
-
if decoder_config.get("model_type", '') == "qwen2":
|
| 221 |
-
self.decoder_config = Qwen2Config(**decoder_config)
|
| 222 |
-
else:
|
| 223 |
-
raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
|
| 224 |
-
elif isinstance(decoder_config, (Qwen2Config,)):
|
| 225 |
-
# If an instance of the config class is provided
|
| 226 |
-
self.decoder_config = decoder_config
|
| 227 |
-
|
| 228 |
-
if diffusion_head_config is None:
|
| 229 |
-
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
|
| 230 |
-
elif isinstance(diffusion_head_config, dict):
|
| 231 |
-
diffusion_head_config["model_type"] = "vibevoice_diffusion_head"
|
| 232 |
-
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
|
| 233 |
-
elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig):
|
| 234 |
-
# If an instance of the config class is provided
|
| 235 |
-
self.diffusion_head_config = diffusion_head_config
|
| 236 |
-
|
| 237 |
-
# other parameters
|
| 238 |
-
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
|
| 239 |
-
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
|
| 240 |
-
|
| 241 |
-
super().__init__(**kwargs)
|
| 242 |
-
|
| 243 |
-
__all__ = [
|
| 244 |
-
"VibeVoiceAcousticTokenizerConfig",
|
| 245 |
-
"VibeVoiceSemanticTokenizerConfig",
|
| 246 |
-
"VibeVoiceDiffusionHeadConfig",
|
| 247 |
-
"VibeVoiceConfig"
|
| 248 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/modular/modeling_vibevoice.py
DELETED
|
@@ -1,487 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from typing import Dict, List, Optional, Tuple, Union, Callable
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
import torch.distributed as dist
|
| 8 |
-
|
| 9 |
-
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
| 10 |
-
|
| 11 |
-
from transformers.activations import ACT2FN
|
| 12 |
-
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
|
| 13 |
-
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
| 14 |
-
from transformers import modeling_utils
|
| 15 |
-
from transformers.modeling_utils import PreTrainedModel
|
| 16 |
-
from transformers.utils import logging
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
|
| 20 |
-
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
|
| 21 |
-
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
|
| 22 |
-
|
| 23 |
-
from .configuration_vibevoice import VibeVoiceConfig
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
logger = logging.get_logger(__name__)
|
| 27 |
-
|
| 28 |
-
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
| 29 |
-
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
| 30 |
-
|
| 31 |
-
@dataclass
|
| 32 |
-
class VibeVoiceCausalLMOutputWithPast(ModelOutput):
|
| 33 |
-
loss: Optional[torch.FloatTensor] = None
|
| 34 |
-
diffusion_loss: Optional[torch.FloatTensor] = None
|
| 35 |
-
speech_token_num: Optional[int] = None
|
| 36 |
-
logits: torch.FloatTensor = None
|
| 37 |
-
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 38 |
-
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 39 |
-
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
@dataclass
|
| 43 |
-
class VibeVoiceGenerationOutput(ModelOutput):
|
| 44 |
-
"""
|
| 45 |
-
Output type for VibeVoice generation.
|
| 46 |
-
|
| 47 |
-
Args:
|
| 48 |
-
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 49 |
-
The generated sequences.
|
| 50 |
-
speech_outputs (`List[torch.FloatTensor]`, *optional*):
|
| 51 |
-
List of generated speech waveforms or latents for each speech segment.
|
| 52 |
-
"""
|
| 53 |
-
sequences: torch.LongTensor = None
|
| 54 |
-
speech_outputs: Optional[List[torch.FloatTensor]] = None
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
class SpeechConnector(nn.Module):
|
| 58 |
-
def __init__(self, input_dim, output_dim):
|
| 59 |
-
super().__init__()
|
| 60 |
-
self.fc1 = nn.Linear(input_dim, output_dim)
|
| 61 |
-
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
|
| 62 |
-
self.fc2 = nn.Linear(output_dim, output_dim)
|
| 63 |
-
|
| 64 |
-
def forward(self, features, **kwargs):
|
| 65 |
-
x = self.fc1(features)
|
| 66 |
-
x = self.norm(x)
|
| 67 |
-
x = self.fc2(x)
|
| 68 |
-
return x
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
# @auto_docstring
|
| 72 |
-
class VibeVoicePreTrainedModel(PreTrainedModel):
|
| 73 |
-
config_class = VibeVoiceConfig
|
| 74 |
-
base_model_prefix = "model"
|
| 75 |
-
supports_gradient_checkpointing = True
|
| 76 |
-
_skip_keys_device_placement = "past_key_values"
|
| 77 |
-
_supports_cache_class = True
|
| 78 |
-
_supports_flash_attn_2 = True
|
| 79 |
-
_supports_sdpa = True
|
| 80 |
-
_supports_quantized_cache = True
|
| 81 |
-
_supports_static_cache = True
|
| 82 |
-
_supports_attention_backend = True
|
| 83 |
-
|
| 84 |
-
def _init_weights(self, module):
|
| 85 |
-
if isinstance(module, VibeVoiceDiffusionHead):
|
| 86 |
-
module.initialize_weights()
|
| 87 |
-
return
|
| 88 |
-
|
| 89 |
-
# Use the language model's initializer_range if available
|
| 90 |
-
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
|
| 91 |
-
std = self.config.language_model_config.initializer_range
|
| 92 |
-
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
|
| 93 |
-
std = self.config.decoder_config.initializer_range
|
| 94 |
-
else:
|
| 95 |
-
std = 0.02 # Default value
|
| 96 |
-
|
| 97 |
-
if isinstance(module, nn.Linear):
|
| 98 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 99 |
-
if module.bias is not None:
|
| 100 |
-
module.bias.data.zero_()
|
| 101 |
-
elif isinstance(module, nn.LayerNorm):
|
| 102 |
-
module.weight.data.fill_(1.0)
|
| 103 |
-
module.bias.data.zero_()
|
| 104 |
-
|
| 105 |
-
# @auto_docstring
|
| 106 |
-
class VibeVoiceModel(VibeVoicePreTrainedModel):
|
| 107 |
-
def __init__(self, config):
|
| 108 |
-
super().__init__(config)
|
| 109 |
-
|
| 110 |
-
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
|
| 111 |
-
if isinstance(config.torch_dtype, str):
|
| 112 |
-
dtype = getattr(torch, config.torch_dtype)
|
| 113 |
-
else:
|
| 114 |
-
dtype = config.torch_dtype
|
| 115 |
-
else:
|
| 116 |
-
dtype = torch.float32
|
| 117 |
-
|
| 118 |
-
# Initialize Qwen2 model for language modeling
|
| 119 |
-
lm_config = config.decoder_config
|
| 120 |
-
self.language_model = AutoModel.from_config(lm_config)
|
| 121 |
-
|
| 122 |
-
# Initialize speech components if needed
|
| 123 |
-
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
|
| 124 |
-
self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
|
| 125 |
-
|
| 126 |
-
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
|
| 127 |
-
self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
|
| 128 |
-
|
| 129 |
-
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
|
| 130 |
-
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
|
| 131 |
-
self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
|
| 132 |
-
|
| 133 |
-
# Initialize prediction head for speech generation
|
| 134 |
-
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
|
| 135 |
-
|
| 136 |
-
# Initialize noise scheduler
|
| 137 |
-
self.noise_scheduler = DPMSolverMultistepScheduler(
|
| 138 |
-
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
|
| 139 |
-
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
|
| 140 |
-
prediction_type=config.diffusion_head_config.prediction_type
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
def get_input_embeddings(self):
|
| 144 |
-
if hasattr(self.language_model, 'embed_tokens'):
|
| 145 |
-
# If the language model has an embed_tokens attribute, return it
|
| 146 |
-
return self.language_model.embed_tokens
|
| 147 |
-
|
| 148 |
-
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
|
| 149 |
-
if attr.orig_name == 'embed_tokens.weight':
|
| 150 |
-
return getattr(self.language_model, name)
|
| 151 |
-
assert False, 'should not arrive here'
|
| 152 |
-
|
| 153 |
-
def set_input_embeddings(self, value):
|
| 154 |
-
self.language_model.embed_tokens = value
|
| 155 |
-
|
| 156 |
-
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
| 157 |
-
"""Set the speech tokenizers used for encoding and decoding speech."""
|
| 158 |
-
self.acoustic_tokenizer = acoustic_tokenizer
|
| 159 |
-
self.semantic_tokenizer = semantic_tokenizer
|
| 160 |
-
|
| 161 |
-
# Reset the encoder to evaluation mode
|
| 162 |
-
if self.acoustic_tokenizer is not None:
|
| 163 |
-
self.acoustic_tokenizer.eval()
|
| 164 |
-
|
| 165 |
-
if self.semantic_tokenizer is not None:
|
| 166 |
-
self.semantic_tokenizer.eval()
|
| 167 |
-
|
| 168 |
-
def forward(
|
| 169 |
-
self,
|
| 170 |
-
input_ids: torch.LongTensor = None,
|
| 171 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 172 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 173 |
-
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 174 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 175 |
-
use_cache: Optional[bool] = None,
|
| 176 |
-
output_attentions: Optional[bool] = None,
|
| 177 |
-
output_hidden_states: Optional[bool] = None,
|
| 178 |
-
return_dict: Optional[bool] = None,
|
| 179 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 180 |
-
**kwargs,
|
| 181 |
-
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 182 |
-
|
| 183 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 184 |
-
|
| 185 |
-
# Forward through language model
|
| 186 |
-
outputs = self.language_model(
|
| 187 |
-
input_ids=input_ids,
|
| 188 |
-
attention_mask=attention_mask,
|
| 189 |
-
position_ids=position_ids,
|
| 190 |
-
past_key_values=past_key_values,
|
| 191 |
-
inputs_embeds=inputs_embeds,
|
| 192 |
-
use_cache=use_cache,
|
| 193 |
-
output_attentions=output_attentions,
|
| 194 |
-
output_hidden_states=output_hidden_states,
|
| 195 |
-
return_dict=return_dict,
|
| 196 |
-
cache_position=cache_position,
|
| 197 |
-
**kwargs,
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
if not return_dict:
|
| 201 |
-
return outputs
|
| 202 |
-
|
| 203 |
-
return BaseModelOutputWithPast(
|
| 204 |
-
last_hidden_state=outputs.last_hidden_state,
|
| 205 |
-
past_key_values=outputs.past_key_values,
|
| 206 |
-
hidden_states=outputs.hidden_states,
|
| 207 |
-
attentions=outputs.attentions,
|
| 208 |
-
)
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
|
| 212 |
-
_tied_weights_keys = ["lm_head.weight"]
|
| 213 |
-
_tp_plan = {"lm_head": "colwise_rep"}
|
| 214 |
-
|
| 215 |
-
def __init__(self, config):
|
| 216 |
-
super().__init__(config)
|
| 217 |
-
self.model = VibeVoiceModel(config)
|
| 218 |
-
self.vocab_size = config.decoder_config.vocab_size
|
| 219 |
-
self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
|
| 220 |
-
|
| 221 |
-
self.post_init()
|
| 222 |
-
|
| 223 |
-
def get_input_embeddings(self):
|
| 224 |
-
return self.model.get_input_embeddings()
|
| 225 |
-
|
| 226 |
-
def set_input_embeddings(self, value):
|
| 227 |
-
self.model.set_input_embeddings(value)
|
| 228 |
-
|
| 229 |
-
def get_output_embeddings(self):
|
| 230 |
-
return self.lm_head
|
| 231 |
-
|
| 232 |
-
def set_decoder(self, decoder):
|
| 233 |
-
self.model.language_model = decoder
|
| 234 |
-
|
| 235 |
-
def get_decoder(self):
|
| 236 |
-
return self.model.language_model
|
| 237 |
-
|
| 238 |
-
def tie_weights(self):
|
| 239 |
-
"""
|
| 240 |
-
Tie the weights between the input embeddings and the output embeddings.
|
| 241 |
-
"""
|
| 242 |
-
if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
|
| 243 |
-
# The standard PreTrainedModel method will handle the tying.
|
| 244 |
-
# It typically does a simple parameter object assignment, which is
|
| 245 |
-
# CORRECT to do BEFORE FSDP wraps the model.
|
| 246 |
-
output_embeddings = self.get_output_embeddings()
|
| 247 |
-
input_embeddings = self.get_input_embeddings()
|
| 248 |
-
if hasattr(input_embeddings, 'weight'):
|
| 249 |
-
output_embeddings.weight = input_embeddings.weight
|
| 250 |
-
else:
|
| 251 |
-
# maybe returned input_embeddings a tensor directly
|
| 252 |
-
output_embeddings.weight = input_embeddings
|
| 253 |
-
|
| 254 |
-
if getattr(output_embeddings, "bias", None) is not None:
|
| 255 |
-
output_embeddings.bias.data = nn.functional.pad(
|
| 256 |
-
output_embeddings.bias.data,
|
| 257 |
-
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
|
| 258 |
-
"constant",
|
| 259 |
-
0,
|
| 260 |
-
)
|
| 261 |
-
print("✅ Tied input and output embeddings using standard assignment.")
|
| 262 |
-
else:
|
| 263 |
-
print("ℹ️ tie_word_embeddings is False, not tying weights.")
|
| 264 |
-
|
| 265 |
-
# Also, ensure set_output_embeddings is safe, though your implementation looks okay.
|
| 266 |
-
# The key is to avoid calling it after accelerator.prepare().
|
| 267 |
-
def set_output_embeddings(self, new_embeddings):
|
| 268 |
-
# Your current implementation using data.copy_ is good practice,
|
| 269 |
-
# but the best way is to not call this after prepare().
|
| 270 |
-
self.lm_head = new_embeddings
|
| 271 |
-
|
| 272 |
-
def forward_speech_features(
|
| 273 |
-
self,
|
| 274 |
-
speech_tensors=None,
|
| 275 |
-
speech_masks=None,
|
| 276 |
-
speech_type="audio",
|
| 277 |
-
return_unmask=False
|
| 278 |
-
):
|
| 279 |
-
if speech_tensors is None:
|
| 280 |
-
# Use config to get vae_dim instead of non-existent self.args
|
| 281 |
-
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
| 282 |
-
audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
|
| 283 |
-
connect_features = self.model.acoustic_connector(audio_features)
|
| 284 |
-
return audio_features, connect_features
|
| 285 |
-
else:
|
| 286 |
-
with torch.no_grad():
|
| 287 |
-
if speech_type == "audio":
|
| 288 |
-
with torch.no_grad():
|
| 289 |
-
frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
|
| 290 |
-
audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
|
| 291 |
-
|
| 292 |
-
elif speech_type == "vae":
|
| 293 |
-
# Use config to get vae_dim instead of non-existent self.args
|
| 294 |
-
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
| 295 |
-
speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
|
| 296 |
-
|
| 297 |
-
# gaussian sample from the speech_mode
|
| 298 |
-
batch_size = speech_mode.size(0)
|
| 299 |
-
value = self.model.acoustic_tokenizer.fix_std / 0.8
|
| 300 |
-
std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
|
| 301 |
-
std = std.view(-1, *[1] * (speech_mode.dim() - 1))
|
| 302 |
-
audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
|
| 303 |
-
else:
|
| 304 |
-
raise NotImplementedError(f"Speech type {speech_type} not implemented")
|
| 305 |
-
|
| 306 |
-
if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
|
| 307 |
-
scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
|
| 308 |
-
bias_factor = -audio_tokens[speech_masks].flatten().mean()
|
| 309 |
-
|
| 310 |
-
# Only use distributed operations if the process group is initialized
|
| 311 |
-
if dist.is_available() and dist.is_initialized():
|
| 312 |
-
dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
|
| 313 |
-
dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
|
| 314 |
-
world_size = dist.get_world_size()
|
| 315 |
-
self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
|
| 316 |
-
self.model.speech_bias_factor.copy_(bias_factor / world_size)
|
| 317 |
-
print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
|
| 318 |
-
else:
|
| 319 |
-
# Single process case
|
| 320 |
-
self.model.speech_scaling_factor.copy_(scaling_factor)
|
| 321 |
-
self.model.speech_bias_factor.copy_(bias_factor)
|
| 322 |
-
print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
|
| 323 |
-
|
| 324 |
-
audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
|
| 325 |
-
|
| 326 |
-
connect_features = self.model.acoustic_connector(audio_features)
|
| 327 |
-
if return_unmask:
|
| 328 |
-
return audio_features, connect_features
|
| 329 |
-
return audio_features[speech_masks], connect_features[speech_masks]
|
| 330 |
-
|
| 331 |
-
def forward(
|
| 332 |
-
self,
|
| 333 |
-
input_ids: torch.LongTensor = None,
|
| 334 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 335 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 336 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 337 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 338 |
-
labels: Optional[torch.LongTensor] = None,
|
| 339 |
-
use_cache: Optional[bool] = False,
|
| 340 |
-
output_attentions: Optional[bool] = None,
|
| 341 |
-
output_hidden_states: Optional[bool] = None,
|
| 342 |
-
return_dict: Optional[bool] = None,
|
| 343 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 344 |
-
# New arguments for speech processing and loss calculation
|
| 345 |
-
speech_tensors: Optional[torch.FloatTensor] = None,
|
| 346 |
-
speech_masks: Optional[torch.BoolTensor] = None,
|
| 347 |
-
speeches_loss_input: Optional[torch.FloatTensor] = None,
|
| 348 |
-
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
|
| 349 |
-
acoustic_input_mask: Optional[torch.BoolTensor] = None,
|
| 350 |
-
acoustic_loss_mask: Optional[torch.BoolTensor] = None,
|
| 351 |
-
ddpm_batch_mul: int = 1,
|
| 352 |
-
**kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
|
| 353 |
-
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
|
| 354 |
-
|
| 355 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 356 |
-
|
| 357 |
-
x = self.get_input_embeddings()(input_ids)
|
| 358 |
-
|
| 359 |
-
semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
|
| 360 |
-
if speeches_loss_input is not None:
|
| 361 |
-
# only part audio need diffuse
|
| 362 |
-
speech_all_features, speech_all_connect_features = self.forward_speech_features(
|
| 363 |
-
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
|
| 364 |
-
speech_masks=speech_masks,
|
| 365 |
-
speech_type=kwargs.get("speech_type", "audio"),
|
| 366 |
-
return_unmask=True
|
| 367 |
-
)
|
| 368 |
-
if speech_tensors is not None:
|
| 369 |
-
if semantic_speech_all_connect_features is not None:
|
| 370 |
-
x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + semantic_speech_all_connect_features[speech_masks]
|
| 371 |
-
else:
|
| 372 |
-
x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
|
| 373 |
-
speech_features = speech_all_features[speeches_loss_input.unsqueeze(-1) & speech_masks] # only part audio need diffuse
|
| 374 |
-
speech_connect_features = speech_all_connect_features[speeches_loss_input.unsqueeze(-1) & speech_masks]
|
| 375 |
-
else:
|
| 376 |
-
speech_features, speech_connect_features = self.forward_speech_features(
|
| 377 |
-
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
|
| 378 |
-
speech_masks=speech_masks,
|
| 379 |
-
speech_type=kwargs.get("speech_type", "audio"),
|
| 380 |
-
)
|
| 381 |
-
if speech_tensors is not None:
|
| 382 |
-
x[acoustic_input_mask] = speech_connect_features
|
| 383 |
-
|
| 384 |
-
outputs = self.model(
|
| 385 |
-
input_ids=None,
|
| 386 |
-
attention_mask=attention_mask,
|
| 387 |
-
position_ids=position_ids,
|
| 388 |
-
past_key_values=past_key_values,
|
| 389 |
-
inputs_embeds=x,
|
| 390 |
-
use_cache=use_cache,
|
| 391 |
-
output_attentions=output_attentions,
|
| 392 |
-
output_hidden_states=False,
|
| 393 |
-
return_dict=return_dict,
|
| 394 |
-
cache_position=cache_position,
|
| 395 |
-
)
|
| 396 |
-
|
| 397 |
-
hidden_states = outputs.last_hidden_state
|
| 398 |
-
logits = self.lm_head(hidden_states)
|
| 399 |
-
# logits = logits.float()
|
| 400 |
-
|
| 401 |
-
loss = None
|
| 402 |
-
if labels is not None:
|
| 403 |
-
# The custom CE loss with masking is calculated in the training script.
|
| 404 |
-
# We leave the standard loss calculation here as None.
|
| 405 |
-
pass
|
| 406 |
-
|
| 407 |
-
# --- Diffusion Loss Calculation ---
|
| 408 |
-
diffusion_loss = None
|
| 409 |
-
# This block is executed only if we are in a context that involves speech.
|
| 410 |
-
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
|
| 411 |
-
condition_features = hidden_states[acoustic_loss_mask]
|
| 412 |
-
|
| 413 |
-
speech_len, latent_size = speech_features.shape
|
| 414 |
-
|
| 415 |
-
noise = torch.randn(
|
| 416 |
-
(speech_len * ddpm_batch_mul, latent_size),
|
| 417 |
-
device=hidden_states.device,
|
| 418 |
-
dtype=hidden_states.dtype
|
| 419 |
-
)
|
| 420 |
-
|
| 421 |
-
timesteps = torch.multinomial(
|
| 422 |
-
torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
|
| 423 |
-
speech_len * ddpm_batch_mul,
|
| 424 |
-
replacement=True,
|
| 425 |
-
).to(hidden_states.device)
|
| 426 |
-
|
| 427 |
-
speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
|
| 428 |
-
condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
|
| 429 |
-
|
| 430 |
-
noisy_speech_features = self.model.noise_scheduler.add_noise(
|
| 431 |
-
speech_features_repeated, noise, timesteps
|
| 432 |
-
)
|
| 433 |
-
|
| 434 |
-
model_output = self.model.prediction_head(
|
| 435 |
-
noisy_speech_features,
|
| 436 |
-
timesteps.type_as(x),
|
| 437 |
-
condition_features_repeated
|
| 438 |
-
)
|
| 439 |
-
|
| 440 |
-
prediction_type = self.config.diffusion_head_config.prediction_type
|
| 441 |
-
if prediction_type == "epsilon":
|
| 442 |
-
target_for_loss = noise
|
| 443 |
-
elif prediction_type == "v_prediction":
|
| 444 |
-
target_for_loss = self.model.noise_scheduler.get_velocity(
|
| 445 |
-
speech_features_repeated, noise, timesteps
|
| 446 |
-
)
|
| 447 |
-
else:
|
| 448 |
-
raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
|
| 449 |
-
|
| 450 |
-
diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
|
| 451 |
-
if latent_size > 0 and ddpm_batch_mul > 0:
|
| 452 |
-
diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
|
| 453 |
-
else:
|
| 454 |
-
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
|
| 455 |
-
|
| 456 |
-
else:
|
| 457 |
-
# Dummy loss for DDP to work when there are no speech samples in a batch,
|
| 458 |
-
# but we are in a speech context.
|
| 459 |
-
diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
|
| 460 |
-
diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
|
| 461 |
-
diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
|
| 462 |
-
# --- End Diffusion Loss Calculation ---
|
| 463 |
-
|
| 464 |
-
if not return_dict:
|
| 465 |
-
output = (logits, speech_len) + outputs.to_tuple()[1:]
|
| 466 |
-
return (loss, diffusion_loss) + output
|
| 467 |
-
|
| 468 |
-
return VibeVoiceCausalLMOutputWithPast(
|
| 469 |
-
loss=loss,
|
| 470 |
-
diffusion_loss=diffusion_loss,
|
| 471 |
-
speech_token_num=speech_len if speech_tensors is not None else 0,
|
| 472 |
-
logits=logits,
|
| 473 |
-
past_key_values=outputs.past_key_values,
|
| 474 |
-
hidden_states=outputs.hidden_states,
|
| 475 |
-
attentions=outputs.attentions,
|
| 476 |
-
)
|
| 477 |
-
|
| 478 |
-
AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
|
| 479 |
-
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
|
| 480 |
-
|
| 481 |
-
__all__ = [
|
| 482 |
-
"VibeVoiceModel",
|
| 483 |
-
"VibeVoicePreTrainedModel",
|
| 484 |
-
"VibeVoiceForConditionalGeneration",
|
| 485 |
-
"VibeVoiceCausalLMOutputWithPast",
|
| 486 |
-
"VibeVoiceGenerationOutput",
|
| 487 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/modular/modeling_vibevoice_inference.py
DELETED
|
@@ -1,716 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from typing import Dict, List, Optional, Tuple, Union, Callable
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
|
| 7 |
-
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
| 8 |
-
|
| 9 |
-
from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList
|
| 10 |
-
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
| 11 |
-
from transformers import modeling_utils
|
| 12 |
-
from transformers.modeling_utils import PreTrainedModel
|
| 13 |
-
from transformers.utils import logging
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
|
| 17 |
-
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceTokenizerEncoderOutput
|
| 18 |
-
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
|
| 19 |
-
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
|
| 20 |
-
|
| 21 |
-
from .configuration_vibevoice import VibeVoiceConfig
|
| 22 |
-
|
| 23 |
-
from .modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizer, VibeVoiceTextTokenizerFast
|
| 24 |
-
|
| 25 |
-
from .modeling_vibevoice import VibeVoiceModel, VibeVoicePreTrainedModel
|
| 26 |
-
from .streamer import AudioStreamer, AsyncAudioStreamer
|
| 27 |
-
|
| 28 |
-
logger = logging.get_logger(__name__)
|
| 29 |
-
|
| 30 |
-
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
| 31 |
-
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
| 32 |
-
|
| 33 |
-
@dataclass
|
| 34 |
-
class VibeVoiceCausalLMOutputWithPast(BaseModelOutputWithPast):
|
| 35 |
-
logits: Optional[torch.FloatTensor] = None
|
| 36 |
-
|
| 37 |
-
@dataclass
|
| 38 |
-
class VibeVoiceGenerationOutput(ModelOutput):
|
| 39 |
-
"""
|
| 40 |
-
Output type for VibeVoice generation.
|
| 41 |
-
|
| 42 |
-
Args:
|
| 43 |
-
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 44 |
-
The generated sequences.
|
| 45 |
-
speech_outputs (`List[torch.FloatTensor]`, *optional*):
|
| 46 |
-
List of generated speech waveforms or latents for each speech segment.
|
| 47 |
-
"""
|
| 48 |
-
sequences: torch.LongTensor = None
|
| 49 |
-
speech_outputs: Optional[List[torch.FloatTensor]] = None
|
| 50 |
-
reach_max_step_sample: Optional[torch.BoolTensor] = None
|
| 51 |
-
|
| 52 |
-
class VibeVoiceTokenConstraintProcessor(LogitsProcessor):
|
| 53 |
-
"""Constrains token generation to only valid tokens during speech generation."""
|
| 54 |
-
|
| 55 |
-
def __init__(self, valid_token_ids: List[int], device: torch.device = None):
|
| 56 |
-
self.valid_token_ids = torch.tensor(valid_token_ids, dtype=torch.long, device=device)
|
| 57 |
-
|
| 58 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 59 |
-
# Create a mask for valid tokens
|
| 60 |
-
mask = torch.full_like(scores, float('-inf'))
|
| 61 |
-
mask[:, self.valid_token_ids] = 0
|
| 62 |
-
|
| 63 |
-
# Apply mask to scores
|
| 64 |
-
scores = scores + mask
|
| 65 |
-
return scores
|
| 66 |
-
|
| 67 |
-
class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, GenerationMixin):
|
| 68 |
-
_tied_weights_keys = ["lm_head.weight"]
|
| 69 |
-
_tp_plan = {"lm_head": "colwise_rep"}
|
| 70 |
-
|
| 71 |
-
def __init__(self, config):
|
| 72 |
-
super().__init__(config)
|
| 73 |
-
|
| 74 |
-
# Initialize the base model
|
| 75 |
-
self.model = VibeVoiceModel(config)
|
| 76 |
-
|
| 77 |
-
# LM head for text generation
|
| 78 |
-
self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.decoder_config.vocab_size, bias=False)
|
| 79 |
-
|
| 80 |
-
# inference configuration
|
| 81 |
-
self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps
|
| 82 |
-
|
| 83 |
-
# Initialize weights and apply final processing
|
| 84 |
-
self.post_init()
|
| 85 |
-
|
| 86 |
-
@property
|
| 87 |
-
def noise_scheduler(self):
|
| 88 |
-
return self.model.noise_scheduler
|
| 89 |
-
|
| 90 |
-
@property
|
| 91 |
-
def prediction_head(self):
|
| 92 |
-
return self.model.prediction_head
|
| 93 |
-
|
| 94 |
-
@property
|
| 95 |
-
def speech_scaling_factor(self):
|
| 96 |
-
return self.model.speech_scaling_factor
|
| 97 |
-
|
| 98 |
-
@property
|
| 99 |
-
def speech_bias_factor(self):
|
| 100 |
-
return self.model.speech_bias_factor
|
| 101 |
-
|
| 102 |
-
@property
|
| 103 |
-
def acoustic_tokenizer(self):
|
| 104 |
-
return self.model.acoustic_tokenizer
|
| 105 |
-
|
| 106 |
-
@property
|
| 107 |
-
def semantic_tokenizer(self):
|
| 108 |
-
return self.model.semantic_tokenizer
|
| 109 |
-
|
| 110 |
-
@property
|
| 111 |
-
def acoustic_connector(self):
|
| 112 |
-
return self.model.acoustic_connector
|
| 113 |
-
|
| 114 |
-
@property
|
| 115 |
-
def semantic_connector(self):
|
| 116 |
-
return self.model.semantic_connector
|
| 117 |
-
|
| 118 |
-
def tie_weights(self):
|
| 119 |
-
"""
|
| 120 |
-
Tie the weights between the input embeddings and the output embeddings.
|
| 121 |
-
"""
|
| 122 |
-
# Tie lm_head.weight to language_model.embed_tokens.weight
|
| 123 |
-
if not getattr(self.config, 'tie_word_embeddings', False):
|
| 124 |
-
return
|
| 125 |
-
|
| 126 |
-
if hasattr(self, 'lm_head') and hasattr(self.model.language_model, 'embed_tokens'):
|
| 127 |
-
self.lm_head.weight = self.model.language_model.embed_tokens.weight
|
| 128 |
-
|
| 129 |
-
def get_input_embeddings(self):
|
| 130 |
-
return self.model.get_input_embeddings()
|
| 131 |
-
|
| 132 |
-
def set_input_embeddings(self, value):
|
| 133 |
-
self.model.set_input_embeddings(value)
|
| 134 |
-
|
| 135 |
-
def get_output_embeddings(self):
|
| 136 |
-
return self.lm_head
|
| 137 |
-
|
| 138 |
-
def set_output_embeddings(self, new_embeddings):
|
| 139 |
-
self.lm_head = new_embeddings
|
| 140 |
-
|
| 141 |
-
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
| 142 |
-
"""Set the speech tokenizers used for encoding and decoding speech."""
|
| 143 |
-
self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer)
|
| 144 |
-
|
| 145 |
-
def set_ddpm_inference_steps(self, num_steps=None):
|
| 146 |
-
self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
|
| 147 |
-
|
| 148 |
-
def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"):
|
| 149 |
-
"""Process speech inputs through tokenizers and connectors."""
|
| 150 |
-
with torch.no_grad():
|
| 151 |
-
if speech_type == "audio":
|
| 152 |
-
# Encode audio to acoustic latents
|
| 153 |
-
encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
|
| 154 |
-
acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
|
| 155 |
-
|
| 156 |
-
# Apply scaling and bias
|
| 157 |
-
acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
|
| 158 |
-
|
| 159 |
-
# Connect to language model space
|
| 160 |
-
acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
|
| 161 |
-
|
| 162 |
-
return acoustic_features, acoustic_connected
|
| 163 |
-
elif speech_type == "pt":
|
| 164 |
-
encoder_output = VibeVoiceTokenizerEncoderOutput(mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std)
|
| 165 |
-
acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
|
| 166 |
-
|
| 167 |
-
# Apply scaling and bias
|
| 168 |
-
acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
|
| 169 |
-
|
| 170 |
-
# Connect to language model space
|
| 171 |
-
acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
|
| 172 |
-
|
| 173 |
-
return acoustic_features, acoustic_connected
|
| 174 |
-
else:
|
| 175 |
-
raise NotImplementedError(f"Speech type {speech_type} not implemented")
|
| 176 |
-
|
| 177 |
-
# @can_return_tuple
|
| 178 |
-
def forward(
|
| 179 |
-
self,
|
| 180 |
-
input_ids: torch.LongTensor = None,
|
| 181 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 182 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 183 |
-
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 184 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 185 |
-
labels: Optional[torch.LongTensor] = None,
|
| 186 |
-
use_cache: Optional[bool] = None,
|
| 187 |
-
output_attentions: Optional[bool] = None,
|
| 188 |
-
output_hidden_states: Optional[bool] = None,
|
| 189 |
-
return_dict: Optional[bool] = None,
|
| 190 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 191 |
-
speech_tensors: Optional[torch.FloatTensor] = None,
|
| 192 |
-
speech_masks: Optional[torch.BoolTensor] = None,
|
| 193 |
-
speech_input_mask: Optional[torch.BoolTensor] = None,
|
| 194 |
-
logits_to_keep: Union[int, slice] = 0,
|
| 195 |
-
**kwargs,
|
| 196 |
-
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
|
| 197 |
-
"""
|
| 198 |
-
Args:
|
| 199 |
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 200 |
-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 201 |
-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 202 |
-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 203 |
-
speech_tensors (`torch.FloatTensor`, *optional*):
|
| 204 |
-
Input speech waveforms for voice cloning or speech understanding.
|
| 205 |
-
speech_masks (`torch.BoolTensor`, *optional*):
|
| 206 |
-
Masks indicating valid speech frames.
|
| 207 |
-
speech_input_mask (`torch.BoolTensor`, *optional*):
|
| 208 |
-
Positions in the input sequence where speech embeddings should be inserted.
|
| 209 |
-
|
| 210 |
-
Returns:
|
| 211 |
-
`VibeVoiceCausalLMOutputWithPast` or tuple
|
| 212 |
-
"""
|
| 213 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 214 |
-
|
| 215 |
-
# Get embeddings
|
| 216 |
-
if inputs_embeds is None:
|
| 217 |
-
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
| 218 |
-
|
| 219 |
-
# Process speech inputs if provided
|
| 220 |
-
if speech_tensors is not None and speech_masks is not None:
|
| 221 |
-
# Ensure speech tensors match model's data type
|
| 222 |
-
speech_tensors = speech_tensors.to(self.dtype)
|
| 223 |
-
acoustic_features, speech_embeds = self._process_speech_inputs(speech_tensors, speech_masks)
|
| 224 |
-
if speech_input_mask is not None:
|
| 225 |
-
inputs_embeds[speech_input_mask] = speech_embeds
|
| 226 |
-
|
| 227 |
-
outputs = self.model(
|
| 228 |
-
inputs_embeds=inputs_embeds,
|
| 229 |
-
attention_mask=attention_mask,
|
| 230 |
-
position_ids=position_ids,
|
| 231 |
-
past_key_values=past_key_values,
|
| 232 |
-
use_cache=use_cache,
|
| 233 |
-
output_attentions=output_attentions,
|
| 234 |
-
output_hidden_states=output_hidden_states,
|
| 235 |
-
return_dict=return_dict,
|
| 236 |
-
cache_position=cache_position,
|
| 237 |
-
**kwargs,
|
| 238 |
-
)
|
| 239 |
-
|
| 240 |
-
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
|
| 241 |
-
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 242 |
-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 243 |
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 244 |
-
|
| 245 |
-
if labels is not None:
|
| 246 |
-
raise NotImplementedError("Loss computation is not implemented in this version.")
|
| 247 |
-
|
| 248 |
-
return VibeVoiceCausalLMOutputWithPast(
|
| 249 |
-
logits=logits,
|
| 250 |
-
past_key_values=outputs.past_key_values,
|
| 251 |
-
last_hidden_state=hidden_states,
|
| 252 |
-
attentions=outputs.attentions,
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
def _build_generate_config_model_kwargs(self, generation_config, inputs, tokenizer, return_processors=False, **kwargs):
|
| 256 |
-
if generation_config is None:
|
| 257 |
-
generation_config = GenerationConfig(
|
| 258 |
-
bos_token_id=tokenizer.bos_token_id,
|
| 259 |
-
eos_token_id=tokenizer.eos_token_id,
|
| 260 |
-
pad_token_id = tokenizer.pad_token_id
|
| 261 |
-
)
|
| 262 |
-
else:
|
| 263 |
-
generation_config = GenerationConfig(
|
| 264 |
-
**generation_config,
|
| 265 |
-
bos_token_id=tokenizer.bos_token_id,
|
| 266 |
-
eos_token_id=tokenizer.eos_token_id,
|
| 267 |
-
pad_token_id = tokenizer.pad_token_id
|
| 268 |
-
)
|
| 269 |
-
|
| 270 |
-
generation_config, model_kwargs = self._prepare_generation_config(
|
| 271 |
-
generation_config,
|
| 272 |
-
use_cache=True,
|
| 273 |
-
speech_start_id=tokenizer.speech_start_id,
|
| 274 |
-
speech_end_id=tokenizer.speech_end_id,
|
| 275 |
-
speech_diffusion_id=tokenizer.speech_diffusion_id,
|
| 276 |
-
**kwargs
|
| 277 |
-
)
|
| 278 |
-
generation_config.speech_start_id = tokenizer.speech_start_id
|
| 279 |
-
generation_config.speech_end_id = tokenizer.speech_end_id
|
| 280 |
-
generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
|
| 281 |
-
|
| 282 |
-
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)
|
| 283 |
-
batch_size = inputs_tensor.shape[0]
|
| 284 |
-
device = self.device
|
| 285 |
-
|
| 286 |
-
self._prepare_special_tokens(generation_config, True, device=device)
|
| 287 |
-
generation_config.use_cache = True
|
| 288 |
-
model_kwargs["use_cache"] = generation_config.use_cache
|
| 289 |
-
input_ids = inputs_tensor.to(self.device)
|
| 290 |
-
|
| 291 |
-
input_ids_length = input_ids.shape[1]
|
| 292 |
-
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
| 293 |
-
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
| 294 |
-
generation_config = self._prepare_generated_length(
|
| 295 |
-
generation_config=generation_config,
|
| 296 |
-
has_default_max_length=has_default_max_length,
|
| 297 |
-
has_default_min_length=has_default_min_length,
|
| 298 |
-
model_input_name=model_input_name,
|
| 299 |
-
inputs_tensor=inputs_tensor,
|
| 300 |
-
input_ids_length=input_ids_length,
|
| 301 |
-
)
|
| 302 |
-
|
| 303 |
-
max_cache_length = generation_config.max_length - 1
|
| 304 |
-
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
| 305 |
-
model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
|
| 306 |
-
for k, v in model_kwargs.items():
|
| 307 |
-
if isinstance(v, torch.Tensor):
|
| 308 |
-
model_kwargs[k] = v.to(device=device)
|
| 309 |
-
|
| 310 |
-
if return_processors:
|
| 311 |
-
logits_processor = self._get_logits_processor(
|
| 312 |
-
generation_config=generation_config,
|
| 313 |
-
input_ids_seq_length=input_ids_length,
|
| 314 |
-
encoder_input_ids=inputs_tensor,
|
| 315 |
-
prefix_allowed_tokens_fn=None,
|
| 316 |
-
logits_processor=LogitsProcessorList(),
|
| 317 |
-
device=inputs_tensor.device,
|
| 318 |
-
model_kwargs=model_kwargs,
|
| 319 |
-
)
|
| 320 |
-
|
| 321 |
-
stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=StoppingCriteriaList())
|
| 322 |
-
|
| 323 |
-
return generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria
|
| 324 |
-
else:
|
| 325 |
-
return generation_config, model_kwargs, input_ids
|
| 326 |
-
|
| 327 |
-
@torch.no_grad()
|
| 328 |
-
def generate(
|
| 329 |
-
self,
|
| 330 |
-
inputs: Optional[torch.Tensor] = None,
|
| 331 |
-
generation_config: Optional[GenerationConfig] = None,
|
| 332 |
-
logits_processor: Optional[LogitsProcessorList] = None,
|
| 333 |
-
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 334 |
-
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
| 335 |
-
synced_gpus: Optional[bool] = None,
|
| 336 |
-
assistant_model: Optional["PreTrainedModel"] = None,
|
| 337 |
-
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
|
| 338 |
-
negative_prompt_ids: Optional[torch.Tensor] = None,
|
| 339 |
-
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 340 |
-
speech_tensors: Optional[torch.FloatTensor] = None,
|
| 341 |
-
speech_masks: Optional[torch.BoolTensor] = None,
|
| 342 |
-
speech_input_mask: Optional[torch.BoolTensor] = None,
|
| 343 |
-
return_speech: bool = True,
|
| 344 |
-
cfg_scale: float = 1.0,
|
| 345 |
-
stop_check_fn: Optional[Callable[[], bool]] = None,
|
| 346 |
-
**kwargs,
|
| 347 |
-
) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]:
|
| 348 |
-
"""
|
| 349 |
-
Generates sequences of token ids and optionally speech outputs.
|
| 350 |
-
|
| 351 |
-
Args:
|
| 352 |
-
All standard generation arguments from GenerationMixin
|
| 353 |
-
negative_prompt_ids: Negative prompt for CFG in speech generation
|
| 354 |
-
negative_prompt_attention_mask: Attention mask for negative prompt
|
| 355 |
-
speech_tensors: Input speech for voice cloning
|
| 356 |
-
speech_masks: Masks for speech tensors
|
| 357 |
-
speech_input_mask: Positions to insert speech embeddings
|
| 358 |
-
return_speech: Whether to decode and return speech outputs
|
| 359 |
-
cfg_scale: CFG scale for speech generation
|
| 360 |
-
stop_check_fn: Optional callable that returns True if generation should stop
|
| 361 |
-
|
| 362 |
-
Returns:
|
| 363 |
-
Generated token sequences and optionally speech outputs
|
| 364 |
-
"""
|
| 365 |
-
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
| 366 |
-
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
| 367 |
-
parsed_scripts = kwargs.pop("parsed_scripts", None)
|
| 368 |
-
all_speakers_list = kwargs.pop("all_speakers_list", None)
|
| 369 |
-
max_length_times = kwargs.pop("max_length_times", 2)
|
| 370 |
-
|
| 371 |
-
if kwargs.get('max_new_tokens', None) is None:
|
| 372 |
-
kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - kwargs['input_ids'].shape[-1]
|
| 373 |
-
|
| 374 |
-
generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs(
|
| 375 |
-
generation_config, inputs, tokenizer, return_processors=True, **kwargs
|
| 376 |
-
)
|
| 377 |
-
|
| 378 |
-
negative_kwargs = {
|
| 379 |
-
'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), tokenizer.speech_start_id, dtype=torch.long, device=kwargs['input_ids'].device),
|
| 380 |
-
'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device),
|
| 381 |
-
'max_new_tokens': kwargs.get('max_new_tokens', 100)
|
| 382 |
-
}
|
| 383 |
-
negative_generation_config, negative_model_kwargs, negative_input_ids = self._build_generate_config_model_kwargs(
|
| 384 |
-
None, None, tokenizer, return_processors=False, **negative_kwargs
|
| 385 |
-
)
|
| 386 |
-
|
| 387 |
-
acoustic_cache = VibeVoiceTokenizerStreamingCache()
|
| 388 |
-
semantic_cache = VibeVoiceTokenizerStreamingCache()
|
| 389 |
-
|
| 390 |
-
batch_size = input_ids.shape[0]
|
| 391 |
-
device = input_ids.device
|
| 392 |
-
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 393 |
-
correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 394 |
-
is_prefill = True
|
| 395 |
-
inputs_embeds = None
|
| 396 |
-
verbose = kwargs.get("verbose", False)
|
| 397 |
-
|
| 398 |
-
# Initialize audio chunks storage for each sample
|
| 399 |
-
audio_chunks = [[] for _ in range(batch_size)]
|
| 400 |
-
|
| 401 |
-
initial_length = input_ids.shape[-1]
|
| 402 |
-
initial_length_per_sample = model_kwargs['attention_mask'].sum(dim=-1)
|
| 403 |
-
|
| 404 |
-
# Define all valid tokens that can be generated
|
| 405 |
-
valid_tokens = [
|
| 406 |
-
generation_config.speech_start_id,
|
| 407 |
-
generation_config.speech_end_id,
|
| 408 |
-
generation_config.speech_diffusion_id,
|
| 409 |
-
generation_config.eos_token_id
|
| 410 |
-
]
|
| 411 |
-
# Add bos_token_id if it exists
|
| 412 |
-
if hasattr(generation_config, 'bos_token_id') and generation_config.bos_token_id is not None:
|
| 413 |
-
valid_tokens.append(generation_config.bos_token_id)
|
| 414 |
-
|
| 415 |
-
# Add custom processor to constrain token generation
|
| 416 |
-
token_constraint_processor = VibeVoiceTokenConstraintProcessor(valid_tokens, device=device)
|
| 417 |
-
if logits_processor is None:
|
| 418 |
-
logits_processor = LogitsProcessorList()
|
| 419 |
-
logits_processor.append(token_constraint_processor)
|
| 420 |
-
|
| 421 |
-
max_steps = min(generation_config.max_length - initial_length, int(max_length_times * initial_length))
|
| 422 |
-
max_step_per_sample = torch.min(generation_config.max_length - initial_length_per_sample, (max_length_times * initial_length_per_sample).long())
|
| 423 |
-
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 424 |
-
|
| 425 |
-
# Create progress iterator if verbose
|
| 426 |
-
if kwargs.get("show_progress_bar", True):
|
| 427 |
-
progress_bar = tqdm(range(max_steps), desc="Generating", leave=False)
|
| 428 |
-
else:
|
| 429 |
-
progress_bar = range(max_steps)
|
| 430 |
-
|
| 431 |
-
for step in progress_bar:
|
| 432 |
-
# Check for external stop signal
|
| 433 |
-
if stop_check_fn is not None and stop_check_fn():
|
| 434 |
-
if verbose:
|
| 435 |
-
print(f"Generation stopped externally at step {step + 1}")
|
| 436 |
-
# End the audio streamer if it exists
|
| 437 |
-
if audio_streamer is not None:
|
| 438 |
-
audio_streamer.end()
|
| 439 |
-
break
|
| 440 |
-
|
| 441 |
-
# Check if audio_streamer has been ended (stopped externally)
|
| 442 |
-
if audio_streamer is not None and hasattr(audio_streamer, 'finished_flags'):
|
| 443 |
-
if any(audio_streamer.finished_flags):
|
| 444 |
-
if verbose:
|
| 445 |
-
print(f"Audio generation stopped externally at step {step + 1}")
|
| 446 |
-
break
|
| 447 |
-
|
| 448 |
-
if finished_tags.all():
|
| 449 |
-
if hasattr(progress_bar, 'set_description'):
|
| 450 |
-
progress_bar.set_description("Generation complete")
|
| 451 |
-
break
|
| 452 |
-
|
| 453 |
-
if input_ids.shape[-1] >= generation_config.max_length:
|
| 454 |
-
print(f"Reached maximum generation length {generation_config.max_length}, stopped it.")
|
| 455 |
-
reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
|
| 456 |
-
if reached_samples.numel() > 0:
|
| 457 |
-
reach_max_step_sample[reached_samples] = True
|
| 458 |
-
break
|
| 459 |
-
|
| 460 |
-
# Update progress bar description with active samples
|
| 461 |
-
if hasattr(progress_bar, 'set_description'):
|
| 462 |
-
active_samples = (~finished_tags).sum().item()
|
| 463 |
-
progress_bar.set_description(f"Generating (active: {active_samples}/{batch_size})")
|
| 464 |
-
|
| 465 |
-
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 466 |
-
if is_prefill:
|
| 467 |
-
# we process the speech inputs only during the first generation step
|
| 468 |
-
prefill_inputs = {
|
| 469 |
-
"speech_tensors": speech_tensors.to(device=device),
|
| 470 |
-
"speech_masks": speech_masks.to(device),
|
| 471 |
-
"speech_input_mask": speech_input_mask.to(device),
|
| 472 |
-
}
|
| 473 |
-
is_prefill = False
|
| 474 |
-
else:
|
| 475 |
-
_ = model_inputs.pop('inputs_embeds', None)
|
| 476 |
-
prefill_inputs = {'inputs_embeds': inputs_embeds}
|
| 477 |
-
|
| 478 |
-
# Forward pass through the model
|
| 479 |
-
outputs = self(
|
| 480 |
-
**model_inputs, **prefill_inputs, logits_to_keep=1, return_dict=True, output_attentions=False, output_hidden_states=False,
|
| 481 |
-
)
|
| 482 |
-
model_kwargs = self._update_model_kwargs_for_generation(
|
| 483 |
-
outputs, model_kwargs, is_encoder_decoder=False,
|
| 484 |
-
)
|
| 485 |
-
|
| 486 |
-
# Get logits and apply logits processor
|
| 487 |
-
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
| 488 |
-
# next_token_logits = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
|
| 489 |
-
next_token_scores = logits_processor(input_ids, next_token_logits)
|
| 490 |
-
|
| 491 |
-
# token selection
|
| 492 |
-
if generation_config.do_sample:
|
| 493 |
-
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
| 494 |
-
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
|
| 495 |
-
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 496 |
-
else:
|
| 497 |
-
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
| 498 |
-
|
| 499 |
-
next_tokens[finished_tags] = generation_config.eos_token_id
|
| 500 |
-
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
| 501 |
-
|
| 502 |
-
if not kwargs.get('refresh_negative', True):
|
| 503 |
-
negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
|
| 504 |
-
# Forward negative pass through the model
|
| 505 |
-
if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
|
| 506 |
-
negative_model_inputs['inputs_embeds'] = inputs_embeds
|
| 507 |
-
negative_model_inputs['input_ids'] = None
|
| 508 |
-
|
| 509 |
-
negative_outputs = self(
|
| 510 |
-
**negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
|
| 511 |
-
)
|
| 512 |
-
negative_model_kwargs = self._update_model_kwargs_for_generation(
|
| 513 |
-
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
|
| 514 |
-
)
|
| 515 |
-
negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
|
| 516 |
-
|
| 517 |
-
# reached end of generation
|
| 518 |
-
if (next_tokens == generation_config.eos_token_id).any():
|
| 519 |
-
eos_indices = (next_tokens == generation_config.eos_token_id).nonzero(as_tuple=False).squeeze(1)
|
| 520 |
-
# Only print for samples that are newly finished (not already marked as finished)
|
| 521 |
-
new_eos_indices = eos_indices[~finished_tags[eos_indices]]
|
| 522 |
-
if new_eos_indices.numel() > 0:
|
| 523 |
-
finished_tags[new_eos_indices] = True
|
| 524 |
-
if verbose:
|
| 525 |
-
print(f"Samples {new_eos_indices.tolist()} reached EOS token at step {step + 1}.", flush=True)
|
| 526 |
-
if audio_streamer is not None:
|
| 527 |
-
audio_streamer.end(new_eos_indices)
|
| 528 |
-
|
| 529 |
-
# Check if any sample reached its maximum generation length
|
| 530 |
-
max_length_reached = step >= max_step_per_sample
|
| 531 |
-
new_max_length_indices = torch.nonzero(max_length_reached & ~finished_tags, as_tuple=False).squeeze(1)
|
| 532 |
-
if new_max_length_indices.numel() > 0:
|
| 533 |
-
finished_tags[new_max_length_indices] = True
|
| 534 |
-
reach_max_step_sample[new_max_length_indices] = True
|
| 535 |
-
if verbose:
|
| 536 |
-
print(f"Samples {new_max_length_indices.tolist()} reached max generation length at step {step + 1}.", flush=True)
|
| 537 |
-
if audio_streamer is not None:
|
| 538 |
-
audio_streamer.end(new_max_length_indices)
|
| 539 |
-
|
| 540 |
-
# speech_end
|
| 541 |
-
diffusion_end_indices = (next_tokens == generation_config.speech_end_id).nonzero(as_tuple=False).squeeze(1)
|
| 542 |
-
if diffusion_end_indices.numel() > 0:
|
| 543 |
-
# Clear tokenizer caches for samples that reached speech end
|
| 544 |
-
acoustic_cache.set_to_zero(diffusion_end_indices)
|
| 545 |
-
semantic_cache.set_to_zero(diffusion_end_indices)
|
| 546 |
-
|
| 547 |
-
# speech_begin
|
| 548 |
-
diffusion_start_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_start_id)]
|
| 549 |
-
if diffusion_start_indices.numel() > 0 and kwargs.get('refresh_negative', True):
|
| 550 |
-
# update attention mask
|
| 551 |
-
for i, sample_idx in enumerate(diffusion_start_indices.tolist()):
|
| 552 |
-
negative_model_kwargs['attention_mask'][sample_idx, :] = 0
|
| 553 |
-
negative_model_kwargs['attention_mask'][sample_idx, -1] = 1
|
| 554 |
-
# update past key values
|
| 555 |
-
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
|
| 556 |
-
negative_model_kwargs['past_key_values'].value_cache)):
|
| 557 |
-
# Process each non-diffusion sample
|
| 558 |
-
for sample_idx in diffusion_start_indices.tolist():
|
| 559 |
-
# Shift cache for this sample
|
| 560 |
-
k_cache[sample_idx, :, -1, :] = k_cache[sample_idx, :, 0, :].clone()
|
| 561 |
-
v_cache[sample_idx, :, -1, :] = v_cache[sample_idx, :, 0, :].clone()
|
| 562 |
-
# update negative_input_ids
|
| 563 |
-
for sample_idx in diffusion_start_indices.tolist():
|
| 564 |
-
negative_input_ids[sample_idx, -1] = generation_config.speech_start_id
|
| 565 |
-
|
| 566 |
-
# Prepare inputs_embeds for next iteration
|
| 567 |
-
# Initialize with default embeddings for all tokens
|
| 568 |
-
next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1) # [batch_size, 1, hidden_size]
|
| 569 |
-
|
| 570 |
-
# forward diffusion
|
| 571 |
-
# Diffusion indices are those that are not finished and not special tokens
|
| 572 |
-
diffusion_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_diffusion_id)]
|
| 573 |
-
|
| 574 |
-
if diffusion_indices.numel() > 0:
|
| 575 |
-
if kwargs.get('refresh_negative', True):
|
| 576 |
-
negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
|
| 577 |
-
# Forward negative pass through the model
|
| 578 |
-
if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
|
| 579 |
-
negative_model_inputs['inputs_embeds'] = inputs_embeds
|
| 580 |
-
negative_model_inputs['input_ids'] = None
|
| 581 |
-
|
| 582 |
-
negative_outputs = self(
|
| 583 |
-
**negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
|
| 584 |
-
)
|
| 585 |
-
negative_model_kwargs = self._update_model_kwargs_for_generation(
|
| 586 |
-
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
|
| 587 |
-
)
|
| 588 |
-
negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
|
| 589 |
-
# correct the non-diffusion indices
|
| 590 |
-
# we forward all samples' negative outputs even if
|
| 591 |
-
# they are not in diffusion mode to keep the cache consistent
|
| 592 |
-
# So we need to correct the kv cache of non-diffusion samples
|
| 593 |
-
non_diffusion_mask = ~finished_tags & (next_tokens != generation_config.speech_diffusion_id)
|
| 594 |
-
if non_diffusion_mask.any():
|
| 595 |
-
non_diffusion_indices = torch.arange(batch_size, device=device)[non_diffusion_mask]
|
| 596 |
-
start_indices = correct_cnt[non_diffusion_indices]
|
| 597 |
-
|
| 598 |
-
# 1. Update attention_mask - need to handle each sample separately
|
| 599 |
-
seq_len = negative_model_kwargs['attention_mask'].shape[1]
|
| 600 |
-
for i, (sample_idx, start_idx) in enumerate(zip(non_diffusion_indices.tolist(), start_indices.tolist())):
|
| 601 |
-
# Shift the attention mask for this sample
|
| 602 |
-
if start_idx + 1 < seq_len - 1:
|
| 603 |
-
negative_model_kwargs['attention_mask'][sample_idx, start_idx+1:] = \
|
| 604 |
-
negative_model_kwargs['attention_mask'][sample_idx, start_idx:-1].clone()
|
| 605 |
-
negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0
|
| 606 |
-
|
| 607 |
-
# 2. Update past_key_values
|
| 608 |
-
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
|
| 609 |
-
negative_model_kwargs['past_key_values'].value_cache)):
|
| 610 |
-
# Process each non-diffusion sample
|
| 611 |
-
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
|
| 612 |
-
if start_idx + 1 < k_cache.shape[2] - 1:
|
| 613 |
-
# Shift cache for this sample
|
| 614 |
-
k_cache[sample_idx, :, start_idx+1:, :] = k_cache[sample_idx, :, start_idx:-1, :].clone()
|
| 615 |
-
v_cache[sample_idx, :, start_idx+1:, :] = v_cache[sample_idx, :, start_idx:-1, :].clone()
|
| 616 |
-
|
| 617 |
-
# 3. Update negative_input_ids
|
| 618 |
-
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
|
| 619 |
-
if start_idx + 1 < negative_input_ids.shape[1] - 1:
|
| 620 |
-
negative_input_ids[sample_idx, start_idx+1:] = \
|
| 621 |
-
negative_input_ids[sample_idx, start_idx:-1].clone()
|
| 622 |
-
|
| 623 |
-
correct_cnt[non_diffusion_indices] += 1
|
| 624 |
-
|
| 625 |
-
positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]
|
| 626 |
-
negative_condition = negative_outputs.last_hidden_state[diffusion_indices, -1, :]
|
| 627 |
-
|
| 628 |
-
speech_latent = self.sample_speech_tokens(
|
| 629 |
-
positive_condition,
|
| 630 |
-
negative_condition,
|
| 631 |
-
cfg_scale=cfg_scale,
|
| 632 |
-
).unsqueeze(1)
|
| 633 |
-
|
| 634 |
-
# Decode acoustic latent to audio using acoustic streaming cache
|
| 635 |
-
scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device)
|
| 636 |
-
audio_chunk = self.model.acoustic_tokenizer.decode(
|
| 637 |
-
scaled_latent.to(self.model.acoustic_tokenizer.device),
|
| 638 |
-
cache=acoustic_cache, # Use acoustic-specific cache
|
| 639 |
-
sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
|
| 640 |
-
use_cache=True,
|
| 641 |
-
debug=False
|
| 642 |
-
)
|
| 643 |
-
|
| 644 |
-
# Store audio chunks for each sample
|
| 645 |
-
for i, sample_idx in enumerate(diffusion_indices):
|
| 646 |
-
idx = sample_idx.item()
|
| 647 |
-
# Only append audio chunk if the sample is not finished
|
| 648 |
-
if not finished_tags[idx]:
|
| 649 |
-
audio_chunks[idx].append(audio_chunk[i])
|
| 650 |
-
|
| 651 |
-
# Add streaming support here
|
| 652 |
-
if audio_streamer is not None:
|
| 653 |
-
# Stream the audio chunks immediately
|
| 654 |
-
audio_streamer.put(audio_chunk, diffusion_indices)
|
| 655 |
-
|
| 656 |
-
# Encode audio to semantic features using semantic streaming cache
|
| 657 |
-
semantic_features = self.model.semantic_tokenizer.encode(
|
| 658 |
-
audio_chunk,
|
| 659 |
-
cache=semantic_cache, # Use semantic-specific cache
|
| 660 |
-
sample_indices=diffusion_indices,
|
| 661 |
-
use_cache=True,
|
| 662 |
-
debug=False
|
| 663 |
-
).mean # semantic tokenizer has no VAE.
|
| 664 |
-
|
| 665 |
-
# Combine acoustic and semantic features for next input
|
| 666 |
-
acoustic_embed = self.model.acoustic_connector(speech_latent)
|
| 667 |
-
semantic_embed = self.model.semantic_connector(semantic_features)
|
| 668 |
-
diffusion_embeds = acoustic_embed + semantic_embed
|
| 669 |
-
|
| 670 |
-
# Update embeddings for diffusion indices
|
| 671 |
-
next_inputs_embeds[diffusion_indices] = diffusion_embeds
|
| 672 |
-
|
| 673 |
-
# Set inputs_embeds for next iteration
|
| 674 |
-
inputs_embeds = next_inputs_embeds
|
| 675 |
-
|
| 676 |
-
if audio_streamer is not None:
|
| 677 |
-
audio_streamer.end()
|
| 678 |
-
|
| 679 |
-
# Concatenate audio chunks for each sample
|
| 680 |
-
final_audio_outputs = []
|
| 681 |
-
for sample_chunks in audio_chunks:
|
| 682 |
-
if sample_chunks:
|
| 683 |
-
# Concatenate all chunks along the time dimension (assumed to be the last dimension)
|
| 684 |
-
concatenated_audio = torch.cat(sample_chunks, dim=-1)
|
| 685 |
-
final_audio_outputs.append(concatenated_audio)
|
| 686 |
-
else:
|
| 687 |
-
# If no audio was generated for this sample, append None
|
| 688 |
-
final_audio_outputs.append(None)
|
| 689 |
-
|
| 690 |
-
return VibeVoiceGenerationOutput(
|
| 691 |
-
sequences=input_ids,
|
| 692 |
-
speech_outputs=final_audio_outputs if return_speech else None,
|
| 693 |
-
reach_max_step_sample=reach_max_step_sample,
|
| 694 |
-
)
|
| 695 |
-
|
| 696 |
-
@torch.no_grad()
|
| 697 |
-
def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
|
| 698 |
-
self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
|
| 699 |
-
condition = torch.cat([condition, neg_condition], dim=0).to(self.model.prediction_head.device)
|
| 700 |
-
speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition)
|
| 701 |
-
for t in self.model.noise_scheduler.timesteps:
|
| 702 |
-
half = speech[: len(speech) // 2]
|
| 703 |
-
combined = torch.cat([half, half], dim=0)
|
| 704 |
-
eps = self.model.prediction_head(combined, t.repeat(combined.shape[0]).to(combined), condition=condition)
|
| 705 |
-
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 706 |
-
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 707 |
-
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 708 |
-
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
|
| 709 |
-
return speech[: len(speech) // 2]
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGenerationInference)
|
| 713 |
-
|
| 714 |
-
__all__ = [
|
| 715 |
-
"VibeVoiceForConditionalGenerationInference",
|
| 716 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/modular/modular_vibevoice_diffusion_head.py
DELETED
|
@@ -1,287 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
from typing import Optional, Tuple, Union
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
|
| 8 |
-
from transformers.models.auto import AutoModel
|
| 9 |
-
from transformers.modeling_utils import PreTrainedModel
|
| 10 |
-
# from transformers.modeling_layers import GradientCheckpointingLayer
|
| 11 |
-
from transformers.activations import ACT2FN
|
| 12 |
-
from transformers.utils import logging
|
| 13 |
-
|
| 14 |
-
from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
logger = logging.get_logger(__name__)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class RMSNorm(nn.Module):
|
| 21 |
-
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
|
| 22 |
-
super().__init__()
|
| 23 |
-
self.dim = dim
|
| 24 |
-
self.eps = eps
|
| 25 |
-
self.elementwise_affine = elementwise_affine
|
| 26 |
-
if self.elementwise_affine:
|
| 27 |
-
self.weight = nn.Parameter(torch.ones(dim))
|
| 28 |
-
else:
|
| 29 |
-
self.register_parameter('weight', None)
|
| 30 |
-
|
| 31 |
-
def _norm(self, x):
|
| 32 |
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 33 |
-
|
| 34 |
-
def forward(self, x):
|
| 35 |
-
output = self._norm(x.float()).type_as(x)
|
| 36 |
-
if self.weight is not None:
|
| 37 |
-
output = output * self.weight
|
| 38 |
-
return output
|
| 39 |
-
|
| 40 |
-
def extra_repr(self) -> str:
|
| 41 |
-
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
| 42 |
-
|
| 43 |
-
def modulate(x, shift, scale):
|
| 44 |
-
"""Apply modulation to input tensor."""
|
| 45 |
-
return x * (1 + scale) + shift
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
class TimestepEmbedder(nn.Module):
|
| 49 |
-
"""
|
| 50 |
-
Embeds scalar timesteps into vector representations.
|
| 51 |
-
|
| 52 |
-
Args:
|
| 53 |
-
hidden_size (`int`): Size of the output embedding
|
| 54 |
-
frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
|
| 55 |
-
"""
|
| 56 |
-
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 57 |
-
super().__init__()
|
| 58 |
-
self.mlp = nn.Sequential(
|
| 59 |
-
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
|
| 60 |
-
# nn.SiLU(),
|
| 61 |
-
ACT2FN['silu'],
|
| 62 |
-
nn.Linear(hidden_size, hidden_size, bias=False),
|
| 63 |
-
)
|
| 64 |
-
self.frequency_embedding_size = frequency_embedding_size
|
| 65 |
-
|
| 66 |
-
@staticmethod
|
| 67 |
-
def timestep_embedding(t, dim, max_period=10000):
|
| 68 |
-
"""
|
| 69 |
-
Create sinusoidal timestep embeddings.
|
| 70 |
-
|
| 71 |
-
Args:
|
| 72 |
-
t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
|
| 73 |
-
These may be fractional.
|
| 74 |
-
dim (`int`): The dimension of the output.
|
| 75 |
-
max_period (`int`, optional): Controls the minimum frequency of the embeddings.
|
| 76 |
-
|
| 77 |
-
Returns:
|
| 78 |
-
`torch.Tensor`: An [N, D] Tensor of positional embeddings.
|
| 79 |
-
"""
|
| 80 |
-
half = dim // 2
|
| 81 |
-
freqs = torch.exp(
|
| 82 |
-
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 83 |
-
).to(t.device)
|
| 84 |
-
args = t[:, None].float() * freqs[None]
|
| 85 |
-
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 86 |
-
if dim % 2:
|
| 87 |
-
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 88 |
-
return embedding.to(t.dtype)
|
| 89 |
-
|
| 90 |
-
def forward(self, t):
|
| 91 |
-
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 92 |
-
t_emb = self.mlp(t_freq)
|
| 93 |
-
return t_emb
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
class FeedForwardNetwork(nn.Module):
|
| 97 |
-
"""
|
| 98 |
-
Standard feed-forward network with SwiGLU activation.
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
embed_dim (`int`): Input dimension
|
| 102 |
-
ffn_dim (`int`): Hidden dimension
|
| 103 |
-
"""
|
| 104 |
-
def __init__(
|
| 105 |
-
self,
|
| 106 |
-
embed_dim,
|
| 107 |
-
ffn_dim,
|
| 108 |
-
):
|
| 109 |
-
super().__init__()
|
| 110 |
-
self.embed_dim = embed_dim
|
| 111 |
-
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
| 112 |
-
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
| 113 |
-
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
|
| 114 |
-
self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
|
| 115 |
-
|
| 116 |
-
def forward(self, x):
|
| 117 |
-
gate = self.gate_proj(x)
|
| 118 |
-
up = self.up_proj(x)
|
| 119 |
-
|
| 120 |
-
# SwiGLU activation
|
| 121 |
-
# gate = F.silu(gate)
|
| 122 |
-
gate = self.act_fn(gate)
|
| 123 |
-
return self.down_proj(gate * up)
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
class HeadLayer(nn.Module):
|
| 127 |
-
"""
|
| 128 |
-
A layer in the diffusion head.
|
| 129 |
-
|
| 130 |
-
Args:
|
| 131 |
-
embed_dim (`int`): Input dimension
|
| 132 |
-
ffn_dim (`int`): Hidden dimension
|
| 133 |
-
cond_dim (`int`): Condition embedding dimension
|
| 134 |
-
norm_eps (`float`, optional): Epsilon for normalization
|
| 135 |
-
"""
|
| 136 |
-
def __init__(
|
| 137 |
-
self,
|
| 138 |
-
embed_dim,
|
| 139 |
-
ffn_dim,
|
| 140 |
-
cond_dim,
|
| 141 |
-
norm_eps=1e-5,
|
| 142 |
-
):
|
| 143 |
-
super().__init__()
|
| 144 |
-
self.embed_dim = embed_dim
|
| 145 |
-
self.cond_dim = cond_dim
|
| 146 |
-
self.ffn_dim = ffn_dim
|
| 147 |
-
self.ffn = FeedForwardNetwork(
|
| 148 |
-
self.embed_dim,
|
| 149 |
-
self.ffn_dim,
|
| 150 |
-
)
|
| 151 |
-
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
|
| 152 |
-
self.adaLN_modulation = nn.Sequential(
|
| 153 |
-
# nn.SiLU(),
|
| 154 |
-
ACT2FN['silu'],
|
| 155 |
-
nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
|
| 156 |
-
)
|
| 157 |
-
|
| 158 |
-
def forward(self, x, c):
|
| 159 |
-
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
|
| 160 |
-
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
|
| 161 |
-
return x
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
class FinalLayer(nn.Module):
|
| 165 |
-
"""
|
| 166 |
-
Final layer in the diffusion head.
|
| 167 |
-
|
| 168 |
-
Args:
|
| 169 |
-
hidden_size (`int`): Input dimension
|
| 170 |
-
output_size (`int`): Output dimension
|
| 171 |
-
cond_size (`int`): Condition embedding dimension
|
| 172 |
-
norm_eps (`float`, optional): Epsilon for normalization
|
| 173 |
-
"""
|
| 174 |
-
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
|
| 175 |
-
super().__init__()
|
| 176 |
-
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
|
| 177 |
-
self.linear = nn.Linear(hidden_size, output_size, bias=False)
|
| 178 |
-
self.adaLN_modulation = nn.Sequential(
|
| 179 |
-
# nn.SiLU(),
|
| 180 |
-
ACT2FN['silu'],
|
| 181 |
-
nn.Linear(cond_size, 2 * hidden_size, bias=False)
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
def forward(self, x, c):
|
| 185 |
-
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 186 |
-
x = modulate(self.norm_final(x), shift, scale)
|
| 187 |
-
x = self.linear(x)
|
| 188 |
-
return x
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
class VibeVoiceDiffusionHead(PreTrainedModel):
|
| 192 |
-
"""
|
| 193 |
-
Diffusion head model for vibevoice.
|
| 194 |
-
|
| 195 |
-
Args:
|
| 196 |
-
config (`VibeVoiceDiffusionHeadConfig`): Model configuration
|
| 197 |
-
latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
|
| 198 |
-
"""
|
| 199 |
-
config_class = VibeVoiceDiffusionHeadConfig
|
| 200 |
-
supports_gradient_checkpointing = True
|
| 201 |
-
_supports_flash_attn_2 = True
|
| 202 |
-
_supports_sdpa = True
|
| 203 |
-
|
| 204 |
-
def __init__(
|
| 205 |
-
self,
|
| 206 |
-
config,
|
| 207 |
-
):
|
| 208 |
-
super().__init__(config)
|
| 209 |
-
self.config = config
|
| 210 |
-
self.cond_dim = config.hidden_size
|
| 211 |
-
latent_size = config.latent_size
|
| 212 |
-
|
| 213 |
-
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
|
| 214 |
-
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
|
| 215 |
-
self.t_embedder = TimestepEmbedder(self.cond_dim)
|
| 216 |
-
|
| 217 |
-
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
|
| 218 |
-
|
| 219 |
-
# Create the intermediate layers
|
| 220 |
-
self.layers = nn.ModuleList([
|
| 221 |
-
HeadLayer(
|
| 222 |
-
embed_dim=config.hidden_size,
|
| 223 |
-
ffn_dim=ffn_dim,
|
| 224 |
-
cond_dim=self.cond_dim,
|
| 225 |
-
norm_eps=config.rms_norm_eps
|
| 226 |
-
)
|
| 227 |
-
for _ in range(config.head_layers)
|
| 228 |
-
])
|
| 229 |
-
|
| 230 |
-
# Final layer for output
|
| 231 |
-
self.final_layer = FinalLayer(
|
| 232 |
-
hidden_size=config.hidden_size,
|
| 233 |
-
output_size=latent_size,
|
| 234 |
-
cond_size=self.cond_dim,
|
| 235 |
-
norm_eps=config.rms_norm_eps
|
| 236 |
-
)
|
| 237 |
-
|
| 238 |
-
self.initialize_weights()
|
| 239 |
-
|
| 240 |
-
def initialize_weights(self):
|
| 241 |
-
"""Initialize the weights of the model."""
|
| 242 |
-
# Initialize timestep embedder
|
| 243 |
-
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 244 |
-
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 245 |
-
|
| 246 |
-
# Zero-out adaLN modulation layers
|
| 247 |
-
for layer in self.layers:
|
| 248 |
-
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
|
| 249 |
-
|
| 250 |
-
# Zero-out output layers
|
| 251 |
-
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 252 |
-
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 253 |
-
|
| 254 |
-
def forward(
|
| 255 |
-
self,
|
| 256 |
-
noisy_images,
|
| 257 |
-
timesteps,
|
| 258 |
-
condition,
|
| 259 |
-
):
|
| 260 |
-
"""
|
| 261 |
-
Forward pass of the prediction head.
|
| 262 |
-
|
| 263 |
-
Args:
|
| 264 |
-
noisy_images (`torch.Tensor`): Noisy images/latents to denoise
|
| 265 |
-
timesteps (`torch.Tensor`): Timesteps for diffusion
|
| 266 |
-
condition (`torch.Tensor`): Conditioning information
|
| 267 |
-
|
| 268 |
-
Returns:
|
| 269 |
-
`torch.Tensor`: The predicted noise/velocity
|
| 270 |
-
"""
|
| 271 |
-
x = self.noisy_images_proj(noisy_images)
|
| 272 |
-
t = self.t_embedder(timesteps)
|
| 273 |
-
condition = self.cond_proj(condition)
|
| 274 |
-
c = condition + t
|
| 275 |
-
|
| 276 |
-
for layer in self.layers:
|
| 277 |
-
x = layer(x, c)
|
| 278 |
-
|
| 279 |
-
x = self.final_layer(x, c)
|
| 280 |
-
return x
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
|
| 284 |
-
|
| 285 |
-
__all__ = [
|
| 286 |
-
"VibeVoiceDiffusionHead",
|
| 287 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/modular/modular_vibevoice_text_tokenizer.py
DELETED
|
@@ -1,214 +0,0 @@
|
|
| 1 |
-
"""Tokenization classes for vibevoice."""
|
| 2 |
-
|
| 3 |
-
from typing import List, Optional, Union
|
| 4 |
-
|
| 5 |
-
from transformers.utils import logging
|
| 6 |
-
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
|
| 7 |
-
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
|
| 8 |
-
|
| 9 |
-
logger = logging.get_logger(__name__)
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class VibeVoiceTextTokenizer(Qwen2Tokenizer):
|
| 13 |
-
"""
|
| 14 |
-
Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech.
|
| 15 |
-
|
| 16 |
-
Args:
|
| 17 |
-
vocab_file (`str`):
|
| 18 |
-
Path to the vocabulary file.
|
| 19 |
-
merges_file (`str`):
|
| 20 |
-
Path to the merges file.
|
| 21 |
-
errors (`str`, *optional*, defaults to `"replace"`):
|
| 22 |
-
Paradigm to follow when decoding bytes to UTF-8.
|
| 23 |
-
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 24 |
-
The unknown token.
|
| 25 |
-
bos_token (`str`, *optional*):
|
| 26 |
-
The beginning of sequence token. Not used for vibevoice.
|
| 27 |
-
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 28 |
-
The end of sequence token.
|
| 29 |
-
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 30 |
-
The token used for padding.
|
| 31 |
-
add_special_tokens (`bool`, *optional*, defaults to `True`):
|
| 32 |
-
Whether or not to add special tokens when encoding.
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
model_input_names = ["input_ids", "attention_mask"]
|
| 36 |
-
|
| 37 |
-
def __init__(
|
| 38 |
-
self,
|
| 39 |
-
vocab_file,
|
| 40 |
-
merges_file,
|
| 41 |
-
errors="replace",
|
| 42 |
-
unk_token="<|endoftext|>",
|
| 43 |
-
bos_token=None,
|
| 44 |
-
eos_token="<|endoftext|>",
|
| 45 |
-
pad_token="<|endoftext|>",
|
| 46 |
-
add_prefix_space=False,
|
| 47 |
-
add_special_tokens=True,
|
| 48 |
-
**kwargs,
|
| 49 |
-
):
|
| 50 |
-
super().__init__(
|
| 51 |
-
vocab_file=vocab_file,
|
| 52 |
-
merges_file=merges_file,
|
| 53 |
-
errors=errors,
|
| 54 |
-
unk_token=unk_token,
|
| 55 |
-
bos_token=bos_token,
|
| 56 |
-
eos_token=eos_token,
|
| 57 |
-
pad_token=pad_token,
|
| 58 |
-
add_prefix_space=add_prefix_space,
|
| 59 |
-
add_special_tokens=add_special_tokens,
|
| 60 |
-
**kwargs,
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
# Add VibeVoice-specific special tokens
|
| 64 |
-
self._add_vibevoice_special_tokens()
|
| 65 |
-
|
| 66 |
-
def _add_vibevoice_special_tokens(self):
|
| 67 |
-
"""Add VibeVoice-specific special tokens."""
|
| 68 |
-
special_tokens = {
|
| 69 |
-
"additional_special_tokens": [
|
| 70 |
-
"<|vision_start|>", # Speech start (reusing vision tokens)
|
| 71 |
-
"<|vision_end|>", # Speech end
|
| 72 |
-
"<|vision_pad|>", # Speech diffusion pad
|
| 73 |
-
]
|
| 74 |
-
}
|
| 75 |
-
num_added = self.add_special_tokens(special_tokens)
|
| 76 |
-
|
| 77 |
-
# Cache special token IDs
|
| 78 |
-
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
|
| 79 |
-
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
|
| 80 |
-
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
|
| 81 |
-
|
| 82 |
-
self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
|
| 83 |
-
|
| 84 |
-
return num_added
|
| 85 |
-
|
| 86 |
-
@property
|
| 87 |
-
def eos_id(self) -> int:
|
| 88 |
-
"""Id of the end of sequence token."""
|
| 89 |
-
return self._eos_id
|
| 90 |
-
|
| 91 |
-
@property
|
| 92 |
-
def speech_start_id(self) -> int:
|
| 93 |
-
"""Id of the speech start token."""
|
| 94 |
-
return self._speech_start_id
|
| 95 |
-
|
| 96 |
-
@property
|
| 97 |
-
def speech_end_id(self) -> int:
|
| 98 |
-
"""Id of the speech end token."""
|
| 99 |
-
return self._speech_end_id
|
| 100 |
-
|
| 101 |
-
@property
|
| 102 |
-
def speech_diffusion_id(self) -> int:
|
| 103 |
-
"""Id of the speech diffusion token."""
|
| 104 |
-
return self._speech_diffusion_id
|
| 105 |
-
|
| 106 |
-
@property
|
| 107 |
-
def pad_id(self) -> int:
|
| 108 |
-
"""Id used for padding (returns -100 for loss masking)."""
|
| 109 |
-
return -100
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
|
| 113 |
-
"""
|
| 114 |
-
Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
|
| 115 |
-
Based on the Qwen2 tokenizer with additional special tokens for speech.
|
| 116 |
-
|
| 117 |
-
Args:
|
| 118 |
-
vocab_file (`str`, *optional*):
|
| 119 |
-
Path to the vocabulary file.
|
| 120 |
-
merges_file (`str`, *optional*):
|
| 121 |
-
Path to the merges file.
|
| 122 |
-
tokenizer_file (`str`, *optional*):
|
| 123 |
-
Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
|
| 124 |
-
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 125 |
-
The unknown token.
|
| 126 |
-
bos_token (`str`, *optional*):
|
| 127 |
-
The beginning of sequence token. Not used for vibevoice.
|
| 128 |
-
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 129 |
-
The end of sequence token.
|
| 130 |
-
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 131 |
-
The token used for padding.
|
| 132 |
-
"""
|
| 133 |
-
|
| 134 |
-
model_input_names = ["input_ids", "attention_mask"]
|
| 135 |
-
|
| 136 |
-
def __init__(
|
| 137 |
-
self,
|
| 138 |
-
vocab_file=None,
|
| 139 |
-
merges_file=None,
|
| 140 |
-
tokenizer_file=None,
|
| 141 |
-
unk_token="<|endoftext|>",
|
| 142 |
-
bos_token=None,
|
| 143 |
-
eos_token="<|endoftext|>",
|
| 144 |
-
pad_token="<|endoftext|>",
|
| 145 |
-
add_prefix_space=False,
|
| 146 |
-
**kwargs,
|
| 147 |
-
):
|
| 148 |
-
super().__init__(
|
| 149 |
-
vocab_file=vocab_file,
|
| 150 |
-
merges_file=merges_file,
|
| 151 |
-
tokenizer_file=tokenizer_file,
|
| 152 |
-
unk_token=unk_token,
|
| 153 |
-
bos_token=bos_token,
|
| 154 |
-
eos_token=eos_token,
|
| 155 |
-
pad_token=pad_token,
|
| 156 |
-
add_prefix_space=add_prefix_space,
|
| 157 |
-
**kwargs,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
# Add VibeVoice-specific special tokens
|
| 161 |
-
self._add_vibevoice_special_tokens()
|
| 162 |
-
|
| 163 |
-
def _add_vibevoice_special_tokens(self):
|
| 164 |
-
"""Add VibeVoice-specific special tokens."""
|
| 165 |
-
special_tokens = {
|
| 166 |
-
"additional_special_tokens": [
|
| 167 |
-
"<|vision_start|>", # Speech start (reusing vision tokens)
|
| 168 |
-
"<|vision_end|>", # Speech end
|
| 169 |
-
"<|vision_pad|>", # Speech diffusion pad
|
| 170 |
-
]
|
| 171 |
-
}
|
| 172 |
-
num_added = self.add_special_tokens(special_tokens)
|
| 173 |
-
|
| 174 |
-
# Cache special token IDs
|
| 175 |
-
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
|
| 176 |
-
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
|
| 177 |
-
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
|
| 178 |
-
|
| 179 |
-
# self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
|
| 180 |
-
self._eos_id = self.eos_token_id # qwen2 / qwen3
|
| 181 |
-
self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
|
| 182 |
-
|
| 183 |
-
return num_added
|
| 184 |
-
|
| 185 |
-
@property
|
| 186 |
-
def eos_id(self) -> int:
|
| 187 |
-
"""Id of the end of sequence token."""
|
| 188 |
-
return self._eos_id
|
| 189 |
-
|
| 190 |
-
@property
|
| 191 |
-
def speech_start_id(self) -> int:
|
| 192 |
-
"""Id of the speech start token."""
|
| 193 |
-
return self._speech_start_id
|
| 194 |
-
|
| 195 |
-
@property
|
| 196 |
-
def speech_end_id(self) -> int:
|
| 197 |
-
"""Id of the speech end token."""
|
| 198 |
-
return self._speech_end_id
|
| 199 |
-
|
| 200 |
-
@property
|
| 201 |
-
def speech_diffusion_id(self) -> int:
|
| 202 |
-
"""Id of the speech diffusion token."""
|
| 203 |
-
return self._speech_diffusion_id
|
| 204 |
-
|
| 205 |
-
@property
|
| 206 |
-
def pad_id(self) -> int:
|
| 207 |
-
"""Id used for padding (returns -100 for loss masking)."""
|
| 208 |
-
return self._pad_id
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
__all__ = [
|
| 212 |
-
"VibeVoiceTextTokenizer",
|
| 213 |
-
"VibeVoiceTextTokenizerFast",
|
| 214 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/modular/modular_vibevoice_tokenizer.py
DELETED
|
@@ -1,1195 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import typing as tp
|
| 3 |
-
from functools import partial
|
| 4 |
-
from dataclasses import dataclass, field
|
| 5 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 6 |
-
import copy
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
import torch.nn.functional as F
|
| 12 |
-
|
| 13 |
-
from transformers.models.auto import AutoModel
|
| 14 |
-
|
| 15 |
-
from transformers.configuration_utils import PretrainedConfig
|
| 16 |
-
from transformers.utils import logging
|
| 17 |
-
from transformers.modeling_utils import PreTrainedModel
|
| 18 |
-
from transformers.activations import ACT2FN
|
| 19 |
-
|
| 20 |
-
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig
|
| 21 |
-
|
| 22 |
-
logger = logging.get_logger(__name__)
|
| 23 |
-
|
| 24 |
-
import os
|
| 25 |
-
# Try to import APEX FusedRMSNorm
|
| 26 |
-
try:
|
| 27 |
-
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
|
| 28 |
-
APEX_AVAILABLE = True
|
| 29 |
-
logger.info("APEX FusedRMSNorm is available and will be used for optimization")
|
| 30 |
-
if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
|
| 31 |
-
APEX_AVAILABLE = False
|
| 32 |
-
logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
|
| 33 |
-
except ImportError:
|
| 34 |
-
APEX_AVAILABLE = False
|
| 35 |
-
logger.warning("APEX FusedRMSNorm not available, using native implementation")
|
| 36 |
-
# APEX_AVAILABLE=False
|
| 37 |
-
|
| 38 |
-
# Normalization modules
|
| 39 |
-
class ConvLayerNorm(nn.LayerNorm):
|
| 40 |
-
"""
|
| 41 |
-
Convolution-friendly LayerNorm that moves channels to last dimensions
|
| 42 |
-
before running the normalization and moves them back to original position right after.
|
| 43 |
-
"""
|
| 44 |
-
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
|
| 45 |
-
super().__init__(normalized_shape, **kwargs)
|
| 46 |
-
|
| 47 |
-
def forward(self, x):
|
| 48 |
-
x = x.transpose(1, 2) # b ... t -> b t ...
|
| 49 |
-
x = nn.functional.layer_norm(x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(x)
|
| 50 |
-
x = x.transpose(1, 2) # b t ... -> b ... t
|
| 51 |
-
return x
|
| 52 |
-
|
| 53 |
-
class RMSNorm(nn.Module):
|
| 54 |
-
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
|
| 55 |
-
super().__init__()
|
| 56 |
-
self.dim = dim
|
| 57 |
-
self.eps = eps
|
| 58 |
-
self.elementwise_affine = elementwise_affine
|
| 59 |
-
if self.elementwise_affine:
|
| 60 |
-
weight_shape = (dim,) if weight_shape is None else weight_shape
|
| 61 |
-
self.weight = nn.Parameter(torch.ones(weight_shape))
|
| 62 |
-
else:
|
| 63 |
-
self.register_parameter('weight', None)
|
| 64 |
-
|
| 65 |
-
def _norm(self, x):
|
| 66 |
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 67 |
-
|
| 68 |
-
def forward(self, x):
|
| 69 |
-
output = self._norm(x.float()).type_as(x)
|
| 70 |
-
if self.weight is not None:
|
| 71 |
-
output = output * self.weight
|
| 72 |
-
return output
|
| 73 |
-
|
| 74 |
-
def extra_repr(self) -> str:
|
| 75 |
-
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
| 76 |
-
|
| 77 |
-
class ConvRMSNorm(RMSNorm):
|
| 78 |
-
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
|
| 79 |
-
super().__init__(dim, eps, elementwise_affine, weight_shape)
|
| 80 |
-
|
| 81 |
-
def forward(self, x):
|
| 82 |
-
x = x.transpose(1, 2) # b ... t -> b t ...
|
| 83 |
-
if (not APEX_AVAILABLE) or (not self.elementwise_affine):
|
| 84 |
-
# Fallback to native implementation
|
| 85 |
-
output = self._norm(x.float()).type_as(x)
|
| 86 |
-
if self.weight is not None:
|
| 87 |
-
output = output * self.weight
|
| 88 |
-
else:
|
| 89 |
-
output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps)
|
| 90 |
-
output = output.transpose(1, 2) # b t ... -> b ... t
|
| 91 |
-
return output
|
| 92 |
-
|
| 93 |
-
# Convolutional layers and utilities
|
| 94 |
-
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
| 95 |
-
'time_layer_norm', 'layer_norm', 'time_group_norm'])
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
|
| 99 |
-
assert norm in CONV_NORMALIZATIONS
|
| 100 |
-
if norm == 'weight_norm':
|
| 101 |
-
return nn.utils.weight_norm(module)
|
| 102 |
-
elif norm == 'spectral_norm':
|
| 103 |
-
return nn.utils.spectral_norm(module)
|
| 104 |
-
else:
|
| 105 |
-
# We already check was in CONV_NORMALIZATION, so any other choice
|
| 106 |
-
# doesn't need reparametrization.
|
| 107 |
-
return module
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
|
| 111 |
-
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
| 112 |
-
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
| 113 |
-
"""
|
| 114 |
-
assert norm in CONV_NORMALIZATIONS
|
| 115 |
-
if norm == 'layer_norm':
|
| 116 |
-
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 117 |
-
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
| 118 |
-
elif norm == 'time_group_norm':
|
| 119 |
-
if causal:
|
| 120 |
-
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
| 121 |
-
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 122 |
-
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
| 123 |
-
else:
|
| 124 |
-
return nn.Identity()
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
| 128 |
-
padding_total: int = 0) -> int:
|
| 129 |
-
"""Calculate extra padding needed for convolution to have the same output length"""
|
| 130 |
-
length = x.shape[-1]
|
| 131 |
-
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 132 |
-
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 133 |
-
return ideal_length - length
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
|
| 137 |
-
"""Pad 1D input with handling for small inputs in reflect mode"""
|
| 138 |
-
length = x.shape[-1]
|
| 139 |
-
padding_left, padding_right = paddings
|
| 140 |
-
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 141 |
-
if mode == 'reflect':
|
| 142 |
-
max_pad = max(padding_left, padding_right)
|
| 143 |
-
extra_pad = 0
|
| 144 |
-
if length <= max_pad:
|
| 145 |
-
extra_pad = max_pad - length + 1
|
| 146 |
-
x = F.pad(x, (0, extra_pad))
|
| 147 |
-
padded = F.pad(x, paddings, mode, value)
|
| 148 |
-
end = padded.shape[-1] - extra_pad
|
| 149 |
-
return padded[..., :end]
|
| 150 |
-
else:
|
| 151 |
-
return F.pad(x, paddings, mode, value)
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
| 155 |
-
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
| 156 |
-
padding_left, padding_right = paddings
|
| 157 |
-
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 158 |
-
assert (padding_left + padding_right) <= x.shape[-1]
|
| 159 |
-
end = x.shape[-1] - padding_right
|
| 160 |
-
return x[..., padding_left: end]
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
class NormConv1d(nn.Module):
|
| 164 |
-
"""Wrapper around Conv1d and normalization applied to this conv"""
|
| 165 |
-
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| 166 |
-
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 167 |
-
super().__init__()
|
| 168 |
-
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
| 169 |
-
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
| 170 |
-
self.norm_type = norm
|
| 171 |
-
|
| 172 |
-
def forward(self, x):
|
| 173 |
-
x = self.conv(x)
|
| 174 |
-
x = self.norm(x)
|
| 175 |
-
return x
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
class NormConvTranspose1d(nn.Module):
|
| 179 |
-
"""Wrapper around ConvTranspose1d and normalization applied to this conv"""
|
| 180 |
-
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| 181 |
-
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 182 |
-
super().__init__()
|
| 183 |
-
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
| 184 |
-
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
| 185 |
-
self.norm_type = norm
|
| 186 |
-
|
| 187 |
-
def forward(self, x):
|
| 188 |
-
x = self.convtr(x)
|
| 189 |
-
x = self.norm(x)
|
| 190 |
-
return x
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
class VibeVoiceTokenizerStreamingCache:
|
| 194 |
-
"""Cache for streaming convolution, similar to KV cache in attention"""
|
| 195 |
-
def __init__(self):
|
| 196 |
-
self.cache = {} # Dict mapping (layer_id, sample_idx) to state tensor
|
| 197 |
-
|
| 198 |
-
def get(self, layer_id: str, sample_indices: torch.Tensor) -> Optional[torch.Tensor]:
|
| 199 |
-
"""Get cached states for given layer and sample indices"""
|
| 200 |
-
states = []
|
| 201 |
-
max_length = 0
|
| 202 |
-
|
| 203 |
-
# First pass: collect states and find max length
|
| 204 |
-
for idx in sample_indices.tolist():
|
| 205 |
-
key = (layer_id, idx)
|
| 206 |
-
if key not in self.cache:
|
| 207 |
-
return None # If any sample is missing, return None
|
| 208 |
-
state = self.cache[key]
|
| 209 |
-
states.append(state)
|
| 210 |
-
max_length = max(max_length, state.shape[-1])
|
| 211 |
-
|
| 212 |
-
# Second pass: pad states to max length if needed
|
| 213 |
-
if len(states) > 0 and states[0].dim() >= 2:
|
| 214 |
-
padded_states = []
|
| 215 |
-
for state in states:
|
| 216 |
-
if state.shape[-1] < max_length:
|
| 217 |
-
# Pad on the time dimension (last dimension)
|
| 218 |
-
pad_size = max_length - state.shape[-1]
|
| 219 |
-
# Pad with zeros on the LEFT to align the most recent samples
|
| 220 |
-
padded_state = F.pad(state, (pad_size, 0), mode='constant', value=0)
|
| 221 |
-
padded_states.append(padded_state)
|
| 222 |
-
else:
|
| 223 |
-
padded_states.append(state)
|
| 224 |
-
return torch.stack(padded_states, dim=0)
|
| 225 |
-
else:
|
| 226 |
-
return torch.stack(states, dim=0)
|
| 227 |
-
|
| 228 |
-
def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor):
|
| 229 |
-
"""Set cached states for given layer and sample indices"""
|
| 230 |
-
for i, idx in enumerate(sample_indices.tolist()):
|
| 231 |
-
key = (layer_id, idx)
|
| 232 |
-
self.cache[key] = states[i].detach()
|
| 233 |
-
|
| 234 |
-
def set_to_zero(self, sample_indices: torch.Tensor):
|
| 235 |
-
"""Set all cached states to zero for given sample indices"""
|
| 236 |
-
for key in list(self.cache.keys()):
|
| 237 |
-
layer_id, sample_idx = key
|
| 238 |
-
if sample_idx in sample_indices.tolist():
|
| 239 |
-
# Create zero tensor with same shape and dtype as cached tensor
|
| 240 |
-
cached_tensor = self.cache[key]
|
| 241 |
-
self.cache[key] = torch.zeros_like(cached_tensor)
|
| 242 |
-
|
| 243 |
-
def clear(self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None):
|
| 244 |
-
"""Clear cache for specific layer/samples or everything"""
|
| 245 |
-
if layer_id is None and sample_indices is None:
|
| 246 |
-
self.cache.clear()
|
| 247 |
-
elif layer_id is not None and sample_indices is None:
|
| 248 |
-
# Clear all samples for a specific layer
|
| 249 |
-
keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id]
|
| 250 |
-
for k in keys_to_remove:
|
| 251 |
-
del self.cache[k]
|
| 252 |
-
elif layer_id is not None and sample_indices is not None:
|
| 253 |
-
# Clear specific samples for a specific layer
|
| 254 |
-
for idx in sample_indices.tolist():
|
| 255 |
-
key = (layer_id, idx)
|
| 256 |
-
self.cache.pop(key, None)
|
| 257 |
-
|
| 258 |
-
class SConv1d(nn.Module):
|
| 259 |
-
"""Conv1d with built-in handling of asymmetric or causal padding and normalization."""
|
| 260 |
-
def __init__(self, in_channels: int, out_channels: int,
|
| 261 |
-
kernel_size: int, stride: int = 1, dilation: int = 1,
|
| 262 |
-
groups: int = 1, bias: bool = True, causal: bool = False,
|
| 263 |
-
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 264 |
-
pad_mode: str = 'reflect'):
|
| 265 |
-
super().__init__()
|
| 266 |
-
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
| 267 |
-
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
| 268 |
-
norm=norm, norm_kwargs=norm_kwargs)
|
| 269 |
-
self.causal = causal
|
| 270 |
-
self.pad_mode = pad_mode
|
| 271 |
-
|
| 272 |
-
# Store configuration
|
| 273 |
-
self.kernel_size = kernel_size
|
| 274 |
-
self.dilation = dilation
|
| 275 |
-
self.stride = stride
|
| 276 |
-
self.in_channels = in_channels
|
| 277 |
-
self.out_channels = out_channels
|
| 278 |
-
|
| 279 |
-
# For causal convolution, we need to maintain kernel_size - 1 samples as context
|
| 280 |
-
# need to check use which context_size is more suitable
|
| 281 |
-
# self.context_size = (kernel_size - 1) * dilation
|
| 282 |
-
self.context_size = (kernel_size - 1) * dilation - (stride - 1)
|
| 283 |
-
|
| 284 |
-
# For non-streaming mode, calculate padding
|
| 285 |
-
self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
|
| 286 |
-
|
| 287 |
-
# Create a unique layer ID for cache management
|
| 288 |
-
self._layer_id = None
|
| 289 |
-
|
| 290 |
-
@property
|
| 291 |
-
def layer_id(self):
|
| 292 |
-
if self._layer_id is None:
|
| 293 |
-
self._layer_id = f"sconv1d_{id(self)}"
|
| 294 |
-
return self._layer_id
|
| 295 |
-
|
| 296 |
-
def forward(self, x: torch.Tensor,
|
| 297 |
-
cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
|
| 298 |
-
sample_indices: Optional[torch.Tensor] = None,
|
| 299 |
-
use_cache: bool = False,
|
| 300 |
-
debug: bool = False) -> torch.Tensor:
|
| 301 |
-
"""
|
| 302 |
-
Forward pass with optional streaming support via cache.
|
| 303 |
-
|
| 304 |
-
Args:
|
| 305 |
-
x: Input tensor [batch_size, channels, time]
|
| 306 |
-
cache: VibeVoiceTokenizerStreamingCache object for maintaining states
|
| 307 |
-
sample_indices: Indices identifying each sample for cache management
|
| 308 |
-
use_cache: Whether to use cached states for streaming
|
| 309 |
-
debug: Whether to print debug information
|
| 310 |
-
|
| 311 |
-
Returns:
|
| 312 |
-
Output tensor
|
| 313 |
-
"""
|
| 314 |
-
B, C, T = x.shape
|
| 315 |
-
|
| 316 |
-
# Non-streaming mode
|
| 317 |
-
if not use_cache or cache is None:
|
| 318 |
-
return self._forward_non_streaming(x, debug=debug)
|
| 319 |
-
|
| 320 |
-
# Streaming mode
|
| 321 |
-
assert self.causal, "Streaming mode is only supported for causal convolutions"
|
| 322 |
-
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
|
| 323 |
-
assert len(sample_indices) == B, "sample_indices must match batch size"
|
| 324 |
-
|
| 325 |
-
return self._forward_streaming(x, cache, sample_indices, debug)
|
| 326 |
-
|
| 327 |
-
def _forward_streaming(self, x: torch.Tensor,
|
| 328 |
-
cache: VibeVoiceTokenizerStreamingCache,
|
| 329 |
-
sample_indices: torch.Tensor,
|
| 330 |
-
debug: bool = False) -> torch.Tensor:
|
| 331 |
-
"""Streaming forward pass with cache operations kept separate from compiled code"""
|
| 332 |
-
B, C, T = x.shape
|
| 333 |
-
|
| 334 |
-
# Cache operations (not compiled)
|
| 335 |
-
cached_states = cache.get(self.layer_id, sample_indices)
|
| 336 |
-
|
| 337 |
-
if cached_states is None:
|
| 338 |
-
# First chunk - initialize with zeros for context
|
| 339 |
-
if self.context_size > 0:
|
| 340 |
-
cached_states = torch.zeros(B, C, self.context_size, device=x.device, dtype=x.dtype)
|
| 341 |
-
if debug:
|
| 342 |
-
print(f"[DEBUG] Initialized cache with shape: {cached_states.shape}, context_size={self.context_size}")
|
| 343 |
-
else:
|
| 344 |
-
cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
|
| 345 |
-
if debug:
|
| 346 |
-
print(f"[DEBUG] No context needed (kernel_size=stride)")
|
| 347 |
-
|
| 348 |
-
# Concatenate cached states with input
|
| 349 |
-
if cached_states.shape[2] > 0:
|
| 350 |
-
input_with_context = torch.cat([cached_states, x], dim=2)
|
| 351 |
-
else:
|
| 352 |
-
input_with_context = x
|
| 353 |
-
|
| 354 |
-
if debug:
|
| 355 |
-
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
|
| 356 |
-
|
| 357 |
-
# Apply convolution directly - no extra padding in streaming mode
|
| 358 |
-
# The conv layer will handle its own padding internally
|
| 359 |
-
output = self.conv(input_with_context)
|
| 360 |
-
|
| 361 |
-
if debug:
|
| 362 |
-
print(f"[DEBUG] Output shape: {output.shape}")
|
| 363 |
-
|
| 364 |
-
# Update cache for next chunk
|
| 365 |
-
if self.context_size > 0:
|
| 366 |
-
# Calculate how many samples to keep
|
| 367 |
-
total_input_length = input_with_context.shape[2]
|
| 368 |
-
|
| 369 |
-
# Keep the last context_size samples
|
| 370 |
-
if total_input_length >= self.context_size:
|
| 371 |
-
new_cache_start = total_input_length - self.context_size
|
| 372 |
-
new_cache = input_with_context[:, :, new_cache_start:]
|
| 373 |
-
else:
|
| 374 |
-
# If we have less than context_size samples, keep everything
|
| 375 |
-
new_cache = input_with_context
|
| 376 |
-
|
| 377 |
-
if debug:
|
| 378 |
-
print(f"[DEBUG] New cache shape: {new_cache.shape}")
|
| 379 |
-
|
| 380 |
-
cache.set(self.layer_id, sample_indices, new_cache)
|
| 381 |
-
|
| 382 |
-
return output
|
| 383 |
-
|
| 384 |
-
def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
|
| 385 |
-
"""Standard forward pass without streaming"""
|
| 386 |
-
B, C, T = x.shape
|
| 387 |
-
kernel_size = self.kernel_size
|
| 388 |
-
stride = self.stride
|
| 389 |
-
dilation = self.dilation
|
| 390 |
-
padding_total = self.padding_total
|
| 391 |
-
|
| 392 |
-
# Compute extra padding for stride alignment
|
| 393 |
-
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 394 |
-
|
| 395 |
-
if debug:
|
| 396 |
-
print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}, padding_total={padding_total}, extra_padding={extra_padding}")
|
| 397 |
-
|
| 398 |
-
if self.causal:
|
| 399 |
-
# Left padding for causal
|
| 400 |
-
if self.pad_mode == 'constant':
|
| 401 |
-
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0)
|
| 402 |
-
else:
|
| 403 |
-
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
| 404 |
-
else:
|
| 405 |
-
# Symmetric padding for non-causal
|
| 406 |
-
padding_right = padding_total // 2
|
| 407 |
-
padding_left = padding_total - padding_right
|
| 408 |
-
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
| 409 |
-
|
| 410 |
-
if debug:
|
| 411 |
-
print(f"[DEBUG NON-STREAMING] After padding: {x.shape}")
|
| 412 |
-
|
| 413 |
-
output = self.conv(x)
|
| 414 |
-
|
| 415 |
-
if debug:
|
| 416 |
-
print(f"[DEBUG NON-STREAMING] Output shape: {output.shape}")
|
| 417 |
-
|
| 418 |
-
return output
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
class SConvTranspose1d(nn.Module):
|
| 422 |
-
"""ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization."""
|
| 423 |
-
def __init__(self, in_channels: int, out_channels: int,
|
| 424 |
-
kernel_size: int, stride: int = 1, causal: bool = False,
|
| 425 |
-
norm: str = 'none', trim_right_ratio: float = 1.,
|
| 426 |
-
norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True):
|
| 427 |
-
super().__init__()
|
| 428 |
-
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
| 429 |
-
causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias)
|
| 430 |
-
self.causal = causal
|
| 431 |
-
self.trim_right_ratio = trim_right_ratio
|
| 432 |
-
assert self.causal or self.trim_right_ratio == 1., \
|
| 433 |
-
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
| 434 |
-
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
| 435 |
-
|
| 436 |
-
# Store configuration
|
| 437 |
-
self.kernel_size = kernel_size
|
| 438 |
-
self.stride = stride
|
| 439 |
-
self.in_channels = in_channels
|
| 440 |
-
self.out_channels = out_channels
|
| 441 |
-
|
| 442 |
-
# For transposed convolution, padding calculation is different
|
| 443 |
-
self.padding_total = kernel_size - stride
|
| 444 |
-
|
| 445 |
-
# For streaming, we need to keep track of input history
|
| 446 |
-
# Transposed conv needs to see multiple input samples to produce correct output
|
| 447 |
-
self.context_size = kernel_size - 1
|
| 448 |
-
|
| 449 |
-
# Create a unique layer ID for cache management
|
| 450 |
-
self._layer_id = None
|
| 451 |
-
|
| 452 |
-
@property
|
| 453 |
-
def layer_id(self):
|
| 454 |
-
if self._layer_id is None:
|
| 455 |
-
self._layer_id = f"sconvtr1d_{id(self)}"
|
| 456 |
-
return self._layer_id
|
| 457 |
-
|
| 458 |
-
def forward(self, x: torch.Tensor,
|
| 459 |
-
cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
|
| 460 |
-
sample_indices: Optional[torch.Tensor] = None,
|
| 461 |
-
use_cache: bool = False,
|
| 462 |
-
debug: bool = False) -> torch.Tensor:
|
| 463 |
-
"""
|
| 464 |
-
Forward pass with optional streaming support via cache.
|
| 465 |
-
"""
|
| 466 |
-
B, C, T = x.shape
|
| 467 |
-
|
| 468 |
-
# Non-streaming mode
|
| 469 |
-
if not use_cache or cache is None:
|
| 470 |
-
return self._forward_non_streaming(x, debug=debug)
|
| 471 |
-
|
| 472 |
-
# Streaming mode
|
| 473 |
-
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
|
| 474 |
-
assert len(sample_indices) == B, "sample_indices must match batch size"
|
| 475 |
-
|
| 476 |
-
return self._forward_streaming(x, cache, sample_indices, debug)
|
| 477 |
-
|
| 478 |
-
def _forward_streaming(self, x: torch.Tensor,
|
| 479 |
-
cache: VibeVoiceTokenizerStreamingCache,
|
| 480 |
-
sample_indices: torch.Tensor,
|
| 481 |
-
debug: bool = False) -> torch.Tensor:
|
| 482 |
-
"""Streaming forward pass with cache operations kept separate from compiled code"""
|
| 483 |
-
B, C, T = x.shape
|
| 484 |
-
|
| 485 |
-
# Cache operations (not compiled)
|
| 486 |
-
cached_input = cache.get(self.layer_id, sample_indices)
|
| 487 |
-
|
| 488 |
-
if cached_input is None:
|
| 489 |
-
# First chunk - no history yet
|
| 490 |
-
cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
|
| 491 |
-
if debug:
|
| 492 |
-
print(f"[DEBUG] Initialized empty cache for transposed conv")
|
| 493 |
-
|
| 494 |
-
# Concatenate cached input with new input
|
| 495 |
-
full_input = torch.cat([cached_input, x], dim=2)
|
| 496 |
-
|
| 497 |
-
if debug:
|
| 498 |
-
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_input.shape}, Combined: {full_input.shape}")
|
| 499 |
-
|
| 500 |
-
# First chunk or debug mode - use uncompiled version
|
| 501 |
-
full_output = self.convtr(full_input)
|
| 502 |
-
|
| 503 |
-
if debug:
|
| 504 |
-
print(f"[DEBUG] Full transposed conv output shape: {full_output.shape}")
|
| 505 |
-
|
| 506 |
-
# Calculate padding to remove
|
| 507 |
-
if self.causal:
|
| 508 |
-
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
|
| 509 |
-
padding_left = self.padding_total - padding_right
|
| 510 |
-
else:
|
| 511 |
-
padding_right = self.padding_total // 2
|
| 512 |
-
padding_left = self.padding_total - padding_right
|
| 513 |
-
|
| 514 |
-
# Remove padding
|
| 515 |
-
if padding_left + padding_right > 0:
|
| 516 |
-
full_output = unpad1d(full_output, (padding_left, padding_right))
|
| 517 |
-
|
| 518 |
-
if debug:
|
| 519 |
-
print(f"[DEBUG] After unpadding: {full_output.shape}")
|
| 520 |
-
|
| 521 |
-
# Determine which part of the output corresponds to the new input
|
| 522 |
-
if cached_input.shape[2] == 0:
|
| 523 |
-
# First chunk - return all output
|
| 524 |
-
output = full_output
|
| 525 |
-
else:
|
| 526 |
-
# Subsequent chunks - return only the new output
|
| 527 |
-
expected_new_output = T * self.stride
|
| 528 |
-
|
| 529 |
-
# Take the last expected_new_output samples
|
| 530 |
-
if full_output.shape[2] >= expected_new_output:
|
| 531 |
-
output = full_output[:, :, -expected_new_output:]
|
| 532 |
-
else:
|
| 533 |
-
output = full_output
|
| 534 |
-
|
| 535 |
-
if debug:
|
| 536 |
-
print(f"[DEBUG] Final streaming output shape: {output.shape}")
|
| 537 |
-
|
| 538 |
-
# Update cache
|
| 539 |
-
if full_input.shape[2] > self.context_size:
|
| 540 |
-
new_cache = full_input[:, :, -self.context_size:]
|
| 541 |
-
else:
|
| 542 |
-
new_cache = full_input
|
| 543 |
-
|
| 544 |
-
if debug:
|
| 545 |
-
print(f"[DEBUG] New cache shape: {new_cache.shape}")
|
| 546 |
-
|
| 547 |
-
cache.set(self.layer_id, sample_indices, new_cache)
|
| 548 |
-
|
| 549 |
-
return output
|
| 550 |
-
|
| 551 |
-
def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
|
| 552 |
-
"""Standard forward pass without streaming"""
|
| 553 |
-
if debug:
|
| 554 |
-
print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}")
|
| 555 |
-
|
| 556 |
-
# Apply transposed convolution
|
| 557 |
-
y = self.convtr(x)
|
| 558 |
-
|
| 559 |
-
if debug:
|
| 560 |
-
print(f"[DEBUG NON-STREAMING] After transposed conv: {y.shape}")
|
| 561 |
-
|
| 562 |
-
# Calculate and remove padding
|
| 563 |
-
if self.causal:
|
| 564 |
-
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
|
| 565 |
-
padding_left = self.padding_total - padding_right
|
| 566 |
-
else:
|
| 567 |
-
padding_right = self.padding_total // 2
|
| 568 |
-
padding_left = self.padding_total - padding_right
|
| 569 |
-
|
| 570 |
-
if padding_left + padding_right > 0:
|
| 571 |
-
y = unpad1d(y, (padding_left, padding_right))
|
| 572 |
-
|
| 573 |
-
if debug:
|
| 574 |
-
print(f"[DEBUG NON-STREAMING] Final output shape: {y.shape}")
|
| 575 |
-
|
| 576 |
-
return y
|
| 577 |
-
|
| 578 |
-
# FFN
|
| 579 |
-
class FFN(nn.Module):
|
| 580 |
-
def __init__(
|
| 581 |
-
self,
|
| 582 |
-
embed_dim,
|
| 583 |
-
ffn_dim,
|
| 584 |
-
bias=False,
|
| 585 |
-
):
|
| 586 |
-
super().__init__()
|
| 587 |
-
self.embed_dim = embed_dim
|
| 588 |
-
self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias)
|
| 589 |
-
self.gelu = ACT2FN["gelu"]
|
| 590 |
-
self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias)
|
| 591 |
-
|
| 592 |
-
def forward(self, x):
|
| 593 |
-
x = self.linear1(x)
|
| 594 |
-
x = self.gelu(x)
|
| 595 |
-
x = self.linear2(x)
|
| 596 |
-
return x
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
class Convlayer(nn.Module):
|
| 600 |
-
def __init__(
|
| 601 |
-
self,
|
| 602 |
-
in_channels,
|
| 603 |
-
out_channels,
|
| 604 |
-
kernel_size,
|
| 605 |
-
stride=1,
|
| 606 |
-
dilation=1,
|
| 607 |
-
groups=1,
|
| 608 |
-
bias=True,
|
| 609 |
-
pad_mode='zeros',
|
| 610 |
-
norm='weight_norm',
|
| 611 |
-
causal=True,
|
| 612 |
-
):
|
| 613 |
-
super().__init__()
|
| 614 |
-
self.conv = SConv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
|
| 615 |
-
groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal)
|
| 616 |
-
|
| 617 |
-
def forward(self, x):
|
| 618 |
-
return self.conv(x)
|
| 619 |
-
|
| 620 |
-
class Block1D(nn.Module):
|
| 621 |
-
def __init__(self, dim, kernel_size=7, drop_path=0., mixer_layer='conv',
|
| 622 |
-
layer_scale_init_value=1e-6, **kwargs):
|
| 623 |
-
super().__init__()
|
| 624 |
-
|
| 625 |
-
if kwargs.get('layernorm', 'LN') == 'LN':
|
| 626 |
-
self.norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
|
| 627 |
-
self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
|
| 628 |
-
elif kwargs.get('layernorm', 'RMSNorm') == 'RMSNorm':
|
| 629 |
-
self.norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
|
| 630 |
-
self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
|
| 631 |
-
|
| 632 |
-
if mixer_layer == 'conv':
|
| 633 |
-
self.mixer = Convlayer(dim, dim, groups=kwargs.get('groups', 1),
|
| 634 |
-
kernel_size=kernel_size,
|
| 635 |
-
pad_mode=kwargs.get('pad_mode', 'reflect'),
|
| 636 |
-
norm=kwargs.get('norm', 'none'),
|
| 637 |
-
causal=kwargs.get('causal', True),
|
| 638 |
-
bias=kwargs.get('bias', True),
|
| 639 |
-
)
|
| 640 |
-
elif mixer_layer == 'depthwise_conv':
|
| 641 |
-
self.mixer = Convlayer(dim, dim, groups=dim,
|
| 642 |
-
kernel_size=kernel_size,
|
| 643 |
-
pad_mode=kwargs.get('pad_mode', 'reflect'),
|
| 644 |
-
norm=kwargs.get('norm', 'none'),
|
| 645 |
-
causal=kwargs.get('causal', True),
|
| 646 |
-
bias=kwargs.get('bias', True),
|
| 647 |
-
)
|
| 648 |
-
else:
|
| 649 |
-
raise ValueError(f"Unsupported mixer layer: {mixer_layer}")
|
| 650 |
-
|
| 651 |
-
self.ffn = FFN(
|
| 652 |
-
dim,
|
| 653 |
-
kwargs.get('ffn_expansion', 4) * dim,
|
| 654 |
-
bias=kwargs.get('bias', False),
|
| 655 |
-
)
|
| 656 |
-
self.drop_path = nn.Identity() if drop_path <= 0. else nn.modules.DropPath(drop_path)
|
| 657 |
-
|
| 658 |
-
if layer_scale_init_value > 0:
|
| 659 |
-
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 660 |
-
self.ffn_gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 661 |
-
else:
|
| 662 |
-
self.gamma = None
|
| 663 |
-
self.ffn_gamma = None
|
| 664 |
-
|
| 665 |
-
def forward(self, x):
|
| 666 |
-
# mixer
|
| 667 |
-
residual = x
|
| 668 |
-
x = self.norm(x)
|
| 669 |
-
x = self.mixer(x)
|
| 670 |
-
if self.gamma is not None:
|
| 671 |
-
x = x * self.gamma.unsqueeze(-1)
|
| 672 |
-
x = residual + self.drop_path(x)
|
| 673 |
-
|
| 674 |
-
# ffn
|
| 675 |
-
residual = x
|
| 676 |
-
x = self.ffn_norm(x)
|
| 677 |
-
x = x.permute(0, 2, 1)
|
| 678 |
-
x = self.ffn(x)
|
| 679 |
-
x = x.permute(0, 2, 1)
|
| 680 |
-
if self.ffn_gamma is not None:
|
| 681 |
-
x = x * self.ffn_gamma.unsqueeze(-1)
|
| 682 |
-
x = residual + self.drop_path(x)
|
| 683 |
-
|
| 684 |
-
return x
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
class TokenizerEncoder(nn.Module):
|
| 688 |
-
"""
|
| 689 |
-
Encoder component for the VibeVoice tokenizer that converts audio to latent representations.
|
| 690 |
-
|
| 691 |
-
Args:
|
| 692 |
-
config: Configuration object with model parameters
|
| 693 |
-
"""
|
| 694 |
-
def __init__(self, config):
|
| 695 |
-
super().__init__()
|
| 696 |
-
|
| 697 |
-
# Extract parameters from config
|
| 698 |
-
self.channels = config.channels
|
| 699 |
-
self.dimension = config.dimension
|
| 700 |
-
self.n_filters = config.n_filters
|
| 701 |
-
self.ratios = list(reversed(config.ratios))
|
| 702 |
-
self.depths = config.depths
|
| 703 |
-
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
|
| 704 |
-
self.hop_length = np.prod(self.ratios)
|
| 705 |
-
self.causal = config.causal
|
| 706 |
-
|
| 707 |
-
# Additional config parameters with defaults
|
| 708 |
-
kernel_size = getattr(config, "kernel_size", 7)
|
| 709 |
-
last_kernel_size = getattr(config, "last_kernel_size", 7)
|
| 710 |
-
norm = getattr(config, "norm", "none")
|
| 711 |
-
norm_params = getattr(config, "norm_params", {})
|
| 712 |
-
pad_mode = getattr(config, "pad_mode", "reflect")
|
| 713 |
-
bias = getattr(config, "bias", True)
|
| 714 |
-
layernorm = getattr(config, "layernorm", "LN")
|
| 715 |
-
layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
|
| 716 |
-
layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
|
| 717 |
-
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
|
| 718 |
-
mixer_layer = getattr(config, "mixer_layer", "conv")
|
| 719 |
-
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
|
| 720 |
-
disable_last_norm = getattr(config, "disable_last_norm", False)
|
| 721 |
-
|
| 722 |
-
# determine the norm type based on layernorm
|
| 723 |
-
if layernorm == 'LN':
|
| 724 |
-
norm_type = ConvLayerNorm
|
| 725 |
-
elif layernorm == 'RMSNorm':
|
| 726 |
-
norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
|
| 727 |
-
else:
|
| 728 |
-
raise ValueError(f"Unsupported norm type: {layernorm}")
|
| 729 |
-
|
| 730 |
-
# stem and intermediate downsampling conv layers
|
| 731 |
-
stem = nn.Sequential(
|
| 732 |
-
SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
|
| 733 |
-
)
|
| 734 |
-
|
| 735 |
-
self.downsample_layers = nn.ModuleList()
|
| 736 |
-
self.downsample_layers.append(stem)
|
| 737 |
-
for i in range(len(self.ratios)):
|
| 738 |
-
in_ch = self.n_filters * (2 ** i)
|
| 739 |
-
out_ch = self.n_filters * (2 ** (i + 1))
|
| 740 |
-
downsample_layer = nn.Sequential(
|
| 741 |
-
SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
| 742 |
-
)
|
| 743 |
-
self.downsample_layers.append(downsample_layer)
|
| 744 |
-
|
| 745 |
-
# configure the transformer blocks
|
| 746 |
-
layer_type = partial(
|
| 747 |
-
Block1D,
|
| 748 |
-
mixer_layer=mixer_layer,
|
| 749 |
-
layernorm=layernorm,
|
| 750 |
-
eps=layernorm_eps,
|
| 751 |
-
causal=self.causal,
|
| 752 |
-
pad_mode=pad_mode,
|
| 753 |
-
norm=norm,
|
| 754 |
-
bias=bias,
|
| 755 |
-
layer_scale_init_value=layer_scale_init_value,
|
| 756 |
-
)
|
| 757 |
-
|
| 758 |
-
self.stages = nn.ModuleList()
|
| 759 |
-
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
| 760 |
-
cur = 0
|
| 761 |
-
|
| 762 |
-
for i in range(len(self.depths)):
|
| 763 |
-
in_ch = self.n_filters * (2 ** i)
|
| 764 |
-
stage = nn.Sequential(
|
| 765 |
-
*[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
|
| 766 |
-
)
|
| 767 |
-
self.stages.append(stage)
|
| 768 |
-
cur += self.depths[i]
|
| 769 |
-
|
| 770 |
-
if not disable_last_norm:
|
| 771 |
-
self.norm = norm_type(in_ch, eps=layernorm_eps)
|
| 772 |
-
else:
|
| 773 |
-
self.norm = nn.Identity()
|
| 774 |
-
self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
| 775 |
-
|
| 776 |
-
def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 777 |
-
for i in range(len(self.depths)):
|
| 778 |
-
# Apply downsampling
|
| 779 |
-
for layer in self.downsample_layers[i]:
|
| 780 |
-
if isinstance(layer, SConv1d):
|
| 781 |
-
x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 782 |
-
else:
|
| 783 |
-
x = layer(x)
|
| 784 |
-
|
| 785 |
-
# Apply stage (Block1D contains Convlayer which contains SConv1d)
|
| 786 |
-
for block in self.stages[i]:
|
| 787 |
-
if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
|
| 788 |
-
# Block1D forward with cache support
|
| 789 |
-
residual = x
|
| 790 |
-
x = block.norm(x)
|
| 791 |
-
x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 792 |
-
if block.gamma is not None:
|
| 793 |
-
x = x * block.gamma.unsqueeze(-1)
|
| 794 |
-
x = residual + x
|
| 795 |
-
|
| 796 |
-
# FFN part
|
| 797 |
-
residual = x
|
| 798 |
-
x = block.ffn_norm(x)
|
| 799 |
-
x = x.permute(0, 2, 1)
|
| 800 |
-
x = block.ffn(x)
|
| 801 |
-
x = x.permute(0, 2, 1)
|
| 802 |
-
if block.ffn_gamma is not None:
|
| 803 |
-
x = x * block.ffn_gamma.unsqueeze(-1)
|
| 804 |
-
x = residual + x
|
| 805 |
-
else:
|
| 806 |
-
x = block(x)
|
| 807 |
-
|
| 808 |
-
return self.norm(x)
|
| 809 |
-
|
| 810 |
-
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 811 |
-
x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 812 |
-
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 813 |
-
return x
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
class TokenizerDecoder(nn.Module):
|
| 817 |
-
"""
|
| 818 |
-
Decoder component for the VibeVoice tokenizer that converts latent representations back to audio.
|
| 819 |
-
|
| 820 |
-
Args:
|
| 821 |
-
config: Configuration object with model parameters
|
| 822 |
-
"""
|
| 823 |
-
def __init__(self, config):
|
| 824 |
-
super().__init__()
|
| 825 |
-
|
| 826 |
-
# Extract parameters from config
|
| 827 |
-
self.dimension = config.dimension
|
| 828 |
-
self.channels = config.channels
|
| 829 |
-
self.n_filters = config.n_filters
|
| 830 |
-
self.ratios = config.ratios
|
| 831 |
-
|
| 832 |
-
# IMPORTANT CHANGE: Don't reverse depths again since they're already reversed in VibeVoiceAcousticTokenizerModel
|
| 833 |
-
self.depths = config.depths # Changed from list(reversed(config.depths))
|
| 834 |
-
|
| 835 |
-
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
|
| 836 |
-
self.hop_length = np.prod(self.ratios)
|
| 837 |
-
self.causal = config.causal
|
| 838 |
-
|
| 839 |
-
# Additional config parameters with defaults
|
| 840 |
-
kernel_size = getattr(config, "kernel_size", 7)
|
| 841 |
-
last_kernel_size = getattr(config, "last_kernel_size", 7)
|
| 842 |
-
norm = getattr(config, "norm", "none")
|
| 843 |
-
norm_params = getattr(config, "norm_params", {})
|
| 844 |
-
pad_mode = getattr(config, "pad_mode", "reflect")
|
| 845 |
-
bias = getattr(config, "bias", True)
|
| 846 |
-
layernorm = getattr(config, "layernorm", "LN")
|
| 847 |
-
layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
|
| 848 |
-
trim_right_ratio = getattr(config, "trim_right_ratio", 1.0)
|
| 849 |
-
layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
|
| 850 |
-
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
|
| 851 |
-
mixer_layer = getattr(config, "mixer_layer", "conv")
|
| 852 |
-
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
|
| 853 |
-
disable_last_norm = getattr(config, "disable_last_norm", False)
|
| 854 |
-
|
| 855 |
-
# determine the norm type based on layernorm
|
| 856 |
-
if layernorm == 'LN':
|
| 857 |
-
norm_type = ConvLayerNorm
|
| 858 |
-
elif layernorm == 'RMSNorm':
|
| 859 |
-
norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
|
| 860 |
-
else:
|
| 861 |
-
raise ValueError(f"Unsupported norm type: {layernorm}")
|
| 862 |
-
|
| 863 |
-
# stem and upsampling layers
|
| 864 |
-
stem = nn.Sequential(
|
| 865 |
-
SConv1d(self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm,
|
| 866 |
-
norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
|
| 867 |
-
)
|
| 868 |
-
|
| 869 |
-
self.upsample_layers = nn.ModuleList()
|
| 870 |
-
self.upsample_layers.append(stem)
|
| 871 |
-
for i in range(len(self.ratios)):
|
| 872 |
-
in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
|
| 873 |
-
out_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i - 1))
|
| 874 |
-
upsample_layer = nn.Sequential(
|
| 875 |
-
SConvTranspose1d(in_ch, out_ch,
|
| 876 |
-
kernel_size=self.ratios[i] * 2, stride=self.ratios[i],
|
| 877 |
-
norm=norm, norm_kwargs=norm_params, bias=bias,
|
| 878 |
-
causal=self.causal, trim_right_ratio=trim_right_ratio),
|
| 879 |
-
)
|
| 880 |
-
self.upsample_layers.append(upsample_layer)
|
| 881 |
-
|
| 882 |
-
# configure transformer blocks
|
| 883 |
-
layer_type = partial(
|
| 884 |
-
Block1D,
|
| 885 |
-
mixer_layer=mixer_layer,
|
| 886 |
-
layernorm=layernorm,
|
| 887 |
-
eps=layernorm_eps,
|
| 888 |
-
causal=self.causal,
|
| 889 |
-
pad_mode=pad_mode,
|
| 890 |
-
norm=norm,
|
| 891 |
-
bias=bias,
|
| 892 |
-
layer_scale_init_value=layer_scale_init_value,
|
| 893 |
-
)
|
| 894 |
-
|
| 895 |
-
self.stages = nn.ModuleList()
|
| 896 |
-
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
| 897 |
-
cur = 0
|
| 898 |
-
|
| 899 |
-
# Create stages in the same order as the original model
|
| 900 |
-
for i in range(len(self.depths)):
|
| 901 |
-
in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
|
| 902 |
-
stage = nn.Sequential(
|
| 903 |
-
*[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
|
| 904 |
-
)
|
| 905 |
-
self.stages.append(stage)
|
| 906 |
-
cur += self.depths[i]
|
| 907 |
-
|
| 908 |
-
if not disable_last_norm:
|
| 909 |
-
self.norm = norm_type(in_ch, eps=layernorm_eps)
|
| 910 |
-
else:
|
| 911 |
-
self.norm = nn.Identity()
|
| 912 |
-
self.head = SConv1d(in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
| 913 |
-
|
| 914 |
-
def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 915 |
-
for i in range(len(self.depths)):
|
| 916 |
-
# Apply upsampling
|
| 917 |
-
for layer in self.upsample_layers[i]:
|
| 918 |
-
if isinstance(layer, (SConv1d, SConvTranspose1d)):
|
| 919 |
-
x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 920 |
-
else:
|
| 921 |
-
x = layer(x)
|
| 922 |
-
|
| 923 |
-
# Apply stage (Block1D contains Convlayer which contains SConv1d)
|
| 924 |
-
for block in self.stages[i]:
|
| 925 |
-
if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
|
| 926 |
-
# Block1D forward with cache support
|
| 927 |
-
residual = x
|
| 928 |
-
x = block.norm(x)
|
| 929 |
-
x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 930 |
-
if block.gamma is not None:
|
| 931 |
-
x = x * block.gamma.unsqueeze(-1)
|
| 932 |
-
x = residual + x
|
| 933 |
-
|
| 934 |
-
# FFN part
|
| 935 |
-
residual = x
|
| 936 |
-
x = block.ffn_norm(x)
|
| 937 |
-
x = x.permute(0, 2, 1)
|
| 938 |
-
x = block.ffn(x)
|
| 939 |
-
x = x.permute(0, 2, 1)
|
| 940 |
-
if block.ffn_gamma is not None:
|
| 941 |
-
x = x * block.ffn_gamma.unsqueeze(-1)
|
| 942 |
-
x = residual + x
|
| 943 |
-
else:
|
| 944 |
-
x = block(x)
|
| 945 |
-
|
| 946 |
-
return self.norm(x)
|
| 947 |
-
|
| 948 |
-
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 949 |
-
x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 950 |
-
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 951 |
-
return x
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
@dataclass
|
| 955 |
-
class VibeVoiceTokenizerEncoderOutput:
|
| 956 |
-
"""
|
| 957 |
-
Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance.
|
| 958 |
-
|
| 959 |
-
Args:
|
| 960 |
-
mean (`torch.FloatTensor`): The mean parameters of the distribution.
|
| 961 |
-
std (`float` or `torch.FloatTensor`): Fixed standard deviation value.
|
| 962 |
-
"""
|
| 963 |
-
mean: torch.Tensor
|
| 964 |
-
std: Optional[Union[float, torch.Tensor]] = None
|
| 965 |
-
|
| 966 |
-
def sample(self, dist_type='fix'):
|
| 967 |
-
"""
|
| 968 |
-
Sample from the distribution.
|
| 969 |
-
|
| 970 |
-
Args:
|
| 971 |
-
dist_type (`str`): Sampling method, either 'fix' or 'gaussian'.
|
| 972 |
-
|
| 973 |
-
Returns:
|
| 974 |
-
`torch.FloatTensor`: Sampled values.
|
| 975 |
-
`torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian').
|
| 976 |
-
"""
|
| 977 |
-
if dist_type == 'fix':
|
| 978 |
-
x = self.mean + self.std * torch.randn_like(self.mean)
|
| 979 |
-
return x, self.std
|
| 980 |
-
elif dist_type == 'gaussian':
|
| 981 |
-
batch_size = self.mean.size(0)
|
| 982 |
-
value = self.std / 0.8
|
| 983 |
-
std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value
|
| 984 |
-
|
| 985 |
-
while std.dim() < self.mean.dim():
|
| 986 |
-
std = std.unsqueeze(-1)
|
| 987 |
-
|
| 988 |
-
x = self.mean + std * torch.randn_like(self.mean)
|
| 989 |
-
return x, std
|
| 990 |
-
else:
|
| 991 |
-
return self.mean, self.std
|
| 992 |
-
|
| 993 |
-
def kl(self):
|
| 994 |
-
"""Compute KL divergence between this distribution and a standard normal."""
|
| 995 |
-
target = torch.zeros_like(self.mean)
|
| 996 |
-
return F.mse_loss(self.mean, target, reduction='none')
|
| 997 |
-
|
| 998 |
-
def mode(self):
|
| 999 |
-
"""Return the distribution mode (which is the mean for Gaussian)."""
|
| 1000 |
-
return self.mean
|
| 1001 |
-
|
| 1002 |
-
class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
| 1003 |
-
"""VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens"""
|
| 1004 |
-
|
| 1005 |
-
config_class = VibeVoiceAcousticTokenizerConfig
|
| 1006 |
-
base_model_prefix = "vibevoice_acoustic_tokenizer"
|
| 1007 |
-
_supports_flash_attn_2 = True
|
| 1008 |
-
_supports_sdpa = True
|
| 1009 |
-
_no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
|
| 1010 |
-
|
| 1011 |
-
def __init__(self, config):
|
| 1012 |
-
super().__init__(config)
|
| 1013 |
-
|
| 1014 |
-
self.register_buffer('fix_std', torch.tensor(config.fix_std), persistent=False)
|
| 1015 |
-
self.std_dist_type = getattr(config, "std_dist_type", "fix")
|
| 1016 |
-
|
| 1017 |
-
# Parse encoder depths
|
| 1018 |
-
if isinstance(config.encoder_depths, str):
|
| 1019 |
-
encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
|
| 1020 |
-
else:
|
| 1021 |
-
encoder_depths = config.encoder_depths
|
| 1022 |
-
|
| 1023 |
-
# Parse decoder depths if provided
|
| 1024 |
-
if config.decoder_depths is not None and isinstance(config.decoder_depths, str):
|
| 1025 |
-
decoder_depths = [int(d) for d in config.decoder_depths.split('-')]
|
| 1026 |
-
else:
|
| 1027 |
-
# Default: use reversed encoder depths if decoder_depths is None
|
| 1028 |
-
decoder_depths = list(reversed(encoder_depths))
|
| 1029 |
-
|
| 1030 |
-
# Create encoder config
|
| 1031 |
-
encoder_config = copy.deepcopy(config)
|
| 1032 |
-
encoder_config.dimension = config.vae_dim
|
| 1033 |
-
encoder_config.n_filters = config.encoder_n_filters
|
| 1034 |
-
encoder_config.ratios = config.encoder_ratios
|
| 1035 |
-
encoder_config.depths = encoder_depths
|
| 1036 |
-
encoder_config.norm = config.conv_norm
|
| 1037 |
-
encoder_config.pad_mode = config.pad_mode
|
| 1038 |
-
encoder_config.bias = config.conv_bias
|
| 1039 |
-
encoder_config.layernorm_eps = config.layernorm_eps
|
| 1040 |
-
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
| 1041 |
-
encoder_config.mixer_layer = config.mixer_layer
|
| 1042 |
-
encoder_config.layer_scale_init_value = config.layer_scale_init_value
|
| 1043 |
-
encoder_config.disable_last_norm = config.disable_last_norm
|
| 1044 |
-
|
| 1045 |
-
# Create decoder config
|
| 1046 |
-
decoder_config = copy.deepcopy(config)
|
| 1047 |
-
decoder_config.dimension = config.vae_dim
|
| 1048 |
-
decoder_config.n_filters = config.decoder_n_filters
|
| 1049 |
-
decoder_config.ratios = config.decoder_ratios
|
| 1050 |
-
decoder_config.depths = decoder_depths
|
| 1051 |
-
decoder_config.norm = config.conv_norm
|
| 1052 |
-
decoder_config.pad_mode = config.pad_mode
|
| 1053 |
-
decoder_config.bias = config.conv_bias
|
| 1054 |
-
decoder_config.layernorm_eps = config.layernorm_eps
|
| 1055 |
-
decoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
| 1056 |
-
decoder_config.mixer_layer = config.mixer_layer
|
| 1057 |
-
decoder_config.layer_scale_init_value = config.layer_scale_init_value
|
| 1058 |
-
decoder_config.disable_last_norm = config.disable_last_norm
|
| 1059 |
-
|
| 1060 |
-
# Initialize encoder and decoder
|
| 1061 |
-
self.encoder = TokenizerEncoder(encoder_config)
|
| 1062 |
-
self.decoder = TokenizerDecoder(decoder_config)
|
| 1063 |
-
|
| 1064 |
-
# Initialize weights
|
| 1065 |
-
self.apply(self._init_weights)
|
| 1066 |
-
|
| 1067 |
-
def _init_weights(self, module):
|
| 1068 |
-
"""Initialize weights for the model"""
|
| 1069 |
-
if isinstance(module, nn.Linear):
|
| 1070 |
-
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
| 1071 |
-
if module.bias is not None:
|
| 1072 |
-
nn.init.zeros_(module.bias)
|
| 1073 |
-
elif isinstance(module, nn.LayerNorm):
|
| 1074 |
-
nn.init.ones_(module.weight)
|
| 1075 |
-
nn.init.zeros_(module.bias)
|
| 1076 |
-
elif isinstance(module, nn.Conv1d):
|
| 1077 |
-
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
| 1078 |
-
if module.bias is not None:
|
| 1079 |
-
nn.init.zeros_(module.bias)
|
| 1080 |
-
|
| 1081 |
-
@torch.no_grad()
|
| 1082 |
-
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 1083 |
-
"""Convert audio to latent representations"""
|
| 1084 |
-
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 1085 |
-
return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std)
|
| 1086 |
-
|
| 1087 |
-
@torch.no_grad()
|
| 1088 |
-
def sampling(self, encoder_output, dist_type=None):
|
| 1089 |
-
"""Sample from the encoder output distribution"""
|
| 1090 |
-
dist_type = dist_type or self.std_dist_type
|
| 1091 |
-
|
| 1092 |
-
if dist_type == 'fix':
|
| 1093 |
-
return encoder_output.sample(dist_type='fix')
|
| 1094 |
-
elif dist_type == 'gaussian':
|
| 1095 |
-
return encoder_output.sample(dist_type='gaussian')
|
| 1096 |
-
else:
|
| 1097 |
-
raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'")
|
| 1098 |
-
|
| 1099 |
-
@torch.no_grad()
|
| 1100 |
-
def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 1101 |
-
"""Convert latent representations back to audio"""
|
| 1102 |
-
if latents.shape[1] == self.config.vae_dim:
|
| 1103 |
-
pass
|
| 1104 |
-
else:
|
| 1105 |
-
latents = latents.permute(0, 2, 1)
|
| 1106 |
-
|
| 1107 |
-
audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 1108 |
-
return audio
|
| 1109 |
-
|
| 1110 |
-
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 1111 |
-
"""Full forward pass: encode audio to latents, then decode back to audio"""
|
| 1112 |
-
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 1113 |
-
sampled_latents, _ = self.sampling(encoder_output)
|
| 1114 |
-
reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 1115 |
-
return reconstructed, sampled_latents
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
class VibeVoiceSemanticTokenizerModel(PreTrainedModel):
|
| 1119 |
-
"""VibeVoice speech tokenizer model with only encoder for semantic tokens"""
|
| 1120 |
-
|
| 1121 |
-
config_class = VibeVoiceSemanticTokenizerConfig
|
| 1122 |
-
base_model_prefix = "vibevoice_semantic_tokenizer"
|
| 1123 |
-
_supports_flash_attn_2 = True
|
| 1124 |
-
_supports_sdpa = True
|
| 1125 |
-
_no_split_modules = ["TokenizerEncoder"]
|
| 1126 |
-
|
| 1127 |
-
def __init__(self, config):
|
| 1128 |
-
super().__init__(config)
|
| 1129 |
-
|
| 1130 |
-
# Parse encoder depths
|
| 1131 |
-
if isinstance(config.encoder_depths, str):
|
| 1132 |
-
encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
|
| 1133 |
-
else:
|
| 1134 |
-
encoder_depths = config.encoder_depths
|
| 1135 |
-
|
| 1136 |
-
# Create encoder config
|
| 1137 |
-
encoder_config = copy.deepcopy(config)
|
| 1138 |
-
encoder_config.dimension = config.vae_dim
|
| 1139 |
-
encoder_config.n_filters = config.encoder_n_filters
|
| 1140 |
-
encoder_config.ratios = config.encoder_ratios
|
| 1141 |
-
encoder_config.depths = encoder_depths
|
| 1142 |
-
encoder_config.norm = config.conv_norm
|
| 1143 |
-
encoder_config.pad_mode = config.pad_mode
|
| 1144 |
-
encoder_config.bias = config.conv_bias
|
| 1145 |
-
encoder_config.layernorm_eps = config.layernorm_eps
|
| 1146 |
-
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
| 1147 |
-
encoder_config.mixer_layer = config.mixer_layer
|
| 1148 |
-
encoder_config.layer_scale_init_value = config.layer_scale_init_value
|
| 1149 |
-
encoder_config.disable_last_norm = config.disable_last_norm
|
| 1150 |
-
|
| 1151 |
-
# Initialize encoder and decoder
|
| 1152 |
-
self.encoder = TokenizerEncoder(encoder_config)
|
| 1153 |
-
|
| 1154 |
-
# Initialize weights
|
| 1155 |
-
self.apply(self._init_weights)
|
| 1156 |
-
|
| 1157 |
-
def _init_weights(self, module):
|
| 1158 |
-
"""Initialize weights for the model"""
|
| 1159 |
-
if isinstance(module, nn.Linear):
|
| 1160 |
-
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
| 1161 |
-
if module.bias is not None:
|
| 1162 |
-
nn.init.zeros_(module.bias)
|
| 1163 |
-
elif isinstance(module, nn.LayerNorm):
|
| 1164 |
-
nn.init.ones_(module.weight)
|
| 1165 |
-
nn.init.zeros_(module.bias)
|
| 1166 |
-
elif isinstance(module, nn.Conv1d):
|
| 1167 |
-
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
| 1168 |
-
if module.bias is not None:
|
| 1169 |
-
nn.init.zeros_(module.bias)
|
| 1170 |
-
|
| 1171 |
-
@torch.no_grad()
|
| 1172 |
-
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 1173 |
-
"""Convert audio to latent representations"""
|
| 1174 |
-
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 1175 |
-
return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
|
| 1176 |
-
|
| 1177 |
-
@torch.no_grad()
|
| 1178 |
-
def sampling(self, encoder_output, dist_type=None):
|
| 1179 |
-
"""Sample from the encoder output distribution"""
|
| 1180 |
-
return encoder_output.sample(dist_type='none')
|
| 1181 |
-
|
| 1182 |
-
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 1183 |
-
"""Full forward pass: encode audio to latents, then decode back to audio"""
|
| 1184 |
-
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 1185 |
-
sampled_latents, _ = self.sampling(encoder_output, dist_type='none')
|
| 1186 |
-
return None, sampled_latents
|
| 1187 |
-
|
| 1188 |
-
AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel)
|
| 1189 |
-
AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel)
|
| 1190 |
-
|
| 1191 |
-
__all__ = [
|
| 1192 |
-
"VibeVoiceTokenizerStreamingCache",
|
| 1193 |
-
"VibeVoiceAcousticTokenizerModel",
|
| 1194 |
-
"VibeVoiceSemanticTokenizerModel",
|
| 1195 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/modular/streamer.py
DELETED
|
@@ -1,264 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
import asyncio
|
| 6 |
-
from queue import Queue
|
| 7 |
-
from typing import TYPE_CHECKING, Optional
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
from transformers.generation.streamers import BaseStreamer
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class AudioStreamer(BaseStreamer):
|
| 14 |
-
"""
|
| 15 |
-
Audio streamer that stores audio chunks in queues for each sample in the batch.
|
| 16 |
-
This allows streaming audio generation for multiple samples simultaneously.
|
| 17 |
-
|
| 18 |
-
Parameters:
|
| 19 |
-
batch_size (`int`):
|
| 20 |
-
The batch size for generation
|
| 21 |
-
stop_signal (`any`, *optional*):
|
| 22 |
-
The signal to put in the queue when generation ends. Defaults to None.
|
| 23 |
-
timeout (`float`, *optional*):
|
| 24 |
-
The timeout for the audio queue. If `None`, the queue will block indefinitely.
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
def __init__(
|
| 28 |
-
self,
|
| 29 |
-
batch_size: int,
|
| 30 |
-
stop_signal: Optional[any] = None,
|
| 31 |
-
timeout: Optional[float] = None,
|
| 32 |
-
):
|
| 33 |
-
self.batch_size = batch_size
|
| 34 |
-
self.stop_signal = stop_signal
|
| 35 |
-
self.timeout = timeout
|
| 36 |
-
|
| 37 |
-
# Create a queue for each sample in the batch
|
| 38 |
-
self.audio_queues = [Queue() for _ in range(batch_size)]
|
| 39 |
-
self.finished_flags = [False for _ in range(batch_size)]
|
| 40 |
-
self.sample_indices_map = {} # Maps from sample index to queue index
|
| 41 |
-
|
| 42 |
-
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
|
| 43 |
-
"""
|
| 44 |
-
Receives audio chunks and puts them in the appropriate queues.
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
|
| 48 |
-
sample_indices: Tensor indicating which samples these chunks belong to
|
| 49 |
-
"""
|
| 50 |
-
for i, sample_idx in enumerate(sample_indices):
|
| 51 |
-
idx = sample_idx.item()
|
| 52 |
-
if idx < self.batch_size and not self.finished_flags[idx]:
|
| 53 |
-
# Convert to numpy or keep as tensor based on preference
|
| 54 |
-
audio_chunk = audio_chunks[i].detach().cpu()
|
| 55 |
-
self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
|
| 56 |
-
|
| 57 |
-
def end(self, sample_indices: Optional[torch.Tensor] = None):
|
| 58 |
-
"""
|
| 59 |
-
Signals the end of generation for specified samples or all samples.
|
| 60 |
-
|
| 61 |
-
Args:
|
| 62 |
-
sample_indices: Optional tensor of sample indices to end. If None, ends all.
|
| 63 |
-
"""
|
| 64 |
-
if sample_indices is None:
|
| 65 |
-
# End all samples
|
| 66 |
-
for idx in range(self.batch_size):
|
| 67 |
-
if not self.finished_flags[idx]:
|
| 68 |
-
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
|
| 69 |
-
self.finished_flags[idx] = True
|
| 70 |
-
else:
|
| 71 |
-
# End specific samples
|
| 72 |
-
for sample_idx in sample_indices:
|
| 73 |
-
idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
|
| 74 |
-
if idx < self.batch_size and not self.finished_flags[idx]:
|
| 75 |
-
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
|
| 76 |
-
self.finished_flags[idx] = True
|
| 77 |
-
|
| 78 |
-
def __iter__(self):
|
| 79 |
-
"""Returns an iterator over the batch of audio streams."""
|
| 80 |
-
return AudioBatchIterator(self)
|
| 81 |
-
|
| 82 |
-
def get_stream(self, sample_idx: int):
|
| 83 |
-
"""Get the audio stream for a specific sample."""
|
| 84 |
-
if sample_idx >= self.batch_size:
|
| 85 |
-
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
|
| 86 |
-
return AudioSampleIterator(self, sample_idx)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class AudioSampleIterator:
|
| 90 |
-
"""Iterator for a single audio stream from the batch."""
|
| 91 |
-
|
| 92 |
-
def __init__(self, streamer: AudioStreamer, sample_idx: int):
|
| 93 |
-
self.streamer = streamer
|
| 94 |
-
self.sample_idx = sample_idx
|
| 95 |
-
|
| 96 |
-
def __iter__(self):
|
| 97 |
-
return self
|
| 98 |
-
|
| 99 |
-
def __next__(self):
|
| 100 |
-
value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
|
| 101 |
-
if value == self.streamer.stop_signal:
|
| 102 |
-
raise StopIteration()
|
| 103 |
-
return value
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
class AudioBatchIterator:
|
| 107 |
-
"""Iterator that yields audio chunks for all samples in the batch."""
|
| 108 |
-
|
| 109 |
-
def __init__(self, streamer: AudioStreamer):
|
| 110 |
-
self.streamer = streamer
|
| 111 |
-
self.active_samples = set(range(streamer.batch_size))
|
| 112 |
-
|
| 113 |
-
def __iter__(self):
|
| 114 |
-
return self
|
| 115 |
-
|
| 116 |
-
def __next__(self):
|
| 117 |
-
if not self.active_samples:
|
| 118 |
-
raise StopIteration()
|
| 119 |
-
|
| 120 |
-
batch_chunks = {}
|
| 121 |
-
samples_to_remove = set()
|
| 122 |
-
|
| 123 |
-
# Try to get chunks from all active samples
|
| 124 |
-
for idx in self.active_samples:
|
| 125 |
-
try:
|
| 126 |
-
value = self.streamer.audio_queues[idx].get(block=False)
|
| 127 |
-
if value == self.streamer.stop_signal:
|
| 128 |
-
samples_to_remove.add(idx)
|
| 129 |
-
else:
|
| 130 |
-
batch_chunks[idx] = value
|
| 131 |
-
except:
|
| 132 |
-
# Queue is empty for this sample, skip it this iteration
|
| 133 |
-
pass
|
| 134 |
-
|
| 135 |
-
# Remove finished samples
|
| 136 |
-
self.active_samples -= samples_to_remove
|
| 137 |
-
|
| 138 |
-
if batch_chunks:
|
| 139 |
-
return batch_chunks
|
| 140 |
-
elif self.active_samples:
|
| 141 |
-
# If no chunks were ready but we still have active samples,
|
| 142 |
-
# wait a bit and try again
|
| 143 |
-
import time
|
| 144 |
-
time.sleep(0.01)
|
| 145 |
-
return self.__next__()
|
| 146 |
-
else:
|
| 147 |
-
raise StopIteration()
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
class AsyncAudioStreamer(AudioStreamer):
|
| 151 |
-
"""
|
| 152 |
-
Async version of AudioStreamer for use in async contexts.
|
| 153 |
-
"""
|
| 154 |
-
|
| 155 |
-
def __init__(
|
| 156 |
-
self,
|
| 157 |
-
batch_size: int,
|
| 158 |
-
stop_signal: Optional[any] = None,
|
| 159 |
-
timeout: Optional[float] = None,
|
| 160 |
-
):
|
| 161 |
-
super().__init__(batch_size, stop_signal, timeout)
|
| 162 |
-
# Replace regular queues with async queues
|
| 163 |
-
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
|
| 164 |
-
self.loop = asyncio.get_running_loop()
|
| 165 |
-
|
| 166 |
-
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
|
| 167 |
-
"""Put audio chunks in the appropriate async queues."""
|
| 168 |
-
for i, sample_idx in enumerate(sample_indices):
|
| 169 |
-
idx = sample_idx.item()
|
| 170 |
-
if idx < self.batch_size and not self.finished_flags[idx]:
|
| 171 |
-
audio_chunk = audio_chunks[i].detach().cpu()
|
| 172 |
-
self.loop.call_soon_threadsafe(
|
| 173 |
-
self.audio_queues[idx].put_nowait, audio_chunk
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
def end(self, sample_indices: Optional[torch.Tensor] = None):
|
| 177 |
-
"""Signal the end of generation for specified samples."""
|
| 178 |
-
if sample_indices is None:
|
| 179 |
-
indices_to_end = range(self.batch_size)
|
| 180 |
-
else:
|
| 181 |
-
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
|
| 182 |
-
|
| 183 |
-
for idx in indices_to_end:
|
| 184 |
-
if idx < self.batch_size and not self.finished_flags[idx]:
|
| 185 |
-
self.loop.call_soon_threadsafe(
|
| 186 |
-
self.audio_queues[idx].put_nowait, self.stop_signal
|
| 187 |
-
)
|
| 188 |
-
self.finished_flags[idx] = True
|
| 189 |
-
|
| 190 |
-
async def get_stream(self, sample_idx: int):
|
| 191 |
-
"""Get async iterator for a specific sample's audio stream."""
|
| 192 |
-
if sample_idx >= self.batch_size:
|
| 193 |
-
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
|
| 194 |
-
|
| 195 |
-
while True:
|
| 196 |
-
value = await self.audio_queues[sample_idx].get()
|
| 197 |
-
if value == self.stop_signal:
|
| 198 |
-
break
|
| 199 |
-
yield value
|
| 200 |
-
|
| 201 |
-
def __aiter__(self):
|
| 202 |
-
"""Returns an async iterator over all audio streams."""
|
| 203 |
-
return AsyncAudioBatchIterator(self)
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
class AsyncAudioBatchIterator:
|
| 207 |
-
"""Async iterator for batch audio streaming."""
|
| 208 |
-
|
| 209 |
-
def __init__(self, streamer: AsyncAudioStreamer):
|
| 210 |
-
self.streamer = streamer
|
| 211 |
-
self.active_samples = set(range(streamer.batch_size))
|
| 212 |
-
|
| 213 |
-
def __aiter__(self):
|
| 214 |
-
return self
|
| 215 |
-
|
| 216 |
-
async def __anext__(self):
|
| 217 |
-
if not self.active_samples:
|
| 218 |
-
raise StopAsyncIteration()
|
| 219 |
-
|
| 220 |
-
batch_chunks = {}
|
| 221 |
-
samples_to_remove = set()
|
| 222 |
-
|
| 223 |
-
# Create tasks for all active samples
|
| 224 |
-
tasks = {
|
| 225 |
-
idx: asyncio.create_task(self._get_chunk(idx))
|
| 226 |
-
for idx in self.active_samples
|
| 227 |
-
}
|
| 228 |
-
|
| 229 |
-
# Wait for at least one chunk to be ready
|
| 230 |
-
done, pending = await asyncio.wait(
|
| 231 |
-
tasks.values(),
|
| 232 |
-
return_when=asyncio.FIRST_COMPLETED,
|
| 233 |
-
timeout=self.streamer.timeout
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
# Cancel pending tasks
|
| 237 |
-
for task in pending:
|
| 238 |
-
task.cancel()
|
| 239 |
-
|
| 240 |
-
# Process completed tasks
|
| 241 |
-
for idx, task in tasks.items():
|
| 242 |
-
if task in done:
|
| 243 |
-
try:
|
| 244 |
-
value = await task
|
| 245 |
-
if value == self.streamer.stop_signal:
|
| 246 |
-
samples_to_remove.add(idx)
|
| 247 |
-
else:
|
| 248 |
-
batch_chunks[idx] = value
|
| 249 |
-
except asyncio.CancelledError:
|
| 250 |
-
pass
|
| 251 |
-
|
| 252 |
-
self.active_samples -= samples_to_remove
|
| 253 |
-
|
| 254 |
-
if batch_chunks:
|
| 255 |
-
return batch_chunks
|
| 256 |
-
elif self.active_samples:
|
| 257 |
-
# Try again if we still have active samples
|
| 258 |
-
return await self.__anext__()
|
| 259 |
-
else:
|
| 260 |
-
raise StopAsyncIteration()
|
| 261 |
-
|
| 262 |
-
async def _get_chunk(self, idx):
|
| 263 |
-
"""Helper to get a chunk from a specific queue."""
|
| 264 |
-
return await self.streamer.audio_queues[idx].get()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/processor/__init__.py
DELETED
|
File without changes
|
src/vibevoice/processor/vibevoice_processor.py
DELETED
|
@@ -1,701 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import warnings
|
| 3 |
-
from typing import List, Optional, Union, Dict, Any, Tuple
|
| 4 |
-
import os
|
| 5 |
-
import re
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
|
| 10 |
-
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
| 11 |
-
from transformers.utils import TensorType, logging
|
| 12 |
-
from .vibevoice_tokenizer_processor import AudioNormalizer
|
| 13 |
-
|
| 14 |
-
logger = logging.get_logger(__name__)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class VibeVoiceProcessor:
|
| 18 |
-
r"""
|
| 19 |
-
Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor.
|
| 20 |
-
|
| 21 |
-
[`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`].
|
| 22 |
-
See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information.
|
| 23 |
-
|
| 24 |
-
Args:
|
| 25 |
-
tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):
|
| 26 |
-
The tokenizer for text processing.
|
| 27 |
-
audio_processor (`VibeVoiceTokenizerProcessor`):
|
| 28 |
-
The audio processor for speech processing.
|
| 29 |
-
speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):
|
| 30 |
-
The compression ratio for speech tokenization.
|
| 31 |
-
db_normalize (`bool`, *optional*, defaults to True):
|
| 32 |
-
Whether to apply decibel normalization to audio inputs.
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs):
|
| 36 |
-
self.tokenizer = tokenizer
|
| 37 |
-
self.audio_processor = audio_processor
|
| 38 |
-
self.speech_tok_compress_ratio = speech_tok_compress_ratio
|
| 39 |
-
self.db_normalize = db_normalize
|
| 40 |
-
self.audio_normalizer = AudioNormalizer() if db_normalize else None
|
| 41 |
-
self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
|
| 42 |
-
|
| 43 |
-
@classmethod
|
| 44 |
-
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 45 |
-
"""
|
| 46 |
-
Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor.
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 50 |
-
This can be either:
|
| 51 |
-
- a string, the *model id* of a pretrained model
|
| 52 |
-
- a path to a *directory* containing processor config
|
| 53 |
-
|
| 54 |
-
Returns:
|
| 55 |
-
[`VibeVoiceProcessor`]: The processor object instantiated from pretrained model.
|
| 56 |
-
"""
|
| 57 |
-
import os
|
| 58 |
-
import json
|
| 59 |
-
from transformers.utils import cached_file
|
| 60 |
-
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
| 61 |
-
from vibevoice.modular.modular_vibevoice_text_tokenizer import (
|
| 62 |
-
VibeVoiceTextTokenizer,
|
| 63 |
-
VibeVoiceTextTokenizerFast
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
# Try to load from local path first, then from HF hub
|
| 67 |
-
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
|
| 68 |
-
config = None
|
| 69 |
-
|
| 70 |
-
if os.path.exists(config_path):
|
| 71 |
-
# Local path exists
|
| 72 |
-
with open(config_path, 'r') as f:
|
| 73 |
-
config = json.load(f)
|
| 74 |
-
else:
|
| 75 |
-
# Try to load from HF hub
|
| 76 |
-
try:
|
| 77 |
-
config_file = cached_file(
|
| 78 |
-
pretrained_model_name_or_path,
|
| 79 |
-
"preprocessor_config.json",
|
| 80 |
-
**kwargs
|
| 81 |
-
)
|
| 82 |
-
with open(config_file, 'r') as f:
|
| 83 |
-
config = json.load(f)
|
| 84 |
-
except Exception as e:
|
| 85 |
-
logger.warning(f"Could not load preprocessor_config.json from {pretrained_model_name_or_path}: {e}")
|
| 86 |
-
logger.warning("Using default configuration")
|
| 87 |
-
config = {
|
| 88 |
-
"speech_tok_compress_ratio": 3200,
|
| 89 |
-
"db_normalize": True,
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
# Extract main processor parameters
|
| 93 |
-
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
|
| 94 |
-
db_normalize = config.get("db_normalize", True)
|
| 95 |
-
|
| 96 |
-
# Load tokenizer - try from model path first, then fallback to Qwen
|
| 97 |
-
language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
|
| 98 |
-
logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
|
| 99 |
-
if 'qwen' in language_model_pretrained_name.lower():
|
| 100 |
-
tokenizer = VibeVoiceTextTokenizerFast.from_pretrained(
|
| 101 |
-
language_model_pretrained_name,
|
| 102 |
-
**kwargs
|
| 103 |
-
)
|
| 104 |
-
else:
|
| 105 |
-
raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.")
|
| 106 |
-
|
| 107 |
-
# Load audio processor
|
| 108 |
-
if "audio_processor" in config:
|
| 109 |
-
# Create audio processor from config
|
| 110 |
-
audio_config = config["audio_processor"]
|
| 111 |
-
audio_processor = VibeVoiceTokenizerProcessor(
|
| 112 |
-
sampling_rate=audio_config.get("sampling_rate", 24000),
|
| 113 |
-
normalize_audio=audio_config.get("normalize_audio", True),
|
| 114 |
-
target_dB_FS=audio_config.get("target_dB_FS", -25),
|
| 115 |
-
eps=audio_config.get("eps", 1e-6),
|
| 116 |
-
)
|
| 117 |
-
else:
|
| 118 |
-
# Create default audio processor
|
| 119 |
-
audio_processor = VibeVoiceTokenizerProcessor()
|
| 120 |
-
|
| 121 |
-
# Create and return the processor
|
| 122 |
-
return cls(
|
| 123 |
-
tokenizer=tokenizer,
|
| 124 |
-
audio_processor=audio_processor,
|
| 125 |
-
speech_tok_compress_ratio=speech_tok_compress_ratio,
|
| 126 |
-
db_normalize=db_normalize,
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
| 130 |
-
"""
|
| 131 |
-
Save a processor to a directory, so that it can be re-loaded using the
|
| 132 |
-
[`~VibeVoiceProcessor.from_pretrained`] class method.
|
| 133 |
-
|
| 134 |
-
Args:
|
| 135 |
-
save_directory (`str` or `os.PathLike`):
|
| 136 |
-
Directory where the processor will be saved.
|
| 137 |
-
"""
|
| 138 |
-
import os
|
| 139 |
-
import json
|
| 140 |
-
|
| 141 |
-
os.makedirs(save_directory, exist_ok=True)
|
| 142 |
-
|
| 143 |
-
# Save processor configuration
|
| 144 |
-
processor_config = {
|
| 145 |
-
"processor_class": "VibeVoiceProcessor",
|
| 146 |
-
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
|
| 147 |
-
"db_normalize": self.db_normalize,
|
| 148 |
-
"audio_processor": {
|
| 149 |
-
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
|
| 150 |
-
"sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000),
|
| 151 |
-
"normalize_audio": getattr(self.audio_processor, 'normalize_audio', True),
|
| 152 |
-
"target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25),
|
| 153 |
-
"eps": getattr(self.audio_processor, 'eps', 1e-6),
|
| 154 |
-
}
|
| 155 |
-
}
|
| 156 |
-
|
| 157 |
-
config_path = os.path.join(save_directory, "preprocessor_config.json")
|
| 158 |
-
with open(config_path, 'w') as f:
|
| 159 |
-
json.dump(processor_config, f, indent=2)
|
| 160 |
-
|
| 161 |
-
logger.info(f"Processor configuration saved in {config_path}")
|
| 162 |
-
|
| 163 |
-
def __call__(
|
| 164 |
-
self,
|
| 165 |
-
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
| 166 |
-
voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,
|
| 167 |
-
padding: Union[bool, str, PaddingStrategy] = True,
|
| 168 |
-
truncation: Union[bool, str, TruncationStrategy] = False,
|
| 169 |
-
max_length: Optional[int] = None,
|
| 170 |
-
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 171 |
-
return_attention_mask: bool = True,
|
| 172 |
-
**kwargs,
|
| 173 |
-
) -> BatchEncoding:
|
| 174 |
-
"""
|
| 175 |
-
Main method to process one or more podcast scripts with optional voice samples.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
text (`str`, `List[str]`):
|
| 179 |
-
The input text(s) to process. Can be:
|
| 180 |
-
- A single script string
|
| 181 |
-
- A list of script strings for batch processing
|
| 182 |
-
- A path to a .json or .txt file
|
| 183 |
-
- A list of paths
|
| 184 |
-
voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*):
|
| 185 |
-
Voice samples for each script. Can be:
|
| 186 |
-
- A list of samples for a single script
|
| 187 |
-
- A list of lists for batch processing
|
| 188 |
-
padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):
|
| 189 |
-
Whether to pad sequences to the same length
|
| 190 |
-
truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):
|
| 191 |
-
Whether to truncate sequences
|
| 192 |
-
max_length (`int`, *optional*):
|
| 193 |
-
Maximum length of the returned sequences
|
| 194 |
-
return_tensors (`str` or `TensorType`, *optional*):
|
| 195 |
-
If set, will return tensors of a particular framework
|
| 196 |
-
return_attention_mask (`bool`, defaults to `True`):
|
| 197 |
-
Whether to return the attention mask
|
| 198 |
-
|
| 199 |
-
Returns:
|
| 200 |
-
`BatchEncoding`: A BatchEncoding with the following fields:
|
| 201 |
-
- **input_ids** -- List of token id sequences or tensor
|
| 202 |
-
- **attention_mask** -- List of attention masks or tensor
|
| 203 |
-
- **speech_tensors** -- Padded speech inputs (if voice_samples provided)
|
| 204 |
-
- **speech_masks** -- Speech masks (if voice_samples provided)
|
| 205 |
-
- **speech_input_mask** -- Boolean masks indicating speech token positions
|
| 206 |
-
"""
|
| 207 |
-
# Handle single vs batch input
|
| 208 |
-
if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
|
| 209 |
-
# Single input
|
| 210 |
-
texts = [text]
|
| 211 |
-
is_batched = False
|
| 212 |
-
else:
|
| 213 |
-
# Batch input
|
| 214 |
-
texts = text
|
| 215 |
-
is_batched = True
|
| 216 |
-
|
| 217 |
-
# Handle voice samples
|
| 218 |
-
if voice_samples is not None:
|
| 219 |
-
if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
|
| 220 |
-
# Single set of voice samples
|
| 221 |
-
voice_samples_list = [voice_samples]
|
| 222 |
-
else:
|
| 223 |
-
# Batch of voice samples
|
| 224 |
-
voice_samples_list = voice_samples
|
| 225 |
-
else:
|
| 226 |
-
voice_samples_list = [None] * len(texts)
|
| 227 |
-
|
| 228 |
-
# Process each input
|
| 229 |
-
all_encodings = []
|
| 230 |
-
for text_input, voice_input in zip(texts, voice_samples_list):
|
| 231 |
-
encoding = self._process_single(text_input, voice_input)
|
| 232 |
-
all_encodings.append(encoding)
|
| 233 |
-
|
| 234 |
-
# Combine batch
|
| 235 |
-
batch_encoding = self._batch_encode(
|
| 236 |
-
all_encodings,
|
| 237 |
-
padding=padding,
|
| 238 |
-
truncation=truncation,
|
| 239 |
-
max_length=max_length,
|
| 240 |
-
return_tensors=return_tensors,
|
| 241 |
-
return_attention_mask=return_attention_mask,
|
| 242 |
-
)
|
| 243 |
-
|
| 244 |
-
return batch_encoding
|
| 245 |
-
|
| 246 |
-
def _process_single(
|
| 247 |
-
self,
|
| 248 |
-
text: Union[str, TextInput],
|
| 249 |
-
voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
|
| 250 |
-
) -> Dict[str, Any]:
|
| 251 |
-
"""Process a single podcast script."""
|
| 252 |
-
# Determine if text is a file path or direct script
|
| 253 |
-
script = None
|
| 254 |
-
if isinstance(text, str):
|
| 255 |
-
# Check if it's a file path
|
| 256 |
-
if text.endswith('.json') and os.path.exists(text):
|
| 257 |
-
script = self._convert_json_to_script(text)
|
| 258 |
-
elif text.endswith('.txt') and os.path.exists(text):
|
| 259 |
-
script = self._convert_text_to_script(text)
|
| 260 |
-
else:
|
| 261 |
-
# Assume it's the script content directly
|
| 262 |
-
script = text
|
| 263 |
-
|
| 264 |
-
if script is None:
|
| 265 |
-
raise ValueError(f"Could not process input text: {text}")
|
| 266 |
-
|
| 267 |
-
# Parse the script
|
| 268 |
-
parsed_lines = self._parse_script(script)
|
| 269 |
-
all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
|
| 270 |
-
|
| 271 |
-
# Create system prompt
|
| 272 |
-
# system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
|
| 273 |
-
system_tokens = self.tokenizer.encode(self.system_prompt)
|
| 274 |
-
|
| 275 |
-
# Process voice samples if provided
|
| 276 |
-
if voice_samples:
|
| 277 |
-
voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
|
| 278 |
-
else:
|
| 279 |
-
voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
|
| 280 |
-
|
| 281 |
-
# Build full token sequence
|
| 282 |
-
full_tokens = system_tokens + voice_tokens
|
| 283 |
-
speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
|
| 284 |
-
|
| 285 |
-
# Add text input section
|
| 286 |
-
full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
|
| 287 |
-
speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
|
| 288 |
-
|
| 289 |
-
for speaker_id, speaker_text in parsed_lines:
|
| 290 |
-
speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
|
| 291 |
-
full_tokens += speaker_text_tokens
|
| 292 |
-
speech_input_mask += [False] * len(speaker_text_tokens)
|
| 293 |
-
|
| 294 |
-
# Add speech output section
|
| 295 |
-
full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
|
| 296 |
-
speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
|
| 297 |
-
|
| 298 |
-
return {
|
| 299 |
-
"input_ids": full_tokens,
|
| 300 |
-
"speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
|
| 301 |
-
"speech_input_mask": speech_input_mask,
|
| 302 |
-
"parsed_script": parsed_lines,
|
| 303 |
-
"all_speakers": all_speakers,
|
| 304 |
-
}
|
| 305 |
-
|
| 306 |
-
def _batch_encode(
|
| 307 |
-
self,
|
| 308 |
-
encodings: List[Dict[str, Any]],
|
| 309 |
-
padding: Union[bool, str, PaddingStrategy] = True,
|
| 310 |
-
truncation: Union[bool, str, TruncationStrategy] = False,
|
| 311 |
-
max_length: Optional[int] = None,
|
| 312 |
-
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 313 |
-
return_attention_mask: bool = True,
|
| 314 |
-
) -> BatchEncoding:
|
| 315 |
-
"""Combine multiple encodings into a batch with padding."""
|
| 316 |
-
# Extract input_ids and create attention_mask
|
| 317 |
-
input_ids_list = [enc["input_ids"] for enc in encodings]
|
| 318 |
-
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
|
| 319 |
-
|
| 320 |
-
# Determine padding strategy
|
| 321 |
-
if isinstance(padding, bool):
|
| 322 |
-
padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
|
| 323 |
-
elif isinstance(padding, str):
|
| 324 |
-
padding_strategy = PaddingStrategy(padding)
|
| 325 |
-
else:
|
| 326 |
-
padding_strategy = padding
|
| 327 |
-
|
| 328 |
-
# Apply padding to input_ids
|
| 329 |
-
if padding_strategy != PaddingStrategy.DO_NOT_PAD:
|
| 330 |
-
if padding_strategy == PaddingStrategy.LONGEST:
|
| 331 |
-
max_len = max(len(ids) for ids in input_ids_list)
|
| 332 |
-
elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
|
| 333 |
-
max_len = max_length
|
| 334 |
-
else:
|
| 335 |
-
max_len = max(len(ids) for ids in input_ids_list)
|
| 336 |
-
|
| 337 |
-
# Pad sequences
|
| 338 |
-
padded_input_ids = []
|
| 339 |
-
attention_masks = []
|
| 340 |
-
padded_speech_input_masks = []
|
| 341 |
-
|
| 342 |
-
for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
|
| 343 |
-
# Truncate if needed
|
| 344 |
-
if truncation and len(input_ids) > max_len:
|
| 345 |
-
input_ids = input_ids[:max_len]
|
| 346 |
-
speech_mask = speech_mask[:max_len]
|
| 347 |
-
|
| 348 |
-
# Pad
|
| 349 |
-
padding_length = max_len - len(input_ids)
|
| 350 |
-
# padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
|
| 351 |
-
padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
|
| 352 |
-
attention_mask = [0] * padding_length + [1] * len(input_ids)
|
| 353 |
-
padded_speech_mask = [False] * padding_length + speech_mask
|
| 354 |
-
|
| 355 |
-
padded_input_ids.append(padded_ids)
|
| 356 |
-
attention_masks.append(attention_mask)
|
| 357 |
-
padded_speech_input_masks.append(padded_speech_mask)
|
| 358 |
-
|
| 359 |
-
input_ids_list = padded_input_ids
|
| 360 |
-
speech_input_masks_list = padded_speech_input_masks
|
| 361 |
-
else:
|
| 362 |
-
# No padding, just create attention masks
|
| 363 |
-
attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
|
| 364 |
-
|
| 365 |
-
# Process speech inputs
|
| 366 |
-
all_speech_inputs = []
|
| 367 |
-
has_speech = False
|
| 368 |
-
for enc in encodings:
|
| 369 |
-
if enc["speech_inputs"] is not None:
|
| 370 |
-
all_speech_inputs.extend(enc["speech_inputs"])
|
| 371 |
-
has_speech = True
|
| 372 |
-
|
| 373 |
-
# Prepare batch encoding
|
| 374 |
-
batch_encoding = BatchEncoding()
|
| 375 |
-
|
| 376 |
-
# Handle tensor conversion
|
| 377 |
-
if return_tensors is not None:
|
| 378 |
-
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
|
| 379 |
-
if return_attention_mask and attention_masks is not None:
|
| 380 |
-
batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
|
| 381 |
-
batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool)
|
| 382 |
-
else:
|
| 383 |
-
batch_encoding["input_ids"] = input_ids_list
|
| 384 |
-
if return_attention_mask and attention_masks is not None:
|
| 385 |
-
batch_encoding["attention_mask"] = attention_masks
|
| 386 |
-
batch_encoding["speech_input_mask"] = speech_input_masks_list
|
| 387 |
-
|
| 388 |
-
# Process speech tensors if present
|
| 389 |
-
if has_speech:
|
| 390 |
-
speech_dict = self.prepare_speech_inputs(
|
| 391 |
-
all_speech_inputs,
|
| 392 |
-
return_tensors=return_tensors,
|
| 393 |
-
)
|
| 394 |
-
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
|
| 395 |
-
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
|
| 396 |
-
else:
|
| 397 |
-
batch_encoding["speech_tensors"] = None
|
| 398 |
-
batch_encoding["speech_masks"] = None
|
| 399 |
-
|
| 400 |
-
# Add metadata
|
| 401 |
-
batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
|
| 402 |
-
batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
|
| 403 |
-
|
| 404 |
-
return batch_encoding
|
| 405 |
-
|
| 406 |
-
def _create_voice_prompt(
|
| 407 |
-
self,
|
| 408 |
-
speaker_samples: List[Union[str, np.ndarray]]
|
| 409 |
-
) -> Tuple[List[int], List[np.ndarray], List[bool]]:
|
| 410 |
-
"""
|
| 411 |
-
Create voice prompt tokens and process audio samples.
|
| 412 |
-
|
| 413 |
-
Returns:
|
| 414 |
-
tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
|
| 415 |
-
"""
|
| 416 |
-
vae_token_id = self.tokenizer.speech_diffusion_id
|
| 417 |
-
|
| 418 |
-
voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
|
| 419 |
-
voice_speech_inputs = []
|
| 420 |
-
voice_speech_masks = [False] * len(voice_full_tokens)
|
| 421 |
-
|
| 422 |
-
for speaker_id, speaker_audio in enumerate(speaker_samples):
|
| 423 |
-
prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
|
| 424 |
-
|
| 425 |
-
# Process audio
|
| 426 |
-
if isinstance(speaker_audio, str):
|
| 427 |
-
# Load audio from file
|
| 428 |
-
wav = self.audio_processor._load_audio_from_path(speaker_audio)
|
| 429 |
-
else:
|
| 430 |
-
wav = np.array(speaker_audio, dtype=np.float32)
|
| 431 |
-
|
| 432 |
-
# Apply normalization if needed
|
| 433 |
-
if self.db_normalize and self.audio_normalizer:
|
| 434 |
-
wav = self.audio_normalizer(wav)
|
| 435 |
-
|
| 436 |
-
# Calculate token length based on compression ratio
|
| 437 |
-
# if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
|
| 438 |
-
# vae_tok_len = wav.shape[0]
|
| 439 |
-
# else:
|
| 440 |
-
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
|
| 441 |
-
|
| 442 |
-
# Build tokens and masks
|
| 443 |
-
speaker_tokens = (prefix_tokens +
|
| 444 |
-
[self.tokenizer.speech_start_id] +
|
| 445 |
-
[vae_token_id] * vae_tok_len +
|
| 446 |
-
[self.tokenizer.speech_end_id] +
|
| 447 |
-
self.tokenizer.encode('\n', add_special_tokens=False))
|
| 448 |
-
|
| 449 |
-
vae_input_mask = ([False] * len(prefix_tokens) +
|
| 450 |
-
[False] +
|
| 451 |
-
[True] * vae_tok_len +
|
| 452 |
-
[False] +
|
| 453 |
-
[False])
|
| 454 |
-
|
| 455 |
-
voice_full_tokens.extend(speaker_tokens)
|
| 456 |
-
voice_speech_masks.extend(vae_input_mask)
|
| 457 |
-
voice_speech_inputs.append(wav)
|
| 458 |
-
|
| 459 |
-
return voice_full_tokens, voice_speech_inputs, voice_speech_masks
|
| 460 |
-
|
| 461 |
-
def prepare_speech_inputs(
|
| 462 |
-
self,
|
| 463 |
-
speech_inputs: List[np.ndarray],
|
| 464 |
-
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 465 |
-
device: Optional[Union[str, torch.device]] = None,
|
| 466 |
-
dtype: Optional[torch.dtype] = None,
|
| 467 |
-
) -> Dict[str, Any]:
|
| 468 |
-
"""
|
| 469 |
-
Prepare speech inputs for model consumption.
|
| 470 |
-
|
| 471 |
-
Args:
|
| 472 |
-
speech_inputs: List of speech arrays
|
| 473 |
-
return_tensors: Output tensor type
|
| 474 |
-
device: Device to place tensors on
|
| 475 |
-
dtype: Data type for tensors
|
| 476 |
-
|
| 477 |
-
Returns:
|
| 478 |
-
Dictionary with padded_speeches and speech_masks
|
| 479 |
-
"""
|
| 480 |
-
if not speech_inputs:
|
| 481 |
-
return {"padded_speeches": None, "speech_masks": None}
|
| 482 |
-
|
| 483 |
-
# Calculate sequence lengths
|
| 484 |
-
vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs]
|
| 485 |
-
# vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
|
| 486 |
-
max_speech_length = max(s.shape[0] for s in speech_inputs)
|
| 487 |
-
|
| 488 |
-
# Pad speeches
|
| 489 |
-
if speech_inputs[0].ndim == 1:
|
| 490 |
-
padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32)
|
| 491 |
-
else:
|
| 492 |
-
padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32)
|
| 493 |
-
speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_)
|
| 494 |
-
|
| 495 |
-
for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)):
|
| 496 |
-
padded_speeches[i, :len(speech)] = speech
|
| 497 |
-
speech_masks[i, :vae_tok_length] = True
|
| 498 |
-
|
| 499 |
-
result = {
|
| 500 |
-
"padded_speeches": padded_speeches,
|
| 501 |
-
"speech_masks": speech_masks,
|
| 502 |
-
}
|
| 503 |
-
|
| 504 |
-
# Convert to tensors if requested
|
| 505 |
-
if return_tensors == "pt":
|
| 506 |
-
result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32)
|
| 507 |
-
result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool)
|
| 508 |
-
|
| 509 |
-
return result
|
| 510 |
-
|
| 511 |
-
def _convert_json_to_script(self, json_file: str) -> str:
|
| 512 |
-
"""
|
| 513 |
-
Convert JSON format to script format.
|
| 514 |
-
Expected JSON format:
|
| 515 |
-
[
|
| 516 |
-
{"speaker": "1", "text": "Hello everyone..."},
|
| 517 |
-
{"speaker": "2", "text": "Great to be here..."}
|
| 518 |
-
]
|
| 519 |
-
"""
|
| 520 |
-
import json
|
| 521 |
-
|
| 522 |
-
with open(json_file, 'r', encoding='utf-8') as f:
|
| 523 |
-
data = json.load(f)
|
| 524 |
-
|
| 525 |
-
if not isinstance(data, list):
|
| 526 |
-
raise ValueError("JSON file must contain a list of speaker entries")
|
| 527 |
-
|
| 528 |
-
script_lines = []
|
| 529 |
-
for item in data:
|
| 530 |
-
if not isinstance(item, dict):
|
| 531 |
-
logger.warning(f"Skipping non-dict entry: {item}")
|
| 532 |
-
continue
|
| 533 |
-
|
| 534 |
-
speaker = item.get('speaker')
|
| 535 |
-
text = item.get('text')
|
| 536 |
-
|
| 537 |
-
if speaker is None or text is None:
|
| 538 |
-
logger.warning(f"Skipping entry missing speaker or text: {item}")
|
| 539 |
-
continue
|
| 540 |
-
|
| 541 |
-
# Ensure speaker ID is valid
|
| 542 |
-
try:
|
| 543 |
-
speaker_id = int(speaker)
|
| 544 |
-
except (ValueError, TypeError):
|
| 545 |
-
logger.warning(f"Invalid speaker ID: {speaker}, skipping entry")
|
| 546 |
-
continue
|
| 547 |
-
|
| 548 |
-
# Clean up text
|
| 549 |
-
text = text.strip()
|
| 550 |
-
if text:
|
| 551 |
-
script_lines.append(f"Speaker {speaker_id}: {text}")
|
| 552 |
-
|
| 553 |
-
if not script_lines:
|
| 554 |
-
raise ValueError("No valid entries found in JSON file")
|
| 555 |
-
|
| 556 |
-
return "\n".join(script_lines)
|
| 557 |
-
|
| 558 |
-
def _convert_text_to_script(self, text_file: str) -> str:
|
| 559 |
-
"""
|
| 560 |
-
Convert text file to script format.
|
| 561 |
-
Handles multiple formats:
|
| 562 |
-
1. Already formatted as "Speaker X: text"
|
| 563 |
-
2. Plain text (assigns to Speaker 1)
|
| 564 |
-
|
| 565 |
-
Handles edge cases like multiple colons in a line.
|
| 566 |
-
"""
|
| 567 |
-
with open(text_file, 'r', encoding='utf-8') as f:
|
| 568 |
-
lines = f.readlines()
|
| 569 |
-
|
| 570 |
-
script_lines = []
|
| 571 |
-
current_speaker = 1
|
| 572 |
-
|
| 573 |
-
for line in lines:
|
| 574 |
-
line = line.strip()
|
| 575 |
-
if not line:
|
| 576 |
-
continue
|
| 577 |
-
|
| 578 |
-
# Try to parse as "Speaker X: text" format
|
| 579 |
-
# Use regex to be more robust
|
| 580 |
-
speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
|
| 581 |
-
|
| 582 |
-
if speaker_match:
|
| 583 |
-
speaker_id = int(speaker_match.group(1))
|
| 584 |
-
text = speaker_match.group(2).strip()
|
| 585 |
-
if text:
|
| 586 |
-
script_lines.append(f"Speaker {speaker_id}: {text}")
|
| 587 |
-
else:
|
| 588 |
-
# Treat as plain text - assign to current speaker
|
| 589 |
-
script_lines.append(f"Speaker {current_speaker}: {line}")
|
| 590 |
-
|
| 591 |
-
if not script_lines:
|
| 592 |
-
raise ValueError("No valid content found in text file")
|
| 593 |
-
|
| 594 |
-
return "\n".join(script_lines)
|
| 595 |
-
|
| 596 |
-
def _parse_script(self, script: str) -> List[Tuple[int, str]]:
|
| 597 |
-
"""Parse script into list of (speaker_id, text) tuples."""
|
| 598 |
-
lines = script.strip().split("\n")
|
| 599 |
-
parsed_lines = []
|
| 600 |
-
speaker_ids = []
|
| 601 |
-
|
| 602 |
-
# First pass: parse all lines and collect speaker IDs
|
| 603 |
-
for line in lines:
|
| 604 |
-
if not line.strip():
|
| 605 |
-
continue
|
| 606 |
-
|
| 607 |
-
# Use regex to handle edge cases like multiple colons
|
| 608 |
-
match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
|
| 609 |
-
|
| 610 |
-
if match:
|
| 611 |
-
speaker_id = int(match.group(1))
|
| 612 |
-
text = ' ' + match.group(2).strip()
|
| 613 |
-
parsed_lines.append((speaker_id, text))
|
| 614 |
-
speaker_ids.append(speaker_id)
|
| 615 |
-
else:
|
| 616 |
-
logger.warning(f"Could not parse line: '{line}'")
|
| 617 |
-
|
| 618 |
-
if not parsed_lines:
|
| 619 |
-
if script.strip():
|
| 620 |
-
# Treat the entire script as a single line with default speaker
|
| 621 |
-
parsed_lines.append({'speaker': 'Narrator', 'text': script.strip()})
|
| 622 |
-
else:
|
| 623 |
-
if script.strip():
|
| 624 |
-
# Treat the entire script as a single line with default speaker
|
| 625 |
-
parsed_lines.append({'speaker': 'Narrator', 'text': script.strip()})
|
| 626 |
-
return parsed_lines
|
| 627 |
-
else:
|
| 628 |
-
raise ValueError("No valid speaker lines found in script")
|
| 629 |
-
|
| 630 |
-
# Check if we need to normalize speaker IDs (only if all are > 0)
|
| 631 |
-
min_speaker_id = min(speaker_ids)
|
| 632 |
-
if min_speaker_id > 0:
|
| 633 |
-
# Normalize to start from 0
|
| 634 |
-
normalized_lines = []
|
| 635 |
-
for speaker_id, text in parsed_lines:
|
| 636 |
-
normalized_lines.append((speaker_id - 1, text))
|
| 637 |
-
return normalized_lines
|
| 638 |
-
else:
|
| 639 |
-
# Keep original IDs
|
| 640 |
-
return parsed_lines
|
| 641 |
-
|
| 642 |
-
def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding:
|
| 643 |
-
"""Merge text and audio inputs into a single BatchEncoding."""
|
| 644 |
-
# Start with text inputs
|
| 645 |
-
merged = BatchEncoding(text_inputs)
|
| 646 |
-
|
| 647 |
-
# Add audio-specific fields
|
| 648 |
-
if "audio" in audio_inputs:
|
| 649 |
-
merged["speech_inputs"] = audio_inputs["audio"]
|
| 650 |
-
if "streaming" in audio_inputs:
|
| 651 |
-
merged["streaming"] = audio_inputs["streaming"]
|
| 652 |
-
|
| 653 |
-
return merged
|
| 654 |
-
|
| 655 |
-
def batch_decode(self, *args, **kwargs):
|
| 656 |
-
"""
|
| 657 |
-
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].
|
| 658 |
-
Please refer to the docstring of this method for more information.
|
| 659 |
-
"""
|
| 660 |
-
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 661 |
-
|
| 662 |
-
def decode(self, *args, **kwargs):
|
| 663 |
-
"""
|
| 664 |
-
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].
|
| 665 |
-
Please refer to the docstring of this method for more information.
|
| 666 |
-
"""
|
| 667 |
-
return self.tokenizer.decode(*args, **kwargs)
|
| 668 |
-
|
| 669 |
-
@property
|
| 670 |
-
def model_input_names(self):
|
| 671 |
-
"""
|
| 672 |
-
Return the list of inputs accepted by the model.
|
| 673 |
-
"""
|
| 674 |
-
tokenizer_input_names = self.tokenizer.model_input_names
|
| 675 |
-
audio_processor_input_names = self.audio_processor.model_input_names
|
| 676 |
-
return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"]))
|
| 677 |
-
|
| 678 |
-
def save_audio(self,
|
| 679 |
-
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
|
| 680 |
-
output_path: str = "output.wav",
|
| 681 |
-
sampling_rate: Optional[int] = None,
|
| 682 |
-
normalize: bool = False,
|
| 683 |
-
batch_prefix: str = "audio_",
|
| 684 |
-
) -> str:
|
| 685 |
-
"""
|
| 686 |
-
Save audio data to a file.
|
| 687 |
-
Args:
|
| 688 |
-
audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):
|
| 689 |
-
The audio data to save. Can be a single tensor/array or a list of them.
|
| 690 |
-
output_path (str, optional): Path to save the audio file. Defaults to "output.wav".
|
| 691 |
-
sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.
|
| 692 |
-
normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.
|
| 693 |
-
batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".
|
| 694 |
-
Returns:
|
| 695 |
-
str: The path to the saved audio file.
|
| 696 |
-
"""
|
| 697 |
-
return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix)
|
| 698 |
-
|
| 699 |
-
__all__ = [
|
| 700 |
-
"VibeVoiceProcessor",
|
| 701 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/processor/vibevoice_tokenizer_processor.py
DELETED
|
@@ -1,483 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Processor class for VibeVoice models.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import os
|
| 6 |
-
import json
|
| 7 |
-
import warnings
|
| 8 |
-
from typing import List, Optional, Union, Dict, Any
|
| 9 |
-
|
| 10 |
-
import numpy as np
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
| 14 |
-
from transformers.utils import logging
|
| 15 |
-
|
| 16 |
-
logger = logging.get_logger(__name__)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class AudioNormalizer:
|
| 20 |
-
"""
|
| 21 |
-
Audio normalization class for VibeVoice tokenizer.
|
| 22 |
-
|
| 23 |
-
This class provides audio normalization to ensure consistent input levels
|
| 24 |
-
for the VibeVoice tokenizer while maintaining audio quality.
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
|
| 28 |
-
"""
|
| 29 |
-
Initialize the audio normalizer.
|
| 30 |
-
|
| 31 |
-
Args:
|
| 32 |
-
target_dB_FS (float): Target dB FS level for the audio. Default: -25
|
| 33 |
-
eps (float): Small value to avoid division by zero. Default: 1e-6
|
| 34 |
-
"""
|
| 35 |
-
self.target_dB_FS = target_dB_FS
|
| 36 |
-
self.eps = eps
|
| 37 |
-
|
| 38 |
-
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
|
| 39 |
-
"""
|
| 40 |
-
Adjust the audio to the target dB FS level.
|
| 41 |
-
|
| 42 |
-
Args:
|
| 43 |
-
audio (np.ndarray): Input audio signal
|
| 44 |
-
|
| 45 |
-
Returns:
|
| 46 |
-
tuple: (normalized_audio, rms, scalar)
|
| 47 |
-
"""
|
| 48 |
-
rms = np.sqrt(np.mean(audio**2))
|
| 49 |
-
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
|
| 50 |
-
normalized_audio = audio * scalar
|
| 51 |
-
return normalized_audio, rms, scalar
|
| 52 |
-
|
| 53 |
-
def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
|
| 54 |
-
"""
|
| 55 |
-
Avoid clipping by scaling down if necessary.
|
| 56 |
-
|
| 57 |
-
Args:
|
| 58 |
-
audio (np.ndarray): Input audio signal
|
| 59 |
-
scalar (float, optional): Explicit scaling factor
|
| 60 |
-
|
| 61 |
-
Returns:
|
| 62 |
-
tuple: (normalized_audio, scalar)
|
| 63 |
-
"""
|
| 64 |
-
if scalar is None:
|
| 65 |
-
max_val = np.max(np.abs(audio))
|
| 66 |
-
if max_val > 1.0:
|
| 67 |
-
scalar = max_val + self.eps
|
| 68 |
-
else:
|
| 69 |
-
scalar = 1.0
|
| 70 |
-
|
| 71 |
-
return audio / scalar, scalar
|
| 72 |
-
|
| 73 |
-
def __call__(self, audio: np.ndarray) -> np.ndarray:
|
| 74 |
-
"""
|
| 75 |
-
Normalize the audio by adjusting to target dB FS and avoiding clipping.
|
| 76 |
-
|
| 77 |
-
Args:
|
| 78 |
-
audio (np.ndarray): Input audio signal
|
| 79 |
-
|
| 80 |
-
Returns:
|
| 81 |
-
np.ndarray: Normalized audio signal
|
| 82 |
-
"""
|
| 83 |
-
# First adjust to target dB FS
|
| 84 |
-
audio, _, _ = self.tailor_dB_FS(audio)
|
| 85 |
-
# Then avoid clipping
|
| 86 |
-
audio, _ = self.avoid_clipping(audio)
|
| 87 |
-
return audio
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
# Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
|
| 91 |
-
class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
|
| 92 |
-
"""
|
| 93 |
-
Processor for VibeVoice acoustic tokenizer models.
|
| 94 |
-
|
| 95 |
-
This processor handles audio preprocessing for VibeVoice models, including:
|
| 96 |
-
- Audio format conversion (stereo to mono)
|
| 97 |
-
- Optional audio normalization
|
| 98 |
-
- Streaming support for infinite-length audio
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
sampling_rate (int, optional): Expected sampling rate. Defaults to 24000.
|
| 102 |
-
normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
|
| 103 |
-
target_dB_FS (float, optional): Target dB FS for normalization. Defaults to -25.
|
| 104 |
-
eps (float, optional): Small value for numerical stability. Defaults to 1e-6.
|
| 105 |
-
"""
|
| 106 |
-
model_input_names = ["input_features"]
|
| 107 |
-
|
| 108 |
-
def __init__(
|
| 109 |
-
self,
|
| 110 |
-
sampling_rate: int = 24000,
|
| 111 |
-
normalize_audio: bool = True,
|
| 112 |
-
target_dB_FS: float = -25,
|
| 113 |
-
eps: float = 1e-6,
|
| 114 |
-
**kwargs,
|
| 115 |
-
):
|
| 116 |
-
super().__init__(**kwargs)
|
| 117 |
-
|
| 118 |
-
self.sampling_rate = sampling_rate
|
| 119 |
-
self.normalize_audio = normalize_audio
|
| 120 |
-
|
| 121 |
-
# Initialize audio normalizer if needed
|
| 122 |
-
if self.normalize_audio:
|
| 123 |
-
self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps)
|
| 124 |
-
else:
|
| 125 |
-
self.normalizer = None
|
| 126 |
-
|
| 127 |
-
# Save config
|
| 128 |
-
self.feature_extractor_dict = {
|
| 129 |
-
"sampling_rate": sampling_rate,
|
| 130 |
-
"normalize_audio": normalize_audio,
|
| 131 |
-
"target_dB_FS": target_dB_FS,
|
| 132 |
-
"eps": eps,
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
|
| 136 |
-
"""
|
| 137 |
-
Convert stereo audio to mono if needed.
|
| 138 |
-
|
| 139 |
-
Args:
|
| 140 |
-
audio (np.ndarray): Input audio array
|
| 141 |
-
|
| 142 |
-
Returns:
|
| 143 |
-
np.ndarray: Mono audio array
|
| 144 |
-
"""
|
| 145 |
-
if len(audio.shape) == 1:
|
| 146 |
-
return audio
|
| 147 |
-
elif len(audio.shape) == 2:
|
| 148 |
-
if audio.shape[0] == 2: # (2, time)
|
| 149 |
-
return np.mean(audio, axis=0)
|
| 150 |
-
elif audio.shape[1] == 2: # (time, 2)
|
| 151 |
-
return np.mean(audio, axis=1)
|
| 152 |
-
else:
|
| 153 |
-
# If one dimension is 1, squeeze it
|
| 154 |
-
if audio.shape[0] == 1:
|
| 155 |
-
return audio.squeeze(0)
|
| 156 |
-
elif audio.shape[1] == 1:
|
| 157 |
-
return audio.squeeze(1)
|
| 158 |
-
else:
|
| 159 |
-
raise ValueError(f"Unexpected audio shape: {audio.shape}")
|
| 160 |
-
else:
|
| 161 |
-
raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}")
|
| 162 |
-
|
| 163 |
-
def _process_single_audio(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray:
|
| 164 |
-
"""
|
| 165 |
-
Process a single audio array.
|
| 166 |
-
|
| 167 |
-
Args:
|
| 168 |
-
audio: Single audio input
|
| 169 |
-
|
| 170 |
-
Returns:
|
| 171 |
-
np.ndarray: Processed audio
|
| 172 |
-
"""
|
| 173 |
-
# Convert to numpy array
|
| 174 |
-
if not isinstance(audio, np.ndarray):
|
| 175 |
-
audio = np.array(audio, dtype=np.float32)
|
| 176 |
-
else:
|
| 177 |
-
audio = audio.astype(np.float32)
|
| 178 |
-
|
| 179 |
-
# Ensure mono
|
| 180 |
-
audio = self._ensure_mono(audio)
|
| 181 |
-
|
| 182 |
-
# Normalize if requested
|
| 183 |
-
if self.normalize_audio and self.normalizer is not None:
|
| 184 |
-
audio = self.normalizer(audio)
|
| 185 |
-
|
| 186 |
-
return audio
|
| 187 |
-
|
| 188 |
-
def __call__(
|
| 189 |
-
self,
|
| 190 |
-
audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str]] = None,
|
| 191 |
-
sampling_rate: Optional[int] = None,
|
| 192 |
-
return_tensors: Optional[str] = None,
|
| 193 |
-
**kwargs,
|
| 194 |
-
):
|
| 195 |
-
"""
|
| 196 |
-
Process audio for VibeVoice models.
|
| 197 |
-
|
| 198 |
-
Args:
|
| 199 |
-
audio: Audio input(s) to process. Can be:
|
| 200 |
-
- str: Path to audio file
|
| 201 |
-
- np.ndarray: Audio array
|
| 202 |
-
- List[float]: Audio as list of floats
|
| 203 |
-
- List[np.ndarray]: Batch of audio arrays
|
| 204 |
-
- List[str]: Batch of audio file paths
|
| 205 |
-
sampling_rate (int, optional): Sampling rate of the input audio
|
| 206 |
-
return_tensors (str, optional): Return format ('pt' for PyTorch, 'np' for NumPy)
|
| 207 |
-
|
| 208 |
-
Returns:
|
| 209 |
-
dict: Processed audio inputs with keys:
|
| 210 |
-
- input_features: Audio tensor(s) ready for the model
|
| 211 |
-
"""
|
| 212 |
-
if audio is None:
|
| 213 |
-
raise ValueError("Audio input is required")
|
| 214 |
-
|
| 215 |
-
# Validate sampling rate
|
| 216 |
-
if sampling_rate is not None and sampling_rate != self.sampling_rate:
|
| 217 |
-
logger.warning(
|
| 218 |
-
f"Input sampling rate ({sampling_rate}) differs from expected "
|
| 219 |
-
f"sampling rate ({self.sampling_rate}). Please resample your audio."
|
| 220 |
-
)
|
| 221 |
-
|
| 222 |
-
# Handle different input types
|
| 223 |
-
if isinstance(audio, str):
|
| 224 |
-
# Single audio file path
|
| 225 |
-
audio = self._load_audio_from_path(audio)
|
| 226 |
-
is_batched = False
|
| 227 |
-
elif isinstance(audio, list):
|
| 228 |
-
if len(audio) == 0:
|
| 229 |
-
raise ValueError("Empty audio list provided")
|
| 230 |
-
|
| 231 |
-
# Check if it's a list of file paths
|
| 232 |
-
if all(isinstance(item, str) for item in audio):
|
| 233 |
-
# Batch of audio file paths
|
| 234 |
-
audio = [self._load_audio_from_path(path) for path in audio]
|
| 235 |
-
is_batched = True
|
| 236 |
-
else:
|
| 237 |
-
# Check if it's batched audio arrays
|
| 238 |
-
is_batched = isinstance(audio[0], (np.ndarray, list))
|
| 239 |
-
else:
|
| 240 |
-
# Single audio array or list
|
| 241 |
-
is_batched = False
|
| 242 |
-
|
| 243 |
-
# Process audio
|
| 244 |
-
if is_batched:
|
| 245 |
-
processed_audio = [self._process_single_audio(a) for a in audio]
|
| 246 |
-
else:
|
| 247 |
-
processed_audio = [self._process_single_audio(audio)]
|
| 248 |
-
|
| 249 |
-
# Convert to tensors if requested
|
| 250 |
-
if return_tensors == "pt":
|
| 251 |
-
if len(processed_audio) == 1:
|
| 252 |
-
# Create a proper batch dimension (B, T)
|
| 253 |
-
input_features = torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1)
|
| 254 |
-
else:
|
| 255 |
-
# For batched input with different lengths, create a batch properly
|
| 256 |
-
input_features = torch.stack([torch.from_numpy(a) for a in processed_audio]).unsqueeze(1)
|
| 257 |
-
elif return_tensors == "np":
|
| 258 |
-
if len(processed_audio) == 1:
|
| 259 |
-
input_features = processed_audio[0][np.newaxis, np.newaxis, :]
|
| 260 |
-
else:
|
| 261 |
-
input_features = np.stack(processed_audio)[:, np.newaxis, :]
|
| 262 |
-
else:
|
| 263 |
-
input_features = processed_audio[0] if len(processed_audio) == 1 else processed_audio
|
| 264 |
-
|
| 265 |
-
outputs = {
|
| 266 |
-
"audio": input_features, # Use "audio" instead of "input_features"
|
| 267 |
-
}
|
| 268 |
-
|
| 269 |
-
return outputs
|
| 270 |
-
|
| 271 |
-
def _load_audio_from_path(self, audio_path: str) -> np.ndarray:
|
| 272 |
-
"""
|
| 273 |
-
Load audio from file path.
|
| 274 |
-
|
| 275 |
-
Args:
|
| 276 |
-
audio_path (str): Path to audio file
|
| 277 |
-
|
| 278 |
-
Returns:
|
| 279 |
-
np.ndarray: Loaded audio array
|
| 280 |
-
"""
|
| 281 |
-
# Get file extension to determine loading method
|
| 282 |
-
file_ext = os.path.splitext(audio_path)[1].lower()
|
| 283 |
-
|
| 284 |
-
if file_ext in ['.wav', '.mp3', '.flac', '.m4a', '.ogg']:
|
| 285 |
-
# Audio file - use librosa
|
| 286 |
-
import librosa
|
| 287 |
-
audio_array, sr = librosa.load(
|
| 288 |
-
audio_path,
|
| 289 |
-
sr=self.sampling_rate,
|
| 290 |
-
mono=True
|
| 291 |
-
)
|
| 292 |
-
return audio_array
|
| 293 |
-
elif file_ext == '.pt':
|
| 294 |
-
# PyTorch tensor file
|
| 295 |
-
audio_tensor = torch.load(audio_path, map_location='cpu').squeeze()
|
| 296 |
-
if isinstance(audio_tensor, torch.Tensor):
|
| 297 |
-
audio_array = audio_tensor.numpy()
|
| 298 |
-
else:
|
| 299 |
-
audio_array = np.array(audio_tensor)
|
| 300 |
-
return audio_array.astype(np.float32)
|
| 301 |
-
elif file_ext == '.npy':
|
| 302 |
-
# NumPy file
|
| 303 |
-
audio_array = np.load(audio_path)
|
| 304 |
-
return audio_array.astype(np.float32)
|
| 305 |
-
else:
|
| 306 |
-
raise ValueError(
|
| 307 |
-
f"Unsupported file format: {file_ext}. "
|
| 308 |
-
f"Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz"
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
def preprocess_audio(
|
| 312 |
-
self,
|
| 313 |
-
audio_path_or_array: Union[str, np.ndarray],
|
| 314 |
-
normalize: Optional[bool] = None,
|
| 315 |
-
) -> np.ndarray:
|
| 316 |
-
"""
|
| 317 |
-
Convenience method to preprocess audio from file path or array.
|
| 318 |
-
This method is kept for backward compatibility but __call__ is recommended.
|
| 319 |
-
|
| 320 |
-
Args:
|
| 321 |
-
audio_path_or_array: Path to audio file or numpy array
|
| 322 |
-
normalize: Whether to normalize (overrides default setting)
|
| 323 |
-
|
| 324 |
-
Returns:
|
| 325 |
-
np.ndarray: Preprocessed audio array
|
| 326 |
-
"""
|
| 327 |
-
if isinstance(audio_path_or_array, str):
|
| 328 |
-
audio_array = self._load_audio_from_path(audio_path_or_array)
|
| 329 |
-
else:
|
| 330 |
-
audio_array = np.array(audio_path_or_array, dtype=np.float32)
|
| 331 |
-
|
| 332 |
-
# Override normalization setting if specified
|
| 333 |
-
original_normalize = self.normalize_audio
|
| 334 |
-
if normalize is not None:
|
| 335 |
-
self.normalize_audio = normalize
|
| 336 |
-
|
| 337 |
-
try:
|
| 338 |
-
processed = self._process_single_audio(audio_array)
|
| 339 |
-
finally:
|
| 340 |
-
# Restore original setting
|
| 341 |
-
self.normalize_audio = original_normalize
|
| 342 |
-
|
| 343 |
-
return processed
|
| 344 |
-
|
| 345 |
-
# Override to_dict method for configuration saving
|
| 346 |
-
def to_dict(self) -> Dict[str, Any]:
|
| 347 |
-
"""
|
| 348 |
-
Convert the object to a dict containing all attributes needed for serialization.
|
| 349 |
-
"""
|
| 350 |
-
return self.feature_extractor_dict
|
| 351 |
-
|
| 352 |
-
def save_audio(
|
| 353 |
-
self,
|
| 354 |
-
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
|
| 355 |
-
output_path: str = "output.wav",
|
| 356 |
-
sampling_rate: Optional[int] = None,
|
| 357 |
-
normalize: bool = False,
|
| 358 |
-
batch_prefix: str = "audio_",
|
| 359 |
-
):
|
| 360 |
-
"""
|
| 361 |
-
Save audio data to WAV file(s).
|
| 362 |
-
|
| 363 |
-
Args:
|
| 364 |
-
audio: Audio data to save. Can be:
|
| 365 |
-
- torch.Tensor: PyTorch tensor with shape (B, C, T) or (B, T) or (T)
|
| 366 |
-
- np.ndarray: NumPy array with shape (B, C, T) or (B, T) or (T)
|
| 367 |
-
- List of tensors or arrays
|
| 368 |
-
output_path: Path where to save the audio. If saving multiple files,
|
| 369 |
-
this is treated as a directory and individual files will be saved inside.
|
| 370 |
-
sampling_rate: Sampling rate for the saved audio. Defaults to the processor's rate.
|
| 371 |
-
normalize: Whether to normalize audio before saving.
|
| 372 |
-
batch_prefix: Prefix for batch files when saving multiple audios.
|
| 373 |
-
|
| 374 |
-
Returns:
|
| 375 |
-
List[str]: Paths to the saved audio files.
|
| 376 |
-
"""
|
| 377 |
-
if sampling_rate is None:
|
| 378 |
-
sampling_rate = self.sampling_rate
|
| 379 |
-
|
| 380 |
-
try:
|
| 381 |
-
import soundfile as sf
|
| 382 |
-
except ImportError:
|
| 383 |
-
raise ImportError(
|
| 384 |
-
"soundfile is required to save audio files. "
|
| 385 |
-
"Install it with: pip install soundfile"
|
| 386 |
-
)
|
| 387 |
-
|
| 388 |
-
# Ensure audio is in the right format
|
| 389 |
-
if isinstance(audio, torch.Tensor):
|
| 390 |
-
# Convert PyTorch tensor to numpy
|
| 391 |
-
audio_np = audio.float().detach().cpu().numpy()
|
| 392 |
-
elif isinstance(audio, np.ndarray):
|
| 393 |
-
audio_np = audio
|
| 394 |
-
elif isinstance(audio, list):
|
| 395 |
-
# Handle list of tensors or arrays
|
| 396 |
-
if all(isinstance(a, torch.Tensor) for a in audio):
|
| 397 |
-
audio_np = [a.float().detach().cpu().numpy() for a in audio]
|
| 398 |
-
else:
|
| 399 |
-
audio_np = audio
|
| 400 |
-
else:
|
| 401 |
-
raise ValueError(f"Unsupported audio type: {type(audio)}")
|
| 402 |
-
|
| 403 |
-
saved_paths = []
|
| 404 |
-
|
| 405 |
-
# Handle based on shape or type
|
| 406 |
-
if isinstance(audio_np, list):
|
| 407 |
-
# Multiple separate audios to save
|
| 408 |
-
output_dir = output_path
|
| 409 |
-
|
| 410 |
-
# Ensure output directory exists
|
| 411 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 412 |
-
|
| 413 |
-
# Save each audio
|
| 414 |
-
for i, audio_item in enumerate(audio_np):
|
| 415 |
-
audio_item = self._prepare_audio_for_save(audio_item, normalize)
|
| 416 |
-
file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
|
| 417 |
-
sf.write(file_path, audio_item, sampling_rate)
|
| 418 |
-
saved_paths.append(file_path)
|
| 419 |
-
|
| 420 |
-
else:
|
| 421 |
-
# Handle different dimensions
|
| 422 |
-
if len(audio_np.shape) >= 3: # (B, C, T) or similar
|
| 423 |
-
# Get batch size
|
| 424 |
-
batch_size = audio_np.shape[0]
|
| 425 |
-
|
| 426 |
-
if batch_size > 1:
|
| 427 |
-
# Multiple audios in a batch
|
| 428 |
-
output_dir = output_path
|
| 429 |
-
|
| 430 |
-
# Ensure output directory exists
|
| 431 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 432 |
-
|
| 433 |
-
# Save each audio in the batch
|
| 434 |
-
for i in range(batch_size):
|
| 435 |
-
# Extract single audio and remove channel dim if present
|
| 436 |
-
single_audio = audio_np[i]
|
| 437 |
-
if len(single_audio.shape) > 1:
|
| 438 |
-
if single_audio.shape[0] == 1: # (1, T)
|
| 439 |
-
single_audio = single_audio.squeeze(0)
|
| 440 |
-
|
| 441 |
-
single_audio = self._prepare_audio_for_save(single_audio, normalize)
|
| 442 |
-
file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
|
| 443 |
-
sf.write(file_path, single_audio, sampling_rate)
|
| 444 |
-
saved_paths.append(file_path)
|
| 445 |
-
else:
|
| 446 |
-
# Single audio with batch and channel dims
|
| 447 |
-
audio_item = audio_np.squeeze() # Remove batch and channel dimensions
|
| 448 |
-
audio_item = self._prepare_audio_for_save(audio_item, normalize)
|
| 449 |
-
sf.write(output_path, audio_item, sampling_rate)
|
| 450 |
-
saved_paths.append(output_path)
|
| 451 |
-
else:
|
| 452 |
-
# Single audio without batch dimension
|
| 453 |
-
audio_item = self._prepare_audio_for_save(audio_np, normalize)
|
| 454 |
-
sf.write(output_path, audio_item, sampling_rate)
|
| 455 |
-
saved_paths.append(output_path)
|
| 456 |
-
|
| 457 |
-
return saved_paths
|
| 458 |
-
|
| 459 |
-
def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
|
| 460 |
-
"""
|
| 461 |
-
Prepare audio for saving by ensuring it's the right shape and optionally normalizing.
|
| 462 |
-
|
| 463 |
-
Args:
|
| 464 |
-
audio: Audio data as numpy array
|
| 465 |
-
normalize: Whether to normalize audio
|
| 466 |
-
|
| 467 |
-
Returns:
|
| 468 |
-
np.ndarray: Processed audio ready for saving
|
| 469 |
-
"""
|
| 470 |
-
# Ensure right dimensionality
|
| 471 |
-
if len(audio.shape) > 1 and audio.shape[0] == 1: # (1, T)
|
| 472 |
-
audio = audio.squeeze(0)
|
| 473 |
-
|
| 474 |
-
# Normalize if requested
|
| 475 |
-
if normalize:
|
| 476 |
-
max_val = np.abs(audio).max()
|
| 477 |
-
if max_val > 0:
|
| 478 |
-
audio = audio / max_val
|
| 479 |
-
|
| 480 |
-
return audio
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
__all__ = ["VibeVoiceTokenizerProcessor", "AudioNormalizer"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/schedule/__init__.py
DELETED
|
File without changes
|
src/vibevoice/schedule/dpm_solver.py
DELETED
|
@@ -1,1065 +0,0 @@
|
|
| 1 |
-
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
|
| 16 |
-
|
| 17 |
-
import math
|
| 18 |
-
from typing import List, Optional, Tuple, Union
|
| 19 |
-
|
| 20 |
-
import numpy as np
|
| 21 |
-
import torch
|
| 22 |
-
|
| 23 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
-
from diffusers.utils import deprecate
|
| 25 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 26 |
-
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
| 27 |
-
|
| 28 |
-
def betas_for_alpha_bar(
|
| 29 |
-
num_diffusion_timesteps,
|
| 30 |
-
max_beta=0.999,
|
| 31 |
-
alpha_transform_type="cosine",
|
| 32 |
-
):
|
| 33 |
-
"""
|
| 34 |
-
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
| 35 |
-
(1-beta) over time from t = [0,1].
|
| 36 |
-
|
| 37 |
-
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
| 38 |
-
to that part of the diffusion process.
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
Args:
|
| 42 |
-
num_diffusion_timesteps (`int`): the number of betas to produce.
|
| 43 |
-
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
| 44 |
-
prevent singularities.
|
| 45 |
-
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
| 46 |
-
Choose from `cosine` or `exp`
|
| 47 |
-
|
| 48 |
-
Returns:
|
| 49 |
-
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
| 50 |
-
"""
|
| 51 |
-
if alpha_transform_type == "cosine":
|
| 52 |
-
|
| 53 |
-
def alpha_bar_fn(t):
|
| 54 |
-
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 55 |
-
# return math.cos(t * math.pi / 2 * 0.95) ** 2
|
| 56 |
-
|
| 57 |
-
elif alpha_transform_type == "exp":
|
| 58 |
-
|
| 59 |
-
def alpha_bar_fn(t):
|
| 60 |
-
return math.exp(t * -12.0)
|
| 61 |
-
|
| 62 |
-
elif alpha_transform_type == "cauchy":
|
| 63 |
-
# µ + γ tan (π (0.5 - x)) γ = 1, µ = 3
|
| 64 |
-
# alpha^2 = 1-1/(exp(λ)+1)
|
| 65 |
-
def alpha_bar_fn(t, gamma=1, mu=3):
|
| 66 |
-
snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9)
|
| 67 |
-
return 1 - 1 / (math.exp(snr) + 1.1)
|
| 68 |
-
|
| 69 |
-
elif alpha_transform_type == "laplace":
|
| 70 |
-
# µ − bsgn(0.5 − t) log(1 − 2|t − 0.5|) µ = 0, b = 1
|
| 71 |
-
def alpha_bar_fn(t, mu=0, b=1):
|
| 72 |
-
snr = mu - b * math.copysign(1, 0.5 - t) * math.log(1 - 2 * abs(t - 0.5) * 0.98)
|
| 73 |
-
return 1 - 1 / (math.exp(snr) + 1.02)
|
| 74 |
-
|
| 75 |
-
else:
|
| 76 |
-
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
| 77 |
-
|
| 78 |
-
betas = []
|
| 79 |
-
for i in range(num_diffusion_timesteps):
|
| 80 |
-
t1 = i / num_diffusion_timesteps
|
| 81 |
-
t2 = (i + 1) / num_diffusion_timesteps
|
| 82 |
-
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
| 83 |
-
return torch.tensor(betas, dtype=torch.float32)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
| 87 |
-
def rescale_zero_terminal_snr(betas):
|
| 88 |
-
"""
|
| 89 |
-
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
Args:
|
| 93 |
-
betas (`torch.Tensor`):
|
| 94 |
-
the betas that the scheduler is being initialized with.
|
| 95 |
-
|
| 96 |
-
Returns:
|
| 97 |
-
`torch.Tensor`: rescaled betas with zero terminal SNR
|
| 98 |
-
"""
|
| 99 |
-
# Convert betas to alphas_bar_sqrt
|
| 100 |
-
alphas = 1.0 - betas
|
| 101 |
-
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 102 |
-
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 103 |
-
|
| 104 |
-
# Store old values.
|
| 105 |
-
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 106 |
-
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 107 |
-
|
| 108 |
-
# Shift so the last timestep is zero.
|
| 109 |
-
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 110 |
-
|
| 111 |
-
# Scale so the first timestep is back to the old value.
|
| 112 |
-
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 113 |
-
|
| 114 |
-
# Convert alphas_bar_sqrt to betas
|
| 115 |
-
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
| 116 |
-
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
| 117 |
-
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 118 |
-
betas = 1 - alphas
|
| 119 |
-
|
| 120 |
-
return betas
|
| 121 |
-
|
| 122 |
-
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
| 123 |
-
"""
|
| 124 |
-
`DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
|
| 125 |
-
|
| 126 |
-
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 127 |
-
methods the library implements for all schedulers such as loading and saving.
|
| 128 |
-
|
| 129 |
-
Args:
|
| 130 |
-
num_train_timesteps (`int`, defaults to 1000):
|
| 131 |
-
The number of diffusion steps to train the model.
|
| 132 |
-
beta_start (`float`, defaults to 0.0001):
|
| 133 |
-
The starting `beta` value of inference.
|
| 134 |
-
beta_end (`float`, defaults to 0.02):
|
| 135 |
-
The final `beta` value.
|
| 136 |
-
beta_schedule (`str`, defaults to `"linear"`):
|
| 137 |
-
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
| 138 |
-
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
| 139 |
-
trained_betas (`np.ndarray`, *optional*):
|
| 140 |
-
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
| 141 |
-
solver_order (`int`, defaults to 2):
|
| 142 |
-
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
| 143 |
-
sampling, and `solver_order=3` for unconditional sampling.
|
| 144 |
-
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
| 145 |
-
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
| 146 |
-
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
| 147 |
-
Video](https://imagen.research.google/video/paper.pdf) paper).
|
| 148 |
-
thresholding (`bool`, defaults to `False`):
|
| 149 |
-
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
| 150 |
-
as Stable Diffusion.
|
| 151 |
-
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
| 152 |
-
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
| 153 |
-
sample_max_value (`float`, defaults to 1.0):
|
| 154 |
-
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
| 155 |
-
`algorithm_type="dpmsolver++"`.
|
| 156 |
-
algorithm_type (`str`, defaults to `dpmsolver++`):
|
| 157 |
-
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
| 158 |
-
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
| 159 |
-
paper, and the `dpmsolver++` type implements the algorithms in the
|
| 160 |
-
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
| 161 |
-
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
| 162 |
-
solver_type (`str`, defaults to `midpoint`):
|
| 163 |
-
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
| 164 |
-
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
| 165 |
-
lower_order_final (`bool`, defaults to `True`):
|
| 166 |
-
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
| 167 |
-
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
| 168 |
-
euler_at_final (`bool`, defaults to `False`):
|
| 169 |
-
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
|
| 170 |
-
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
|
| 171 |
-
steps, but sometimes may result in blurring.
|
| 172 |
-
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
| 173 |
-
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
| 174 |
-
the sigmas are determined according to a sequence of noise levels {σi}.
|
| 175 |
-
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
|
| 176 |
-
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
|
| 177 |
-
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
|
| 178 |
-
`lambda(t)`.
|
| 179 |
-
final_sigmas_type (`str`, defaults to `"zero"`):
|
| 180 |
-
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
| 181 |
-
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
| 182 |
-
lambda_min_clipped (`float`, defaults to `-inf`):
|
| 183 |
-
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
| 184 |
-
cosine (`squaredcos_cap_v2`) noise schedule.
|
| 185 |
-
variance_type (`str`, *optional*):
|
| 186 |
-
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
| 187 |
-
contains the predicted Gaussian variance.
|
| 188 |
-
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 189 |
-
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 190 |
-
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 191 |
-
steps_offset (`int`, defaults to 0):
|
| 192 |
-
An offset added to the inference steps, as required by some model families.
|
| 193 |
-
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
| 194 |
-
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
| 195 |
-
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
| 196 |
-
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
| 197 |
-
"""
|
| 198 |
-
|
| 199 |
-
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 200 |
-
order = 1
|
| 201 |
-
|
| 202 |
-
@register_to_config
|
| 203 |
-
def __init__(
|
| 204 |
-
self,
|
| 205 |
-
num_train_timesteps: int = 1000,
|
| 206 |
-
beta_start: float = 0.0001,
|
| 207 |
-
beta_end: float = 0.02,
|
| 208 |
-
beta_schedule: str = "linear",
|
| 209 |
-
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
| 210 |
-
solver_order: int = 2,
|
| 211 |
-
prediction_type: str = "epsilon",
|
| 212 |
-
thresholding: bool = False,
|
| 213 |
-
dynamic_thresholding_ratio: float = 0.995,
|
| 214 |
-
sample_max_value: float = 1.0,
|
| 215 |
-
algorithm_type: str = "dpmsolver++",
|
| 216 |
-
solver_type: str = "midpoint",
|
| 217 |
-
lower_order_final: bool = True,
|
| 218 |
-
euler_at_final: bool = False,
|
| 219 |
-
use_karras_sigmas: Optional[bool] = False,
|
| 220 |
-
use_lu_lambdas: Optional[bool] = False,
|
| 221 |
-
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
| 222 |
-
lambda_min_clipped: float = -float("inf"),
|
| 223 |
-
variance_type: Optional[str] = None,
|
| 224 |
-
timestep_spacing: str = "linspace",
|
| 225 |
-
steps_offset: int = 0,
|
| 226 |
-
rescale_betas_zero_snr: bool = False,
|
| 227 |
-
):
|
| 228 |
-
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
| 229 |
-
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
| 230 |
-
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
| 231 |
-
|
| 232 |
-
if trained_betas is not None:
|
| 233 |
-
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
| 234 |
-
elif beta_schedule == "linear":
|
| 235 |
-
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
| 236 |
-
elif beta_schedule == "scaled_linear":
|
| 237 |
-
# this schedule is very specific to the latent diffusion model.
|
| 238 |
-
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
| 239 |
-
elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine":
|
| 240 |
-
# Glide cosine schedule
|
| 241 |
-
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
|
| 242 |
-
elif beta_schedule == "cauchy":
|
| 243 |
-
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cauchy")
|
| 244 |
-
elif beta_schedule == "laplace":
|
| 245 |
-
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
|
| 246 |
-
else:
|
| 247 |
-
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
| 248 |
-
|
| 249 |
-
if rescale_betas_zero_snr:
|
| 250 |
-
self.betas = rescale_zero_terminal_snr(self.betas)
|
| 251 |
-
|
| 252 |
-
self.alphas = 1.0 - self.betas
|
| 253 |
-
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 254 |
-
|
| 255 |
-
if rescale_betas_zero_snr:
|
| 256 |
-
# Close to 0 without being 0 so first sigma is not inf
|
| 257 |
-
# FP16 smallest positive subnormal works well here
|
| 258 |
-
self.alphas_cumprod[-1] = 2**-24
|
| 259 |
-
|
| 260 |
-
# Currently we only support VP-type noise schedule
|
| 261 |
-
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
| 262 |
-
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
| 263 |
-
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
| 264 |
-
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
| 265 |
-
|
| 266 |
-
# standard deviation of the initial noise distribution
|
| 267 |
-
self.init_noise_sigma = 1.0
|
| 268 |
-
|
| 269 |
-
# settings for DPM-Solver
|
| 270 |
-
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
| 271 |
-
if algorithm_type == "deis":
|
| 272 |
-
self.register_to_config(algorithm_type="dpmsolver++")
|
| 273 |
-
else:
|
| 274 |
-
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
|
| 275 |
-
|
| 276 |
-
if solver_type not in ["midpoint", "heun"]:
|
| 277 |
-
if solver_type in ["logrho", "bh1", "bh2"]:
|
| 278 |
-
self.register_to_config(solver_type="midpoint")
|
| 279 |
-
else:
|
| 280 |
-
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
| 281 |
-
|
| 282 |
-
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
| 283 |
-
raise ValueError(
|
| 284 |
-
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
|
| 285 |
-
)
|
| 286 |
-
|
| 287 |
-
# setable values
|
| 288 |
-
self.num_inference_steps = None
|
| 289 |
-
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
| 290 |
-
self.timesteps = torch.from_numpy(timesteps)
|
| 291 |
-
self.model_outputs = [None] * solver_order
|
| 292 |
-
self.lower_order_nums = 0
|
| 293 |
-
self._step_index = None
|
| 294 |
-
self._begin_index = None
|
| 295 |
-
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 296 |
-
|
| 297 |
-
@property
|
| 298 |
-
def step_index(self):
|
| 299 |
-
"""
|
| 300 |
-
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 301 |
-
"""
|
| 302 |
-
return self._step_index
|
| 303 |
-
|
| 304 |
-
@property
|
| 305 |
-
def begin_index(self):
|
| 306 |
-
"""
|
| 307 |
-
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 308 |
-
"""
|
| 309 |
-
return self._begin_index
|
| 310 |
-
|
| 311 |
-
def set_begin_index(self, begin_index: int = 0):
|
| 312 |
-
"""
|
| 313 |
-
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 314 |
-
|
| 315 |
-
Args:
|
| 316 |
-
begin_index (`int`):
|
| 317 |
-
The begin index for the scheduler.
|
| 318 |
-
"""
|
| 319 |
-
self._begin_index = begin_index
|
| 320 |
-
|
| 321 |
-
def set_timesteps(
|
| 322 |
-
self,
|
| 323 |
-
num_inference_steps: int = None,
|
| 324 |
-
device: Union[str, torch.device] = None,
|
| 325 |
-
timesteps: Optional[List[int]] = None,
|
| 326 |
-
):
|
| 327 |
-
"""
|
| 328 |
-
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 329 |
-
|
| 330 |
-
Args:
|
| 331 |
-
num_inference_steps (`int`):
|
| 332 |
-
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 333 |
-
device (`str` or `torch.device`, *optional*):
|
| 334 |
-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 335 |
-
timesteps (`List[int]`, *optional*):
|
| 336 |
-
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
| 337 |
-
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
| 338 |
-
must be `None`, and `timestep_spacing` attribute will be ignored.
|
| 339 |
-
"""
|
| 340 |
-
if num_inference_steps is None and timesteps is None:
|
| 341 |
-
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
| 342 |
-
if num_inference_steps is not None and timesteps is not None:
|
| 343 |
-
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
| 344 |
-
if timesteps is not None and self.config.use_karras_sigmas:
|
| 345 |
-
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
|
| 346 |
-
if timesteps is not None and self.config.use_lu_lambdas:
|
| 347 |
-
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
|
| 348 |
-
|
| 349 |
-
if timesteps is not None:
|
| 350 |
-
timesteps = np.array(timesteps).astype(np.int64)
|
| 351 |
-
else:
|
| 352 |
-
# Clipping the minimum of all lambda(t) for numerical stability.
|
| 353 |
-
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
|
| 354 |
-
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
|
| 355 |
-
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
|
| 356 |
-
|
| 357 |
-
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
| 358 |
-
if self.config.timestep_spacing == "linspace":
|
| 359 |
-
timesteps = (
|
| 360 |
-
np.linspace(0, last_timestep - 1, num_inference_steps + 1)
|
| 361 |
-
.round()[::-1][:-1]
|
| 362 |
-
.copy()
|
| 363 |
-
.astype(np.int64)
|
| 364 |
-
)
|
| 365 |
-
elif self.config.timestep_spacing == "leading":
|
| 366 |
-
step_ratio = last_timestep // (num_inference_steps + 1)
|
| 367 |
-
# creates integer timesteps by multiplying by ratio
|
| 368 |
-
# casting to int to avoid issues when num_inference_step is power of 3
|
| 369 |
-
timesteps = (
|
| 370 |
-
(np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
|
| 371 |
-
)
|
| 372 |
-
timesteps += self.config.steps_offset
|
| 373 |
-
elif self.config.timestep_spacing == "trailing":
|
| 374 |
-
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
| 375 |
-
# creates integer timesteps by multiplying by ratio
|
| 376 |
-
# casting to int to avoid issues when num_inference_step is power of 3
|
| 377 |
-
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
|
| 378 |
-
timesteps -= 1
|
| 379 |
-
else:
|
| 380 |
-
raise ValueError(
|
| 381 |
-
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
| 382 |
-
)
|
| 383 |
-
|
| 384 |
-
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
| 385 |
-
log_sigmas = np.log(sigmas)
|
| 386 |
-
|
| 387 |
-
if self.config.use_karras_sigmas:
|
| 388 |
-
sigmas = np.flip(sigmas).copy()
|
| 389 |
-
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
| 390 |
-
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
| 391 |
-
elif self.config.use_lu_lambdas:
|
| 392 |
-
lambdas = np.flip(log_sigmas.copy())
|
| 393 |
-
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
|
| 394 |
-
sigmas = np.exp(lambdas)
|
| 395 |
-
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
| 396 |
-
else:
|
| 397 |
-
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
| 398 |
-
|
| 399 |
-
if self.config.final_sigmas_type == "sigma_min":
|
| 400 |
-
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
| 401 |
-
elif self.config.final_sigmas_type == "zero":
|
| 402 |
-
sigma_last = 0
|
| 403 |
-
else:
|
| 404 |
-
raise ValueError(
|
| 405 |
-
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
| 406 |
-
)
|
| 407 |
-
|
| 408 |
-
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
| 409 |
-
|
| 410 |
-
self.sigmas = torch.from_numpy(sigmas)
|
| 411 |
-
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
| 412 |
-
|
| 413 |
-
self.num_inference_steps = len(timesteps)
|
| 414 |
-
|
| 415 |
-
self.model_outputs = [
|
| 416 |
-
None,
|
| 417 |
-
] * self.config.solver_order
|
| 418 |
-
self.lower_order_nums = 0
|
| 419 |
-
|
| 420 |
-
# add an index counter for schedulers that allow duplicated timesteps
|
| 421 |
-
self._step_index = None
|
| 422 |
-
self._begin_index = None
|
| 423 |
-
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 424 |
-
|
| 425 |
-
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
| 426 |
-
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
| 427 |
-
"""
|
| 428 |
-
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
| 429 |
-
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
| 430 |
-
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
| 431 |
-
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
| 432 |
-
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
| 433 |
-
|
| 434 |
-
https://arxiv.org/abs/2205.11487
|
| 435 |
-
"""
|
| 436 |
-
dtype = sample.dtype
|
| 437 |
-
batch_size, channels, *remaining_dims = sample.shape
|
| 438 |
-
|
| 439 |
-
if dtype not in (torch.float32, torch.float64):
|
| 440 |
-
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
| 441 |
-
|
| 442 |
-
# Flatten sample for doing quantile calculation along each image
|
| 443 |
-
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
| 444 |
-
|
| 445 |
-
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
| 446 |
-
|
| 447 |
-
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
| 448 |
-
s = torch.clamp(
|
| 449 |
-
s, min=1, max=self.config.sample_max_value
|
| 450 |
-
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
| 451 |
-
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
| 452 |
-
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
| 453 |
-
|
| 454 |
-
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
| 455 |
-
sample = sample.to(dtype)
|
| 456 |
-
|
| 457 |
-
return sample
|
| 458 |
-
|
| 459 |
-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
| 460 |
-
def _sigma_to_t(self, sigma, log_sigmas):
|
| 461 |
-
# get log sigma
|
| 462 |
-
log_sigma = np.log(np.maximum(sigma, 1e-10))
|
| 463 |
-
|
| 464 |
-
# get distribution
|
| 465 |
-
dists = log_sigma - log_sigmas[:, np.newaxis]
|
| 466 |
-
|
| 467 |
-
# get sigmas range
|
| 468 |
-
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
| 469 |
-
high_idx = low_idx + 1
|
| 470 |
-
|
| 471 |
-
low = log_sigmas[low_idx]
|
| 472 |
-
high = log_sigmas[high_idx]
|
| 473 |
-
|
| 474 |
-
# interpolate sigmas
|
| 475 |
-
w = (low - log_sigma) / (low - high)
|
| 476 |
-
w = np.clip(w, 0, 1)
|
| 477 |
-
|
| 478 |
-
# transform interpolation to time range
|
| 479 |
-
t = (1 - w) * low_idx + w * high_idx
|
| 480 |
-
t = t.reshape(sigma.shape)
|
| 481 |
-
return t
|
| 482 |
-
|
| 483 |
-
def _sigma_to_alpha_sigma_t(self, sigma):
|
| 484 |
-
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
| 485 |
-
sigma_t = sigma * alpha_t
|
| 486 |
-
|
| 487 |
-
return alpha_t, sigma_t
|
| 488 |
-
|
| 489 |
-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
| 490 |
-
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
| 491 |
-
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 492 |
-
|
| 493 |
-
# Hack to make sure that other schedulers which copy this function don't break
|
| 494 |
-
# TODO: Add this logic to the other schedulers
|
| 495 |
-
if hasattr(self.config, "sigma_min"):
|
| 496 |
-
sigma_min = self.config.sigma_min
|
| 497 |
-
else:
|
| 498 |
-
sigma_min = None
|
| 499 |
-
|
| 500 |
-
if hasattr(self.config, "sigma_max"):
|
| 501 |
-
sigma_max = self.config.sigma_max
|
| 502 |
-
else:
|
| 503 |
-
sigma_max = None
|
| 504 |
-
|
| 505 |
-
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
| 506 |
-
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
| 507 |
-
|
| 508 |
-
rho = 7.0 # 7.0 is the value used in the paper
|
| 509 |
-
ramp = np.linspace(0, 1, num_inference_steps)
|
| 510 |
-
min_inv_rho = sigma_min ** (1 / rho)
|
| 511 |
-
max_inv_rho = sigma_max ** (1 / rho)
|
| 512 |
-
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 513 |
-
return sigmas
|
| 514 |
-
|
| 515 |
-
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
| 516 |
-
"""Constructs the noise schedule of Lu et al. (2022)."""
|
| 517 |
-
|
| 518 |
-
lambda_min: float = in_lambdas[-1].item()
|
| 519 |
-
lambda_max: float = in_lambdas[0].item()
|
| 520 |
-
|
| 521 |
-
rho = 1.0 # 1.0 is the value used in the paper
|
| 522 |
-
ramp = np.linspace(0, 1, num_inference_steps)
|
| 523 |
-
min_inv_rho = lambda_min ** (1 / rho)
|
| 524 |
-
max_inv_rho = lambda_max ** (1 / rho)
|
| 525 |
-
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 526 |
-
return lambdas
|
| 527 |
-
|
| 528 |
-
def convert_model_output(
|
| 529 |
-
self,
|
| 530 |
-
model_output: torch.Tensor,
|
| 531 |
-
*args,
|
| 532 |
-
sample: torch.Tensor = None,
|
| 533 |
-
**kwargs,
|
| 534 |
-
) -> torch.Tensor:
|
| 535 |
-
"""
|
| 536 |
-
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
| 537 |
-
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
| 538 |
-
integral of the data prediction model.
|
| 539 |
-
|
| 540 |
-
<Tip>
|
| 541 |
-
|
| 542 |
-
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
|
| 543 |
-
prediction and data prediction models.
|
| 544 |
-
|
| 545 |
-
</Tip>
|
| 546 |
-
|
| 547 |
-
Args:
|
| 548 |
-
model_output (`torch.Tensor`):
|
| 549 |
-
The direct output from the learned diffusion model.
|
| 550 |
-
sample (`torch.Tensor`):
|
| 551 |
-
A current instance of a sample created by the diffusion process.
|
| 552 |
-
|
| 553 |
-
Returns:
|
| 554 |
-
`torch.Tensor`:
|
| 555 |
-
The converted model output.
|
| 556 |
-
"""
|
| 557 |
-
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 558 |
-
if sample is None:
|
| 559 |
-
if len(args) > 1:
|
| 560 |
-
sample = args[1]
|
| 561 |
-
else:
|
| 562 |
-
raise ValueError("missing `sample` as a required keyward argument")
|
| 563 |
-
if timestep is not None:
|
| 564 |
-
deprecate(
|
| 565 |
-
"timesteps",
|
| 566 |
-
"1.0.0",
|
| 567 |
-
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 568 |
-
)
|
| 569 |
-
|
| 570 |
-
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
| 571 |
-
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
| 572 |
-
if self.config.prediction_type == "epsilon":
|
| 573 |
-
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
| 574 |
-
if self.config.variance_type in ["learned", "learned_range"]:
|
| 575 |
-
model_output = model_output[:, :3]
|
| 576 |
-
sigma = self.sigmas[self.step_index]
|
| 577 |
-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 578 |
-
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
| 579 |
-
elif self.config.prediction_type == "sample":
|
| 580 |
-
x0_pred = model_output
|
| 581 |
-
elif self.config.prediction_type == "v_prediction":
|
| 582 |
-
sigma = self.sigmas[self.step_index]
|
| 583 |
-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 584 |
-
x0_pred = alpha_t * sample - sigma_t * model_output
|
| 585 |
-
else:
|
| 586 |
-
raise ValueError(
|
| 587 |
-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
| 588 |
-
" `v_prediction` for the DPMSolverMultistepScheduler."
|
| 589 |
-
)
|
| 590 |
-
|
| 591 |
-
if self.config.thresholding:
|
| 592 |
-
x0_pred = self._threshold_sample(x0_pred)
|
| 593 |
-
|
| 594 |
-
return x0_pred
|
| 595 |
-
|
| 596 |
-
# DPM-Solver needs to solve an integral of the noise prediction model.
|
| 597 |
-
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
| 598 |
-
if self.config.prediction_type == "epsilon":
|
| 599 |
-
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
| 600 |
-
if self.config.variance_type in ["learned", "learned_range"]:
|
| 601 |
-
epsilon = model_output[:, :3]
|
| 602 |
-
else:
|
| 603 |
-
epsilon = model_output
|
| 604 |
-
elif self.config.prediction_type == "sample":
|
| 605 |
-
sigma = self.sigmas[self.step_index]
|
| 606 |
-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 607 |
-
epsilon = (sample - alpha_t * model_output) / sigma_t
|
| 608 |
-
elif self.config.prediction_type == "v_prediction":
|
| 609 |
-
sigma = self.sigmas[self.step_index]
|
| 610 |
-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 611 |
-
epsilon = alpha_t * model_output + sigma_t * sample
|
| 612 |
-
else:
|
| 613 |
-
raise ValueError(
|
| 614 |
-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
| 615 |
-
" `v_prediction` for the DPMSolverMultistepScheduler."
|
| 616 |
-
)
|
| 617 |
-
|
| 618 |
-
if self.config.thresholding:
|
| 619 |
-
sigma = self.sigmas[self.step_index]
|
| 620 |
-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 621 |
-
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
| 622 |
-
x0_pred = self._threshold_sample(x0_pred)
|
| 623 |
-
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
| 624 |
-
|
| 625 |
-
return epsilon
|
| 626 |
-
|
| 627 |
-
def dpm_solver_first_order_update(
|
| 628 |
-
self,
|
| 629 |
-
model_output: torch.Tensor,
|
| 630 |
-
*args,
|
| 631 |
-
sample: torch.Tensor = None,
|
| 632 |
-
noise: Optional[torch.Tensor] = None,
|
| 633 |
-
**kwargs,
|
| 634 |
-
) -> torch.Tensor:
|
| 635 |
-
"""
|
| 636 |
-
One step for the first-order DPMSolver (equivalent to DDIM).
|
| 637 |
-
|
| 638 |
-
Args:
|
| 639 |
-
model_output (`torch.Tensor`):
|
| 640 |
-
The direct output from the learned diffusion model.
|
| 641 |
-
sample (`torch.Tensor`):
|
| 642 |
-
A current instance of a sample created by the diffusion process.
|
| 643 |
-
|
| 644 |
-
Returns:
|
| 645 |
-
`torch.Tensor`:
|
| 646 |
-
The sample tensor at the previous timestep.
|
| 647 |
-
"""
|
| 648 |
-
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 649 |
-
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
| 650 |
-
if sample is None:
|
| 651 |
-
if len(args) > 2:
|
| 652 |
-
sample = args[2]
|
| 653 |
-
else:
|
| 654 |
-
raise ValueError(" missing `sample` as a required keyward argument")
|
| 655 |
-
if timestep is not None:
|
| 656 |
-
deprecate(
|
| 657 |
-
"timesteps",
|
| 658 |
-
"1.0.0",
|
| 659 |
-
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 660 |
-
)
|
| 661 |
-
|
| 662 |
-
if prev_timestep is not None:
|
| 663 |
-
deprecate(
|
| 664 |
-
"prev_timestep",
|
| 665 |
-
"1.0.0",
|
| 666 |
-
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 667 |
-
)
|
| 668 |
-
|
| 669 |
-
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
| 670 |
-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 671 |
-
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
| 672 |
-
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 673 |
-
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
| 674 |
-
|
| 675 |
-
h = lambda_t - lambda_s
|
| 676 |
-
if self.config.algorithm_type == "dpmsolver++":
|
| 677 |
-
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
| 678 |
-
elif self.config.algorithm_type == "dpmsolver":
|
| 679 |
-
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
| 680 |
-
elif self.config.algorithm_type == "sde-dpmsolver++":
|
| 681 |
-
assert noise is not None
|
| 682 |
-
x_t = (
|
| 683 |
-
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
| 684 |
-
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
| 685 |
-
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 686 |
-
)
|
| 687 |
-
elif self.config.algorithm_type == "sde-dpmsolver":
|
| 688 |
-
assert noise is not None
|
| 689 |
-
x_t = (
|
| 690 |
-
(alpha_t / alpha_s) * sample
|
| 691 |
-
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
| 692 |
-
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
| 693 |
-
)
|
| 694 |
-
return x_t
|
| 695 |
-
|
| 696 |
-
def multistep_dpm_solver_second_order_update(
|
| 697 |
-
self,
|
| 698 |
-
model_output_list: List[torch.Tensor],
|
| 699 |
-
*args,
|
| 700 |
-
sample: torch.Tensor = None,
|
| 701 |
-
noise: Optional[torch.Tensor] = None,
|
| 702 |
-
**kwargs,
|
| 703 |
-
) -> torch.Tensor:
|
| 704 |
-
"""
|
| 705 |
-
One step for the second-order multistep DPMSolver.
|
| 706 |
-
|
| 707 |
-
Args:
|
| 708 |
-
model_output_list (`List[torch.Tensor]`):
|
| 709 |
-
The direct outputs from learned diffusion model at current and latter timesteps.
|
| 710 |
-
sample (`torch.Tensor`):
|
| 711 |
-
A current instance of a sample created by the diffusion process.
|
| 712 |
-
|
| 713 |
-
Returns:
|
| 714 |
-
`torch.Tensor`:
|
| 715 |
-
The sample tensor at the previous timestep.
|
| 716 |
-
"""
|
| 717 |
-
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
| 718 |
-
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
| 719 |
-
if sample is None:
|
| 720 |
-
if len(args) > 2:
|
| 721 |
-
sample = args[2]
|
| 722 |
-
else:
|
| 723 |
-
raise ValueError(" missing `sample` as a required keyward argument")
|
| 724 |
-
if timestep_list is not None:
|
| 725 |
-
deprecate(
|
| 726 |
-
"timestep_list",
|
| 727 |
-
"1.0.0",
|
| 728 |
-
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 729 |
-
)
|
| 730 |
-
|
| 731 |
-
if prev_timestep is not None:
|
| 732 |
-
deprecate(
|
| 733 |
-
"prev_timestep",
|
| 734 |
-
"1.0.0",
|
| 735 |
-
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 736 |
-
)
|
| 737 |
-
|
| 738 |
-
sigma_t, sigma_s0, sigma_s1 = (
|
| 739 |
-
self.sigmas[self.step_index + 1],
|
| 740 |
-
self.sigmas[self.step_index],
|
| 741 |
-
self.sigmas[self.step_index - 1],
|
| 742 |
-
)
|
| 743 |
-
|
| 744 |
-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 745 |
-
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 746 |
-
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
| 747 |
-
|
| 748 |
-
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 749 |
-
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 750 |
-
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
| 751 |
-
|
| 752 |
-
m0, m1 = model_output_list[-1], model_output_list[-2]
|
| 753 |
-
|
| 754 |
-
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
| 755 |
-
r0 = h_0 / h
|
| 756 |
-
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
| 757 |
-
if self.config.algorithm_type == "dpmsolver++":
|
| 758 |
-
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
| 759 |
-
if self.config.solver_type == "midpoint":
|
| 760 |
-
x_t = (
|
| 761 |
-
(sigma_t / sigma_s0) * sample
|
| 762 |
-
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
| 763 |
-
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
|
| 764 |
-
)
|
| 765 |
-
elif self.config.solver_type == "heun":
|
| 766 |
-
x_t = (
|
| 767 |
-
(sigma_t / sigma_s0) * sample
|
| 768 |
-
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
| 769 |
-
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
| 770 |
-
)
|
| 771 |
-
elif self.config.algorithm_type == "dpmsolver":
|
| 772 |
-
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
| 773 |
-
if self.config.solver_type == "midpoint":
|
| 774 |
-
x_t = (
|
| 775 |
-
(alpha_t / alpha_s0) * sample
|
| 776 |
-
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 777 |
-
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
|
| 778 |
-
)
|
| 779 |
-
elif self.config.solver_type == "heun":
|
| 780 |
-
x_t = (
|
| 781 |
-
(alpha_t / alpha_s0) * sample
|
| 782 |
-
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 783 |
-
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
| 784 |
-
)
|
| 785 |
-
elif self.config.algorithm_type == "sde-dpmsolver++":
|
| 786 |
-
assert noise is not None
|
| 787 |
-
if self.config.solver_type == "midpoint":
|
| 788 |
-
x_t = (
|
| 789 |
-
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
| 790 |
-
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
| 791 |
-
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
| 792 |
-
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 793 |
-
)
|
| 794 |
-
elif self.config.solver_type == "heun":
|
| 795 |
-
x_t = (
|
| 796 |
-
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
| 797 |
-
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
| 798 |
-
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
| 799 |
-
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 800 |
-
)
|
| 801 |
-
elif self.config.algorithm_type == "sde-dpmsolver":
|
| 802 |
-
assert noise is not None
|
| 803 |
-
if self.config.solver_type == "midpoint":
|
| 804 |
-
x_t = (
|
| 805 |
-
(alpha_t / alpha_s0) * sample
|
| 806 |
-
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 807 |
-
- (sigma_t * (torch.exp(h) - 1.0)) * D1
|
| 808 |
-
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
| 809 |
-
)
|
| 810 |
-
elif self.config.solver_type == "heun":
|
| 811 |
-
x_t = (
|
| 812 |
-
(alpha_t / alpha_s0) * sample
|
| 813 |
-
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 814 |
-
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
| 815 |
-
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
| 816 |
-
)
|
| 817 |
-
return x_t
|
| 818 |
-
|
| 819 |
-
def multistep_dpm_solver_third_order_update(
|
| 820 |
-
self,
|
| 821 |
-
model_output_list: List[torch.Tensor],
|
| 822 |
-
*args,
|
| 823 |
-
sample: torch.Tensor = None,
|
| 824 |
-
**kwargs,
|
| 825 |
-
) -> torch.Tensor:
|
| 826 |
-
"""
|
| 827 |
-
One step for the third-order multistep DPMSolver.
|
| 828 |
-
|
| 829 |
-
Args:
|
| 830 |
-
model_output_list (`List[torch.Tensor]`):
|
| 831 |
-
The direct outputs from learned diffusion model at current and latter timesteps.
|
| 832 |
-
sample (`torch.Tensor`):
|
| 833 |
-
A current instance of a sample created by diffusion process.
|
| 834 |
-
|
| 835 |
-
Returns:
|
| 836 |
-
`torch.Tensor`:
|
| 837 |
-
The sample tensor at the previous timestep.
|
| 838 |
-
"""
|
| 839 |
-
|
| 840 |
-
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
| 841 |
-
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
| 842 |
-
if sample is None:
|
| 843 |
-
if len(args) > 2:
|
| 844 |
-
sample = args[2]
|
| 845 |
-
else:
|
| 846 |
-
raise ValueError(" missing`sample` as a required keyward argument")
|
| 847 |
-
if timestep_list is not None:
|
| 848 |
-
deprecate(
|
| 849 |
-
"timestep_list",
|
| 850 |
-
"1.0.0",
|
| 851 |
-
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 852 |
-
)
|
| 853 |
-
|
| 854 |
-
if prev_timestep is not None:
|
| 855 |
-
deprecate(
|
| 856 |
-
"prev_timestep",
|
| 857 |
-
"1.0.0",
|
| 858 |
-
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 859 |
-
)
|
| 860 |
-
|
| 861 |
-
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
| 862 |
-
self.sigmas[self.step_index + 1],
|
| 863 |
-
self.sigmas[self.step_index],
|
| 864 |
-
self.sigmas[self.step_index - 1],
|
| 865 |
-
self.sigmas[self.step_index - 2],
|
| 866 |
-
)
|
| 867 |
-
|
| 868 |
-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 869 |
-
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 870 |
-
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
| 871 |
-
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
| 872 |
-
|
| 873 |
-
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 874 |
-
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 875 |
-
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
| 876 |
-
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
| 877 |
-
|
| 878 |
-
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
| 879 |
-
|
| 880 |
-
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
| 881 |
-
r0, r1 = h_0 / h, h_1 / h
|
| 882 |
-
D0 = m0
|
| 883 |
-
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
| 884 |
-
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
| 885 |
-
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
| 886 |
-
if self.config.algorithm_type == "dpmsolver++":
|
| 887 |
-
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
| 888 |
-
x_t = (
|
| 889 |
-
(sigma_t / sigma_s0) * sample
|
| 890 |
-
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
| 891 |
-
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
| 892 |
-
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
| 893 |
-
)
|
| 894 |
-
elif self.config.algorithm_type == "dpmsolver":
|
| 895 |
-
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
| 896 |
-
x_t = (
|
| 897 |
-
(alpha_t / alpha_s0) * sample
|
| 898 |
-
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 899 |
-
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
| 900 |
-
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
| 901 |
-
)
|
| 902 |
-
return x_t
|
| 903 |
-
|
| 904 |
-
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 905 |
-
if schedule_timesteps is None:
|
| 906 |
-
schedule_timesteps = self.timesteps
|
| 907 |
-
|
| 908 |
-
index_candidates = (schedule_timesteps == timestep).nonzero()
|
| 909 |
-
|
| 910 |
-
if len(index_candidates) == 0:
|
| 911 |
-
step_index = len(self.timesteps) - 1
|
| 912 |
-
# The sigma index that is taken for the **very** first `step`
|
| 913 |
-
# is always the second index (or the last index if there is only 1)
|
| 914 |
-
# This way we can ensure we don't accidentally skip a sigma in
|
| 915 |
-
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 916 |
-
elif len(index_candidates) > 1:
|
| 917 |
-
step_index = index_candidates[1].item()
|
| 918 |
-
else:
|
| 919 |
-
step_index = index_candidates[0].item()
|
| 920 |
-
|
| 921 |
-
return step_index
|
| 922 |
-
|
| 923 |
-
def _init_step_index(self, timestep):
|
| 924 |
-
"""
|
| 925 |
-
Initialize the step_index counter for the scheduler.
|
| 926 |
-
"""
|
| 927 |
-
|
| 928 |
-
if self.begin_index is None:
|
| 929 |
-
if isinstance(timestep, torch.Tensor):
|
| 930 |
-
timestep = timestep.to(self.timesteps.device)
|
| 931 |
-
self._step_index = self.index_for_timestep(timestep)
|
| 932 |
-
else:
|
| 933 |
-
self._step_index = self._begin_index
|
| 934 |
-
|
| 935 |
-
def step(
|
| 936 |
-
self,
|
| 937 |
-
model_output: torch.Tensor,
|
| 938 |
-
timestep: int,
|
| 939 |
-
sample: torch.Tensor,
|
| 940 |
-
generator=None,
|
| 941 |
-
variance_noise: Optional[torch.Tensor] = None,
|
| 942 |
-
return_dict: bool = True,
|
| 943 |
-
) -> Union[SchedulerOutput, Tuple]:
|
| 944 |
-
"""
|
| 945 |
-
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
| 946 |
-
the multistep DPMSolver.
|
| 947 |
-
|
| 948 |
-
Args:
|
| 949 |
-
model_output (`torch.Tensor`):
|
| 950 |
-
The direct output from learned diffusion model.
|
| 951 |
-
timestep (`int`):
|
| 952 |
-
The current discrete timestep in the diffusion chain.
|
| 953 |
-
sample (`torch.Tensor`):
|
| 954 |
-
A current instance of a sample created by the diffusion process.
|
| 955 |
-
generator (`torch.Generator`, *optional*):
|
| 956 |
-
A random number generator.
|
| 957 |
-
variance_noise (`torch.Tensor`):
|
| 958 |
-
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
| 959 |
-
itself. Useful for methods such as [`LEdits++`].
|
| 960 |
-
return_dict (`bool`):
|
| 961 |
-
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
| 962 |
-
|
| 963 |
-
Returns:
|
| 964 |
-
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
| 965 |
-
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
| 966 |
-
tuple is returned where the first element is the sample tensor.
|
| 967 |
-
|
| 968 |
-
"""
|
| 969 |
-
if self.num_inference_steps is None:
|
| 970 |
-
raise ValueError(
|
| 971 |
-
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 972 |
-
)
|
| 973 |
-
|
| 974 |
-
if self.step_index is None:
|
| 975 |
-
self._init_step_index(timestep)
|
| 976 |
-
|
| 977 |
-
# Improve numerical stability for small number of steps
|
| 978 |
-
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
|
| 979 |
-
self.config.euler_at_final
|
| 980 |
-
or (self.config.lower_order_final and len(self.timesteps) < 15)
|
| 981 |
-
or self.config.final_sigmas_type == "zero"
|
| 982 |
-
)
|
| 983 |
-
lower_order_second = (
|
| 984 |
-
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
| 985 |
-
)
|
| 986 |
-
|
| 987 |
-
model_output = self.convert_model_output(model_output, sample=sample)
|
| 988 |
-
for i in range(self.config.solver_order - 1):
|
| 989 |
-
self.model_outputs[i] = self.model_outputs[i + 1]
|
| 990 |
-
self.model_outputs[-1] = model_output
|
| 991 |
-
|
| 992 |
-
# Upcast to avoid precision issues when computing prev_sample
|
| 993 |
-
sample = sample.to(torch.float32)
|
| 994 |
-
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
| 995 |
-
noise = randn_tensor(
|
| 996 |
-
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
|
| 997 |
-
)
|
| 998 |
-
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
| 999 |
-
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
|
| 1000 |
-
else:
|
| 1001 |
-
noise = None
|
| 1002 |
-
|
| 1003 |
-
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
| 1004 |
-
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
| 1005 |
-
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
| 1006 |
-
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
| 1007 |
-
else:
|
| 1008 |
-
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
|
| 1009 |
-
|
| 1010 |
-
if self.lower_order_nums < self.config.solver_order:
|
| 1011 |
-
self.lower_order_nums += 1
|
| 1012 |
-
|
| 1013 |
-
# Cast sample back to expected dtype
|
| 1014 |
-
prev_sample = prev_sample.to(model_output.dtype)
|
| 1015 |
-
|
| 1016 |
-
# upon completion increase step index by one
|
| 1017 |
-
self._step_index += 1
|
| 1018 |
-
|
| 1019 |
-
if not return_dict:
|
| 1020 |
-
return (prev_sample,)
|
| 1021 |
-
|
| 1022 |
-
return SchedulerOutput(prev_sample=prev_sample)
|
| 1023 |
-
|
| 1024 |
-
def add_noise(
|
| 1025 |
-
self,
|
| 1026 |
-
original_samples: torch.Tensor,
|
| 1027 |
-
noise: torch.Tensor,
|
| 1028 |
-
timesteps: torch.IntTensor,
|
| 1029 |
-
) -> torch.Tensor:
|
| 1030 |
-
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 1031 |
-
# alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 1032 |
-
# sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 1033 |
-
alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
|
| 1034 |
-
sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
|
| 1035 |
-
timesteps = timesteps.to(original_samples.device)
|
| 1036 |
-
alpha_t = alpha_t[timesteps].flatten()
|
| 1037 |
-
while len(alpha_t.shape) < len(original_samples.shape):
|
| 1038 |
-
alpha_t = alpha_t.unsqueeze(-1)
|
| 1039 |
-
|
| 1040 |
-
sigma_t = sigma_t[timesteps].flatten()
|
| 1041 |
-
while len(sigma_t.shape) < len(original_samples.shape):
|
| 1042 |
-
sigma_t = sigma_t.unsqueeze(-1)
|
| 1043 |
-
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
| 1044 |
-
return noisy_samples
|
| 1045 |
-
|
| 1046 |
-
def get_velocity(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
| 1047 |
-
# alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 1048 |
-
# sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 1049 |
-
alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
|
| 1050 |
-
sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
|
| 1051 |
-
|
| 1052 |
-
timesteps = timesteps.to(original_samples.device)
|
| 1053 |
-
alpha_t = alpha_t[timesteps].flatten()
|
| 1054 |
-
while len(alpha_t.shape) < len(original_samples.shape):
|
| 1055 |
-
alpha_t = alpha_t.unsqueeze(-1)
|
| 1056 |
-
|
| 1057 |
-
sigma_t = sigma_t[timesteps].flatten()
|
| 1058 |
-
while len(sigma_t.shape) < len(original_samples.shape):
|
| 1059 |
-
sigma_t = sigma_t.unsqueeze(-1)
|
| 1060 |
-
|
| 1061 |
-
velocity = alpha_t * noise - sigma_t * original_samples
|
| 1062 |
-
return velocity
|
| 1063 |
-
|
| 1064 |
-
def __len__(self):
|
| 1065 |
-
return self.config.num_train_timesteps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/schedule/timestep_sampler.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import torch
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class UniformSampler:
|
| 6 |
-
def __init__(self, timesteps = 1000):
|
| 7 |
-
self.timesteps = timesteps
|
| 8 |
-
def sample(self, batch_size, device):
|
| 9 |
-
return torch.randint(0, self.timesteps, (batch_size,), device=device)
|
| 10 |
-
|
| 11 |
-
class LogitNormalSampler:
|
| 12 |
-
def __init__(self, timesteps = 1000, m = 0, s = 1):
|
| 13 |
-
self.timesteps = timesteps
|
| 14 |
-
timesteps = torch.linspace(0, 1, timesteps)
|
| 15 |
-
logit = torch.log(timesteps / (1 - timesteps))
|
| 16 |
-
self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi))
|
| 17 |
-
def sample(self, batch_size, device):
|
| 18 |
-
return torch.multinomial(self.prob, batch_size, replacement=True).to(device)
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/vibevoice/scripts/__init__.py
DELETED
|
File without changes
|
src/vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py
DELETED
|
@@ -1,166 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# coding=utf-8
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
import os
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
import re
|
| 9 |
-
import torch
|
| 10 |
-
from typing import Dict, List, Tuple
|
| 11 |
-
|
| 12 |
-
from vibevoice.modular.configuration_vibevoice import (
|
| 13 |
-
VibeVoiceConfig
|
| 14 |
-
)
|
| 15 |
-
from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
|
| 16 |
-
from transformers.utils import logging
|
| 17 |
-
|
| 18 |
-
logger = logging.get_logger(__name__)
|
| 19 |
-
|
| 20 |
-
def convert_vibevoice_nnscaler_checkpoint_to_hf(
|
| 21 |
-
checkpoint_path: str,
|
| 22 |
-
pytorch_dump_folder_path: str,
|
| 23 |
-
config_path: str = None,
|
| 24 |
-
):
|
| 25 |
-
"""
|
| 26 |
-
Convert a nnscaler VibeVoice checkpoint to HuggingFace format.
|
| 27 |
-
Supports both regular checkpoints and tensor parallel checkpoints.
|
| 28 |
-
"""
|
| 29 |
-
|
| 30 |
-
# Load regular checkpoint
|
| 31 |
-
logger.info(f"Loading regular checkpoint from {checkpoint_path}")
|
| 32 |
-
checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader']
|
| 33 |
-
|
| 34 |
-
# config = checkpoint['train_args']
|
| 35 |
-
init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path']
|
| 36 |
-
pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path']
|
| 37 |
-
|
| 38 |
-
init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1]
|
| 39 |
-
if init_config_path.exists():
|
| 40 |
-
logger.info(f"Loading initial config from {init_config_path}")
|
| 41 |
-
with open(init_config_path, 'r') as f:
|
| 42 |
-
init_config = json.load(f)
|
| 43 |
-
else:
|
| 44 |
-
raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.")
|
| 45 |
-
|
| 46 |
-
tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True)
|
| 47 |
-
logger.info(f"Tie word embeddings: {tie_word_embeddings}")
|
| 48 |
-
|
| 49 |
-
init_config['decoder_config']['use_cache'] = True
|
| 50 |
-
config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings)
|
| 51 |
-
|
| 52 |
-
# # Extract the model state dict
|
| 53 |
-
model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')}
|
| 54 |
-
if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys():
|
| 55 |
-
# If not tying weights, we need to add the lm_head weight separately
|
| 56 |
-
model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight']
|
| 57 |
-
|
| 58 |
-
# Override with provided config if available
|
| 59 |
-
if config_path:
|
| 60 |
-
logger.info(f"Loading config from {config_path}")
|
| 61 |
-
with open(config_path, 'r') as f:
|
| 62 |
-
config_dict = json.load(f)
|
| 63 |
-
config = VibeVoiceConfig.from_dict(config_dict)
|
| 64 |
-
|
| 65 |
-
# Set the default dtype to bfloat16 before creating the model
|
| 66 |
-
original_dtype = torch.get_default_dtype()
|
| 67 |
-
torch.set_default_dtype(torch.bfloat16)
|
| 68 |
-
|
| 69 |
-
# Create the HuggingFace model
|
| 70 |
-
logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model")
|
| 71 |
-
model = VibeVoiceForConditionalGeneration(config)
|
| 72 |
-
|
| 73 |
-
# Restore original dtype
|
| 74 |
-
torch.set_default_dtype(original_dtype)
|
| 75 |
-
|
| 76 |
-
# Load the state dict
|
| 77 |
-
logger.info("Loading weights into model")
|
| 78 |
-
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
|
| 79 |
-
|
| 80 |
-
if missing_keys:
|
| 81 |
-
logger.warning(f"Missing keys: {missing_keys}")
|
| 82 |
-
if unexpected_keys:
|
| 83 |
-
logger.warning(f"Unexpected keys: {unexpected_keys}")
|
| 84 |
-
|
| 85 |
-
# Create output directory
|
| 86 |
-
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
| 87 |
-
|
| 88 |
-
# Save the model and config
|
| 89 |
-
logger.info(f"Saving model to {pytorch_dump_folder_path}")
|
| 90 |
-
|
| 91 |
-
# Save config
|
| 92 |
-
config.save_pretrained(pytorch_dump_folder_path)
|
| 93 |
-
|
| 94 |
-
# Save VibeVoiceProcessor configuration
|
| 95 |
-
logger.info("Saving VibeVoiceProcessor configuration")
|
| 96 |
-
processor_config = {
|
| 97 |
-
"processor_class": "VibeVoiceProcessor",
|
| 98 |
-
"speech_tok_compress_ratio": 3200,
|
| 99 |
-
"db_normalize": True,
|
| 100 |
-
# Audio processor configuration
|
| 101 |
-
"audio_processor": {
|
| 102 |
-
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
|
| 103 |
-
"sampling_rate": 24000,
|
| 104 |
-
"normalize_audio": True,
|
| 105 |
-
"target_dB_FS": -25,
|
| 106 |
-
"eps": 1e-6,
|
| 107 |
-
},
|
| 108 |
-
"language_model_pretrained_name": pretrained_name,
|
| 109 |
-
}
|
| 110 |
-
|
| 111 |
-
processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json")
|
| 112 |
-
with open(processor_config_path, 'w') as f:
|
| 113 |
-
json.dump(processor_config, f, indent=2)
|
| 114 |
-
logger.info(f"Saved processor config to {processor_config_path}")
|
| 115 |
-
|
| 116 |
-
# Save model with sharding
|
| 117 |
-
# save_pretrained handles tied weights automatically
|
| 118 |
-
logger.info("Saving model weights with sharding...")
|
| 119 |
-
model.save_pretrained(
|
| 120 |
-
pytorch_dump_folder_path,
|
| 121 |
-
max_shard_size="2GB", # Set maximum size for each shard
|
| 122 |
-
safe_serialization=True # Ensure saving in .safetensors format
|
| 123 |
-
)
|
| 124 |
-
logger.info(f"Model weights saved to {pytorch_dump_folder_path}")
|
| 125 |
-
|
| 126 |
-
logger.info("Conversion complete!")
|
| 127 |
-
|
| 128 |
-
# Verify the saved model can be loaded
|
| 129 |
-
logger.info("Verifying saved model...")
|
| 130 |
-
loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path)
|
| 131 |
-
logger.info("Model successfully loaded from saved checkpoint!")
|
| 132 |
-
|
| 133 |
-
def main():
|
| 134 |
-
parser = argparse.ArgumentParser()
|
| 135 |
-
parser.add_argument(
|
| 136 |
-
"--nnscaler_checkpoint_path",
|
| 137 |
-
type=str,
|
| 138 |
-
required=True,
|
| 139 |
-
help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, "
|
| 140 |
-
"provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), "
|
| 141 |
-
"and the script will automatically detect and merge all parts.",
|
| 142 |
-
)
|
| 143 |
-
parser.add_argument(
|
| 144 |
-
"--pytorch_dump_folder_path",
|
| 145 |
-
type=str,
|
| 146 |
-
required=True,
|
| 147 |
-
help="Path to the output PyTorch model directory",
|
| 148 |
-
)
|
| 149 |
-
parser.add_argument(
|
| 150 |
-
"--config_path",
|
| 151 |
-
type=str,
|
| 152 |
-
default=None,
|
| 153 |
-
help="Optional path to a config JSON file to override extracted config",
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
args = parser.parse_args()
|
| 157 |
-
|
| 158 |
-
convert_vibevoice_nnscaler_checkpoint_to_hf(
|
| 159 |
-
args.nnscaler_checkpoint_path,
|
| 160 |
-
args.pytorch_dump_folder_path,
|
| 161 |
-
args.config_path,
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
if __name__ == "__main__":
|
| 166 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/voices/vibe_voices/en-Alice_woman.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c27ae47421436287a6bd2c3062de2dc2a2855b78c0bb626d472202c359704203
|
| 3 |
-
size 296684
|
|
|
|
|
|
|
|
|
|
|
|
src/voices/vibe_voices/en-Carter_man.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:9dd7b12f25bf279d878a9f7a3125f64bff2b312a189959090acff9138a55e8dd
|
| 3 |
-
size 1331244
|
|
|
|
|
|
|
|
|
|
|
|
src/voices/vibe_voices/en-Frank_man.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:aa77c4794a005c4b05a52bbce5f30e77f0d28987b9a9e737401a5d30fd1ebcb5
|
| 3 |
-
size 1158444
|
|
|
|
|
|
|
|
|
|
|
|
src/voices/vibe_voices/en-Mary_woman_bgm.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c421eeab1af5b3ddae8d14cfcf6b65e496047ad2228325d61d1b6967fca11700
|
| 3 |
-
size 1292878
|
|
|
|
|
|
|
|
|
|
|
|
src/voices/vibe_voices/en-Maya_woman.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:eb1288bc02546c7f1117698fb78e994f060e623af148be8ccbf93dd0bea79e32
|
| 3 |
-
size 1305644
|
|
|
|
|
|
|
|
|
|
|
|
src/voices/vibe_voices/in-Samuel_man.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:76b07b5a12ca0b24a1e4a88100c4e2e47a2552ebb96807d52f116cf05fc46b50
|
| 3 |
-
size 1273644
|
|
|
|
|
|
|
|
|
|
|
|
src/voices/vibe_voices/zh-Anchen_man_bgm.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f71aeb33ed66c449dedb75d8a505478d86d47ec49e0e4c33c1fd0f8324d781fb
|
| 3 |
-
size 1177644
|
|
|
|
|
|
|
|
|
|
|
|
src/voices/vibe_voices/zh-Bowen_man.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:0cef6c018e73e9fb6a1269fd61ded08144ae6380cdec242eebb1cc8aca49fed1
|
| 3 |
-
size 1419940
|
|
|
|
|
|
|
|
|
|
|
|
src/voices/vibe_voices/zh-Xinran_woman.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:dbcb9e28bcc544675ef75a8ba12528bf09e713eb53a8c0c819dec3daf2d486d3
|
| 3 |
-
size 1337644
|
|
|
|
|
|
|
|
|
|
|
|