Michael Hu commited on
Commit
d77f8ff
·
1 Parent(s): dda48cd

remove vibevoice

Browse files
README.md CHANGED
@@ -47,7 +47,6 @@ This demo showcases the multilingual capabilities of multiple TTS models, suppor
47
  - **Chatterbox**: Industrial-grade multilingual TTS solution
48
  - **KittenTTS**: High-quality TTS with voice cloning capabilities
49
  - **Piper**: Local on-device TTS with multiple voice options
50
- - **VibeVoice 1.5B**: Microsoft's advanced seq2seq TTS model
51
 
52
  ## Examples
53
 
 
47
  - **Chatterbox**: Industrial-grade multilingual TTS solution
48
  - **KittenTTS**: High-quality TTS with voice cloning capabilities
49
  - **Piper**: Local on-device TTS with multiple voice options
 
50
 
51
  ## Examples
52
 
app.py CHANGED
@@ -20,7 +20,6 @@ MODEL_DESCRIPTIONS = {
20
  "ResembleAI/chatterbox": "Industrial-grade TTS solution with multilingual support",
21
  "KittenML/KittenTTS": "High-quality TTS with voice cloning capabilities using reference audio",
22
  "piper-tts": "Local on-device TTS with dynamic English and Chinese voice selection from Piper models",
23
- "microsoft/VibeVoice-1.5B": "Microsoft's advanced seq2seq TTS model with high-quality speech synthesis",
24
  }
25
 
26
  # Models dictionary
@@ -28,7 +27,6 @@ MODELS = {
28
  "ResembleAI/chatterbox": "Chatterbox",
29
  "KittenML/KittenTTS": "KittenTTS",
30
  "piper-tts": "Piper (no voice cloning)",
31
- "microsoft/VibeVoice-1.5B": "VibeVoice 1.5B",
32
  }
33
 
34
  original_torch_load = torch.load
@@ -53,130 +51,6 @@ except RuntimeError as e:
53
  # Initialize KittenTTS model
54
  kittentts_model = KittenTTS("KittenML/kitten-tts-nano-0.2")
55
 
