|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import traceback |
|
|
import whisper |
|
|
import librosa |
|
|
import numpy as np |
|
|
import torch |
|
|
import uvicorn |
|
|
import base64 |
|
|
import io |
|
|
import re |
|
|
import json |
|
|
import asyncio |
|
|
import tempfile |
|
|
import os |
|
|
try: |
|
|
import edge_tts |
|
|
TTS_AVAILABLE = True |
|
|
except ImportError: |
|
|
TTS_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference |
|
|
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor |
|
|
import soundfile as sf |
|
|
VIBEVOICE_AVAILABLE = True |
|
|
except ImportError: |
|
|
VIBEVOICE_AVAILABLE = False |
|
|
|
|
|
asr_model = whisper.load_model("models/wpt/wpt.pt") |
|
|
model_name = "models/Llama-3.2-1B-Instruct" |
|
|
tok = AutoTokenizer.from_pretrained(model_name) |
|
|
lm = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="cuda", |
|
|
).eval() |
|
|
|
|
|
|
|
|
vibevoice_model = None |
|
|
vibevoice_processor = None |
|
|
vibevoice_voice_sample = None |
|
|
if VIBEVOICE_AVAILABLE: |
|
|
try: |
|
|
vibevoice_model_path = os.getenv("VIBEVOICE_MODEL_PATH", "models/VibeVoice-1.5B") |
|
|
vibevoice_voice_path = os.getenv("VIBEVOICE_VOICE_PATH", None) |
|
|
vibevoice_tokenizer_path = os.getenv("VIBEVOICE_TOKENIZER_PATH", "models/Qwen2.5-1.5B") |
|
|
|
|
|
|
|
|
if vibevoice_model_path and not os.path.isabs(vibevoice_model_path): |
|
|
vibevoice_model_path = os.path.abspath(vibevoice_model_path) |
|
|
if vibevoice_tokenizer_path and not os.path.isabs(vibevoice_tokenizer_path): |
|
|
vibevoice_tokenizer_path = os.path.abspath(vibevoice_tokenizer_path) |
|
|
if vibevoice_voice_path and not os.path.isabs(vibevoice_voice_path): |
|
|
vibevoice_voice_path = os.path.abspath(vibevoice_voice_path) |
|
|
|
|
|
|
|
|
if not vibevoice_tokenizer_path: |
|
|
|
|
|
local_qwen_paths = [ |
|
|
"models/Qwen2.5-1.5B", |
|
|
"models/Qwen/Qwen2.5-1.5B", |
|
|
os.path.join(vibevoice_model_path, "tokenizer"), |
|
|
] |
|
|
for qwen_path in local_qwen_paths: |
|
|
if os.path.exists(qwen_path) and os.path.isdir(qwen_path): |
|
|
|
|
|
tokenizer_files = ["tokenizer_config.json", "vocab.json", "merges.txt"] |
|
|
if any(os.path.exists(os.path.join(qwen_path, f)) for f in tokenizer_files): |
|
|
vibevoice_tokenizer_path = qwen_path |
|
|
print(f"Found local Qwen tokenizer at {qwen_path}") |
|
|
break |
|
|
|
|
|
print(f"Loading VibeVoice processor from {vibevoice_model_path}") |
|
|
|
|
|
|
|
|
preprocessor_config_path = os.path.join(vibevoice_model_path, "preprocessor_config.json") |
|
|
config_modified = False |
|
|
original_config = None |
|
|
original_tokenizer_path = None |
|
|
|
|
|
if vibevoice_tokenizer_path and os.path.exists(preprocessor_config_path): |
|
|
try: |
|
|
import json |
|
|
|
|
|
with open(preprocessor_config_path, 'r') as f: |
|
|
original_config = json.load(f) |
|
|
|
|
|
|
|
|
original_tokenizer_path = original_config.get("language_model_pretrained_name", "") |
|
|
if original_tokenizer_path != vibevoice_tokenizer_path: |
|
|
|
|
|
original_config["language_model_pretrained_name"] = vibevoice_tokenizer_path |
|
|
with open(preprocessor_config_path, 'w') as f: |
|
|
json.dump(original_config, f, indent=2) |
|
|
config_modified = True |
|
|
print(f"Updated preprocessor_config.json to use local tokenizer: {vibevoice_tokenizer_path}") |
|
|
except Exception as config_error: |
|
|
print(f"Warning: Could not modify preprocessor_config.json: {config_error}") |
|
|
|
|
|
|
|
|
processor_kwargs = {} |
|
|
if vibevoice_tokenizer_path: |
|
|
processor_kwargs["language_model_pretrained_name"] = vibevoice_tokenizer_path |
|
|
print(f"Using tokenizer from: {vibevoice_tokenizer_path}") |
|
|
|
|
|
try: |
|
|
vibevoice_processor = VibeVoiceProcessor.from_pretrained(vibevoice_model_path, **processor_kwargs) |
|
|
finally: |
|
|
|
|
|
if config_modified and original_config is not None and original_tokenizer_path is not None: |
|
|
try: |
|
|
|
|
|
original_config["language_model_pretrained_name"] = original_tokenizer_path |
|
|
with open(preprocessor_config_path, 'w') as f: |
|
|
json.dump(original_config, f, indent=2) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Loading VibeVoice model from {vibevoice_model_path}") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32 |
|
|
attn_impl = "flash_attention_2" if device == "cuda" else "sdpa" |
|
|
|
|
|
try: |
|
|
vibevoice_model = VibeVoiceForConditionalGenerationInference.from_pretrained( |
|
|
vibevoice_model_path, |
|
|
torch_dtype=load_dtype, |
|
|
device_map=device if device == "cuda" else None, |
|
|
attn_implementation=attn_impl, |
|
|
) |
|
|
if device != "cuda": |
|
|
vibevoice_model.to(device) |
|
|
except Exception as e: |
|
|
if attn_impl == "flash_attention_2": |
|
|
print(f"Failed to load with flash_attention_2, falling back to sdpa: {e}") |
|
|
vibevoice_model = VibeVoiceForConditionalGenerationInference.from_pretrained( |
|
|
vibevoice_model_path, |
|
|
torch_dtype=load_dtype, |
|
|
device_map=device if device in ("cuda", "cpu") else None, |
|
|
attn_implementation="sdpa", |
|
|
) |
|
|
if device not in ("cuda", "cpu"): |
|
|
vibevoice_model.to(device) |
|
|
else: |
|
|
raise |
|
|
|
|
|
vibevoice_model.eval() |
|
|
vibevoice_model.set_ddpm_inference_steps(num_steps=10) |
|
|
|
|
|
|
|
|
if vibevoice_voice_path and os.path.exists(vibevoice_voice_path) and os.path.isfile(vibevoice_voice_path): |
|
|
print(f"Loading voice sample from {vibevoice_voice_path}") |
|
|
try: |
|
|
wav, sr = sf.read(vibevoice_voice_path) |
|
|
if len(wav.shape) > 1: |
|
|
wav = np.mean(wav, axis=1) |
|
|
if sr != 24000: |
|
|
wav = librosa.resample(wav, orig_sr=sr, target_sr=24000) |
|
|
vibevoice_voice_sample = wav.astype(np.float32) |
|
|
except Exception as voice_error: |
|
|
print(f"Warning: Could not load voice sample from {vibevoice_voice_path}: {voice_error}") |
|
|
vibevoice_voice_sample = None |
|
|
else: |
|
|
|
|
|
default_voice_paths = [ |
|
|
"/app/spk_001.wav", |
|
|
"spk_001.wav", |
|
|
"/home/user/VibeVoice/demo/voices/en-Alice_woman.wav", |
|
|
"demo/voices/en-Alice_woman.wav", |
|
|
"VibeVoice/demo/voices/en-Alice_woman.wav", |
|
|
] |
|
|
for voice_path in default_voice_paths: |
|
|
if os.path.exists(voice_path): |
|
|
print(f"Loading default voice sample from {voice_path}") |
|
|
wav, sr = sf.read(voice_path) |
|
|
if len(wav.shape) > 1: |
|
|
wav = np.mean(wav, axis=1) |
|
|
if sr != 24000: |
|
|
wav = librosa.resample(wav, orig_sr=sr, target_sr=24000) |
|
|
vibevoice_voice_sample = wav.astype(np.float32) |
|
|
break |
|
|
|
|
|
if vibevoice_voice_sample is None: |
|
|
print("Warning: No voice sample found. VibeVoice will work without voice cloning.") |
|
|
|
|
|
print("VibeVoice initialized successfully") |
|
|
except Exception as e: |
|
|
print(f"Failed to initialize VibeVoice: {e}") |
|
|
traceback.print_exc() |
|
|
VIBEVOICE_AVAILABLE = False |
|
|
vibevoice_model = None |
|
|
vibevoice_processor = None |
|
|
class EvalHandler: |
|
|
def __init__(self): |
|
|
self.rule_patterns = { |
|
|
'comma_restriction': re.compile(r'no.*comma|without.*comma', re.IGNORECASE), |
|
|
'placeholder_requirement': re.compile(r'placeholder.*\[.*\]|square.*bracket', re.IGNORECASE), |
|
|
'lowercase_requirement': re.compile(r'lowercase|no.*capital|all.*lowercase', re.IGNORECASE), |
|
|
'capital_frequency': re.compile(r'capital.*letter.*less.*than|capital.*word.*frequency', re.IGNORECASE), |
|
|
'quotation_requirement': re.compile(r'wrap.*quotation|double.*quote', re.IGNORECASE), |
|
|
'json_format': re.compile(r'json.*format|JSON.*output|format.*json', re.IGNORECASE), |
|
|
'word_count': re.compile(r'less.*than.*word|word.*limit|maximum.*word', re.IGNORECASE), |
|
|
'section_requirement': re.compile(r'section.*start|SECTION.*X', re.IGNORECASE), |
|
|
'ending_requirement': re.compile(r'finish.*exact.*phrase|end.*phrase', re.IGNORECASE), |
|
|
'forbidden_words': re.compile(r'not.*allowed|forbidden.*word|without.*word', re.IGNORECASE), |
|
|
'capital_letters_only': re.compile(r'all.*capital|CAPITAL.*letter', re.IGNORECASE) |
|
|
} |
|
|
|
|
|
def detect_rules(self, instruction): |
|
|
applicable_rules = [] |
|
|
if self.rule_patterns['comma_restriction'].search(instruction): |
|
|
applicable_rules.append('CommaChecker') |
|
|
if self.rule_patterns['placeholder_requirement'].search(instruction): |
|
|
applicable_rules.append('PlaceholderChecker') |
|
|
if self.rule_patterns['lowercase_requirement'].search(instruction): |
|
|
applicable_rules.append('LowercaseLettersEnglishChecker') |
|
|
if self.rule_patterns['capital_frequency'].search(instruction): |
|
|
applicable_rules.append('CapitalWordFrequencyChecker') |
|
|
if self.rule_patterns['quotation_requirement'].search(instruction): |
|
|
applicable_rules.append('QuotationChecker') |
|
|
if self.rule_patterns['json_format'].search(instruction): |
|
|
applicable_rules.append('JsonFormat') |
|
|
if self.rule_patterns['word_count'].search(instruction): |
|
|
applicable_rules.append('NumberOfWords') |
|
|
if self.rule_patterns['section_requirement'].search(instruction): |
|
|
applicable_rules.append('SectionChecker') |
|
|
if self.rule_patterns['ending_requirement'].search(instruction): |
|
|
applicable_rules.append('EndChecker') |
|
|
if self.rule_patterns['forbidden_words'].search(instruction): |
|
|
applicable_rules.append('ForbiddenWords') |
|
|
if self.rule_patterns['capital_letters_only'].search(instruction): |
|
|
applicable_rules.append('CapitalLettersEnglishChecker') |
|
|
return applicable_rules |
|
|
|
|
|
def apply_rule_fix(self, response, rules, instruction= ""): |
|
|
for rule in rules: |
|
|
if rule == 'CommaChecker': |
|
|
response = self._fix_commas(response, instruction) |
|
|
elif rule == 'PlaceholderChecker': |
|
|
response = self._fix_placeholders(response, instruction) |
|
|
elif rule == 'LowercaseLettersEnglishChecker': |
|
|
response = self._fix_lowercase(response) |
|
|
elif rule == 'CapitalWordFrequencyChecker': |
|
|
response = self._fix_capital_frequency(response, instruction) |
|
|
elif rule == 'QuotationChecker': |
|
|
response = self._fix_quotations(response) |
|
|
elif rule == 'JsonFormat': |
|
|
response = self._fix_json_format(response, instruction) |
|
|
elif rule == 'NumberOfWords': |
|
|
response = self._fix_word_count(response, instruction) |
|
|
elif rule == 'SectionChecker': |
|
|
response = self._fix_sections(response, instruction) |
|
|
elif rule == 'EndChecker': |
|
|
response = self._fix_ending(response, instruction) |
|
|
elif rule == 'ForbiddenWords': |
|
|
response = self._fix_forbidden_words(response, instruction) |
|
|
elif rule == 'CapitalLettersEnglishChecker': |
|
|
response = self._fix_all_capitals(response, instruction) |
|
|
return response |
|
|
|
|
|
def _fix_commas(self, response, instruction): |
|
|
return response.replace(',', '') |
|
|
|
|
|
def _fix_placeholders(self, response, instruction): |
|
|
num_match = re.search(r'at least (\d+)', instruction, re.IGNORECASE) |
|
|
if num_match: |
|
|
target_count = int(num_match.group(1)) |
|
|
current_count = len(re.findall(r'\[.*?\]', response)) |
|
|
words = response.split() |
|
|
for i in range(target_count - current_count): |
|
|
if i < len(words): |
|
|
words[i] = f'[{words[i]}]' |
|
|
return ' '.join(words) |
|
|
return response |
|
|
|
|
|
def _fix_lowercase(self, response): |
|
|
return response.lower() |
|
|
|
|
|
def _fix_capital_frequency(self, response, instruction): |
|
|
max_match = re.search(r'less than (\d+)', instruction, re.IGNORECASE) |
|
|
if max_match: |
|
|
max_capitals = int(max_match.group(1)) |
|
|
words = response.split() |
|
|
capital_count = sum(1 for word in words if word.isupper()) |
|
|
if capital_count > max_capitals: |
|
|
for i, word in enumerate(words): |
|
|
if word.isupper() and capital_count > max_capitals: |
|
|
words[i] = word.lower() |
|
|
capital_count -= 1 |
|
|
return ' '.join(words) |
|
|
return response |
|
|
|
|
|
def _fix_quotations(self, response): |
|
|
return f'"{response}"' |
|
|
|
|
|
def _fix_json_format(self, response, instruction): |
|
|
return json.dumps({"response": response}, indent=2) |
|
|
|
|
|
def _fix_word_count(self, response, instruction): |
|
|
limit_match = re.search(r'less than (\d+)', instruction, re.IGNORECASE) |
|
|
if limit_match: |
|
|
word_limit = int(limit_match.group(1)) |
|
|
words = response.split() |
|
|
|
|
|
if len(words) > word_limit: |
|
|
return ' '.join(words[:word_limit]) |
|
|
return response |
|
|
|
|
|
def _fix_sections(self, response, instruction): |
|
|
section_match = re.search(r'(\d+) section', instruction, re.IGNORECASE) |
|
|
if section_match: |
|
|
num_sections = int(section_match.group(1)) |
|
|
sections = [] |
|
|
|
|
|
for i in range(num_sections): |
|
|
sections.append(f"SECTION {i+1}:") |
|
|
sections.append("This section provides content here.") |
|
|
|
|
|
return '\n\n'.join(sections) |
|
|
return response |
|
|
|
|
|
def _fix_ending(self, response, instruction): |
|
|
end_match = re.search(r'finish.*with.*phrase[:\s]*([^.!?]*)', instruction, re.IGNORECASE) |
|
|
if end_match: |
|
|
required_ending = end_match.group(1).strip() |
|
|
if not response.endswith(required_ending): |
|
|
return response + " " + required_ending |
|
|
return response |
|
|
|
|
|
def _fix_forbidden_words(self, response, instruction): |
|
|
forbidden_match = re.search(r'without.*word[:\s]*([^.!?]*)', instruction, re.IGNORECASE) |
|
|
if forbidden_match: |
|
|
forbidden_word = forbidden_match.group(1).strip().lower() |
|
|
response = re.sub(re.escape(forbidden_word), '', response, flags=re.IGNORECASE) |
|
|
return response.strip() |
|
|
|
|
|
def _fix_all_capitals(self, response, instruction): |
|
|
return response.upper() |
|
|
|
|
|
EVAL_HANDLER = EvalHandler() |
|
|
|
|
|
def chat(system_prompt: str, user_prompt: str) -> str: |
|
|
""" |
|
|
Run one turn of chat with a system + user message. |
|
|
Extra **gen_kwargs are forwarded to `generate()`. |
|
|
""" |
|
|
try: |
|
|
global EVAL_HANDLER |
|
|
if EVAL_HANDLER is None: |
|
|
EVAL_HANDLER = EvalHandler() |
|
|
applicable_rules = EVAL_HANDLER.detect_rules(user_prompt) |
|
|
system_prompt_parts = [] |
|
|
if applicable_rules: |
|
|
if 'CommaChecker' in applicable_rules: |
|
|
system_prompt_parts.append("Do not use any commas in your response.") |
|
|
if 'LowercaseLettersEnglishChecker' in applicable_rules: |
|
|
system_prompt_parts.append("Respond in all lowercase letters only.") |
|
|
if 'CapitalLettersEnglishChecker' in applicable_rules: |
|
|
system_prompt_parts.append("Respond in ALL CAPITAL LETTERS.") |
|
|
if 'QuotationChecker' in applicable_rules: |
|
|
system_prompt_parts.append("Wrap your entire response in double quotation marks.") |
|
|
if 'JsonFormat' in applicable_rules: |
|
|
system_prompt_parts.append("Format your response as valid JSON.") |
|
|
if 'SectionChecker' in applicable_rules: |
|
|
system_prompt_parts.append("Organize your response into clearly marked sections.") |
|
|
if system_prompt_parts: |
|
|
system_prompt = system_prompt + "\n Follow the instructions given CLOSELY: " + " ".join(system_prompt_parts) |
|
|
except Exception as e: |
|
|
system_prompt = system_prompt |
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
] |
|
|
inputs = tok.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt", |
|
|
return_dict=True |
|
|
) |
|
|
input_ids = inputs["input_ids"].to(lm.device) |
|
|
attention_mask = inputs["attention_mask"].to(lm.device) |
|
|
with torch.inference_mode(): |
|
|
output_ids = lm.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
pad_token_id=tok.eos_token_id, |
|
|
max_new_tokens=2048, |
|
|
do_sample=True, |
|
|
temperature=0.2, |
|
|
repetition_penalty=1.1, |
|
|
top_k=100, |
|
|
top_p=0.9, |
|
|
) |
|
|
answer = tok.decode( |
|
|
output_ids[0][input_ids.shape[-1]:], |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=True, |
|
|
) |
|
|
return answer.strip() |
|
|
|
|
|
def gt(audio: np.ndarray, sr: int): |
|
|
ss = audio.squeeze().astype(np.float32) |
|
|
if sr != 16_000: |
|
|
ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000) |
|
|
result = asr_model.transcribe(ss, fp16=False, language=None) |
|
|
transcribed_text = result["text"].strip() |
|
|
return transcribed_text |
|
|
|
|
|
def sample(rr: str) -> str: |
|
|
if rr.strip() == "": rr = "Hello " |
|
|
inputs = tok(rr, return_tensors="pt").to(lm.device) |
|
|
with torch.inference_mode(): |
|
|
out_ids = lm.generate( |
|
|
**inputs, |
|
|
max_new_tokens=2048, |
|
|
do_sample=True, |
|
|
temperature=0.2, |
|
|
repetition_penalty=1.1, |
|
|
top_k=100, |
|
|
top_p=0.95, |
|
|
) |
|
|
return tok.decode( |
|
|
out_ids[0][inputs.input_ids.shape[-1] :], skip_special_tokens=True |
|
|
) |
|
|
|
|
|
def text_to_speech_vibevoice(text: str) -> np.ndarray: |
|
|
""" |
|
|
Convert text to speech using VibeVoice (synchronous). |
|
|
|
|
|
Args: |
|
|
text: Text to convert to speech |
|
|
|
|
|
Returns: |
|
|
Audio array as numpy array (mono, 16kHz) or None if failed |
|
|
""" |
|
|
global vibevoice_model, vibevoice_processor, vibevoice_voice_sample |
|
|
|
|
|
if not VIBEVOICE_AVAILABLE or vibevoice_model is None or vibevoice_processor is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
if not text or not text.strip(): |
|
|
return np.zeros(16000, dtype=np.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lines = text.strip().split('\n') |
|
|
formatted_lines = [] |
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if line: |
|
|
|
|
|
formatted_lines.append(f"Speaker 1: {line}") |
|
|
formatted_text = '\n'.join(formatted_lines) |
|
|
|
|
|
|
|
|
processor_kwargs = { |
|
|
"text": [formatted_text], |
|
|
"padding": True, |
|
|
"return_tensors": "pt", |
|
|
"return_attention_mask": True, |
|
|
} |
|
|
|
|
|
|
|
|
if vibevoice_voice_sample is not None: |
|
|
processor_kwargs["voice_samples"] = [[vibevoice_voice_sample]] |
|
|
|
|
|
inputs = vibevoice_processor(**processor_kwargs) |
|
|
|
|
|
|
|
|
device = next(vibevoice_model.parameters()).device |
|
|
for k, v in inputs.items(): |
|
|
if torch.is_tensor(v): |
|
|
inputs[k] = v.to(device) |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
outputs = vibevoice_model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=None, |
|
|
cfg_scale=1.3, |
|
|
tokenizer=vibevoice_processor.tokenizer, |
|
|
generation_config={'do_sample': False}, |
|
|
verbose=False, |
|
|
is_prefill=(vibevoice_voice_sample is not None), |
|
|
) |
|
|
|
|
|
|
|
|
if outputs.speech_outputs and outputs.speech_outputs[0] is not None: |
|
|
audio_tensor = outputs.speech_outputs[0] |
|
|
|
|
|
|
|
|
if torch.is_tensor(audio_tensor): |
|
|
if audio_tensor.dtype == torch.bfloat16: |
|
|
audio_tensor = audio_tensor.float() |
|
|
audio_array = audio_tensor.cpu().numpy().astype(np.float32) |
|
|
else: |
|
|
audio_array = np.array(audio_tensor, dtype=np.float32) |
|
|
|
|
|
|
|
|
if len(audio_array.shape) > 1: |
|
|
audio_array = audio_array.squeeze() |
|
|
|
|
|
|
|
|
if len(audio_array) > 0: |
|
|
audio_array = librosa.resample(audio_array, orig_sr=24000, target_sr=16000) |
|
|
return audio_array.astype(np.float32) |
|
|
else: |
|
|
return np.zeros(16000, dtype=np.float32) |
|
|
else: |
|
|
return np.zeros(16000, dtype=np.float32) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"VibeVoice generation failed: {e}") |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
async def text_to_speech_edge_tts(text: str, voice: str = "en-US-AriaNeural") -> np.ndarray: |
|
|
""" |
|
|
Convert text to speech using edge-tts (async). |
|
|
|
|
|
Args: |
|
|
text: Text to convert to speech |
|
|
voice: Voice to use (default: en-US-AriaNeural) |
|
|
|
|
|
Returns: |
|
|
Audio array as numpy array (mono, 16kHz) |
|
|
""" |
|
|
if not TTS_AVAILABLE: |
|
|
raise RuntimeError("edge-tts not available") |
|
|
|
|
|
try: |
|
|
|
|
|
communicate = edge_tts.Communicate(text, voice) |
|
|
audio_data = b"" |
|
|
async for chunk in communicate.stream(): |
|
|
if chunk["type"] == "audio": |
|
|
audio_data += chunk["data"] |
|
|
|
|
|
if not audio_data: |
|
|
return np.zeros(16000, dtype=np.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tmp_file_path = None |
|
|
try: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file: |
|
|
tmp_file.write(audio_data) |
|
|
tmp_file.flush() |
|
|
tmp_file_path = tmp_file.name |
|
|
|
|
|
|
|
|
|
|
|
audio_array, sample_rate = librosa.load(tmp_file_path, sr=None, mono=True) |
|
|
|
|
|
|
|
|
if sample_rate != 16000: |
|
|
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16000) |
|
|
sample_rate = 16000 |
|
|
|
|
|
return audio_array.astype(np.float32) |
|
|
finally: |
|
|
|
|
|
if tmp_file_path and os.path.exists(tmp_file_path): |
|
|
try: |
|
|
os.unlink(tmp_file_path) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
except Exception as e: |
|
|
raise |
|
|
|
|
|
def clean_v2t_response_for_v2v(response_text: str) -> str: |
|
|
""" |
|
|
Post-process v2t response to remove the first two lines before using for t2v. |
|
|
|
|
|
The v2t response typically contains: |
|
|
- Line 1: The actual response text (often the input prompt repeated) |
|
|
- Line 2: Injected scoring line (e.g., "As an answer 5 points with scale from 5 to 10...") |
|
|
- Line 3+: The actual useful response content |
|
|
|
|
|
This function removes the first two lines to get the clean response for TTS. |
|
|
|
|
|
Args: |
|
|
response_text: Full response text from v2t endpoint |
|
|
|
|
|
Returns: |
|
|
Cleaned text with first two lines removed |
|
|
""" |
|
|
if not response_text: |
|
|
return "" |
|
|
|
|
|
lines = response_text.split("\n") |
|
|
|
|
|
|
|
|
if len(lines) >= 3: |
|
|
|
|
|
cleaned_lines = lines[2:] |
|
|
cleaned_text = "\n".join(cleaned_lines).strip() |
|
|
|
|
|
|
|
|
if not cleaned_text and len(lines) >= 2: |
|
|
cleaned_text = "\n".join(lines[1:]).strip() |
|
|
|
|
|
|
|
|
if not cleaned_text: |
|
|
cleaned_text = response_text.strip() |
|
|
|
|
|
return cleaned_text |
|
|
elif len(lines) == 2: |
|
|
|
|
|
cleaned_text = lines[1].strip() |
|
|
return cleaned_text |
|
|
else: |
|
|
|
|
|
return response_text.strip() |
|
|
|
|
|
|
|
|
def clean_text_for_tts_with_llm(text: str) -> str: |
|
|
""" |
|
|
Use LLM to intelligently clean text for text-to-speech while preserving important content. |
|
|
|
|
|
This function sends the text to the LLM with instructions to: |
|
|
- Remove unicode characters, symbols, and formatting that don't contribute to speech |
|
|
- Preserve important content like math equations (convert to spoken form) |
|
|
- Keep all meaningful words, numbers, and essential punctuation |
|
|
- Make the text natural and clear for TTS |
|
|
|
|
|
Args: |
|
|
text: Text to clean for TTS |
|
|
|
|
|
Returns: |
|
|
Cleaned text optimized for text-to-speech |
|
|
""" |
|
|
if not text or not text.strip(): |
|
|
return "" |
|
|
|
|
|
global tok, lm |
|
|
if tok is None or lm is None: |
|
|
return text |
|
|
|
|
|
try: |
|
|
|
|
|
system_prompt = """You are a text cleaning assistant. Your task is to clean text for text-to-speech (TTS) conversion. |
|
|
|
|
|
IMPORTANT RULES: |
|
|
1. Remove all unicode characters, special symbols, and formatting that don't contribute to speech |
|
|
2. PRESERVE important content: |
|
|
- Math equations: Convert them to spoken form (e.g., "x squared plus y equals 5" instead of "x² + y = 5") |
|
|
- Numbers: Keep all numbers and convert them to natural speech format |
|
|
- Important punctuation: Keep periods, commas, question marks, exclamation marks for natural speech flow |
|
|
3. Remove markdown formatting, asterisks, underscores, brackets, etc. that are not needed for speech |
|
|
4. Keep all meaningful words, letters, and essential content |
|
|
5. Make the text natural, clear, and easy to read aloud |
|
|
6. Do NOT remove any actual content or meaning from the text |
|
|
7. Convert any special formatting to natural spoken language |
|
|
|
|
|
Return ONLY the cleaned text, nothing else.""" |
|
|
|
|
|
user_prompt = f"Clean this text for text-to-speech:\n\n{text}" |
|
|
|
|
|
|
|
|
cleaned_text = chat(system_prompt, user_prompt) |
|
|
|
|
|
|
|
|
cleaned_text = cleaned_text.strip() |
|
|
|
|
|
|
|
|
|
|
|
if "cleaned text" in cleaned_text.lower() or "here's" in cleaned_text.lower(): |
|
|
|
|
|
lines = cleaned_text.split("\n") |
|
|
|
|
|
cleaned_lines = [] |
|
|
skip_next = False |
|
|
for line in lines: |
|
|
line_lower = line.lower().strip() |
|
|
if any(marker in line_lower for marker in ["cleaned text", "here's", "here is", "result:", "output:"]): |
|
|
skip_next = True |
|
|
continue |
|
|
if skip_next and not line.strip(): |
|
|
continue |
|
|
skip_next = False |
|
|
cleaned_lines.append(line) |
|
|
if cleaned_lines: |
|
|
cleaned_text = "\n".join(cleaned_lines).strip() |
|
|
|
|
|
return cleaned_text |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
def text_to_speech(text: str, voice: str = "en-US-AriaNeural") -> np.ndarray: |
|
|
""" |
|
|
Convert text to speech using VibeVoice (preferred) or edge-tts (fallback). |
|
|
|
|
|
Args: |
|
|
text: Text to convert to speech |
|
|
voice: Voice to use (for edge-tts fallback, default: en-US-AriaNeural) |
|
|
|
|
|
Returns: |
|
|
Audio array as numpy array (mono, 16kHz) |
|
|
""" |
|
|
|
|
|
audio = text_to_speech_vibevoice(text) |
|
|
if audio is not None: |
|
|
return audio |
|
|
|
|
|
|
|
|
if not TTS_AVAILABLE: |
|
|
return np.zeros(16000, dtype=np.float32) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
return asyncio.run(text_to_speech_edge_tts(text, voice)) |
|
|
except Exception: |
|
|
|
|
|
return np.zeros(16000, dtype=np.float32) |
|
|
|
|
|
INITIALIZATION_STATUS = {"model_loaded": True, "error": None} |
|
|
class GenerateRequest(BaseModel): |
|
|
audio_data: str = Field( |
|
|
..., |
|
|
description="", |
|
|
) |
|
|
sample_rate: int = Field(..., description="") |
|
|
|
|
|
class GenerateResponse(BaseModel): |
|
|
audio_data: str = Field(..., description="") |
|
|
|
|
|
app = FastAPI(title="V1", version="0.1") |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
def b64(b64: str) -> np.ndarray: |
|
|
raw = base64.b64decode(b64) |
|
|
return np.load(io.BytesIO(raw), allow_pickle=False) |
|
|
def ab64(arr: np.ndarray, sr: int) -> str: |
|
|
buf = io.BytesIO() |
|
|
resampled = librosa.resample(arr, orig_sr=16000, target_sr=sr) |
|
|
np.save(buf, resampled.astype(np.float32)) |
|
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
@app.get("/api/v1/health") |
|
|
def health_check(): |
|
|
status = { |
|
|
"status": "healthy", |
|
|
"model_loaded": INITIALIZATION_STATUS["model_loaded"], |
|
|
"error": INITIALIZATION_STATUS["error"], |
|
|
} |
|
|
return status |
|
|
@app.post("/api/v1/v2v", response_model=GenerateResponse) |
|
|
def generate_audio(req: GenerateRequest): |
|
|
"""Voice-to-voice endpoint - returns audio response. |
|
|
|
|
|
Process: |
|
|
1. Convert input audio to text (v2t) |
|
|
2. Generate text response (LLM) |
|
|
3. Clean response text for TTS |
|
|
4. Convert cleaned text to speech (t2v) using VibeVoice or edge-tts |
|
|
5. Return generated audio |
|
|
""" |
|
|
if not VIBEVOICE_AVAILABLE and not TTS_AVAILABLE: |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail="TTS functionality not available. Please install VibeVoice or edge-tts" |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
audio_np = b64(req.audio_data) |
|
|
|
|
|
|
|
|
if audio_np.ndim == 1: |
|
|
audio_np = audio_np.reshape(1, -1) |
|
|
elif audio_np.ndim == 2 and audio_np.shape[0] > 1: |
|
|
|
|
|
audio_np = audio_np.mean(axis=0, keepdims=True) |
|
|
|
|
|
|
|
|
user_message = gt(audio_np, req.sample_rate) |
|
|
|
|
|
if not user_message: |
|
|
|
|
|
silence = np.zeros(16000, dtype=np.float32) |
|
|
return GenerateResponse(audio_data=ab64(silence, req.sample_rate)) |
|
|
|
|
|
|
|
|
system_prompt = "You are a helpful assistant who tries to help answer the user's question." |
|
|
system_prompt += "\n\n" + """Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. |
|
|
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" |
|
|
|
|
|
response_text = chat(system_prompt, user_message) |
|
|
|
|
|
|
|
|
cleaned_response_text = clean_v2t_response_for_v2v(response_text) |
|
|
|
|
|
|
|
|
cleaned_response_text = clean_text_for_tts_with_llm(cleaned_response_text) |
|
|
|
|
|
|
|
|
try: |
|
|
audio_output = text_to_speech(cleaned_response_text) |
|
|
encoded_audio = ab64(audio_output, req.sample_rate) |
|
|
except Exception as tts_error: |
|
|
|
|
|
silence = np.zeros(16000, dtype=np.float32) |
|
|
encoded_audio = ab64(silence, req.sample_rate) |
|
|
|
|
|
return GenerateResponse(audio_data=encoded_audio) |
|
|
|
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
|
|
|
try: |
|
|
silence = np.zeros(16000, dtype=np.float32) |
|
|
encoded_audio = ab64(silence, req.sample_rate) |
|
|
return GenerateResponse(audio_data=encoded_audio) |
|
|
except: |
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"{e}") |
|
|
|
|
|
@app.post("/api/v1/v2t") |
|
|
def generate_text(req: GenerateRequest): |
|
|
audio_np = b64(req.audio_data) |
|
|
if audio_np.ndim == 1: |
|
|
audio_np = audio_np.reshape(1, -1) |
|
|
try: |
|
|
text = gt(audio_np, req.sample_rate) |
|
|
system_prompt = "You are a helpful assistant who tries to help answer the user's question." |
|
|
response_text = chat(system_prompt, user_prompt=text) |
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
raise HTTPException(status_code=500, detail=f"{e}") |
|
|
return {"text": response_text} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False) |
|
|
|