TGPro1 commited on
Commit
27564df
·
verified ·
1 Parent(s): a3d8979

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import threading
4
+ import torch
5
+ import numpy as np
6
+ import gradio as gr
7
+ from pathlib import Path
8
+ from typing import Dict, Optional, Tuple, Iterator, Any
9
+ import copy
10
+
11
+ from vibevoice.modular.modeling_vibevoice_streaming_inference import (
12
+ VibeVoiceStreamingForConditionalGenerationInference,
13
+ )
14
+ from vibevoice.processor.vibevoice_streaming_processor import (
15
+ VibeVoiceStreamingProcessor,
16
+ )
17
+ from vibevoice.modular.streamer import AudioStreamer
18
+
19
+ SAMPLE_RATE = 24_000
20
+
21
+ class StreamingTTSService:
22
+ def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5) -> None:
23
+ self.model_path = model_path
24
+ self.inference_steps = inference_steps
25
+ self.sample_rate = SAMPLE_RATE
26
+ self.processor: Optional[VibeVoiceStreamingProcessor] = None
27
+ self.model: Optional[VibeVoiceStreamingForConditionalGenerationInference] = None
28
+ self.voice_presets: Dict[str, Path] = {}
29
+ self.default_voice_key: Optional[str] = None
30
+ self._voice_cache: Dict[str, Tuple[object, Path, str]] = {}
31
+ if device == "cuda" and not torch.cuda.is_available():
32
+ print("Warning: CUDA not available. Falling back to CPU.")
33
+ device = "cpu"
34
+ self.device = device
35
+ self._torch_device = torch.device(device)
36
+
37
+ def load(self) -> None:
38
+ print(f"[startup] Loading processor from {self.model_path}")
39
+ self.processor = VibeVoiceStreamingProcessor.from_pretrained(self.model_path)
40
+ if self.device == "cuda":
41
+ load_dtype = torch.bfloat16
42
+ device_map = 'cuda'
43
+ attn_impl_primary = "flash_attention_2"
44
+ else:
45
+ load_dtype = torch.float32
46
+ device_map = 'cpu'
47
+ attn_impl_primary = "sdpa"
48
+ print(f"Using device: {device_map}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
49
+ try:
50
+ self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
51
+ self.model_path, torch_dtype=load_dtype, device_map=device_map, attn_implementation=attn_impl_primary,
52
+ )
53
+ except Exception as e:
54
+ print(f"Error loading model with {attn_impl_primary}: {e}")
55
+ if attn_impl_primary == 'flash_attention_2':
56
+ print("Falling back to SDPA...")
57
+ self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
58
+ self.model_path, torch_dtype=load_dtype, device_map=self.device, attn_implementation='sdpa',
59
+ )
60
+ else:
61
+ raise e
62
+ self.model.eval()
63
+ self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
64
+ self.model.model.noise_scheduler.config, algorithm_type="sde-dpmsolver++", beta_schedule="squaredcos_cap_v2",
65
+ )
66
+ self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
67
+ self.voice_presets = self._load_voice_presets()
68
+ self.default_voice_key = self._determine_voice_key(None)
69
+
70
+ def _load_voice_presets(self) -> Dict[str, Path]:
71
+ voices_dir = Path("./demo/voices/streaming_model")
72
+ if not voices_dir.exists():
73
+ if Path("demo").exists(): pass
74
+ else: raise RuntimeError(f"Cannot find voices dir at {voices_dir}")
75
+ presets: Dict[str, Path] = {}
76
+ for pt_path in voices_dir.rglob("*.pt"):
77
+ presets[pt_path.stem] = pt_path
78
+ if not presets: raise RuntimeError(f"No voice preset (.pt) files found in {voices_dir}")
79
+ print(f"[startup] Found {len(presets)} voice presets")
80
+ return dict(sorted(presets.items()))
81
+
82
+ def _determine_voice_key(self, name: Optional[str]) -> str:
83
+ if name and name in self.voice_presets: return name
84
+ candidates = ["en-WHTest_man"]
85
+ for c in candidates:
86
+ if c in self.voice_presets: return c
87
+ return next(iter(self.voice_presets))
88
+
89
+ def _ensure_voice_cached(self, key: str) -> object:
90
+ if key not in self.voice_presets: key = self.default_voice_key
91
+ if key not in self._voice_cache:
92
+ preset_path = self.voice_presets[key]
93
+ prefilled_outputs = torch.load(preset_path, map_location=self._torch_device, weights_only=False)
94
+ self._voice_cache[key] = prefilled_outputs
95
+ return self._voice_cache[key]
96
+
97
+ def _prepare_inputs(self, text: str, prefilled_outputs: object):
98
+ processor_kwargs = {"text": text.strip(), "cached_prompt": prefilled_outputs, "padding": True, "return_tensors": "pt", "return_attention_mask": True}
99
+ processed = self.processor.process_input_with_cached_prompt(**processor_kwargs)
100
+ prepared = {key: value.to(self._torch_device) if hasattr(value, "to") else value for key, value in processed.items()}
101
+ return prepared
102
+
103
+ def _run_generation(self, inputs, audio_streamer, errors, stop_event, prefilled_outputs):
104
+ try:
105
+ self.model.generate(**inputs, max_new_tokens=None, cfg_scale=1.5, tokenizer=self.processor.tokenizer, generation_config={"do_sample": True, "temperature": 1.0, "top_p": 1.0}, audio_streamer=audio_streamer, stop_check_fn=stop_event.is_set, verbose=False, refresh_negative=True, all_prefilled_outputs=copy.deepcopy(prefilled_outputs))
106
+ except Exception as e:
107
+ errors.append(e)
108
+ print(f"Generation error: {e}")
109
+ audio_streamer.end()
110
+
111
+ def stream(self, text: str, voice_key: str) -> Iterator[Tuple[int, np.ndarray]]:
112
+ if not text.strip(): return
113
+ prefilled_outputs = self._ensure_voice_cached(voice_key)
114
+ audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None)
115
+ stop_event = threading.Event()
116
+ errors = []
117
+ inputs = self._prepare_inputs(text, prefilled_outputs)
118
+ thread = threading.Thread(target=self._run_generation, kwargs={"inputs": inputs, "audio_streamer": audio_streamer, "errors": errors, "stop_event": stop_event, "prefilled_outputs": prefilled_outputs}, daemon=True)
119
+ thread.start()
120
+ try:
121
+ stream = audio_streamer.get_stream(0)
122
+ for audio_chunk in stream:
123
+ if torch.is_tensor(audio_chunk): audio_chunk = audio_chunk.detach().cpu().float().numpy()
124
+ else: audio_chunk = np.asarray(audio_chunk, dtype=np.float32)
125
+ if audio_chunk.ndim > 1: audio_chunk = audio_chunk.reshape(-1)
126
+ yield (SAMPLE_RATE, audio_chunk)
127
+ finally:
128
+ stop_event.set()
129
+ audio_streamer.end()
130
+ thread.join()
131
+ if errors: raise errors[0]
132
+
133
+ MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
134
+ service = StreamingTTSService(MODEL_ID)
135
+ service.load()
136
+
137
+ def tts_generate(text, voice):
138
+ yield from service.stream(text, voice)
139
+
140
+ with gr.Blocks(title="VibeVoice-Realtime Demo") as demo:
141
+ gr.Markdown("# Microsoft VibeVoice-Realtime-0.5B")
142
+ with gr.Row():
143
+ text_input = gr.Textbox(label="Input Text", value="Hello world! This is VibeVoice speaking realtime.")
144
+ voice_dropdown = gr.Dropdown(choices=list(service.voice_presets.keys()), value=service.default_voice_key, label="Voice Preset")
145
+ audio_output = gr.Audio(label="Generated Audio", streaming=True, autoplay=True)
146
+ btn = gr.Button("Generate", variant="primary")
147
+ btn.click(tts_generate, inputs=[text_input, voice_dropdown], outputs=[audio_output])
148
+
149
+ if __name__ == "__main__":
150
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)