Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import threading | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from pathlib import Path | |
| from typing import Dict, Optional, Tuple, Iterator, Any | |
| import copy | |
| from vibevoice.modular.modeling_vibevoice_streaming_inference import ( | |
| VibeVoiceStreamingForConditionalGenerationInference, | |
| ) | |
| from vibevoice.processor.vibevoice_streaming_processor import ( | |
| VibeVoiceStreamingProcessor, | |
| ) | |
| from vibevoice.modular.streamer import AudioStreamer | |
| SAMPLE_RATE = 24_000 | |
| class StreamingTTSService: | |
| def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5) -> None: | |
| self.model_path = model_path | |
| self.inference_steps = inference_steps | |
| self.sample_rate = SAMPLE_RATE | |
| self.processor: Optional[VibeVoiceStreamingProcessor] = None | |
| self.model: Optional[VibeVoiceStreamingForConditionalGenerationInference] = None | |
| self.voice_presets: Dict[str, Path] = {} | |
| self.default_voice_key: Optional[str] = None | |
| self._voice_cache: Dict[str, Tuple[object, Path, str]] = {} | |
| if device == "cuda" and not torch.cuda.is_available(): | |
| print("Warning: CUDA not available. Falling back to CPU.") | |
| device = "cpu" | |
| self.device = device | |
| self._torch_device = torch.device(device) | |
| def load(self) -> None: | |
| print(f"[startup] Loading processor from {self.model_path}") | |
| self.processor = VibeVoiceStreamingProcessor.from_pretrained(self.model_path) | |
| if self.device == "cuda": | |
| load_dtype = torch.bfloat16 | |
| device_map = 'cuda' | |
| attn_impl_primary = "flash_attention_2" | |
| else: | |
| load_dtype = torch.float32 | |
| device_map = 'cpu' | |
| attn_impl_primary = "sdpa" | |
| print(f"Using device: {device_map}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}") | |
| try: | |
| self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( | |
| self.model_path, torch_dtype=load_dtype, device_map=device_map, attn_implementation=attn_impl_primary, | |
| ) | |
| except Exception as e: | |
| print(f"Error loading model with {attn_impl_primary}: {e}") | |
| if attn_impl_primary == 'flash_attention_2': | |
| print("Falling back to SDPA...") | |
| self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( | |
| self.model_path, torch_dtype=load_dtype, device_map=self.device, attn_implementation='sdpa', | |
| ) | |
| else: | |
| raise e | |
| self.model.eval() | |
| self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( | |
| self.model.model.noise_scheduler.config, algorithm_type="sde-dpmsolver++", beta_schedule="squaredcos_cap_v2", | |
| ) | |
| self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) | |
| self.voice_presets = self._load_voice_presets() | |
| self.default_voice_key = self._determine_voice_key(None) | |
| def _load_voice_presets(self) -> Dict[str, Path]: | |
| voices_dir = Path("./demo/voices/streaming_model") | |
| if not voices_dir.exists(): | |
| if Path("demo").exists(): pass | |
| else: raise RuntimeError(f"Cannot find voices dir at {voices_dir}") | |
| presets: Dict[str, Path] = {} | |
| for pt_path in voices_dir.rglob("*.pt"): | |
| presets[pt_path.stem] = pt_path | |
| if not presets: raise RuntimeError(f"No voice preset (.pt) files found in {voices_dir}") | |
| print(f"[startup] Found {len(presets)} voice presets") | |
| return dict(sorted(presets.items())) | |
| def _determine_voice_key(self, name: Optional[str]) -> str: | |
| if name and name in self.voice_presets: return name | |
| candidates = ["en-WHTest_man"] | |
| for c in candidates: | |
| if c in self.voice_presets: return c | |
| return next(iter(self.voice_presets)) | |
| def _ensure_voice_cached(self, key: str) -> object: | |
| if key not in self.voice_presets: key = self.default_voice_key | |
| if key not in self._voice_cache: | |
| preset_path = self.voice_presets[key] | |
| prefilled_outputs = torch.load(preset_path, map_location=self._torch_device, weights_only=False) | |
| self._voice_cache[key] = prefilled_outputs | |
| return self._voice_cache[key] | |
| def _prepare_inputs(self, text: str, prefilled_outputs: object): | |
| processor_kwargs = {"text": text.strip(), "cached_prompt": prefilled_outputs, "padding": True, "return_tensors": "pt", "return_attention_mask": True} | |
| processed = self.processor.process_input_with_cached_prompt(**processor_kwargs) | |
| prepared = {key: value.to(self._torch_device) if hasattr(value, "to") else value for key, value in processed.items()} | |
| return prepared | |
| def _run_generation(self, inputs, audio_streamer, errors, stop_event, prefilled_outputs): | |
| try: | |
| self.model.generate(**inputs, max_new_tokens=None, cfg_scale=1.5, tokenizer=self.processor.tokenizer, generation_config={"do_sample": True, "temperature": 1.0, "top_p": 1.0}, audio_streamer=audio_streamer, stop_check_fn=stop_event.is_set, verbose=False, refresh_negative=True, all_prefilled_outputs=copy.deepcopy(prefilled_outputs)) | |
| except Exception as e: | |
| errors.append(e) | |
| print(f"Generation error: {e}") | |
| audio_streamer.end() | |
| def stream(self, text: str, voice_key: str) -> Iterator[Tuple[int, np.ndarray]]: | |
| if not text.strip(): return | |
| prefilled_outputs = self._ensure_voice_cached(voice_key) | |
| audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None) | |
| stop_event = threading.Event() | |
| errors = [] | |
| inputs = self._prepare_inputs(text, prefilled_outputs) | |
| thread = threading.Thread(target=self._run_generation, kwargs={"inputs": inputs, "audio_streamer": audio_streamer, "errors": errors, "stop_event": stop_event, "prefilled_outputs": prefilled_outputs}, daemon=True) | |
| thread.start() | |
| try: | |
| stream = audio_streamer.get_stream(0) | |
| for audio_chunk in stream: | |
| if torch.is_tensor(audio_chunk): audio_chunk = audio_chunk.detach().cpu().float().numpy() | |
| else: audio_chunk = np.asarray(audio_chunk, dtype=np.float32) | |
| if audio_chunk.ndim > 1: audio_chunk = audio_chunk.reshape(-1) | |
| yield (SAMPLE_RATE, audio_chunk) | |
| finally: | |
| stop_event.set() | |
| audio_streamer.end() | |
| thread.join() | |
| if errors: raise errors[0] | |
| MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B" | |
| service = StreamingTTSService(MODEL_ID) | |
| service.load() | |
| def tts_generate(text, voice): | |
| yield from service.stream(text, voice) | |
| with gr.Blocks(title="VibeVoice-Realtime Demo") as demo: | |
| gr.Markdown("# Microsoft VibeVoice-Realtime-0.5B") | |
| with gr.Row(): | |
| text_input = gr.Textbox(label="Input Text", value="Hello world! This is VibeVoice speaking realtime.") | |
| voice_dropdown = gr.Dropdown(choices=list(service.voice_presets.keys()), value=service.default_voice_key, label="Voice Preset") | |
| audio_output = gr.Audio(label="Generated Audio", streaming=True, autoplay=True) | |
| btn = gr.Button("Generate", variant="primary") | |
| btn.click(tts_generate, inputs=[text_input, voice_dropdown], outputs=[audio_output]) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |