|
|
|
|
|
""" |
|
|
This is an extra gRPC server of LocalAI for VibeVoice |
|
|
""" |
|
|
from concurrent import futures |
|
|
import time |
|
|
import argparse |
|
|
import signal |
|
|
import sys |
|
|
import os |
|
|
import copy |
|
|
import traceback |
|
|
from pathlib import Path |
|
|
import backend_pb2 |
|
|
import backend_pb2_grpc |
|
|
import torch |
|
|
from vibevoice.modular.modeling_vibevoice_streaming_inference import VibeVoiceStreamingForConditionalGenerationInference |
|
|
from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor |
|
|
|
|
|
import grpc |
|
|
|
|
|
def is_float(s): |
|
|
"""Check if a string can be converted to float.""" |
|
|
try: |
|
|
float(s) |
|
|
return True |
|
|
except ValueError: |
|
|
return False |
|
|
def is_int(s): |
|
|
"""Check if a string can be converted to int.""" |
|
|
try: |
|
|
int(s) |
|
|
return True |
|
|
except ValueError: |
|
|
return False |
|
|
|
|
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 |
|
|
|
|
|
|
|
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) |
|
|
|
|
|
|
|
|
class BackendServicer(backend_pb2_grpc.BackendServicer): |
|
|
""" |
|
|
BackendServicer is the class that implements the gRPC service |
|
|
""" |
|
|
def Health(self, request, context): |
|
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8')) |
|
|
|
|
|
def LoadModel(self, request, context): |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
print("CUDA is available", file=sys.stderr) |
|
|
device = "cuda" |
|
|
else: |
|
|
print("CUDA is not available", file=sys.stderr) |
|
|
device = "cpu" |
|
|
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() |
|
|
if mps_available: |
|
|
device = "mps" |
|
|
if not torch.cuda.is_available() and request.CUDA: |
|
|
return backend_pb2.Result(success=False, message="CUDA is not available") |
|
|
|
|
|
|
|
|
if device == "mpx": |
|
|
print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr) |
|
|
device = "mps" |
|
|
|
|
|
|
|
|
if device == "mps" and not torch.backends.mps.is_available(): |
|
|
print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr) |
|
|
device = "cpu" |
|
|
|
|
|
self.device = device |
|
|
self._torch_device = torch.device(device) |
|
|
|
|
|
options = request.Options |
|
|
|
|
|
|
|
|
self.options = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for opt in options: |
|
|
if ":" not in opt: |
|
|
continue |
|
|
key, value = opt.split(":", 1) |
|
|
|
|
|
if is_float(value): |
|
|
value = float(value) |
|
|
elif is_int(value): |
|
|
value = int(value) |
|
|
elif value.lower() in ["true", "false"]: |
|
|
value = value.lower() == "true" |
|
|
self.options[key] = value |
|
|
|
|
|
|
|
|
model_path = request.Model |
|
|
if not model_path: |
|
|
model_path = "microsoft/VibeVoice-Realtime-0.5B" |
|
|
|
|
|
|
|
|
self.inference_steps = self.options.get("inference_steps", 5) |
|
|
if not isinstance(self.inference_steps, int) or self.inference_steps <= 0: |
|
|
self.inference_steps = 5 |
|
|
|
|
|
|
|
|
self.cfg_scale = self.options.get("cfg_scale", 1.5) |
|
|
if not isinstance(self.cfg_scale, (int, float)) or self.cfg_scale <= 0: |
|
|
self.cfg_scale = 1.5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
voices_dir = None |
|
|
|
|
|
|
|
|
if "voices_dir" in self.options: |
|
|
voices_dir_option = self.options["voices_dir"] |
|
|
if isinstance(voices_dir_option, str) and voices_dir_option.strip(): |
|
|
voices_dir = voices_dir_option.strip() |
|
|
|
|
|
if not os.path.isabs(voices_dir): |
|
|
if hasattr(request, 'ModelPath') and request.ModelPath: |
|
|
voices_dir = os.path.join(request.ModelPath, voices_dir) |
|
|
elif request.ModelFile: |
|
|
model_file_base = os.path.dirname(request.ModelFile) |
|
|
voices_dir = os.path.join(model_file_base, voices_dir) |
|
|
|
|
|
if not os.path.isabs(voices_dir): |
|
|
voices_dir = os.path.abspath(voices_dir) |
|
|
|
|
|
if not os.path.exists(voices_dir): |
|
|
print(f"Warning: voices_dir option specified but directory does not exist: {voices_dir}", file=sys.stderr) |
|
|
voices_dir = None |
|
|
|
|
|
|
|
|
if not voices_dir and request.ModelFile: |
|
|
model_file_base = os.path.dirname(request.ModelFile) |
|
|
voices_dir = os.path.join(model_file_base, "voices", "streaming_model") |
|
|
if not os.path.exists(voices_dir): |
|
|
voices_dir = None |
|
|
|
|
|
|
|
|
if not voices_dir and hasattr(request, 'ModelPath') and request.ModelPath: |
|
|
voices_dir = os.path.join(request.ModelPath, "voices", "streaming_model") |
|
|
if not os.path.exists(voices_dir): |
|
|
voices_dir = None |
|
|
|
|
|
|
|
|
if not voices_dir: |
|
|
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
voices_dir = os.path.join(backend_dir, "vibevoice", "voices", "streaming_model") |
|
|
if not os.path.exists(voices_dir): |
|
|
|
|
|
if request.AudioPath and os.path.isabs(request.AudioPath): |
|
|
voices_dir = os.path.dirname(request.AudioPath) |
|
|
else: |
|
|
voices_dir = None |
|
|
|
|
|
self.voices_dir = voices_dir |
|
|
self.voice_presets = {} |
|
|
self._voice_cache = {} |
|
|
self.default_voice_key = None |
|
|
|
|
|
|
|
|
if self.voices_dir and os.path.exists(self.voices_dir): |
|
|
self._load_voice_presets() |
|
|
else: |
|
|
print(f"Warning: Voices directory not found. Voice presets will not be available.", file=sys.stderr) |
|
|
|
|
|
try: |
|
|
print(f"Loading processor & model from {model_path}", file=sys.stderr) |
|
|
self.processor = VibeVoiceStreamingProcessor.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
if self.device == "mps": |
|
|
load_dtype = torch.float32 |
|
|
device_map = None |
|
|
attn_impl_primary = "sdpa" |
|
|
elif self.device == "cuda": |
|
|
load_dtype = torch.bfloat16 |
|
|
device_map = "cuda" |
|
|
attn_impl_primary = "flash_attention_2" |
|
|
else: |
|
|
load_dtype = torch.float32 |
|
|
device_map = "cpu" |
|
|
attn_impl_primary = "sdpa" |
|
|
|
|
|
print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}", file=sys.stderr) |
|
|
|
|
|
|
|
|
try: |
|
|
if self.device == "mps": |
|
|
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=load_dtype, |
|
|
attn_implementation=attn_impl_primary, |
|
|
device_map=None, |
|
|
) |
|
|
self.model.to("mps") |
|
|
elif self.device == "cuda": |
|
|
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=load_dtype, |
|
|
device_map="cuda", |
|
|
attn_implementation=attn_impl_primary, |
|
|
) |
|
|
else: |
|
|
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=load_dtype, |
|
|
device_map="cpu", |
|
|
attn_implementation=attn_impl_primary, |
|
|
) |
|
|
except Exception as e: |
|
|
if attn_impl_primary == 'flash_attention_2': |
|
|
print(f"[ERROR] : {type(e).__name__}: {e}", file=sys.stderr) |
|
|
print(traceback.format_exc(), file=sys.stderr) |
|
|
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.", file=sys.stderr) |
|
|
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=load_dtype, |
|
|
device_map=(self.device if self.device in ("cuda", "cpu") else None), |
|
|
attn_implementation='sdpa' |
|
|
) |
|
|
if self.device == "mps": |
|
|
self.model.to("mps") |
|
|
else: |
|
|
raise e |
|
|
|
|
|
self.model.eval() |
|
|
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) |
|
|
|
|
|
|
|
|
if self.voice_presets: |
|
|
|
|
|
preset_name = os.environ.get("VOICE_PRESET") |
|
|
self.default_voice_key = self._determine_voice_key(preset_name) |
|
|
print(f"Default voice preset: {self.default_voice_key}", file=sys.stderr) |
|
|
else: |
|
|
print("Warning: No voice presets available. Voice selection will not work.", file=sys.stderr) |
|
|
|
|
|
except Exception as err: |
|
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") |
|
|
|
|
|
return backend_pb2.Result(message="Model loaded successfully", success=True) |
|
|
|
|
|
def _load_voice_presets(self): |
|
|
"""Load voice presets from the voices directory.""" |
|
|
if not self.voices_dir or not os.path.exists(self.voices_dir): |
|
|
self.voice_presets = {} |
|
|
return |
|
|
|
|
|
self.voice_presets = {} |
|
|
|
|
|
|
|
|
pt_files = [f for f in os.listdir(self.voices_dir) |
|
|
if f.lower().endswith('.pt') and os.path.isfile(os.path.join(self.voices_dir, f))] |
|
|
|
|
|
|
|
|
for pt_file in pt_files: |
|
|
|
|
|
name = os.path.splitext(pt_file)[0] |
|
|
|
|
|
full_path = os.path.join(self.voices_dir, pt_file) |
|
|
self.voice_presets[name] = full_path |
|
|
|
|
|
|
|
|
self.voice_presets = dict(sorted(self.voice_presets.items())) |
|
|
|
|
|
print(f"Found {len(self.voice_presets)} voice files in {self.voices_dir}", file=sys.stderr) |
|
|
if self.voice_presets: |
|
|
print(f"Available voices: {', '.join(self.voice_presets.keys())}", file=sys.stderr) |
|
|
|
|
|
def _determine_voice_key(self, name): |
|
|
"""Determine voice key from name or use default.""" |
|
|
if name and name in self.voice_presets: |
|
|
return name |
|
|
|
|
|
|
|
|
default_key = "en-WHTest_man" |
|
|
if default_key in self.voice_presets: |
|
|
return default_key |
|
|
|
|
|
|
|
|
if self.voice_presets: |
|
|
first_key = next(iter(self.voice_presets)) |
|
|
print(f"Using fallback voice preset: {first_key}", file=sys.stderr) |
|
|
return first_key |
|
|
|
|
|
return None |
|
|
|
|
|
def _get_voice_path(self, speaker_name): |
|
|
"""Get voice file path for a given speaker name.""" |
|
|
if not self.voice_presets: |
|
|
return None |
|
|
|
|
|
|
|
|
if speaker_name and speaker_name in self.voice_presets: |
|
|
return self.voice_presets[speaker_name] |
|
|
|
|
|
|
|
|
if speaker_name: |
|
|
speaker_lower = speaker_name.lower() |
|
|
for preset_name, path in self.voice_presets.items(): |
|
|
if preset_name.lower() in speaker_lower or speaker_lower in preset_name.lower(): |
|
|
return path |
|
|
|
|
|
|
|
|
if self.default_voice_key and self.default_voice_key in self.voice_presets: |
|
|
return self.voice_presets[self.default_voice_key] |
|
|
elif self.voice_presets: |
|
|
default_voice = list(self.voice_presets.values())[0] |
|
|
print(f"Warning: No voice preset found for '{speaker_name}', using default voice: {default_voice}", file=sys.stderr) |
|
|
return default_voice |
|
|
|
|
|
return None |
|
|
|
|
|
def _ensure_voice_cached(self, voice_path): |
|
|
"""Load and cache voice preset.""" |
|
|
if not voice_path or not os.path.exists(voice_path): |
|
|
return None |
|
|
|
|
|
|
|
|
if voice_path not in self._voice_cache: |
|
|
print(f"Loading prefilled prompt from {voice_path}", file=sys.stderr) |
|
|
prefilled_outputs = torch.load( |
|
|
voice_path, |
|
|
map_location=self._torch_device, |
|
|
weights_only=False, |
|
|
) |
|
|
self._voice_cache[voice_path] = prefilled_outputs |
|
|
|
|
|
return self._voice_cache[voice_path] |
|
|
|
|
|
def TTS(self, request, context): |
|
|
try: |
|
|
|
|
|
|
|
|
voice_path = None |
|
|
voice_key = None |
|
|
|
|
|
if request.voice: |
|
|
|
|
|
voice_path = self._get_voice_path(request.voice) |
|
|
if voice_path: |
|
|
voice_key = request.voice |
|
|
elif request.AudioPath: |
|
|
|
|
|
if os.path.isabs(request.AudioPath): |
|
|
voice_path = request.AudioPath |
|
|
elif request.ModelFile: |
|
|
model_file_base = os.path.dirname(request.ModelFile) |
|
|
voice_path = os.path.join(model_file_base, request.AudioPath) |
|
|
elif hasattr(request, 'ModelPath') and request.ModelPath: |
|
|
voice_path = os.path.join(request.ModelPath, request.AudioPath) |
|
|
else: |
|
|
voice_path = request.AudioPath |
|
|
elif self.default_voice_key: |
|
|
voice_path = self._get_voice_path(self.default_voice_key) |
|
|
voice_key = self.default_voice_key |
|
|
|
|
|
if not voice_path or not os.path.exists(voice_path): |
|
|
return backend_pb2.Result( |
|
|
success=False, |
|
|
message=f"Voice file not found: {voice_path}. Please provide a valid voice preset or AudioPath." |
|
|
) |
|
|
|
|
|
|
|
|
prefilled_outputs = self._ensure_voice_cached(voice_path) |
|
|
if prefilled_outputs is None: |
|
|
return backend_pb2.Result( |
|
|
success=False, |
|
|
message=f"Failed to load voice preset from {voice_path}" |
|
|
) |
|
|
|
|
|
|
|
|
cfg_scale = self.options.get("cfg_scale", self.cfg_scale) |
|
|
inference_steps = self.options.get("inference_steps", self.inference_steps) |
|
|
do_sample = self.options.get("do_sample", False) |
|
|
temperature = self.options.get("temperature", 0.9) |
|
|
top_p = self.options.get("top_p", 0.9) |
|
|
|
|
|
|
|
|
if inference_steps != self.inference_steps: |
|
|
self.model.set_ddpm_inference_steps(num_steps=inference_steps) |
|
|
self.inference_steps = inference_steps |
|
|
|
|
|
|
|
|
text = request.text.strip().replace("'", "'").replace('"', '"').replace('"', '"') |
|
|
|
|
|
|
|
|
inputs = self.processor.process_input_with_cached_prompt( |
|
|
text=text, |
|
|
cached_prompt=prefilled_outputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
return_attention_mask=True, |
|
|
) |
|
|
|
|
|
|
|
|
target_device = self._torch_device |
|
|
for k, v in inputs.items(): |
|
|
if torch.is_tensor(v): |
|
|
inputs[k] = v.to(target_device) |
|
|
|
|
|
print(f"Generating audio with cfg_scale: {cfg_scale}, inference_steps: {inference_steps}", file=sys.stderr) |
|
|
|
|
|
|
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=None, |
|
|
cfg_scale=cfg_scale, |
|
|
tokenizer=self.processor.tokenizer, |
|
|
generation_config={ |
|
|
'do_sample': do_sample, |
|
|
'temperature': temperature if do_sample else 1.0, |
|
|
'top_p': top_p if do_sample else 1.0, |
|
|
}, |
|
|
verbose=False, |
|
|
all_prefilled_outputs=copy.deepcopy(prefilled_outputs) if prefilled_outputs is not None else None, |
|
|
) |
|
|
|
|
|
|
|
|
if outputs.speech_outputs and outputs.speech_outputs[0] is not None: |
|
|
self.processor.save_audio( |
|
|
outputs.speech_outputs[0], |
|
|
output_path=request.dst, |
|
|
) |
|
|
print(f"Saved output to {request.dst}", file=sys.stderr) |
|
|
else: |
|
|
return backend_pb2.Result( |
|
|
success=False, |
|
|
message="No audio output generated" |
|
|
) |
|
|
|
|
|
except Exception as err: |
|
|
print(f"Error in TTS: {err}", file=sys.stderr) |
|
|
print(traceback.format_exc(), file=sys.stderr) |
|
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") |
|
|
|
|
|
return backend_pb2.Result(success=True) |
|
|
|
|
|
def serve(address): |
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), |
|
|
options=[ |
|
|
('grpc.max_message_length', 50 * 1024 * 1024), |
|
|
('grpc.max_send_message_length', 50 * 1024 * 1024), |
|
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), |
|
|
]) |
|
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) |
|
|
server.add_insecure_port(address) |
|
|
server.start() |
|
|
print("Server started. Listening on: " + address, file=sys.stderr) |
|
|
|
|
|
|
|
|
def signal_handler(sig, frame): |
|
|
print("Received termination signal. Shutting down...") |
|
|
server.stop(0) |
|
|
sys.exit(0) |
|
|
|
|
|
|
|
|
signal.signal(signal.SIGINT, signal_handler) |
|
|
signal.signal(signal.SIGTERM, signal_handler) |
|
|
|
|
|
try: |
|
|
while True: |
|
|
time.sleep(_ONE_DAY_IN_SECONDS) |
|
|
except KeyboardInterrupt: |
|
|
server.stop(0) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Run the gRPC server.") |
|
|
parser.add_argument( |
|
|
"--addr", default="localhost:50051", help="The address to bind the server to." |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
serve(args.addr) |
|
|
|