baenacoco's picture
Upload app.py with huggingface_hub
781f017 verified
"""Space: TTS Test β€” Test your trained voice model
Downloads voice model from Hub -> generates speech from text.
GPU: T4 medium (F5-TTS inference only)
"""
import gc
import logging
import os
import shutil
import traceback
from pathlib import Path
import gradio as gr
import soundfile as sf
import torch
from hub_utils import download_step
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger(__name__)
# ── Config ──
IS_HF_SPACE = os.environ.get("SPACE_ID") is not None
_data_path = Path("/data")
if IS_HF_SPACE and _data_path.exists() and os.access(_data_path, os.W_OK):
BASE_DIR = _data_path
else:
BASE_DIR = Path("data")
VOICE_MODEL_DIR = BASE_DIR / "voice_model"
TEMP_DIR = BASE_DIR / "temp"
HF_CACHE_DIR = BASE_DIR / "hf_cache"
for d in [VOICE_MODEL_DIR, TEMP_DIR, HF_CACHE_DIR]:
d.mkdir(parents=True, exist_ok=True)
os.environ["HF_HOME"] = str(HF_CACHE_DIR)
os.environ["TRANSFORMERS_CACHE"] = str(HF_CACHE_DIR)
F5_SPANISH_MODEL_ID = "jpgallegoar/F5-Spanish"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
APP_VERSION = "1.1.0"
_f5_model = None
_ref_text_cache = {} # {audio_path: transcribed_text}
def _clear_cache():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _load_tts():
global _f5_model
if _f5_model is not None:
return
from f5_tts.api import F5TTS
finetuned_path = VOICE_MODEL_DIR / "model_last.pt"
if not finetuned_path.exists():
checkpoints = list(VOICE_MODEL_DIR.glob("*.pt")) + list(VOICE_MODEL_DIR.glob("*.safetensors"))
# Exclude pretrained base model
checkpoints = [c for c in checkpoints if "pretrained" not in c.name and "1250000" not in c.name]
finetuned_path = checkpoints[0] if checkpoints else None
if finetuned_path and finetuned_path.exists():
logger.info(f"Loading fine-tuned F5-TTS from {finetuned_path}")
_f5_model = F5TTS(ckpt_file=str(finetuned_path), device=DEVICE)
else:
logger.info(f"No fine-tuned model found, loading base F5-Spanish")
_f5_model = F5TTS(device=DEVICE)
logger.info("F5-TTS loaded")
def _unload_tts():
global _f5_model
if _f5_model is not None:
del _f5_model
_f5_model = None
_clear_cache()
def _get_reference_audio():
ref = VOICE_MODEL_DIR / "reference.wav"
if ref.exists():
return str(ref)
raise FileNotFoundError("No hay reference.wav. Descarga el modelo primero.")
def _get_ref_text(audio_path):
"""Pre-transcribe reference audio in Spanish to avoid Whisper auto-detecting wrong language."""
if audio_path in _ref_text_cache:
return _ref_text_cache[audio_path]
_load_tts()
logger.info(f"Transcribing reference audio as Spanish: {audio_path}")
ref_text = _f5_model.transcribe(audio_path, language="spanish")
logger.info(f"Reference transcription: {ref_text}")
_ref_text_cache[audio_path] = ref_text
return ref_text
# ── Gradio handlers ──
def download_model(project_name, progress=gr.Progress()):
if not project_name or not project_name.strip():
return "Error: Debes introducir un nombre de proyecto"
name = project_name.strip()
try:
_unload_tts()
if VOICE_MODEL_DIR.exists():
shutil.rmtree(VOICE_MODEL_DIR)
VOICE_MODEL_DIR.mkdir(parents=True)
progress(0.1, desc="Descargando modelo de voz...")
download_step(name, "step3_voice", str(BASE_DIR))
src = BASE_DIR / name / "step3_voice"
if src.exists():
for f in src.iterdir():
shutil.move(str(f), str(VOICE_MODEL_DIR / f.name))
shutil.rmtree(BASE_DIR / name, ignore_errors=True)
models = list(VOICE_MODEL_DIR.glob("*.pt")) + list(VOICE_MODEL_DIR.glob("*.safetensors"))
has_ref = (VOICE_MODEL_DIR / "reference.wav").exists()
files_str = ", ".join(f.name for f in models)
progress(0.9, desc="Cargando modelo...")
_load_tts()
return f"OK - Modelo descargado ({files_str}), referencia: {'si' if has_ref else 'no'}"
except Exception as e:
return f"Error: {e}"
def generate_speech(project_name, text, speed, progress=gr.Progress()):
if not project_name or not project_name.strip():
return None, "Error: Debes introducir un nombre de proyecto"
if not text or not text.strip():
return None, "Error: Introduce texto para generar"
try:
progress(0.1, desc="Cargando modelo...")
_load_tts()
ref_audio = _get_reference_audio()
ref_text = _get_ref_text(ref_audio)
output_path = str(TEMP_DIR / "tts_output.wav")
progress(0.3, desc="Generando voz...")
logger.info(f"Generating: '{text[:80]}...'")
audio, sr, _spec = _f5_model.infer(
ref_file=ref_audio,
ref_text=ref_text,
gen_text=text,
speed=speed,
)
sf.write(output_path, audio, sr)
duration = len(audio) / sr
logger.info(f"Generated {duration:.1f}s audio")
return output_path, f"OK - {duration:.1f}s de audio generado"
except Exception as e:
logger.error(f"TTS failed:\n{traceback.format_exc()}")
return None, f"Error: {e}"
def generate_with_custom_ref(project_name, text, ref_audio_path, speed, progress=gr.Progress()):
if not project_name or not project_name.strip():
return None, "Error: Debes introducir un nombre de proyecto"
if not text or not text.strip():
return None, "Error: Introduce texto para generar"
if ref_audio_path is None:
return None, "Error: Sube un audio de referencia"
try:
progress(0.1, desc="Cargando modelo...")
_load_tts()
output_path = str(TEMP_DIR / "tts_custom_output.wav")
progress(0.3, desc="Generando voz...")
logger.info(f"Generating with custom ref: '{text[:80]}...'")
ref_text = _get_ref_text(ref_audio_path)
audio, sr, _spec = _f5_model.infer(
ref_file=ref_audio_path,
ref_text=ref_text,
gen_text=text,
speed=speed,
)
sf.write(output_path, audio, sr)
duration = len(audio) / sr
return output_path, f"OK - {duration:.1f}s de audio generado (ref custom)"
except Exception as e:
logger.error(f"TTS failed:\n{traceback.format_exc()}")
return None, f"Error: {e}"
# ── UI ──
with gr.Blocks(title="Talking Head - TTS Test", theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# Talking Head - Test TTS `v{APP_VERSION}`\nPrueba tu modelo de voz entrenado con F5-TTS")
project_name = gr.Textbox(
label="Nombre del proyecto",
placeholder="primer_intento",
info="Obligatorio. Debe coincidir con el proyecto que entrenaste.",
)
gr.Markdown("### 1. Descargar modelo de voz")
download_btn = gr.Button("Descargar modelo del Hub", variant="secondary")
download_status = gr.Textbox(label="Estado", interactive=False)
gr.Markdown("### 2. Generar voz")
with gr.Tabs():
with gr.Tab("Referencia guardada"):
text_input = gr.Textbox(
label="Texto a hablar (espanol)",
placeholder="Hola, soy un avatar digital hiperrealista creado con inteligencia artificial.",
lines=4,
)
speed_slider = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Velocidad")
gen_btn = gr.Button("Generar Voz", variant="primary")
audio_output = gr.Audio(label="Audio generado")
gen_status = gr.Textbox(label="Estado", interactive=False)
with gr.Tab("Referencia custom"):
gr.Markdown("Sube un audio de referencia diferente para clonar otra voz")
text_input_custom = gr.Textbox(
label="Texto a hablar (espanol)",
placeholder="Hola, esta es una prueba con referencia personalizada.",
lines=4,
)
ref_audio = gr.Audio(label="Audio de referencia (WAV)", type="filepath")
speed_slider_custom = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Velocidad")
gen_btn_custom = gr.Button("Generar Voz (Custom)", variant="primary")
audio_output_custom = gr.Audio(label="Audio generado")
gen_status_custom = gr.Textbox(label="Estado", interactive=False)
download_btn.click(download_model, inputs=[project_name], outputs=[download_status])
gen_btn.click(
generate_speech,
inputs=[project_name, text_input, speed_slider],
outputs=[audio_output, gen_status],
)
gen_btn_custom.click(
generate_with_custom_ref,
inputs=[project_name, text_input_custom, ref_audio, speed_slider_custom],
outputs=[audio_output_custom, gen_status_custom],
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)