| import sys |
| import os |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) |
|
|
| import torch |
| from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference |
| from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor |
| from peft import PeftModel |
| import json |
|
|
| print("Loading fadeout model (4 epochs, fadeout + padding dataset)...") |
| model = VibeVoiceForConditionalGenerationInference.from_pretrained( |
| ".", |
| torch_dtype=torch.bfloat16, |
| device_map="cuda", |
| attn_implementation="flash_attention_2" |
| ) |
|
|
| |
| model.model.language_model = PeftModel.from_pretrained( |
| model.model.language_model, |
| "finetune_elise_fadeout/lora" |
| ) |
|
|
| diffusion_state = torch.load("finetune_elise_fadeout/lora/diffusion_head_full.bin", map_location="cpu") |
| model.model.prediction_head.load_state_dict(diffusion_state) |
|
|
| processor = VibeVoiceProcessor.from_pretrained("src/vibevoice/processor") |
| model.eval() |
|
|
| |
| model.set_ddpm_inference_steps(num_steps=20) |
|
|
| |
| with open("jinsaryko_elise_fadeout/elise_train_split.jsonl", 'r') as f: |
| voice_data = json.loads(f.readline()) |
| dummy_voice_path = voice_data['audio'] |
|
|
| print(f"\nUsing dummy voice (ignored): {os.path.basename(dummy_voice_path)}") |
| print("Testing fadeout model with various length statements...\n") |
|
|
| |
| test_sentences = [ |
| |
| "Hello!", |
| "Good morning everyone...", |
| "Welcome to my channel...", |
| "Thanks for watching!", |
|
|
| |
| "Today we're going to learn something amazing together...", |
| "I'm really excited to share this with all of you...", |
| "Let me show you how this incredible feature works...", |
| "Have you ever wondered about the mysteries of the universe?", |
|
|
| |
| "Welcome back to the channel! Today I have something really special to share with you, and I think you're going to absolutely love what we're about to explore together...", |
| "Throughout my journey of learning and discovery, I've come across many fascinating concepts, but this one in particular has completely transformed the way I think about technology and innovation...", |
| "The beauty of machine learning lies not just in its complexity, but in how it can bring seemingly impossible ideas to life, creating experiences that were once confined to the realm of science fiction...", |
| "As we dive deeper into this topic, I want you to imagine the endless possibilities that await us, and consider how these advancements might shape our future in ways we can barely comprehend today..." |
| ] |
|
|
| os.makedirs("test_fadeout_output", exist_ok=True) |
|
|
| for i, text in enumerate(test_sentences): |
| length = "short" if len(text) < 30 else "medium" if len(text) < 100 else "long" |
| print(f"\n[{i+1}/{len(test_sentences)}] [{length.upper()}] {text}") |
|
|
| prompt = f"Speaker 0: {text}" |
|
|
| inputs = processor( |
| text=[prompt], |
| voice_samples=[[dummy_voice_path]], |
| return_tensors="pt" |
| ) |
|
|
| for k, v in inputs.items(): |
| if torch.is_tensor(v): |
| inputs[k] = v.to("cuda") |
|
|
| outputs = model.generate( |
| **inputs, |
| cfg_scale=2.0, |
| tokenizer=processor.tokenizer, |
| generation_config={'do_sample': False}, |
| verbose=False |
| ) |
|
|
| if outputs.speech_outputs and outputs.speech_outputs[0] is not None: |
| audio = outputs.speech_outputs[0] |
|
|
| |
| silence = torch.zeros_like(audio[..., :4800]) |
| padded = torch.cat([audio, silence], dim=-1) |
|
|
| output_path = f"test_fadeout_output/test_{i:02d}_{length}.wav" |
| processor.save_audio(padded, output_path) |
|
|
| duration = audio.shape[-1] / 24000 |
| print(f" ✓ Generated {duration:.2f}s → {output_path}") |
|
|
| print("\n" + "="*60) |
| print("Fadeout model test complete!") |
| print("Files saved in test_fadeout_output/") |
| print("\nModel stats:") |
| print("- 4 epochs on fadeout dataset (100ms fade + 250ms padding)") |
| print("- Final CE loss: ~5.25") |
| print("- Final Diffusion loss: ~0.559") |
| print("- voice_prompt_drop_rate: 1.0 (no voice prompts)") |
| print("- All training audio had smooth fadeouts!") |
| print("="*60) |