File size: 3,560 Bytes
86e8346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '../src'))

import torch
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from peft import PeftModel

# Configuration
MODEL_DIR = ".."  # Path to VibeVoice-1.5B directory
LORA_DIR = "../finetune_elise_single_speaker/lora"  # Path to your fine-tuned LoRA weights
OUTPUT_DIR = "output_audio"

def load_model():
    """Load the fine-tuned model"""
    print("Loading model...")

    # Load base model
    model = VibeVoiceForConditionalGenerationInference.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
        attn_implementation="flash_attention_2"
    )

    # Load fine-tuned LoRA weights
    model.model.language_model = PeftModel.from_pretrained(
        model.model.language_model,
        LORA_DIR
    )

    # Load diffusion head
    diffusion_state = torch.load(f"{LORA_DIR}/diffusion_head_full.bin", map_location="cpu")
    model.model.prediction_head.load_state_dict(diffusion_state)

    # Load processor
    processor = VibeVoiceProcessor.from_pretrained(f"{MODEL_DIR}/src/vibevoice/processor")

    model.eval()
    model.set_ddpm_inference_steps(num_steps=20)

    return model, processor

def generate_speech(model, processor, text, voice_sample_path=None):
    """Generate speech from text"""

    # Format text with Speaker 0 prefix (required!)
    prompt = f"Speaker 0: {text}"

    # If no voice sample provided, use a dummy one from training data
    # The model ignores this since it was trained with voice_prompt_drop_rate=1.0
    if voice_sample_path is None:
        # You'll need at least one audio file from the training set
        voice_sample_path = "../elise_cleaned/wavs/sample_000009.wav"

    # Process inputs
    inputs = processor(
        text=[prompt],
        voice_samples=[[voice_sample_path]],
        return_tensors="pt"
    )

    # Move to GPU
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    # Generate audio
    outputs = model.generate(
        **inputs,
        cfg_scale=2.0,
        tokenizer=processor.tokenizer,
        generation_config={'do_sample': False},
        verbose=False
    )

    if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
        audio = outputs.speech_outputs[0]

        # Add small silence padding at the end
        silence = torch.zeros_like(audio[..., :4800])  # 200ms
        padded = torch.cat([audio, silence], dim=-1)

        return padded

    return None

def main():
    # Load model once
    model, processor = load_model()

    # Create output directory
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Example texts
    texts = [
        "Hello! This is the Elise voice model.",
        "I can generate speech without needing voice samples.",
        "Thank you for using this model!"
    ]

    # Generate speech for each text
    for i, text in enumerate(texts):
        print(f"\nGenerating: {text}")

        audio = generate_speech(model, processor, text)

        if audio is not None:
            output_path = f"{OUTPUT_DIR}/output_{i:02d}.wav"
            processor.save_audio(audio, output_path)

            duration = (audio.shape[-1] - 4800) / 24000  # Subtract padding
            print(f"Saved: {output_path} ({duration:.2f}s)")
        else:
            print("Failed to generate audio")

if __name__ == "__main__":
    main()