|
|
from typing import Dict, Any |
|
|
import tempfile |
|
|
import torchaudio |
|
|
import soundfile as sf |
|
|
import re |
|
|
from num2words import num2words |
|
|
from f5_tts.model import DiT |
|
|
from f5_tts.infer.utils_infer import ( |
|
|
load_vocoder, |
|
|
load_model, |
|
|
preprocess_ref_audio_text, |
|
|
infer_process, |
|
|
remove_silence_for_generated_wav, |
|
|
) |
|
|
import base64 |
|
|
import io |
|
|
import numpy as np |
|
|
from huggingface_hub import hf_hub_download |
|
|
import traceback |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
self.vocoder = load_vocoder() |
|
|
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) |
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id="jpgallegoar/F5-Spanish", |
|
|
filename="model_1200000.safetensors" |
|
|
) |
|
|
|
|
|
self.ema_model = load_model(DiT, model_cfg, model_path) |
|
|
|
|
|
def traducir_numero_a_texto(self, texto): |
|
|
texto_separado = re.sub(r'([A-Za-z])(\d)', r'\1 \2', texto) |
|
|
texto_separado = re.sub(r'(\d)([A-Za-z])', r'\1 \2', texto_separado) |
|
|
def reemplazar_numero(match): |
|
|
numero = match.group() |
|
|
return num2words(int(numero), lang='es') |
|
|
return re.sub(r'\b\d+\b', reemplazar_numero, texto_separado) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
try: |
|
|
ref_audio_base64 = data.get("ref_audio") |
|
|
if not ref_audio_base64: |
|
|
return { |
|
|
"success": False, |
|
|
"error": "Missing required field: 'ref_audio'" |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
audio_bytes = base64.b64decode(ref_audio_base64) |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file: |
|
|
temp_audio_file.write(audio_bytes) |
|
|
temp_audio_path = temp_audio_file.name |
|
|
except Exception as e: |
|
|
return { |
|
|
"success": False, |
|
|
"error": f"Invalid audio data: {type(e).__name__}: {str(e)}" |
|
|
} |
|
|
|
|
|
ref_text = data.get("ref_text", "") |
|
|
gen_text = data.get("gen_text", "") |
|
|
if not gen_text: |
|
|
return { |
|
|
"success": False, |
|
|
"error": "Missing required field: 'gen_text'" |
|
|
} |
|
|
|
|
|
remove_silence = data.get("remove_silence", True) |
|
|
cross_fade_duration = data.get("cross_fade_duration", 0.15) |
|
|
speed = data.get("speed", 1.0) |
|
|
|
|
|
ref_audio, ref_text = preprocess_ref_audio_text(temp_audio_path, ref_text, show_info=print) |
|
|
|
|
|
if not gen_text.startswith(" "): |
|
|
gen_text = " " + gen_text |
|
|
if not gen_text.endswith(". "): |
|
|
gen_text += ". " |
|
|
gen_text = self.traducir_numero_a_texto(gen_text.lower()) |
|
|
|
|
|
final_wave, final_sample_rate, _ = infer_process( |
|
|
ref_audio, |
|
|
ref_text, |
|
|
gen_text, |
|
|
self.ema_model, |
|
|
self.vocoder, |
|
|
cross_fade_duration=cross_fade_duration, |
|
|
speed=speed, |
|
|
show_info=print, |
|
|
) |
|
|
|
|
|
if remove_silence: |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: |
|
|
sf.write(f.name, final_wave, final_sample_rate) |
|
|
remove_silence_for_generated_wav(f.name) |
|
|
final_wave, _ = torchaudio.load(f.name) |
|
|
final_wave = final_wave.squeeze().cpu().numpy() |
|
|
|
|
|
with io.BytesIO() as buffer: |
|
|
sf.write(buffer, final_wave, final_sample_rate, format="WAV") |
|
|
buffer.seek(0) |
|
|
encoded_audio = base64.b64encode(buffer.read()).decode("utf-8") |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"audio_base64": encoded_audio |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print("==== Exception Traceback ====") |
|
|
traceback.print_exc() |
|
|
print("==== End Traceback ====") |
|
|
return { |
|
|
"success": False, |
|
|
"error": f"{type(e).__name__}: {str(e)}" |
|
|
} |
|
|
|