56
- # Initialize VibeVoice model
57
- vibevoice_model = None
58
- vibevoice_processor = None
59
- vibevoice_voices = {}
60
-
61
- def initialize_vibevoice():
62
- """Initialize VibeVoice model using the proper VibeVoice classes"""
63
- global vibevoice_model, vibevoice_processor, vibevoice_voices
64
- try:
65
- # Add the src directory to Python path to make vibe-voice importable
66
- src_path = os.path.join(os.path.dirname(__file__), 'src')
67
- if src_path not in sys.path:
68
- sys.path.insert(0, src_path)
69
-
70
- # Import VibeVoice specific classes from src/vibe-voice directory
71
- # Use underscore import since hyphens aren't valid in Python module names
72
- vibe_voice_path = os.path.join(src_path, 'vibevoice')
73
- if vibe_voice_path not in sys.path:
74
- sys.path.insert(0, vibe_voice_path)
75
-
76
- # Now import using the actual module structure
77
- from modular.configuration_vibevoice import VibeVoiceConfig
78
- from modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
79
- from processor.vibevoice_processor import VibeVoiceProcessor
80
-
81
- # Determine device
82
- device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
83
-
84
- # Load processor
85
- print("Loading VibeVoice processor...")
86
- vibevoice_processor = VibeVoiceProcessor.from_pretrained("microsoft/VibeVoice-1.5B")
87
-
88
- # Determine dtype and attention implementation based on device
89
- if device == "mps":
90
- load_dtype = torch.float32
91
- attn_impl_primary = "sdpa"
92
- elif device == "cuda":
93
- load_dtype = torch.bfloat16
94
- attn_impl_primary = "flash_attention_2"
95
- else:
96
- load_dtype = torch.float32
97
- attn_impl_primary = "sdpa"
98
-
99
- print(f"Using device: {device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
100
-
101
- # Load model
102
- print("Loading VibeVoice model...")
103
- try:
104
- if device == "mps":
105
- vibevoice_model = VibeVoiceForConditionalGenerationInference.from_pretrained(
106
- "microsoft/VibeVoice-1.5B",
107
- torch_dtype=load_dtype,
108
- attn_implementation=attn_impl_primary,
109
- device_map=None,
110
- )
111
- vibevoice_model.to("mps")
112
- elif device == "cuda":
113
- vibevoice_model = VibeVoiceForConditionalGenerationInference.from_pretrained(
114
- "microsoft/VibeVoice-1.5B",
115
- torch_dtype=load_dtype,
116
- device_map="cuda",
117
- attn_implementation=attn_impl_primary,
118
- )
119
- else:
120
- vibevoice_model = VibeVoiceForConditionalGenerationInference.from_pretrained(
121
- "microsoft/VibeVoice-1.5B",
122
- torch_dtype=load_dtype,
123
- device_map="cpu",
124
- attn_implementation=attn_impl_primary,
125
- )
126
- except Exception as e:
127
- if attn_impl_primary == 'flash_attention_2':
128
- print(f"[ERROR] : {type(e).__name__}: {e}")
129
- print("Falling back to attention implementation: sdpa")
130
- vibevoice_model = VibeVoiceForConditionalGenerationInference.from_pretrained(
131
- "microsoft/VibeVoice-1.5B",
132
- torch_dtype=load_dtype,
133
- device_map=(device if device in ("cuda", "cpu") else None),
134
- attn_implementation="sdpa",
135
- )
136
- if device == "mps":
137
- vibevoice_model.to("mps")
138
- else:
139
- raise e
140
-
141
- vibevoice_model.eval()
142
-
143
- # Setup noise scheduler for SDE solver
144
- vibevoice_model.model.noise_scheduler = vibevoice_model.model.noise_scheduler.from_config(
145
- vibevoice_model.model.noise_scheduler.config,
146
- algorithm_type='sde-dpmsolver++',
147
- beta_schedule='squaredcos_cap_v2'
148
- )
149
- vibevoice_model.set_ddpm_inference_steps(num_steps=10)
150
-
151
- # Load voice presets
152
- voices_dir = "src/voices/vibe_voices"
153
- if os.path.exists(voices_dir):
154
- wav_files = [f for f in os.listdir(voices_dir)
155
- if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')) and os.path.isfile(os.path.join(voices_dir, f))]
156
-
157
- for wav_file in wav_files:
158
- name = os.path.splitext(wav_file)[0]
159
- full_path = os.path.join(voices_dir, wav_file)
160
- vibevoice_voices[name] = full_path
161
-
162
- vibevoice_voices = dict(sorted(vibevoice_voices.items()))
163
- print(f"Found {len(vibevoice_voices)} voice files in {voices_dir}")
164
- print(f"Available voices: {', '.join(vibevoice_voices.keys())}")
165
- else:
166
- print(f"Warning: Voices directory not found at {voices_dir}")
167
- vibevoice_voices = {}
168
-
169
- print("VibeVoice model loaded successfully")
170
-
171
- except Exception as e:
172
- print(f"Error loading VibeVoice: {str(e)}")
173
- import traceback
174
- traceback.print_exc()
175
- raise e
176
-
177
- # Initialize VibeVoice on startup
178
- initialize_vibevoice()
179
-
180
  # Scan Piper voices
181
  def scan_piper_voices():
182
  voices_dir = "src/voices/piper_voices"
@@ -306,147 +180,6 @@ def generate_piper_speech(text, lang, voice):
306
  except Exception as e:
307
  return None, f"Error synthesizing speech: {str(e)}"
308
 
309
- def generate_vibevoice_speech(text, voice_name=None):
310
- """
311
- Generate speech from text using VibeVoice 1.5B with proper API
312
-
313
- Args:
314
- text (str): Text to convert to speech
315
- voice_name (str, optional): Name of voice preset to use
316
-
317
- Returns:
318
- str: Path to the generated audio file
319
- """
320
- if not vibevoice_model or not vibevoice_processor:
321
- raise RuntimeError("VibeVoice model not initialized")
322
-
323
- if not text.strip():
324
- raise ValueError("Please enter text to synthesize")
325
-
326
- try:
327
- # Select voice preset
328
- if voice_name and voice_name in vibevoice_voices:
329
- voice_path = vibevoice_voices[voice_name]
330
- print(f"Using voice preset: {voice_name}")
331
- else:
332
- # Use first available voice or default
333
- if vibevoice_voices:
334
- voice_name = list(vibevoice_voices.keys())[0]
335
- voice_path = vibevoice_voices[voice_name]
336
- print(f"Using default voice preset: {voice_name}")
337
- else:
338
- # Generate without voice preset (may not work well)
339
- voice_path = None
340
- print("No voice presets available, generating without voice reference")
341
-
342
- # Read voice sample if available
343
- voice_samples = []
344
- if voice_path:
345
- try:
346
- wav, sr = sf.read(voice_path)
347
- if len(wav.shape) > 1:
348
- wav = np.mean(wav, axis=1)
349
- if sr != 24000:
350
- wav = librosa.resample(wav, orig_sr=sr, target_sr=24000)
351
- voice_samples.append(wav)
352
- print(f"Loaded voice sample: {voice_path}, duration: {len(wav)/24000:.2f}s")
353
- except Exception as e:
354
- print(f"Error loading voice sample {voice_path}: {e}")
355
- voice_samples = []
356
-
357
- # Prepare input for VibeVoice - format text as single-speaker script
358
- formatted_script = f"Speaker 1: {text}"
359
-
360
- voice_samples_input = [voice_samples] if voice_samples else None
361
-
362
- inputs = vibevoice_processor(
363
- text=[formatted_script],
364
- voice_samples=voice_samples_input,
365
- padding=True,
366
- return_tensors="pt",
367
- return_attention_mask=True,
368
- )
369
-
370
- # Ensure voice samples are properly typed before processor
371
- if voice_samples_input and voice_samples_input[0]:
372
- voice_samples_input[0] = torch.tensor(voice_samples_input[0], dtype=torch.float32)
373
-
374
- # Move tensors to device and match model's data type
375
- device = next(vibevoice_model.parameters()).device
376
- model_dtype = next(vibevoice_model.parameters()).dtype
377
-
378
- for k, v in inputs.items():
379
- if torch.is_tensor(v):
380
- # Convert to model's data type before moving to device
381
- inputs[k] = v.to(dtype=model_dtype).to(device)
382
-
383
- # Generate speech using VibeVoice
384
- with torch.no_grad():
385
- outputs = vibevoice_model.generate(
386
- **inputs,
387
- cfg_scale=1.3,
388
- tokenizer=vibevoice_processor.tokenizer,
389
- generation_config={
390
- 'do_sample': False,
391
- },
392
- verbose=False,
393
- refresh_negative=True,
394
- )
395
-
396
- # Extract audio from outputs
397
- if hasattr(outputs, 'waveform'):
398
- audio = outputs.waveform
399
- elif hasattr(outputs, 'audio'):
400
- audio = outputs.audio
401
- elif isinstance(outputs, dict) and 'audio' in outputs:
402
- audio = outputs['audio']
403
- elif isinstance(outputs, torch.Tensor):
404
- audio = outputs
405
- else:
406
- # Try to get audio from the model output
407
- audio = vibevoice_model.model.generate_audio(outputs)
408
-
409
- # Ensure audio is in correct format
410
- if torch.is_tensor(audio):
411
- audio = audio.cpu().numpy()
412
-
413
- # Ensure audio is 1D and properly normalized
414
- if len(audio.shape) > 1:
415
- audio = np.mean(audio, axis=1) if audio.shape[0] < audio.shape[1] else np.mean(audio, axis=0)
416
-
417
- # Normalize to [-1, 1] range
418
- if np.max(np.abs(audio)) > 1.0:
419
- audio = audio / np.max(np.abs(audio))
420
-
421
- # Convert to 16-bit for saving
422
- audio_16bit = (audio * 32767).astype(np.int16)
423
-
424
- # Save to temporary file
425
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
426
- sf.write(tmp_file.name, audio_16bit, 24000)
427
- print(f"Generated audio saved to: {tmp_file.name}")
428
- return tmp_file.name
429
-
430
- except Exception as e:
431
- print(f"Error in VibeVoice generation: {str(e)}")
432
- import traceback
433
- traceback.print_exc()
434
- # Fallback to simple audio generation if model inference fails
435
- try:
436
- sample_rate = 22050
437
- duration = 2.0
438
- t = torch.linspace(0, duration, int(sample_rate * duration))
439
- frequency = 440 # A4 note
440
- audio = torch.sin(2 * torch.pi * frequency * t).unsqueeze(0) * 0.3
441
-
442
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
443
- # Convert tensor to numpy array before saving
444
- audio_np = audio.squeeze().numpy() # Remove extra dimensions
445
- sf.write(tmp_file.name, audio_np, sample_rate)
446
- return tmp_file.name
447
- except Exception as fallback_error:
448
- raise RuntimeError(f"Error generating speech with VibeVoice: {str(e)} - Fallback also failed: {str(fallback_error)}")
449
-
450
  def update_piper_voices(lang):
451
  choices = list(voices_by_lang.get(lang, {}).keys())
452
  value = choices[0] if choices else None
@@ -550,33 +283,7 @@ with gr.Blocks(css=custom_css, title="🎙️ TTS Model Gallery", theme=gr.theme
550
  piper_audio_output = gr.Audio(label="Generated Speech", type="filepath")
551
  piper_status = gr.Textbox(label="Status", interactive=False)
552
 
553
- # VibeVoice Model Section
554
- vibevoice_model_info = gr.HTML(create_model_card("microsoft/VibeVoice-1.5B"))
555
-
556
- with gr.Row():
557
- with gr.Column():
558
- vibevoice_voice_selection = gr.Dropdown(
559
- choices=list(vibevoice_voices.keys()) if vibevoice_voices else [],
560
- value=list(vibevoice_voices.keys())[0] if vibevoice_voices else None,
561
- label="Voice Preset"
562
- )
563
- vibevoice_generate_btn = gr.Button("Generate Speech")
564
-
565
- with gr.Column():
566
- vibevoice_audio_output = gr.Audio(label="Generated Speech", type="filepath")
567
-
568
- # Examples for VibeVoice
569
- gr.Examples(
570
- examples=[
571
- ["Hello, this is a test of VibeVoice 1.5B from Microsoft.", list(vibevoice_voices.keys())[0] if vibevoice_voices else None],
572
- ["The quick brown fox jumps over the lazy dog.", list(vibevoice_voices.keys())[0] if vibevoice_voices else None],
573
- ["Artificial intelligence is transforming the world.", list(vibevoice_voices.keys())[0] if vibevoice_voices else None]
574
- ],
575
- inputs=[text_input, vibevoice_voice_selection],
576
- outputs=vibevoice_audio_output,
577
- fn=generate_vibevoice_speech,
578
- cache_examples=False
579
- )
580
 
581
  # Examples for Chatterbox
582
  gr.Examples(
@@ -597,12 +304,7 @@ with gr.Blocks(css=custom_css, title="🎙️ TTS Model Gallery", theme=gr.theme
597
  outputs=audio_output
598
  )
599
 
600
- # Connect the VibeVoice generate button to the function
601
- vibevoice_generate_btn.click(
602
- fn=generate_vibevoice_speech,
603
- inputs=[text_input, vibevoice_voice_selection],
604
- outputs=vibevoice_audio_output
605
- )
606
 
607
  # Connect the KittenTTS generate button to the function
608
  kittentts_generate_btn.click(
 
20
  "ResembleAI/chatterbox": "Industrial-grade TTS solution with multilingual support",
21
  "KittenML/KittenTTS": "High-quality TTS with voice cloning capabilities using reference audio",
22
  "piper-tts": "Local on-device TTS with dynamic English and Chinese voice selection from Piper models",
 
23
  }
24
 
25
  # Models dictionary
 
27
  "ResembleAI/chatterbox": "Chatterbox",
28
  "KittenML/KittenTTS": "KittenTTS",
29
  "piper-tts": "Piper (no voice cloning)",
 
30
  }
31
 
32
  original_torch_load = torch.load
 
51
  # Initialize KittenTTS model
52
  kittentts_model = KittenTTS("KittenML/kitten-tts-nano-0.2")
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # Scan Piper voices
55
  def scan_piper_voices():
56
  voices_dir = "src/voices/piper_voices"
 
180
  except Exception as e:
181
  return None, f"Error synthesizing speech: {str(e)}"
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def update_piper_voices(lang):
184
  choices = list(voices_by_lang.get(lang, {}).keys())
185
  value = choices[0] if choices else None
 
283
  piper_audio_output = gr.Audio(label="Generated Speech", type="filepath")
284
  piper_status = gr.Textbox(label="Status", interactive=False)
285
 
286
+ # VibeVoice section removed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  # Examples for Chatterbox
289
  gr.Examples(
 
304
  outputs=audio_output
305
  )
306
 
307
+ # VibeVoice button connection removed
 
 
 
 
 
308
 
309
  # Connect the KittenTTS generate button to the function
310
  kittentts_generate_btn.click(
src/vibevoice/__init__.py DELETED
File without changes
src/vibevoice/configs/qwen2.5_1.5b_64k.json DELETED
@@ -1,112 +0,0 @@
1
- {
2
- "_attn_implementation_autoset": true,
3
- "acoustic_vae_dim": 64,
4
- "acoustic_tokenizer_config": {
5
- "causal": true,
6
- "channels": 1,
7
- "conv_bias": true,
8
- "conv_norm": "none",
9
- "corpus_normalize": 0.0,
10
- "decoder_depths": null,
11
- "decoder_n_filters": 32,
12
- "decoder_ratios": [
13
- 8,
14
- 5,
15
- 5,
16
- 4,
17
- 2,
18
- 2
19
- ],
20
- "disable_last_norm": true,
21
- "encoder_depths": "3-3-3-3-3-3-8",
22
- "encoder_n_filters": 32,
23
- "encoder_ratios": [
24
- 8,
25
- 5,
26
- 5,
27
- 4,
28
- 2,
29
- 2
30
- ],
31
- "fix_std": 0.5,
32
- "layer_scale_init_value": 1e-06,
33
- "layernorm": "RMSNorm",
34
- "layernorm_elementwise_affine": true,
35
- "layernorm_eps": 1e-05,
36
- "mixer_layer": "depthwise_conv",
37
- "model_type": "vibepod_acoustic_tokenizer",
38
- "pad_mode": "constant",
39
- "std_dist_type": "gaussian",
40
- "vae_dim": 64,
41
- "weight_init_value": 0.01
42
- },
43
- "decoder_config": {
44
- "attention_dropout": 0.0,
45
- "hidden_act": "silu",
46
- "hidden_size": 1536,
47
- "initializer_range": 0.02,
48
- "intermediate_size": 8960,
49
- "max_position_embeddings": 65536,
50
- "max_window_layers": 28,
51
- "model_type": "qwen2",
52
- "num_attention_heads": 12,
53
- "num_hidden_layers": 28,
54
- "num_key_value_heads": 2,
55
- "rms_norm_eps": 1e-06,
56
- "rope_scaling": null,
57
- "rope_theta": 1000000.0,
58
- "sliding_window": null,
59
- "tie_word_embeddings": true,
60
- "torch_dtype": "bfloat16",
61
- "use_cache": true,
62
- "use_sliding_window": false,
63
- "vocab_size": 151936
64
- },
65
- "diffusion_head_config": {
66
- "ddpm_batch_mul": 4,
67
- "ddpm_beta_schedule": "cosine",
68
- "ddpm_num_inference_steps": 20,
69
- "ddpm_num_steps": 1000,
70
- "diffusion_type": "ddpm",
71
- "head_ffn_ratio": 3.0,
72
- "head_layers": 4,
73
- "hidden_size": 1536,
74
- "latent_size": 64,
75
- "model_type": "vibepod_diffusion_head",
76
- "prediction_type": "v_prediction",
77
- "rms_norm_eps": 1e-05,
78
- "speech_vae_dim": 64
79
- },
80
- "model_type": "vibepod",
81
- "semantic_tokenizer_config": {
82
- "causal": true,
83
- "channels": 1,
84
- "conv_bias": true,
85
- "conv_norm": "none",
86
- "corpus_normalize": 0.0,
87
- "disable_last_norm": true,
88
- "encoder_depths": "3-3-3-3-3-3-8",
89
- "encoder_n_filters": 32,
90
- "encoder_ratios": [
91
- 8,
92
- 5,
93
- 5,
94
- 4,
95
- 2,
96
- 2
97
- ],
98
- "fix_std": 0,
99
- "layer_scale_init_value": 1e-06,
100
- "layernorm": "RMSNorm",
101
- "layernorm_elementwise_affine": true,
102
- "layernorm_eps": 1e-05,
103
- "mixer_layer": "depthwise_conv",
104
- "model_type": "vibepod_semantic_tokenizer",
105
- "pad_mode": "constant",
106
- "std_dist_type": "none",
107
- "vae_dim": 128,
108
- "weight_init_value": 0.01
109
- },
110
- "semantic_vae_dim": 128,
111
- "torch_dtype": "bfloat16"
112
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/configs/qwen2.5_7b_32k.json DELETED
@@ -1,113 +0,0 @@
1
- {
2
- "_attn_implementation_autoset": true,
3
- "acoustic_vae_dim": 64,
4
- "acoustic_tokenizer_config": {
5
- "causal": true,
6
- "channels": 1,
7
- "conv_bias": true,
8
- "conv_norm": "none",
9
- "corpus_normalize": 0.0,
10
- "decoder_depths": null,
11
- "decoder_n_filters": 32,
12
- "decoder_ratios": [
13
- 8,
14
- 5,
15
- 5,
16
- 4,
17
- 2,
18
- 2
19
- ],
20
- "disable_last_norm": true,
21
- "encoder_depths": "3-3-3-3-3-3-8",
22
- "encoder_n_filters": 32,
23
- "encoder_ratios": [
24
- 8,
25
- 5,
26
- 5,
27
- 4,
28
- 2,
29
- 2
30
- ],
31
- "fix_std": 0.5,
32
- "layer_scale_init_value": 1e-06,
33
- "layernorm": "RMSNorm",
34
- "layernorm_elementwise_affine": true,
35
- "layernorm_eps": 1e-05,
36
- "mixer_layer": "depthwise_conv",
37
- "model_type": "vibepod_acoustic_tokenizer",
38
- "pad_mode": "constant",
39
- "std_dist_type": "gaussian",
40
- "vae_dim": 64,
41
- "weight_init_value": 0.01
42
- },
43
- "decoder_config": {
44
- "attention_dropout": 0.0,
45
- "hidden_act": "silu",
46
- "hidden_size": 3584,
47
- "initializer_range": 0.02,
48
- "intermediate_size": 18944,
49
- "max_position_embeddings": 32768,
50
- "max_window_layers": 28,
51
- "model_type": "qwen2",
52
- "num_attention_heads": 28,
53
- "num_hidden_layers": 28,
54
- "num_key_value_heads": 4,
55
- "rms_norm_eps": 1e-06,
56
- "rope_theta": 1000000.0,
57
- "sliding_window": null,
58
- "tie_word_embeddings": false,
59
- "torch_dtype": "bfloat16",
60
- "transformers_version": "4.40.1",
61
- "use_cache": true,
62
- "use_mrope": false,
63
- "use_sliding_window": false,
64
- "vocab_size": 152064
65
- },
66
- "diffusion_head_config": {
67
- "ddpm_batch_mul": 4,
68
- "ddpm_beta_schedule": "cosine",
69
- "ddpm_num_inference_steps": 20,
70
- "ddpm_num_steps": 1000,
71
- "diffusion_type": "ddpm",
72
- "head_ffn_ratio": 3.0,
73
- "head_layers": 4,
74
- "hidden_size": 3584,
75
- "latent_size": 64,
76
- "model_type": "vibepod_diffusion_head",
77
- "prediction_type": "v_prediction",
78
- "rms_norm_eps": 1e-05,
79
- "speech_vae_dim": 64
80
- },
81
- "model_type": "vibepod",
82
- "semantic_tokenizer_config": {
83
- "causal": true,
84
- "channels": 1,
85
- "conv_bias": true,
86
- "conv_norm": "none",
87
- "corpus_normalize": 0.0,
88
- "disable_last_norm": true,
89
- "encoder_depths": "3-3-3-3-3-3-8",
90
- "encoder_n_filters": 32,
91
- "encoder_ratios": [
92
- 8,
93
- 5,
94
- 5,
95
- 4,
96
- 2,
97
- 2
98
- ],
99
- "fix_std": 0,
100
- "layer_scale_init_value": 1e-06,
101
- "layernorm": "RMSNorm",
102
- "layernorm_elementwise_affine": true,
103
- "layernorm_eps": 1e-05,
104
- "mixer_layer": "depthwise_conv",
105
- "model_type": "vibepod_semantic_tokenizer",
106
- "pad_mode": "constant",
107
- "std_dist_type": "none",
108
- "vae_dim": 128,
109
- "weight_init_value": 0.01
110
- },
111
- "semantic_vae_dim": 128,
112
- "torch_dtype": "bfloat16"
113
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/modular/__init__.py DELETED
File without changes
src/vibevoice/modular/configuration_vibevoice.py DELETED
@@ -1,248 +0,0 @@
1
- """ VibeVoice_AcousticTokenizer model configuration"""
2
-
3
- from typing import Dict, List, Optional, Tuple
4
-
5
- from transformers.configuration_utils import PretrainedConfig
6
- from transformers.utils import logging
7
-
8
- from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
9
-
10
- logger = logging.get_logger(__name__)
11
-
12
-
13
- class VibeVoiceAcousticTokenizerConfig(PretrainedConfig):
14
- model_type = "vibevoice_acoustic_tokenizer"
15
-
16
- def __init__(
17
- self,
18
- channels: int = 1,
19
- corpus_normalize: float = 0.0,
20
- causal: bool = True,
21
- vae_dim: int = 64,
22
- fix_std: float = 0.5,
23
- std_dist_type: str = 'gaussian',
24
- # common
25
- mixer_layer: str = 'depthwise_conv',
26
- conv_norm: str = 'none',
27
- pad_mode: str = 'constant',
28
- disable_last_norm: bool = True,
29
- layernorm: str = 'RMSNorm',
30
- layernorm_eps: float = 1e-5,
31
- layernorm_elementwise_affine: bool = True,
32
- conv_bias: bool = True,
33
- layer_scale_init_value: float = 1e-6,
34
- weight_init_value: float = 1e-2,
35
- # encoder specific
36
- encoder_n_filters: int = 32,
37
- encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
38
- encoder_depths: str = "3-3-3-3-3-3-8",
39
- # decoder specific
40
- decoder_n_filters: int = 32,
41
- decoder_ratios: Optional[List[int]] = None, # if None, same as encoder
42
- decoder_depths: Optional[str] = None,
43
- **kwargs
44
- ):
45
- super().__init__(**kwargs)
46
- self.channels = channels
47
- self.corpus_normalize = corpus_normalize
48
- self.causal = causal
49
- self.vae_dim = vae_dim
50
- self.fix_std = fix_std
51
- self.std_dist_type = std_dist_type
52
-
53
- # common parameters
54
- self.conv_norm = conv_norm
55
- self.pad_mode = pad_mode
56
- self.layernorm_eps = layernorm_eps
57
- self.disable_last_norm = disable_last_norm
58
- self.layernorm = layernorm
59
- self.layernorm_elementwise_affine = layernorm_elementwise_affine
60
- self.conv_bias = conv_bias
61
- self.layer_scale_init_value = layer_scale_init_value
62
- self.weight_init_value = weight_init_value
63
- self.mixer_layer = mixer_layer
64
-
65
- # encoder specific parameters
66
- self.encoder_n_filters = encoder_n_filters
67
- self.encoder_ratios = encoder_ratios
68
- self.encoder_depths = encoder_depths
69
-
70
- # decoder specific parameters
71
- self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios
72
- self.decoder_n_filters = decoder_n_filters
73
- self.decoder_depths = decoder_depths
74
-
75
-
76
- class VibeVoiceSemanticTokenizerConfig(PretrainedConfig):
77
- model_type = "vibevoice_semantic_tokenizer"
78
-
79
- def __init__(
80
- self,
81
- channels: int = 1,
82
- corpus_normalize: float = 0.0,
83
- causal: bool = True,
84
- vae_dim: int = 64,
85
- fix_std: float = 0,
86
- std_dist_type: str = 'none',
87
- # common
88
- mixer_layer: str = 'depthwise_conv',
89
- conv_norm: str = 'none',
90
- pad_mode: str = 'constant',
91
- disable_last_norm: bool = True,
92
- layernorm: str = 'RMSNorm',
93
- layernorm_eps: float = 1e-5,
94
- layernorm_elementwise_affine: bool = True,
95
- conv_bias: bool = True,
96
- layer_scale_init_value: float = 1e-6,
97
- weight_init_value: float = 1e-2,
98
- # encoder specific
99
- encoder_n_filters: int = 32,
100
- encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
101
- encoder_depths: str = "3-3-3-3-3-3-8",
102
- **kwargs
103
- ):
104
- super().__init__(**kwargs)
105
- self.channels = channels
106
- self.corpus_normalize = corpus_normalize
107
- self.causal = causal
108
- self.vae_dim = vae_dim
109
- self.fix_std = fix_std
110
- self.std_dist_type = std_dist_type
111
-
112
- # common parameters
113
- self.conv_norm = conv_norm
114
- self.pad_mode = pad_mode
115
- self.layernorm_eps = layernorm_eps
116
- self.disable_last_norm = disable_last_norm
117
- self.layernorm = layernorm
118
- self.layernorm_elementwise_affine = layernorm_elementwise_affine
119
- self.conv_bias = conv_bias
120
- self.layer_scale_init_value = layer_scale_init_value
121
- self.weight_init_value = weight_init_value
122
- self.mixer_layer = mixer_layer
123
-
124
- # encoder specific parameters
125
- self.encoder_n_filters = encoder_n_filters
126
- self.encoder_ratios = encoder_ratios
127
- self.encoder_depths = encoder_depths
128
-
129
-
130
- class VibeVoiceDiffusionHeadConfig(PretrainedConfig):
131
- model_type = "vibevoice_diffusion_head"
132
-
133
- def __init__(
134
- self,
135
- hidden_size=768,
136
- head_layers=4,
137
- head_ffn_ratio=3.0,
138
- rms_norm_eps=1e-5,
139
- latent_size=64,
140
- speech_vae_dim=None,
141
- prediction_type="v_prediction",
142
- diffusion_type="ddpm",
143
- ddpm_num_steps=1000,
144
- ddpm_num_inference_steps=20,
145
- ddpm_beta_schedule="cosine",
146
- ddpm_batch_mul=4,
147
- **kwargs
148
- ):
149
- self.hidden_size = hidden_size
150
- self.head_layers = head_layers
151
- self.head_ffn_ratio = head_ffn_ratio
152
- self.rms_norm_eps = rms_norm_eps
153
- self.latent_size = latent_size
154
- self.speech_vae_dim = speech_vae_dim
155
- self.prediction_type = prediction_type
156
- self.diffusion_type = diffusion_type
157
- self.ddpm_num_steps = ddpm_num_steps
158
- self.ddpm_num_inference_steps = ddpm_num_inference_steps
159
- self.ddpm_beta_schedule = ddpm_beta_schedule
160
- self.ddpm_batch_mul = ddpm_batch_mul
161
-
162
- super().__init__(**kwargs)
163
-
164
- class VibeVoiceConfig(PretrainedConfig):
165
- model_type = "vibevoice"
166
- is_composition = True
167
- sub_configs = {
168
- "acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
169
- "semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
170
- "decoder_config": Qwen2Config,
171
- "diffusion_head_config": VibeVoiceDiffusionHeadConfig,
172
- }
173
- # keys_to_ignore_at_inference = ["past_key_values"]
174
- # Default tensor parallel plan for base model `Qwen2`
175
- base_model_tp_plan = {
176
- "layers.*.self_attn.q_proj": "colwise",
177
- "layers.*.self_attn.k_proj": "colwise",
178
- "layers.*.self_attn.v_proj": "colwise",
179
- "layers.*.self_attn.o_proj": "rowwise",
180
- "layers.*.mlp.gate_proj": "colwise",
181
- "layers.*.mlp.up_proj": "colwise",
182
- "layers.*.mlp.down_proj": "rowwise",
183
- }
184
-
185
- def __init__(
186
- self,
187
- acoustic_tokenizer_config=None,
188
- semantic_tokenizer_config=None,
189
- decoder_config=None,
190
- diffusion_head_config=None,
191
- **kwargs
192
- ):
193
-
194
- # kwargs["_attn_implementation"] = "flash_attention_2"
195
- kwargs["_attn_implementation_autoset"] = False
196
-
197
- if acoustic_tokenizer_config is None:
198
- self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
199
- elif isinstance(acoustic_tokenizer_config, dict):
200
- acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
201
- self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
202
- elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
203
- # If an instance of the config class is provided
204
- self.acoustic_tokenizer_config = acoustic_tokenizer_config
205
-
206
- if semantic_tokenizer_config is None:
207
- self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
208
- elif isinstance(semantic_tokenizer_config, dict):
209
- semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
210
- self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
211
- elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
212
- # If an instance of the config class is provided
213
- self.semantic_tokenizer_config = semantic_tokenizer_config
214
-
215
- if decoder_config is None:
216
- self.decoder_config = self.sub_configs["decoder_config"]()
217
- elif isinstance(decoder_config, dict):
218
- # If a dictionary is provided, instantiate the config class with it
219
- # self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
220
- if decoder_config.get("model_type", '') == "qwen2":
221
- self.decoder_config = Qwen2Config(**decoder_config)
222
- else:
223
- raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
224
- elif isinstance(decoder_config, (Qwen2Config,)):
225
- # If an instance of the config class is provided
226
- self.decoder_config = decoder_config
227
-
228
- if diffusion_head_config is None:
229
- self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
230
- elif isinstance(diffusion_head_config, dict):
231
- diffusion_head_config["model_type"] = "vibevoice_diffusion_head"
232
- self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
233
- elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig):
234
- # If an instance of the config class is provided
235
- self.diffusion_head_config = diffusion_head_config
236
-
237
- # other parameters
238
- self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
239
- self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
240
-
241
- super().__init__(**kwargs)
242
-
243
- __all__ = [
244
- "VibeVoiceAcousticTokenizerConfig",
245
- "VibeVoiceSemanticTokenizerConfig",
246
- "VibeVoiceDiffusionHeadConfig",
247
- "VibeVoiceConfig"
248
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/modular/modeling_vibevoice.py DELETED
@@ -1,487 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Dict, List, Optional, Tuple, Union, Callable
3
- from tqdm import tqdm
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import torch.distributed as dist
8
-
9
- from transformers.models.auto import AutoModel, AutoModelForCausalLM
10
-
11
- from transformers.activations import ACT2FN
12
- from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
13
- from transformers.models.llama.modeling_llama import LlamaRMSNorm
14
- from transformers import modeling_utils
15
- from transformers.modeling_utils import PreTrainedModel
16
- from transformers.utils import logging
17
-
18
-
19
- from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
20
- from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
21
- from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
22
-
23
- from .configuration_vibevoice import VibeVoiceConfig
24
-
25
-
26
- logger = logging.get_logger(__name__)
27
-
28
- if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
29
- modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
30
-
31
- @dataclass
32
- class VibeVoiceCausalLMOutputWithPast(ModelOutput):
33
- loss: Optional[torch.FloatTensor] = None
34
- diffusion_loss: Optional[torch.FloatTensor] = None
35
- speech_token_num: Optional[int] = None
36
- logits: torch.FloatTensor = None
37
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
38
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
39
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
40
-
41
-
42
- @dataclass
43
- class VibeVoiceGenerationOutput(ModelOutput):
44
- """
45
- Output type for VibeVoice generation.
46
-
47
- Args:
48
- sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
49
- The generated sequences.
50
- speech_outputs (`List[torch.FloatTensor]`, *optional*):
51
- List of generated speech waveforms or latents for each speech segment.
52
- """
53
- sequences: torch.LongTensor = None
54
- speech_outputs: Optional[List[torch.FloatTensor]] = None
55
-
56
-
57
- class SpeechConnector(nn.Module):
58
- def __init__(self, input_dim, output_dim):
59
- super().__init__()
60
- self.fc1 = nn.Linear(input_dim, output_dim)
61
- self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
62
- self.fc2 = nn.Linear(output_dim, output_dim)
63
-
64
- def forward(self, features, **kwargs):
65
- x = self.fc1(features)
66
- x = self.norm(x)
67
- x = self.fc2(x)
68
- return x
69
-
70
-
71
- # @auto_docstring
72
- class VibeVoicePreTrainedModel(PreTrainedModel):
73
- config_class = VibeVoiceConfig
74
- base_model_prefix = "model"
75
- supports_gradient_checkpointing = True
76
- _skip_keys_device_placement = "past_key_values"
77
- _supports_cache_class = True
78
- _supports_flash_attn_2 = True
79
- _supports_sdpa = True
80
- _supports_quantized_cache = True
81
- _supports_static_cache = True
82
- _supports_attention_backend = True
83
-
84
- def _init_weights(self, module):
85
- if isinstance(module, VibeVoiceDiffusionHead):
86
- module.initialize_weights()
87
- return
88
-
89
- # Use the language model's initializer_range if available
90
- if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
91
- std = self.config.language_model_config.initializer_range
92
- elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
93
- std = self.config.decoder_config.initializer_range
94
- else:
95
- std = 0.02 # Default value
96
-
97
- if isinstance(module, nn.Linear):
98
- module.weight.data.normal_(mean=0.0, std=std)
99
- if module.bias is not None:
100
- module.bias.data.zero_()
101
- elif isinstance(module, nn.LayerNorm):
102
- module.weight.data.fill_(1.0)
103
- module.bias.data.zero_()
104
-
105
- # @auto_docstring
106
- class VibeVoiceModel(VibeVoicePreTrainedModel):
107
- def __init__(self, config):
108
- super().__init__(config)
109
-
110
- if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
111
- if isinstance(config.torch_dtype, str):
112
- dtype = getattr(torch, config.torch_dtype)
113
- else:
114
- dtype = config.torch_dtype
115
- else:
116
- dtype = torch.float32
117
-
118
- # Initialize Qwen2 model for language modeling
119
- lm_config = config.decoder_config
120
- self.language_model = AutoModel.from_config(lm_config)
121
-
122
- # Initialize speech components if needed
123
- self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
124
- self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
125
-
126
- self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
127
- self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
128
-
129
- # Register scaling factors as buffers - use 1D tensors for FSDP compatibility
130
- self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
131
- self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
132
-
133
- # Initialize prediction head for speech generation
134
- self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
135
-
136
- # Initialize noise scheduler
137
- self.noise_scheduler = DPMSolverMultistepScheduler(
138
- num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
139
- beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
140
- prediction_type=config.diffusion_head_config.prediction_type
141
- )
142
-
143
- def get_input_embeddings(self):
144
- if hasattr(self.language_model, 'embed_tokens'):
145
- # If the language model has an embed_tokens attribute, return it
146
- return self.language_model.embed_tokens
147
-
148
- for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
149
- if attr.orig_name == 'embed_tokens.weight':
150
- return getattr(self.language_model, name)
151
- assert False, 'should not arrive here'
152
-
153
- def set_input_embeddings(self, value):
154
- self.language_model.embed_tokens = value
155
-
156
- def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
157
- """Set the speech tokenizers used for encoding and decoding speech."""
158
- self.acoustic_tokenizer = acoustic_tokenizer
159
- self.semantic_tokenizer = semantic_tokenizer
160
-
161
- # Reset the encoder to evaluation mode
162
- if self.acoustic_tokenizer is not None:
163
- self.acoustic_tokenizer.eval()
164
-
165
- if self.semantic_tokenizer is not None:
166
- self.semantic_tokenizer.eval()
167
-
168
- def forward(
169
- self,
170
- input_ids: torch.LongTensor = None,
171
- attention_mask: Optional[torch.Tensor] = None,
172
- position_ids: Optional[torch.LongTensor] = None,
173
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
174
- inputs_embeds: Optional[torch.FloatTensor] = None,
175
- use_cache: Optional[bool] = None,
176
- output_attentions: Optional[bool] = None,
177
- output_hidden_states: Optional[bool] = None,
178
- return_dict: Optional[bool] = None,
179
- cache_position: Optional[torch.LongTensor] = None,
180
- **kwargs,
181
- ) -> Union[Tuple, BaseModelOutputWithPast]:
182
-
183
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
184
-
185
- # Forward through language model
186
- outputs = self.language_model(
187
- input_ids=input_ids,
188
- attention_mask=attention_mask,
189
- position_ids=position_ids,
190
- past_key_values=past_key_values,
191
- inputs_embeds=inputs_embeds,
192
- use_cache=use_cache,
193
- output_attentions=output_attentions,
194
- output_hidden_states=output_hidden_states,
195
- return_dict=return_dict,
196
- cache_position=cache_position,
197
- **kwargs,
198
- )
199
-
200
- if not return_dict:
201
- return outputs
202
-
203
- return BaseModelOutputWithPast(
204
- last_hidden_state=outputs.last_hidden_state,
205
- past_key_values=outputs.past_key_values,
206
- hidden_states=outputs.hidden_states,
207
- attentions=outputs.attentions,
208
- )
209
-
210
-
211
- class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
212
- _tied_weights_keys = ["lm_head.weight"]
213
- _tp_plan = {"lm_head": "colwise_rep"}
214
-
215
- def __init__(self, config):
216
- super().__init__(config)
217
- self.model = VibeVoiceModel(config)
218
- self.vocab_size = config.decoder_config.vocab_size
219
- self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
220
-
221
- self.post_init()
222
-
223
- def get_input_embeddings(self):
224
- return self.model.get_input_embeddings()
225
-
226
- def set_input_embeddings(self, value):
227
- self.model.set_input_embeddings(value)
228
-
229
- def get_output_embeddings(self):
230
- return self.lm_head
231
-
232
- def set_decoder(self, decoder):
233
- self.model.language_model = decoder
234
-
235
- def get_decoder(self):
236
- return self.model.language_model
237
-
238
- def tie_weights(self):
239
- """
240
- Tie the weights between the input embeddings and the output embeddings.
241
- """
242
- if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
243
- # The standard PreTrainedModel method will handle the tying.
244
- # It typically does a simple parameter object assignment, which is
245
- # CORRECT to do BEFORE FSDP wraps the model.
246
- output_embeddings = self.get_output_embeddings()
247
- input_embeddings = self.get_input_embeddings()
248
- if hasattr(input_embeddings, 'weight'):
249
- output_embeddings.weight = input_embeddings.weight
250
- else:
251
- # maybe returned input_embeddings a tensor directly
252
- output_embeddings.weight = input_embeddings
253
-
254
- if getattr(output_embeddings, "bias", None) is not None:
255
- output_embeddings.bias.data = nn.functional.pad(
256
- output_embeddings.bias.data,
257
- (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
258
- "constant",
259
- 0,
260
- )
261
- print("✅ Tied input and output embeddings using standard assignment.")
262
- else:
263
- print("ℹ️ tie_word_embeddings is False, not tying weights.")
264
-
265
- # Also, ensure set_output_embeddings is safe, though your implementation looks okay.
266
- # The key is to avoid calling it after accelerator.prepare().
267
- def set_output_embeddings(self, new_embeddings):
268
- # Your current implementation using data.copy_ is good practice,
269
- # but the best way is to not call this after prepare().
270
- self.lm_head = new_embeddings
271
-
272
- def forward_speech_features(
273
- self,
274
- speech_tensors=None,
275
- speech_masks=None,
276
- speech_type="audio",
277
- return_unmask=False
278
- ):
279
- if speech_tensors is None:
280
- # Use config to get vae_dim instead of non-existent self.args
281
- vae_dim = self.config.acoustic_tokenizer_config.vae_dim
282
- audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
283
- connect_features = self.model.acoustic_connector(audio_features)
284
- return audio_features, connect_features
285
- else:
286
- with torch.no_grad():
287
- if speech_type == "audio":
288
- with torch.no_grad():
289
- frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
290
- audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
291
-
292
- elif speech_type == "vae":
293
- # Use config to get vae_dim instead of non-existent self.args
294
- vae_dim = self.config.acoustic_tokenizer_config.vae_dim
295
- speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
296
-
297
- # gaussian sample from the speech_mode
298
- batch_size = speech_mode.size(0)
299
- value = self.model.acoustic_tokenizer.fix_std / 0.8
300
- std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
301
- std = std.view(-1, *[1] * (speech_mode.dim() - 1))
302
- audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
303
- else:
304
- raise NotImplementedError(f"Speech type {speech_type} not implemented")
305
-
306
- if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
307
- scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
308
- bias_factor = -audio_tokens[speech_masks].flatten().mean()
309
-
310
- # Only use distributed operations if the process group is initialized
311
- if dist.is_available() and dist.is_initialized():
312
- dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
313
- dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
314
- world_size = dist.get_world_size()
315
- self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
316
- self.model.speech_bias_factor.copy_(bias_factor / world_size)
317
- print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
318
- else:
319
- # Single process case
320
- self.model.speech_scaling_factor.copy_(scaling_factor)
321
- self.model.speech_bias_factor.copy_(bias_factor)
322
- print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
323
-
324
- audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
325
-
326
- connect_features = self.model.acoustic_connector(audio_features)
327
- if return_unmask:
328
- return audio_features, connect_features
329
- return audio_features[speech_masks], connect_features[speech_masks]
330
-
331
- def forward(
332
- self,
333
- input_ids: torch.LongTensor = None,
334
- attention_mask: Optional[torch.Tensor] = None,
335
- position_ids: Optional[torch.LongTensor] = None,
336
- past_key_values: Optional[List[torch.FloatTensor]] = None,
337
- inputs_embeds: Optional[torch.FloatTensor] = None,
338
- labels: Optional[torch.LongTensor] = None,
339
- use_cache: Optional[bool] = False,
340
- output_attentions: Optional[bool] = None,
341
- output_hidden_states: Optional[bool] = None,
342
- return_dict: Optional[bool] = None,
343
- cache_position: Optional[torch.LongTensor] = None,
344
- # New arguments for speech processing and loss calculation
345
- speech_tensors: Optional[torch.FloatTensor] = None,
346
- speech_masks: Optional[torch.BoolTensor] = None,
347
- speeches_loss_input: Optional[torch.FloatTensor] = None,
348
- speech_semantic_tensors: Optional[torch.FloatTensor] = None,
349
- acoustic_input_mask: Optional[torch.BoolTensor] = None,
350
- acoustic_loss_mask: Optional[torch.BoolTensor] = None,
351
- ddpm_batch_mul: int = 1,
352
- **kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
353
- ) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
354
-
355
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
356
-
357
- x = self.get_input_embeddings()(input_ids)
358
-
359
- semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
360
- if speeches_loss_input is not None:
361
- # only part audio need diffuse
362
- speech_all_features, speech_all_connect_features = self.forward_speech_features(
363
- speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
364
- speech_masks=speech_masks,
365
- speech_type=kwargs.get("speech_type", "audio"),
366
- return_unmask=True
367
- )
368
- if speech_tensors is not None:
369
- if semantic_speech_all_connect_features is not None:
370
- x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + semantic_speech_all_connect_features[speech_masks]
371
- else:
372
- x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
373
- speech_features = speech_all_features[speeches_loss_input.unsqueeze(-1) & speech_masks] # only part audio need diffuse
374
- speech_connect_features = speech_all_connect_features[speeches_loss_input.unsqueeze(-1) & speech_masks]
375
- else:
376
- speech_features, speech_connect_features = self.forward_speech_features(
377
- speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
378
- speech_masks=speech_masks,
379
- speech_type=kwargs.get("speech_type", "audio"),
380
- )
381
- if speech_tensors is not None:
382
- x[acoustic_input_mask] = speech_connect_features
383
-
384
- outputs = self.model(
385
- input_ids=None,
386
- attention_mask=attention_mask,
387
- position_ids=position_ids,
388
- past_key_values=past_key_values,
389
- inputs_embeds=x,
390
- use_cache=use_cache,
391
- output_attentions=output_attentions,
392
- output_hidden_states=False,
393
- return_dict=return_dict,
394
- cache_position=cache_position,
395
- )
396
-
397
- hidden_states = outputs.last_hidden_state
398
- logits = self.lm_head(hidden_states)
399
- # logits = logits.float()
400
-
401
- loss = None
402
- if labels is not None:
403
- # The custom CE loss with masking is calculated in the training script.
404
- # We leave the standard loss calculation here as None.
405
- pass
406
-
407
- # --- Diffusion Loss Calculation ---
408
- diffusion_loss = None
409
- # This block is executed only if we are in a context that involves speech.
410
- if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
411
- condition_features = hidden_states[acoustic_loss_mask]
412
-
413
- speech_len, latent_size = speech_features.shape
414
-
415
- noise = torch.randn(
416
- (speech_len * ddpm_batch_mul, latent_size),
417
- device=hidden_states.device,
418
- dtype=hidden_states.dtype
419
- )
420
-
421
- timesteps = torch.multinomial(
422
- torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
423
- speech_len * ddpm_batch_mul,
424
- replacement=True,
425
- ).to(hidden_states.device)
426
-
427
- speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
428
- condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
429
-
430
- noisy_speech_features = self.model.noise_scheduler.add_noise(
431
- speech_features_repeated, noise, timesteps
432
- )
433
-
434
- model_output = self.model.prediction_head(
435
- noisy_speech_features,
436
- timesteps.type_as(x),
437
- condition_features_repeated
438
- )
439
-
440
- prediction_type = self.config.diffusion_head_config.prediction_type
441
- if prediction_type == "epsilon":
442
- target_for_loss = noise
443
- elif prediction_type == "v_prediction":
444
- target_for_loss = self.model.noise_scheduler.get_velocity(
445
- speech_features_repeated, noise, timesteps
446
- )
447
- else:
448
- raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
449
-
450
- diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
451
- if latent_size > 0 and ddpm_batch_mul > 0:
452
- diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
453
- else:
454
- diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
455
-
456
- else:
457
- # Dummy loss for DDP to work when there are no speech samples in a batch,
458
- # but we are in a speech context.
459
- diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
460
- diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
461
- diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
462
- # --- End Diffusion Loss Calculation ---
463
-
464
- if not return_dict:
465
- output = (logits, speech_len) + outputs.to_tuple()[1:]
466
- return (loss, diffusion_loss) + output
467
-
468
- return VibeVoiceCausalLMOutputWithPast(
469
- loss=loss,
470
- diffusion_loss=diffusion_loss,
471
- speech_token_num=speech_len if speech_tensors is not None else 0,
472
- logits=logits,
473
- past_key_values=outputs.past_key_values,
474
- hidden_states=outputs.hidden_states,
475
- attentions=outputs.attentions,
476
- )
477
-
478
- AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
479
- AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
480
-
481
- __all__ = [
482
- "VibeVoiceModel",
483
- "VibeVoicePreTrainedModel",
484
- "VibeVoiceForConditionalGeneration",
485
- "VibeVoiceCausalLMOutputWithPast",
486
- "VibeVoiceGenerationOutput",
487
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/modular/modeling_vibevoice_inference.py DELETED
@@ -1,716 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Dict, List, Optional, Tuple, Union, Callable
3
- from tqdm import tqdm
4
- import torch
5
- import torch.nn as nn
6
-
7
- from transformers.models.auto import AutoModel, AutoModelForCausalLM
8
-
9
- from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList
10
- from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
11
- from transformers import modeling_utils
12
- from transformers.modeling_utils import PreTrainedModel
13
- from transformers.utils import logging
14
-
15
-
16
- # from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
17
- from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceTokenizerEncoderOutput
18
- from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
19
- from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
20
-
21
- from .configuration_vibevoice import VibeVoiceConfig
22
-
23
- from .modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizer, VibeVoiceTextTokenizerFast
24
-
25
- from .modeling_vibevoice import VibeVoiceModel, VibeVoicePreTrainedModel
26
- from .streamer import AudioStreamer, AsyncAudioStreamer
27
-
28
- logger = logging.get_logger(__name__)
29
-
30
- if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
31
- modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
32
-
33
- @dataclass
34
- class VibeVoiceCausalLMOutputWithPast(BaseModelOutputWithPast):
35
- logits: Optional[torch.FloatTensor] = None
36
-
37
- @dataclass
38
- class VibeVoiceGenerationOutput(ModelOutput):
39
- """
40
- Output type for VibeVoice generation.
41
-
42
- Args:
43
- sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
44
- The generated sequences.
45
- speech_outputs (`List[torch.FloatTensor]`, *optional*):
46
- List of generated speech waveforms or latents for each speech segment.
47
- """
48
- sequences: torch.LongTensor = None
49
- speech_outputs: Optional[List[torch.FloatTensor]] = None
50
- reach_max_step_sample: Optional[torch.BoolTensor] = None
51
-
52
- class VibeVoiceTokenConstraintProcessor(LogitsProcessor):
53
- """Constrains token generation to only valid tokens during speech generation."""
54
-
55
- def __init__(self, valid_token_ids: List[int], device: torch.device = None):
56
- self.valid_token_ids = torch.tensor(valid_token_ids, dtype=torch.long, device=device)
57
-
58
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
59
- # Create a mask for valid tokens
60
- mask = torch.full_like(scores, float('-inf'))
61
- mask[:, self.valid_token_ids] = 0
62
-
63
- # Apply mask to scores
64
- scores = scores + mask
65
- return scores
66
-
67
- class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, GenerationMixin):
68
- _tied_weights_keys = ["lm_head.weight"]
69
- _tp_plan = {"lm_head": "colwise_rep"}
70
-
71
- def __init__(self, config):
72
- super().__init__(config)
73
-
74
- # Initialize the base model
75
- self.model = VibeVoiceModel(config)
76
-
77
- # LM head for text generation
78
- self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.decoder_config.vocab_size, bias=False)
79
-
80
- # inference configuration
81
- self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps
82
-
83
- # Initialize weights and apply final processing
84
- self.post_init()
85
-
86
- @property
87
- def noise_scheduler(self):
88
- return self.model.noise_scheduler
89
-
90
- @property
91
- def prediction_head(self):
92
- return self.model.prediction_head
93
-
94
- @property
95
- def speech_scaling_factor(self):
96
- return self.model.speech_scaling_factor
97
-
98
- @property
99
- def speech_bias_factor(self):
100
- return self.model.speech_bias_factor
101
-
102
- @property
103
- def acoustic_tokenizer(self):
104
- return self.model.acoustic_tokenizer
105
-
106
- @property
107
- def semantic_tokenizer(self):
108
- return self.model.semantic_tokenizer
109
-
110
- @property
111
- def acoustic_connector(self):
112
- return self.model.acoustic_connector
113
-
114
- @property
115
- def semantic_connector(self):
116
- return self.model.semantic_connector
117
-
118
- def tie_weights(self):
119
- """
120
- Tie the weights between the input embeddings and the output embeddings.
121
- """
122
- # Tie lm_head.weight to language_model.embed_tokens.weight
123
- if not getattr(self.config, 'tie_word_embeddings', False):
124
- return
125
-
126
- if hasattr(self, 'lm_head') and hasattr(self.model.language_model, 'embed_tokens'):
127
- self.lm_head.weight = self.model.language_model.embed_tokens.weight
128
-
129
- def get_input_embeddings(self):
130
- return self.model.get_input_embeddings()
131
-
132
- def set_input_embeddings(self, value):
133
- self.model.set_input_embeddings(value)
134
-
135
- def get_output_embeddings(self):
136
- return self.lm_head
137
-
138
- def set_output_embeddings(self, new_embeddings):
139
- self.lm_head = new_embeddings
140
-
141
- def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
142
- """Set the speech tokenizers used for encoding and decoding speech."""
143
- self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer)
144
-
145
- def set_ddpm_inference_steps(self, num_steps=None):
146
- self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
147
-
148
- def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"):
149
- """Process speech inputs through tokenizers and connectors."""
150
- with torch.no_grad():
151
- if speech_type == "audio":
152
- # Encode audio to acoustic latents
153
- encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
154
- acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
155
-
156
- # Apply scaling and bias
157
- acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
158
-
159
- # Connect to language model space
160
- acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
161
-
162
- return acoustic_features, acoustic_connected
163
- elif speech_type == "pt":
164
- encoder_output = VibeVoiceTokenizerEncoderOutput(mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std)
165
- acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
166
-
167
- # Apply scaling and bias
168
- acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
169
-
170
- # Connect to language model space
171
- acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
172
-
173
- return acoustic_features, acoustic_connected
174
- else:
175
- raise NotImplementedError(f"Speech type {speech_type} not implemented")
176
-
177
- # @can_return_tuple
178
- def forward(
179
- self,
180
- input_ids: torch.LongTensor = None,
181
- attention_mask: Optional[torch.Tensor] = None,
182
- position_ids: Optional[torch.LongTensor] = None,
183
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
184
- inputs_embeds: Optional[torch.FloatTensor] = None,
185
- labels: Optional[torch.LongTensor] = None,
186
- use_cache: Optional[bool] = None,
187
- output_attentions: Optional[bool] = None,
188
- output_hidden_states: Optional[bool] = None,
189
- return_dict: Optional[bool] = None,
190
- cache_position: Optional[torch.LongTensor] = None,
191
- speech_tensors: Optional[torch.FloatTensor] = None,
192
- speech_masks: Optional[torch.BoolTensor] = None,
193
- speech_input_mask: Optional[torch.BoolTensor] = None,
194
- logits_to_keep: Union[int, slice] = 0,
195
- **kwargs,
196
- ) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
197
- """
198
- Args:
199
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
200
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
201
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
202
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
203
- speech_tensors (`torch.FloatTensor`, *optional*):
204
- Input speech waveforms for voice cloning or speech understanding.
205
- speech_masks (`torch.BoolTensor`, *optional*):
206
- Masks indicating valid speech frames.
207
- speech_input_mask (`torch.BoolTensor`, *optional*):
208
- Positions in the input sequence where speech embeddings should be inserted.
209
-
210
- Returns:
211
- `VibeVoiceCausalLMOutputWithPast` or tuple
212
- """
213
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
214
-
215
- # Get embeddings
216
- if inputs_embeds is None:
217
- inputs_embeds = self.model.get_input_embeddings()(input_ids)
218
-
219
- # Process speech inputs if provided
220
- if speech_tensors is not None and speech_masks is not None:
221
- # Ensure speech tensors match model's data type
222
- speech_tensors = speech_tensors.to(self.dtype)
223
- acoustic_features, speech_embeds = self._process_speech_inputs(speech_tensors, speech_masks)
224
- if speech_input_mask is not None:
225
- inputs_embeds[speech_input_mask] = speech_embeds
226
-
227
- outputs = self.model(
228
- inputs_embeds=inputs_embeds,
229
- attention_mask=attention_mask,
230
- position_ids=position_ids,
231
- past_key_values=past_key_values,
232
- use_cache=use_cache,
233
- output_attentions=output_attentions,
234
- output_hidden_states=output_hidden_states,
235
- return_dict=return_dict,
236
- cache_position=cache_position,
237
- **kwargs,
238
- )
239
-
240
- hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
241
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
242
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
243
- logits = self.lm_head(hidden_states[:, slice_indices, :])
244
-
245
- if labels is not None:
246
- raise NotImplementedError("Loss computation is not implemented in this version.")
247
-
248
- return VibeVoiceCausalLMOutputWithPast(
249
- logits=logits,
250
- past_key_values=outputs.past_key_values,
251
- last_hidden_state=hidden_states,
252
- attentions=outputs.attentions,
253
- )
254
-
255
- def _build_generate_config_model_kwargs(self, generation_config, inputs, tokenizer, return_processors=False, **kwargs):
256
- if generation_config is None:
257
- generation_config = GenerationConfig(
258
- bos_token_id=tokenizer.bos_token_id,
259
- eos_token_id=tokenizer.eos_token_id,
260
- pad_token_id = tokenizer.pad_token_id
261
- )
262
- else:
263
- generation_config = GenerationConfig(
264
- **generation_config,
265
- bos_token_id=tokenizer.bos_token_id,
266
- eos_token_id=tokenizer.eos_token_id,
267
- pad_token_id = tokenizer.pad_token_id
268
- )
269
-
270
- generation_config, model_kwargs = self._prepare_generation_config(
271
- generation_config,
272
- use_cache=True,
273
- speech_start_id=tokenizer.speech_start_id,
274
- speech_end_id=tokenizer.speech_end_id,
275
- speech_diffusion_id=tokenizer.speech_diffusion_id,
276
- **kwargs
277
- )
278
- generation_config.speech_start_id = tokenizer.speech_start_id
279
- generation_config.speech_end_id = tokenizer.speech_end_id
280
- generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
281
-
282
- inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)
283
- batch_size = inputs_tensor.shape[0]
284
- device = self.device
285
-
286
- self._prepare_special_tokens(generation_config, True, device=device)
287
- generation_config.use_cache = True
288
- model_kwargs["use_cache"] = generation_config.use_cache
289
- input_ids = inputs_tensor.to(self.device)
290
-
291
- input_ids_length = input_ids.shape[1]
292
- has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
293
- has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
294
- generation_config = self._prepare_generated_length(
295
- generation_config=generation_config,
296
- has_default_max_length=has_default_max_length,
297
- has_default_min_length=has_default_min_length,
298
- model_input_name=model_input_name,
299
- inputs_tensor=inputs_tensor,
300
- input_ids_length=input_ids_length,
301
- )
302
-
303
- max_cache_length = generation_config.max_length - 1
304
- self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
305
- model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
306
- for k, v in model_kwargs.items():
307
- if isinstance(v, torch.Tensor):
308
- model_kwargs[k] = v.to(device=device)
309
-
310
- if return_processors:
311
- logits_processor = self._get_logits_processor(
312
- generation_config=generation_config,
313
- input_ids_seq_length=input_ids_length,
314
- encoder_input_ids=inputs_tensor,
315
- prefix_allowed_tokens_fn=None,
316
- logits_processor=LogitsProcessorList(),
317
- device=inputs_tensor.device,
318
- model_kwargs=model_kwargs,
319
- )
320
-
321
- stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=StoppingCriteriaList())
322
-
323
- return generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria
324
- else:
325
- return generation_config, model_kwargs, input_ids
326
-
327
- @torch.no_grad()
328
- def generate(
329
- self,
330
- inputs: Optional[torch.Tensor] = None,
331
- generation_config: Optional[GenerationConfig] = None,
332
- logits_processor: Optional[LogitsProcessorList] = None,
333
- stopping_criteria: Optional[StoppingCriteriaList] = None,
334
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
335
- synced_gpus: Optional[bool] = None,
336
- assistant_model: Optional["PreTrainedModel"] = None,
337
- audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
338
- negative_prompt_ids: Optional[torch.Tensor] = None,
339
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
340
- speech_tensors: Optional[torch.FloatTensor] = None,
341
- speech_masks: Optional[torch.BoolTensor] = None,
342
- speech_input_mask: Optional[torch.BoolTensor] = None,
343
- return_speech: bool = True,
344
- cfg_scale: float = 1.0,
345
- stop_check_fn: Optional[Callable[[], bool]] = None,
346
- **kwargs,
347
- ) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]:
348
- """
349
- Generates sequences of token ids and optionally speech outputs.
350
-
351
- Args:
352
- All standard generation arguments from GenerationMixin
353
- negative_prompt_ids: Negative prompt for CFG in speech generation
354
- negative_prompt_attention_mask: Attention mask for negative prompt
355
- speech_tensors: Input speech for voice cloning
356
- speech_masks: Masks for speech tensors
357
- speech_input_mask: Positions to insert speech embeddings
358
- return_speech: Whether to decode and return speech outputs
359
- cfg_scale: CFG scale for speech generation
360
- stop_check_fn: Optional callable that returns True if generation should stop
361
-
362
- Returns:
363
- Generated token sequences and optionally speech outputs
364
- """
365
- # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
366
- tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
367
- parsed_scripts = kwargs.pop("parsed_scripts", None)
368
- all_speakers_list = kwargs.pop("all_speakers_list", None)
369
- max_length_times = kwargs.pop("max_length_times", 2)
370
-
371
- if kwargs.get('max_new_tokens', None) is None:
372
- kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - kwargs['input_ids'].shape[-1]
373
-
374
- generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs(
375
- generation_config, inputs, tokenizer, return_processors=True, **kwargs
376
- )
377
-
378
- negative_kwargs = {
379
- 'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), tokenizer.speech_start_id, dtype=torch.long, device=kwargs['input_ids'].device),
380
- 'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device),
381
- 'max_new_tokens': kwargs.get('max_new_tokens', 100)
382
- }
383
- negative_generation_config, negative_model_kwargs, negative_input_ids = self._build_generate_config_model_kwargs(
384
- None, None, tokenizer, return_processors=False, **negative_kwargs
385
- )
386
-
387
- acoustic_cache = VibeVoiceTokenizerStreamingCache()
388
- semantic_cache = VibeVoiceTokenizerStreamingCache()
389
-
390
- batch_size = input_ids.shape[0]
391
- device = input_ids.device
392
- finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
393
- correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
394
- is_prefill = True
395
- inputs_embeds = None
396
- verbose = kwargs.get("verbose", False)
397
-
398
- # Initialize audio chunks storage for each sample
399
- audio_chunks = [[] for _ in range(batch_size)]
400
-
401
- initial_length = input_ids.shape[-1]
402
- initial_length_per_sample = model_kwargs['attention_mask'].sum(dim=-1)
403
-
404
- # Define all valid tokens that can be generated
405
- valid_tokens = [
406
- generation_config.speech_start_id,
407
- generation_config.speech_end_id,
408
- generation_config.speech_diffusion_id,
409
- generation_config.eos_token_id
410
- ]
411
- # Add bos_token_id if it exists
412
- if hasattr(generation_config, 'bos_token_id') and generation_config.bos_token_id is not None:
413
- valid_tokens.append(generation_config.bos_token_id)
414
-
415
- # Add custom processor to constrain token generation
416
- token_constraint_processor = VibeVoiceTokenConstraintProcessor(valid_tokens, device=device)
417
- if logits_processor is None:
418
- logits_processor = LogitsProcessorList()
419
- logits_processor.append(token_constraint_processor)
420
-
421
- max_steps = min(generation_config.max_length - initial_length, int(max_length_times * initial_length))
422
- max_step_per_sample = torch.min(generation_config.max_length - initial_length_per_sample, (max_length_times * initial_length_per_sample).long())
423
- reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
424
-
425
- # Create progress iterator if verbose
426
- if kwargs.get("show_progress_bar", True):
427
- progress_bar = tqdm(range(max_steps), desc="Generating", leave=False)
428
- else:
429
- progress_bar = range(max_steps)
430
-
431
- for step in progress_bar:
432
- # Check for external stop signal
433
- if stop_check_fn is not None and stop_check_fn():
434
- if verbose:
435
- print(f"Generation stopped externally at step {step + 1}")
436
- # End the audio streamer if it exists
437
- if audio_streamer is not None:
438
- audio_streamer.end()
439
- break
440
-
441
- # Check if audio_streamer has been ended (stopped externally)
442
- if audio_streamer is not None and hasattr(audio_streamer, 'finished_flags'):
443
- if any(audio_streamer.finished_flags):
444
- if verbose:
445
- print(f"Audio generation stopped externally at step {step + 1}")
446
- break
447
-
448
- if finished_tags.all():
449
- if hasattr(progress_bar, 'set_description'):
450
- progress_bar.set_description("Generation complete")
451
- break
452
-
453
- if input_ids.shape[-1] >= generation_config.max_length:
454
- print(f"Reached maximum generation length {generation_config.max_length}, stopped it.")
455
- reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
456
- if reached_samples.numel() > 0:
457
- reach_max_step_sample[reached_samples] = True
458
- break
459
-
460
- # Update progress bar description with active samples
461
- if hasattr(progress_bar, 'set_description'):
462
- active_samples = (~finished_tags).sum().item()
463
- progress_bar.set_description(f"Generating (active: {active_samples}/{batch_size})")
464
-
465
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
466
- if is_prefill:
467
- # we process the speech inputs only during the first generation step
468
- prefill_inputs = {
469
- "speech_tensors": speech_tensors.to(device=device),
470
- "speech_masks": speech_masks.to(device),
471
- "speech_input_mask": speech_input_mask.to(device),
472
- }
473
- is_prefill = False
474
- else:
475
- _ = model_inputs.pop('inputs_embeds', None)
476
- prefill_inputs = {'inputs_embeds': inputs_embeds}
477
-
478
- # Forward pass through the model
479
- outputs = self(
480
- **model_inputs, **prefill_inputs, logits_to_keep=1, return_dict=True, output_attentions=False, output_hidden_states=False,
481
- )
482
- model_kwargs = self._update_model_kwargs_for_generation(
483
- outputs, model_kwargs, is_encoder_decoder=False,
484
- )
485
-
486
- # Get logits and apply logits processor
487
- next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
488
- # next_token_logits = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
489
- next_token_scores = logits_processor(input_ids, next_token_logits)
490
-
491
- # token selection
492
- if generation_config.do_sample:
493
- probs = nn.functional.softmax(next_token_scores, dim=-1)
494
- # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
495
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
496
- else:
497
- next_tokens = torch.argmax(next_token_scores, dim=-1)
498
-
499
- next_tokens[finished_tags] = generation_config.eos_token_id
500
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
501
-
502
- if not kwargs.get('refresh_negative', True):
503
- negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
504
- # Forward negative pass through the model
505
- if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
506
- negative_model_inputs['inputs_embeds'] = inputs_embeds
507
- negative_model_inputs['input_ids'] = None
508
-
509
- negative_outputs = self(
510
- **negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
511
- )
512
- negative_model_kwargs = self._update_model_kwargs_for_generation(
513
- negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
514
- )
515
- negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
516
-
517
- # reached end of generation
518
- if (next_tokens == generation_config.eos_token_id).any():
519
- eos_indices = (next_tokens == generation_config.eos_token_id).nonzero(as_tuple=False).squeeze(1)
520
- # Only print for samples that are newly finished (not already marked as finished)
521
- new_eos_indices = eos_indices[~finished_tags[eos_indices]]
522
- if new_eos_indices.numel() > 0:
523
- finished_tags[new_eos_indices] = True
524
- if verbose:
525
- print(f"Samples {new_eos_indices.tolist()} reached EOS token at step {step + 1}.", flush=True)
526
- if audio_streamer is not None:
527
- audio_streamer.end(new_eos_indices)
528
-
529
- # Check if any sample reached its maximum generation length
530
- max_length_reached = step >= max_step_per_sample
531
- new_max_length_indices = torch.nonzero(max_length_reached & ~finished_tags, as_tuple=False).squeeze(1)
532
- if new_max_length_indices.numel() > 0:
533
- finished_tags[new_max_length_indices] = True
534
- reach_max_step_sample[new_max_length_indices] = True
535
- if verbose:
536
- print(f"Samples {new_max_length_indices.tolist()} reached max generation length at step {step + 1}.", flush=True)
537
- if audio_streamer is not None:
538
- audio_streamer.end(new_max_length_indices)
539
-
540
- # speech_end
541
- diffusion_end_indices = (next_tokens == generation_config.speech_end_id).nonzero(as_tuple=False).squeeze(1)
542
- if diffusion_end_indices.numel() > 0:
543
- # Clear tokenizer caches for samples that reached speech end
544
- acoustic_cache.set_to_zero(diffusion_end_indices)
545
- semantic_cache.set_to_zero(diffusion_end_indices)
546
-
547
- # speech_begin
548
- diffusion_start_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_start_id)]
549
- if diffusion_start_indices.numel() > 0 and kwargs.get('refresh_negative', True):
550
- # update attention mask
551
- for i, sample_idx in enumerate(diffusion_start_indices.tolist()):
552
- negative_model_kwargs['attention_mask'][sample_idx, :] = 0
553
- negative_model_kwargs['attention_mask'][sample_idx, -1] = 1
554
- # update past key values
555
- for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
556
- negative_model_kwargs['past_key_values'].value_cache)):
557
- # Process each non-diffusion sample
558
- for sample_idx in diffusion_start_indices.tolist():
559
- # Shift cache for this sample
560
- k_cache[sample_idx, :, -1, :] = k_cache[sample_idx, :, 0, :].clone()
561
- v_cache[sample_idx, :, -1, :] = v_cache[sample_idx, :, 0, :].clone()
562
- # update negative_input_ids
563
- for sample_idx in diffusion_start_indices.tolist():
564
- negative_input_ids[sample_idx, -1] = generation_config.speech_start_id
565
-
566
- # Prepare inputs_embeds for next iteration
567
- # Initialize with default embeddings for all tokens
568
- next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1) # [batch_size, 1, hidden_size]
569
-
570
- # forward diffusion
571
- # Diffusion indices are those that are not finished and not special tokens
572
- diffusion_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_diffusion_id)]
573
-
574
- if diffusion_indices.numel() > 0:
575
- if kwargs.get('refresh_negative', True):
576
- negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
577
- # Forward negative pass through the model
578
- if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
579
- negative_model_inputs['inputs_embeds'] = inputs_embeds
580
- negative_model_inputs['input_ids'] = None
581
-
582
- negative_outputs = self(
583
- **negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
584
- )
585
- negative_model_kwargs = self._update_model_kwargs_for_generation(
586
- negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
587
- )
588
- negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
589
- # correct the non-diffusion indices
590
- # we forward all samples' negative outputs even if
591
- # they are not in diffusion mode to keep the cache consistent
592
- # So we need to correct the kv cache of non-diffusion samples
593
- non_diffusion_mask = ~finished_tags & (next_tokens != generation_config.speech_diffusion_id)
594
- if non_diffusion_mask.any():
595
- non_diffusion_indices = torch.arange(batch_size, device=device)[non_diffusion_mask]
596
- start_indices = correct_cnt[non_diffusion_indices]
597
-
598
- # 1. Update attention_mask - need to handle each sample separately
599
- seq_len = negative_model_kwargs['attention_mask'].shape[1]
600
- for i, (sample_idx, start_idx) in enumerate(zip(non_diffusion_indices.tolist(), start_indices.tolist())):
601
- # Shift the attention mask for this sample
602
- if start_idx + 1 < seq_len - 1:
603
- negative_model_kwargs['attention_mask'][sample_idx, start_idx+1:] = \
604
- negative_model_kwargs['attention_mask'][sample_idx, start_idx:-1].clone()
605
- negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0
606
-
607
- # 2. Update past_key_values
608
- for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
609
- negative_model_kwargs['past_key_values'].value_cache)):
610
- # Process each non-diffusion sample
611
- for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
612
- if start_idx + 1 < k_cache.shape[2] - 1:
613
- # Shift cache for this sample
614
- k_cache[sample_idx, :, start_idx+1:, :] = k_cache[sample_idx, :, start_idx:-1, :].clone()
615
- v_cache[sample_idx, :, start_idx+1:, :] = v_cache[sample_idx, :, start_idx:-1, :].clone()
616
-
617
- # 3. Update negative_input_ids
618
- for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
619
- if start_idx + 1 < negative_input_ids.shape[1] - 1:
620
- negative_input_ids[sample_idx, start_idx+1:] = \
621
- negative_input_ids[sample_idx, start_idx:-1].clone()
622
-
623
- correct_cnt[non_diffusion_indices] += 1
624
-
625
- positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]
626
- negative_condition = negative_outputs.last_hidden_state[diffusion_indices, -1, :]
627
-
628
- speech_latent = self.sample_speech_tokens(
629
- positive_condition,
630
- negative_condition,
631
- cfg_scale=cfg_scale,
632
- ).unsqueeze(1)
633
-
634
- # Decode acoustic latent to audio using acoustic streaming cache
635
- scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device)
636
- audio_chunk = self.model.acoustic_tokenizer.decode(
637
- scaled_latent.to(self.model.acoustic_tokenizer.device),
638
- cache=acoustic_cache, # Use acoustic-specific cache
639
- sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
640
- use_cache=True,
641
- debug=False
642
- )
643
-
644
- # Store audio chunks for each sample
645
- for i, sample_idx in enumerate(diffusion_indices):
646
- idx = sample_idx.item()
647
- # Only append audio chunk if the sample is not finished
648
- if not finished_tags[idx]:
649
- audio_chunks[idx].append(audio_chunk[i])
650
-
651
- # Add streaming support here
652
- if audio_streamer is not None:
653
- # Stream the audio chunks immediately
654
- audio_streamer.put(audio_chunk, diffusion_indices)
655
-
656
- # Encode audio to semantic features using semantic streaming cache
657
- semantic_features = self.model.semantic_tokenizer.encode(
658
- audio_chunk,
659
- cache=semantic_cache, # Use semantic-specific cache
660
- sample_indices=diffusion_indices,
661
- use_cache=True,
662
- debug=False
663
- ).mean # semantic tokenizer has no VAE.
664
-
665
- # Combine acoustic and semantic features for next input
666
- acoustic_embed = self.model.acoustic_connector(speech_latent)
667
- semantic_embed = self.model.semantic_connector(semantic_features)
668
- diffusion_embeds = acoustic_embed + semantic_embed
669
-
670
- # Update embeddings for diffusion indices
671
- next_inputs_embeds[diffusion_indices] = diffusion_embeds
672
-
673
- # Set inputs_embeds for next iteration
674
- inputs_embeds = next_inputs_embeds
675
-
676
- if audio_streamer is not None:
677
- audio_streamer.end()
678
-
679
- # Concatenate audio chunks for each sample
680
- final_audio_outputs = []
681
- for sample_chunks in audio_chunks:
682
- if sample_chunks:
683
- # Concatenate all chunks along the time dimension (assumed to be the last dimension)
684
- concatenated_audio = torch.cat(sample_chunks, dim=-1)
685
- final_audio_outputs.append(concatenated_audio)
686
- else:
687
- # If no audio was generated for this sample, append None
688
- final_audio_outputs.append(None)
689
-
690
- return VibeVoiceGenerationOutput(
691
- sequences=input_ids,
692
- speech_outputs=final_audio_outputs if return_speech else None,
693
- reach_max_step_sample=reach_max_step_sample,
694
- )
695
-
696
- @torch.no_grad()
697
- def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
698
- self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
699
- condition = torch.cat([condition, neg_condition], dim=0).to(self.model.prediction_head.device)
700
- speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition)
701
- for t in self.model.noise_scheduler.timesteps:
702
- half = speech[: len(speech) // 2]
703
- combined = torch.cat([half, half], dim=0)
704
- eps = self.model.prediction_head(combined, t.repeat(combined.shape[0]).to(combined), condition=condition)
705
- cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
706
- half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
707
- eps = torch.cat([half_eps, half_eps], dim=0)
708
- speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
709
- return speech[: len(speech) // 2]
710
-
711
-
712
- AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGenerationInference)
713
-
714
- __all__ = [
715
- "VibeVoiceForConditionalGenerationInference",
716
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/modular/modular_vibevoice_diffusion_head.py DELETED
@@ -1,287 +0,0 @@
1
- import math
2
- from typing import Optional, Tuple, Union
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- from transformers.models.auto import AutoModel
9
- from transformers.modeling_utils import PreTrainedModel
10
- # from transformers.modeling_layers import GradientCheckpointingLayer
11
- from transformers.activations import ACT2FN
12
- from transformers.utils import logging
13
-
14
- from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
15
-
16
-
17
- logger = logging.get_logger(__name__)
18
-
19
-
20
- class RMSNorm(nn.Module):
21
- def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
22
- super().__init__()
23
- self.dim = dim
24
- self.eps = eps
25
- self.elementwise_affine = elementwise_affine
26
- if self.elementwise_affine:
27
- self.weight = nn.Parameter(torch.ones(dim))
28
- else:
29
- self.register_parameter('weight', None)
30
-
31
- def _norm(self, x):
32
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
33
-
34
- def forward(self, x):
35
- output = self._norm(x.float()).type_as(x)
36
- if self.weight is not None:
37
- output = output * self.weight
38
- return output
39
-
40
- def extra_repr(self) -> str:
41
- return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
42
-
43
- def modulate(x, shift, scale):
44
- """Apply modulation to input tensor."""
45
- return x * (1 + scale) + shift
46
-
47
-
48
- class TimestepEmbedder(nn.Module):
49
- """
50
- Embeds scalar timesteps into vector representations.
51
-
52
- Args:
53
- hidden_size (`int`): Size of the output embedding
54
- frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
55
- """
56
- def __init__(self, hidden_size, frequency_embedding_size=256):
57
- super().__init__()
58
- self.mlp = nn.Sequential(
59
- nn.Linear(frequency_embedding_size, hidden_size, bias=False),
60
- # nn.SiLU(),
61
- ACT2FN['silu'],
62
- nn.Linear(hidden_size, hidden_size, bias=False),
63
- )
64
- self.frequency_embedding_size = frequency_embedding_size
65
-
66
- @staticmethod
67
- def timestep_embedding(t, dim, max_period=10000):
68
- """
69
- Create sinusoidal timestep embeddings.
70
-
71
- Args:
72
- t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
73
- These may be fractional.
74
- dim (`int`): The dimension of the output.
75
- max_period (`int`, optional): Controls the minimum frequency of the embeddings.
76
-
77
- Returns:
78
- `torch.Tensor`: An [N, D] Tensor of positional embeddings.
79
- """
80
- half = dim // 2
81
- freqs = torch.exp(
82
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
83
- ).to(t.device)
84
- args = t[:, None].float() * freqs[None]
85
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
86
- if dim % 2:
87
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
88
- return embedding.to(t.dtype)
89
-
90
- def forward(self, t):
91
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
92
- t_emb = self.mlp(t_freq)
93
- return t_emb
94
-
95
-
96
- class FeedForwardNetwork(nn.Module):
97
- """
98
- Standard feed-forward network with SwiGLU activation.
99
-
100
- Args:
101
- embed_dim (`int`): Input dimension
102
- ffn_dim (`int`): Hidden dimension
103
- """
104
- def __init__(
105
- self,
106
- embed_dim,
107
- ffn_dim,
108
- ):
109
- super().__init__()
110
- self.embed_dim = embed_dim
111
- self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
112
- self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
113
- self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
114
- self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
115
-
116
- def forward(self, x):
117
- gate = self.gate_proj(x)
118
- up = self.up_proj(x)
119
-
120
- # SwiGLU activation
121
- # gate = F.silu(gate)
122
- gate = self.act_fn(gate)
123
- return self.down_proj(gate * up)
124
-
125
-
126
- class HeadLayer(nn.Module):
127
- """
128
- A layer in the diffusion head.
129
-
130
- Args:
131
- embed_dim (`int`): Input dimension
132
- ffn_dim (`int`): Hidden dimension
133
- cond_dim (`int`): Condition embedding dimension
134
- norm_eps (`float`, optional): Epsilon for normalization
135
- """
136
- def __init__(
137
- self,
138
- embed_dim,
139
- ffn_dim,
140
- cond_dim,
141
- norm_eps=1e-5,
142
- ):
143
- super().__init__()
144
- self.embed_dim = embed_dim
145
- self.cond_dim = cond_dim
146
- self.ffn_dim = ffn_dim
147
- self.ffn = FeedForwardNetwork(
148
- self.embed_dim,
149
- self.ffn_dim,
150
- )
151
- self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
152
- self.adaLN_modulation = nn.Sequential(
153
- # nn.SiLU(),
154
- ACT2FN['silu'],
155
- nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
156
- )
157
-
158
- def forward(self, x, c):
159
- shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
160
- x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
161
- return x
162
-
163
-
164
- class FinalLayer(nn.Module):
165
- """
166
- Final layer in the diffusion head.
167
-
168
- Args:
169
- hidden_size (`int`): Input dimension
170
- output_size (`int`): Output dimension
171
- cond_size (`int`): Condition embedding dimension
172
- norm_eps (`float`, optional): Epsilon for normalization
173
- """
174
- def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
175
- super().__init__()
176
- self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
177
- self.linear = nn.Linear(hidden_size, output_size, bias=False)
178
- self.adaLN_modulation = nn.Sequential(
179
- # nn.SiLU(),
180
- ACT2FN['silu'],
181
- nn.Linear(cond_size, 2 * hidden_size, bias=False)
182
- )
183
-
184
- def forward(self, x, c):
185
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
186
- x = modulate(self.norm_final(x), shift, scale)
187
- x = self.linear(x)
188
- return x
189
-
190
-
191
- class VibeVoiceDiffusionHead(PreTrainedModel):
192
- """
193
- Diffusion head model for vibevoice.
194
-
195
- Args:
196
- config (`VibeVoiceDiffusionHeadConfig`): Model configuration
197
- latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
198
- """
199
- config_class = VibeVoiceDiffusionHeadConfig
200
- supports_gradient_checkpointing = True
201
- _supports_flash_attn_2 = True
202
- _supports_sdpa = True
203
-
204
- def __init__(
205
- self,
206
- config,
207
- ):
208
- super().__init__(config)
209
- self.config = config
210
- self.cond_dim = config.hidden_size
211
- latent_size = config.latent_size
212
-
213
- self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
214
- self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
215
- self.t_embedder = TimestepEmbedder(self.cond_dim)
216
-
217
- ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
218
-
219
- # Create the intermediate layers
220
- self.layers = nn.ModuleList([
221
- HeadLayer(
222
- embed_dim=config.hidden_size,
223
- ffn_dim=ffn_dim,
224
- cond_dim=self.cond_dim,
225
- norm_eps=config.rms_norm_eps
226
- )
227
- for _ in range(config.head_layers)
228
- ])
229
-
230
- # Final layer for output
231
- self.final_layer = FinalLayer(
232
- hidden_size=config.hidden_size,
233
- output_size=latent_size,
234
- cond_size=self.cond_dim,
235
- norm_eps=config.rms_norm_eps
236
- )
237
-
238
- self.initialize_weights()
239
-
240
- def initialize_weights(self):
241
- """Initialize the weights of the model."""
242
- # Initialize timestep embedder
243
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
244
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
245
-
246
- # Zero-out adaLN modulation layers
247
- for layer in self.layers:
248
- nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
249
-
250
- # Zero-out output layers
251
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
252
- nn.init.constant_(self.final_layer.linear.weight, 0)
253
-
254
- def forward(
255
- self,
256
- noisy_images,
257
- timesteps,
258
- condition,
259
- ):
260
- """
261
- Forward pass of the prediction head.
262
-
263
- Args:
264
- noisy_images (`torch.Tensor`): Noisy images/latents to denoise
265
- timesteps (`torch.Tensor`): Timesteps for diffusion
266
- condition (`torch.Tensor`): Conditioning information
267
-
268
- Returns:
269
- `torch.Tensor`: The predicted noise/velocity
270
- """
271
- x = self.noisy_images_proj(noisy_images)
272
- t = self.t_embedder(timesteps)
273
- condition = self.cond_proj(condition)
274
- c = condition + t
275
-
276
- for layer in self.layers:
277
- x = layer(x, c)
278
-
279
- x = self.final_layer(x, c)
280
- return x
281
-
282
-
283
- AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
284
-
285
- __all__ = [
286
- "VibeVoiceDiffusionHead",
287
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/modular/modular_vibevoice_text_tokenizer.py DELETED
@@ -1,214 +0,0 @@
1
- """Tokenization classes for vibevoice."""
2
-
3
- from typing import List, Optional, Union
4
-
5
- from transformers.utils import logging
6
- from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
7
- from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
8
-
9
- logger = logging.get_logger(__name__)
10
-
11
-
12
- class VibeVoiceTextTokenizer(Qwen2Tokenizer):
13
- """
14
- Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech.
15
-
16
- Args:
17
- vocab_file (`str`):
18
- Path to the vocabulary file.
19
- merges_file (`str`):
20
- Path to the merges file.
21
- errors (`str`, *optional*, defaults to `"replace"`):
22
- Paradigm to follow when decoding bytes to UTF-8.
23
- unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
24
- The unknown token.
25
- bos_token (`str`, *optional*):
26
- The beginning of sequence token. Not used for vibevoice.
27
- eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
28
- The end of sequence token.
29
- pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
30
- The token used for padding.
31
- add_special_tokens (`bool`, *optional*, defaults to `True`):
32
- Whether or not to add special tokens when encoding.
33
- """
34
-
35
- model_input_names = ["input_ids", "attention_mask"]
36
-
37
- def __init__(
38
- self,
39
- vocab_file,
40
- merges_file,
41
- errors="replace",
42
- unk_token="<|endoftext|>",
43
- bos_token=None,
44
- eos_token="<|endoftext|>",
45
- pad_token="<|endoftext|>",
46
- add_prefix_space=False,
47
- add_special_tokens=True,
48
- **kwargs,
49
- ):
50
- super().__init__(
51
- vocab_file=vocab_file,
52
- merges_file=merges_file,
53
- errors=errors,
54
- unk_token=unk_token,
55
- bos_token=bos_token,
56
- eos_token=eos_token,
57
- pad_token=pad_token,
58
- add_prefix_space=add_prefix_space,
59
- add_special_tokens=add_special_tokens,
60
- **kwargs,
61
- )
62
-
63
- # Add VibeVoice-specific special tokens
64
- self._add_vibevoice_special_tokens()
65
-
66
- def _add_vibevoice_special_tokens(self):
67
- """Add VibeVoice-specific special tokens."""
68
- special_tokens = {
69
- "additional_special_tokens": [
70
- "<|vision_start|>", # Speech start (reusing vision tokens)
71
- "<|vision_end|>", # Speech end
72
- "<|vision_pad|>", # Speech diffusion pad
73
- ]
74
- }
75
- num_added = self.add_special_tokens(special_tokens)
76
-
77
- # Cache special token IDs
78
- self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
79
- self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
80
- self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
81
-
82
- self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
83
-
84
- return num_added
85
-
86
- @property
87
- def eos_id(self) -> int:
88
- """Id of the end of sequence token."""
89
- return self._eos_id
90
-
91
- @property
92
- def speech_start_id(self) -> int:
93
- """Id of the speech start token."""
94
- return self._speech_start_id
95
-
96
- @property
97
- def speech_end_id(self) -> int:
98
- """Id of the speech end token."""
99
- return self._speech_end_id
100
-
101
- @property
102
- def speech_diffusion_id(self) -> int:
103
- """Id of the speech diffusion token."""
104
- return self._speech_diffusion_id
105
-
106
- @property
107
- def pad_id(self) -> int:
108
- """Id used for padding (returns -100 for loss masking)."""
109
- return -100
110
-
111
-
112
- class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
113
- """
114
- Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
115
- Based on the Qwen2 tokenizer with additional special tokens for speech.
116
-
117
- Args:
118
- vocab_file (`str`, *optional*):
119
- Path to the vocabulary file.
120
- merges_file (`str`, *optional*):
121
- Path to the merges file.
122
- tokenizer_file (`str`, *optional*):
123
- Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
124
- unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
125
- The unknown token.
126
- bos_token (`str`, *optional*):
127
- The beginning of sequence token. Not used for vibevoice.
128
- eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
129
- The end of sequence token.
130
- pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
131
- The token used for padding.
132
- """
133
-
134
- model_input_names = ["input_ids", "attention_mask"]
135
-
136
- def __init__(
137
- self,
138
- vocab_file=None,
139
- merges_file=None,
140
- tokenizer_file=None,
141
- unk_token="<|endoftext|>",
142
- bos_token=None,
143
- eos_token="<|endoftext|>",
144
- pad_token="<|endoftext|>",
145
- add_prefix_space=False,
146
- **kwargs,
147
- ):
148
- super().__init__(
149
- vocab_file=vocab_file,
150
- merges_file=merges_file,
151
- tokenizer_file=tokenizer_file,
152
- unk_token=unk_token,
153
- bos_token=bos_token,
154
- eos_token=eos_token,
155
- pad_token=pad_token,
156
- add_prefix_space=add_prefix_space,
157
- **kwargs,
158
- )
159
-
160
- # Add VibeVoice-specific special tokens
161
- self._add_vibevoice_special_tokens()
162
-
163
- def _add_vibevoice_special_tokens(self):
164
- """Add VibeVoice-specific special tokens."""
165
- special_tokens = {
166
- "additional_special_tokens": [
167
- "<|vision_start|>", # Speech start (reusing vision tokens)
168
- "<|vision_end|>", # Speech end
169
- "<|vision_pad|>", # Speech diffusion pad
170
- ]
171
- }
172
- num_added = self.add_special_tokens(special_tokens)
173
-
174
- # Cache special token IDs
175
- self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
176
- self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
177
- self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
178
-
179
- # self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
180
- self._eos_id = self.eos_token_id # qwen2 / qwen3
181
- self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
182
-
183
- return num_added
184
-
185
- @property
186
- def eos_id(self) -> int:
187
- """Id of the end of sequence token."""
188
- return self._eos_id
189
-
190
- @property
191
- def speech_start_id(self) -> int:
192
- """Id of the speech start token."""
193
- return self._speech_start_id
194
-
195
- @property
196
- def speech_end_id(self) -> int:
197
- """Id of the speech end token."""
198
- return self._speech_end_id
199
-
200
- @property
201
- def speech_diffusion_id(self) -> int:
202
- """Id of the speech diffusion token."""
203
- return self._speech_diffusion_id
204
-
205
- @property
206
- def pad_id(self) -> int:
207
- """Id used for padding (returns -100 for loss masking)."""
208
- return self._pad_id
209
-
210
-
211
- __all__ = [
212
- "VibeVoiceTextTokenizer",
213
- "VibeVoiceTextTokenizerFast",
214
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/modular/modular_vibevoice_tokenizer.py DELETED
@@ -1,1195 +0,0 @@
1
- import math
2
- import typing as tp
3
- from functools import partial
4
- from dataclasses import dataclass, field
5
- from typing import Dict, List, Optional, Tuple, Union
6
- import copy
7
-
8
- import numpy as np
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
-
13
- from transformers.models.auto import AutoModel
14
-
15
- from transformers.configuration_utils import PretrainedConfig
16
- from transformers.utils import logging
17
- from transformers.modeling_utils import PreTrainedModel
18
- from transformers.activations import ACT2FN
19
-
20
- from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig
21
-
22
- logger = logging.get_logger(__name__)
23
-
24
- import os
25
- # Try to import APEX FusedRMSNorm
26
- try:
27
- from apex.normalization.fused_layer_norm import fused_rms_norm_affine
28
- APEX_AVAILABLE = True
29
- logger.info("APEX FusedRMSNorm is available and will be used for optimization")
30
- if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
31
- APEX_AVAILABLE = False
32
- logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
33
- except ImportError:
34
- APEX_AVAILABLE = False
35
- logger.warning("APEX FusedRMSNorm not available, using native implementation")
36
- # APEX_AVAILABLE=False
37
-
38
- # Normalization modules
39
- class ConvLayerNorm(nn.LayerNorm):
40
- """
41
- Convolution-friendly LayerNorm that moves channels to last dimensions
42
- before running the normalization and moves them back to original position right after.
43
- """
44
- def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
45
- super().__init__(normalized_shape, **kwargs)
46
-
47
- def forward(self, x):
48
- x = x.transpose(1, 2) # b ... t -> b t ...
49
- x = nn.functional.layer_norm(x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(x)
50
- x = x.transpose(1, 2) # b t ... -> b ... t
51
- return x
52
-
53
- class RMSNorm(nn.Module):
54
- def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
55
- super().__init__()
56
- self.dim = dim
57
- self.eps = eps
58
- self.elementwise_affine = elementwise_affine
59
- if self.elementwise_affine:
60
- weight_shape = (dim,) if weight_shape is None else weight_shape
61
- self.weight = nn.Parameter(torch.ones(weight_shape))
62
- else:
63
- self.register_parameter('weight', None)
64
-
65
- def _norm(self, x):
66
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
67
-
68
- def forward(self, x):
69
- output = self._norm(x.float()).type_as(x)
70
- if self.weight is not None:
71
- output = output * self.weight
72
- return output
73
-
74
- def extra_repr(self) -> str:
75
- return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
76
-
77
- class ConvRMSNorm(RMSNorm):
78
- def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
79
- super().__init__(dim, eps, elementwise_affine, weight_shape)
80
-
81
- def forward(self, x):
82
- x = x.transpose(1, 2) # b ... t -> b t ...
83
- if (not APEX_AVAILABLE) or (not self.elementwise_affine):
84
- # Fallback to native implementation
85
- output = self._norm(x.float()).type_as(x)
86
- if self.weight is not None:
87
- output = output * self.weight
88
- else:
89
- output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps)
90
- output = output.transpose(1, 2) # b t ... -> b ... t
91
- return output
92
-
93
- # Convolutional layers and utilities
94
- CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
95
- 'time_layer_norm', 'layer_norm', 'time_group_norm'])
96
-
97
-
98
- def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
99
- assert norm in CONV_NORMALIZATIONS
100
- if norm == 'weight_norm':
101
- return nn.utils.weight_norm(module)
102
- elif norm == 'spectral_norm':
103
- return nn.utils.spectral_norm(module)
104
- else:
105
- # We already check was in CONV_NORMALIZATION, so any other choice
106
- # doesn't need reparametrization.
107
- return module
108
-
109
-
110
- def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
111
- """Return the proper normalization module. If causal is True, this will ensure the returned
112
- module is causal, or return an error if the normalization doesn't support causal evaluation.
113
- """
114
- assert norm in CONV_NORMALIZATIONS
115
- if norm == 'layer_norm':
116
- assert isinstance(module, nn.modules.conv._ConvNd)
117
- return ConvLayerNorm(module.out_channels, **norm_kwargs)
118
- elif norm == 'time_group_norm':
119
- if causal:
120
- raise ValueError("GroupNorm doesn't support causal evaluation.")
121
- assert isinstance(module, nn.modules.conv._ConvNd)
122
- return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
123
- else:
124
- return nn.Identity()
125
-
126
-
127
- def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
128
- padding_total: int = 0) -> int:
129
- """Calculate extra padding needed for convolution to have the same output length"""
130
- length = x.shape[-1]
131
- n_frames = (length - kernel_size + padding_total) / stride + 1
132
- ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
133
- return ideal_length - length
134
-
135
-
136
- def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
137
- """Pad 1D input with handling for small inputs in reflect mode"""
138
- length = x.shape[-1]
139
- padding_left, padding_right = paddings
140
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
141
- if mode == 'reflect':
142
- max_pad = max(padding_left, padding_right)
143
- extra_pad = 0
144
- if length <= max_pad:
145
- extra_pad = max_pad - length + 1
146
- x = F.pad(x, (0, extra_pad))
147
- padded = F.pad(x, paddings, mode, value)
148
- end = padded.shape[-1] - extra_pad
149
- return padded[..., :end]
150
- else:
151
- return F.pad(x, paddings, mode, value)
152
-
153
-
154
- def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
155
- """Remove padding from x, handling properly zero padding. Only for 1d!"""
156
- padding_left, padding_right = paddings
157
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
158
- assert (padding_left + padding_right) <= x.shape[-1]
159
- end = x.shape[-1] - padding_right
160
- return x[..., padding_left: end]
161
-
162
-
163
- class NormConv1d(nn.Module):
164
- """Wrapper around Conv1d and normalization applied to this conv"""
165
- def __init__(self, *args, causal: bool = False, norm: str = 'none',
166
- norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
167
- super().__init__()
168
- self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
169
- self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
170
- self.norm_type = norm
171
-
172
- def forward(self, x):
173
- x = self.conv(x)
174
- x = self.norm(x)
175
- return x
176
-
177
-
178
- class NormConvTranspose1d(nn.Module):
179
- """Wrapper around ConvTranspose1d and normalization applied to this conv"""
180
- def __init__(self, *args, causal: bool = False, norm: str = 'none',
181
- norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
182
- super().__init__()
183
- self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
184
- self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
185
- self.norm_type = norm
186
-
187
- def forward(self, x):
188
- x = self.convtr(x)
189
- x = self.norm(x)
190
- return x
191
-
192
-
193
- class VibeVoiceTokenizerStreamingCache:
194
- """Cache for streaming convolution, similar to KV cache in attention"""
195
- def __init__(self):
196
- self.cache = {} # Dict mapping (layer_id, sample_idx) to state tensor
197
-
198
- def get(self, layer_id: str, sample_indices: torch.Tensor) -> Optional[torch.Tensor]:
199
- """Get cached states for given layer and sample indices"""
200
- states = []
201
- max_length = 0
202
-
203
- # First pass: collect states and find max length
204
- for idx in sample_indices.tolist():
205
- key = (layer_id, idx)
206
- if key not in self.cache:
207
- return None # If any sample is missing, return None
208
- state = self.cache[key]
209
- states.append(state)
210
- max_length = max(max_length, state.shape[-1])
211
-
212
- # Second pass: pad states to max length if needed
213
- if len(states) > 0 and states[0].dim() >= 2:
214
- padded_states = []
215
- for state in states:
216
- if state.shape[-1] < max_length:
217
- # Pad on the time dimension (last dimension)
218
- pad_size = max_length - state.shape[-1]
219
- # Pad with zeros on the LEFT to align the most recent samples
220
- padded_state = F.pad(state, (pad_size, 0), mode='constant', value=0)
221
- padded_states.append(padded_state)
222
- else:
223
- padded_states.append(state)
224
- return torch.stack(padded_states, dim=0)
225
- else:
226
- return torch.stack(states, dim=0)
227
-
228
- def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor):
229
- """Set cached states for given layer and sample indices"""
230
- for i, idx in enumerate(sample_indices.tolist()):
231
- key = (layer_id, idx)
232
- self.cache[key] = states[i].detach()
233
-
234
- def set_to_zero(self, sample_indices: torch.Tensor):
235
- """Set all cached states to zero for given sample indices"""
236
- for key in list(self.cache.keys()):
237
- layer_id, sample_idx = key
238
- if sample_idx in sample_indices.tolist():
239
- # Create zero tensor with same shape and dtype as cached tensor
240
- cached_tensor = self.cache[key]
241
- self.cache[key] = torch.zeros_like(cached_tensor)
242
-
243
- def clear(self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None):
244
- """Clear cache for specific layer/samples or everything"""
245
- if layer_id is None and sample_indices is None:
246
- self.cache.clear()
247
- elif layer_id is not None and sample_indices is None:
248
- # Clear all samples for a specific layer
249
- keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id]
250
- for k in keys_to_remove:
251
- del self.cache[k]
252
- elif layer_id is not None and sample_indices is not None:
253
- # Clear specific samples for a specific layer
254
- for idx in sample_indices.tolist():
255
- key = (layer_id, idx)
256
- self.cache.pop(key, None)
257
-
258
- class SConv1d(nn.Module):
259
- """Conv1d with built-in handling of asymmetric or causal padding and normalization."""
260
- def __init__(self, in_channels: int, out_channels: int,
261
- kernel_size: int, stride: int = 1, dilation: int = 1,
262
- groups: int = 1, bias: bool = True, causal: bool = False,
263
- norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
264
- pad_mode: str = 'reflect'):
265
- super().__init__()
266
- self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
267
- dilation=dilation, groups=groups, bias=bias, causal=causal,
268
- norm=norm, norm_kwargs=norm_kwargs)
269
- self.causal = causal
270
- self.pad_mode = pad_mode
271
-
272
- # Store configuration
273
- self.kernel_size = kernel_size
274
- self.dilation = dilation
275
- self.stride = stride
276
- self.in_channels = in_channels
277
- self.out_channels = out_channels
278
-
279
- # For causal convolution, we need to maintain kernel_size - 1 samples as context
280
- # need to check use which context_size is more suitable
281
- # self.context_size = (kernel_size - 1) * dilation
282
- self.context_size = (kernel_size - 1) * dilation - (stride - 1)
283
-
284
- # For non-streaming mode, calculate padding
285
- self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
286
-
287
- # Create a unique layer ID for cache management
288
- self._layer_id = None
289
-
290
- @property
291
- def layer_id(self):
292
- if self._layer_id is None:
293
- self._layer_id = f"sconv1d_{id(self)}"
294
- return self._layer_id
295
-
296
- def forward(self, x: torch.Tensor,
297
- cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
298
- sample_indices: Optional[torch.Tensor] = None,
299
- use_cache: bool = False,
300
- debug: bool = False) -> torch.Tensor:
301
- """
302
- Forward pass with optional streaming support via cache.
303
-
304
- Args:
305
- x: Input tensor [batch_size, channels, time]
306
- cache: VibeVoiceTokenizerStreamingCache object for maintaining states
307
- sample_indices: Indices identifying each sample for cache management
308
- use_cache: Whether to use cached states for streaming
309
- debug: Whether to print debug information
310
-
311
- Returns:
312
- Output tensor
313
- """
314
- B, C, T = x.shape
315
-
316
- # Non-streaming mode
317
- if not use_cache or cache is None:
318
- return self._forward_non_streaming(x, debug=debug)
319
-
320
- # Streaming mode
321
- assert self.causal, "Streaming mode is only supported for causal convolutions"
322
- assert sample_indices is not None, "sample_indices must be provided for streaming mode"
323
- assert len(sample_indices) == B, "sample_indices must match batch size"
324
-
325
- return self._forward_streaming(x, cache, sample_indices, debug)
326
-
327
- def _forward_streaming(self, x: torch.Tensor,
328
- cache: VibeVoiceTokenizerStreamingCache,
329
- sample_indices: torch.Tensor,
330
- debug: bool = False) -> torch.Tensor:
331
- """Streaming forward pass with cache operations kept separate from compiled code"""
332
- B, C, T = x.shape
333
-
334
- # Cache operations (not compiled)
335
- cached_states = cache.get(self.layer_id, sample_indices)
336
-
337
- if cached_states is None:
338
- # First chunk - initialize with zeros for context
339
- if self.context_size > 0:
340
- cached_states = torch.zeros(B, C, self.context_size, device=x.device, dtype=x.dtype)
341
- if debug:
342
- print(f"[DEBUG] Initialized cache with shape: {cached_states.shape}, context_size={self.context_size}")
343
- else:
344
- cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
345
- if debug:
346
- print(f"[DEBUG] No context needed (kernel_size=stride)")
347
-
348
- # Concatenate cached states with input
349
- if cached_states.shape[2] > 0:
350
- input_with_context = torch.cat([cached_states, x], dim=2)
351
- else:
352
- input_with_context = x
353
-
354
- if debug:
355
- print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
356
-
357
- # Apply convolution directly - no extra padding in streaming mode
358
- # The conv layer will handle its own padding internally
359
- output = self.conv(input_with_context)
360
-
361
- if debug:
362
- print(f"[DEBUG] Output shape: {output.shape}")
363
-
364
- # Update cache for next chunk
365
- if self.context_size > 0:
366
- # Calculate how many samples to keep
367
- total_input_length = input_with_context.shape[2]
368
-
369
- # Keep the last context_size samples
370
- if total_input_length >= self.context_size:
371
- new_cache_start = total_input_length - self.context_size
372
- new_cache = input_with_context[:, :, new_cache_start:]
373
- else:
374
- # If we have less than context_size samples, keep everything
375
- new_cache = input_with_context
376
-
377
- if debug:
378
- print(f"[DEBUG] New cache shape: {new_cache.shape}")
379
-
380
- cache.set(self.layer_id, sample_indices, new_cache)
381
-
382
- return output
383
-
384
- def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
385
- """Standard forward pass without streaming"""
386
- B, C, T = x.shape
387
- kernel_size = self.kernel_size
388
- stride = self.stride
389
- dilation = self.dilation
390
- padding_total = self.padding_total
391
-
392
- # Compute extra padding for stride alignment
393
- extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
394
-
395
- if debug:
396
- print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}, padding_total={padding_total}, extra_padding={extra_padding}")
397
-
398
- if self.causal:
399
- # Left padding for causal
400
- if self.pad_mode == 'constant':
401
- x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0)
402
- else:
403
- x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
404
- else:
405
- # Symmetric padding for non-causal
406
- padding_right = padding_total // 2
407
- padding_left = padding_total - padding_right
408
- x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
409
-
410
- if debug:
411
- print(f"[DEBUG NON-STREAMING] After padding: {x.shape}")
412
-
413
- output = self.conv(x)
414
-
415
- if debug:
416
- print(f"[DEBUG NON-STREAMING] Output shape: {output.shape}")
417
-
418
- return output
419
-
420
-
421
- class SConvTranspose1d(nn.Module):
422
- """ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization."""
423
- def __init__(self, in_channels: int, out_channels: int,
424
- kernel_size: int, stride: int = 1, causal: bool = False,
425
- norm: str = 'none', trim_right_ratio: float = 1.,
426
- norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True):
427
- super().__init__()
428
- self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
429
- causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias)
430
- self.causal = causal
431
- self.trim_right_ratio = trim_right_ratio
432
- assert self.causal or self.trim_right_ratio == 1., \
433
- "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
434
- assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
435
-
436
- # Store configuration
437
- self.kernel_size = kernel_size
438
- self.stride = stride
439
- self.in_channels = in_channels
440
- self.out_channels = out_channels
441
-
442
- # For transposed convolution, padding calculation is different
443
- self.padding_total = kernel_size - stride
444
-
445
- # For streaming, we need to keep track of input history
446
- # Transposed conv needs to see multiple input samples to produce correct output
447
- self.context_size = kernel_size - 1
448
-
449
- # Create a unique layer ID for cache management
450
- self._layer_id = None
451
-
452
- @property
453
- def layer_id(self):
454
- if self._layer_id is None:
455
- self._layer_id = f"sconvtr1d_{id(self)}"
456
- return self._layer_id
457
-
458
- def forward(self, x: torch.Tensor,
459
- cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
460
- sample_indices: Optional[torch.Tensor] = None,
461
- use_cache: bool = False,
462
- debug: bool = False) -> torch.Tensor:
463
- """
464
- Forward pass with optional streaming support via cache.
465
- """
466
- B, C, T = x.shape
467
-
468
- # Non-streaming mode
469
- if not use_cache or cache is None:
470
- return self._forward_non_streaming(x, debug=debug)
471
-
472
- # Streaming mode
473
- assert sample_indices is not None, "sample_indices must be provided for streaming mode"
474
- assert len(sample_indices) == B, "sample_indices must match batch size"
475
-
476
- return self._forward_streaming(x, cache, sample_indices, debug)
477
-
478
- def _forward_streaming(self, x: torch.Tensor,
479
- cache: VibeVoiceTokenizerStreamingCache,
480
- sample_indices: torch.Tensor,
481
- debug: bool = False) -> torch.Tensor:
482
- """Streaming forward pass with cache operations kept separate from compiled code"""
483
- B, C, T = x.shape
484
-
485
- # Cache operations (not compiled)
486
- cached_input = cache.get(self.layer_id, sample_indices)
487
-
488
- if cached_input is None:
489
- # First chunk - no history yet
490
- cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
491
- if debug:
492
- print(f"[DEBUG] Initialized empty cache for transposed conv")
493
-
494
- # Concatenate cached input with new input
495
- full_input = torch.cat([cached_input, x], dim=2)
496
-
497
- if debug:
498
- print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_input.shape}, Combined: {full_input.shape}")
499
-
500
- # First chunk or debug mode - use uncompiled version
501
- full_output = self.convtr(full_input)
502
-
503
- if debug:
504
- print(f"[DEBUG] Full transposed conv output shape: {full_output.shape}")
505
-
506
- # Calculate padding to remove
507
- if self.causal:
508
- padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
509
- padding_left = self.padding_total - padding_right
510
- else:
511
- padding_right = self.padding_total // 2
512
- padding_left = self.padding_total - padding_right
513
-
514
- # Remove padding
515
- if padding_left + padding_right > 0:
516
- full_output = unpad1d(full_output, (padding_left, padding_right))
517
-
518
- if debug:
519
- print(f"[DEBUG] After unpadding: {full_output.shape}")
520
-
521
- # Determine which part of the output corresponds to the new input
522
- if cached_input.shape[2] == 0:
523
- # First chunk - return all output
524
- output = full_output
525
- else:
526
- # Subsequent chunks - return only the new output
527
- expected_new_output = T * self.stride
528
-
529
- # Take the last expected_new_output samples
530
- if full_output.shape[2] >= expected_new_output:
531
- output = full_output[:, :, -expected_new_output:]
532
- else:
533
- output = full_output
534
-
535
- if debug:
536
- print(f"[DEBUG] Final streaming output shape: {output.shape}")
537
-
538
- # Update cache
539
- if full_input.shape[2] > self.context_size:
540
- new_cache = full_input[:, :, -self.context_size:]
541
- else:
542
- new_cache = full_input
543
-
544
- if debug:
545
- print(f"[DEBUG] New cache shape: {new_cache.shape}")
546
-
547
- cache.set(self.layer_id, sample_indices, new_cache)
548
-
549
- return output
550
-
551
- def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
552
- """Standard forward pass without streaming"""
553
- if debug:
554
- print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}")
555
-
556
- # Apply transposed convolution
557
- y = self.convtr(x)
558
-
559
- if debug:
560
- print(f"[DEBUG NON-STREAMING] After transposed conv: {y.shape}")
561
-
562
- # Calculate and remove padding
563
- if self.causal:
564
- padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
565
- padding_left = self.padding_total - padding_right
566
- else:
567
- padding_right = self.padding_total // 2
568
- padding_left = self.padding_total - padding_right
569
-
570
- if padding_left + padding_right > 0:
571
- y = unpad1d(y, (padding_left, padding_right))
572
-
573
- if debug:
574
- print(f"[DEBUG NON-STREAMING] Final output shape: {y.shape}")
575
-
576
- return y
577
-
578
- # FFN
579
- class FFN(nn.Module):
580
- def __init__(
581
- self,
582
- embed_dim,
583
- ffn_dim,
584
- bias=False,
585
- ):
586
- super().__init__()
587
- self.embed_dim = embed_dim
588
- self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias)
589
- self.gelu = ACT2FN["gelu"]
590
- self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias)
591
-
592
- def forward(self, x):
593
- x = self.linear1(x)
594
- x = self.gelu(x)
595
- x = self.linear2(x)
596
- return x
597
-
598
-
599
- class Convlayer(nn.Module):
600
- def __init__(
601
- self,
602
- in_channels,
603
- out_channels,
604
- kernel_size,
605
- stride=1,
606
- dilation=1,
607
- groups=1,
608
- bias=True,
609
- pad_mode='zeros',
610
- norm='weight_norm',
611
- causal=True,
612
- ):
613
- super().__init__()
614
- self.conv = SConv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
615
- groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal)
616
-
617
- def forward(self, x):
618
- return self.conv(x)
619
-
620
- class Block1D(nn.Module):
621
- def __init__(self, dim, kernel_size=7, drop_path=0., mixer_layer='conv',
622
- layer_scale_init_value=1e-6, **kwargs):
623
- super().__init__()
624
-
625
- if kwargs.get('layernorm', 'LN') == 'LN':
626
- self.norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
627
- self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
628
- elif kwargs.get('layernorm', 'RMSNorm') == 'RMSNorm':
629
- self.norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
630
- self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
631
-
632
- if mixer_layer == 'conv':
633
- self.mixer = Convlayer(dim, dim, groups=kwargs.get('groups', 1),
634
- kernel_size=kernel_size,
635
- pad_mode=kwargs.get('pad_mode', 'reflect'),
636
- norm=kwargs.get('norm', 'none'),
637
- causal=kwargs.get('causal', True),
638
- bias=kwargs.get('bias', True),
639
- )
640
- elif mixer_layer == 'depthwise_conv':
641
- self.mixer = Convlayer(dim, dim, groups=dim,
642
- kernel_size=kernel_size,
643
- pad_mode=kwargs.get('pad_mode', 'reflect'),
644
- norm=kwargs.get('norm', 'none'),
645
- causal=kwargs.get('causal', True),
646
- bias=kwargs.get('bias', True),
647
- )
648
- else:
649
- raise ValueError(f"Unsupported mixer layer: {mixer_layer}")
650
-
651
- self.ffn = FFN(
652
- dim,
653
- kwargs.get('ffn_expansion', 4) * dim,
654
- bias=kwargs.get('bias', False),
655
- )
656
- self.drop_path = nn.Identity() if drop_path <= 0. else nn.modules.DropPath(drop_path)
657
-
658
- if layer_scale_init_value > 0:
659
- self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
660
- self.ffn_gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
661
- else:
662
- self.gamma = None
663
- self.ffn_gamma = None
664
-
665
- def forward(self, x):
666
- # mixer
667
- residual = x
668
- x = self.norm(x)
669
- x = self.mixer(x)
670
- if self.gamma is not None:
671
- x = x * self.gamma.unsqueeze(-1)
672
- x = residual + self.drop_path(x)
673
-
674
- # ffn
675
- residual = x
676
- x = self.ffn_norm(x)
677
- x = x.permute(0, 2, 1)
678
- x = self.ffn(x)
679
- x = x.permute(0, 2, 1)
680
- if self.ffn_gamma is not None:
681
- x = x * self.ffn_gamma.unsqueeze(-1)
682
- x = residual + self.drop_path(x)
683
-
684
- return x
685
-
686
-
687
- class TokenizerEncoder(nn.Module):
688
- """
689
- Encoder component for the VibeVoice tokenizer that converts audio to latent representations.
690
-
691
- Args:
692
- config: Configuration object with model parameters
693
- """
694
- def __init__(self, config):
695
- super().__init__()
696
-
697
- # Extract parameters from config
698
- self.channels = config.channels
699
- self.dimension = config.dimension
700
- self.n_filters = config.n_filters
701
- self.ratios = list(reversed(config.ratios))
702
- self.depths = config.depths
703
- self.n_residual_layers = getattr(config, "n_residual_layers", 1)
704
- self.hop_length = np.prod(self.ratios)
705
- self.causal = config.causal
706
-
707
- # Additional config parameters with defaults
708
- kernel_size = getattr(config, "kernel_size", 7)
709
- last_kernel_size = getattr(config, "last_kernel_size", 7)
710
- norm = getattr(config, "norm", "none")
711
- norm_params = getattr(config, "norm_params", {})
712
- pad_mode = getattr(config, "pad_mode", "reflect")
713
- bias = getattr(config, "bias", True)
714
- layernorm = getattr(config, "layernorm", "LN")
715
- layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
716
- layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
717
- drop_path_rate = getattr(config, "drop_path_rate", 0.0)
718
- mixer_layer = getattr(config, "mixer_layer", "conv")
719
- layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
720
- disable_last_norm = getattr(config, "disable_last_norm", False)
721
-
722
- # determine the norm type based on layernorm
723
- if layernorm == 'LN':
724
- norm_type = ConvLayerNorm
725
- elif layernorm == 'RMSNorm':
726
- norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
727
- else:
728
- raise ValueError(f"Unsupported norm type: {layernorm}")
729
-
730
- # stem and intermediate downsampling conv layers
731
- stem = nn.Sequential(
732
- SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
733
- )
734
-
735
- self.downsample_layers = nn.ModuleList()
736
- self.downsample_layers.append(stem)
737
- for i in range(len(self.ratios)):
738
- in_ch = self.n_filters * (2 ** i)
739
- out_ch = self.n_filters * (2 ** (i + 1))
740
- downsample_layer = nn.Sequential(
741
- SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
742
- )
743
- self.downsample_layers.append(downsample_layer)
744
-
745
- # configure the transformer blocks
746
- layer_type = partial(
747
- Block1D,
748
- mixer_layer=mixer_layer,
749
- layernorm=layernorm,
750
- eps=layernorm_eps,
751
- causal=self.causal,
752
- pad_mode=pad_mode,
753
- norm=norm,
754
- bias=bias,
755
- layer_scale_init_value=layer_scale_init_value,
756
- )
757
-
758
- self.stages = nn.ModuleList()
759
- dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
760
- cur = 0
761
-
762
- for i in range(len(self.depths)):
763
- in_ch = self.n_filters * (2 ** i)
764
- stage = nn.Sequential(
765
- *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
766
- )
767
- self.stages.append(stage)
768
- cur += self.depths[i]
769
-
770
- if not disable_last_norm:
771
- self.norm = norm_type(in_ch, eps=layernorm_eps)
772
- else:
773
- self.norm = nn.Identity()
774
- self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
775
-
776
- def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
777
- for i in range(len(self.depths)):
778
- # Apply downsampling
779
- for layer in self.downsample_layers[i]:
780
- if isinstance(layer, SConv1d):
781
- x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
782
- else:
783
- x = layer(x)
784
-
785
- # Apply stage (Block1D contains Convlayer which contains SConv1d)
786
- for block in self.stages[i]:
787
- if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
788
- # Block1D forward with cache support
789
- residual = x
790
- x = block.norm(x)
791
- x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
792
- if block.gamma is not None:
793
- x = x * block.gamma.unsqueeze(-1)
794
- x = residual + x
795
-
796
- # FFN part
797
- residual = x
798
- x = block.ffn_norm(x)
799
- x = x.permute(0, 2, 1)
800
- x = block.ffn(x)
801
- x = x.permute(0, 2, 1)
802
- if block.ffn_gamma is not None:
803
- x = x * block.ffn_gamma.unsqueeze(-1)
804
- x = residual + x
805
- else:
806
- x = block(x)
807
-
808
- return self.norm(x)
809
-
810
- def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
811
- x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
812
- x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
813
- return x
814
-
815
-
816
- class TokenizerDecoder(nn.Module):
817
- """
818
- Decoder component for the VibeVoice tokenizer that converts latent representations back to audio.
819
-
820
- Args:
821
- config: Configuration object with model parameters
822
- """
823
- def __init__(self, config):
824
- super().__init__()
825
-
826
- # Extract parameters from config
827
- self.dimension = config.dimension
828
- self.channels = config.channels
829
- self.n_filters = config.n_filters
830
- self.ratios = config.ratios
831
-
832
- # IMPORTANT CHANGE: Don't reverse depths again since they're already reversed in VibeVoiceAcousticTokenizerModel
833
- self.depths = config.depths # Changed from list(reversed(config.depths))
834
-
835
- self.n_residual_layers = getattr(config, "n_residual_layers", 1)
836
- self.hop_length = np.prod(self.ratios)
837
- self.causal = config.causal
838
-
839
- # Additional config parameters with defaults
840
- kernel_size = getattr(config, "kernel_size", 7)
841
- last_kernel_size = getattr(config, "last_kernel_size", 7)
842
- norm = getattr(config, "norm", "none")
843
- norm_params = getattr(config, "norm_params", {})
844
- pad_mode = getattr(config, "pad_mode", "reflect")
845
- bias = getattr(config, "bias", True)
846
- layernorm = getattr(config, "layernorm", "LN")
847
- layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
848
- trim_right_ratio = getattr(config, "trim_right_ratio", 1.0)
849
- layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
850
- drop_path_rate = getattr(config, "drop_path_rate", 0.0)
851
- mixer_layer = getattr(config, "mixer_layer", "conv")
852
- layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
853
- disable_last_norm = getattr(config, "disable_last_norm", False)
854
-
855
- # determine the norm type based on layernorm
856
- if layernorm == 'LN':
857
- norm_type = ConvLayerNorm
858
- elif layernorm == 'RMSNorm':
859
- norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
860
- else:
861
- raise ValueError(f"Unsupported norm type: {layernorm}")
862
-
863
- # stem and upsampling layers
864
- stem = nn.Sequential(
865
- SConv1d(self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm,
866
- norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
867
- )
868
-
869
- self.upsample_layers = nn.ModuleList()
870
- self.upsample_layers.append(stem)
871
- for i in range(len(self.ratios)):
872
- in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
873
- out_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i - 1))
874
- upsample_layer = nn.Sequential(
875
- SConvTranspose1d(in_ch, out_ch,
876
- kernel_size=self.ratios[i] * 2, stride=self.ratios[i],
877
- norm=norm, norm_kwargs=norm_params, bias=bias,
878
- causal=self.causal, trim_right_ratio=trim_right_ratio),
879
- )
880
- self.upsample_layers.append(upsample_layer)
881
-
882
- # configure transformer blocks
883
- layer_type = partial(
884
- Block1D,
885
- mixer_layer=mixer_layer,
886
- layernorm=layernorm,
887
- eps=layernorm_eps,
888
- causal=self.causal,
889
- pad_mode=pad_mode,
890
- norm=norm,
891
- bias=bias,
892
- layer_scale_init_value=layer_scale_init_value,
893
- )
894
-
895
- self.stages = nn.ModuleList()
896
- dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
897
- cur = 0
898
-
899
- # Create stages in the same order as the original model
900
- for i in range(len(self.depths)):
901
- in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
902
- stage = nn.Sequential(
903
- *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
904
- )
905
- self.stages.append(stage)
906
- cur += self.depths[i]
907
-
908
- if not disable_last_norm:
909
- self.norm = norm_type(in_ch, eps=layernorm_eps)
910
- else:
911
- self.norm = nn.Identity()
912
- self.head = SConv1d(in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
913
-
914
- def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
915
- for i in range(len(self.depths)):
916
- # Apply upsampling
917
- for layer in self.upsample_layers[i]:
918
- if isinstance(layer, (SConv1d, SConvTranspose1d)):
919
- x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
920
- else:
921
- x = layer(x)
922
-
923
- # Apply stage (Block1D contains Convlayer which contains SConv1d)
924
- for block in self.stages[i]:
925
- if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
926
- # Block1D forward with cache support
927
- residual = x
928
- x = block.norm(x)
929
- x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
930
- if block.gamma is not None:
931
- x = x * block.gamma.unsqueeze(-1)
932
- x = residual + x
933
-
934
- # FFN part
935
- residual = x
936
- x = block.ffn_norm(x)
937
- x = x.permute(0, 2, 1)
938
- x = block.ffn(x)
939
- x = x.permute(0, 2, 1)
940
- if block.ffn_gamma is not None:
941
- x = x * block.ffn_gamma.unsqueeze(-1)
942
- x = residual + x
943
- else:
944
- x = block(x)
945
-
946
- return self.norm(x)
947
-
948
- def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
949
- x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
950
- x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
951
- return x
952
-
953
-
954
- @dataclass
955
- class VibeVoiceTokenizerEncoderOutput:
956
- """
957
- Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance.
958
-
959
- Args:
960
- mean (`torch.FloatTensor`): The mean parameters of the distribution.
961
- std (`float` or `torch.FloatTensor`): Fixed standard deviation value.
962
- """
963
- mean: torch.Tensor
964
- std: Optional[Union[float, torch.Tensor]] = None
965
-
966
- def sample(self, dist_type='fix'):
967
- """
968
- Sample from the distribution.
969
-
970
- Args:
971
- dist_type (`str`): Sampling method, either 'fix' or 'gaussian'.
972
-
973
- Returns:
974
- `torch.FloatTensor`: Sampled values.
975
- `torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian').
976
- """
977
- if dist_type == 'fix':
978
- x = self.mean + self.std * torch.randn_like(self.mean)
979
- return x, self.std
980
- elif dist_type == 'gaussian':
981
- batch_size = self.mean.size(0)
982
- value = self.std / 0.8
983
- std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value
984
-
985
- while std.dim() < self.mean.dim():
986
- std = std.unsqueeze(-1)
987
-
988
- x = self.mean + std * torch.randn_like(self.mean)
989
- return x, std
990
- else:
991
- return self.mean, self.std
992
-
993
- def kl(self):
994
- """Compute KL divergence between this distribution and a standard normal."""
995
- target = torch.zeros_like(self.mean)
996
- return F.mse_loss(self.mean, target, reduction='none')
997
-
998
- def mode(self):
999
- """Return the distribution mode (which is the mean for Gaussian)."""
1000
- return self.mean
1001
-
1002
- class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
1003
- """VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens"""
1004
-
1005
- config_class = VibeVoiceAcousticTokenizerConfig
1006
- base_model_prefix = "vibevoice_acoustic_tokenizer"
1007
- _supports_flash_attn_2 = True
1008
- _supports_sdpa = True
1009
- _no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
1010
-
1011
- def __init__(self, config):
1012
- super().__init__(config)
1013
-
1014
- self.register_buffer('fix_std', torch.tensor(config.fix_std), persistent=False)
1015
- self.std_dist_type = getattr(config, "std_dist_type", "fix")
1016
-
1017
- # Parse encoder depths
1018
- if isinstance(config.encoder_depths, str):
1019
- encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
1020
- else:
1021
- encoder_depths = config.encoder_depths
1022
-
1023
- # Parse decoder depths if provided
1024
- if config.decoder_depths is not None and isinstance(config.decoder_depths, str):
1025
- decoder_depths = [int(d) for d in config.decoder_depths.split('-')]
1026
- else:
1027
- # Default: use reversed encoder depths if decoder_depths is None
1028
- decoder_depths = list(reversed(encoder_depths))
1029
-
1030
- # Create encoder config
1031
- encoder_config = copy.deepcopy(config)
1032
- encoder_config.dimension = config.vae_dim
1033
- encoder_config.n_filters = config.encoder_n_filters
1034
- encoder_config.ratios = config.encoder_ratios
1035
- encoder_config.depths = encoder_depths
1036
- encoder_config.norm = config.conv_norm
1037
- encoder_config.pad_mode = config.pad_mode
1038
- encoder_config.bias = config.conv_bias
1039
- encoder_config.layernorm_eps = config.layernorm_eps
1040
- encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
1041
- encoder_config.mixer_layer = config.mixer_layer
1042
- encoder_config.layer_scale_init_value = config.layer_scale_init_value
1043
- encoder_config.disable_last_norm = config.disable_last_norm
1044
-
1045
- # Create decoder config
1046
- decoder_config = copy.deepcopy(config)
1047
- decoder_config.dimension = config.vae_dim
1048
- decoder_config.n_filters = config.decoder_n_filters
1049
- decoder_config.ratios = config.decoder_ratios
1050
- decoder_config.depths = decoder_depths
1051
- decoder_config.norm = config.conv_norm
1052
- decoder_config.pad_mode = config.pad_mode
1053
- decoder_config.bias = config.conv_bias
1054
- decoder_config.layernorm_eps = config.layernorm_eps
1055
- decoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
1056
- decoder_config.mixer_layer = config.mixer_layer
1057
- decoder_config.layer_scale_init_value = config.layer_scale_init_value
1058
- decoder_config.disable_last_norm = config.disable_last_norm
1059
-
1060
- # Initialize encoder and decoder
1061
- self.encoder = TokenizerEncoder(encoder_config)
1062
- self.decoder = TokenizerDecoder(decoder_config)
1063
-
1064
- # Initialize weights
1065
- self.apply(self._init_weights)
1066
-
1067
- def _init_weights(self, module):
1068
- """Initialize weights for the model"""
1069
- if isinstance(module, nn.Linear):
1070
- nn.init.normal_(module.weight, std=self.config.weight_init_value)
1071
- if module.bias is not None:
1072
- nn.init.zeros_(module.bias)
1073
- elif isinstance(module, nn.LayerNorm):
1074
- nn.init.ones_(module.weight)
1075
- nn.init.zeros_(module.bias)
1076
- elif isinstance(module, nn.Conv1d):
1077
- nn.init.normal_(module.weight, std=self.config.weight_init_value)
1078
- if module.bias is not None:
1079
- nn.init.zeros_(module.bias)
1080
-
1081
- @torch.no_grad()
1082
- def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1083
- """Convert audio to latent representations"""
1084
- latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1085
- return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std)
1086
-
1087
- @torch.no_grad()
1088
- def sampling(self, encoder_output, dist_type=None):
1089
- """Sample from the encoder output distribution"""
1090
- dist_type = dist_type or self.std_dist_type
1091
-
1092
- if dist_type == 'fix':
1093
- return encoder_output.sample(dist_type='fix')
1094
- elif dist_type == 'gaussian':
1095
- return encoder_output.sample(dist_type='gaussian')
1096
- else:
1097
- raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'")
1098
-
1099
- @torch.no_grad()
1100
- def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
1101
- """Convert latent representations back to audio"""
1102
- if latents.shape[1] == self.config.vae_dim:
1103
- pass
1104
- else:
1105
- latents = latents.permute(0, 2, 1)
1106
-
1107
- audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1108
- return audio
1109
-
1110
- def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1111
- """Full forward pass: encode audio to latents, then decode back to audio"""
1112
- encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1113
- sampled_latents, _ = self.sampling(encoder_output)
1114
- reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1115
- return reconstructed, sampled_latents
1116
-
1117
-
1118
- class VibeVoiceSemanticTokenizerModel(PreTrainedModel):
1119
- """VibeVoice speech tokenizer model with only encoder for semantic tokens"""
1120
-
1121
- config_class = VibeVoiceSemanticTokenizerConfig
1122
- base_model_prefix = "vibevoice_semantic_tokenizer"
1123
- _supports_flash_attn_2 = True
1124
- _supports_sdpa = True
1125
- _no_split_modules = ["TokenizerEncoder"]
1126
-
1127
- def __init__(self, config):
1128
- super().__init__(config)
1129
-
1130
- # Parse encoder depths
1131
- if isinstance(config.encoder_depths, str):
1132
- encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
1133
- else:
1134
- encoder_depths = config.encoder_depths
1135
-
1136
- # Create encoder config
1137
- encoder_config = copy.deepcopy(config)
1138
- encoder_config.dimension = config.vae_dim
1139
- encoder_config.n_filters = config.encoder_n_filters
1140
- encoder_config.ratios = config.encoder_ratios
1141
- encoder_config.depths = encoder_depths
1142
- encoder_config.norm = config.conv_norm
1143
- encoder_config.pad_mode = config.pad_mode
1144
- encoder_config.bias = config.conv_bias
1145
- encoder_config.layernorm_eps = config.layernorm_eps
1146
- encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
1147
- encoder_config.mixer_layer = config.mixer_layer
1148
- encoder_config.layer_scale_init_value = config.layer_scale_init_value
1149
- encoder_config.disable_last_norm = config.disable_last_norm
1150
-
1151
- # Initialize encoder and decoder
1152
- self.encoder = TokenizerEncoder(encoder_config)
1153
-
1154
- # Initialize weights
1155
- self.apply(self._init_weights)
1156
-
1157
- def _init_weights(self, module):
1158
- """Initialize weights for the model"""
1159
- if isinstance(module, nn.Linear):
1160
- nn.init.normal_(module.weight, std=self.config.weight_init_value)
1161
- if module.bias is not None:
1162
- nn.init.zeros_(module.bias)
1163
- elif isinstance(module, nn.LayerNorm):
1164
- nn.init.ones_(module.weight)
1165
- nn.init.zeros_(module.bias)
1166
- elif isinstance(module, nn.Conv1d):
1167
- nn.init.normal_(module.weight, std=self.config.weight_init_value)
1168
- if module.bias is not None:
1169
- nn.init.zeros_(module.bias)
1170
-
1171
- @torch.no_grad()
1172
- def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1173
- """Convert audio to latent representations"""
1174
- latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1175
- return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
1176
-
1177
- @torch.no_grad()
1178
- def sampling(self, encoder_output, dist_type=None):
1179
- """Sample from the encoder output distribution"""
1180
- return encoder_output.sample(dist_type='none')
1181
-
1182
- def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1183
- """Full forward pass: encode audio to latents, then decode back to audio"""
1184
- encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1185
- sampled_latents, _ = self.sampling(encoder_output, dist_type='none')
1186
- return None, sampled_latents
1187
-
1188
- AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel)
1189
- AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel)
1190
-
1191
- __all__ = [
1192
- "VibeVoiceTokenizerStreamingCache",
1193
- "VibeVoiceAcousticTokenizerModel",
1194
- "VibeVoiceSemanticTokenizerModel",
1195
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/modular/streamer.py DELETED
@@ -1,264 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import torch
4
-
5
- import asyncio
6
- from queue import Queue
7
- from typing import TYPE_CHECKING, Optional
8
-
9
-
10
- from transformers.generation.streamers import BaseStreamer
11
-
12
-
13
- class AudioStreamer(BaseStreamer):
14
- """
15
- Audio streamer that stores audio chunks in queues for each sample in the batch.
16
- This allows streaming audio generation for multiple samples simultaneously.
17
-
18
- Parameters:
19
- batch_size (`int`):
20
- The batch size for generation
21
- stop_signal (`any`, *optional*):
22
- The signal to put in the queue when generation ends. Defaults to None.
23
- timeout (`float`, *optional*):
24
- The timeout for the audio queue. If `None`, the queue will block indefinitely.
25
- """
26
-
27
- def __init__(
28
- self,
29
- batch_size: int,
30
- stop_signal: Optional[any] = None,
31
- timeout: Optional[float] = None,
32
- ):
33
- self.batch_size = batch_size
34
- self.stop_signal = stop_signal
35
- self.timeout = timeout
36
-
37
- # Create a queue for each sample in the batch
38
- self.audio_queues = [Queue() for _ in range(batch_size)]
39
- self.finished_flags = [False for _ in range(batch_size)]
40
- self.sample_indices_map = {} # Maps from sample index to queue index
41
-
42
- def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
43
- """
44
- Receives audio chunks and puts them in the appropriate queues.
45
-
46
- Args:
47
- audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
48
- sample_indices: Tensor indicating which samples these chunks belong to
49
- """
50
- for i, sample_idx in enumerate(sample_indices):
51
- idx = sample_idx.item()
52
- if idx < self.batch_size and not self.finished_flags[idx]:
53
- # Convert to numpy or keep as tensor based on preference
54
- audio_chunk = audio_chunks[i].detach().cpu()
55
- self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
56
-
57
- def end(self, sample_indices: Optional[torch.Tensor] = None):
58
- """
59
- Signals the end of generation for specified samples or all samples.
60
-
61
- Args:
62
- sample_indices: Optional tensor of sample indices to end. If None, ends all.
63
- """
64
- if sample_indices is None:
65
- # End all samples
66
- for idx in range(self.batch_size):
67
- if not self.finished_flags[idx]:
68
- self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
69
- self.finished_flags[idx] = True
70
- else:
71
- # End specific samples
72
- for sample_idx in sample_indices:
73
- idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
74
- if idx < self.batch_size and not self.finished_flags[idx]:
75
- self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
76
- self.finished_flags[idx] = True
77
-
78
- def __iter__(self):
79
- """Returns an iterator over the batch of audio streams."""
80
- return AudioBatchIterator(self)
81
-
82
- def get_stream(self, sample_idx: int):
83
- """Get the audio stream for a specific sample."""
84
- if sample_idx >= self.batch_size:
85
- raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
86
- return AudioSampleIterator(self, sample_idx)
87
-
88
-
89
- class AudioSampleIterator:
90
- """Iterator for a single audio stream from the batch."""
91
-
92
- def __init__(self, streamer: AudioStreamer, sample_idx: int):
93
- self.streamer = streamer
94
- self.sample_idx = sample_idx
95
-
96
- def __iter__(self):
97
- return self
98
-
99
- def __next__(self):
100
- value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
101
- if value == self.streamer.stop_signal:
102
- raise StopIteration()
103
- return value
104
-
105
-
106
- class AudioBatchIterator:
107
- """Iterator that yields audio chunks for all samples in the batch."""
108
-
109
- def __init__(self, streamer: AudioStreamer):
110
- self.streamer = streamer
111
- self.active_samples = set(range(streamer.batch_size))
112
-
113
- def __iter__(self):
114
- return self
115
-
116
- def __next__(self):
117
- if not self.active_samples:
118
- raise StopIteration()
119
-
120
- batch_chunks = {}
121
- samples_to_remove = set()
122
-
123
- # Try to get chunks from all active samples
124
- for idx in self.active_samples:
125
- try:
126
- value = self.streamer.audio_queues[idx].get(block=False)
127
- if value == self.streamer.stop_signal:
128
- samples_to_remove.add(idx)
129
- else:
130
- batch_chunks[idx] = value
131
- except:
132
- # Queue is empty for this sample, skip it this iteration
133
- pass
134
-
135
- # Remove finished samples
136
- self.active_samples -= samples_to_remove
137
-
138
- if batch_chunks:
139
- return batch_chunks
140
- elif self.active_samples:
141
- # If no chunks were ready but we still have active samples,
142
- # wait a bit and try again
143
- import time
144
- time.sleep(0.01)
145
- return self.__next__()
146
- else:
147
- raise StopIteration()
148
-
149
-
150
- class AsyncAudioStreamer(AudioStreamer):
151
- """
152
- Async version of AudioStreamer for use in async contexts.
153
- """
154
-
155
- def __init__(
156
- self,
157
- batch_size: int,
158
- stop_signal: Optional[any] = None,
159
- timeout: Optional[float] = None,
160
- ):
161
- super().__init__(batch_size, stop_signal, timeout)
162
- # Replace regular queues with async queues
163
- self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
164
- self.loop = asyncio.get_running_loop()
165
-
166
- def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
167
- """Put audio chunks in the appropriate async queues."""
168
- for i, sample_idx in enumerate(sample_indices):
169
- idx = sample_idx.item()
170
- if idx < self.batch_size and not self.finished_flags[idx]:
171
- audio_chunk = audio_chunks[i].detach().cpu()
172
- self.loop.call_soon_threadsafe(
173
- self.audio_queues[idx].put_nowait, audio_chunk
174
- )
175
-
176
- def end(self, sample_indices: Optional[torch.Tensor] = None):
177
- """Signal the end of generation for specified samples."""
178
- if sample_indices is None:
179
- indices_to_end = range(self.batch_size)
180
- else:
181
- indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
182
-
183
- for idx in indices_to_end:
184
- if idx < self.batch_size and not self.finished_flags[idx]:
185
- self.loop.call_soon_threadsafe(
186
- self.audio_queues[idx].put_nowait, self.stop_signal
187
- )
188
- self.finished_flags[idx] = True
189
-
190
- async def get_stream(self, sample_idx: int):
191
- """Get async iterator for a specific sample's audio stream."""
192
- if sample_idx >= self.batch_size:
193
- raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
194
-
195
- while True:
196
- value = await self.audio_queues[sample_idx].get()
197
- if value == self.stop_signal:
198
- break
199
- yield value
200
-
201
- def __aiter__(self):
202
- """Returns an async iterator over all audio streams."""
203
- return AsyncAudioBatchIterator(self)
204
-
205
-
206
- class AsyncAudioBatchIterator:
207
- """Async iterator for batch audio streaming."""
208
-
209
- def __init__(self, streamer: AsyncAudioStreamer):
210
- self.streamer = streamer
211
- self.active_samples = set(range(streamer.batch_size))
212
-
213
- def __aiter__(self):
214
- return self
215
-
216
- async def __anext__(self):
217
- if not self.active_samples:
218
- raise StopAsyncIteration()
219
-
220
- batch_chunks = {}
221
- samples_to_remove = set()
222
-
223
- # Create tasks for all active samples
224
- tasks = {
225
- idx: asyncio.create_task(self._get_chunk(idx))
226
- for idx in self.active_samples
227
- }
228
-
229
- # Wait for at least one chunk to be ready
230
- done, pending = await asyncio.wait(
231
- tasks.values(),
232
- return_when=asyncio.FIRST_COMPLETED,
233
- timeout=self.streamer.timeout
234
- )
235
-
236
- # Cancel pending tasks
237
- for task in pending:
238
- task.cancel()
239
-
240
- # Process completed tasks
241
- for idx, task in tasks.items():
242
- if task in done:
243
- try:
244
- value = await task
245
- if value == self.streamer.stop_signal:
246
- samples_to_remove.add(idx)
247
- else:
248
- batch_chunks[idx] = value
249
- except asyncio.CancelledError:
250
- pass
251
-
252
- self.active_samples -= samples_to_remove
253
-
254
- if batch_chunks:
255
- return batch_chunks
256
- elif self.active_samples:
257
- # Try again if we still have active samples
258
- return await self.__anext__()
259
- else:
260
- raise StopAsyncIteration()
261
-
262
- async def _get_chunk(self, idx):
263
- """Helper to get a chunk from a specific queue."""
264
- return await self.streamer.audio_queues[idx].get()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/processor/__init__.py DELETED
File without changes
src/vibevoice/processor/vibevoice_processor.py DELETED
@@ -1,701 +0,0 @@
1
- import math
2
- import warnings
3
- from typing import List, Optional, Union, Dict, Any, Tuple
4
- import os
5
- import re
6
-
7
- import numpy as np
8
- import torch
9
-
10
- from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
11
- from transformers.utils import TensorType, logging
12
- from .vibevoice_tokenizer_processor import AudioNormalizer
13
-
14
- logger = logging.get_logger(__name__)
15
-
16
-
17
- class VibeVoiceProcessor:
18
- r"""
19
- Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor.
20
-
21
- [`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`].
22
- See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information.
23
-
24
- Args:
25
- tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):
26
- The tokenizer for text processing.
27
- audio_processor (`VibeVoiceTokenizerProcessor`):
28
- The audio processor for speech processing.
29
- speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):
30
- The compression ratio for speech tokenization.
31
- db_normalize (`bool`, *optional*, defaults to True):
32
- Whether to apply decibel normalization to audio inputs.
33
- """
34
-
35
- def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs):
36
- self.tokenizer = tokenizer
37
- self.audio_processor = audio_processor
38
- self.speech_tok_compress_ratio = speech_tok_compress_ratio
39
- self.db_normalize = db_normalize
40
- self.audio_normalizer = AudioNormalizer() if db_normalize else None
41
- self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
42
-
43
- @classmethod
44
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
45
- """
46
- Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor.
47
-
48
- Args:
49
- pretrained_model_name_or_path (`str` or `os.PathLike`):
50
- This can be either:
51
- - a string, the *model id* of a pretrained model
52
- - a path to a *directory* containing processor config
53
-
54
- Returns:
55
- [`VibeVoiceProcessor`]: The processor object instantiated from pretrained model.
56
- """
57
- import os
58
- import json
59
- from transformers.utils import cached_file
60
- from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
61
- from vibevoice.modular.modular_vibevoice_text_tokenizer import (
62
- VibeVoiceTextTokenizer,
63
- VibeVoiceTextTokenizerFast
64
- )
65
-
66
- # Try to load from local path first, then from HF hub
67
- config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
68
- config = None
69
-
70
- if os.path.exists(config_path):
71
- # Local path exists
72
- with open(config_path, 'r') as f:
73
- config = json.load(f)
74
- else:
75
- # Try to load from HF hub
76
- try:
77
- config_file = cached_file(
78
- pretrained_model_name_or_path,
79
- "preprocessor_config.json",
80
- **kwargs
81
- )
82
- with open(config_file, 'r') as f:
83
- config = json.load(f)
84
- except Exception as e:
85
- logger.warning(f"Could not load preprocessor_config.json from {pretrained_model_name_or_path}: {e}")
86
- logger.warning("Using default configuration")
87
- config = {
88
- "speech_tok_compress_ratio": 3200,
89
- "db_normalize": True,
90
- }
91
-
92
- # Extract main processor parameters
93
- speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
94
- db_normalize = config.get("db_normalize", True)
95
-
96
- # Load tokenizer - try from model path first, then fallback to Qwen
97
- language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
98
- logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
99
- if 'qwen' in language_model_pretrained_name.lower():
100
- tokenizer = VibeVoiceTextTokenizerFast.from_pretrained(
101
- language_model_pretrained_name,
102
- **kwargs
103
- )
104
- else:
105
- raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.")
106
-
107
- # Load audio processor
108
- if "audio_processor" in config:
109
- # Create audio processor from config
110
- audio_config = config["audio_processor"]
111
- audio_processor = VibeVoiceTokenizerProcessor(
112
- sampling_rate=audio_config.get("sampling_rate", 24000),
113
- normalize_audio=audio_config.get("normalize_audio", True),
114
- target_dB_FS=audio_config.get("target_dB_FS", -25),
115
- eps=audio_config.get("eps", 1e-6),
116
- )
117
- else:
118
- # Create default audio processor
119
- audio_processor = VibeVoiceTokenizerProcessor()
120
-
121
- # Create and return the processor
122
- return cls(
123
- tokenizer=tokenizer,
124
- audio_processor=audio_processor,
125
- speech_tok_compress_ratio=speech_tok_compress_ratio,
126
- db_normalize=db_normalize,
127
- )
128
-
129
- def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
130
- """
131
- Save a processor to a directory, so that it can be re-loaded using the
132
- [`~VibeVoiceProcessor.from_pretrained`] class method.
133
-
134
- Args:
135
- save_directory (`str` or `os.PathLike`):
136
- Directory where the processor will be saved.
137
- """
138
- import os
139
- import json
140
-
141
- os.makedirs(save_directory, exist_ok=True)
142
-
143
- # Save processor configuration
144
- processor_config = {
145
- "processor_class": "VibeVoiceProcessor",
146
- "speech_tok_compress_ratio": self.speech_tok_compress_ratio,
147
- "db_normalize": self.db_normalize,
148
- "audio_processor": {
149
- "feature_extractor_type": "VibeVoiceTokenizerProcessor",
150
- "sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000),
151
- "normalize_audio": getattr(self.audio_processor, 'normalize_audio', True),
152
- "target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25),
153
- "eps": getattr(self.audio_processor, 'eps', 1e-6),
154
- }
155
- }
156
-
157
- config_path = os.path.join(save_directory, "preprocessor_config.json")
158
- with open(config_path, 'w') as f:
159
- json.dump(processor_config, f, indent=2)
160
-
161
- logger.info(f"Processor configuration saved in {config_path}")
162
-
163
- def __call__(
164
- self,
165
- text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
166
- voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,
167
- padding: Union[bool, str, PaddingStrategy] = True,
168
- truncation: Union[bool, str, TruncationStrategy] = False,
169
- max_length: Optional[int] = None,
170
- return_tensors: Optional[Union[str, TensorType]] = None,
171
- return_attention_mask: bool = True,
172
- **kwargs,
173
- ) -> BatchEncoding:
174
- """
175
- Main method to process one or more podcast scripts with optional voice samples.
176
-
177
- Args:
178
- text (`str`, `List[str]`):
179
- The input text(s) to process. Can be:
180
- - A single script string
181
- - A list of script strings for batch processing
182
- - A path to a .json or .txt file
183
- - A list of paths
184
- voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*):
185
- Voice samples for each script. Can be:
186
- - A list of samples for a single script
187
- - A list of lists for batch processing
188
- padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):
189
- Whether to pad sequences to the same length
190
- truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):
191
- Whether to truncate sequences
192
- max_length (`int`, *optional*):
193
- Maximum length of the returned sequences
194
- return_tensors (`str` or `TensorType`, *optional*):
195
- If set, will return tensors of a particular framework
196
- return_attention_mask (`bool`, defaults to `True`):
197
- Whether to return the attention mask
198
-
199
- Returns:
200
- `BatchEncoding`: A BatchEncoding with the following fields:
201
- - **input_ids** -- List of token id sequences or tensor
202
- - **attention_mask** -- List of attention masks or tensor
203
- - **speech_tensors** -- Padded speech inputs (if voice_samples provided)
204
- - **speech_masks** -- Speech masks (if voice_samples provided)
205
- - **speech_input_mask** -- Boolean masks indicating speech token positions
206
- """
207
- # Handle single vs batch input
208
- if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
209
- # Single input
210
- texts = [text]
211
- is_batched = False
212
- else:
213
- # Batch input
214
- texts = text
215
- is_batched = True
216
-
217
- # Handle voice samples
218
- if voice_samples is not None:
219
- if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
220
- # Single set of voice samples
221
- voice_samples_list = [voice_samples]
222
- else:
223
- # Batch of voice samples
224
- voice_samples_list = voice_samples
225
- else:
226
- voice_samples_list = [None] * len(texts)
227
-
228
- # Process each input
229
- all_encodings = []
230
- for text_input, voice_input in zip(texts, voice_samples_list):
231
- encoding = self._process_single(text_input, voice_input)
232
- all_encodings.append(encoding)
233
-
234
- # Combine batch
235
- batch_encoding = self._batch_encode(
236
- all_encodings,
237
- padding=padding,
238
- truncation=truncation,
239
- max_length=max_length,
240
- return_tensors=return_tensors,
241
- return_attention_mask=return_attention_mask,
242
- )
243
-
244
- return batch_encoding
245
-
246
- def _process_single(
247
- self,
248
- text: Union[str, TextInput],
249
- voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
250
- ) -> Dict[str, Any]:
251
- """Process a single podcast script."""
252
- # Determine if text is a file path or direct script
253
- script = None
254
- if isinstance(text, str):
255
- # Check if it's a file path
256
- if text.endswith('.json') and os.path.exists(text):
257
- script = self._convert_json_to_script(text)
258
- elif text.endswith('.txt') and os.path.exists(text):
259
- script = self._convert_text_to_script(text)
260
- else:
261
- # Assume it's the script content directly
262
- script = text
263
-
264
- if script is None:
265
- raise ValueError(f"Could not process input text: {text}")
266
-
267
- # Parse the script
268
- parsed_lines = self._parse_script(script)
269
- all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
270
-
271
- # Create system prompt
272
- # system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
273
- system_tokens = self.tokenizer.encode(self.system_prompt)
274
-
275
- # Process voice samples if provided
276
- if voice_samples:
277
- voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
278
- else:
279
- voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
280
-
281
- # Build full token sequence
282
- full_tokens = system_tokens + voice_tokens
283
- speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
284
-
285
- # Add text input section
286
- full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
287
- speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
288
-
289
- for speaker_id, speaker_text in parsed_lines:
290
- speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
291
- full_tokens += speaker_text_tokens
292
- speech_input_mask += [False] * len(speaker_text_tokens)
293
-
294
- # Add speech output section
295
- full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
296
- speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
297
-
298
- return {
299
- "input_ids": full_tokens,
300
- "speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
301
- "speech_input_mask": speech_input_mask,
302
- "parsed_script": parsed_lines,
303
- "all_speakers": all_speakers,
304
- }
305
-
306
- def _batch_encode(
307
- self,
308
- encodings: List[Dict[str, Any]],
309
- padding: Union[bool, str, PaddingStrategy] = True,
310
- truncation: Union[bool, str, TruncationStrategy] = False,
311
- max_length: Optional[int] = None,
312
- return_tensors: Optional[Union[str, TensorType]] = None,
313
- return_attention_mask: bool = True,
314
- ) -> BatchEncoding:
315
- """Combine multiple encodings into a batch with padding."""
316
- # Extract input_ids and create attention_mask
317
- input_ids_list = [enc["input_ids"] for enc in encodings]
318
- speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
319
-
320
- # Determine padding strategy
321
- if isinstance(padding, bool):
322
- padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
323
- elif isinstance(padding, str):
324
- padding_strategy = PaddingStrategy(padding)
325
- else:
326
- padding_strategy = padding
327
-
328
- # Apply padding to input_ids
329
- if padding_strategy != PaddingStrategy.DO_NOT_PAD:
330
- if padding_strategy == PaddingStrategy.LONGEST:
331
- max_len = max(len(ids) for ids in input_ids_list)
332
- elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
333
- max_len = max_length
334
- else:
335
- max_len = max(len(ids) for ids in input_ids_list)
336
-
337
- # Pad sequences
338
- padded_input_ids = []
339
- attention_masks = []
340
- padded_speech_input_masks = []
341
-
342
- for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
343
- # Truncate if needed
344
- if truncation and len(input_ids) > max_len:
345
- input_ids = input_ids[:max_len]
346
- speech_mask = speech_mask[:max_len]
347
-
348
- # Pad
349
- padding_length = max_len - len(input_ids)
350
- # padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
351
- padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
352
- attention_mask = [0] * padding_length + [1] * len(input_ids)
353
- padded_speech_mask = [False] * padding_length + speech_mask
354
-
355
- padded_input_ids.append(padded_ids)
356
- attention_masks.append(attention_mask)
357
- padded_speech_input_masks.append(padded_speech_mask)
358
-
359
- input_ids_list = padded_input_ids
360
- speech_input_masks_list = padded_speech_input_masks
361
- else:
362
- # No padding, just create attention masks
363
- attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
364
-
365
- # Process speech inputs
366
- all_speech_inputs = []
367
- has_speech = False
368
- for enc in encodings:
369
- if enc["speech_inputs"] is not None:
370
- all_speech_inputs.extend(enc["speech_inputs"])
371
- has_speech = True
372
-
373
- # Prepare batch encoding
374
- batch_encoding = BatchEncoding()
375
-
376
- # Handle tensor conversion
377
- if return_tensors is not None:
378
- batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
379
- if return_attention_mask and attention_masks is not None:
380
- batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
381
- batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool)
382
- else:
383
- batch_encoding["input_ids"] = input_ids_list
384
- if return_attention_mask and attention_masks is not None:
385
- batch_encoding["attention_mask"] = attention_masks
386
- batch_encoding["speech_input_mask"] = speech_input_masks_list
387
-
388
- # Process speech tensors if present
389
- if has_speech:
390
- speech_dict = self.prepare_speech_inputs(
391
- all_speech_inputs,
392
- return_tensors=return_tensors,
393
- )
394
- batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
395
- batch_encoding["speech_masks"] = speech_dict["speech_masks"]
396
- else:
397
- batch_encoding["speech_tensors"] = None
398
- batch_encoding["speech_masks"] = None
399
-
400
- # Add metadata
401
- batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
402
- batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
403
-
404
- return batch_encoding
405
-
406
- def _create_voice_prompt(
407
- self,
408
- speaker_samples: List[Union[str, np.ndarray]]
409
- ) -> Tuple[List[int], List[np.ndarray], List[bool]]:
410
- """
411
- Create voice prompt tokens and process audio samples.
412
-
413
- Returns:
414
- tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
415
- """
416
- vae_token_id = self.tokenizer.speech_diffusion_id
417
-
418
- voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
419
- voice_speech_inputs = []
420
- voice_speech_masks = [False] * len(voice_full_tokens)
421
-
422
- for speaker_id, speaker_audio in enumerate(speaker_samples):
423
- prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
424
-
425
- # Process audio
426
- if isinstance(speaker_audio, str):
427
- # Load audio from file
428
- wav = self.audio_processor._load_audio_from_path(speaker_audio)
429
- else:
430
- wav = np.array(speaker_audio, dtype=np.float32)
431
-
432
- # Apply normalization if needed
433
- if self.db_normalize and self.audio_normalizer:
434
- wav = self.audio_normalizer(wav)
435
-
436
- # Calculate token length based on compression ratio
437
- # if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
438
- # vae_tok_len = wav.shape[0]
439
- # else:
440
- vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
441
-
442
- # Build tokens and masks
443
- speaker_tokens = (prefix_tokens +
444
- [self.tokenizer.speech_start_id] +
445
- [vae_token_id] * vae_tok_len +
446
- [self.tokenizer.speech_end_id] +
447
- self.tokenizer.encode('\n', add_special_tokens=False))
448
-
449
- vae_input_mask = ([False] * len(prefix_tokens) +
450
- [False] +
451
- [True] * vae_tok_len +
452
- [False] +
453
- [False])
454
-
455
- voice_full_tokens.extend(speaker_tokens)
456
- voice_speech_masks.extend(vae_input_mask)
457
- voice_speech_inputs.append(wav)
458
-
459
- return voice_full_tokens, voice_speech_inputs, voice_speech_masks
460
-
461
- def prepare_speech_inputs(
462
- self,
463
- speech_inputs: List[np.ndarray],
464
- return_tensors: Optional[Union[str, TensorType]] = None,
465
- device: Optional[Union[str, torch.device]] = None,
466
- dtype: Optional[torch.dtype] = None,
467
- ) -> Dict[str, Any]:
468
- """
469
- Prepare speech inputs for model consumption.
470
-
471
- Args:
472
- speech_inputs: List of speech arrays
473
- return_tensors: Output tensor type
474
- device: Device to place tensors on
475
- dtype: Data type for tensors
476
-
477
- Returns:
478
- Dictionary with padded_speeches and speech_masks
479
- """
480
- if not speech_inputs:
481
- return {"padded_speeches": None, "speech_masks": None}
482
-
483
- # Calculate sequence lengths
484
- vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs]
485
- # vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
486
- max_speech_length = max(s.shape[0] for s in speech_inputs)
487
-
488
- # Pad speeches
489
- if speech_inputs[0].ndim == 1:
490
- padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32)
491
- else:
492
- padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32)
493
- speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_)
494
-
495
- for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)):
496
- padded_speeches[i, :len(speech)] = speech
497
- speech_masks[i, :vae_tok_length] = True
498
-
499
- result = {
500
- "padded_speeches": padded_speeches,
501
- "speech_masks": speech_masks,
502
- }
503
-
504
- # Convert to tensors if requested
505
- if return_tensors == "pt":
506
- result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32)
507
- result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool)
508
-
509
- return result
510
-
511
- def _convert_json_to_script(self, json_file: str) -> str:
512
- """
513
- Convert JSON format to script format.
514
- Expected JSON format:
515
- [
516
- {"speaker": "1", "text": "Hello everyone..."},
517
- {"speaker": "2", "text": "Great to be here..."}
518
- ]
519
- """
520
- import json
521
-
522
- with open(json_file, 'r', encoding='utf-8') as f:
523
- data = json.load(f)
524
-
525
- if not isinstance(data, list):
526
- raise ValueError("JSON file must contain a list of speaker entries")
527
-
528
- script_lines = []
529
- for item in data:
530
- if not isinstance(item, dict):
531
- logger.warning(f"Skipping non-dict entry: {item}")
532
- continue
533
-
534
- speaker = item.get('speaker')
535
- text = item.get('text')
536
-
537
- if speaker is None or text is None:
538
- logger.warning(f"Skipping entry missing speaker or text: {item}")
539
- continue
540
-
541
- # Ensure speaker ID is valid
542
- try:
543
- speaker_id = int(speaker)
544
- except (ValueError, TypeError):
545
- logger.warning(f"Invalid speaker ID: {speaker}, skipping entry")
546
- continue
547
-
548
- # Clean up text
549
- text = text.strip()
550
- if text:
551
- script_lines.append(f"Speaker {speaker_id}: {text}")
552
-
553
- if not script_lines:
554
- raise ValueError("No valid entries found in JSON file")
555
-
556
- return "\n".join(script_lines)
557
-
558
- def _convert_text_to_script(self, text_file: str) -> str:
559
- """
560
- Convert text file to script format.
561
- Handles multiple formats:
562
- 1. Already formatted as "Speaker X: text"
563
- 2. Plain text (assigns to Speaker 1)
564
-
565
- Handles edge cases like multiple colons in a line.
566
- """
567
- with open(text_file, 'r', encoding='utf-8') as f:
568
- lines = f.readlines()
569
-
570
- script_lines = []
571
- current_speaker = 1
572
-
573
- for line in lines:
574
- line = line.strip()
575
- if not line:
576
- continue
577
-
578
- # Try to parse as "Speaker X: text" format
579
- # Use regex to be more robust
580
- speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
581
-
582
- if speaker_match:
583
- speaker_id = int(speaker_match.group(1))
584
- text = speaker_match.group(2).strip()
585
- if text:
586
- script_lines.append(f"Speaker {speaker_id}: {text}")
587
- else:
588
- # Treat as plain text - assign to current speaker
589
- script_lines.append(f"Speaker {current_speaker}: {line}")
590
-
591
- if not script_lines:
592
- raise ValueError("No valid content found in text file")
593
-
594
- return "\n".join(script_lines)
595
-
596
- def _parse_script(self, script: str) -> List[Tuple[int, str]]:
597
- """Parse script into list of (speaker_id, text) tuples."""
598
- lines = script.strip().split("\n")
599
- parsed_lines = []
600
- speaker_ids = []
601
-
602
- # First pass: parse all lines and collect speaker IDs
603
- for line in lines:
604
- if not line.strip():
605
- continue
606
-
607
- # Use regex to handle edge cases like multiple colons
608
- match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
609
-
610
- if match:
611
- speaker_id = int(match.group(1))
612
- text = ' ' + match.group(2).strip()
613
- parsed_lines.append((speaker_id, text))
614
- speaker_ids.append(speaker_id)
615
- else:
616
- logger.warning(f"Could not parse line: '{line}'")
617
-
618
- if not parsed_lines:
619
- if script.strip():
620
- # Treat the entire script as a single line with default speaker
621
- parsed_lines.append({'speaker': 'Narrator', 'text': script.strip()})
622
- else:
623
- if script.strip():
624
- # Treat the entire script as a single line with default speaker
625
- parsed_lines.append({'speaker': 'Narrator', 'text': script.strip()})
626
- return parsed_lines
627
- else:
628
- raise ValueError("No valid speaker lines found in script")
629
-
630
- # Check if we need to normalize speaker IDs (only if all are > 0)
631
- min_speaker_id = min(speaker_ids)
632
- if min_speaker_id > 0:
633
- # Normalize to start from 0
634
- normalized_lines = []
635
- for speaker_id, text in parsed_lines:
636
- normalized_lines.append((speaker_id - 1, text))
637
- return normalized_lines
638
- else:
639
- # Keep original IDs
640
- return parsed_lines
641
-
642
- def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding:
643
- """Merge text and audio inputs into a single BatchEncoding."""
644
- # Start with text inputs
645
- merged = BatchEncoding(text_inputs)
646
-
647
- # Add audio-specific fields
648
- if "audio" in audio_inputs:
649
- merged["speech_inputs"] = audio_inputs["audio"]
650
- if "streaming" in audio_inputs:
651
- merged["streaming"] = audio_inputs["streaming"]
652
-
653
- return merged
654
-
655
- def batch_decode(self, *args, **kwargs):
656
- """
657
- This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].
658
- Please refer to the docstring of this method for more information.
659
- """
660
- return self.tokenizer.batch_decode(*args, **kwargs)
661
-
662
- def decode(self, *args, **kwargs):
663
- """
664
- This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].
665
- Please refer to the docstring of this method for more information.
666
- """
667
- return self.tokenizer.decode(*args, **kwargs)
668
-
669
- @property
670
- def model_input_names(self):
671
- """
672
- Return the list of inputs accepted by the model.
673
- """
674
- tokenizer_input_names = self.tokenizer.model_input_names
675
- audio_processor_input_names = self.audio_processor.model_input_names
676
- return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"]))
677
-
678
- def save_audio(self,
679
- audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
680
- output_path: str = "output.wav",
681
- sampling_rate: Optional[int] = None,
682
- normalize: bool = False,
683
- batch_prefix: str = "audio_",
684
- ) -> str:
685
- """
686
- Save audio data to a file.
687
- Args:
688
- audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):
689
- The audio data to save. Can be a single tensor/array or a list of them.
690
- output_path (str, optional): Path to save the audio file. Defaults to "output.wav".
691
- sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.
692
- normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.
693
- batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".
694
- Returns:
695
- str: The path to the saved audio file.
696
- """
697
- return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix)
698
-
699
- __all__ = [
700
- "VibeVoiceProcessor",
701
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/processor/vibevoice_tokenizer_processor.py DELETED
@@ -1,483 +0,0 @@
1
- """
2
- Processor class for VibeVoice models.
3
- """
4
-
5
- import os
6
- import json
7
- import warnings
8
- from typing import List, Optional, Union, Dict, Any
9
-
10
- import numpy as np
11
- import torch
12
-
13
- from transformers.feature_extraction_utils import FeatureExtractionMixin
14
- from transformers.utils import logging
15
-
16
- logger = logging.get_logger(__name__)
17
-
18
-
19
- class AudioNormalizer:
20
- """
21
- Audio normalization class for VibeVoice tokenizer.
22
-
23
- This class provides audio normalization to ensure consistent input levels
24
- for the VibeVoice tokenizer while maintaining audio quality.
25
- """
26
-
27
- def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
28
- """
29
- Initialize the audio normalizer.
30
-
31
- Args:
32
- target_dB_FS (float): Target dB FS level for the audio. Default: -25
33
- eps (float): Small value to avoid division by zero. Default: 1e-6
34
- """
35
- self.target_dB_FS = target_dB_FS
36
- self.eps = eps
37
-
38
- def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
39
- """
40
- Adjust the audio to the target dB FS level.
41
-
42
- Args:
43
- audio (np.ndarray): Input audio signal
44
-
45
- Returns:
46
- tuple: (normalized_audio, rms, scalar)
47
- """
48
- rms = np.sqrt(np.mean(audio**2))
49
- scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
50
- normalized_audio = audio * scalar
51
- return normalized_audio, rms, scalar
52
-
53
- def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
54
- """
55
- Avoid clipping by scaling down if necessary.
56
-
57
- Args:
58
- audio (np.ndarray): Input audio signal
59
- scalar (float, optional): Explicit scaling factor
60
-
61
- Returns:
62
- tuple: (normalized_audio, scalar)
63
- """
64
- if scalar is None:
65
- max_val = np.max(np.abs(audio))
66
- if max_val > 1.0:
67
- scalar = max_val + self.eps
68
- else:
69
- scalar = 1.0
70
-
71
- return audio / scalar, scalar
72
-
73
- def __call__(self, audio: np.ndarray) -> np.ndarray:
74
- """
75
- Normalize the audio by adjusting to target dB FS and avoiding clipping.
76
-
77
- Args:
78
- audio (np.ndarray): Input audio signal
79
-
80
- Returns:
81
- np.ndarray: Normalized audio signal
82
- """
83
- # First adjust to target dB FS
84
- audio, _, _ = self.tailor_dB_FS(audio)
85
- # Then avoid clipping
86
- audio, _ = self.avoid_clipping(audio)
87
- return audio
88
-
89
-
90
- # Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
91
- class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
92
- """
93
- Processor for VibeVoice acoustic tokenizer models.
94
-
95
- This processor handles audio preprocessing for VibeVoice models, including:
96
- - Audio format conversion (stereo to mono)
97
- - Optional audio normalization
98
- - Streaming support for infinite-length audio
99
-
100
- Args:
101
- sampling_rate (int, optional): Expected sampling rate. Defaults to 24000.
102
- normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
103
- target_dB_FS (float, optional): Target dB FS for normalization. Defaults to -25.
104
- eps (float, optional): Small value for numerical stability. Defaults to 1e-6.
105
- """
106
- model_input_names = ["input_features"]
107
-
108
- def __init__(
109
- self,
110
- sampling_rate: int = 24000,
111
- normalize_audio: bool = True,
112
- target_dB_FS: float = -25,
113
- eps: float = 1e-6,
114
- **kwargs,
115
- ):
116
- super().__init__(**kwargs)
117
-
118
- self.sampling_rate = sampling_rate
119
- self.normalize_audio = normalize_audio
120
-
121
- # Initialize audio normalizer if needed
122
- if self.normalize_audio:
123
- self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps)
124
- else:
125
- self.normalizer = None
126
-
127
- # Save config
128
- self.feature_extractor_dict = {
129
- "sampling_rate": sampling_rate,
130
- "normalize_audio": normalize_audio,
131
- "target_dB_FS": target_dB_FS,
132
- "eps": eps,
133
- }
134
-
135
- def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
136
- """
137
- Convert stereo audio to mono if needed.
138
-
139
- Args:
140
- audio (np.ndarray): Input audio array
141
-
142
- Returns:
143
- np.ndarray: Mono audio array
144
- """
145
- if len(audio.shape) == 1:
146
- return audio
147
- elif len(audio.shape) == 2:
148
- if audio.shape[0] == 2: # (2, time)
149
- return np.mean(audio, axis=0)
150
- elif audio.shape[1] == 2: # (time, 2)
151
- return np.mean(audio, axis=1)
152
- else:
153
- # If one dimension is 1, squeeze it
154
- if audio.shape[0] == 1:
155
- return audio.squeeze(0)
156
- elif audio.shape[1] == 1:
157
- return audio.squeeze(1)
158
- else:
159
- raise ValueError(f"Unexpected audio shape: {audio.shape}")
160
- else:
161
- raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}")
162
-
163
- def _process_single_audio(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray:
164
- """
165
- Process a single audio array.
166
-
167
- Args:
168
- audio: Single audio input
169
-
170
- Returns:
171
- np.ndarray: Processed audio
172
- """
173
- # Convert to numpy array
174
- if not isinstance(audio, np.ndarray):
175
- audio = np.array(audio, dtype=np.float32)
176
- else:
177
- audio = audio.astype(np.float32)
178
-
179
- # Ensure mono
180
- audio = self._ensure_mono(audio)
181
-
182
- # Normalize if requested
183
- if self.normalize_audio and self.normalizer is not None:
184
- audio = self.normalizer(audio)
185
-
186
- return audio
187
-
188
- def __call__(
189
- self,
190
- audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str]] = None,
191
- sampling_rate: Optional[int] = None,
192
- return_tensors: Optional[str] = None,
193
- **kwargs,
194
- ):
195
- """
196
- Process audio for VibeVoice models.
197
-
198
- Args:
199
- audio: Audio input(s) to process. Can be:
200
- - str: Path to audio file
201
- - np.ndarray: Audio array
202
- - List[float]: Audio as list of floats
203
- - List[np.ndarray]: Batch of audio arrays
204
- - List[str]: Batch of audio file paths
205
- sampling_rate (int, optional): Sampling rate of the input audio
206
- return_tensors (str, optional): Return format ('pt' for PyTorch, 'np' for NumPy)
207
-
208
- Returns:
209
- dict: Processed audio inputs with keys:
210
- - input_features: Audio tensor(s) ready for the model
211
- """
212
- if audio is None:
213
- raise ValueError("Audio input is required")
214
-
215
- # Validate sampling rate
216
- if sampling_rate is not None and sampling_rate != self.sampling_rate:
217
- logger.warning(
218
- f"Input sampling rate ({sampling_rate}) differs from expected "
219
- f"sampling rate ({self.sampling_rate}). Please resample your audio."
220
- )
221
-
222
- # Handle different input types
223
- if isinstance(audio, str):
224
- # Single audio file path
225
- audio = self._load_audio_from_path(audio)
226
- is_batched = False
227
- elif isinstance(audio, list):
228
- if len(audio) == 0:
229
- raise ValueError("Empty audio list provided")
230
-
231
- # Check if it's a list of file paths
232
- if all(isinstance(item, str) for item in audio):
233
- # Batch of audio file paths
234
- audio = [self._load_audio_from_path(path) for path in audio]
235
- is_batched = True
236
- else:
237
- # Check if it's batched audio arrays
238
- is_batched = isinstance(audio[0], (np.ndarray, list))
239
- else:
240
- # Single audio array or list
241
- is_batched = False
242
-
243
- # Process audio
244
- if is_batched:
245
- processed_audio = [self._process_single_audio(a) for a in audio]
246
- else:
247
- processed_audio = [self._process_single_audio(audio)]
248
-
249
- # Convert to tensors if requested
250
- if return_tensors == "pt":
251
- if len(processed_audio) == 1:
252
- # Create a proper batch dimension (B, T)
253
- input_features = torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1)
254
- else:
255
- # For batched input with different lengths, create a batch properly
256
- input_features = torch.stack([torch.from_numpy(a) for a in processed_audio]).unsqueeze(1)
257
- elif return_tensors == "np":
258
- if len(processed_audio) == 1:
259
- input_features = processed_audio[0][np.newaxis, np.newaxis, :]
260
- else:
261
- input_features = np.stack(processed_audio)[:, np.newaxis, :]
262
- else:
263
- input_features = processed_audio[0] if len(processed_audio) == 1 else processed_audio
264
-
265
- outputs = {
266
- "audio": input_features, # Use "audio" instead of "input_features"
267
- }
268
-
269
- return outputs
270
-
271
- def _load_audio_from_path(self, audio_path: str) -> np.ndarray:
272
- """
273
- Load audio from file path.
274
-
275
- Args:
276
- audio_path (str): Path to audio file
277
-
278
- Returns:
279
- np.ndarray: Loaded audio array
280
- """
281
- # Get file extension to determine loading method
282
- file_ext = os.path.splitext(audio_path)[1].lower()
283
-
284
- if file_ext in ['.wav', '.mp3', '.flac', '.m4a', '.ogg']:
285
- # Audio file - use librosa
286
- import librosa
287
- audio_array, sr = librosa.load(
288
- audio_path,
289
- sr=self.sampling_rate,
290
- mono=True
291
- )
292
- return audio_array
293
- elif file_ext == '.pt':
294
- # PyTorch tensor file
295
- audio_tensor = torch.load(audio_path, map_location='cpu').squeeze()
296
- if isinstance(audio_tensor, torch.Tensor):
297
- audio_array = audio_tensor.numpy()
298
- else:
299
- audio_array = np.array(audio_tensor)
300
- return audio_array.astype(np.float32)
301
- elif file_ext == '.npy':
302
- # NumPy file
303
- audio_array = np.load(audio_path)
304
- return audio_array.astype(np.float32)
305
- else:
306
- raise ValueError(
307
- f"Unsupported file format: {file_ext}. "
308
- f"Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz"
309
- )
310
-
311
- def preprocess_audio(
312
- self,
313
- audio_path_or_array: Union[str, np.ndarray],
314
- normalize: Optional[bool] = None,
315
- ) -> np.ndarray:
316
- """
317
- Convenience method to preprocess audio from file path or array.
318
- This method is kept for backward compatibility but __call__ is recommended.
319
-
320
- Args:
321
- audio_path_or_array: Path to audio file or numpy array
322
- normalize: Whether to normalize (overrides default setting)
323
-
324
- Returns:
325
- np.ndarray: Preprocessed audio array
326
- """
327
- if isinstance(audio_path_or_array, str):
328
- audio_array = self._load_audio_from_path(audio_path_or_array)
329
- else:
330
- audio_array = np.array(audio_path_or_array, dtype=np.float32)
331
-
332
- # Override normalization setting if specified
333
- original_normalize = self.normalize_audio
334
- if normalize is not None:
335
- self.normalize_audio = normalize
336
-
337
- try:
338
- processed = self._process_single_audio(audio_array)
339
- finally:
340
- # Restore original setting
341
- self.normalize_audio = original_normalize
342
-
343
- return processed
344
-
345
- # Override to_dict method for configuration saving
346
- def to_dict(self) -> Dict[str, Any]:
347
- """
348
- Convert the object to a dict containing all attributes needed for serialization.
349
- """
350
- return self.feature_extractor_dict
351
-
352
- def save_audio(
353
- self,
354
- audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
355
- output_path: str = "output.wav",
356
- sampling_rate: Optional[int] = None,
357
- normalize: bool = False,
358
- batch_prefix: str = "audio_",
359
- ):
360
- """
361
- Save audio data to WAV file(s).
362
-
363
- Args:
364
- audio: Audio data to save. Can be:
365
- - torch.Tensor: PyTorch tensor with shape (B, C, T) or (B, T) or (T)
366
- - np.ndarray: NumPy array with shape (B, C, T) or (B, T) or (T)
367
- - List of tensors or arrays
368
- output_path: Path where to save the audio. If saving multiple files,
369
- this is treated as a directory and individual files will be saved inside.
370
- sampling_rate: Sampling rate for the saved audio. Defaults to the processor's rate.
371
- normalize: Whether to normalize audio before saving.
372
- batch_prefix: Prefix for batch files when saving multiple audios.
373
-
374
- Returns:
375
- List[str]: Paths to the saved audio files.
376
- """
377
- if sampling_rate is None:
378
- sampling_rate = self.sampling_rate
379
-
380
- try:
381
- import soundfile as sf
382
- except ImportError:
383
- raise ImportError(
384
- "soundfile is required to save audio files. "
385
- "Install it with: pip install soundfile"
386
- )
387
-
388
- # Ensure audio is in the right format
389
- if isinstance(audio, torch.Tensor):
390
- # Convert PyTorch tensor to numpy
391
- audio_np = audio.float().detach().cpu().numpy()
392
- elif isinstance(audio, np.ndarray):
393
- audio_np = audio
394
- elif isinstance(audio, list):
395
- # Handle list of tensors or arrays
396
- if all(isinstance(a, torch.Tensor) for a in audio):
397
- audio_np = [a.float().detach().cpu().numpy() for a in audio]
398
- else:
399
- audio_np = audio
400
- else:
401
- raise ValueError(f"Unsupported audio type: {type(audio)}")
402
-
403
- saved_paths = []
404
-
405
- # Handle based on shape or type
406
- if isinstance(audio_np, list):
407
- # Multiple separate audios to save
408
- output_dir = output_path
409
-
410
- # Ensure output directory exists
411
- os.makedirs(output_dir, exist_ok=True)
412
-
413
- # Save each audio
414
- for i, audio_item in enumerate(audio_np):
415
- audio_item = self._prepare_audio_for_save(audio_item, normalize)
416
- file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
417
- sf.write(file_path, audio_item, sampling_rate)
418
- saved_paths.append(file_path)
419
-
420
- else:
421
- # Handle different dimensions
422
- if len(audio_np.shape) >= 3: # (B, C, T) or similar
423
- # Get batch size
424
- batch_size = audio_np.shape[0]
425
-
426
- if batch_size > 1:
427
- # Multiple audios in a batch
428
- output_dir = output_path
429
-
430
- # Ensure output directory exists
431
- os.makedirs(output_dir, exist_ok=True)
432
-
433
- # Save each audio in the batch
434
- for i in range(batch_size):
435
- # Extract single audio and remove channel dim if present
436
- single_audio = audio_np[i]
437
- if len(single_audio.shape) > 1:
438
- if single_audio.shape[0] == 1: # (1, T)
439
- single_audio = single_audio.squeeze(0)
440
-
441
- single_audio = self._prepare_audio_for_save(single_audio, normalize)
442
- file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
443
- sf.write(file_path, single_audio, sampling_rate)
444
- saved_paths.append(file_path)
445
- else:
446
- # Single audio with batch and channel dims
447
- audio_item = audio_np.squeeze() # Remove batch and channel dimensions
448
- audio_item = self._prepare_audio_for_save(audio_item, normalize)
449
- sf.write(output_path, audio_item, sampling_rate)
450
- saved_paths.append(output_path)
451
- else:
452
- # Single audio without batch dimension
453
- audio_item = self._prepare_audio_for_save(audio_np, normalize)
454
- sf.write(output_path, audio_item, sampling_rate)
455
- saved_paths.append(output_path)
456
-
457
- return saved_paths
458
-
459
- def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
460
- """
461
- Prepare audio for saving by ensuring it's the right shape and optionally normalizing.
462
-
463
- Args:
464
- audio: Audio data as numpy array
465
- normalize: Whether to normalize audio
466
-
467
- Returns:
468
- np.ndarray: Processed audio ready for saving
469
- """
470
- # Ensure right dimensionality
471
- if len(audio.shape) > 1 and audio.shape[0] == 1: # (1, T)
472
- audio = audio.squeeze(0)
473
-
474
- # Normalize if requested
475
- if normalize:
476
- max_val = np.abs(audio).max()
477
- if max_val > 0:
478
- audio = audio / max_val
479
-
480
- return audio
481
-
482
-
483
- __all__ = ["VibeVoiceTokenizerProcessor", "AudioNormalizer"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/schedule/__init__.py DELETED
File without changes
src/vibevoice/schedule/dpm_solver.py DELETED
@@ -1,1065 +0,0 @@
1
- # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
16
-
17
- import math
18
- from typing import List, Optional, Tuple, Union
19
-
20
- import numpy as np
21
- import torch
22
-
23
- from diffusers.configuration_utils import ConfigMixin, register_to_config
24
- from diffusers.utils import deprecate
25
- from diffusers.utils.torch_utils import randn_tensor
26
- from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
-
28
- def betas_for_alpha_bar(
29
- num_diffusion_timesteps,
30
- max_beta=0.999,
31
- alpha_transform_type="cosine",
32
- ):
33
- """
34
- Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
35
- (1-beta) over time from t = [0,1].
36
-
37
- Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
38
- to that part of the diffusion process.
39
-
40
-
41
- Args:
42
- num_diffusion_timesteps (`int`): the number of betas to produce.
43
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
44
- prevent singularities.
45
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
46
- Choose from `cosine` or `exp`
47
-
48
- Returns:
49
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
50
- """
51
- if alpha_transform_type == "cosine":
52
-
53
- def alpha_bar_fn(t):
54
- return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
55
- # return math.cos(t * math.pi / 2 * 0.95) ** 2
56
-
57
- elif alpha_transform_type == "exp":
58
-
59
- def alpha_bar_fn(t):
60
- return math.exp(t * -12.0)
61
-
62
- elif alpha_transform_type == "cauchy":
63
- # µ + γ tan (π (0.5 - x)) γ = 1, µ = 3
64
- # alpha^2 = 1-1/(exp(λ)+1)
65
- def alpha_bar_fn(t, gamma=1, mu=3):
66
- snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9)
67
- return 1 - 1 / (math.exp(snr) + 1.1)
68
-
69
- elif alpha_transform_type == "laplace":
70
- # µ − bsgn(0.5 − t) log(1 − 2|t − 0.5|) µ = 0, b = 1
71
- def alpha_bar_fn(t, mu=0, b=1):
72
- snr = mu - b * math.copysign(1, 0.5 - t) * math.log(1 - 2 * abs(t - 0.5) * 0.98)
73
- return 1 - 1 / (math.exp(snr) + 1.02)
74
-
75
- else:
76
- raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
77
-
78
- betas = []
79
- for i in range(num_diffusion_timesteps):
80
- t1 = i / num_diffusion_timesteps
81
- t2 = (i + 1) / num_diffusion_timesteps
82
- betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
83
- return torch.tensor(betas, dtype=torch.float32)
84
-
85
-
86
- # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
87
- def rescale_zero_terminal_snr(betas):
88
- """
89
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
90
-
91
-
92
- Args:
93
- betas (`torch.Tensor`):
94
- the betas that the scheduler is being initialized with.
95
-
96
- Returns:
97
- `torch.Tensor`: rescaled betas with zero terminal SNR
98
- """
99
- # Convert betas to alphas_bar_sqrt
100
- alphas = 1.0 - betas
101
- alphas_cumprod = torch.cumprod(alphas, dim=0)
102
- alphas_bar_sqrt = alphas_cumprod.sqrt()
103
-
104
- # Store old values.
105
- alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
106
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
107
-
108
- # Shift so the last timestep is zero.
109
- alphas_bar_sqrt -= alphas_bar_sqrt_T
110
-
111
- # Scale so the first timestep is back to the old value.
112
- alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
113
-
114
- # Convert alphas_bar_sqrt to betas
115
- alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
116
- alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
117
- alphas = torch.cat([alphas_bar[0:1], alphas])
118
- betas = 1 - alphas
119
-
120
- return betas
121
-
122
- class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
123
- """
124
- `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
125
-
126
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
127
- methods the library implements for all schedulers such as loading and saving.
128
-
129
- Args:
130
- num_train_timesteps (`int`, defaults to 1000):
131
- The number of diffusion steps to train the model.
132
- beta_start (`float`, defaults to 0.0001):
133
- The starting `beta` value of inference.
134
- beta_end (`float`, defaults to 0.02):
135
- The final `beta` value.
136
- beta_schedule (`str`, defaults to `"linear"`):
137
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
138
- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
139
- trained_betas (`np.ndarray`, *optional*):
140
- Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
141
- solver_order (`int`, defaults to 2):
142
- The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
143
- sampling, and `solver_order=3` for unconditional sampling.
144
- prediction_type (`str`, defaults to `epsilon`, *optional*):
145
- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
146
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
147
- Video](https://imagen.research.google/video/paper.pdf) paper).
148
- thresholding (`bool`, defaults to `False`):
149
- Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
150
- as Stable Diffusion.
151
- dynamic_thresholding_ratio (`float`, defaults to 0.995):
152
- The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
153
- sample_max_value (`float`, defaults to 1.0):
154
- The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
155
- `algorithm_type="dpmsolver++"`.
156
- algorithm_type (`str`, defaults to `dpmsolver++`):
157
- Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
158
- `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
159
- paper, and the `dpmsolver++` type implements the algorithms in the
160
- [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
161
- `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
162
- solver_type (`str`, defaults to `midpoint`):
163
- Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
164
- sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
165
- lower_order_final (`bool`, defaults to `True`):
166
- Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
167
- stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
168
- euler_at_final (`bool`, defaults to `False`):
169
- Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
170
- richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
171
- steps, but sometimes may result in blurring.
172
- use_karras_sigmas (`bool`, *optional*, defaults to `False`):
173
- Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
174
- the sigmas are determined according to a sequence of noise levels {σi}.
175
- use_lu_lambdas (`bool`, *optional*, defaults to `False`):
176
- Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
177
- the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
178
- `lambda(t)`.
179
- final_sigmas_type (`str`, defaults to `"zero"`):
180
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
181
- sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
182
- lambda_min_clipped (`float`, defaults to `-inf`):
183
- Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
184
- cosine (`squaredcos_cap_v2`) noise schedule.
185
- variance_type (`str`, *optional*):
186
- Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
187
- contains the predicted Gaussian variance.
188
- timestep_spacing (`str`, defaults to `"linspace"`):
189
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
190
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
191
- steps_offset (`int`, defaults to 0):
192
- An offset added to the inference steps, as required by some model families.
193
- rescale_betas_zero_snr (`bool`, defaults to `False`):
194
- Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
195
- dark samples instead of limiting it to samples with medium brightness. Loosely related to
196
- [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
197
- """
198
-
199
- _compatibles = [e.name for e in KarrasDiffusionSchedulers]
200
- order = 1
201
-
202
- @register_to_config
203
- def __init__(
204
- self,
205
- num_train_timesteps: int = 1000,
206
- beta_start: float = 0.0001,
207
- beta_end: float = 0.02,
208
- beta_schedule: str = "linear",
209
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
210
- solver_order: int = 2,
211
- prediction_type: str = "epsilon",
212
- thresholding: bool = False,
213
- dynamic_thresholding_ratio: float = 0.995,
214
- sample_max_value: float = 1.0,
215
- algorithm_type: str = "dpmsolver++",
216
- solver_type: str = "midpoint",
217
- lower_order_final: bool = True,
218
- euler_at_final: bool = False,
219
- use_karras_sigmas: Optional[bool] = False,
220
- use_lu_lambdas: Optional[bool] = False,
221
- final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
222
- lambda_min_clipped: float = -float("inf"),
223
- variance_type: Optional[str] = None,
224
- timestep_spacing: str = "linspace",
225
- steps_offset: int = 0,
226
- rescale_betas_zero_snr: bool = False,
227
- ):
228
- if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
229
- deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
230
- deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
231
-
232
- if trained_betas is not None:
233
- self.betas = torch.tensor(trained_betas, dtype=torch.float32)
234
- elif beta_schedule == "linear":
235
- self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
236
- elif beta_schedule == "scaled_linear":
237
- # this schedule is very specific to the latent diffusion model.
238
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
239
- elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine":
240
- # Glide cosine schedule
241
- self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
242
- elif beta_schedule == "cauchy":
243
- self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cauchy")
244
- elif beta_schedule == "laplace":
245
- self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
246
- else:
247
- raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
248
-
249
- if rescale_betas_zero_snr:
250
- self.betas = rescale_zero_terminal_snr(self.betas)
251
-
252
- self.alphas = 1.0 - self.betas
253
- self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
254
-
255
- if rescale_betas_zero_snr:
256
- # Close to 0 without being 0 so first sigma is not inf
257
- # FP16 smallest positive subnormal works well here
258
- self.alphas_cumprod[-1] = 2**-24
259
-
260
- # Currently we only support VP-type noise schedule
261
- self.alpha_t = torch.sqrt(self.alphas_cumprod)
262
- self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
263
- self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
264
- self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
265
-
266
- # standard deviation of the initial noise distribution
267
- self.init_noise_sigma = 1.0
268
-
269
- # settings for DPM-Solver
270
- if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
271
- if algorithm_type == "deis":
272
- self.register_to_config(algorithm_type="dpmsolver++")
273
- else:
274
- raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
275
-
276
- if solver_type not in ["midpoint", "heun"]:
277
- if solver_type in ["logrho", "bh1", "bh2"]:
278
- self.register_to_config(solver_type="midpoint")
279
- else:
280
- raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
281
-
282
- if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
283
- raise ValueError(
284
- f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
285
- )
286
-
287
- # setable values
288
- self.num_inference_steps = None
289
- timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
290
- self.timesteps = torch.from_numpy(timesteps)
291
- self.model_outputs = [None] * solver_order
292
- self.lower_order_nums = 0
293
- self._step_index = None
294
- self._begin_index = None
295
- self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
296
-
297
- @property
298
- def step_index(self):
299
- """
300
- The index counter for current timestep. It will increase 1 after each scheduler step.
301
- """
302
- return self._step_index
303
-
304
- @property
305
- def begin_index(self):
306
- """
307
- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
308
- """
309
- return self._begin_index
310
-
311
- def set_begin_index(self, begin_index: int = 0):
312
- """
313
- Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
314
-
315
- Args:
316
- begin_index (`int`):
317
- The begin index for the scheduler.
318
- """
319
- self._begin_index = begin_index
320
-
321
- def set_timesteps(
322
- self,
323
- num_inference_steps: int = None,
324
- device: Union[str, torch.device] = None,
325
- timesteps: Optional[List[int]] = None,
326
- ):
327
- """
328
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
329
-
330
- Args:
331
- num_inference_steps (`int`):
332
- The number of diffusion steps used when generating samples with a pre-trained model.
333
- device (`str` or `torch.device`, *optional*):
334
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
335
- timesteps (`List[int]`, *optional*):
336
- Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
337
- based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
338
- must be `None`, and `timestep_spacing` attribute will be ignored.
339
- """
340
- if num_inference_steps is None and timesteps is None:
341
- raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
342
- if num_inference_steps is not None and timesteps is not None:
343
- raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
344
- if timesteps is not None and self.config.use_karras_sigmas:
345
- raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
346
- if timesteps is not None and self.config.use_lu_lambdas:
347
- raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
348
-
349
- if timesteps is not None:
350
- timesteps = np.array(timesteps).astype(np.int64)
351
- else:
352
- # Clipping the minimum of all lambda(t) for numerical stability.
353
- # This is critical for cosine (squaredcos_cap_v2) noise schedule.
354
- clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
355
- last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
356
-
357
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
358
- if self.config.timestep_spacing == "linspace":
359
- timesteps = (
360
- np.linspace(0, last_timestep - 1, num_inference_steps + 1)
361
- .round()[::-1][:-1]
362
- .copy()
363
- .astype(np.int64)
364
- )
365
- elif self.config.timestep_spacing == "leading":
366
- step_ratio = last_timestep // (num_inference_steps + 1)
367
- # creates integer timesteps by multiplying by ratio
368
- # casting to int to avoid issues when num_inference_step is power of 3
369
- timesteps = (
370
- (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
371
- )
372
- timesteps += self.config.steps_offset
373
- elif self.config.timestep_spacing == "trailing":
374
- step_ratio = self.config.num_train_timesteps / num_inference_steps
375
- # creates integer timesteps by multiplying by ratio
376
- # casting to int to avoid issues when num_inference_step is power of 3
377
- timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
378
- timesteps -= 1
379
- else:
380
- raise ValueError(
381
- f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
382
- )
383
-
384
- sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
385
- log_sigmas = np.log(sigmas)
386
-
387
- if self.config.use_karras_sigmas:
388
- sigmas = np.flip(sigmas).copy()
389
- sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
390
- timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
391
- elif self.config.use_lu_lambdas:
392
- lambdas = np.flip(log_sigmas.copy())
393
- lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
394
- sigmas = np.exp(lambdas)
395
- timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
396
- else:
397
- sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
398
-
399
- if self.config.final_sigmas_type == "sigma_min":
400
- sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
401
- elif self.config.final_sigmas_type == "zero":
402
- sigma_last = 0
403
- else:
404
- raise ValueError(
405
- f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
406
- )
407
-
408
- sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
409
-
410
- self.sigmas = torch.from_numpy(sigmas)
411
- self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
412
-
413
- self.num_inference_steps = len(timesteps)
414
-
415
- self.model_outputs = [
416
- None,
417
- ] * self.config.solver_order
418
- self.lower_order_nums = 0
419
-
420
- # add an index counter for schedulers that allow duplicated timesteps
421
- self._step_index = None
422
- self._begin_index = None
423
- self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
424
-
425
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
426
- def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
427
- """
428
- "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
429
- prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
430
- s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
431
- pixels from saturation at each step. We find that dynamic thresholding results in significantly better
432
- photorealism as well as better image-text alignment, especially when using very large guidance weights."
433
-
434
- https://arxiv.org/abs/2205.11487
435
- """
436
- dtype = sample.dtype
437
- batch_size, channels, *remaining_dims = sample.shape
438
-
439
- if dtype not in (torch.float32, torch.float64):
440
- sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
441
-
442
- # Flatten sample for doing quantile calculation along each image
443
- sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
444
-
445
- abs_sample = sample.abs() # "a certain percentile absolute pixel value"
446
-
447
- s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
448
- s = torch.clamp(
449
- s, min=1, max=self.config.sample_max_value
450
- ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
451
- s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
452
- sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
453
-
454
- sample = sample.reshape(batch_size, channels, *remaining_dims)
455
- sample = sample.to(dtype)
456
-
457
- return sample
458
-
459
- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
460
- def _sigma_to_t(self, sigma, log_sigmas):
461
- # get log sigma
462
- log_sigma = np.log(np.maximum(sigma, 1e-10))
463
-
464
- # get distribution
465
- dists = log_sigma - log_sigmas[:, np.newaxis]
466
-
467
- # get sigmas range
468
- low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
469
- high_idx = low_idx + 1
470
-
471
- low = log_sigmas[low_idx]
472
- high = log_sigmas[high_idx]
473
-
474
- # interpolate sigmas
475
- w = (low - log_sigma) / (low - high)
476
- w = np.clip(w, 0, 1)
477
-
478
- # transform interpolation to time range
479
- t = (1 - w) * low_idx + w * high_idx
480
- t = t.reshape(sigma.shape)
481
- return t
482
-
483
- def _sigma_to_alpha_sigma_t(self, sigma):
484
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
485
- sigma_t = sigma * alpha_t
486
-
487
- return alpha_t, sigma_t
488
-
489
- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
490
- def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
491
- """Constructs the noise schedule of Karras et al. (2022)."""
492
-
493
- # Hack to make sure that other schedulers which copy this function don't break
494
- # TODO: Add this logic to the other schedulers
495
- if hasattr(self.config, "sigma_min"):
496
- sigma_min = self.config.sigma_min
497
- else:
498
- sigma_min = None
499
-
500
- if hasattr(self.config, "sigma_max"):
501
- sigma_max = self.config.sigma_max
502
- else:
503
- sigma_max = None
504
-
505
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
506
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
507
-
508
- rho = 7.0 # 7.0 is the value used in the paper
509
- ramp = np.linspace(0, 1, num_inference_steps)
510
- min_inv_rho = sigma_min ** (1 / rho)
511
- max_inv_rho = sigma_max ** (1 / rho)
512
- sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
513
- return sigmas
514
-
515
- def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
516
- """Constructs the noise schedule of Lu et al. (2022)."""
517
-
518
- lambda_min: float = in_lambdas[-1].item()
519
- lambda_max: float = in_lambdas[0].item()
520
-
521
- rho = 1.0 # 1.0 is the value used in the paper
522
- ramp = np.linspace(0, 1, num_inference_steps)
523
- min_inv_rho = lambda_min ** (1 / rho)
524
- max_inv_rho = lambda_max ** (1 / rho)
525
- lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
526
- return lambdas
527
-
528
- def convert_model_output(
529
- self,
530
- model_output: torch.Tensor,
531
- *args,
532
- sample: torch.Tensor = None,
533
- **kwargs,
534
- ) -> torch.Tensor:
535
- """
536
- Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
537
- designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
538
- integral of the data prediction model.
539
-
540
- <Tip>
541
-
542
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
543
- prediction and data prediction models.
544
-
545
- </Tip>
546
-
547
- Args:
548
- model_output (`torch.Tensor`):
549
- The direct output from the learned diffusion model.
550
- sample (`torch.Tensor`):
551
- A current instance of a sample created by the diffusion process.
552
-
553
- Returns:
554
- `torch.Tensor`:
555
- The converted model output.
556
- """
557
- timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
558
- if sample is None:
559
- if len(args) > 1:
560
- sample = args[1]
561
- else:
562
- raise ValueError("missing `sample` as a required keyward argument")
563
- if timestep is not None:
564
- deprecate(
565
- "timesteps",
566
- "1.0.0",
567
- "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
568
- )
569
-
570
- # DPM-Solver++ needs to solve an integral of the data prediction model.
571
- if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
572
- if self.config.prediction_type == "epsilon":
573
- # DPM-Solver and DPM-Solver++ only need the "mean" output.
574
- if self.config.variance_type in ["learned", "learned_range"]:
575
- model_output = model_output[:, :3]
576
- sigma = self.sigmas[self.step_index]
577
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
578
- x0_pred = (sample - sigma_t * model_output) / alpha_t
579
- elif self.config.prediction_type == "sample":
580
- x0_pred = model_output
581
- elif self.config.prediction_type == "v_prediction":
582
- sigma = self.sigmas[self.step_index]
583
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
584
- x0_pred = alpha_t * sample - sigma_t * model_output
585
- else:
586
- raise ValueError(
587
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
588
- " `v_prediction` for the DPMSolverMultistepScheduler."
589
- )
590
-
591
- if self.config.thresholding:
592
- x0_pred = self._threshold_sample(x0_pred)
593
-
594
- return x0_pred
595
-
596
- # DPM-Solver needs to solve an integral of the noise prediction model.
597
- elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
598
- if self.config.prediction_type == "epsilon":
599
- # DPM-Solver and DPM-Solver++ only need the "mean" output.
600
- if self.config.variance_type in ["learned", "learned_range"]:
601
- epsilon = model_output[:, :3]
602
- else:
603
- epsilon = model_output
604
- elif self.config.prediction_type == "sample":
605
- sigma = self.sigmas[self.step_index]
606
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
607
- epsilon = (sample - alpha_t * model_output) / sigma_t
608
- elif self.config.prediction_type == "v_prediction":
609
- sigma = self.sigmas[self.step_index]
610
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
611
- epsilon = alpha_t * model_output + sigma_t * sample
612
- else:
613
- raise ValueError(
614
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
615
- " `v_prediction` for the DPMSolverMultistepScheduler."
616
- )
617
-
618
- if self.config.thresholding:
619
- sigma = self.sigmas[self.step_index]
620
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
621
- x0_pred = (sample - sigma_t * epsilon) / alpha_t
622
- x0_pred = self._threshold_sample(x0_pred)
623
- epsilon = (sample - alpha_t * x0_pred) / sigma_t
624
-
625
- return epsilon
626
-
627
- def dpm_solver_first_order_update(
628
- self,
629
- model_output: torch.Tensor,
630
- *args,
631
- sample: torch.Tensor = None,
632
- noise: Optional[torch.Tensor] = None,
633
- **kwargs,
634
- ) -> torch.Tensor:
635
- """
636
- One step for the first-order DPMSolver (equivalent to DDIM).
637
-
638
- Args:
639
- model_output (`torch.Tensor`):
640
- The direct output from the learned diffusion model.
641
- sample (`torch.Tensor`):
642
- A current instance of a sample created by the diffusion process.
643
-
644
- Returns:
645
- `torch.Tensor`:
646
- The sample tensor at the previous timestep.
647
- """
648
- timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
649
- prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
650
- if sample is None:
651
- if len(args) > 2:
652
- sample = args[2]
653
- else:
654
- raise ValueError(" missing `sample` as a required keyward argument")
655
- if timestep is not None:
656
- deprecate(
657
- "timesteps",
658
- "1.0.0",
659
- "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
660
- )
661
-
662
- if prev_timestep is not None:
663
- deprecate(
664
- "prev_timestep",
665
- "1.0.0",
666
- "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
667
- )
668
-
669
- sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
670
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
671
- alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
672
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
673
- lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
674
-
675
- h = lambda_t - lambda_s
676
- if self.config.algorithm_type == "dpmsolver++":
677
- x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
678
- elif self.config.algorithm_type == "dpmsolver":
679
- x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
680
- elif self.config.algorithm_type == "sde-dpmsolver++":
681
- assert noise is not None
682
- x_t = (
683
- (sigma_t / sigma_s * torch.exp(-h)) * sample
684
- + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
685
- + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
686
- )
687
- elif self.config.algorithm_type == "sde-dpmsolver":
688
- assert noise is not None
689
- x_t = (
690
- (alpha_t / alpha_s) * sample
691
- - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
692
- + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
693
- )
694
- return x_t
695
-
696
- def multistep_dpm_solver_second_order_update(
697
- self,
698
- model_output_list: List[torch.Tensor],
699
- *args,
700
- sample: torch.Tensor = None,
701
- noise: Optional[torch.Tensor] = None,
702
- **kwargs,
703
- ) -> torch.Tensor:
704
- """
705
- One step for the second-order multistep DPMSolver.
706
-
707
- Args:
708
- model_output_list (`List[torch.Tensor]`):
709
- The direct outputs from learned diffusion model at current and latter timesteps.
710
- sample (`torch.Tensor`):
711
- A current instance of a sample created by the diffusion process.
712
-
713
- Returns:
714
- `torch.Tensor`:
715
- The sample tensor at the previous timestep.
716
- """
717
- timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
718
- prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
719
- if sample is None:
720
- if len(args) > 2:
721
- sample = args[2]
722
- else:
723
- raise ValueError(" missing `sample` as a required keyward argument")
724
- if timestep_list is not None:
725
- deprecate(
726
- "timestep_list",
727
- "1.0.0",
728
- "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
729
- )
730
-
731
- if prev_timestep is not None:
732
- deprecate(
733
- "prev_timestep",
734
- "1.0.0",
735
- "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
736
- )
737
-
738
- sigma_t, sigma_s0, sigma_s1 = (
739
- self.sigmas[self.step_index + 1],
740
- self.sigmas[self.step_index],
741
- self.sigmas[self.step_index - 1],
742
- )
743
-
744
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
745
- alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
746
- alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
747
-
748
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
749
- lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
750
- lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
751
-
752
- m0, m1 = model_output_list[-1], model_output_list[-2]
753
-
754
- h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
755
- r0 = h_0 / h
756
- D0, D1 = m0, (1.0 / r0) * (m0 - m1)
757
- if self.config.algorithm_type == "dpmsolver++":
758
- # See https://arxiv.org/abs/2211.01095 for detailed derivations
759
- if self.config.solver_type == "midpoint":
760
- x_t = (
761
- (sigma_t / sigma_s0) * sample
762
- - (alpha_t * (torch.exp(-h) - 1.0)) * D0
763
- - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
764
- )
765
- elif self.config.solver_type == "heun":
766
- x_t = (
767
- (sigma_t / sigma_s0) * sample
768
- - (alpha_t * (torch.exp(-h) - 1.0)) * D0
769
- + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
770
- )
771
- elif self.config.algorithm_type == "dpmsolver":
772
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
773
- if self.config.solver_type == "midpoint":
774
- x_t = (
775
- (alpha_t / alpha_s0) * sample
776
- - (sigma_t * (torch.exp(h) - 1.0)) * D0
777
- - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
778
- )
779
- elif self.config.solver_type == "heun":
780
- x_t = (
781
- (alpha_t / alpha_s0) * sample
782
- - (sigma_t * (torch.exp(h) - 1.0)) * D0
783
- - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
784
- )
785
- elif self.config.algorithm_type == "sde-dpmsolver++":
786
- assert noise is not None
787
- if self.config.solver_type == "midpoint":
788
- x_t = (
789
- (sigma_t / sigma_s0 * torch.exp(-h)) * sample
790
- + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
791
- + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
792
- + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
793
- )
794
- elif self.config.solver_type == "heun":
795
- x_t = (
796
- (sigma_t / sigma_s0 * torch.exp(-h)) * sample
797
- + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
798
- + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
799
- + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
800
- )
801
- elif self.config.algorithm_type == "sde-dpmsolver":
802
- assert noise is not None
803
- if self.config.solver_type == "midpoint":
804
- x_t = (
805
- (alpha_t / alpha_s0) * sample
806
- - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
807
- - (sigma_t * (torch.exp(h) - 1.0)) * D1
808
- + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
809
- )
810
- elif self.config.solver_type == "heun":
811
- x_t = (
812
- (alpha_t / alpha_s0) * sample
813
- - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
814
- - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
815
- + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
816
- )
817
- return x_t
818
-
819
- def multistep_dpm_solver_third_order_update(
820
- self,
821
- model_output_list: List[torch.Tensor],
822
- *args,
823
- sample: torch.Tensor = None,
824
- **kwargs,
825
- ) -> torch.Tensor:
826
- """
827
- One step for the third-order multistep DPMSolver.
828
-
829
- Args:
830
- model_output_list (`List[torch.Tensor]`):
831
- The direct outputs from learned diffusion model at current and latter timesteps.
832
- sample (`torch.Tensor`):
833
- A current instance of a sample created by diffusion process.
834
-
835
- Returns:
836
- `torch.Tensor`:
837
- The sample tensor at the previous timestep.
838
- """
839
-
840
- timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
841
- prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
842
- if sample is None:
843
- if len(args) > 2:
844
- sample = args[2]
845
- else:
846
- raise ValueError(" missing`sample` as a required keyward argument")
847
- if timestep_list is not None:
848
- deprecate(
849
- "timestep_list",
850
- "1.0.0",
851
- "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
852
- )
853
-
854
- if prev_timestep is not None:
855
- deprecate(
856
- "prev_timestep",
857
- "1.0.0",
858
- "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
859
- )
860
-
861
- sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
862
- self.sigmas[self.step_index + 1],
863
- self.sigmas[self.step_index],
864
- self.sigmas[self.step_index - 1],
865
- self.sigmas[self.step_index - 2],
866
- )
867
-
868
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
869
- alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
870
- alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
871
- alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
872
-
873
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
874
- lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
875
- lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
876
- lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
877
-
878
- m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
879
-
880
- h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
881
- r0, r1 = h_0 / h, h_1 / h
882
- D0 = m0
883
- D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
884
- D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
885
- D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
886
- if self.config.algorithm_type == "dpmsolver++":
887
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
888
- x_t = (
889
- (sigma_t / sigma_s0) * sample
890
- - (alpha_t * (torch.exp(-h) - 1.0)) * D0
891
- + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
892
- - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
893
- )
894
- elif self.config.algorithm_type == "dpmsolver":
895
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
896
- x_t = (
897
- (alpha_t / alpha_s0) * sample
898
- - (sigma_t * (torch.exp(h) - 1.0)) * D0
899
- - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
900
- - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
901
- )
902
- return x_t
903
-
904
- def index_for_timestep(self, timestep, schedule_timesteps=None):
905
- if schedule_timesteps is None:
906
- schedule_timesteps = self.timesteps
907
-
908
- index_candidates = (schedule_timesteps == timestep).nonzero()
909
-
910
- if len(index_candidates) == 0:
911
- step_index = len(self.timesteps) - 1
912
- # The sigma index that is taken for the **very** first `step`
913
- # is always the second index (or the last index if there is only 1)
914
- # This way we can ensure we don't accidentally skip a sigma in
915
- # case we start in the middle of the denoising schedule (e.g. for image-to-image)
916
- elif len(index_candidates) > 1:
917
- step_index = index_candidates[1].item()
918
- else:
919
- step_index = index_candidates[0].item()
920
-
921
- return step_index
922
-
923
- def _init_step_index(self, timestep):
924
- """
925
- Initialize the step_index counter for the scheduler.
926
- """
927
-
928
- if self.begin_index is None:
929
- if isinstance(timestep, torch.Tensor):
930
- timestep = timestep.to(self.timesteps.device)
931
- self._step_index = self.index_for_timestep(timestep)
932
- else:
933
- self._step_index = self._begin_index
934
-
935
- def step(
936
- self,
937
- model_output: torch.Tensor,
938
- timestep: int,
939
- sample: torch.Tensor,
940
- generator=None,
941
- variance_noise: Optional[torch.Tensor] = None,
942
- return_dict: bool = True,
943
- ) -> Union[SchedulerOutput, Tuple]:
944
- """
945
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
946
- the multistep DPMSolver.
947
-
948
- Args:
949
- model_output (`torch.Tensor`):
950
- The direct output from learned diffusion model.
951
- timestep (`int`):
952
- The current discrete timestep in the diffusion chain.
953
- sample (`torch.Tensor`):
954
- A current instance of a sample created by the diffusion process.
955
- generator (`torch.Generator`, *optional*):
956
- A random number generator.
957
- variance_noise (`torch.Tensor`):
958
- Alternative to generating noise with `generator` by directly providing the noise for the variance
959
- itself. Useful for methods such as [`LEdits++`].
960
- return_dict (`bool`):
961
- Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
962
-
963
- Returns:
964
- [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
965
- If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
966
- tuple is returned where the first element is the sample tensor.
967
-
968
- """
969
- if self.num_inference_steps is None:
970
- raise ValueError(
971
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
972
- )
973
-
974
- if self.step_index is None:
975
- self._init_step_index(timestep)
976
-
977
- # Improve numerical stability for small number of steps
978
- lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
979
- self.config.euler_at_final
980
- or (self.config.lower_order_final and len(self.timesteps) < 15)
981
- or self.config.final_sigmas_type == "zero"
982
- )
983
- lower_order_second = (
984
- (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
985
- )
986
-
987
- model_output = self.convert_model_output(model_output, sample=sample)
988
- for i in range(self.config.solver_order - 1):
989
- self.model_outputs[i] = self.model_outputs[i + 1]
990
- self.model_outputs[-1] = model_output
991
-
992
- # Upcast to avoid precision issues when computing prev_sample
993
- sample = sample.to(torch.float32)
994
- if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
995
- noise = randn_tensor(
996
- model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
997
- )
998
- elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
999
- noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
1000
- else:
1001
- noise = None
1002
-
1003
- if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
1004
- prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
1005
- elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
1006
- prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
1007
- else:
1008
- prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
1009
-
1010
- if self.lower_order_nums < self.config.solver_order:
1011
- self.lower_order_nums += 1
1012
-
1013
- # Cast sample back to expected dtype
1014
- prev_sample = prev_sample.to(model_output.dtype)
1015
-
1016
- # upon completion increase step index by one
1017
- self._step_index += 1
1018
-
1019
- if not return_dict:
1020
- return (prev_sample,)
1021
-
1022
- return SchedulerOutput(prev_sample=prev_sample)
1023
-
1024
- def add_noise(
1025
- self,
1026
- original_samples: torch.Tensor,
1027
- noise: torch.Tensor,
1028
- timesteps: torch.IntTensor,
1029
- ) -> torch.Tensor:
1030
- # Make sure sigmas and timesteps have the same device and dtype as original_samples
1031
- # alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
1032
- # sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
1033
- alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
1034
- sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
1035
- timesteps = timesteps.to(original_samples.device)
1036
- alpha_t = alpha_t[timesteps].flatten()
1037
- while len(alpha_t.shape) < len(original_samples.shape):
1038
- alpha_t = alpha_t.unsqueeze(-1)
1039
-
1040
- sigma_t = sigma_t[timesteps].flatten()
1041
- while len(sigma_t.shape) < len(original_samples.shape):
1042
- sigma_t = sigma_t.unsqueeze(-1)
1043
- noisy_samples = alpha_t * original_samples + sigma_t * noise
1044
- return noisy_samples
1045
-
1046
- def get_velocity(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
1047
- # alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
1048
- # sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
1049
- alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
1050
- sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
1051
-
1052
- timesteps = timesteps.to(original_samples.device)
1053
- alpha_t = alpha_t[timesteps].flatten()
1054
- while len(alpha_t.shape) < len(original_samples.shape):
1055
- alpha_t = alpha_t.unsqueeze(-1)
1056
-
1057
- sigma_t = sigma_t[timesteps].flatten()
1058
- while len(sigma_t.shape) < len(original_samples.shape):
1059
- sigma_t = sigma_t.unsqueeze(-1)
1060
-
1061
- velocity = alpha_t * noise - sigma_t * original_samples
1062
- return velocity
1063
-
1064
- def __len__(self):
1065
- return self.config.num_train_timesteps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/schedule/timestep_sampler.py DELETED
@@ -1,19 +0,0 @@
1
- import math
2
- import torch
3
-
4
-
5
- class UniformSampler:
6
- def __init__(self, timesteps = 1000):
7
- self.timesteps = timesteps
8
- def sample(self, batch_size, device):
9
- return torch.randint(0, self.timesteps, (batch_size,), device=device)
10
-
11
- class LogitNormalSampler:
12
- def __init__(self, timesteps = 1000, m = 0, s = 1):
13
- self.timesteps = timesteps
14
- timesteps = torch.linspace(0, 1, timesteps)
15
- logit = torch.log(timesteps / (1 - timesteps))
16
- self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi))
17
- def sample(self, batch_size, device):
18
- return torch.multinomial(self.prob, batch_size, replacement=True).to(device)
19
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vibevoice/scripts/__init__.py DELETED
File without changes
src/vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py DELETED
@@ -1,166 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
-
4
- import argparse
5
- import json
6
- import os
7
- from pathlib import Path
8
- import re
9
- import torch
10
- from typing import Dict, List, Tuple
11
-
12
- from vibevoice.modular.configuration_vibevoice import (
13
- VibeVoiceConfig
14
- )
15
- from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
16
- from transformers.utils import logging
17
-
18
- logger = logging.get_logger(__name__)
19
-
20
- def convert_vibevoice_nnscaler_checkpoint_to_hf(
21
- checkpoint_path: str,
22
- pytorch_dump_folder_path: str,
23
- config_path: str = None,
24
- ):
25
- """
26
- Convert a nnscaler VibeVoice checkpoint to HuggingFace format.
27
- Supports both regular checkpoints and tensor parallel checkpoints.
28
- """
29
-
30
- # Load regular checkpoint
31
- logger.info(f"Loading regular checkpoint from {checkpoint_path}")
32
- checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader']
33
-
34
- # config = checkpoint['train_args']
35
- init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path']
36
- pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path']
37
-
38
- init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1]
39
- if init_config_path.exists():
40
- logger.info(f"Loading initial config from {init_config_path}")
41
- with open(init_config_path, 'r') as f:
42
- init_config = json.load(f)
43
- else:
44
- raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.")
45
-
46
- tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True)
47
- logger.info(f"Tie word embeddings: {tie_word_embeddings}")
48
-
49
- init_config['decoder_config']['use_cache'] = True
50
- config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings)
51
-
52
- # # Extract the model state dict
53
- model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')}
54
- if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys():
55
- # If not tying weights, we need to add the lm_head weight separately
56
- model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight']
57
-
58
- # Override with provided config if available
59
- if config_path:
60
- logger.info(f"Loading config from {config_path}")
61
- with open(config_path, 'r') as f:
62
- config_dict = json.load(f)
63
- config = VibeVoiceConfig.from_dict(config_dict)
64
-
65
- # Set the default dtype to bfloat16 before creating the model
66
- original_dtype = torch.get_default_dtype()
67
- torch.set_default_dtype(torch.bfloat16)
68
-
69
- # Create the HuggingFace model
70
- logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model")
71
- model = VibeVoiceForConditionalGeneration(config)
72
-
73
- # Restore original dtype
74
- torch.set_default_dtype(original_dtype)
75
-
76
- # Load the state dict
77
- logger.info("Loading weights into model")
78
- missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
79
-
80
- if missing_keys:
81
- logger.warning(f"Missing keys: {missing_keys}")
82
- if unexpected_keys:
83
- logger.warning(f"Unexpected keys: {unexpected_keys}")
84
-
85
- # Create output directory
86
- os.makedirs(pytorch_dump_folder_path, exist_ok=True)
87
-
88
- # Save the model and config
89
- logger.info(f"Saving model to {pytorch_dump_folder_path}")
90
-
91
- # Save config
92
- config.save_pretrained(pytorch_dump_folder_path)
93
-
94
- # Save VibeVoiceProcessor configuration
95
- logger.info("Saving VibeVoiceProcessor configuration")
96
- processor_config = {
97
- "processor_class": "VibeVoiceProcessor",
98
- "speech_tok_compress_ratio": 3200,
99
- "db_normalize": True,
100
- # Audio processor configuration
101
- "audio_processor": {
102
- "feature_extractor_type": "VibeVoiceTokenizerProcessor",
103
- "sampling_rate": 24000,
104
- "normalize_audio": True,
105
- "target_dB_FS": -25,
106
- "eps": 1e-6,
107
- },
108
- "language_model_pretrained_name": pretrained_name,
109
- }
110
-
111
- processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json")
112
- with open(processor_config_path, 'w') as f:
113
- json.dump(processor_config, f, indent=2)
114
- logger.info(f"Saved processor config to {processor_config_path}")
115
-
116
- # Save model with sharding
117
- # save_pretrained handles tied weights automatically
118
- logger.info("Saving model weights with sharding...")
119
- model.save_pretrained(
120
- pytorch_dump_folder_path,
121
- max_shard_size="2GB", # Set maximum size for each shard
122
- safe_serialization=True # Ensure saving in .safetensors format
123
- )
124
- logger.info(f"Model weights saved to {pytorch_dump_folder_path}")
125
-
126
- logger.info("Conversion complete!")
127
-
128
- # Verify the saved model can be loaded
129
- logger.info("Verifying saved model...")
130
- loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path)
131
- logger.info("Model successfully loaded from saved checkpoint!")
132
-
133
- def main():
134
- parser = argparse.ArgumentParser()
135
- parser.add_argument(
136
- "--nnscaler_checkpoint_path",
137
- type=str,
138
- required=True,
139
- help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, "
140
- "provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), "
141
- "and the script will automatically detect and merge all parts.",
142
- )
143
- parser.add_argument(
144
- "--pytorch_dump_folder_path",
145
- type=str,
146
- required=True,
147
- help="Path to the output PyTorch model directory",
148
- )
149
- parser.add_argument(
150
- "--config_path",
151
- type=str,
152
- default=None,
153
- help="Optional path to a config JSON file to override extracted config",
154
- )
155
-
156
- args = parser.parse_args()
157
-
158
- convert_vibevoice_nnscaler_checkpoint_to_hf(
159
- args.nnscaler_checkpoint_path,
160
- args.pytorch_dump_folder_path,
161
- args.config_path,
162
- )
163
-
164
-
165
- if __name__ == "__main__":
166
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/voices/vibe_voices/en-Alice_woman.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c27ae47421436287a6bd2c3062de2dc2a2855b78c0bb626d472202c359704203
3
- size 296684
 
 
 
 
src/voices/vibe_voices/en-Carter_man.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9dd7b12f25bf279d878a9f7a3125f64bff2b312a189959090acff9138a55e8dd
3
- size 1331244
 
 
 
 
src/voices/vibe_voices/en-Frank_man.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:aa77c4794a005c4b05a52bbce5f30e77f0d28987b9a9e737401a5d30fd1ebcb5
3
- size 1158444
 
 
 
 
src/voices/vibe_voices/en-Mary_woman_bgm.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c421eeab1af5b3ddae8d14cfcf6b65e496047ad2228325d61d1b6967fca11700
3
- size 1292878
 
 
 
 
src/voices/vibe_voices/en-Maya_woman.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:eb1288bc02546c7f1117698fb78e994f060e623af148be8ccbf93dd0bea79e32
3
- size 1305644
 
 
 
 
src/voices/vibe_voices/in-Samuel_man.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:76b07b5a12ca0b24a1e4a88100c4e2e47a2552ebb96807d52f116cf05fc46b50
3
- size 1273644
 
 
 
 
src/voices/vibe_voices/zh-Anchen_man_bgm.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f71aeb33ed66c449dedb75d8a505478d86d47ec49e0e4c33c1fd0f8324d781fb
3
- size 1177644
 
 
 
 
src/voices/vibe_voices/zh-Bowen_man.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0cef6c018e73e9fb6a1269fd61ded08144ae6380cdec242eebb1cc8aca49fed1
3
- size 1419940
 
 
 
 
src/voices/vibe_voices/zh-Xinran_woman.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dbcb9e28bcc544675ef75a8ba12528bf09e713eb53a8c0c819dec3daf2d486d3
3
- size 1337644