Spaces:
Sleeping
Sleeping
File size: 7,504 Bytes
27564df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import os
import shutil
import threading
import torch
import numpy as np
import gradio as gr
from pathlib import Path
from typing import Dict, Optional, Tuple, Iterator, Any
import copy
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
VibeVoiceStreamingForConditionalGenerationInference,
)
from vibevoice.processor.vibevoice_streaming_processor import (
VibeVoiceStreamingProcessor,
)
from vibevoice.modular.streamer import AudioStreamer
SAMPLE_RATE = 24_000
class StreamingTTSService:
def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5) -> None:
self.model_path = model_path
self.inference_steps = inference_steps
self.sample_rate = SAMPLE_RATE
self.processor: Optional[VibeVoiceStreamingProcessor] = None
self.model: Optional[VibeVoiceStreamingForConditionalGenerationInference] = None
self.voice_presets: Dict[str, Path] = {}
self.default_voice_key: Optional[str] = None
self._voice_cache: Dict[str, Tuple[object, Path, str]] = {}
if device == "cuda" and not torch.cuda.is_available():
print("Warning: CUDA not available. Falling back to CPU.")
device = "cpu"
self.device = device
self._torch_device = torch.device(device)
def load(self) -> None:
print(f"[startup] Loading processor from {self.model_path}")
self.processor = VibeVoiceStreamingProcessor.from_pretrained(self.model_path)
if self.device == "cuda":
load_dtype = torch.bfloat16
device_map = 'cuda'
attn_impl_primary = "flash_attention_2"
else:
load_dtype = torch.float32
device_map = 'cpu'
attn_impl_primary = "sdpa"
print(f"Using device: {device_map}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
try:
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
self.model_path, torch_dtype=load_dtype, device_map=device_map, attn_implementation=attn_impl_primary,
)
except Exception as e:
print(f"Error loading model with {attn_impl_primary}: {e}")
if attn_impl_primary == 'flash_attention_2':
print("Falling back to SDPA...")
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
self.model_path, torch_dtype=load_dtype, device_map=self.device, attn_implementation='sdpa',
)
else:
raise e
self.model.eval()
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
self.model.model.noise_scheduler.config, algorithm_type="sde-dpmsolver++", beta_schedule="squaredcos_cap_v2",
)
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
self.voice_presets = self._load_voice_presets()
self.default_voice_key = self._determine_voice_key(None)
def _load_voice_presets(self) -> Dict[str, Path]:
voices_dir = Path("./demo/voices/streaming_model")
if not voices_dir.exists():
if Path("demo").exists(): pass
else: raise RuntimeError(f"Cannot find voices dir at {voices_dir}")
presets: Dict[str, Path] = {}
for pt_path in voices_dir.rglob("*.pt"):
presets[pt_path.stem] = pt_path
if not presets: raise RuntimeError(f"No voice preset (.pt) files found in {voices_dir}")
print(f"[startup] Found {len(presets)} voice presets")
return dict(sorted(presets.items()))
def _determine_voice_key(self, name: Optional[str]) -> str:
if name and name in self.voice_presets: return name
candidates = ["en-WHTest_man"]
for c in candidates:
if c in self.voice_presets: return c
return next(iter(self.voice_presets))
def _ensure_voice_cached(self, key: str) -> object:
if key not in self.voice_presets: key = self.default_voice_key
if key not in self._voice_cache:
preset_path = self.voice_presets[key]
prefilled_outputs = torch.load(preset_path, map_location=self._torch_device, weights_only=False)
self._voice_cache[key] = prefilled_outputs
return self._voice_cache[key]
def _prepare_inputs(self, text: str, prefilled_outputs: object):
processor_kwargs = {"text": text.strip(), "cached_prompt": prefilled_outputs, "padding": True, "return_tensors": "pt", "return_attention_mask": True}
processed = self.processor.process_input_with_cached_prompt(**processor_kwargs)
prepared = {key: value.to(self._torch_device) if hasattr(value, "to") else value for key, value in processed.items()}
return prepared
def _run_generation(self, inputs, audio_streamer, errors, stop_event, prefilled_outputs):
try:
self.model.generate(**inputs, max_new_tokens=None, cfg_scale=1.5, tokenizer=self.processor.tokenizer, generation_config={"do_sample": True, "temperature": 1.0, "top_p": 1.0}, audio_streamer=audio_streamer, stop_check_fn=stop_event.is_set, verbose=False, refresh_negative=True, all_prefilled_outputs=copy.deepcopy(prefilled_outputs))
except Exception as e:
errors.append(e)
print(f"Generation error: {e}")
audio_streamer.end()
def stream(self, text: str, voice_key: str) -> Iterator[Tuple[int, np.ndarray]]:
if not text.strip(): return
prefilled_outputs = self._ensure_voice_cached(voice_key)
audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None)
stop_event = threading.Event()
errors = []
inputs = self._prepare_inputs(text, prefilled_outputs)
thread = threading.Thread(target=self._run_generation, kwargs={"inputs": inputs, "audio_streamer": audio_streamer, "errors": errors, "stop_event": stop_event, "prefilled_outputs": prefilled_outputs}, daemon=True)
thread.start()
try:
stream = audio_streamer.get_stream(0)
for audio_chunk in stream:
if torch.is_tensor(audio_chunk): audio_chunk = audio_chunk.detach().cpu().float().numpy()
else: audio_chunk = np.asarray(audio_chunk, dtype=np.float32)
if audio_chunk.ndim > 1: audio_chunk = audio_chunk.reshape(-1)
yield (SAMPLE_RATE, audio_chunk)
finally:
stop_event.set()
audio_streamer.end()
thread.join()
if errors: raise errors[0]
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
service = StreamingTTSService(MODEL_ID)
service.load()
def tts_generate(text, voice):
yield from service.stream(text, voice)
with gr.Blocks(title="VibeVoice-Realtime Demo") as demo:
gr.Markdown("# Microsoft VibeVoice-Realtime-0.5B")
with gr.Row():
text_input = gr.Textbox(label="Input Text", value="Hello world! This is VibeVoice speaking realtime.")
voice_dropdown = gr.Dropdown(choices=list(service.voice_presets.keys()), value=service.default_voice_key, label="Voice Preset")
audio_output = gr.Audio(label="Generated Audio", streaming=True, autoplay=True)
btn = gr.Button("Generate", variant="primary")
btn.click(tts_generate, inputs=[text_input, voice_dropdown], outputs=[audio_output])
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|