| """Space: TTS Test β Test your trained voice model |
| |
| Downloads voice model from Hub -> generates speech from text. |
| GPU: T4 medium (F5-TTS inference only) |
| """ |
| import gc |
| import logging |
| import os |
| import shutil |
| import traceback |
| from pathlib import Path |
|
|
| import gradio as gr |
| import soundfile as sf |
| import torch |
|
|
| from hub_utils import download_step |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| |
| IS_HF_SPACE = os.environ.get("SPACE_ID") is not None |
| _data_path = Path("/data") |
| if IS_HF_SPACE and _data_path.exists() and os.access(_data_path, os.W_OK): |
| BASE_DIR = _data_path |
| else: |
| BASE_DIR = Path("data") |
|
|
| VOICE_MODEL_DIR = BASE_DIR / "voice_model" |
| TEMP_DIR = BASE_DIR / "temp" |
| HF_CACHE_DIR = BASE_DIR / "hf_cache" |
|
|
| for d in [VOICE_MODEL_DIR, TEMP_DIR, HF_CACHE_DIR]: |
| d.mkdir(parents=True, exist_ok=True) |
|
|
| os.environ["HF_HOME"] = str(HF_CACHE_DIR) |
| os.environ["TRANSFORMERS_CACHE"] = str(HF_CACHE_DIR) |
|
|
| F5_SPANISH_MODEL_ID = "jpgallegoar/F5-Spanish" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| APP_VERSION = "1.1.0" |
|
|
| _f5_model = None |
| _ref_text_cache = {} |
|
|
|
|
| def _clear_cache(): |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
|
|
| def _load_tts(): |
| global _f5_model |
| if _f5_model is not None: |
| return |
|
|
| from f5_tts.api import F5TTS |
|
|
| finetuned_path = VOICE_MODEL_DIR / "model_last.pt" |
| if not finetuned_path.exists(): |
| checkpoints = list(VOICE_MODEL_DIR.glob("*.pt")) + list(VOICE_MODEL_DIR.glob("*.safetensors")) |
| |
| checkpoints = [c for c in checkpoints if "pretrained" not in c.name and "1250000" not in c.name] |
| finetuned_path = checkpoints[0] if checkpoints else None |
|
|
| if finetuned_path and finetuned_path.exists(): |
| logger.info(f"Loading fine-tuned F5-TTS from {finetuned_path}") |
| _f5_model = F5TTS(ckpt_file=str(finetuned_path), device=DEVICE) |
| else: |
| logger.info(f"No fine-tuned model found, loading base F5-Spanish") |
| _f5_model = F5TTS(device=DEVICE) |
|
|
| logger.info("F5-TTS loaded") |
|
|
|
|
| def _unload_tts(): |
| global _f5_model |
| if _f5_model is not None: |
| del _f5_model |
| _f5_model = None |
| _clear_cache() |
|
|
|
|
| def _get_reference_audio(): |
| ref = VOICE_MODEL_DIR / "reference.wav" |
| if ref.exists(): |
| return str(ref) |
| raise FileNotFoundError("No hay reference.wav. Descarga el modelo primero.") |
|
|
|
|
| def _get_ref_text(audio_path): |
| """Pre-transcribe reference audio in Spanish to avoid Whisper auto-detecting wrong language.""" |
| if audio_path in _ref_text_cache: |
| return _ref_text_cache[audio_path] |
| _load_tts() |
| logger.info(f"Transcribing reference audio as Spanish: {audio_path}") |
| ref_text = _f5_model.transcribe(audio_path, language="spanish") |
| logger.info(f"Reference transcription: {ref_text}") |
| _ref_text_cache[audio_path] = ref_text |
| return ref_text |
|
|
|
|
| |
|
|
| def download_model(project_name, progress=gr.Progress()): |
| if not project_name or not project_name.strip(): |
| return "Error: Debes introducir un nombre de proyecto" |
| name = project_name.strip() |
|
|
| try: |
| _unload_tts() |
|
|
| if VOICE_MODEL_DIR.exists(): |
| shutil.rmtree(VOICE_MODEL_DIR) |
| VOICE_MODEL_DIR.mkdir(parents=True) |
|
|
| progress(0.1, desc="Descargando modelo de voz...") |
| download_step(name, "step3_voice", str(BASE_DIR)) |
|
|
| src = BASE_DIR / name / "step3_voice" |
| if src.exists(): |
| for f in src.iterdir(): |
| shutil.move(str(f), str(VOICE_MODEL_DIR / f.name)) |
| shutil.rmtree(BASE_DIR / name, ignore_errors=True) |
|
|
| models = list(VOICE_MODEL_DIR.glob("*.pt")) + list(VOICE_MODEL_DIR.glob("*.safetensors")) |
| has_ref = (VOICE_MODEL_DIR / "reference.wav").exists() |
| files_str = ", ".join(f.name for f in models) |
|
|
| progress(0.9, desc="Cargando modelo...") |
| _load_tts() |
|
|
| return f"OK - Modelo descargado ({files_str}), referencia: {'si' if has_ref else 'no'}" |
| except Exception as e: |
| return f"Error: {e}" |
|
|
|
|
| def generate_speech(project_name, text, speed, progress=gr.Progress()): |
| if not project_name or not project_name.strip(): |
| return None, "Error: Debes introducir un nombre de proyecto" |
| if not text or not text.strip(): |
| return None, "Error: Introduce texto para generar" |
|
|
| try: |
| progress(0.1, desc="Cargando modelo...") |
| _load_tts() |
|
|
| ref_audio = _get_reference_audio() |
| ref_text = _get_ref_text(ref_audio) |
| output_path = str(TEMP_DIR / "tts_output.wav") |
|
|
| progress(0.3, desc="Generando voz...") |
| logger.info(f"Generating: '{text[:80]}...'") |
|
|
| audio, sr, _spec = _f5_model.infer( |
| ref_file=ref_audio, |
| ref_text=ref_text, |
| gen_text=text, |
| speed=speed, |
| ) |
|
|
| sf.write(output_path, audio, sr) |
| duration = len(audio) / sr |
| logger.info(f"Generated {duration:.1f}s audio") |
|
|
| return output_path, f"OK - {duration:.1f}s de audio generado" |
| except Exception as e: |
| logger.error(f"TTS failed:\n{traceback.format_exc()}") |
| return None, f"Error: {e}" |
|
|
|
|
| def generate_with_custom_ref(project_name, text, ref_audio_path, speed, progress=gr.Progress()): |
| if not project_name or not project_name.strip(): |
| return None, "Error: Debes introducir un nombre de proyecto" |
| if not text or not text.strip(): |
| return None, "Error: Introduce texto para generar" |
| if ref_audio_path is None: |
| return None, "Error: Sube un audio de referencia" |
|
|
| try: |
| progress(0.1, desc="Cargando modelo...") |
| _load_tts() |
|
|
| output_path = str(TEMP_DIR / "tts_custom_output.wav") |
|
|
| progress(0.3, desc="Generando voz...") |
| logger.info(f"Generating with custom ref: '{text[:80]}...'") |
|
|
| ref_text = _get_ref_text(ref_audio_path) |
| audio, sr, _spec = _f5_model.infer( |
| ref_file=ref_audio_path, |
| ref_text=ref_text, |
| gen_text=text, |
| speed=speed, |
| ) |
|
|
| sf.write(output_path, audio, sr) |
| duration = len(audio) / sr |
|
|
| return output_path, f"OK - {duration:.1f}s de audio generado (ref custom)" |
| except Exception as e: |
| logger.error(f"TTS failed:\n{traceback.format_exc()}") |
| return None, f"Error: {e}" |
|
|
|
|
| |
|
|
| with gr.Blocks(title="Talking Head - TTS Test", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(f"# Talking Head - Test TTS `v{APP_VERSION}`\nPrueba tu modelo de voz entrenado con F5-TTS") |
|
|
| project_name = gr.Textbox( |
| label="Nombre del proyecto", |
| placeholder="primer_intento", |
| info="Obligatorio. Debe coincidir con el proyecto que entrenaste.", |
| ) |
|
|
| gr.Markdown("### 1. Descargar modelo de voz") |
| download_btn = gr.Button("Descargar modelo del Hub", variant="secondary") |
| download_status = gr.Textbox(label="Estado", interactive=False) |
|
|
| gr.Markdown("### 2. Generar voz") |
| with gr.Tabs(): |
| with gr.Tab("Referencia guardada"): |
| text_input = gr.Textbox( |
| label="Texto a hablar (espanol)", |
| placeholder="Hola, soy un avatar digital hiperrealista creado con inteligencia artificial.", |
| lines=4, |
| ) |
| speed_slider = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Velocidad") |
| gen_btn = gr.Button("Generar Voz", variant="primary") |
| audio_output = gr.Audio(label="Audio generado") |
| gen_status = gr.Textbox(label="Estado", interactive=False) |
|
|
| with gr.Tab("Referencia custom"): |
| gr.Markdown("Sube un audio de referencia diferente para clonar otra voz") |
| text_input_custom = gr.Textbox( |
| label="Texto a hablar (espanol)", |
| placeholder="Hola, esta es una prueba con referencia personalizada.", |
| lines=4, |
| ) |
| ref_audio = gr.Audio(label="Audio de referencia (WAV)", type="filepath") |
| speed_slider_custom = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Velocidad") |
| gen_btn_custom = gr.Button("Generar Voz (Custom)", variant="primary") |
| audio_output_custom = gr.Audio(label="Audio generado") |
| gen_status_custom = gr.Textbox(label="Estado", interactive=False) |
|
|
| download_btn.click(download_model, inputs=[project_name], outputs=[download_status]) |
| gen_btn.click( |
| generate_speech, |
| inputs=[project_name, text_input, speed_slider], |
| outputs=[audio_output, gen_status], |
| ) |
| gen_btn_custom.click( |
| generate_with_custom_ref, |
| inputs=[project_name, text_input_custom, ref_audio, speed_slider_custom], |
| outputs=[audio_output_custom, gen_status_custom], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False) |
|
|