TGPro1's picture
Create app.py
27564df verified
import os
import shutil
import threading
import torch
import numpy as np
import gradio as gr
from pathlib import Path
from typing import Dict, Optional, Tuple, Iterator, Any
import copy
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
VibeVoiceStreamingForConditionalGenerationInference,
)
from vibevoice.processor.vibevoice_streaming_processor import (
VibeVoiceStreamingProcessor,
)
from vibevoice.modular.streamer import AudioStreamer
SAMPLE_RATE = 24_000
class StreamingTTSService:
def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5) -> None:
self.model_path = model_path
self.inference_steps = inference_steps
self.sample_rate = SAMPLE_RATE
self.processor: Optional[VibeVoiceStreamingProcessor] = None
self.model: Optional[VibeVoiceStreamingForConditionalGenerationInference] = None
self.voice_presets: Dict[str, Path] = {}
self.default_voice_key: Optional[str] = None
self._voice_cache: Dict[str, Tuple[object, Path, str]] = {}
if device == "cuda" and not torch.cuda.is_available():
print("Warning: CUDA not available. Falling back to CPU.")
device = "cpu"
self.device = device
self._torch_device = torch.device(device)
def load(self) -> None:
print(f"[startup] Loading processor from {self.model_path}")
self.processor = VibeVoiceStreamingProcessor.from_pretrained(self.model_path)
if self.device == "cuda":
load_dtype = torch.bfloat16
device_map = 'cuda'
attn_impl_primary = "flash_attention_2"
else:
load_dtype = torch.float32
device_map = 'cpu'
attn_impl_primary = "sdpa"
print(f"Using device: {device_map}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
try:
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
self.model_path, torch_dtype=load_dtype, device_map=device_map, attn_implementation=attn_impl_primary,
)
except Exception as e:
print(f"Error loading model with {attn_impl_primary}: {e}")
if attn_impl_primary == 'flash_attention_2':
print("Falling back to SDPA...")
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
self.model_path, torch_dtype=load_dtype, device_map=self.device, attn_implementation='sdpa',
)
else:
raise e
self.model.eval()
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
self.model.model.noise_scheduler.config, algorithm_type="sde-dpmsolver++", beta_schedule="squaredcos_cap_v2",
)
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
self.voice_presets = self._load_voice_presets()
self.default_voice_key = self._determine_voice_key(None)
def _load_voice_presets(self) -> Dict[str, Path]:
voices_dir = Path("./demo/voices/streaming_model")
if not voices_dir.exists():
if Path("demo").exists(): pass
else: raise RuntimeError(f"Cannot find voices dir at {voices_dir}")
presets: Dict[str, Path] = {}
for pt_path in voices_dir.rglob("*.pt"):
presets[pt_path.stem] = pt_path
if not presets: raise RuntimeError(f"No voice preset (.pt) files found in {voices_dir}")
print(f"[startup] Found {len(presets)} voice presets")
return dict(sorted(presets.items()))
def _determine_voice_key(self, name: Optional[str]) -> str:
if name and name in self.voice_presets: return name
candidates = ["en-WHTest_man"]
for c in candidates:
if c in self.voice_presets: return c
return next(iter(self.voice_presets))
def _ensure_voice_cached(self, key: str) -> object:
if key not in self.voice_presets: key = self.default_voice_key
if key not in self._voice_cache:
preset_path = self.voice_presets[key]
prefilled_outputs = torch.load(preset_path, map_location=self._torch_device, weights_only=False)
self._voice_cache[key] = prefilled_outputs
return self._voice_cache[key]
def _prepare_inputs(self, text: str, prefilled_outputs: object):
processor_kwargs = {"text": text.strip(), "cached_prompt": prefilled_outputs, "padding": True, "return_tensors": "pt", "return_attention_mask": True}
processed = self.processor.process_input_with_cached_prompt(**processor_kwargs)
prepared = {key: value.to(self._torch_device) if hasattr(value, "to") else value for key, value in processed.items()}
return prepared
def _run_generation(self, inputs, audio_streamer, errors, stop_event, prefilled_outputs):
try:
self.model.generate(**inputs, max_new_tokens=None, cfg_scale=1.5, tokenizer=self.processor.tokenizer, generation_config={"do_sample": True, "temperature": 1.0, "top_p": 1.0}, audio_streamer=audio_streamer, stop_check_fn=stop_event.is_set, verbose=False, refresh_negative=True, all_prefilled_outputs=copy.deepcopy(prefilled_outputs))
except Exception as e:
errors.append(e)
print(f"Generation error: {e}")
audio_streamer.end()
def stream(self, text: str, voice_key: str) -> Iterator[Tuple[int, np.ndarray]]:
if not text.strip(): return
prefilled_outputs = self._ensure_voice_cached(voice_key)
audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None)
stop_event = threading.Event()
errors = []
inputs = self._prepare_inputs(text, prefilled_outputs)
thread = threading.Thread(target=self._run_generation, kwargs={"inputs": inputs, "audio_streamer": audio_streamer, "errors": errors, "stop_event": stop_event, "prefilled_outputs": prefilled_outputs}, daemon=True)
thread.start()
try:
stream = audio_streamer.get_stream(0)
for audio_chunk in stream:
if torch.is_tensor(audio_chunk): audio_chunk = audio_chunk.detach().cpu().float().numpy()
else: audio_chunk = np.asarray(audio_chunk, dtype=np.float32)
if audio_chunk.ndim > 1: audio_chunk = audio_chunk.reshape(-1)
yield (SAMPLE_RATE, audio_chunk)
finally:
stop_event.set()
audio_streamer.end()
thread.join()
if errors: raise errors[0]
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
service = StreamingTTSService(MODEL_ID)
service.load()
def tts_generate(text, voice):
yield from service.stream(text, voice)
with gr.Blocks(title="VibeVoice-Realtime Demo") as demo:
gr.Markdown("# Microsoft VibeVoice-Realtime-0.5B")
with gr.Row():
text_input = gr.Textbox(label="Input Text", value="Hello world! This is VibeVoice speaking realtime.")
voice_dropdown = gr.Dropdown(choices=list(service.voice_presets.keys()), value=service.default_voice_key, label="Voice Preset")
audio_output = gr.Audio(label="Generated Audio", streaming=True, autoplay=True)
btn = gr.Button("Generate", variant="primary")
btn.click(tts_generate, inputs=[text_input, voice_dropdown], outputs=[audio_output])
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)