File size: 7,504 Bytes
27564df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import shutil
import threading
import torch
import numpy as np
import gradio as gr
from pathlib import Path
from typing import Dict, Optional, Tuple, Iterator, Any
import copy

from vibevoice.modular.modeling_vibevoice_streaming_inference import (
    VibeVoiceStreamingForConditionalGenerationInference,
)
from vibevoice.processor.vibevoice_streaming_processor import (
    VibeVoiceStreamingProcessor,
)
from vibevoice.modular.streamer import AudioStreamer

SAMPLE_RATE = 24_000

class StreamingTTSService:
    def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5) -> None:
        self.model_path = model_path
        self.inference_steps = inference_steps
        self.sample_rate = SAMPLE_RATE
        self.processor: Optional[VibeVoiceStreamingProcessor] = None
        self.model: Optional[VibeVoiceStreamingForConditionalGenerationInference] = None
        self.voice_presets: Dict[str, Path] = {}
        self.default_voice_key: Optional[str] = None
        self._voice_cache: Dict[str, Tuple[object, Path, str]] = {}
        if device == "cuda" and not torch.cuda.is_available():
            print("Warning: CUDA not available. Falling back to CPU.")
            device = "cpu"
        self.device = device
        self._torch_device = torch.device(device)

    def load(self) -> None:
        print(f"[startup] Loading processor from {self.model_path}")
        self.processor = VibeVoiceStreamingProcessor.from_pretrained(self.model_path)
        if self.device == "cuda":
            load_dtype = torch.bfloat16
            device_map = 'cuda'
            attn_impl_primary = "flash_attention_2"
        else:
            load_dtype = torch.float32
            device_map = 'cpu'
            attn_impl_primary = "sdpa"
        print(f"Using device: {device_map}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
        try:
            self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
                self.model_path, torch_dtype=load_dtype, device_map=device_map, attn_implementation=attn_impl_primary,
            )
        except Exception as e:
            print(f"Error loading model with {attn_impl_primary}: {e}")
            if attn_impl_primary == 'flash_attention_2':
                print("Falling back to SDPA...")
                self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
                    self.model_path, torch_dtype=load_dtype, device_map=self.device, attn_implementation='sdpa',
                )
            else:
                raise e
        self.model.eval()
        self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
            self.model.model.noise_scheduler.config, algorithm_type="sde-dpmsolver++", beta_schedule="squaredcos_cap_v2",
        )
        self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
        self.voice_presets = self._load_voice_presets()
        self.default_voice_key = self._determine_voice_key(None)

    def _load_voice_presets(self) -> Dict[str, Path]:
        voices_dir = Path("./demo/voices/streaming_model")
        if not voices_dir.exists():
            if Path("demo").exists(): pass
            else: raise RuntimeError(f"Cannot find voices dir at {voices_dir}")
        presets: Dict[str, Path] = {}
        for pt_path in voices_dir.rglob("*.pt"):
            presets[pt_path.stem] = pt_path
        if not presets: raise RuntimeError(f"No voice preset (.pt) files found in {voices_dir}")
        print(f"[startup] Found {len(presets)} voice presets")
        return dict(sorted(presets.items()))

    def _determine_voice_key(self, name: Optional[str]) -> str:
        if name and name in self.voice_presets: return name
        candidates = ["en-WHTest_man"]
        for c in candidates:
             if c in self.voice_presets: return c
        return next(iter(self.voice_presets))

    def _ensure_voice_cached(self, key: str) -> object:
        if key not in self.voice_presets: key = self.default_voice_key
        if key not in self._voice_cache:
            preset_path = self.voice_presets[key]
            prefilled_outputs = torch.load(preset_path, map_location=self._torch_device, weights_only=False)
            self._voice_cache[key] = prefilled_outputs
        return self._voice_cache[key]

    def _prepare_inputs(self, text: str, prefilled_outputs: object):
        processor_kwargs = {"text": text.strip(), "cached_prompt": prefilled_outputs, "padding": True, "return_tensors": "pt", "return_attention_mask": True}
        processed = self.processor.process_input_with_cached_prompt(**processor_kwargs)
        prepared = {key: value.to(self._torch_device) if hasattr(value, "to") else value for key, value in processed.items()}
        return prepared
    
    def _run_generation(self, inputs, audio_streamer, errors, stop_event, prefilled_outputs):
         try:
            self.model.generate(**inputs, max_new_tokens=None, cfg_scale=1.5, tokenizer=self.processor.tokenizer, generation_config={"do_sample": True, "temperature": 1.0, "top_p": 1.0}, audio_streamer=audio_streamer, stop_check_fn=stop_event.is_set, verbose=False, refresh_negative=True, all_prefilled_outputs=copy.deepcopy(prefilled_outputs))
         except Exception as e:
             errors.append(e)
             print(f"Generation error: {e}")
             audio_streamer.end()

    def stream(self, text: str, voice_key: str) -> Iterator[Tuple[int, np.ndarray]]:
        if not text.strip(): return
        prefilled_outputs = self._ensure_voice_cached(voice_key)
        audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None)
        stop_event = threading.Event()
        errors = []
        inputs = self._prepare_inputs(text, prefilled_outputs)
        thread = threading.Thread(target=self._run_generation, kwargs={"inputs": inputs, "audio_streamer": audio_streamer, "errors": errors, "stop_event": stop_event, "prefilled_outputs": prefilled_outputs}, daemon=True)
        thread.start()
        try:
            stream = audio_streamer.get_stream(0)
            for audio_chunk in stream:
                if torch.is_tensor(audio_chunk): audio_chunk = audio_chunk.detach().cpu().float().numpy()
                else: audio_chunk = np.asarray(audio_chunk, dtype=np.float32)
                if audio_chunk.ndim > 1: audio_chunk = audio_chunk.reshape(-1)
                yield (SAMPLE_RATE, audio_chunk)
        finally:
            stop_event.set()
            audio_streamer.end()
            thread.join()
            if errors: raise errors[0]

MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
service = StreamingTTSService(MODEL_ID)
service.load()

def tts_generate(text, voice):
    yield from service.stream(text, voice)

with gr.Blocks(title="VibeVoice-Realtime Demo") as demo:
    gr.Markdown("# Microsoft VibeVoice-Realtime-0.5B")
    with gr.Row():
        text_input = gr.Textbox(label="Input Text", value="Hello world! This is VibeVoice speaking realtime.")
        voice_dropdown = gr.Dropdown(choices=list(service.voice_presets.keys()), value=service.default_voice_key, label="Voice Preset")
    audio_output = gr.Audio(label="Generated Audio", streaming=True, autoplay=True)
    btn = gr.Button("Generate", variant="primary")
    btn.click(tts_generate, inputs=[text_input, voice_dropdown], outputs=[audio_output])

if __name__ == "__main__":
    demo.queue().launch(server_name="0.0.0.0", server_port=7860)