import os import shutil import threading import torch import numpy as np import gradio as gr from pathlib import Path from typing import Dict, Optional, Tuple, Iterator, Any import copy from vibevoice.modular.modeling_vibevoice_streaming_inference import ( VibeVoiceStreamingForConditionalGenerationInference, ) from vibevoice.processor.vibevoice_streaming_processor import ( VibeVoiceStreamingProcessor, ) from vibevoice.modular.streamer import AudioStreamer SAMPLE_RATE = 24_000 class StreamingTTSService: def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5) -> None: self.model_path = model_path self.inference_steps = inference_steps self.sample_rate = SAMPLE_RATE self.processor: Optional[VibeVoiceStreamingProcessor] = None self.model: Optional[VibeVoiceStreamingForConditionalGenerationInference] = None self.voice_presets: Dict[str, Path] = {} self.default_voice_key: Optional[str] = None self._voice_cache: Dict[str, Tuple[object, Path, str]] = {} if device == "cuda" and not torch.cuda.is_available(): print("Warning: CUDA not available. Falling back to CPU.") device = "cpu" self.device = device self._torch_device = torch.device(device) def load(self) -> None: print(f"[startup] Loading processor from {self.model_path}") self.processor = VibeVoiceStreamingProcessor.from_pretrained(self.model_path) if self.device == "cuda": load_dtype = torch.bfloat16 device_map = 'cuda' attn_impl_primary = "flash_attention_2" else: load_dtype = torch.float32 device_map = 'cpu' attn_impl_primary = "sdpa" print(f"Using device: {device_map}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}") try: self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( self.model_path, torch_dtype=load_dtype, device_map=device_map, attn_implementation=attn_impl_primary, ) except Exception as e: print(f"Error loading model with {attn_impl_primary}: {e}") if attn_impl_primary == 'flash_attention_2': print("Falling back to SDPA...") self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( self.model_path, torch_dtype=load_dtype, device_map=self.device, attn_implementation='sdpa', ) else: raise e self.model.eval() self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( self.model.model.noise_scheduler.config, algorithm_type="sde-dpmsolver++", beta_schedule="squaredcos_cap_v2", ) self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) self.voice_presets = self._load_voice_presets() self.default_voice_key = self._determine_voice_key(None) def _load_voice_presets(self) -> Dict[str, Path]: voices_dir = Path("./demo/voices/streaming_model") if not voices_dir.exists(): if Path("demo").exists(): pass else: raise RuntimeError(f"Cannot find voices dir at {voices_dir}") presets: Dict[str, Path] = {} for pt_path in voices_dir.rglob("*.pt"): presets[pt_path.stem] = pt_path if not presets: raise RuntimeError(f"No voice preset (.pt) files found in {voices_dir}") print(f"[startup] Found {len(presets)} voice presets") return dict(sorted(presets.items())) def _determine_voice_key(self, name: Optional[str]) -> str: if name and name in self.voice_presets: return name candidates = ["en-WHTest_man"] for c in candidates: if c in self.voice_presets: return c return next(iter(self.voice_presets)) def _ensure_voice_cached(self, key: str) -> object: if key not in self.voice_presets: key = self.default_voice_key if key not in self._voice_cache: preset_path = self.voice_presets[key] prefilled_outputs = torch.load(preset_path, map_location=self._torch_device, weights_only=False) self._voice_cache[key] = prefilled_outputs return self._voice_cache[key] def _prepare_inputs(self, text: str, prefilled_outputs: object): processor_kwargs = {"text": text.strip(), "cached_prompt": prefilled_outputs, "padding": True, "return_tensors": "pt", "return_attention_mask": True} processed = self.processor.process_input_with_cached_prompt(**processor_kwargs) prepared = {key: value.to(self._torch_device) if hasattr(value, "to") else value for key, value in processed.items()} return prepared def _run_generation(self, inputs, audio_streamer, errors, stop_event, prefilled_outputs): try: self.model.generate(**inputs, max_new_tokens=None, cfg_scale=1.5, tokenizer=self.processor.tokenizer, generation_config={"do_sample": True, "temperature": 1.0, "top_p": 1.0}, audio_streamer=audio_streamer, stop_check_fn=stop_event.is_set, verbose=False, refresh_negative=True, all_prefilled_outputs=copy.deepcopy(prefilled_outputs)) except Exception as e: errors.append(e) print(f"Generation error: {e}") audio_streamer.end() def stream(self, text: str, voice_key: str) -> Iterator[Tuple[int, np.ndarray]]: if not text.strip(): return prefilled_outputs = self._ensure_voice_cached(voice_key) audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None) stop_event = threading.Event() errors = [] inputs = self._prepare_inputs(text, prefilled_outputs) thread = threading.Thread(target=self._run_generation, kwargs={"inputs": inputs, "audio_streamer": audio_streamer, "errors": errors, "stop_event": stop_event, "prefilled_outputs": prefilled_outputs}, daemon=True) thread.start() try: stream = audio_streamer.get_stream(0) for audio_chunk in stream: if torch.is_tensor(audio_chunk): audio_chunk = audio_chunk.detach().cpu().float().numpy() else: audio_chunk = np.asarray(audio_chunk, dtype=np.float32) if audio_chunk.ndim > 1: audio_chunk = audio_chunk.reshape(-1) yield (SAMPLE_RATE, audio_chunk) finally: stop_event.set() audio_streamer.end() thread.join() if errors: raise errors[0] MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B" service = StreamingTTSService(MODEL_ID) service.load() def tts_generate(text, voice): yield from service.stream(text, voice) with gr.Blocks(title="VibeVoice-Realtime Demo") as demo: gr.Markdown("# Microsoft VibeVoice-Realtime-0.5B") with gr.Row(): text_input = gr.Textbox(label="Input Text", value="Hello world! This is VibeVoice speaking realtime.") voice_dropdown = gr.Dropdown(choices=list(service.voice_presets.keys()), value=service.default_voice_key, label="Voice Preset") audio_output = gr.Audio(label="Generated Audio", streaming=True, autoplay=True) btn = gr.Button("Generate", variant="primary") btn.click(tts_generate, inputs=[text_input, voice_dropdown], outputs=[audio_output]) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=7860)