Spanish-F5-API / handler.py
eloicito333's picture
fix typo in success key in EndpointHandler response
df516aa
from typing import Dict, Any
import tempfile
import torchaudio
import soundfile as sf
import re
from num2words import num2words
from f5_tts.model import DiT
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
)
import base64
import io
import numpy as np
from huggingface_hub import hf_hub_download
import traceback
class EndpointHandler:
def __init__(self, path=""):
self.vocoder = load_vocoder()
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_path = hf_hub_download(
repo_id="jpgallegoar/F5-Spanish",
filename="model_1200000.safetensors"
)
self.ema_model = load_model(DiT, model_cfg, model_path)
def traducir_numero_a_texto(self, texto):
texto_separado = re.sub(r'([A-Za-z])(\d)', r'\1 \2', texto)
texto_separado = re.sub(r'(\d)([A-Za-z])', r'\1 \2', texto_separado)
def reemplazar_numero(match):
numero = match.group()
return num2words(int(numero), lang='es')
return re.sub(r'\b\d+\b', reemplazar_numero, texto_separado)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
ref_audio_base64 = data.get("ref_audio")
if not ref_audio_base64:
return {
"success": False,
"error": "Missing required field: 'ref_audio'"
}
# Decode base64 audio and write to temp file
try:
audio_bytes = base64.b64decode(ref_audio_base64)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
temp_audio_file.write(audio_bytes)
temp_audio_path = temp_audio_file.name
except Exception as e:
return {
"success": False,
"error": f"Invalid audio data: {type(e).__name__}: {str(e)}"
}
ref_text = data.get("ref_text", "")
gen_text = data.get("gen_text", "")
if not gen_text:
return {
"success": False,
"error": "Missing required field: 'gen_text'"
}
remove_silence = data.get("remove_silence", True)
cross_fade_duration = data.get("cross_fade_duration", 0.15)
speed = data.get("speed", 1.0)
ref_audio, ref_text = preprocess_ref_audio_text(temp_audio_path, ref_text, show_info=print)
if not gen_text.startswith(" "):
gen_text = " " + gen_text
if not gen_text.endswith(". "):
gen_text += ". "
gen_text = self.traducir_numero_a_texto(gen_text.lower())
final_wave, final_sample_rate, _ = infer_process(
ref_audio,
ref_text,
gen_text,
self.ema_model,
self.vocoder,
cross_fade_duration=cross_fade_duration,
speed=speed,
show_info=print,
)
if remove_silence:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
sf.write(f.name, final_wave, final_sample_rate)
remove_silence_for_generated_wav(f.name)
final_wave, _ = torchaudio.load(f.name)
final_wave = final_wave.squeeze().cpu().numpy()
with io.BytesIO() as buffer:
sf.write(buffer, final_wave, final_sample_rate, format="WAV")
buffer.seek(0)
encoded_audio = base64.b64encode(buffer.read()).decode("utf-8")
return {
"success": True,
"audio_base64": encoded_audio
}
except Exception as e:
print("==== Exception Traceback ====")
traceback.print_exc()
print("==== End Traceback ====")
return {
"success": False,
"error": f"{type(e).__name__}: {str(e)}"
}